work on type inference some more
[ccc.git] / sem / hm.c
index d91715e..fb3d2ce 100644 (file)
--- a/sem/hm.c
+++ b/sem/hm.c
 #include <string.h>
 
 #include "hm.h"
-#include "../util.h"
+#include "hm/subst.h"
+#include "hm/gamma.h"
+#include "hm/scheme.h"
 #include "../ast.h"
 
-struct gamma {
-       int fresh;
-       int nschemes;
-       char **vars;
-       struct scheme *schemes;
-};
-
-struct scheme *gamma_lookup(struct gamma *gamma, char *ident)
-{
-       for (int i = 0; i<nschemes; i++) {
-               if (strcmp(ident, gamma->vars[i]) == 0) {
-                       //TODO
-               }
-       }
-       return NULL;
-}
-
-struct type *fresh(struct gamma *gamma)
-{
-       char *buf = safe_malloc(10);
-       sprintf(buf, "%d", gamma->fresh);
-       gamma->fresh++;
-       return type_var(buf);
-}
-
-void ftv_type(struct type *r, int *nftv, char **ftv)
-{
-       switch (r->type) {
-       case tbasic:
-               break;
-       case tlist:
-               ftv_type(r->data.tlist, nftv, ftv);
-               break;
-       case ttuple:
-               ftv_type(r->data.ttuple.l, nftv, ftv);
-               ftv_type(r->data.ttuple.r, nftv, ftv);
-               break;
-       case tvar:
-               *ftv = realloc(*ftv, (*nftv+1)*sizeof(char *));
-               if (*ftv == NULL)
-                       perror("realloc");
-               ftv[(*nftv)++] = r->data.tvar;
-               break;
-       }
-}
-
 bool occurs_check(char *var, struct type *r)
 {
        int nftv = 0;
-       char *ftv = NULL;
-       ftv_type(r, &nftv, &ftv);
+       char **ftv = NULL;
+       type_ftv(r, &nftv, &ftv);
        for (int i = 0; i<nftv; i++)
-               if (strcmp(ftv+i, var) == 0)
+               if (strcmp(ftv[i], var) == 0)
                        return true;
        return false;
 }
 
-struct type *dup_type(struct type *r)
-{
-       struct type *res = safe_malloc(sizeof(struct type));
-       *res = *r;
-       switch (r->type) {
-       case tbasic:
-               break;
-       case tlist:
-               res->data.tlist = dup_type(r->data.tlist);
-               break;
-       case ttuple:
-               res->data.ttuple.l = dup_type(r->data.ttuple.l);
-               res->data.ttuple.r = dup_type(r->data.ttuple.r);
-               break;
-       case tvar:
-               res->data.tvar = strdup(r->data.tvar);
-               break;
-       }
-       return res;
-}
-
-struct type *subst_apply_t(struct substitution *subst, struct type *l)
+struct subst *unify(struct type *l, struct type *r)
 {
+       if (l == NULL || r == NULL)
+               return NULL;
+       if (r->type == tvar && l->type != tvar)
+               return unify(r, l);
+       struct subst *s1, *s2;
        switch (l->type) {
-       case tbasic:
-               break;
-       case tlist:
-               l->data.tlist = subst_apply_t(subst, l->data.tlist);
-               break;
-       case ttuple:
-               l->data.ttuple.l = subst_apply_t(subst, l->data.ttuple.l);
-               l->data.ttuple.r = subst_apply_t(subst, l->data.ttuple.r);
-               break;
-       case tvar:
-               for (int i = 0; i<subst->nvar; i++) {
-                       if (strcmp(subst->vars[i], l->data.tvar) == 0) {
-                               free(l->data.tvar);
-                               struct type *r = subst->types[i];
-                               *l = *r;
-                               free(r);
-                       }
+       case tarrow:
+               if (r->type == tarrow) {
+                       s1 = unify(l->data.tarrow.l, r->data.tarrow.l);
+                       s2 = unify(subst_apply_t(s1, l->data.tarrow.l),
+                               subst_apply_t(s1, r->data.tarrow.l));
+                       return subst_union(s1, s2);
                }
                break;
-       }
-       return l;
-}
-struct gamma *subst_apply_g(struct substitution *subst, struct gamma *gamma)
-{
-       //TODO
-       return gamma;
-}
-
-void subst_print(struct substitution *s, FILE *out)
-{
-       if (s == NULL) {
-               fprintf(out, "no substitution\n");
-       } else {
-               fprintf(out, "[");
-               for (int i = 0; i<s->nvar; i++) {
-                       fprintf(out, "%s->", s->vars[i]);
-                       type_print(s->types[i], out);
-                       if (i + 1 < s->nvar)
-                               fprintf(out, ", ");
-               }
-               fprintf(out, "]\n");
-       }
-}
-
-struct substitution *subst_id()
-{
-       struct substitution *res = safe_malloc(sizeof(struct substitution));
-       res->nvar = 0;
-       res->vars = NULL;
-       res->types = NULL;
-       return res;
-}
-
-struct substitution *subst_singleton(char *ident, struct type *t)
-{
-       struct substitution *res = safe_malloc(sizeof(struct substitution));
-       res->nvar = 1;
-       res->vars = safe_malloc(sizeof(char *));
-       res->vars[0] = safe_strdup(ident);
-       res->types = safe_malloc(sizeof(struct type *));
-       res->types[0] = dup_type(t);
-       return res;
-}
-
-struct substitution *subst_union(struct substitution *l, struct substitution *r)
-{
-       struct substitution *res = safe_malloc(sizeof(struct substitution));
-       res->nvar = l->nvar+r->nvar;
-       res->vars = safe_malloc(res->nvar*sizeof(char *));
-       res->types = safe_malloc(res->nvar*sizeof(struct type *));
-       for (int i = 0; i<l->nvar; i++) {
-               res->vars[i] = l->vars[i];
-               res->types[i] = l->types[i];
-       }
-       for (int i = 0; i<r->nvar; i++) {
-               res->vars[l->nvar+i] = l->vars[i];
-               res->types[l->nvar+i] = subst_apply_t(l, r->types[i]);
-       }
-       return res;
-}
-
-struct substitution *unify(struct type *l, struct type *r) {
-       switch (l->type) {
        case tbasic:
                if (r->type == tbasic && l->data.tbasic == r->data.tbasic)
                        return subst_id();
@@ -179,11 +44,8 @@ struct substitution *unify(struct type *l, struct type *r) {
                break;
        case ttuple:
                if (r->type == ttuple) {
-                       struct substitution *s1 = unify(
-                               l->data.ttuple.l,
-                               r->data.ttuple.l);
-                       struct substitution *s2 = unify(
-                               subst_apply_t(s1, l->data.ttuple.l),
+                       s1 = unify(l->data.ttuple.l, r->data.ttuple.l);
+                       s2 = unify(subst_apply_t(s1, l->data.ttuple.l),
                                subst_apply_t(s1, r->data.ttuple.l));
                        return subst_union(s1, s2);
                }
@@ -197,11 +59,23 @@ struct substitution *unify(struct type *l, struct type *r) {
                        return subst_singleton(l->data.tvar, r);
                break;
        }
+       fprintf(stderr, "cannot unify ");
+       type_print(l, stderr);
+       fprintf(stderr, " with ");
+       type_print(r, stderr);
+       fprintf(stderr, "\n");
        return NULL;
 }
 
-struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type)
+struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type)
 {
+       fprintf(stderr, "infer expr: ");
+       expr_print(expr, stderr);
+       fprintf(stderr, "\ngamma: ");
+       gamma_print(gamma, stderr);
+       fprintf(stderr, "\ntype: ");
+       type_print(type, stderr);
+       fprintf(stderr, "\n");
 
 #define infbop(l, r, a1, a2, rt, sigma) {\
        s1 = infer_expr(gamma, l, a1);\
@@ -210,13 +84,14 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t
 }
 #define infbinop(e, a1, a2, rt, sigma)\
        infbop(e->data.ebinop.l, e->data.ebinop.r, a1, a2, rt, sigma)
-       struct substitution *s1, *s2;
+       struct subst *s1, *s2;
        struct type *f1, *f2;
+       struct scheme *s;
        switch (expr->type) {
        case ebool:
                return unify(type_basic(btbool), type);
        case ebinop:
-               switch(expr->data.ebinop.op) {
+               switch (expr->data.ebinop.op) {
                case binor:
                case binand:
                        infbinop(expr, type_basic(btbool), type_basic(btbool),
@@ -227,10 +102,10 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t
                case le:
                case geq:
                case ge:
-                       f1 = fresh(gamma);
+                       f1 = gamma_fresh(gamma);
                        infbinop(expr, f1, f1, type_basic(btbool), type);
                case cons:
-                       f1 = fresh(gamma);
+                       f1 = gamma_fresh(gamma);
                        infbinop(expr, f1, type_list(f1), type_list(f1), type);
                case plus:
                case minus:
@@ -241,20 +116,28 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t
                        infbinop(expr, type_basic(btint), type_basic(btint),
                                type_basic(btint), type);
                }
+               break;
        case echar:
                return unify(type_basic(btchar), type);
        case efuncall:
+               if ((s = gamma_lookup(gamma, expr->data.efuncall.ident)) == NULL)
+                       die("Unbound function: %s\n", expr->data.efuncall.ident);
                //TODO
+               //TODO fields
                return NULL;
        case eint:
                return unify(type_basic(btint), type);
        case eident:
-
-               //TODO
-               return NULL;
+               if ((s = gamma_lookup(gamma, expr->data.eident.ident)) == NULL)
+                       die("Unbound variable: %s\n", expr->data.eident.ident);
+               //TODO fields
+               return unify(scheme_instantiate(gamma, s), type);
+       case enil:
+               f1 = gamma_fresh(gamma);
+               return unify(type_list(f1), type);
        case etuple:
-               f1 = fresh(gamma);
-               f2 = fresh(gamma);
+               f1 = gamma_fresh(gamma);
+               f2 = gamma_fresh(gamma);
                infbop(expr->data.etuple.left, expr->data.etuple.right,
                       f1, f2, type_tuple(f1, f2), type);
        case estring:
@@ -264,6 +147,8 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t
                case negate:
                        s1 = infer_expr(gamma,
                                expr->data.eunop.l, type_basic(btint));
+                       if (s1 == NULL)
+                               return NULL;
                        return subst_union(s1,
                                unify(subst_apply_t(s1, type),
                                type_basic(btint)));
@@ -275,5 +160,5 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t
                                type_basic(btbool)));
                }
        }
-
+       return NULL;
 }