more type checking
authorMart Lubbers <mart@martlubbers.net>
Tue, 23 Feb 2021 14:33:57 +0000 (15:33 +0100)
committerMart Lubbers <mart@martlubbers.net>
Tue, 23 Feb 2021 14:33:57 +0000 (15:33 +0100)
sem.c
sem/hm.c
sem/hm.h
sem/hm/gamma.c
sem/hm/scheme.c
sem/hm/subst.c
sem/hm/subst.h
sem/scc.c
util.c
util.h

diff --git a/sem.c b/sem.c
index 3404fb4..5c9a528 100644 (file)
--- a/sem.c
+++ b/sem.c
@@ -38,14 +38,14 @@ void check_expr_constant(struct expr *expr)
 struct vardecl *type_vardecl(struct gamma *gamma, struct vardecl *vardecl)
 {
        struct type *t = vardecl->type == NULL
-               ? gamma_fresh(gamma) : type_dup(vardecl->type);
+               ? gamma_fresh(gamma) : vardecl->type;
        struct subst *s = infer_expr(gamma, vardecl->expr, t);
 
        if (s == NULL)
                die("error inferring variable\n");
        vardecl->type = subst_apply_t(s, t);
 
-       //subst_free(s);
+       subst_free(s);
 
        return vardecl;
 }
@@ -66,9 +66,11 @@ struct ast *sem(struct ast *ast)
                        type_vardecl(gamma, ast->decls[i]->data.dvar);
                        break;
                case dfundecl: {
-//                     struct type *f1 = gamma_fresh(gamma);
-//                     gamma_insert(gamma, ast->decls[i]->data.dfun->ident
-//                             , scheme_create(f1));
+                       struct type *f1 = gamma_fresh(gamma);
+                       gamma_insert(gamma, ast->decls[i]->data.dfun->ident
+                               , scheme_create(f1));
+                       struct subst *s = infer_fundecl(gamma, ast->decls[i]->data.dfun);
+                       subst_free(s);
 //infer env (Let [(x, e1)] e2)
 //     =              fresh
 //     >>= \tv->      let env` = 'Data.Map'.put x (Forall [] tv) env
index fb3d2ce..4b88076 100644 (file)
--- a/sem/hm.c
+++ b/sem/hm.c
@@ -11,11 +11,16 @@ bool occurs_check(char *var, struct type *r)
 {
        int nftv = 0;
        char **ftv = NULL;
+       bool res = false;
        type_ftv(r, &nftv, &ftv);
-       for (int i = 0; i<nftv; i++)
-               if (strcmp(ftv[i], var) == 0)
-                       return true;
-       return false;
+       for (int i = 0; i<nftv; i++) {
+               if (strcmp(ftv[i], var) == 0) {
+                       res = true;
+                       break;
+               }
+       }
+       free(ftv);
+       return res;
 }
 
 struct subst *unify(struct type *l, struct type *r)
@@ -24,14 +29,15 @@ struct subst *unify(struct type *l, struct type *r)
                return NULL;
        if (r->type == tvar && l->type != tvar)
                return unify(r, l);
-       struct subst *s1, *s2;
+       struct subst *s1, *s2, *s3;
        switch (l->type) {
        case tarrow:
                if (r->type == tarrow) {
                        s1 = unify(l->data.tarrow.l, r->data.tarrow.l);
                        s2 = unify(subst_apply_t(s1, l->data.tarrow.l),
                                subst_apply_t(s1, r->data.tarrow.l));
-                       return subst_union(s1, s2);
+                       s3 = subst_union(s1, s2);
+                       return s3;
                }
                break;
        case tbasic:
@@ -47,7 +53,10 @@ struct subst *unify(struct type *l, struct type *r)
                        s1 = unify(l->data.ttuple.l, r->data.ttuple.l);
                        s2 = unify(subst_apply_t(s1, l->data.ttuple.l),
                                subst_apply_t(s1, r->data.ttuple.l));
-                       return subst_union(s1, s2);
+                       s3 = subst_union(s1, s2);
+                       subst_free(s1);
+                       subst_free(s2);
+                       return s3;
                }
                break;
        case tvar:
@@ -67,6 +76,47 @@ struct subst *unify(struct type *l, struct type *r)
        return NULL;
 }
 
+struct subst *unifyfree(struct type *l, struct type *r, bool freel, bool freer)
+{
+       struct subst *s = unify(l, r);
+       if (freel)
+               type_free(l);
+       if (freer)
+               type_free(r);
+       return s;
+}
+
+struct subst *infer_binop(struct gamma *gamma, struct expr *l, struct expr *r,
+       struct type *a1, struct type *a2, struct type *rt, struct type *sigma)
+{
+       struct subst *s1 = infer_expr(gamma, l, a1);
+       struct subst *s2 = infer_expr(subst_apply_g(s1, gamma), r, a2);
+       struct subst *s3 = subst_union(s1, s2);
+       struct subst *s4 = unify(subst_apply_t(s3, sigma), rt);
+       struct subst *s5 = subst_union(s3, s4);
+       subst_free(s1);
+       subst_free(s2);
+       subst_free(s3);
+       subst_free(s4);
+       return s5;
+}
+
+struct subst *infer_unop(struct gamma *gamma, struct expr *e,
+       struct type *a, struct type  *rt, struct type *sigma)
+{
+       struct subst *s1 = infer_expr(gamma, e, a);
+       struct subst *s2 = unify(subst_apply_t(s1, sigma), rt);
+       struct subst *s3 = subst_union(s1, s2);
+       subst_free(s1);
+       subst_free(s2);
+       return s3;
+}
+
+static struct type tybool = {.type=tbasic, .data={.tbasic=btbool}};
+static struct type tychar = {.type=tbasic, .data={.tbasic=btchar}};
+static struct type tyint = {.type=tbasic, .data={.tbasic=btint}};
+static struct type tystring = {.type=tlist, .data={.tlist=&tychar}};
+
 struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type)
 {
        fprintf(stderr, "infer expr: ");
@@ -77,25 +127,19 @@ struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *ty
        type_print(type, stderr);
        fprintf(stderr, "\n");
 
-#define infbop(l, r, a1, a2, rt, sigma) {\
-       s1 = infer_expr(gamma, l, a1);\
-       s2 = subst_union(s1, infer_expr(subst_apply_g(s1, gamma), r, a2));\
-       return subst_union(s2, unify(subst_apply_t(s2, sigma), rt));\
-}
-#define infbinop(e, a1, a2, rt, sigma)\
-       infbop(e->data.ebinop.l, e->data.ebinop.r, a1, a2, rt, sigma)
-       struct subst *s1, *s2;
-       struct type *f1, *f2;
+#define infbinop(e, a1, a2, rt) infer_binop(\
+       gamma, e->data.ebinop.l, e->data.ebinop.r, a1, a2, rt, type)
+       struct subst *s1;
+       struct type *f1, *f2, *f3;
        struct scheme *s;
        switch (expr->type) {
        case ebool:
-               return unify(type_basic(btbool), type);
+               return unify(&tybool, type);
        case ebinop:
                switch (expr->data.ebinop.op) {
                case binor:
                case binand:
-                       infbinop(expr, type_basic(btbool), type_basic(btbool),
-                               type_basic(btbool), type);
+                       return infbinop(expr, &tybool, &tybool, &tybool);
                case eq:
                case neq:
                case leq:
@@ -103,22 +147,26 @@ struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *ty
                case geq:
                case ge:
                        f1 = gamma_fresh(gamma);
-                       infbinop(expr, f1, f1, type_basic(btbool), type);
+                       s1 = infbinop(expr, f1, f1, &tybool);
+                       type_free(f1);
+                       return s1;
                case cons:
                        f1 = gamma_fresh(gamma);
-                       infbinop(expr, f1, type_list(f1), type_list(f1), type);
+                       f2 = type_list(f1);
+                       s1 = infbinop(expr, f1, f2, f2);
+                       type_free(f2);
+                       return s1;
                case plus:
                case minus:
                case times:
                case divide:
                case modulo:
                case power:
-                       infbinop(expr, type_basic(btint), type_basic(btint),
-                               type_basic(btint), type);
+                       return infbinop(expr, &tyint, &tyint, &tyint);
                }
                break;
        case echar:
-               return unify(type_basic(btchar), type);
+               return unify(&tychar, type);
        case efuncall:
                if ((s = gamma_lookup(gamma, expr->data.efuncall.ident)) == NULL)
                        die("Unbound function: %s\n", expr->data.efuncall.ident);
@@ -126,39 +174,147 @@ struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *ty
                //TODO fields
                return NULL;
        case eint:
-               return unify(type_basic(btint), type);
+               return unify(&tyint, type);
        case eident:
                if ((s = gamma_lookup(gamma, expr->data.eident.ident)) == NULL)
                        die("Unbound variable: %s\n", expr->data.eident.ident);
-               //TODO fields
-               return unify(scheme_instantiate(gamma, s), type);
+               f1 = scheme_instantiate(gamma, s);
+               s1 = unify(f1, type);
+               type_free(f1);
+               //TODO field
+               return s1;
        case enil:
                f1 = gamma_fresh(gamma);
-               return unify(type_list(f1), type);
+               return unifyfree(type_list(f1), type, true, false);
        case etuple:
                f1 = gamma_fresh(gamma);
                f2 = gamma_fresh(gamma);
-               infbop(expr->data.etuple.left, expr->data.etuple.right,
-                      f1, f2, type_tuple(f1, f2), type);
+               f3 = type_tuple(f1, f2);
+               s1 = infer_binop(gamma, expr->data.etuple.left,
+                       expr->data.etuple.right, f1, f2, f3, type);
+               type_free(f3);
+               return s1;
        case estring:
-               return unify(type_list(type_basic(btchar)), type);
+               return unify(&tystring, type);
        case eunop:
-               switch(expr->data.eunop.op) {
+               switch (expr->data.eunop.op) {
                case negate:
-                       s1 = infer_expr(gamma,
-                               expr->data.eunop.l, type_basic(btint));
-                       if (s1 == NULL)
-                               return NULL;
-                       return subst_union(s1,
-                               unify(subst_apply_t(s1, type),
-                               type_basic(btint)));
+                       return infer_unop(gamma, expr->data.eunop.l,
+                               &tyint, &tyint, type);
                case inverse:
-                       s1 = infer_expr(gamma,
-                               expr->data.eunop.l, type_basic(btbool));
-                       return subst_union(s1,
-                               unify(subst_apply_t(s1, type),
-                               type_basic(btbool)));
+                       return infer_unop(gamma, expr->data.eunop.l,
+                               &tybool, &tybool, type);
                }
        }
        return NULL;
 }
+
+struct subst *infer_stmt(struct gamma *gamma, struct stmt *stmt, struct type *type)
+{
+       fprintf(stderr, "infer stmt: ");
+       stmt_print(stmt, 0, stderr);
+       fprintf(stderr, "\ngamma: ");
+       gamma_print(gamma, stderr);
+       fprintf(stderr, "\ntype: ");
+       type_print(type, stderr);
+       fprintf(stderr, "\n");
+
+//     struct subst *s1, *s2;
+//     struct type *f1, *f2, *f3;
+//     struct scheme *s;
+       switch (stmt->type) {
+       case sassign:
+               break;
+       case sif:
+               break;
+       case sreturn:
+               return infer_expr(gamma, stmt->data.sreturn, type);
+       case sexpr:
+               break;
+       case svardecl:
+               break;
+       case swhile:
+               break;
+       }
+       return subst_id();
+//     union {
+//             struct {
+//                     char *ident;
+//                     int nfields;
+//                     char **fields;
+//                     struct expr *expr;
+//             } sassign;
+//             struct {
+//                     struct expr *pred;
+//                     int nthen;
+//                     struct stmt **then;
+//                     int nels;
+//                     struct stmt **els;
+//             } sif;
+//             struct vardecl *svardecl;
+//             struct expr *sreturn;
+//             struct expr *sexpr;
+//             struct {
+//                     struct expr *pred;
+//                     int nbody;
+//                     struct stmt **body;
+//             } swhile;
+//     } data;
+}
+
+struct subst *infer_fundecl(struct gamma *gamma, struct fundecl *fundecl)
+{
+       //struct type *t;
+       if (fundecl->rtype == NULL || fundecl->atypes == NULL) {
+               fundecl->rtype = gamma_fresh(gamma);
+               fundecl->atypes = safe_realloc(fundecl->atypes, fundecl->nargs*sizeof(struct type));
+               for (int i = 0; i<fundecl->nargs; i++)
+                       fundecl->atypes[i] = gamma_fresh(gamma);
+       }
+       fprintf(stderr, "fundecl with type: ");
+       for (int i = 0; i<fundecl->nargs; i++) {
+               type_print(fundecl->atypes[i], stderr);
+               fprintf(stderr, " ");
+       }
+       fprintf(stderr, "-> ");
+       type_print(fundecl->rtype, stderr);
+       fprintf(stderr, "\n");
+
+       for (int i = 0; i<fundecl->nargs; i++)
+               gamma_insert(gamma, fundecl->args[i],
+                       scheme_create(fundecl->atypes[i]));
+
+       struct subst *s = subst_id();
+       for (int i = 0; i<fundecl->nbody; i++) {
+               struct subst *s1 = infer_stmt(gamma, fundecl->body[i], fundecl->rtype);
+               struct subst *s2 = s;
+               s = subst_union(s2, s1);
+               subst_free(s1);
+               subst_free(s2);
+       }
+
+       fprintf(stderr, "inferred function substitution: ");
+       subst_print(s, stderr);
+
+       for (int i = 0; i<fundecl->nargs; i++)
+               fundecl->atypes[i] = subst_apply_t(s, fundecl->atypes[i]);
+       fundecl->rtype = subst_apply_t(s, fundecl->rtype);
+       fprintf(stderr, "fundecl with type: ");
+       for (int i = 0; i<fundecl->nargs; i++) {
+               type_print(fundecl->atypes[i], stderr);
+               fprintf(stderr, " ");
+       }
+       fprintf(stderr, "-> ");
+       type_print(fundecl->rtype, stderr);
+       fprintf(stderr, "\n");
+       //char *ident;
+       //int nargs;
+       //char **args;
+       //int natypes;
+       //struct type **atypes;
+       //struct type *rtype;
+       //int nbody;
+       //struct stmt **body;
+
+       return s;
+}
index 6792c3b..106655b 100644 (file)
--- a/sem/hm.h
+++ b/sem/hm.h
@@ -8,5 +8,7 @@
 
 struct ast *infer(struct ast *ast);
 struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type);
+struct subst *infer_stmt(struct gamma *gamma, struct stmt *stmt, struct type *type);
+struct subst *infer_fundecl(struct gamma *gamma, struct fundecl *fundecl);
 
 #endif
index 9a97c27..4202154 100644 (file)
@@ -15,6 +15,13 @@ struct gamma *gamma_init()
 
 void gamma_insert(struct gamma *gamma, char *ident, struct scheme *scheme)
 {
+       for (int i = 0; i<gamma->nschemes; i++) {
+               if(strcmp(gamma->vars[i], ident) == 0) {
+                       scheme_free(gamma->schemes[i]);
+                       gamma->schemes[i] = scheme;
+                       return;
+               }
+       }
        gamma->nschemes++;
        gamma->vars = realloc(gamma->vars, gamma->nschemes*sizeof(char *));
        gamma->schemes = realloc(gamma->schemes,
index 177bb64..41d5401 100644 (file)
@@ -22,6 +22,7 @@ struct scheme *scheme_create(struct type *t)
        s->type = t;
        s->nvar = 0;
        s->var = NULL;
+       return s;
 }
 
 struct scheme *scheme_generalise(struct gamma *gamma, struct type *t)
index 55d6cfe..e8fb35c 100644 (file)
@@ -32,12 +32,12 @@ struct subst *subst_union(struct subst *l, struct subst *r)
        res->vars = safe_malloc(res->nvar*sizeof(char *));
        res->types = safe_malloc(res->nvar*sizeof(struct type *));
        for (int i = 0; i<l->nvar; i++) {
-               res->vars[i] = l->vars[i];
-               res->types[i] = l->types[i];
+               res->vars[i] = safe_strdup(l->vars[i]);
+               res->types[i] = type_dup(l->types[i]);
        }
        for (int i = 0; i<r->nvar; i++) {
-               res->vars[l->nvar+i] = r->vars[i];
-               res->types[l->nvar+i] = subst_apply_t(l, r->types[i]);
+               res->vars[l->nvar+i] = safe_strdup(r->vars[i]);
+               res->types[l->nvar+i] = subst_apply_t(l, type_dup(r->types[i]));
        }
        return res;
 }
@@ -112,13 +112,15 @@ void subst_print(struct subst *s, FILE *out)
        }
 }
 
-void subst_free(struct subst *s, bool type)
+void subst_free(struct subst *s)
 {
        if (s != NULL) {
                for (int i = 0; i<s->nvar; i++) {
                        free(s->vars[i]);
-                       if (type)
-                               type_free(s->types[i]);
+                       type_free(s->types[i]);
                }
+               free(s->vars);
+               free(s->types);
+               free(s);
        }
 }
index ffec9a7..0fba436 100644 (file)
@@ -19,6 +19,6 @@ struct scheme *subst_apply_s(struct subst *subst, struct scheme *scheme);
 struct gamma *subst_apply_g(struct subst *subst, struct gamma *gamma);
 
 void subst_print(struct subst *s, FILE *out);
-void subst_free(struct subst *s, bool type);
+void subst_free(struct subst *s);
 
 #endif
index 6d7b02e..f728f4f 100644 (file)
--- a/sem/scc.c
+++ b/sem/scc.c
@@ -271,7 +271,6 @@ struct ast *ast_scc(struct ast *ast)
        struct edge **edata = (struct edge **)
                list_to_array(edges, &nedges, false);
 
-       fprintf(stderr, "nfun: %d, ffun: %d, nedges: %d\n", nfun, ffun, nedges);
        // Do tarjan's and convert back into the declaration list
        struct components *cs = tarjans(nfun, (void **)fundecls, nedges, edata);
 //     if (cs == NULL)
diff --git a/util.c b/util.c
index fa60ddc..6c70c63 100644 (file)
--- a/util.c
+++ b/util.c
@@ -152,6 +152,14 @@ void *safe_malloc(size_t size)
        return res;
 }
 
+void *safe_realloc(void *ptr, size_t size)
+{
+       void *res = realloc(ptr, size);
+       if (res == NULL)
+               pdie("realloc");
+       return res;
+}
+
 void *safe_strdup(const char *c)
 {
        size_t nchar = strlen(c);
diff --git a/util.h b/util.h
index ac47fa6..6704cf4 100644 (file)
--- a/util.h
+++ b/util.h
@@ -20,6 +20,7 @@ void pindent(int indent, FILE *out);
 
 void safe_fprintf(FILE *out, const char *msg, ...);
 void *safe_malloc(size_t size);
+void *safe_realloc(void *ptr, size_t size);
 void *safe_strdup(const char *c);
 FILE *safe_fopen(const char *path, const char *mode);
 void safe_fclose(FILE *file);