#include <string.h>
#include "hm.h"
-#include "../util.h"
+#include "hm/subst.h"
+#include "hm/gamma.h"
+#include "hm/scheme.h"
#include "../ast.h"
-struct gamma {
- int fresh;
- int nschemes;
- char **vars;
- struct scheme *schemes;
-};
-
-struct scheme *gamma_lookup(struct gamma *gamma, char *ident)
-{
- for (int i = 0; i<nschemes; i++) {
- if (strcmp(ident, gamma->vars[i]) == 0) {
- //TODO
- }
- }
- return NULL;
-}
-
-struct type *fresh(struct gamma *gamma)
-{
- char *buf = safe_malloc(10);
- sprintf(buf, "%d", gamma->fresh);
- gamma->fresh++;
- return type_var(buf);
-}
-
-void ftv_type(struct type *r, int *nftv, char **ftv)
-{
- switch (r->type) {
- case tbasic:
- break;
- case tlist:
- ftv_type(r->data.tlist, nftv, ftv);
- break;
- case ttuple:
- ftv_type(r->data.ttuple.l, nftv, ftv);
- ftv_type(r->data.ttuple.r, nftv, ftv);
- break;
- case tvar:
- *ftv = realloc(*ftv, (*nftv+1)*sizeof(char *));
- if (*ftv == NULL)
- perror("realloc");
- ftv[(*nftv)++] = r->data.tvar;
- break;
- }
-}
-
bool occurs_check(char *var, struct type *r)
{
int nftv = 0;
- char *ftv = NULL;
- ftv_type(r, &nftv, &ftv);
+ char **ftv = NULL;
+ type_ftv(r, &nftv, &ftv);
for (int i = 0; i<nftv; i++)
- if (strcmp(ftv+i, var) == 0)
+ if (strcmp(ftv[i], var) == 0)
return true;
return false;
}
-struct type *dup_type(struct type *r)
-{
- struct type *res = safe_malloc(sizeof(struct type));
- *res = *r;
- switch (r->type) {
- case tbasic:
- break;
- case tlist:
- res->data.tlist = dup_type(r->data.tlist);
- break;
- case ttuple:
- res->data.ttuple.l = dup_type(r->data.ttuple.l);
- res->data.ttuple.r = dup_type(r->data.ttuple.r);
- break;
- case tvar:
- res->data.tvar = strdup(r->data.tvar);
- break;
- }
- return res;
-}
-
-struct type *subst_apply_t(struct substitution *subst, struct type *l)
+struct subst *unify(struct type *l, struct type *r)
{
+ if (l == NULL || r == NULL)
+ return NULL;
+ if (r->type == tvar && l->type != tvar)
+ return unify(r, l);
+ struct subst *s1, *s2;
switch (l->type) {
- case tbasic:
- break;
- case tlist:
- l->data.tlist = subst_apply_t(subst, l->data.tlist);
- break;
- case ttuple:
- l->data.ttuple.l = subst_apply_t(subst, l->data.ttuple.l);
- l->data.ttuple.r = subst_apply_t(subst, l->data.ttuple.r);
- break;
- case tvar:
- for (int i = 0; i<subst->nvar; i++) {
- if (strcmp(subst->vars[i], l->data.tvar) == 0) {
- free(l->data.tvar);
- struct type *r = subst->types[i];
- *l = *r;
- free(r);
- }
+ case tarrow:
+ if (r->type == tarrow) {
+ s1 = unify(l->data.tarrow.l, r->data.tarrow.l);
+ s2 = unify(subst_apply_t(s1, l->data.tarrow.l),
+ subst_apply_t(s1, r->data.tarrow.l));
+ return subst_union(s1, s2);
}
break;
- }
- return l;
-}
-struct gamma *subst_apply_g(struct substitution *subst, struct gamma *gamma)
-{
- //TODO
- return gamma;
-}
-
-void subst_print(struct substitution *s, FILE *out)
-{
- if (s == NULL) {
- fprintf(out, "no substitution\n");
- } else {
- fprintf(out, "[");
- for (int i = 0; i<s->nvar; i++) {
- fprintf(out, "%s->", s->vars[i]);
- type_print(s->types[i], out);
- if (i + 1 < s->nvar)
- fprintf(out, ", ");
- }
- fprintf(out, "]\n");
- }
-}
-
-struct substitution *subst_id()
-{
- struct substitution *res = safe_malloc(sizeof(struct substitution));
- res->nvar = 0;
- res->vars = NULL;
- res->types = NULL;
- return res;
-}
-
-struct substitution *subst_singleton(char *ident, struct type *t)
-{
- struct substitution *res = safe_malloc(sizeof(struct substitution));
- res->nvar = 1;
- res->vars = safe_malloc(sizeof(char *));
- res->vars[0] = safe_strdup(ident);
- res->types = safe_malloc(sizeof(struct type *));
- res->types[0] = dup_type(t);
- return res;
-}
-
-struct substitution *subst_union(struct substitution *l, struct substitution *r)
-{
- struct substitution *res = safe_malloc(sizeof(struct substitution));
- res->nvar = l->nvar+r->nvar;
- res->vars = safe_malloc(res->nvar*sizeof(char *));
- res->types = safe_malloc(res->nvar*sizeof(struct type *));
- for (int i = 0; i<l->nvar; i++) {
- res->vars[i] = l->vars[i];
- res->types[i] = l->types[i];
- }
- for (int i = 0; i<r->nvar; i++) {
- res->vars[l->nvar+i] = l->vars[i];
- res->types[l->nvar+i] = subst_apply_t(l, r->types[i]);
- }
- return res;
-}
-
-struct substitution *unify(struct type *l, struct type *r) {
- switch (l->type) {
case tbasic:
if (r->type == tbasic && l->data.tbasic == r->data.tbasic)
return subst_id();
break;
case ttuple:
if (r->type == ttuple) {
- struct substitution *s1 = unify(
- l->data.ttuple.l,
- r->data.ttuple.l);
- struct substitution *s2 = unify(
- subst_apply_t(s1, l->data.ttuple.l),
+ s1 = unify(l->data.ttuple.l, r->data.ttuple.l);
+ s2 = unify(subst_apply_t(s1, l->data.ttuple.l),
subst_apply_t(s1, r->data.ttuple.l));
return subst_union(s1, s2);
}
return subst_singleton(l->data.tvar, r);
break;
}
+ fprintf(stderr, "cannot unify ");
+ type_print(l, stderr);
+ fprintf(stderr, " with ");
+ type_print(r, stderr);
+ fprintf(stderr, "\n");
return NULL;
}
-struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type)
+struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type)
{
+ fprintf(stderr, "infer expr: ");
+ expr_print(expr, stderr);
+ fprintf(stderr, "\ngamma: ");
+ gamma_print(gamma, stderr);
+ fprintf(stderr, "\ntype: ");
+ type_print(type, stderr);
+ fprintf(stderr, "\n");
#define infbop(l, r, a1, a2, rt, sigma) {\
s1 = infer_expr(gamma, l, a1);\
}
#define infbinop(e, a1, a2, rt, sigma)\
infbop(e->data.ebinop.l, e->data.ebinop.r, a1, a2, rt, sigma)
- struct substitution *s1, *s2;
+ struct subst *s1, *s2;
struct type *f1, *f2;
+ struct scheme *s;
switch (expr->type) {
case ebool:
return unify(type_basic(btbool), type);
case ebinop:
- switch(expr->data.ebinop.op) {
+ switch (expr->data.ebinop.op) {
case binor:
case binand:
infbinop(expr, type_basic(btbool), type_basic(btbool),
case le:
case geq:
case ge:
- f1 = fresh(gamma);
+ f1 = gamma_fresh(gamma);
infbinop(expr, f1, f1, type_basic(btbool), type);
case cons:
- f1 = fresh(gamma);
+ f1 = gamma_fresh(gamma);
infbinop(expr, f1, type_list(f1), type_list(f1), type);
case plus:
case minus:
infbinop(expr, type_basic(btint), type_basic(btint),
type_basic(btint), type);
}
+ break;
case echar:
return unify(type_basic(btchar), type);
case efuncall:
+ if ((s = gamma_lookup(gamma, expr->data.efuncall.ident)) == NULL)
+ die("Unbound function: %s\n", expr->data.efuncall.ident);
//TODO
+ //TODO fields
return NULL;
case eint:
return unify(type_basic(btint), type);
case eident:
-
- //TODO
- return NULL;
+ if ((s = gamma_lookup(gamma, expr->data.eident.ident)) == NULL)
+ die("Unbound variable: %s\n", expr->data.eident.ident);
+ //TODO fields
+ return unify(scheme_instantiate(gamma, s), type);
+ case enil:
+ f1 = gamma_fresh(gamma);
+ return unify(type_list(f1), type);
case etuple:
- f1 = fresh(gamma);
- f2 = fresh(gamma);
+ f1 = gamma_fresh(gamma);
+ f2 = gamma_fresh(gamma);
infbop(expr->data.etuple.left, expr->data.etuple.right,
f1, f2, type_tuple(f1, f2), type);
case estring:
case negate:
s1 = infer_expr(gamma,
expr->data.eunop.l, type_basic(btint));
+ if (s1 == NULL)
+ return NULL;
return subst_union(s1,
unify(subst_apply_t(s1, type),
type_basic(btint)));
type_basic(btbool)));
}
}
-
+ return NULL;
}