work on type inference some more develop master
authorMart Lubbers <mart@martlubbers.net>
Wed, 17 Feb 2021 10:43:04 +0000 (11:43 +0100)
committerMart Lubbers <mart@martlubbers.net>
Wed, 17 Feb 2021 10:43:04 +0000 (11:43 +0100)
18 files changed:
Makefile
ast.c
ast.h
sem.c
sem/hm.c
sem/hm.h
sem/hm/gamma.c [new file with mode: 0644]
sem/hm/gamma.h [new file with mode: 0644]
sem/hm/scheme.c [new file with mode: 0644]
sem/hm/scheme.h [new file with mode: 0644]
sem/hm/subst.c [new file with mode: 0644]
sem/hm/subst.h [new file with mode: 0644]
sem/scc.c
test/Makefile [new file with mode: 0644]
test/test_sem_hm_gamma.c [new file with mode: 0644]
type.c [new file with mode: 0644]
type.h [new file with mode: 0644]
util.c

index 74272aa..5328897 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -1,9 +1,10 @@
-CFLAGS+=-Wall -Wextra -std=c99 -pedantic -D_XOPEN_SOURCE=700 -ggdb
+CFLAGS+=-Wall -Wextra -std=c99 -pedantic -ggdb
 YFLAGS+=-d --locations -v --defines=parse.h
 LFLAGS+=--header-file=scan.h
 
-OBJECTS:=scan.o parse.o ast.o util.o list.o sem.o genc.o \
-       sem/scc.o sem/hm.o
+OBJECTS:=scan.o parse.o ast.o type.o util.o list.o sem.o genc.o \
+       sem/scc.o\
+       $(addprefix sem/hm, .o /gamma.o /subst.o /scheme.o)
 
 all: splc
 splc: $(OBJECTS)
@@ -11,5 +12,13 @@ scan.c: scan.l parse.h
 parse.h: parse.c
 expr.c: y.tab.h
 
+scan.o: CFLAGS+=-D_XOPEN_SOURCE=700
+
+.PHONY: test
+
+test:
+       CFLAGS="$(CFLAGS)" $(MAKE) -C test test
+
 clean:
        $(RM) $(OBJECTS) y.output parse.h scan.h scan.c parse.c expr a.c
+       $(MAKE) -C test clean
diff --git a/ast.c b/ast.c
index bb8a541..1e93a71 100644 (file)
--- a/ast.c
+++ b/ast.c
@@ -4,6 +4,7 @@
 
 #include "util.h"
 #include "ast.h"
+#include "type.h"
 #include "list.h"
 #include "parse.h"
 
@@ -16,10 +17,6 @@ const char *binop_str[] = {
 const char *fieldspec_str[] = {
        [fst] = "fst", [snd] = "snd", [hd] = "hd", [tl] = "tl"};
 const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
-static const char *basictype_str[] = {
-       [btbool] = "Bool", [btchar] = "Char", [btint] = "Int",
-       [btvoid] = "Void",
-};
 
 struct ast *ast(struct list *decls)
 {
@@ -237,57 +234,6 @@ struct expr *expr_unop(enum unop op, struct expr *l)
        return res;
 }
 
-struct type *type_basic(enum basictype type)
-{
-       struct type *res = safe_malloc(sizeof(struct type));
-       res->type = tbasic;
-       res->data.tbasic = type;
-       return res;
-}
-
-struct type *type_list(struct type *type)
-{
-       struct type *res = safe_malloc(sizeof(struct type));
-       res->type = tlist;
-       res->data.tlist = type;
-       return res;
-}
-
-struct type *type_tuple(struct type *l, struct type *r)
-{
-       struct type *res = safe_malloc(sizeof(struct type));
-       res->type = ttuple;
-       res->data.ttuple.l = l;
-       res->data.ttuple.r = r;
-       return res;
-}
-
-struct type *type_var(char *ident)
-{
-       struct type *res = safe_malloc(sizeof(struct type));
-       if (strcmp(ident, "Int") == 0) {
-               res->type = tbasic;
-               res->data.tbasic = btint;
-               free(ident);
-       } else if (strcmp(ident, "Char") == 0) {
-               res->type = tbasic;
-               res->data.tbasic = btchar;
-               free(ident);
-       } else if (strcmp(ident, "Bool") == 0) {
-               res->type = tbasic;
-               res->data.tbasic = btbool;
-               free(ident);
-       } else if (strcmp(ident, "Void") == 0) {
-               res->type = tbasic;
-               res->data.tbasic = btvoid;
-               free(ident);
-       } else {
-               res->type = tvar;
-               res->data.tvar = ident;
-       }
-       return res;
-}
-
 void ast_print(struct ast *ast, FILE *out)
 {
        if (ast == NULL)
@@ -479,34 +425,6 @@ void expr_print(struct expr *expr, FILE *out)
        }
 }
 
-void type_print(struct type *type, FILE *out)
-{
-       if (type == NULL)
-               return;
-       switch (type->type) {
-       case tbasic:
-               safe_fprintf(out, "%s", basictype_str[type->data.tbasic]);
-               break;
-       case tlist:
-               safe_fprintf(out, "[");
-               type_print(type->data.tlist, out);
-               safe_fprintf(out, "]");
-               break;
-       case ttuple:
-               safe_fprintf(out, "(");
-               type_print(type->data.ttuple.l, out);
-               safe_fprintf(out, ",");
-               type_print(type->data.ttuple.r, out);
-               safe_fprintf(out, ")");
-               break;
-       case tvar:
-               safe_fprintf(out, "%s", type->data.tvar);
-               break;
-       default:
-               die("Unsupported type node\n");
-       }
-}
-
 void ast_free(struct ast *ast)
 {
        if (ast == NULL)
@@ -648,26 +566,3 @@ void expr_free(struct expr *expr)
        }
        free(expr);
 }
-
-void type_free(struct type *type)
-{
-       if (type == NULL)
-               return;
-       switch (type->type) {
-       case tbasic:
-               break;
-       case tlist:
-               type_free(type->data.tlist);
-               break;
-       case ttuple:
-               type_free(type->data.ttuple.l);
-               type_free(type->data.ttuple.r);
-               break;
-       case tvar:
-               free(type->data.tvar);
-               break;
-       default:
-               die("Unsupported type node\n");
-       }
-       free(type);
-}
diff --git a/ast.h b/ast.h
index db69333..635d0c8 100644 (file)
--- a/ast.h
+++ b/ast.h
@@ -5,6 +5,7 @@
 #include <stdbool.h>
 
 #include "util.h"
+#include "type.h"
 struct ast;
 #include "parse.h"
 
@@ -32,20 +33,6 @@ struct fundecl {
        struct stmt **body;
 };
 
-enum basictype {btbool, btchar, btint, btvoid};
-struct type {
-       enum {tbasic,tlist,ttuple,tvar} type;
-       union {
-               enum basictype tbasic;
-               struct type *tlist;
-               struct {
-                       struct type *l;
-                       struct type *r;
-               } ttuple;
-               char *tvar;
-       } data;
-};
-
 struct decl {
        //NOTE: DON'T CHANGE THIS ORDER
        enum {dcomp, dvardecl, dfundecl} type;
@@ -158,18 +145,12 @@ struct expr *expr_tuple(struct expr *left, struct expr *right);
 struct expr *expr_string(char *str);
 struct expr *expr_unop(enum unop op, struct expr *l);
 
-struct type *type_basic(enum basictype type);
-struct type *type_list(struct type *type);
-struct type *type_tuple(struct type *l, struct type *r);
-struct type *type_var(char *ident);
-
 void ast_print(struct ast *, FILE *out);
 void vardecl_print(struct vardecl *decl, int indent, FILE *out);
 void fundecl_print(struct fundecl *decl, FILE *out);
 void decl_print(struct decl *ast, FILE *out);
 void stmt_print(struct stmt *ast, int indent, FILE *out);
 void expr_print(struct expr *ast, FILE *out);
-void type_print(struct type *type, FILE *out);
 
 void ast_free(struct ast *);
 void vardecl_free(struct vardecl *decl);
@@ -177,6 +158,5 @@ void fundecl_free(struct fundecl *fundecl);
 void decl_free(struct decl *ast);
 void stmt_free(struct stmt *ast);
 void expr_free(struct expr *ast);
-void type_free(struct type *type);
 
 #endif
diff --git a/sem.c b/sem.c
index 1cea48f..3404fb4 100644 (file)
--- a/sem.c
+++ b/sem.c
@@ -3,6 +3,7 @@
 
 #include "list.h"
 #include "sem/scc.h"
+#include "sem/hm.h"
 #include "ast.h"
 
 void type_error(const char *msg, ...)
@@ -34,38 +35,54 @@ void check_expr_constant(struct expr *expr)
        }
 }
 
-struct vardecl *type_vardecl(struct vardecl *vardecl)
+struct vardecl *type_vardecl(struct gamma *gamma, struct vardecl *vardecl)
 {
-       return vardecl;
-}
+       struct type *t = vardecl->type == NULL
+               ? gamma_fresh(gamma) : type_dup(vardecl->type);
+       struct subst *s = infer_expr(gamma, vardecl->expr, t);
 
-struct decl *type_decl(struct decl *decl)
-{
-       switch (decl->type) {
-       case dcomp:
-               fprintf(stderr, "type_decl:component unsupported\n");
-               break;
-       case dfundecl:
-               fprintf(stderr, "type_decl:fundecl unsupported\n");
-               break;
-       case dvardecl:
-               decl->data.dvar = type_vardecl(decl->data.dvar);
-               break;
-       }
-       return decl;
+       if (s == NULL)
+               die("error inferring variable\n");
+       vardecl->type = subst_apply_t(s, t);
+
+       //subst_free(s);
+
+       return vardecl;
 }
 
 struct ast *sem(struct ast *ast)
 {
        ast = ast_scc(ast);
 
-       //Check that all globals are constant
+       struct gamma *gamma = gamma_init();
+
+       //Check all vardecls
        for (int i = 0; i<ast->ndecls; i++) {
-               if (ast->decls[i]->type == dvardecl) {
-                       //Check globals
+               switch(ast->decls[i]->type) {
+               case dvardecl:
+                       //Check if constant
                        check_expr_constant(ast->decls[i]->data.dvar->expr);
+                       //Infer if necessary
+                       type_vardecl(gamma, ast->decls[i]->data.dvar);
+                       break;
+               case dfundecl: {
+//                     struct type *f1 = gamma_fresh(gamma);
+//                     gamma_insert(gamma, ast->decls[i]->data.dfun->ident
+//                             , scheme_create(f1));
+//infer env (Let [(x, e1)] e2)
+//     =              fresh
+//     >>= \tv->      let env` = 'Data.Map'.put x (Forall [] tv) env
+//                    in infer env` e1
+//     >>= \(s1,t1)-> infer ('Data.Map'.put x (generalize (apply s1 env`) t1) env`) e2
+//     >>= \(s2, t2)->pure (s1 oo s2, t2)
+                       break;
+               }
+               case dcomp:
                        break;
                }
        }
+
+       gamma_free(gamma);
+
        return ast;
 }
index d91715e..fb3d2ce 100644 (file)
--- a/sem/hm.c
+++ b/sem/hm.c
 #include <string.h>
 
 #include "hm.h"
-#include "../util.h"
+#include "hm/subst.h"
+#include "hm/gamma.h"
+#include "hm/scheme.h"
 #include "../ast.h"
 
-struct gamma {
-       int fresh;
-       int nschemes;
-       char **vars;
-       struct scheme *schemes;
-};
-
-struct scheme *gamma_lookup(struct gamma *gamma, char *ident)
-{
-       for (int i = 0; i<nschemes; i++) {
-               if (strcmp(ident, gamma->vars[i]) == 0) {
-                       //TODO
-               }
-       }
-       return NULL;
-}
-
-struct type *fresh(struct gamma *gamma)
-{
-       char *buf = safe_malloc(10);
-       sprintf(buf, "%d", gamma->fresh);
-       gamma->fresh++;
-       return type_var(buf);
-}
-
-void ftv_type(struct type *r, int *nftv, char **ftv)
-{
-       switch (r->type) {
-       case tbasic:
-               break;
-       case tlist:
-               ftv_type(r->data.tlist, nftv, ftv);
-               break;
-       case ttuple:
-               ftv_type(r->data.ttuple.l, nftv, ftv);
-               ftv_type(r->data.ttuple.r, nftv, ftv);
-               break;
-       case tvar:
-               *ftv = realloc(*ftv, (*nftv+1)*sizeof(char *));
-               if (*ftv == NULL)
-                       perror("realloc");
-               ftv[(*nftv)++] = r->data.tvar;
-               break;
-       }
-}
-
 bool occurs_check(char *var, struct type *r)
 {
        int nftv = 0;
-       char *ftv = NULL;
-       ftv_type(r, &nftv, &ftv);
+       char **ftv = NULL;
+       type_ftv(r, &nftv, &ftv);
        for (int i = 0; i<nftv; i++)
-               if (strcmp(ftv+i, var) == 0)
+               if (strcmp(ftv[i], var) == 0)
                        return true;
        return false;
 }
 
-struct type *dup_type(struct type *r)
-{
-       struct type *res = safe_malloc(sizeof(struct type));
-       *res = *r;
-       switch (r->type) {
-       case tbasic:
-               break;
-       case tlist:
-               res->data.tlist = dup_type(r->data.tlist);
-               break;
-       case ttuple:
-               res->data.ttuple.l = dup_type(r->data.ttuple.l);
-               res->data.ttuple.r = dup_type(r->data.ttuple.r);
-               break;
-       case tvar:
-               res->data.tvar = strdup(r->data.tvar);
-               break;
-       }
-       return res;
-}
-
-struct type *subst_apply_t(struct substitution *subst, struct type *l)
+struct subst *unify(struct type *l, struct type *r)
 {
+       if (l == NULL || r == NULL)
+               return NULL;
+       if (r->type == tvar && l->type != tvar)
+               return unify(r, l);
+       struct subst *s1, *s2;
        switch (l->type) {
-       case tbasic:
-               break;
-       case tlist:
-               l->data.tlist = subst_apply_t(subst, l->data.tlist);
-               break;
-       case ttuple:
-               l->data.ttuple.l = subst_apply_t(subst, l->data.ttuple.l);
-               l->data.ttuple.r = subst_apply_t(subst, l->data.ttuple.r);
-               break;
-       case tvar:
-               for (int i = 0; i<subst->nvar; i++) {
-                       if (strcmp(subst->vars[i], l->data.tvar) == 0) {
-                               free(l->data.tvar);
-                               struct type *r = subst->types[i];
-                               *l = *r;
-                               free(r);
-                       }
+       case tarrow:
+               if (r->type == tarrow) {
+                       s1 = unify(l->data.tarrow.l, r->data.tarrow.l);
+                       s2 = unify(subst_apply_t(s1, l->data.tarrow.l),
+                               subst_apply_t(s1, r->data.tarrow.l));
+                       return subst_union(s1, s2);
                }
                break;
-       }
-       return l;
-}
-struct gamma *subst_apply_g(struct substitution *subst, struct gamma *gamma)
-{
-       //TODO
-       return gamma;
-}
-
-void subst_print(struct substitution *s, FILE *out)
-{
-       if (s == NULL) {
-               fprintf(out, "no substitution\n");
-       } else {
-               fprintf(out, "[");
-               for (int i = 0; i<s->nvar; i++) {
-                       fprintf(out, "%s->", s->vars[i]);
-                       type_print(s->types[i], out);
-                       if (i + 1 < s->nvar)
-                               fprintf(out, ", ");
-               }
-               fprintf(out, "]\n");
-       }
-}
-
-struct substitution *subst_id()
-{
-       struct substitution *res = safe_malloc(sizeof(struct substitution));
-       res->nvar = 0;
-       res->vars = NULL;
-       res->types = NULL;
-       return res;
-}
-
-struct substitution *subst_singleton(char *ident, struct type *t)
-{
-       struct substitution *res = safe_malloc(sizeof(struct substitution));
-       res->nvar = 1;
-       res->vars = safe_malloc(sizeof(char *));
-       res->vars[0] = safe_strdup(ident);
-       res->types = safe_malloc(sizeof(struct type *));
-       res->types[0] = dup_type(t);
-       return res;
-}
-
-struct substitution *subst_union(struct substitution *l, struct substitution *r)
-{
-       struct substitution *res = safe_malloc(sizeof(struct substitution));
-       res->nvar = l->nvar+r->nvar;
-       res->vars = safe_malloc(res->nvar*sizeof(char *));
-       res->types = safe_malloc(res->nvar*sizeof(struct type *));
-       for (int i = 0; i<l->nvar; i++) {
-               res->vars[i] = l->vars[i];
-               res->types[i] = l->types[i];
-       }
-       for (int i = 0; i<r->nvar; i++) {
-               res->vars[l->nvar+i] = l->vars[i];
-               res->types[l->nvar+i] = subst_apply_t(l, r->types[i]);
-       }
-       return res;
-}
-
-struct substitution *unify(struct type *l, struct type *r) {
-       switch (l->type) {
        case tbasic:
                if (r->type == tbasic && l->data.tbasic == r->data.tbasic)
                        return subst_id();
@@ -179,11 +44,8 @@ struct substitution *unify(struct type *l, struct type *r) {
                break;
        case ttuple:
                if (r->type == ttuple) {
-                       struct substitution *s1 = unify(
-                               l->data.ttuple.l,
-                               r->data.ttuple.l);
-                       struct substitution *s2 = unify(
-                               subst_apply_t(s1, l->data.ttuple.l),
+                       s1 = unify(l->data.ttuple.l, r->data.ttuple.l);
+                       s2 = unify(subst_apply_t(s1, l->data.ttuple.l),
                                subst_apply_t(s1, r->data.ttuple.l));
                        return subst_union(s1, s2);
                }
@@ -197,11 +59,23 @@ struct substitution *unify(struct type *l, struct type *r) {
                        return subst_singleton(l->data.tvar, r);
                break;
        }
+       fprintf(stderr, "cannot unify ");
+       type_print(l, stderr);
+       fprintf(stderr, " with ");
+       type_print(r, stderr);
+       fprintf(stderr, "\n");
        return NULL;
 }
 
-struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type)
+struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type)
 {
+       fprintf(stderr, "infer expr: ");
+       expr_print(expr, stderr);
+       fprintf(stderr, "\ngamma: ");
+       gamma_print(gamma, stderr);
+       fprintf(stderr, "\ntype: ");
+       type_print(type, stderr);
+       fprintf(stderr, "\n");
 
 #define infbop(l, r, a1, a2, rt, sigma) {\
        s1 = infer_expr(gamma, l, a1);\
@@ -210,13 +84,14 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t
 }
 #define infbinop(e, a1, a2, rt, sigma)\
        infbop(e->data.ebinop.l, e->data.ebinop.r, a1, a2, rt, sigma)
-       struct substitution *s1, *s2;
+       struct subst *s1, *s2;
        struct type *f1, *f2;
+       struct scheme *s;
        switch (expr->type) {
        case ebool:
                return unify(type_basic(btbool), type);
        case ebinop:
-               switch(expr->data.ebinop.op) {
+               switch (expr->data.ebinop.op) {
                case binor:
                case binand:
                        infbinop(expr, type_basic(btbool), type_basic(btbool),
@@ -227,10 +102,10 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t
                case le:
                case geq:
                case ge:
-                       f1 = fresh(gamma);
+                       f1 = gamma_fresh(gamma);
                        infbinop(expr, f1, f1, type_basic(btbool), type);
                case cons:
-                       f1 = fresh(gamma);
+                       f1 = gamma_fresh(gamma);
                        infbinop(expr, f1, type_list(f1), type_list(f1), type);
                case plus:
                case minus:
@@ -241,20 +116,28 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t
                        infbinop(expr, type_basic(btint), type_basic(btint),
                                type_basic(btint), type);
                }
+               break;
        case echar:
                return unify(type_basic(btchar), type);
        case efuncall:
+               if ((s = gamma_lookup(gamma, expr->data.efuncall.ident)) == NULL)
+                       die("Unbound function: %s\n", expr->data.efuncall.ident);
                //TODO
+               //TODO fields
                return NULL;
        case eint:
                return unify(type_basic(btint), type);
        case eident:
-
-               //TODO
-               return NULL;
+               if ((s = gamma_lookup(gamma, expr->data.eident.ident)) == NULL)
+                       die("Unbound variable: %s\n", expr->data.eident.ident);
+               //TODO fields
+               return unify(scheme_instantiate(gamma, s), type);
+       case enil:
+               f1 = gamma_fresh(gamma);
+               return unify(type_list(f1), type);
        case etuple:
-               f1 = fresh(gamma);
-               f2 = fresh(gamma);
+               f1 = gamma_fresh(gamma);
+               f2 = gamma_fresh(gamma);
                infbop(expr->data.etuple.left, expr->data.etuple.right,
                       f1, f2, type_tuple(f1, f2), type);
        case estring:
@@ -264,6 +147,8 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t
                case negate:
                        s1 = infer_expr(gamma,
                                expr->data.eunop.l, type_basic(btint));
+                       if (s1 == NULL)
+                               return NULL;
                        return subst_union(s1,
                                unify(subst_apply_t(s1, type),
                                type_basic(btint)));
@@ -275,5 +160,5 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t
                                type_basic(btbool)));
                }
        }
-
+       return NULL;
 }
index efb1afa..6792c3b 100644 (file)
--- a/sem/hm.h
+++ b/sem/hm.h
@@ -2,19 +2,11 @@
 #define SEM_HM_C
 
 #include "../ast.h"
-
-struct scheme {
-       struct type *type;
-       int nvar;
-       char **var;
-};
-
-struct substitution {
-       int nvar;
-       char **vars;
-       struct type **types;
-};
+#include "hm/gamma.h"
+#include "hm/subst.h"
+#include "hm/scheme.h"
 
 struct ast *infer(struct ast *ast);
+struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type);
 
 #endif
diff --git a/sem/hm/gamma.c b/sem/hm/gamma.c
new file mode 100644 (file)
index 0000000..9a97c27
--- /dev/null
@@ -0,0 +1,62 @@
+#include <stdlib.h>
+#include <string.h>
+
+#include "../hm.h"
+
+struct gamma *gamma_init()
+{
+       struct gamma *gamma = safe_malloc(sizeof(struct gamma));
+       gamma->fresh = 0;
+       gamma->nschemes = 0;
+       gamma->vars = NULL;
+       gamma->schemes = NULL;
+       return gamma;
+}
+
+void gamma_insert(struct gamma *gamma, char *ident, struct scheme *scheme)
+{
+       gamma->nschemes++;
+       gamma->vars = realloc(gamma->vars, gamma->nschemes*sizeof(char *));
+       gamma->schemes = realloc(gamma->schemes,
+               gamma->nschemes*sizeof(struct scheme *));
+       gamma->vars[gamma->nschemes-1] = safe_strdup(ident);
+       gamma->schemes[gamma->nschemes-1] = scheme;
+}
+
+struct scheme *gamma_lookup(struct gamma *gamma, char *ident)
+{
+       for (int i = 0; i<gamma->nschemes; i++)
+               if (strcmp(ident, gamma->vars[i]) == 0)
+                       return gamma->schemes[i];
+       return NULL;
+}
+
+struct type *gamma_fresh(struct gamma *gamma)
+{
+       char buf[10] = {0};
+       sprintf(buf, "%d", gamma->fresh++);
+       return type_var(safe_strdup(buf));
+}
+
+void gamma_print(struct gamma *gamma, FILE *out)
+{
+       fprintf(out, "{");
+       for (int i = 0; i<gamma->nschemes; i++) {
+               fprintf(out, "%s=", gamma->vars[i]);
+               scheme_print(gamma->schemes[i], out);
+               if (i + 1 < gamma->nschemes)
+                       fprintf(out, ", ");
+       }
+       fprintf(out, "}");
+}
+
+void gamma_free(struct gamma *gamma)
+{
+       for (int i = 0; i<gamma->nschemes; i++) {
+               free(gamma->vars[i]);
+               scheme_free(gamma->schemes[i]);
+       }
+       free(gamma->vars);
+       free(gamma->schemes);
+       free(gamma);
+}
diff --git a/sem/hm/gamma.h b/sem/hm/gamma.h
new file mode 100644 (file)
index 0000000..8499144
--- /dev/null
@@ -0,0 +1,24 @@
+#ifndef SEM_HM_GAMMA_H
+#define SEM_HM_GAMMA_H
+
+#include <stdlib.h>
+
+#include "../hm.h"
+
+struct gamma {
+       int fresh;
+       int nschemes;
+       char **vars;
+       struct scheme **schemes;
+};
+
+struct gamma *gamma_init();
+void gamma_insert(struct gamma *gamma, char *ident, struct scheme *scheme);
+
+struct scheme *gamma_lookup(struct gamma *gamma, char *ident);
+struct type *gamma_fresh(struct gamma *gamma);
+
+void gamma_print(struct gamma *gamma, FILE *out);
+void gamma_free(struct gamma *gamma);
+
+#endif
diff --git a/sem/hm/scheme.c b/sem/hm/scheme.c
new file mode 100644 (file)
index 0000000..177bb64
--- /dev/null
@@ -0,0 +1,68 @@
+#include <string.h>
+#include <stdlib.h>
+
+#include "../hm.h"
+
+struct type *scheme_instantiate(struct gamma *gamma, struct scheme *sch)
+{
+       struct subst *s = subst_id();
+       for (int i = 0; i<sch->nvar; i++) {
+               s = subst_union(s, subst_singleton(sch->var[i], gamma_fresh(gamma)));
+       }
+       struct type *t = subst_apply_t(s, type_dup(sch->type));
+       for (int i = 0; i<s->nvar; i++)
+               free(s->vars[i]);
+       free(s);
+       return t;
+}
+
+struct scheme *scheme_create(struct type *t)
+{
+       struct scheme *s = safe_malloc(sizeof(struct scheme));
+       s->type = t;
+       s->nvar = 0;
+       s->var = NULL;
+}
+
+struct scheme *scheme_generalise(struct gamma *gamma, struct type *t)
+{
+       struct scheme *s = safe_malloc(sizeof(struct scheme));
+       int nftv = 0;
+       char **ftv = NULL;
+       type_ftv(t, &nftv, &ftv);
+
+       s->type = t;
+       s->nvar = 0;
+       s->var = safe_malloc(nftv*sizeof(char *));
+       for (int i = 0; i<nftv; i++) {
+               bool skip = false;
+               for (int j = 0; j<gamma->nschemes; j++)
+                       if (strcmp(gamma->vars[j], ftv[i]) == 0)
+                               skip = true;
+               if (skip)
+                       continue;
+               s->nvar++;
+               s->var[i] = ftv[i];
+       }
+       return s;
+}
+
+void scheme_print(struct scheme *scheme, FILE *out)
+{
+       if (scheme->nvar > 0) {
+               fprintf(out, "A.");
+               for (int i = 0; i<scheme->nvar; i++)
+                       fprintf(out, "%s", scheme->var[i]);
+               fprintf(out, ": ");
+       }
+       type_print(scheme->type, out);
+}
+
+void scheme_free(struct scheme *scheme)
+{
+       type_free(scheme->type);
+       for (int i = 0; i<scheme->nvar; i++)
+               free(scheme->var[i]);
+       free(scheme->var);
+       free(scheme);
+}
diff --git a/sem/hm/scheme.h b/sem/hm/scheme.h
new file mode 100644 (file)
index 0000000..eaab3fc
--- /dev/null
@@ -0,0 +1,19 @@
+#ifndef SEM_HM_SCHEME_H
+#define SEM_HM_SCHEME_H
+
+#include "../hm.h"
+
+struct scheme {
+       struct type *type;
+       int nvar;
+       char **var;
+};
+
+struct type *scheme_instantiate(struct gamma *gamma, struct scheme *s);
+struct scheme *scheme_create(struct type *t);
+struct scheme *scheme_generalise(struct gamma *gamma, struct type *t);
+
+void scheme_print(struct scheme *scheme, FILE *out);
+void scheme_free(struct scheme *scheme);
+
+#endif
diff --git a/sem/hm/subst.c b/sem/hm/subst.c
new file mode 100644 (file)
index 0000000..55d6cfe
--- /dev/null
@@ -0,0 +1,124 @@
+#include <string.h>
+#include <stdlib.h>
+
+#include "../hm.h"
+
+struct subst *subst_id()
+{
+       struct subst *res = safe_malloc(sizeof(struct subst));
+       res->nvar = 0;
+       res->vars = NULL;
+       res->types = NULL;
+       return res;
+}
+
+struct subst *subst_singleton(char *ident, struct type *t)
+{
+       struct subst *res = safe_malloc(sizeof(struct subst));
+       res->nvar = 1;
+       res->vars = safe_malloc(sizeof(char *));
+       res->vars[0] = safe_strdup(ident);
+       res->types = safe_malloc(sizeof(struct type *));
+       res->types[0] = type_dup(t);
+       return res;
+}
+
+struct subst *subst_union(struct subst *l, struct subst *r)
+{
+       if (l == NULL || r == NULL)
+               return NULL;
+       struct subst *res = safe_malloc(sizeof(struct subst));
+       res->nvar = l->nvar+r->nvar;
+       res->vars = safe_malloc(res->nvar*sizeof(char *));
+       res->types = safe_malloc(res->nvar*sizeof(struct type *));
+       for (int i = 0; i<l->nvar; i++) {
+               res->vars[i] = l->vars[i];
+               res->types[i] = l->types[i];
+       }
+       for (int i = 0; i<r->nvar; i++) {
+               res->vars[l->nvar+i] = r->vars[i];
+               res->types[l->nvar+i] = subst_apply_t(l, r->types[i]);
+       }
+       return res;
+}
+
+struct type *subst_apply_t(struct subst *subst, struct type *l)
+{
+       if (subst == NULL)
+               return l;
+       switch (l->type) {
+       case tarrow:
+               l->data.tarrow.l = subst_apply_t(subst, l->data.tarrow.l);
+               l->data.tarrow.r = subst_apply_t(subst, l->data.tarrow.r);
+               break;
+       case tbasic:
+               break;
+       case tlist:
+               l->data.tlist = subst_apply_t(subst, l->data.tlist);
+               break;
+       case ttuple:
+               l->data.ttuple.l = subst_apply_t(subst, l->data.ttuple.l);
+               l->data.ttuple.r = subst_apply_t(subst, l->data.ttuple.r);
+               break;
+       case tvar:
+               for (int i = 0; i<subst->nvar; i++) {
+                       if (strcmp(subst->vars[i], l->data.tvar) == 0) {
+                               free(l->data.tvar);
+                               struct type *r = type_dup(subst->types[i]);
+                               *l = *r;
+                               free(r);
+                               break;
+                       }
+               }
+               break;
+       }
+       return l;
+}
+
+struct scheme *subst_apply_s(struct subst *subst, struct scheme *scheme)
+{
+       for (int i = 0; i<scheme->nvar; i++) {
+               for (int j = 0; j<subst->nvar; j++) {
+                       if (strcmp(scheme->var[i], subst->vars[j]) != 0) {
+                               struct subst *t = subst_singleton(
+                                       subst->vars[j], subst->types[j]);
+                               scheme->type = subst_apply_t(t, scheme->type);
+                       }
+               }
+       }
+       return scheme;
+}
+
+struct gamma *subst_apply_g(struct subst *subst, struct gamma *gamma)
+{
+       for (int i = 0; i<gamma->nschemes; i++)
+               subst_apply_s(subst, gamma->schemes[i]);
+       return gamma;
+}
+
+void subst_print(struct subst *s, FILE *out)
+{
+       if (s == NULL) {
+               fprintf(out, "no subst\n");
+       } else {
+               fprintf(out, "[");
+               for (int i = 0; i<s->nvar; i++) {
+                       fprintf(out, "%s->", s->vars[i]);
+                       type_print(s->types[i], out);
+                       if (i + 1 < s->nvar)
+                               fprintf(out, ", ");
+               }
+               fprintf(out, "]\n");
+       }
+}
+
+void subst_free(struct subst *s, bool type)
+{
+       if (s != NULL) {
+               for (int i = 0; i<s->nvar; i++) {
+                       free(s->vars[i]);
+                       if (type)
+                               type_free(s->types[i]);
+               }
+       }
+}
diff --git a/sem/hm/subst.h b/sem/hm/subst.h
new file mode 100644 (file)
index 0000000..ffec9a7
--- /dev/null
@@ -0,0 +1,24 @@
+#ifndef SEM_HM_SUBST_H
+#define SEM_HM_SUBST_H
+
+#include "../../ast.h"
+#include "../hm.h"
+
+struct subst {
+       int nvar;
+       char **vars;
+       struct type **types;
+};
+
+struct subst *subst_id();
+struct subst *subst_singleton(char *ident, struct type *t);
+struct subst *subst_union(struct subst *l, struct subst *r);
+
+struct type *subst_apply_t(struct subst *subst, struct type *l);
+struct scheme *subst_apply_s(struct subst *subst, struct scheme *scheme);
+struct gamma *subst_apply_g(struct subst *subst, struct gamma *gamma);
+
+void subst_print(struct subst *s, FILE *out);
+void subst_free(struct subst *s, bool type);
+
+#endif
index fa47590..6d7b02e 100644 (file)
--- a/sem/scc.c
+++ b/sem/scc.c
@@ -271,10 +271,11 @@ struct ast *ast_scc(struct ast *ast)
        struct edge **edata = (struct edge **)
                list_to_array(edges, &nedges, false);
 
+       fprintf(stderr, "nfun: %d, ffun: %d, nedges: %d\n", nfun, ffun, nedges);
        // Do tarjan's and convert back into the declaration list
        struct components *cs = tarjans(nfun, (void **)fundecls, nedges, edata);
-       if (cs == NULL)
-               die("malformed edges in tarjan's????");
+//     if (cs == NULL)
+//             die("malformed edges in tarjan's????");
 
        int i = ffun;
        for (struct components *c = cs; c != NULL; c = c->next) {
diff --git a/test/Makefile b/test/Makefile
new file mode 100644 (file)
index 0000000..faaff6a
--- /dev/null
@@ -0,0 +1,19 @@
+TESTOBJECTS:=$(patsubst %.c,%.o,$(wildcard *.c))
+TESTS:=$(patsubst %.o,%,$(TESTOBJECTS))
+
+LDLIBS+=$(shell pkg-config --libs check)
+
+.PHONY: test
+
+test_sem_hm_gamma.o: CFLAGS+=$(shell pkg-config --cflags check)
+
+test_sem_hm_gamma: test_sem_hm_gamma.o $(addprefix ../sem/hm/,gamma.o scheme.o subst.o) ../util.o ../type.o
+
+test: $(TESTS)
+       $(foreach f,$^,./$(f);)
+
+clean:
+       $(RM) $(TESTOBJECTS) $(TESTS)
+ifeq ($(MAKELEVEL), 0)
+       $(MAKE) -C ../ clean
+endif
diff --git a/test/test_sem_hm_gamma.c b/test/test_sem_hm_gamma.c
new file mode 100644 (file)
index 0000000..1665031
--- /dev/null
@@ -0,0 +1,97 @@
+#include <stdbool.h>
+#include <stdlib.h>
+#include <check.h>
+
+#include "../sem/hm/gamma.h"
+#include "../sem/hm/subst.h"
+#include "../sem/hm/scheme.h"
+
+START_TEST(test_gamma_lookup)
+{
+       struct gamma *gamma = gamma_init();
+
+       ck_assert_ptr_null(gamma_lookup(gamma, "fun"));
+       ck_assert_ptr_null(gamma_lookup(gamma, "fun2"));
+
+       gamma_insert(gamma, "fun", scheme_generalise(gamma, type_basic(btint)));
+
+       ck_assert_ptr_nonnull(gamma_lookup(gamma, "fun"));
+       ck_assert_ptr_null(gamma_lookup(gamma, "fun2"));
+
+       struct type *t1 = gamma_fresh(gamma);
+       ck_assert(t1->type == tvar);
+       struct type *t2 = gamma_fresh(gamma);
+       ck_assert(t2->type == tvar);
+       struct type *t3 = gamma_fresh(gamma);
+       ck_assert(t3->type == tvar);
+       struct type *t4 = gamma_fresh(gamma);
+       ck_assert(t4->type == tvar);
+
+       ck_assert_str_ne(t1->data.tvar, t2->data.tvar);
+       ck_assert_str_ne(t2->data.tvar, t3->data.tvar);
+       ck_assert_str_ne(t3->data.tvar, t4->data.tvar);
+}
+END_TEST
+
+START_TEST(test_scheme)
+{
+       struct gamma *gamma = gamma_init();
+
+       char **var = malloc(sizeof(char *));
+       var[0] = safe_strdup("a");
+       struct scheme scheme = {.type=type_var("a"), .nvar=1, .var=var};
+
+       struct type *t = scheme_instantiate(gamma, &scheme);
+       ck_assert(t->type == tvar);
+       ck_assert_str_eq(t->data.tvar, "0");
+
+       scheme.type = type_list(type_var("a"));
+       t = scheme_instantiate(gamma, &scheme);
+       ck_assert(t->type == tlist);
+       ck_assert(t->data.tlist->type == tvar);
+       ck_assert_str_eq(t->data.tlist->data.tvar, "1");
+}
+END_TEST
+
+START_TEST(test_subst)
+{
+       struct subst *s1 = subst_id();
+       ck_assert_int_eq(0, s1->nvar);
+       s1 = subst_singleton("i1", type_basic(btint));
+       ck_assert_int_eq(1, s1->nvar);
+       s1 = subst_union(subst_id(), subst_singleton("i1", type_basic(btint)));
+       ck_assert_int_eq(1, s1->nvar);
+       s1 = subst_union(subst_singleton("i2", type_basic(btbool)),
+               subst_singleton("i1", type_basic(btint)));
+       ck_assert_int_eq(2, s1->nvar);
+
+}
+END_TEST
+
+Suite *util_suite(void)
+{
+       Suite *s = suite_create("List");
+
+       TCase *tc_gamma = tcase_create("Gamma lookup");
+       tcase_add_test(tc_gamma, test_gamma_lookup);
+       tcase_add_test(tc_gamma, test_scheme);
+       tcase_add_test(tc_gamma, test_subst);
+       suite_add_tcase(s, tc_gamma);
+
+       return s;
+}
+
+int main(void)
+{
+       int failed;
+       Suite *s;
+       SRunner *sr;
+
+       s = util_suite();
+       sr = srunner_create(s);
+
+       srunner_run_all(sr, CK_NORMAL);
+       failed = srunner_ntests_failed(sr);
+       srunner_free(sr);
+       return (failed == 0) ? EXIT_SUCCESS : EXIT_FAILURE;
+}
diff --git a/type.c b/type.c
new file mode 100644 (file)
index 0000000..8fc6102
--- /dev/null
+++ b/type.c
@@ -0,0 +1,183 @@
+#include <string.h>
+#include <stdlib.h>
+
+#include "util.h"
+#include "type.h"
+
+static const char *basictype_str[] = {
+       [btbool] = "Bool", [btchar] = "Char", [btint] = "Int",
+       [btvoid] = "Void",
+};
+
+struct type *type_arrow(struct type *l, struct type *r)
+{
+       struct type *res = safe_malloc(sizeof(struct type));
+       res->type = tarrow;
+       res->data.tarrow.l = l;
+       res->data.tarrow.r = r;
+       return res;
+
+}
+
+struct type *type_basic(enum basictype type)
+{
+       struct type *res = safe_malloc(sizeof(struct type));
+       res->type = tbasic;
+       res->data.tbasic = type;
+       return res;
+}
+
+struct type *type_list(struct type *type)
+{
+       struct type *res = safe_malloc(sizeof(struct type));
+       res->type = tlist;
+       res->data.tlist = type;
+       return res;
+}
+
+struct type *type_tuple(struct type *l, struct type *r)
+{
+       struct type *res = safe_malloc(sizeof(struct type));
+       res->type = ttuple;
+       res->data.ttuple.l = l;
+       res->data.ttuple.r = r;
+       return res;
+}
+
+struct type *type_var(char *ident)
+{
+       struct type *res = safe_malloc(sizeof(struct type));
+       if (strcmp(ident, "Int") == 0) {
+               res->type = tbasic;
+               res->data.tbasic = btint;
+               free(ident);
+       } else if (strcmp(ident, "Char") == 0) {
+               res->type = tbasic;
+               res->data.tbasic = btchar;
+               free(ident);
+       } else if (strcmp(ident, "Bool") == 0) {
+               res->type = tbasic;
+               res->data.tbasic = btbool;
+               free(ident);
+       } else if (strcmp(ident, "Void") == 0) {
+               res->type = tbasic;
+               res->data.tbasic = btvoid;
+               free(ident);
+       } else {
+               res->type = tvar;
+               res->data.tvar = ident;
+       }
+       return res;
+}
+
+void type_print(struct type *type, FILE *out)
+{
+       if (type == NULL)
+               return;
+       switch (type->type) {
+       case tarrow:
+               safe_fprintf(out, "(");
+               type_print(type->data.tarrow.l, out);
+               safe_fprintf(out, "->");
+               type_print(type->data.tarrow.r, out);
+               safe_fprintf(out, ")");
+               break;
+       case tbasic:
+               safe_fprintf(out, "%s", basictype_str[type->data.tbasic]);
+               break;
+       case tlist:
+               safe_fprintf(out, "[");
+               type_print(type->data.tlist, out);
+               safe_fprintf(out, "]");
+               break;
+       case ttuple:
+               safe_fprintf(out, "(");
+               type_print(type->data.ttuple.l, out);
+               safe_fprintf(out, ",");
+               type_print(type->data.ttuple.r, out);
+               safe_fprintf(out, ")");
+               break;
+       case tvar:
+               safe_fprintf(out, "%s", type->data.tvar);
+               break;
+       default:
+               die("Unsupported type node\n");
+       }
+}
+
+void type_free(struct type *type)
+{
+       if (type == NULL)
+               return;
+       switch (type->type) {
+       case tarrow:
+               type_free(type->data.tarrow.l);
+               type_free(type->data.tarrow.r);
+               break;
+       case tbasic:
+               break;
+       case tlist:
+               type_free(type->data.tlist);
+               break;
+       case ttuple:
+               type_free(type->data.ttuple.l);
+               type_free(type->data.ttuple.r);
+               break;
+       case tvar:
+               free(type->data.tvar);
+               break;
+       default:
+               die("Unsupported type node\n");
+       }
+       free(type);
+}
+
+struct type *type_dup(struct type *r)
+{
+       struct type *res = safe_malloc(sizeof(struct type));
+       *res = *r;
+       switch (r->type) {
+       case tarrow:
+               res->data.tarrow.l = type_dup(r->data.tarrow.l);
+               res->data.tarrow.r = type_dup(r->data.tarrow.r);
+               break;
+       case tbasic:
+               break;
+       case tlist:
+               res->data.tlist = type_dup(r->data.tlist);
+               break;
+       case ttuple:
+               res->data.ttuple.l = type_dup(r->data.ttuple.l);
+               res->data.ttuple.r = type_dup(r->data.ttuple.r);
+               break;
+       case tvar:
+               res->data.tvar = safe_strdup(r->data.tvar);
+               break;
+       }
+       return res;
+}
+
+void type_ftv(struct type *r, int *nftv, char ***ftv)
+{
+       switch (r->type) {
+       case tarrow:
+               type_ftv(r->data.ttuple.l, nftv, ftv);
+               type_ftv(r->data.ttuple.r, nftv, ftv);
+               break;
+       case tbasic:
+               break;
+       case tlist:
+               type_ftv(r->data.tlist, nftv, ftv);
+               break;
+       case ttuple:
+               type_ftv(r->data.ttuple.l, nftv, ftv);
+               type_ftv(r->data.ttuple.r, nftv, ftv);
+               break;
+       case tvar:
+               *ftv = realloc(*ftv, (*nftv+1)*sizeof(char *));
+               if (*ftv == NULL)
+                       perror("realloc");
+               (*ftv)[(*nftv)++] = r->data.tvar;
+               break;
+       }
+}
diff --git a/type.h b/type.h
new file mode 100644 (file)
index 0000000..0bbb2e0
--- /dev/null
+++ b/type.h
@@ -0,0 +1,36 @@
+#ifndef TYPE_H
+#define TYPE_H
+
+#include <stdio.h>
+
+enum basictype {btbool, btchar, btint, btvoid};
+struct type {
+       enum {tarrow,tbasic,tlist,ttuple,tvar} type;
+       union {
+               struct {
+                       struct type *l;
+                       struct type *r;
+               } tarrow;
+               enum basictype tbasic;
+               struct type *tlist;
+               struct {
+                       struct type *l;
+                       struct type *r;
+               } ttuple;
+               char *tvar;
+       } data;
+};
+
+struct type *type_arrow(struct type *l, struct type *r);
+struct type *type_basic(enum basictype type);
+struct type *type_list(struct type *type);
+struct type *type_tuple(struct type *l, struct type *r);
+struct type *type_var(char *ident);
+
+void type_print(struct type *type, FILE *out);
+void type_free(struct type *type);
+
+struct type *type_dup(struct type *t);
+void type_ftv(struct type *r, int *nftv, char ***ftv);
+
+#endif
diff --git a/util.c b/util.c
index efe22c7..fa60ddc 100644 (file)
--- a/util.c
+++ b/util.c
@@ -154,9 +154,11 @@ void *safe_malloc(size_t size)
 
 void *safe_strdup(const char *c)
 {
-       char *res = strdup(c);
+       size_t nchar = strlen(c);
+       char *res = malloc((nchar+1)*sizeof(char));
        if (res == NULL)
                pdie("strdup");
+       memcpy(res, c, nchar+1);
        return res;
 }