modify array interface
authorMart Lubbers <mart@martlubbers.net>
Tue, 16 Mar 2021 09:59:45 +0000 (10:59 +0100)
committerMart Lubbers <mart@martlubbers.net>
Tue, 16 Mar 2021 09:59:45 +0000 (10:59 +0100)
29 files changed:
.gitignore
Makefile
array.c
array.h
ast.c
ast.h
gen.c
gen/c.c
gen/c.h
gen/ssm.c
gen/ssm.h
parse.y
sem.c
sem/constant.c [new file with mode: 0644]
sem/constant.h [new file with mode: 0644]
sem/hm.c
sem/main.c [new file with mode: 0644]
sem/main.h [new file with mode: 0644]
sem/return.c [new file with mode: 0644]
sem/return.h [new file with mode: 0644]
sem/scc.c
sem/scc.h
sem/type.c [new file with mode: 0644]
sem/type.h [new file with mode: 0644]
sem/vardecl.c [new file with mode: 0644]
sem/vardecl.h [new file with mode: 0644]
splc.c
util.c
util.h

index bf48a04..1c29613 100644 (file)
@@ -1,4 +1,5 @@
 splc
+splc.exe
 parse.[ch]
 scan.[ch]
 *.o
@@ -6,6 +7,7 @@ y.output
 a.c
 a.ssm
 a.out
+a.exe
 
 callgrind.out.*
 massif.out.*
index 34518dd..e6bc8f0 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -5,7 +5,8 @@ LFLAGS+=--header-file=scan.h
 
 OBJECTS:=array.o scan.o parse.o ast.o type.o util.o sem.o ident.o\
        $(addprefix gen,.o /c.o /ssm.o)\
-       $(addprefix sem,.o /scc.o $(addprefix /hm, .o /gamma.o /subst.o /scheme.o))
+       $(addprefix sem,.o /main.o /constant.o /return.o /scc.o /type.o /vardecl.o\
+               $(addprefix /hm, .o /gamma.o /subst.o /scheme.o))
 
 all: splc
 splc: $(OBJECTS)
diff --git a/array.c b/array.c
index 22e3772..6b7d313 100644 (file)
--- a/array.c
+++ b/array.c
@@ -6,6 +6,13 @@
 
 const struct array array_null = {.nel=0, .cap=0, .el=NULL};
 
+struct array *array_new(size_t cap)
+{
+       struct array *res = xalloc(1, struct array);
+       array_init(res, cap);
+       return res;
+}
+
 void array_init(struct array *array, size_t cap)
 {
        array->nel = 0;
@@ -13,43 +20,39 @@ void array_init(struct array *array, size_t cap)
        array->el = xalloc(cap, void *);
 }
 
-struct array array_resize(struct array a, size_t cap)
+void array_resize(struct array *a, size_t cap)
 {
-       if (cap > a.cap)
-               a.el = xrealloc(a.el, a.cap = cap, void *);
-       return a;
+       if (cap > a->cap)
+               a->el = xrealloc(a->el, a->cap = cap, void *);
 }
 
-struct array array_append(struct array a, void *x)
+void array_append(struct array *a, void *x)
 {
-       if (a.nel >= a.cap)
-               a = array_resize(a, a.cap == 0 ? 8 : 2*a.cap);
-       a.el[a.nel++] = x;
-       return a;
+       if (a->nel >= a->cap)
+               array_resize(a, a->cap == 0 ? 8 : 2*a->cap);
+       a->el[a->nel++] = x;
 }
 
-struct array array_insert(struct array a, size_t idx, void *x)
+void array_insert(struct array *a, size_t idx, void *x)
 {
-       a = array_append(a, NULL);
-       for (size_t i = a.nel-1; i>idx; i--)
-               a.el[i] = a.el[i-1];
-       a.el[idx] = x;
-       return a;
+       array_append(a, NULL);
+       for (size_t i = a->nel-1; i>idx; i--)
+               a->el[i] = a->el[i-1];
+       a->el[idx] = x;
 }
 
-void array_free(struct array a, void (*freefun)(void *))
+void array_free(struct array *a, void (*freefun)(void *))
 {
        array_clean(a, freefun);
-       free(a.el);
+       free(a->el);
 }
 
-struct array array_clean(struct array a, void (*freefun)(void *))
+void array_clean(struct array *a, void (*freefun)(void *))
 {
        if (freefun != NULL)
                ARRAY_ITERI(i, a)
-                       freefun(a.el[i]);
-       a.nel = 0;
-       return a;
+                       freefun(a->el[i]);
+       a->nel = 0;
 }
 
 static const void *bsearchfail;
@@ -59,22 +62,24 @@ static int bscmp(const void *l, const void *r)
        bsearchfail = r;
        return realcmp(l, r);
 }
-struct array array_binsert(void *key, struct array a, int (*cmp)(const void *, const void *)) {
-       if (ARRAY_SIZE(a) == 0)
-               return array_append(a, key);
+void array_binsert(void *key, struct array *a, int (*cmp)(const void *, const void *)) {
+       if (ARRAY_SIZE(a) == 0) {
+               array_append(a, key);
+               return;
+       }
        bsearchfail = NULL;
        realcmp = cmp;
-       void *e = bsearch(key, a.el, a.nel, sizeof(void *), bscmp);
+       void *e = bsearch(key, a->el, a->nel, sizeof(void *), bscmp);
        if (e != NULL)
-               return a;
-       size_t idx = ((intptr_t)a.el-(intptr_t)bsearchfail)/sizeof(void *);
+               return;
+       size_t idx = ((intptr_t)a->el-(intptr_t)bsearchfail)/sizeof(void *);
        //check if it is smaller than the smallest
        if (idx == 0) {
-               if (cmp(key, a.el) > 0)
+               if (cmp(key, a->el) > 0)
                        idx++;
-       } else if (idx >= a.nel) {
-               idx = a.nel;
+       } else if (idx >= a->nel) {
+               idx = a->nel;
        }
 
-       return array_insert(a, idx, key);
+       array_insert(a, idx, key);
 }
diff --git a/array.h b/array.h
index a3e5932..e53698a 100644 (file)
--- a/array.h
+++ b/array.h
@@ -5,17 +5,17 @@
 #include <stdlib.h>
 
 /* Select an element */
-#define ARRAY_EL(type, array, idx) ((type)((array).el[idx]))
+#define ARRAY_EL(type, array, idx) ((type)((array)->el[idx]))
 /* Iterate over the indices of an array */
-#define ARRAY_ITERI(iter, a) for (size_t (iter) = 0; (iter)<(a).nel; (iter)++)
+#define ARRAY_ITERI(iter, a) for (size_t (iter) = 0; (iter)<(a)->nel; (iter)++)
 /* Iterate over the indices and elements of an array */
 #define ARRAY_ITER(type, x, iter, a) ARRAY_ITERI (iter, a) {\
                type (x) = ARRAY_EL(type, a, iter);
 #define AIEND }
 /* Get the size of the array */
-#define ARRAY_SIZE(a) (a).nel
+#define ARRAY_SIZE(a) (a)->nel
 /* Get a pointer to the elements of the array */
-#define ARRAY_ELS(type, a) ((type *)(a).el)
+#define ARRAY_ELS(type, a) ((type *)(a)->el)
 
 #define ARRAY_BSEARCH(type, key, a, cmp) (type)bsearch(key, ARRAY_ELS(type, a),\
        ARRAY_SIZE(a), sizeof(void *), (int (*)(const void *, const void *))cmp)
@@ -28,25 +28,25 @@ struct array {
        void **el;
 };
 
+//* Create an array with the given capacity
+struct array *array_new(size_t cap);
+
 //* Initialise an array
 void array_init(struct array *array, size_t cap);
 
 //* Resize an array to hold at least a certain capacity
-struct array array_resize(struct array a, size_t cap);
-
-//* Create an array with the given capacity
-struct array *array_new(size_t cap);
+void array_resize(struct array *a, size_t cap);
 
 //* Append an item to the array
-struct array array_append(struct array a, void *x);
+void array_append(struct array *a, void *x);
 
 //* free all elements and free the array
-void array_free(struct array, void (*freefun)(void *));
+void array_free(struct array *a, void (*freefun)(void *));
 
 //* free all element and keep the array
-struct array array_clean(struct array array, void (*freefun)(void *));
+void array_clean(struct array *a, void (*freefun)(void *));
 
 //* insert an item in a sorted array
-struct array array_binsert(void *key, struct array a, int (*cmp)(const void *, const void *));
+void array_binsert(void *key, struct array *a, int (*cmp)(const void *, const void *));
 
 #endif
diff --git a/ast.c b/ast.c
index ade4615..ccc7a46 100644 (file)
--- a/ast.c
+++ b/ast.c
@@ -21,8 +21,8 @@ struct ast *ast(struct array decls, YYLTYPE l)
 {
        struct ast *res = xalloc(1, struct ast);
        res->loc = l;
-       res->ndecls = ARRAY_SIZE(decls);
-       res->decls = ARRAY_ELS(struct decl *, decls);
+       res->ndecls = ARRAY_SIZE(&decls);
+       res->decls = ARRAY_ELS(struct decl *, &decls);
        return res;
 }
 
@@ -36,7 +36,7 @@ struct vardecl *vardecl(struct type *type, char *ident, struct expr *expr, YYLTY
        return res;
 }
 
-struct fundecl *fundecl(char *ident, struct array args, struct array atypes,
+struct fundecl *fundecl(char *ident, struct array args, struct array *atypes,
        struct type *rtype, struct array body, YYLTYPE l)
 {
        struct fundecl *res = xalloc(1, struct fundecl);
@@ -184,20 +184,21 @@ static struct expr *expr_funcall_real(char *ident, struct array args, YYLTYPE l)
        res->data.efuncall.ident = ident;
        res->data.efuncall.args = args;
        res->data.efuncall.type = NULL;
+       res->data.efuncall.ref = NULL;
        return res;
 }
 
 static struct expr *expr_apply_fields(struct expr *r, struct array fields, YYLTYPE l)
 {
-       ARRAY_ITER(char *, f, i, fields) {
+       ARRAY_ITER(char *, f, i, &fields) {
                if (is_valid_field(f)) {
                        struct array as;
                        array_init(&as, 1);
-                       as = array_append(as, r);
+                       array_append(&as, r);
                        r = expr_funcall_real(f, as, l);
                }
        } AIEND
-       array_free(fields, NULL);
+       array_free(&fields, NULL);
        return r;
 }
 
@@ -275,8 +276,6 @@ struct expr *expr_unop(enum unop op, struct expr *e, YYLTYPE l)
 
 void ast_print(struct ast *ast, FILE *out)
 {
-       if (ast == NULL)
-               return;
        for (int i = 0; i<ast->ndecls; i++)
                decl_print(ast->decls[i], out);
 }
@@ -296,23 +295,29 @@ void vardecl_print(struct vardecl *decl, int indent, FILE *out)
 void fundecl_print(struct fundecl *decl, FILE *out)
 {
        safe_fprintf(out, "%s (", decl->ident);
-       ARRAY_ITER(char *, arg, i, decl->args) {
+       ARRAY_ITER(char *, arg, i, &decl->args) {
                safe_fprintf(out, "%s", arg);
-               if (i < ARRAY_SIZE(decl->args) - 1)
+               if (i < ARRAY_SIZE(&decl->args) - 1)
                        safe_fprintf(out, ", ");
        } AIEND
        safe_fprintf(out, ")");
        if (decl->rtype != NULL) {
-               safe_fprintf(out, " :: ");
-               ARRAY_ITER(struct type *, atype, i, decl->atypes) {
-                       type_print(atype, out);
-                       safe_fprintf(out, " ");
-               } AIEND
-               safe_fprintf(out, "-> ");
-               type_print(decl->rtype, out);
+               if (decl->atypes == NULL) {
+                       safe_fprintf(out, "/* :: ?? -> ");
+                       type_print(decl->rtype, out);
+                       safe_fprintf(out, " */");
+               } else {
+                       safe_fprintf(out, " :: ");
+                       ARRAY_ITER(struct type *, atype, i, decl->atypes) {
+                               type_print(atype, out);
+                               safe_fprintf(out, " ");
+                       } AIEND
+                       safe_fprintf(out, "-> ");
+                       type_print(decl->rtype, out);
+               }
        }
        safe_fprintf(out, " {\n");
-       ARRAY_ITER(struct stmt *, stmt, i, decl->body)
+       ARRAY_ITER(struct stmt *, stmt, i, &decl->body)
                stmt_print(stmt, 1, out);
        AIEND
        safe_fprintf(out, "}\n");
@@ -320,8 +325,6 @@ void fundecl_print(struct fundecl *decl, FILE *out)
 
 void decl_print(struct decl *decl, FILE *out)
 {
-       if (decl == NULL)
-               return;
        switch(decl->type) {
        case dfundecl:
                fundecl_print(decl->data.dfun, out);
@@ -331,7 +334,7 @@ void decl_print(struct decl *decl, FILE *out)
                break;
        case dcomp:
                safe_fprintf(out, "//<<<comp\n");
-               ARRAY_ITER(struct fundecl *, d, i, decl->data.dcomp)
+               ARRAY_ITER(struct fundecl *, d, i, &decl->data.dcomp)
                        fundecl_print(d, out);
                AIEND
                safe_fprintf(out, "//>>>comp\n");
@@ -343,13 +346,11 @@ void decl_print(struct decl *decl, FILE *out)
 
 void stmt_print(struct stmt *stmt, int indent, FILE *out)
 {
-       if (stmt == NULL)
-               return;
        switch(stmt->type) {
        case sassign:
                pindent(indent, out);
                safe_fprintf(out, "%s", stmt->data.sassign.ident);
-               ARRAY_ITER(char *, f, i, stmt->data.sassign.fields)
+               ARRAY_ITER(char *, f, i, &stmt->data.sassign.fields)
                        safe_fprintf(out, ".%s", f);
                AIEND
                safe_fprintf(out, " = ");
@@ -361,12 +362,12 @@ void stmt_print(struct stmt *stmt, int indent, FILE *out)
                safe_fprintf(out, "if (");
                expr_print(stmt->data.sif.pred, out);
                safe_fprintf(out, ") {\n");
-               ARRAY_ITER(struct stmt *, s, i, stmt->data.sif.then)
+               ARRAY_ITER(struct stmt *, s, i, &stmt->data.sif.then)
                        stmt_print(s, indent+1, out);
                AIEND
                pindent(indent, out);
                safe_fprintf(out, "} else {\n");
-               ARRAY_ITER(struct stmt *, s, i, stmt->data.sif.els)
+               ARRAY_ITER(struct stmt *, s, i, &stmt->data.sif.els)
                        stmt_print(s, indent+1, out);
                AIEND
                pindent(indent, out);
@@ -375,7 +376,8 @@ void stmt_print(struct stmt *stmt, int indent, FILE *out)
        case sreturn:
                pindent(indent, out);
                safe_fprintf(out, "return ");
-               expr_print(stmt->data.sreturn, out);
+               if (stmt->data.sreturn != NULL)
+                       expr_print(stmt->data.sreturn, out);
                safe_fprintf(out, ";\n");
                break;
        case sexpr:
@@ -391,7 +393,7 @@ void stmt_print(struct stmt *stmt, int indent, FILE *out)
                safe_fprintf(out, "while (");
                expr_print(stmt->data.swhile.pred, out);
                safe_fprintf(out, ") {\n");
-               ARRAY_ITER(struct stmt *, s, i, stmt->data.swhile.body)
+               ARRAY_ITER(struct stmt *, s, i, &stmt->data.swhile.body)
                        stmt_print(s, indent+1, out);
                AIEND
                pindent(indent, out);
@@ -445,8 +447,6 @@ static inline bool brace(struct ctx this, struct ctx outer)
 
 static void expr_print2(struct expr *expr, FILE *out, struct ctx ctx)
 {
-       if (expr == NULL)
-               return;
        char buf[] = "\\xff";
        struct ctx this;
        switch(expr->type) {
@@ -457,7 +457,7 @@ static void expr_print2(struct expr *expr, FILE *out, struct ctx ctx)
                this.branch = left;
                expr_print2(expr->data.ebinop.l, out, this);
                safe_fprintf(out, " %s ", binop_str[expr->data.ebinop.op]);
-               if (expr->data.efuncall.type != NULL) {
+               if (expr->data.ebinop.type != NULL) {
                        safe_fprintf(out, " /* ");
                        type_print(expr->data.ebinop.type, out);
                        safe_fprintf(out, " */ ");
@@ -482,9 +482,9 @@ static void expr_print2(struct expr *expr, FILE *out, struct ctx ctx)
                        safe_fprintf(out, " */ ");
                }
                safe_fprintf(out, "(");
-               ARRAY_ITER(struct expr *, e, i, expr->data.efuncall.args) {
+               ARRAY_ITER(struct expr *, e, i, &expr->data.efuncall.args) {
                        expr_print2(e, out, nfctx);
-                       if (i+1 < ARRAY_SIZE(expr->data.efuncall.args))
+                       if (i+1 < ARRAY_SIZE(&expr->data.efuncall.args))
                                safe_fprintf(out, ", ");
                } AIEND
                safe_fprintf(out, ")");
@@ -534,8 +534,6 @@ void expr_print(struct expr *expr, FILE *out)
 
 void ast_free(struct ast *ast)
 {
-       if (ast == NULL)
-               return;
        for (int i = 0; i<ast->ndecls; i++)
                decl_free(ast->decls[i]);
        free(ast->decls);
@@ -553,20 +551,19 @@ void vardecl_free(struct vardecl *decl)
 void fundecl_free(struct fundecl *decl)
 {
        free(decl->ident);
-       array_free(decl->args, free);
+       array_free(&decl->args, free);
        array_free(decl->atypes, (void (*)(void *))type_free);
+       free(decl->atypes);
        type_free(decl->rtype);
-       array_free(decl->body, (void (*)(void *))stmt_free);
+       array_free(&decl->body, (void (*)(void *))stmt_free);
        free(decl);
 }
 
 void decl_free(struct decl *decl)
 {
-       if (decl == NULL)
-               return;
        switch(decl->type) {
        case dcomp:
-               array_free(decl->data.dcomp, (void (*)(void *))fundecl_free);
+               array_free(&decl->data.dcomp, (void (*)(void *))fundecl_free);
                break;
        case dfundecl:
                fundecl_free(decl->data.dfun);
@@ -582,28 +579,27 @@ void decl_free(struct decl *decl)
 
 void stmt_free(struct stmt *stmt)
 {
-       if (stmt == NULL)
-               return;
        switch(stmt->type) {
        case sassign:
                free(stmt->data.sassign.ident);
-               array_free(stmt->data.sassign.fields, free);
+               array_free(&stmt->data.sassign.fields, free);
                expr_free(stmt->data.sassign.expr);
                break;
        case sif:
                expr_free(stmt->data.sif.pred);
-               array_free(stmt->data.sif.then, (void (*)(void *))stmt_free);
-               array_free(stmt->data.sif.els, (void (*)(void *))stmt_free);
+               array_free(&stmt->data.sif.then, (void (*)(void *))stmt_free);
+               array_free(&stmt->data.sif.els, (void (*)(void *))stmt_free);
                break;
        case sreturn:
-               expr_free(stmt->data.sreturn);
+               if (stmt->data.sreturn != NULL)
+                       expr_free(stmt->data.sreturn);
                break;
        case sexpr:
                expr_free(stmt->data.sexpr);
                break;
        case swhile:
                expr_free(stmt->data.swhile.pred);
-               array_free(stmt->data.swhile.body, (void (*)(void *))stmt_free);
+               array_free(&stmt->data.swhile.body, (void (*)(void *))stmt_free);
                break;
        case svardecl:
                vardecl_free(stmt->data.svardecl);
@@ -616,8 +612,6 @@ void stmt_free(struct stmt *stmt)
 
 void expr_free(struct expr *expr)
 {
-       if (expr == NULL)
-               return;
        switch(expr->type) {
        case ebinop:
                expr_free(expr->data.ebinop.l);
@@ -633,7 +627,7 @@ void expr_free(struct expr *expr)
                free(expr->data.efuncall.ident);
                if (expr->data.efuncall.type != NULL)
                        type_free(expr->data.efuncall.type);
-               array_free(expr->data.efuncall.args, (void (*)(void *))expr_free);
+               array_free(&expr->data.efuncall.args, (void (*)(void *))expr_free);
                break;
        case eint:
                break;
diff --git a/ast.h b/ast.h
index 7e63b47..541b64f 100644 (file)
--- a/ast.h
+++ b/ast.h
@@ -29,7 +29,7 @@ struct fundecl {
        YYLTYPE loc;
        char *ident;
        struct array args; // char *
-       struct array atypes; // struct type *
+       struct array *atypes; // struct type *
        struct type *rtype;
        struct array body; //struct stmt *
 };
@@ -93,6 +93,7 @@ struct expr {
                        char *ident;
                        struct array args; // struct expr *
                        struct type *type; // type for overloaded functions
+                       struct fundecl *ref;
                } efuncall;
                int eint;
                char *eident;
@@ -114,7 +115,7 @@ struct expr {
 struct ast *ast(struct array decls, YYLTYPE l);
 
 struct vardecl *vardecl(struct type *type, char *ident, struct expr *expr, YYLTYPE l);
-struct fundecl *fundecl(char *ident, struct array args, struct array atypes, struct type *rtype, struct array body, YYLTYPE l);
+struct fundecl *fundecl(char *ident, struct array args, struct array *atypes, struct type *rtype, struct array body, YYLTYPE l);
 
 struct decl *decl_fun(struct fundecl *fundecl, YYLTYPE l);
 struct decl *decl_var(struct vardecl *vardecl, YYLTYPE l);
diff --git a/gen.c b/gen.c
index 0ec9190..dfacb2b 100644 (file)
--- a/gen.c
+++ b/gen.c
@@ -37,7 +37,7 @@ static int type_cmpv(const void *l, const void *r)
 
 static void call_register(struct array *st, struct type *type)
 {
-       *st = array_binsert(type, *st, type_cmpv);
+       array_binsert(type, st, type_cmpv);
        switch(type->type) {
        case tlist:
                call_register(st, type->data.tlist);
@@ -63,7 +63,7 @@ static void ol_expr(struct overload *st, struct expr *expr)
        case efuncall:
                if (strcmp(expr->data.efuncall.ident, "print") == 0)
                        call_register(&st->print, expr->data.efuncall.type);
-               ARRAY_ITER(struct expr *, e, i, expr->data.efuncall.args)
+               ARRAY_ITER(struct expr *, e, i, &expr->data.efuncall.args)
                        ol_expr(st, e);
                AIEND
                break;
@@ -114,7 +114,7 @@ static void ol_stmt(struct overload *st, struct stmt *stmt)
 
 static void ol_body(struct overload *st, struct array body)
 {
-       ARRAY_ITER(struct stmt *, s, i, body)
+       ARRAY_ITER(struct stmt *, s, i, &body)
                ol_stmt(st, s);
        AIEND
 }
@@ -126,7 +126,7 @@ void gen(struct ast *res, enum lang lang, FILE *cout)
                struct decl *decl = res->decls[i];
                switch(decl->type) {
                case dcomp:
-                       ARRAY_ITER(struct fundecl *, d, i, decl->data.dcomp)
+                       ARRAY_ITER(struct fundecl *, d, i, &decl->data.dcomp)
                                ol_body(&st, d->body);
                        AIEND
                        break;
@@ -148,6 +148,6 @@ void gen(struct ast *res, enum lang lang, FILE *cout)
        default:
                die("unsupported language\n");
        }
-       array_free(st.print, NULL);
-       array_free(st.eq, NULL);
+       array_free(&st.print, NULL);
+       array_free(&st.eq, NULL);
 }
diff --git a/gen/c.c b/gen/c.c
index 6514471..663dbde 100644 (file)
--- a/gen/c.c
+++ b/gen/c.c
@@ -6,21 +6,18 @@
 #include "../sem.h"
 #include "../gen.h"
 
-struct gencst {
-       struct array printtypes; // struct type *
-};
-static void expr_genc(struct gencst *st, struct expr *expr, FILE *cout);
+static void expr_genc(struct expr *expr, FILE *cout);
 
-static void binop_genc(struct gencst *st, char *fun, struct expr *l, struct expr *r, FILE *cout)
+static void binop_genc(char *fun, struct expr *l, struct expr *r, FILE *cout)
 {
        safe_fprintf(cout, "%s(", fun);
-       expr_genc(st, l, cout);
+       expr_genc(l, cout);
        safe_fprintf(cout, ", ");
-       expr_genc(st, r, cout);
+       expr_genc(r, cout);
        safe_fprintf(cout, ")");
 }
 
-static void expr_genc(struct gencst *st, struct expr *expr, FILE *cout)
+static void expr_genc(struct expr *expr, FILE *cout)
 {
        char buf[] = "\\x55";
        if (expr == NULL)
@@ -33,21 +30,21 @@ static void expr_genc(struct gencst *st, struct expr *expr, FILE *cout)
                        safe_fprintf(cout, "eq_");
                        overloaded_type(expr->loc, expr->data.ebinop.type, cout);
                        safe_fprintf(cout, "(");
-                       expr_genc(st, expr->data.ebinop.l, cout);
+                       expr_genc(expr->data.ebinop.l, cout);
                        safe_fprintf(cout, ",");
-                       expr_genc(st, expr->data.ebinop.r, cout);
+                       expr_genc(expr->data.ebinop.r, cout);
                        safe_fprintf(cout, ")");
                } else if (expr->data.ebinop.op == cons) {
-                       binop_genc(st, "splc_cons", expr->data.ebinop.l,
+                       binop_genc("splc_cons", expr->data.ebinop.l,
                                expr->data.ebinop.r, cout);
                } else if (expr->data.ebinop.op == power) {
-                       binop_genc(st, "splc_power", expr->data.ebinop.l,
+                       binop_genc("splc_power", expr->data.ebinop.l,
                                expr->data.ebinop.r, cout);
                } else {
                        safe_fprintf(cout, "(");
-                       expr_genc(st, expr->data.ebinop.l, cout);
+                       expr_genc(expr->data.ebinop.l, cout);
                        safe_fprintf(cout, "%s", binop_str[expr->data.ebinop.op]);
-                       expr_genc(st, expr->data.ebinop.r, cout);
+                       expr_genc(expr->data.ebinop.r, cout);
                        safe_fprintf(cout, ")");
                }
                break;
@@ -66,9 +63,9 @@ static void expr_genc(struct gencst *st, struct expr *expr, FILE *cout)
                        safe_fprintf(cout, "%s", expr->data.efuncall.ident);
                }
                safe_fprintf(cout, "(");
-               ARRAY_ITER(struct expr *, e, i, expr->data.efuncall.args) {
-                       expr_genc(st, e, cout);
-                       if (i+1 < ARRAY_SIZE(expr->data.efuncall.args))
+               ARRAY_ITER(struct expr *, e, i, &expr->data.efuncall.args) {
+                       expr_genc(e, cout);
+                       if (i+1 < ARRAY_SIZE(&expr->data.efuncall.args))
                                safe_fprintf(cout, ", ");
                } AIEND
                safe_fprintf(cout, ")");
@@ -83,7 +80,7 @@ static void expr_genc(struct gencst *st, struct expr *expr, FILE *cout)
                safe_fprintf(cout, "NULL");
                break;
        case etuple:
-               binop_genc(st, "splc_tuple", expr->data.etuple.left,
+               binop_genc("splc_tuple", expr->data.etuple.left,
                        expr->data.etuple.right, cout);
                break;
        case estring:
@@ -95,7 +92,7 @@ static void expr_genc(struct gencst *st, struct expr *expr, FILE *cout)
                break;
        case eunop:
                safe_fprintf(cout, "(%s", unop_str[expr->data.eunop.op]);
-               expr_genc(st, expr->data.eunop.l, cout);
+               expr_genc(expr->data.eunop.l, cout);
                safe_fprintf(cout, ")");
                break;
        default:
@@ -131,18 +128,18 @@ static void type_genc(struct type *type, FILE *cout)
        }
 }
 
-static void vardecl_genc(struct gencst *st, struct vardecl *vardecl, int indent, FILE *cout)
+static void vardecl_genc(struct vardecl *vardecl, int indent, FILE *cout)
 {
        if (vardecl == NULL)
                return;
        pindent(indent, cout);
        type_genc(vardecl->type, cout);
        safe_fprintf(cout, "%s = ", vardecl->ident);
-       expr_genc(st, vardecl->expr, cout);
+       expr_genc(vardecl->expr, cout);
        safe_fprintf(cout, ";\n");
 }
 
-static void stmt_genc(struct gencst *st, struct stmt *stmt, int indent, FILE *cout)
+static void stmt_genc(struct stmt *stmt, int indent, FILE *cout)
 {
        if (stmt == NULL)
                return;
@@ -150,25 +147,25 @@ static void stmt_genc(struct gencst *st, struct stmt *stmt, int indent, FILE *co
        case sassign:
                pindent(indent, cout);
                safe_fprintf(cout, "%s", stmt->data.sassign.ident);
-               ARRAY_ITER(char *, f, i, stmt->data.sassign.fields)
+               ARRAY_ITER(char *, f, i, &stmt->data.sassign.fields)
                        safe_fprintf(cout, "->%s", f);
                AIEND
                safe_fprintf(cout, " = ");
-               expr_genc(st, stmt->data.sassign.expr, cout);
+               expr_genc(stmt->data.sassign.expr, cout);
                safe_fprintf(cout, ";\n");
                break;
        case sif:
                pindent(indent, cout);
                safe_fprintf(cout, "if (");
-               expr_genc(st, stmt->data.sif.pred, cout);
+               expr_genc(stmt->data.sif.pred, cout);
                safe_fprintf(cout, ") {\n");
-               ARRAY_ITER(struct stmt *, s, i, stmt->data.sif.then)
-                       stmt_genc(st, s, indent+1, cout);
+               ARRAY_ITER(struct stmt *, s, i, &stmt->data.sif.then)
+                       stmt_genc(s, indent+1, cout);
                AIEND
                pindent(indent, cout);
                safe_fprintf(cout, "} else {\n");
-               ARRAY_ITER(struct stmt *, s, i, stmt->data.sif.els)
-                       stmt_genc(st, s, indent+1, cout);
+               ARRAY_ITER(struct stmt *, s, i, &stmt->data.sif.els)
+                       stmt_genc(s, indent+1, cout);
                AIEND
                pindent(indent, cout);
                safe_fprintf(cout, "}\n");
@@ -176,24 +173,24 @@ static void stmt_genc(struct gencst *st, struct stmt *stmt, int indent, FILE *co
        case sreturn:
                pindent(indent, cout);
                safe_fprintf(cout, "return ");
-               expr_genc(st, stmt->data.sreturn, cout);
+               expr_genc(stmt->data.sreturn, cout);
                safe_fprintf(cout, ";\n");
                break;
        case sexpr:
                pindent(indent, cout);
-               expr_genc(st, stmt->data.sexpr, cout);
+               expr_genc(stmt->data.sexpr, cout);
                safe_fprintf(cout, ";\n");
                break;
        case svardecl:
-               vardecl_genc(st, stmt->data.svardecl, indent, cout);
+               vardecl_genc(stmt->data.svardecl, indent, cout);
                break;
        case swhile:
                pindent(indent, cout);
                safe_fprintf(cout, "while (");
-               expr_genc(st, stmt->data.swhile.pred, cout);
+               expr_genc(stmt->data.swhile.pred, cout);
                safe_fprintf(cout, ") {\n");
-               ARRAY_ITER(struct stmt *, s, i, stmt->data.swhile.body)
-                       stmt_genc(st, s, indent+1, cout);
+               ARRAY_ITER(struct stmt *, s, i, &stmt->data.swhile.body)
+                       stmt_genc(s, indent+1, cout);
                AIEND
                pindent(indent, cout);
                safe_fprintf(cout, "}\n");
@@ -207,46 +204,46 @@ static void fundecl_sig(struct fundecl *decl, FILE *cout)
 {
        type_genc(decl->rtype, cout);
        safe_fprintf(cout, "%s (", decl->ident);
-       ARRAY_ITER(char *, a, i, decl->args) {
+       ARRAY_ITER(char *, a, i, &decl->args) {
                if (i >= ARRAY_SIZE(decl->atypes))
                        die("function with unmatched type\n");
                type_genc(ARRAY_EL(struct type *, decl->atypes, i), cout);
                safe_fprintf(cout, "%s", a);
-               if (i < ARRAY_SIZE(decl->args) - 1)
+               if (i < ARRAY_SIZE(&decl->args) - 1)
                        safe_fprintf(cout, ", ");
        } AIEND
        safe_fprintf(cout, ")");
 }
 
-static void fundecl_genc(struct gencst *st, struct fundecl *decl, FILE *cout)
+static void fundecl_genc(struct fundecl *decl, FILE *cout)
 {
        fundecl_sig(decl, cout);
        safe_fprintf(cout, "{\n");
-       ARRAY_ITER(struct stmt *, s, i, decl->body)
-               stmt_genc(st, s, 1, cout);
+       ARRAY_ITER(struct stmt *, s, i, &decl->body)
+               stmt_genc(s, 1, cout);
        AIEND
        safe_fprintf(cout, "}\n");
 }
 
-static void decl_genc(struct gencst *st, struct decl *decl, FILE *cout)
+static void decl_genc(struct decl *decl, FILE *cout)
 {
        switch (decl->type) {
        case dcomp:
-               if (ARRAY_SIZE(decl->data.dcomp) > 1) {
-                       ARRAY_ITER(struct fundecl *, d, i, decl->data.dcomp)
+               if (ARRAY_SIZE(&decl->data.dcomp) > 1) {
+                       ARRAY_ITER(struct fundecl *, d, i, &decl->data.dcomp)
                                fundecl_sig(d, cout);
                                safe_fprintf(cout, ";\n");
                        AIEND
                }
-               ARRAY_ITER(struct fundecl *, d, i, decl->data.dcomp)
-                       fundecl_genc(st, d, cout);
+               ARRAY_ITER(struct fundecl *, d, i, &decl->data.dcomp)
+                       fundecl_genc(d, cout);
                AIEND
                break;
        case dfundecl:
                die("fundecls should be gone by now\n");
                break;
        case dvardecl:
-               vardecl_genc(st, decl->data.dvar, 0, cout);
+               vardecl_genc(decl->data.dvar, 0, cout);
                break;
        }
 }
@@ -278,7 +275,7 @@ static void generate_eq(struct type *type, FILE *cout)
        case ttuple:
                safe_fprintf(cout, "\treturn eq_");
                overloaded_type(loc, type->data.ttuple.l, cout);
-               safe_fprintf(cout, "(x->fst, y->fst)");
+               safe_fprintf(cout, "(x->fy->fst)");
                safe_fprintf(cout, " && eq_");
                overloaded_type(loc, type->data.ttuple.r, cout);
                safe_fprintf(cout, "(x->snd, y->snd);");
@@ -331,23 +328,22 @@ static void generate_print(struct type *type, FILE *cout)
        safe_fprintf(cout, "}\n");
 }
 
-void genc(struct ast *ast, struct overload ol, FILE *cout)
+void genc(const struct ast *ast, const struct overload ol, FILE *cout)
 {
        //Header
        safe_fprintf(cout, "#include \"rts.h\"\n");
 
        //Overloaded functions
-       ARRAY_ITER(struct type *, t, i, ol.print) {
+       ARRAY_ITER(struct type *, t, i, &ol.print) {
                generate_print(t, cout);
        } AIEND
-       ARRAY_ITER(struct type *, t, i, ol.eq) {
+       ARRAY_ITER(struct type *, t, i, &ol.eq) {
                generate_eq(t, cout);
        } AIEND
 
        //Code
-       struct gencst st = {.printtypes = array_null};
        for (int i = 0; i<ast->ndecls; i++) {
                safe_fprintf(cout, "\n");
-               decl_genc(&st, ast->decls[i], cout);
+               decl_genc(ast->decls[i], cout);
        }
 }
diff --git a/gen/c.h b/gen/c.h
index 4c8ef90..7b400b2 100644 (file)
--- a/gen/c.h
+++ b/gen/c.h
@@ -6,6 +6,6 @@
 #include "../ast.h"
 #include "../gen.h"
 
-void genc(struct ast *res, struct overload ol, FILE *cout);
+void genc(const struct ast *res, const struct overload ol, FILE *cout);
 
 #endif
index b4f7aaa..2e37d57 100644 (file)
--- a/gen/ssm.c
+++ b/gen/ssm.c
@@ -3,9 +3,19 @@
 #include "../ast.h"
 #include "../sem.h"
 #include "../gen.h"
+#include <uthash.h>
 
+enum vreftype {global, arg, local};
+struct vref {
+       const char *id;
+       enum vreftype type;
+       int num;
+       UT_hash_handle hh;
+};
 struct genssmst {
        int fresh;
+       int vdecl;
+       struct vref *refs;
 };
 
 static const char *unop_instr[] = { [inverse] = "not", [negate] = "neg" };
@@ -16,6 +26,15 @@ static const char *binop_instr[] = {
        [modulo] = "mod", [power] = "^",
 };
 
+static void add_vref(struct genssmst *st, char *id, enum vreftype type, int no)
+{
+       struct vref *r = xalloc(1, struct vref);
+       r->id = id;
+       r->type = type;
+       r->num = no;
+       HASH_ADD_KEYPTR(hh, st->refs, r->id, strlen(r->id), r);
+}
+
 static void generate_eq(struct type *type, FILE *cout)
 {
        YYLTYPE loc;
@@ -216,6 +235,7 @@ static void call_print_type(YYLTYPE loc, struct type *type, FILE *cout)
 //}
 static void expr_genssm(struct genssmst *st, struct expr *expr, FILE *cout)
 {
+       struct vref *el;
        switch(expr->type) {
        case ebinop:
                expr_genssm(st, expr->data.ebinop.l, cout);
@@ -246,26 +266,41 @@ static void expr_genssm(struct genssmst *st, struct expr *expr, FILE *cout)
                safe_fprintf(cout, "ldc %d\n", expr->data.echar);
                break;
        case efuncall:
-               ARRAY_ITER(struct expr *, e, i, expr->data.efuncall.args)
+               ARRAY_ITER(struct expr *, e, i, &expr->data.efuncall.args)
                        expr_genssm(st, e, cout);
                AIEND
                if (strcmp(expr->data.efuncall.ident, "print") == 0) {
                        safe_fprintf(cout, "bsr print");
-//                     call_print_register(st, expr->data.efuncall.type);
                        call_print_type(expr->loc, expr->data.efuncall.type, cout);
                        safe_fprintf(cout, "\n");
                } else {
                        safe_fprintf(cout, "bsr %s\n", expr->data.efuncall.ident);
                        safe_fprintf(cout, "ajs -%u\n",
-                               ARRAY_SIZE(expr->data.efuncall.args));
+                               ARRAY_SIZE(&expr->data.efuncall.args));
+                       //TODO don't do this when the function returns void
+                       safe_fprintf(cout, "ldr RR\n");
                }
                break;
        case eint:
                safe_fprintf(cout, "ldc %d\n", expr->data.eint);
                break;
-       //case eident:
-       //      safe_fprintf(cout, "%s", expr->data.eident);
-       //      break;
+       case eident:
+               HASH_FIND_STR(st->refs, expr->data.eident, el);
+               if (el == NULL)
+                       die("unknown variable: %s???", expr->data.eident);
+               switch(el->type) {
+               case global:
+                       safe_fprintf(cout, "ldr R5\n");
+                       safe_fprintf(cout, "lda %d\n", el->num);
+                       break;
+               case arg:
+                       safe_fprintf(cout, "ldl -%d\n", el->num);
+                       break;
+               case local:
+                       safe_fprintf(cout, "ldl %d\n", el->num);
+                       break;
+               }
+               break;
        case enil:
                safe_fprintf(cout, "ldc 0\n");
                break;
@@ -293,7 +328,7 @@ static void expr_genssm(struct genssmst *st, struct expr *expr, FILE *cout)
 static void stmt_genssm(struct genssmst *st, struct stmt *stmt, FILE *cout);
 static void body_genssm(struct genssmst *st, struct array body, FILE *cout)
 {
-       ARRAY_ITER(struct stmt *, s, i, body)
+       ARRAY_ITER(struct stmt *, s, i, &body)
                stmt_genssm(st, s, cout);
        AIEND
 }
@@ -331,7 +366,8 @@ static void stmt_genssm(struct genssmst *st, struct stmt *stmt, FILE *cout)
                safe_fprintf(cout, "ajs -1\n");
                break;
        case svardecl:
-//             vardecl_genc(st, stmt->data.svardecl, indent, cout);
+               expr_genssm(st, stmt->data.svardecl->expr, cout);
+               add_vref(st, stmt->data.svardecl->ident, local, st->vdecl++);
                break;
        case swhile:
                safe_fprintf(cout, "_while%d: \n", st->fresh);
@@ -346,39 +382,68 @@ static void stmt_genssm(struct genssmst *st, struct stmt *stmt, FILE *cout)
        }
 }
 
+//static void vardecl_genssm(struct genssmst *st, struct vardecl *vardecl, FILE *cout)
+//{
+//     //TODO add to dictionary
+//     expr_genssm(st, vardecl->expr, cout);
+//}
 
-static void vardecl_genssm(struct genssmst *st, struct vardecl *vardecl, FILE *cout)
+static void global_genssm(int no, struct genssmst *st, struct vardecl *vardecl, FILE *cout)
 {
-       //TODO add to dictionary
+       add_vref(st, vardecl->ident, global, no+1);
        expr_genssm(st, vardecl->expr, cout);
 }
 
+static int count_locals(struct array stmts)
+{
+       int r = 0;
+       ARRAY_ITER(struct stmt *, s, i, &stmts) {
+               switch(s->type) {
+               case sif:
+                       r += count_locals(s->data.sif.then);
+                       r += count_locals(s->data.sif.els);
+                       break;
+               case svardecl:
+                       r++;
+                       break;
+               case swhile:
+                       r += count_locals(s->data.swhile.body);
+                       break;
+               default:
+                       break;
+               }
+       } AIEND
+       return r;
+}
+
 static void fundecl_genssm(struct genssmst *st, struct fundecl *decl, FILE *cout)
 {
-       safe_fprintf(cout, "%s: link 0\n", decl->ident);
-       //TODO add args to dictionary
+       safe_fprintf(cout, "%s: link %d\n", decl->ident, count_locals(decl->body));
+       ARRAY_ITER(char *, a, i, &decl->args) {
+               add_vref(st, a, arg, ARRAY_SIZE(&decl->args)-i+1);
+       } AIEND
        body_genssm(st, decl->body, cout);
        safe_fprintf(cout, "unlink\n");
        safe_fprintf(cout, "ret\n");
 }
 
-void genssm(struct ast *ast, struct overload ol, FILE *cout)
+void genssm(const struct ast *ast, const struct overload ol, FILE *cout)
 {
        //Header
        safe_fprintf(cout, "ldrr R5 R1\n");
-       struct genssmst st = { .fresh=0 };
+       struct genssmst st = { .fresh=0, .refs=NULL };
        for (int i = 0; i<ast->ndecls; i++)
                if (ast->decls[i]->type == dvardecl)
-                       vardecl_genssm(&st, ast->decls[i]->data.dvar, cout);
+                       global_genssm(i, &st, ast->decls[i]->data.dvar, cout);
        safe_fprintf(cout, "bsr main\n");
        safe_fprintf(cout, "halt\n");
 
        //Generate overloaded functions
-       ARRAY_ITER(struct type *, t, i, ol.print)
+       ARRAY_ITER(struct type *, t, i, &ol.print)
                if (t->type != tbasic)
                        generate_print(t, cout);
        AIEND
-       ARRAY_ITER(struct type *, t, i, ol.eq)
+       ARRAY_ITER(struct type *, t, i, &ol.eq)
                if (t->type != tbasic)
                        generate_eq(t, cout);
        AIEND
@@ -386,9 +451,18 @@ void genssm(struct ast *ast, struct overload ol, FILE *cout)
        //Generate code
        for (int i = 0; i<ast->ndecls; i++)
                if (ast->decls[i]->type == dcomp)
-                       ARRAY_ITER(struct fundecl *, d, j, ast->decls[i]->data.dcomp)
+                       ARRAY_ITER(struct fundecl *, d, j, &ast->decls[i]->data.dcomp)
                                fundecl_genssm(&st, d, cout);
                        AIEND
+
+       //Free again
+       struct vref *el, *tmp;
+       HASH_ITER(hh, st.refs, el, tmp) {
+               HASH_DEL(st.refs, el);
+               free(el);
+       }
+
+       //Include rts
        FILE *rts = fopen("rts.ssm", "r");
        if (rts == NULL)
                pdie("fopen");
index 93e5247..64ae852 100644 (file)
--- a/gen/ssm.h
+++ b/gen/ssm.h
@@ -6,6 +6,6 @@
 #include "../ast.h"
 #include "../gen.h"
 
-void genssm(struct ast *res, struct overload ol, FILE *cout);
+void genssm(const struct ast *res, const struct overload ol, FILE *cout);
 
 #endif
diff --git a/parse.y b/parse.y
index 8d53296..8008a0e 100644 (file)
--- a/parse.y
+++ b/parse.y
@@ -35,6 +35,7 @@ int yywrap()
        struct expr *expr;
        struct stmt *stmt;
        struct array array;
+       struct array *arrayp;
        struct vardecl *vardecl;
        struct fundecl *fundecl;
        struct type *type;
@@ -57,7 +58,8 @@ int yywrap()
 
 %type <ast> start
 %type <expr> expr tuply
-%type <array> decls funtype args nargs body bbody fargs fnargs field
+%type <array> decls args nargs body bbody fargs fnargs field
+%type <arrayp> funtype
 %type <stmt> stmt
 %type <type> type ftype
 %type <vardecl> vardecl
@@ -68,8 +70,8 @@ int yywrap()
 start : decls { *result = ast($1, @$); } ;
 decls
        : /* empty */ { $$ = array_null; }
-       | decls vardecl { $$ = array_append($1, decl_var($2, @2)); }
-       | decls fundecl { $$ = array_append($1, decl_fun($2, @2)); }
+       | decls vardecl { array_append(&$1, decl_var($2, @2)); $$ = $1; }
+       | decls fundecl { array_append(&$1, decl_fun($2, @2)); $$ = $1; }
        ;
 vardecl
        : VAR IDENT ASSIGN expr SEMICOLON { $$ = vardecl(NULL, $2, $4, @$); }
@@ -77,13 +79,13 @@ vardecl
        ;
 fundecl
        : IDENT BOPEN args BCLOSE bbody
-               { $$ = fundecl($1, $3, array_null, NULL, $5, @$); }
+               { $$ = fundecl($1, $3, NULL, NULL, $5, @$); }
        | IDENT BOPEN args BCLOSE CONS CONS funtype ARROW ftype bbody
                { $$ = fundecl($1, $3, $7, $9, $10, @$); }
        ;
 funtype
-       : /* empty */ { $$ = array_null; }
-       | funtype ftype { $$ = array_append($1, $2); }
+       : /* empty */ { $$ = array_new(8); }
+       | funtype ftype { array_append($1, $2); $$ = $1; }
        ;
 /* don't allow vardecls to be fully polymorph, this complicates parsing a lot */
 type
@@ -103,28 +105,28 @@ args
        | nargs
        ;
 nargs
-       : nargs COMMA IDENT { $$ = array_append($1, $3); }
-       | IDENT { array_init(&$$, 8); $$ = array_append($$, $1); }
+       : nargs COMMA IDENT { array_append(&$1, $3); $$ = $1; }
+       | IDENT { array_init(&$$, 8); array_append(&$$, $1); }
        ;
 fargs
        : /* empty */ { $$ = array_null; }
        | fnargs
        ;
 fnargs
-       : fnargs COMMA expr { $$ = array_append($1, $3); }
-       | expr { array_init(&$$, 8); $$ = array_append($$, $1); }
+       : fnargs COMMA expr { array_append(&$1, $3); $$ = $1; }
+       | expr { array_init(&$$, 8); array_append(&$$, $1); }
        ;
 body
        : /* empty */ { $$ = array_null; }
-       | body stmt { $$ = array_append($1, $2); }
+       | body stmt { array_append(&$1, $2); $$ = $1; }
        ;
 field
        : /* empty */ { $$ = array_null; }
-       | field DOT IDENT { $$ = array_append($1, $3); }
+       | field DOT IDENT { array_append(&$1, $3); $$ = $1; }
        ;
 bbody
        : COPEN body CCLOSE { $$ = $2; }
-       | stmt { array_init(&$$, 1); $$ = array_append($$, $1); }
+       | stmt { array_init(&$$, 1); array_append(&$$, $1); }
        ;
 stmt
        : IF BOPEN expr BCLOSE bbody { $$ = stmt_if($3, $5, array_null, @$); }
diff --git a/sem.c b/sem.c
index 78cf747..181ea84 100644 (file)
--- a/sem.c
+++ b/sem.c
@@ -1,12 +1,9 @@
-#include <stdlib.h>
-#include <string.h>
-#include <stdint.h>
-
-#include "list.h"
+#include "sem/constant.h"
+#include "sem/main.h"
+#include "sem/return.h"
 #include "sem/scc.h"
-#include "sem/hm/scheme.h"
-#include "sem/hm/gamma.h"
-#include "ast.h"
+#include "sem/type.h"
+#include "sem/vardecl.h"
 
 void type_error(YYLTYPE l, bool d, const char *msg, ...)
 {
@@ -14,306 +11,31 @@ void type_error(YYLTYPE l, bool d, const char *msg, ...)
        va_start(ap, msg);
        safe_fprintf(stderr, "Type error\n%d-%d: ", l.first_line, l.first_column);
        safe_vfprintf(stderr, msg, ap);
+       safe_fprintf(stderr, "\n");
        va_end(ap);
        if (d)
                die("");
 }
 
-static void check_expr_constant(struct expr *expr)
-{
-       switch (expr->type) {
-       case ebinop:
-               check_expr_constant(expr->data.ebinop.l);
-               check_expr_constant(expr->data.ebinop.r);
-               break;
-       case eunop:
-               check_expr_constant(expr->data.eunop.l);
-               break;
-       case efuncall:
-       case eident:
-               type_error(expr->loc, true,
-                       "Initialiser is not constant (identifier used)\n");
-               break;
-       default:
-               break;
-       }
-}
-
-static struct vardecl *type_vardecl(struct gamma *gamma, struct vardecl *vardecl)
-{
-       struct type *t = vardecl->type == NULL
-               ? gamma_fresh(gamma) : vardecl->type;
-       struct subst *s = infer_expr(gamma, vardecl->expr, t);
-
-       vardecl->type = subst_apply_t(s, t);
-       gamma_insert(gamma, ident_str(vardecl->ident), scheme_create(vardecl->type));
-
-       subst_free(s);
-
-       return vardecl;
-}
-
-static void patch_overload_stmt(struct subst *subst, struct stmt *s);
-static void patch_overload_body(struct subst *subst, struct array body)
-{
-       ARRAY_ITER(struct stmt *, s, j, body)
-               patch_overload_stmt(subst, s);
-       AIEND
-}
-
-static void patch_overload_expr(struct subst *subst, struct expr *expr)
-{
-       if (expr == NULL)
-               return;
-       switch (expr->type) {
-       case ebinop:
-               if (expr->data.ebinop.op == eq || expr->data.ebinop.op == neq)
-                       expr->data.ebinop.type = subst_apply_t(subst,
-                               expr->data.ebinop.type);
-               patch_overload_expr(subst, expr->data.ebinop.l);
-               patch_overload_expr(subst, expr->data.ebinop.r);
-               break;
-       case efuncall:
-               if (strcmp(expr->data.efuncall.ident, "print") == 0)
-                       expr->data.efuncall.type = subst_apply_t(subst,
-                               expr->data.efuncall.type);
-               break;
-       case etuple:
-               patch_overload_expr(subst, expr->data.etuple.left);
-               patch_overload_expr(subst, expr->data.etuple.right);
-               break;
-       default:
-               break;
-       }
-
-}
-static void patch_overload_stmt(struct subst *subst, struct stmt *stmt)
-{
-       switch (stmt->type) {
-       case sassign:
-               patch_overload_expr(subst, stmt->data.sassign.expr);
-               break;
-       case sif:
-               patch_overload_expr(subst, stmt->data.sif.pred);
-               patch_overload_body(subst, stmt->data.sif.then);
-               patch_overload_body(subst, stmt->data.sif.els);
-               break;
-       case sreturn:
-               patch_overload_expr(subst, stmt->data.sreturn);
-               break;
-       case sexpr:
-               patch_overload_expr(subst, stmt->data.sexpr);
-               break;
-       case svardecl:
-               stmt->data.svardecl->type = subst_apply_t(subst,
-                       stmt->data.svardecl->type);
-               patch_overload_expr(subst, stmt->data.svardecl->expr);
-               break;
-       case swhile:
-               patch_overload_expr(subst, stmt->data.swhile.pred);
-               patch_overload_body(subst, stmt->data.swhile.body);
-               break;
-       }
-}
-
-static void type_comp(struct gamma *gamma, struct array decl)
-{
-       //Create a fresh variable for every function in the component
-       struct type **fs = xalloc(ARRAY_SIZE(decl), struct type *);
-       ARRAY_ITER(struct fundecl *, d, i, decl) {
-               fs[i] = gamma_fresh(gamma);
-               ARRAY_ITERI(j, d->args) {
-                       struct type *a = gamma_fresh(gamma);
-                       fs[i] = type_arrow(a, fs[i]);
-               }
-               gamma_insert(gamma, ident_str(d->ident), scheme_create(fs[i]));
-       } AIEND
-
-       //Infer each function
-       struct subst *s0 = subst_id();
-       ARRAY_ITERI(i, decl) {
-               struct subst *s1 = infer_fundecl(gamma,
-                       ARRAY_EL(struct fundecl *, decl, i),
-                       subst_apply_t(s0, fs[i]));
-               s0 = subst_union(s1, s0);
-       }
-
-       //Generalise all functions and put in gamma
-       ARRAY_ITER(struct fundecl *, d, i, decl) {
-               struct type *t = subst_apply_t(s0, fs[i]);
-
-               //unify against given type specification
-               if (d->rtype != NULL) {
-                       struct type *dt = d->rtype;
-                       for (int j = (int)ARRAY_SIZE(d->atypes)-1; j>=0; j--)
-                               dt = type_arrow(ARRAY_EL(struct type *,
-                                       d->atypes, j), dt);
-                       struct subst *s1 = unify(d->loc, dt, t);
-                       subst_apply_t(s1, fs[i]);
-                       subst_free(s1);
-                       type_free(dt);
-               }
-
-               gamma_insert(gamma, ident_str(d->ident), scheme_generalise(gamma, t));
-
-               //Put the type in the ast
-               d->atypes = array_clean(d->atypes, NULL);
-               d->atypes = array_resize(d->atypes, ARRAY_SIZE(d->args));
-
-               ARRAY_ITERI(j, d->args) {
-                       d->atypes = array_append(d->atypes, (void *)type_dup(t->data.tarrow.l));
-                       t = t->data.tarrow.r;
-               }
-               d->rtype = type_dup(t);
-       } AIEND
-
-       //Free all types
-       for (size_t i = 0; i<ARRAY_SIZE(decl); i++)
-               type_free(fs[i]);
-       free(fs);
-
-       //Patch all overloaded functions
-       ARRAY_ITER(struct fundecl *, d, i, decl) {
-               patch_overload_body(s0, d->body);
-       } AIEND
-       subst_free(s0);
-}
-
-static void gamma_preamble(struct gamma *gamma)
-{
-       struct type *t = type_arrow(type_tuple(type_var_str("a")
-               , type_var_str("b")) ,type_var_str("a"));
-       gamma_insert(gamma, ident_str("fst"), scheme_generalise(gamma, t));
-       type_free(t);
-
-       t = type_arrow(type_tuple(type_var_str("a")
-               , type_var_str("b")) ,type_var_str("b"));
-       gamma_insert(gamma, ident_str("snd"), scheme_generalise(gamma, t));
-       type_free(t);
-
-       t = type_arrow(type_list(type_var_str("a")),
-                       type_var_str("a"));
-       gamma_insert(gamma, ident_str("hd"), scheme_generalise(gamma, t));
-       type_free(t);
-
-       t = type_arrow(type_list(type_var_str("a")),
-                       type_list(type_var_str("a")));
-       gamma_insert(gamma, ident_str("tl"), scheme_generalise(gamma, t));
-       type_free(t);
-
-       t = type_arrow(type_list(type_var_str("a")),
-                       type_basic(btbool));
-       gamma_insert(gamma, ident_str("isEmpty"), scheme_generalise(gamma, t));
-       type_free(t);
-
-       t = type_arrow(type_var_str("a"), type_basic(btvoid));
-       gamma_insert(gamma, ident_str("print"), scheme_generalise(gamma, t));
-       type_free(t);
-}
-
-static bool check_return_stmt(struct stmt *stmt);
-static bool check_return_body(struct array body)
-{
-       ARRAY_ITER(struct stmt *, s, i, body)
-               if (check_return_stmt(s))
-                       return true;
-       AIEND
-       return false;
-}
-
-
-static bool check_return_stmt(struct stmt *stmt)
-{
-       switch (stmt->type) {
-       case sassign:
-               return false;
-       case sif:
-               return check_return_body(stmt->data.sif.then)
-                       && check_return_body(stmt->data.sif.els);
-       case swhile:
-               return check_return_body(stmt->data.swhile.body);
-       case sreturn:
-               return true;
-       default:
-               return false;
-       }
-}
-
-static void check_return_comp(struct array decl)
-{
-       ARRAY_ITER(struct fundecl *, d, i, decl) {
-               if (d->rtype->type == tbasic && d->rtype->data.tbasic == btvoid)
-                       continue;
-               if (!check_return_body(d->body))
-                       type_error(d->loc, true,
-                               "%s doesn't return properly", d->ident);
-       } AIEND
-}
-
-static void add_return_if_none(struct array decl)
-{
-       ARRAY_ITER(struct fundecl *, d, i, decl)
-               if (d->rtype == NULL && !check_return_body(d->body))
-                       d->body = array_append(d->body, stmt_return(NULL, d->loc));
-       AIEND
-}
-
-static bool checkmain (struct fundecl *d)
-{
-       if (strcmp(d->ident, "main") == 0) {
-               if (ARRAY_SIZE(d->args) != 0)
-                       type_error(d->loc, true, "main cannot have arguments");
-               if (d->rtype->type != tbasic || d->rtype->data.tbasic != btvoid)
-                       type_error(d->loc, true, "main must return void");
-               return true;
-       }
-       return false;
-}
-
 struct ast *sem(struct ast *ast)
 {
        //Break up into strongly connected components
-       ast = ast_scc(ast);
+       sem_check_scc(ast);
+
+       //Check whether all globals are constant
+       sem_check_constant(ast);
+
+       // Check that all functions return and mark void
+       sem_check_return(ast);
 
-       struct gamma *gamma = gamma_init();
-       gamma_preamble(gamma);
+       // Check the types
+       sem_check_types(ast);
 
-       //Check all vardecls
-       for (int i = 0; i<ast->ndecls; i++) {
-               struct decl *decl = ast->decls[i];
-               switch(decl->type) {
-               case dvardecl:
-                       //Check if constant
-                       check_expr_constant(decl->data.dvar->expr);
-                       //Infer if necessary
-                       type_vardecl(gamma, decl->data.dvar);
-                       break;
-               case dcomp:
-                       //Infer function as singleton component
-                       add_return_if_none(decl->data.dcomp);
-                       type_comp(gamma, decl->data.dcomp);
-                       check_return_comp(decl->data.dcomp);
-                       break;
-               case dfundecl:
-                       die("fundecls should be gone by now\n");
-                       break;
-               }
-       }
-       gamma_free(gamma);
+       // Check that a main function exists with the correct type
+       sem_check_main(ast);
 
-       //Check for a main function
-       bool found = false;
-       for (int i = 0; i<ast->ndecls && !found; i++) {
-               struct decl *decl = ast->decls[i];
-               if (decl->type == dcomp) {
-                       ARRAY_ITER(struct fundecl *, d, i, decl->data.dcomp) {
-                               if ((found = checkmain(d)))
-                                       break;
-                       } AIEND
-               }
-       }
-       if (!found)
-               type_error(ast->loc, true, "no main function found\n");
+       // Move all vardecls to the top of the function
+       sem_check_vardecls(ast);
 
        return ast;
 }
diff --git a/sem/constant.c b/sem/constant.c
new file mode 100644 (file)
index 0000000..e827e10
--- /dev/null
@@ -0,0 +1,29 @@
+#include "../sem.h"
+
+static void check_expr_constant(struct expr *expr)
+{
+       switch (expr->type) {
+       case ebinop:
+               check_expr_constant(expr->data.ebinop.l);
+               check_expr_constant(expr->data.ebinop.r);
+               break;
+       case eunop:
+               check_expr_constant(expr->data.eunop.l);
+               break;
+       case efuncall:
+       case eident:
+               type_error(expr->loc, true,
+                       "Initialiser is not constant (identifier used)\n");
+               break;
+       default:
+               break;
+       }
+}
+
+
+void sem_check_constant(struct ast *ast)
+{
+       for (int i = 0; i<ast->ndecls; i++)
+               if (ast->decls[i]->type == dvardecl)
+                       check_expr_constant(ast->decls[i]->data.dvar->expr);
+}
diff --git a/sem/constant.h b/sem/constant.h
new file mode 100644 (file)
index 0000000..bdaf5b5
--- /dev/null
@@ -0,0 +1,8 @@
+#ifndef SEM_CONSTANT_H
+#define SEM_CONSTANT_H
+
+#include "../ast.h"
+
+void sem_check_constant(struct ast *ast);
+
+#endif
index f1aa557..097010d 100644 (file)
--- a/sem/hm.c
+++ b/sem/hm.c
@@ -140,7 +140,7 @@ struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *ty
                struct type *t = ft;
                s0 = subst_id();
                //Infer args
-               ARRAY_ITER(struct expr *, a, i, expr->data.efuncall.args) {
+               ARRAY_ITER(struct expr *, a, i, &expr->data.efuncall.args) {
                        if (t->type != tarrow)
                                type_error(expr->loc, true,
                                        "too many arguments to %s\n",
@@ -203,7 +203,7 @@ static struct subst *infer_body(struct gamma *gamma, struct array stmts, struct
 {
        gamma_increment_scope(gamma);
        struct subst *s0 = subst_id(), *s1;
-       ARRAY_ITER(struct stmt *, s, i, stmts) {
+       ARRAY_ITER(struct stmt *, s, i, &stmts) {
                s1 = infer_stmt(gamma, s, type);
                s0 = subst_union(s1, s0);
                subst_apply_g(s0, gamma);
@@ -225,7 +225,7 @@ struct subst *infer_stmt(struct gamma *gamma, struct stmt *stmt, struct type *ty
                f1 = scheme_instantiate(gamma, s);
 
                s0 = subst_id();
-               ARRAY_ITER(char *, f, i, stmt->data.sassign.fields) {
+               ARRAY_ITER(char *, f, i, &stmt->data.sassign.fields) {
                        if (strcmp(f, "fst") == 0) {
                                //f1 is the type of the variable in gamma
                                struct type *fl = gamma_fresh(gamma);
@@ -328,7 +328,7 @@ struct subst *infer_fundecl(struct gamma *gamma, struct fundecl *fundecl, struct
        // Put arguments in gamma
        gamma_increment_scope(gamma);
        struct type *at = ftype;
-       ARRAY_ITER(char *, a, i, fundecl->args) {
+       ARRAY_ITER(char *, a, i, &fundecl->args) {
                if (at->type != tarrow)
                        die("malformed ftype\n");
                gamma_insert(gamma, ident_str(a), scheme_create(at->data.tarrow.l));
@@ -338,7 +338,7 @@ struct subst *infer_fundecl(struct gamma *gamma, struct fundecl *fundecl, struct
                die("malformed ftype\n");
 
        struct subst *s = subst_id();
-       ARRAY_ITER(struct stmt *, st, i, fundecl->body) {
+       ARRAY_ITER(struct stmt *, st, i, &fundecl->body) {
                struct subst *s1 = infer_stmt(gamma, st, at);
                s = subst_union(s1, s);
        } AIEND
diff --git a/sem/main.c b/sem/main.c
new file mode 100644 (file)
index 0000000..2a49a4a
--- /dev/null
@@ -0,0 +1,31 @@
+#include <string.h>
+
+#include "../array.h"
+#include "../sem.h"
+
+static bool checkmain (struct fundecl *d)
+{
+       if (strcmp(d->ident, "main") == 0) {
+               if (ARRAY_SIZE(&d->args) != 0)
+                       type_error(d->loc, true, "main cannot have arguments");
+               if (d->rtype->type != tbasic || d->rtype->data.tbasic != btvoid)
+                       type_error(d->loc, true, "main must return void");
+               return true;
+       }
+       return false;
+}
+
+void sem_check_main(struct ast *ast)
+{
+       bool found = false;
+       for (int i = 0; i<ast->ndecls && !found; i++) {
+               struct decl *decl = ast->decls[i];
+               if (decl->type == dcomp) {
+                       ARRAY_ITER(struct fundecl *, d, i, &decl->data.dcomp) {
+                               if ((found = checkmain(d)))
+                                       return;
+                       } AIEND
+               }
+       }
+       type_error(ast->loc, true, "no main function found\n");
+}
diff --git a/sem/main.h b/sem/main.h
new file mode 100644 (file)
index 0000000..63a6cc0
--- /dev/null
@@ -0,0 +1,8 @@
+#ifndef SEM_MAIN_H
+#define SEM_MAIN_H
+
+#include "../ast.h"
+
+void sem_check_main(struct ast *ast);
+
+#endif
diff --git a/sem/return.c b/sem/return.c
new file mode 100644 (file)
index 0000000..63df8b9
--- /dev/null
@@ -0,0 +1,64 @@
+#include "../sem.h"
+
+static bool check_return_stmt(struct stmt *stmt);
+static bool check_return_body(struct array body)
+{
+       ARRAY_ITER(struct stmt *, s, i, &body)
+               if (check_return_stmt(s))
+                       return true;
+       AIEND
+       return false;
+}
+
+static bool check_return_stmt(struct stmt *stmt)
+{
+       switch (stmt->type) {
+       case sassign:
+               return false;
+       case sif:
+               return check_return_body(stmt->data.sif.then)
+                       && check_return_body(stmt->data.sif.els);
+       case swhile:
+               return check_return_body(stmt->data.swhile.body);
+       case sreturn:
+               return true;
+       default:
+               return false;
+       }
+}
+
+static void check_return_comp(struct array decl)
+{
+       ARRAY_ITER(struct fundecl *, d, i, &decl) {
+               //If the function has no type, set return to void if there is a
+               //return and otherwise do nothing
+               if (d->rtype == NULL) {
+                       if (!check_return_body(d->body))
+                               d->rtype = type_basic(btvoid);
+               //If the function is typed
+               } else {
+                       //If the function is typed as void
+                       if (d->rtype->type == tbasic
+                                       && d->rtype->data.tbasic == btvoid) {
+                               if (!check_return_body(d->body))
+                                       array_append(&d->body,
+                                               stmt_return(NULL, d->loc));
+                       // If the function is some other type
+                       } else {
+                               if (!check_return_body(d->body))
+                                       type_error(d->loc, true,
+                                               "%s doesn't return properly",
+                                               d->ident);
+                       }
+               }
+       } AIEND
+}
+
+void sem_check_return(struct ast *ast)
+{
+       for (int i = 0; i<ast->ndecls; i++) {
+               if (ast->decls[i]->type == dcomp) {
+                       check_return_comp(ast->decls[i]->data.dcomp);
+               }
+       }
+}
diff --git a/sem/return.h b/sem/return.h
new file mode 100644 (file)
index 0000000..7ccc0b9
--- /dev/null
@@ -0,0 +1,8 @@
+#ifndef SEM_RETURN_H
+#define SEM_RETURN_H
+
+#include "../ast.h"
+
+void sem_check_return(struct ast *ast);
+
+#endif
index 05c527b..10681c0 100644 (file)
--- a/sem/scc.c
+++ b/sem/scc.c
@@ -162,7 +162,7 @@ struct array edges_expr(int ndecls, struct decl **decls, void *parent,
        case echar:
                break;
        case efuncall:
-               ARRAY_ITER(struct expr *, e, i, expr->data.efuncall.args )
+               ARRAY_ITER(struct expr *, e, i, &expr->data.efuncall.args )
                        l = edges_expr(ndecls, decls, parent, e, l);
                AIEND
                bool found = false;
@@ -172,7 +172,7 @@ struct array edges_expr(int ndecls, struct decl **decls, void *parent,
                                struct edge *edge = xalloc(1, struct edge);
                                edge->from = parent;
                                edge->to = (void *)decls[i];
-                               l = array_append(l, edge);
+                               array_append(&l, edge);
                                found = true;
                        }
                }
@@ -212,10 +212,10 @@ struct array edges_stmt(int ndecls, struct decl **decls, void *parent,
                break;
        case sif:
                l = edges_expr(ndecls, decls, parent, stmt->data.sif.pred, l);
-               ARRAY_ITER(struct stmt *, s, i, stmt->data.sif.then)
+               ARRAY_ITER(struct stmt *, s, i, &stmt->data.sif.then)
                        l = edges_stmt(ndecls, decls, parent, s, l);
                AIEND
-               ARRAY_ITER(struct stmt *, s, i, stmt->data.sif.els)
+               ARRAY_ITER(struct stmt *, s, i, &stmt->data.sif.els)
                        l = edges_stmt(ndecls, decls, parent, s, l);
                AIEND
                break;
@@ -232,7 +232,7 @@ struct array edges_stmt(int ndecls, struct decl **decls, void *parent,
        case swhile:
                l = edges_expr(ndecls, decls, parent,
                        stmt->data.swhile.pred, l);
-               ARRAY_ITER(struct stmt *, s, i, stmt->data.swhile.body)
+               ARRAY_ITER(struct stmt *, s, i, &stmt->data.swhile.body)
                        l = edges_stmt(ndecls, decls, parent, s, l);
                AIEND
                break;
@@ -247,7 +247,7 @@ int declcmp(const void *l, const void *r)
        return (*(struct decl **)l)->type - (*(struct decl **)r)->type;
 }
 
-struct ast *ast_scc(struct ast *ast)
+void sem_check_scc(struct ast *ast)
 {
        //Sort so that the functions are at the end
        qsort(ast->decls, ast->ndecls, sizeof(struct decl *), declcmp);
@@ -264,14 +264,14 @@ struct ast *ast_scc(struct ast *ast)
        struct array edges;
        array_init(&edges, nfun*2);
        for (size_t i = 0; i<nfun; i++) {
-               ARRAY_ITER(struct stmt *, s, j, fundecls[i]->data.dfun->body)
+               ARRAY_ITER(struct stmt *, s, j, &fundecls[i]->data.dfun->body)
                        edges = edges_stmt(nfun, fundecls, fundecls[i], s, edges);
                AIEND
        }
 
        // Do tarjan's and convert back into the declaration list
        struct components *cs = tarjans(nfun, (void **)fundecls,
-               ARRAY_SIZE(edges), ARRAY_ELS(struct edge *, edges));
+               ARRAY_SIZE(&edges), ARRAY_ELS(struct edge *, &edges));
 
        int i = ffun;
        for (struct components *c = cs; c != NULL; c = c->next) {
@@ -279,14 +279,14 @@ struct ast *ast_scc(struct ast *ast)
                d->type = dcomp;
                array_init(&d->data.dcomp, c->nnodes);
                for (int j = 0; j<c->nnodes; j++)
-                       d->data.dcomp = array_append(d->data.dcomp,
+                       array_append(&d->data.dcomp,
                                ((struct decl *)c->nodes[j])->data.dfun);
                ast->decls[i++] = d;
        }
        ast->ndecls = i;
 
        //Cleanup
-       array_free(edges, free);
+       array_free(&edges, free);
 
        struct components *t;
        while (cs != NULL) {
@@ -297,5 +297,4 @@ struct ast *ast_scc(struct ast *ast)
                free(cs);
                cs = t;
        }
-       return ast;
 }
index 784ed30..e948435 100644 (file)
--- a/sem/scc.h
+++ b/sem/scc.h
@@ -3,7 +3,6 @@
 
 #include "../ast.h"
 
-// Split up the AST in strongly connected components
-struct ast *ast_scc(struct ast *ast);
+void sem_check_scc(struct ast *ast);
 
 #endif
diff --git a/sem/type.c b/sem/type.c
new file mode 100644 (file)
index 0000000..5e04324
--- /dev/null
@@ -0,0 +1,214 @@
+#include <string.h>
+
+#include "../sem.h"
+#include "hm.h"
+
+static void patch_overload_stmt(struct subst *subst, struct stmt *s);
+static void patch_overload_body(struct subst *subst, struct array body)
+{
+       ARRAY_ITER(struct stmt *, s, j, &body)
+               patch_overload_stmt(subst, s);
+       AIEND
+}
+
+static void patch_overload_expr(struct subst *subst, struct expr *expr)
+{
+       if (expr == NULL)
+               return;
+       switch (expr->type) {
+       case ebinop:
+               if (expr->data.ebinop.op == eq || expr->data.ebinop.op == neq)
+                       expr->data.ebinop.type = subst_apply_t(subst,
+                               expr->data.ebinop.type);
+               patch_overload_expr(subst, expr->data.ebinop.l);
+               patch_overload_expr(subst, expr->data.ebinop.r);
+               break;
+       case efuncall:
+               if (strcmp(expr->data.efuncall.ident, "print") == 0)
+                       expr->data.efuncall.type = subst_apply_t(subst,
+                               expr->data.efuncall.type);
+               break;
+       case etuple:
+               patch_overload_expr(subst, expr->data.etuple.left);
+               patch_overload_expr(subst, expr->data.etuple.right);
+               break;
+       default:
+               break;
+       }
+}
+
+static void patch_overload_stmt(struct subst *subst, struct stmt *stmt)
+{
+       switch (stmt->type) {
+       case sassign:
+               patch_overload_expr(subst, stmt->data.sassign.expr);
+               break;
+       case sif:
+               patch_overload_expr(subst, stmt->data.sif.pred);
+               patch_overload_body(subst, stmt->data.sif.then);
+               patch_overload_body(subst, stmt->data.sif.els);
+               break;
+       case sreturn:
+               patch_overload_expr(subst, stmt->data.sreturn);
+               break;
+       case sexpr:
+               patch_overload_expr(subst, stmt->data.sexpr);
+               break;
+       case svardecl:
+               stmt->data.svardecl->type = subst_apply_t(subst,
+                       stmt->data.svardecl->type);
+               patch_overload_expr(subst, stmt->data.svardecl->expr);
+               break;
+       case swhile:
+               patch_overload_expr(subst, stmt->data.swhile.pred);
+               patch_overload_body(subst, stmt->data.swhile.body);
+               break;
+       }
+}
+
+static void gamma_preamble(struct gamma *gamma)
+{
+       struct type *t = type_arrow(type_tuple(type_var_str("a")
+               , type_var_str("b")) ,type_var_str("a"));
+       gamma_insert(gamma, ident_str("fst"), scheme_generalise(gamma, t));
+       type_free(t);
+
+       t = type_arrow(type_tuple(type_var_str("a")
+               , type_var_str("b")) ,type_var_str("b"));
+       gamma_insert(gamma, ident_str("snd"), scheme_generalise(gamma, t));
+       type_free(t);
+
+       t = type_arrow(type_list(type_var_str("a")),
+                       type_var_str("a"));
+       gamma_insert(gamma, ident_str("hd"), scheme_generalise(gamma, t));
+       type_free(t);
+
+       t = type_arrow(type_list(type_var_str("a")),
+                       type_list(type_var_str("a")));
+       gamma_insert(gamma, ident_str("tl"), scheme_generalise(gamma, t));
+       type_free(t);
+
+       t = type_arrow(type_list(type_var_str("a")),
+                       type_basic(btbool));
+       gamma_insert(gamma, ident_str("isEmpty"), scheme_generalise(gamma, t));
+       type_free(t);
+
+       t = type_arrow(type_var_str("a"), type_basic(btvoid));
+       gamma_insert(gamma, ident_str("print"), scheme_generalise(gamma, t));
+       type_free(t);
+}
+
+static struct vardecl *type_vardecl(struct gamma *gamma, struct vardecl *vardecl)
+{
+       struct type *t = vardecl->type == NULL
+               ? gamma_fresh(gamma) : vardecl->type;
+       struct subst *s = infer_expr(gamma, vardecl->expr, t);
+
+       vardecl->type = subst_apply_t(s, t);
+       gamma_insert(gamma, ident_str(vardecl->ident), scheme_create(vardecl->type));
+
+       subst_free(s);
+
+       return vardecl;
+}
+
+static void type_comp(struct gamma *gamma, struct array decl)
+{
+       //Create a fresh variable for every function in the component
+       struct type **fs = xalloc(ARRAY_SIZE(&decl), struct type *);
+       ARRAY_ITER(struct fundecl *, d, i, &decl) {
+               fs[i] = gamma_fresh(gamma);
+               ARRAY_ITERI(j, &d->args) {
+                       struct type *a = gamma_fresh(gamma);
+                       fs[i] = type_arrow(a, fs[i]);
+               }
+               gamma_insert(gamma, ident_str(d->ident), scheme_create(fs[i]));
+       } AIEND
+
+       //Infer each function
+       struct subst *s0 = subst_id();
+       ARRAY_ITERI(i, &decl) {
+               struct subst *s1 = infer_fundecl(gamma,
+                       ARRAY_EL(struct fundecl *, &decl, i),
+                       subst_apply_t(s0, fs[i]));
+               s0 = subst_union(s1, s0);
+       }
+
+       //Generalise all functions and put in gamma
+       ARRAY_ITER(struct fundecl *, d, i, &decl) {
+               struct type *t = subst_apply_t(s0, fs[i]);
+
+               //unify against given type specification
+               if (d->rtype != NULL) {
+                       // only check return type
+                       if (d->atypes == NULL) {
+                               struct type *dt = d->rtype;
+                               for (int j = (int)ARRAY_SIZE(&d->args)-1; j>=0; j--)
+                                       dt = type_arrow(gamma_fresh(gamma), dt);
+                               struct subst *s1 = unify(d->loc, dt, t);
+                               subst_apply_t(s1, fs[i]);
+                               subst_free(s1);
+                               type_free(dt);
+                       } else {
+                               struct type *dt = d->rtype;
+                               for (int j = (int)ARRAY_SIZE(d->atypes)-1; j>=0; j--)
+                                       dt = type_arrow(ARRAY_EL(struct type *,
+                                               d->atypes, j), dt);
+                               struct subst *s1 = unify(d->loc, dt, t);
+                               subst_apply_t(s1, fs[i]);
+                               subst_free(s1);
+                               type_free(dt);
+                       }
+               }
+
+               gamma_insert(gamma, ident_str(d->ident), scheme_generalise(gamma, t));
+
+               //Put the type in the ast
+               if (d->atypes != NULL)
+                       array_clean(d->atypes, NULL);
+               else
+                       d->atypes = array_new(ARRAY_SIZE(&d->args));
+               array_resize(d->atypes, ARRAY_SIZE(&d->args));
+
+               ARRAY_ITERI(j, &d->args) {
+                       array_append(d->atypes, (void *)type_dup(t->data.tarrow.l));
+                       t = t->data.tarrow.r;
+               }
+               d->rtype = type_dup(t);
+       } AIEND
+
+       //Free all types
+       for (size_t i = 0; i<ARRAY_SIZE(&decl); i++)
+               type_free(fs[i]);
+       free(fs);
+
+       //Patch all overloaded functions
+       ARRAY_ITER(struct fundecl *, d, i, &decl) {
+               patch_overload_body(s0, d->body);
+       } AIEND
+       subst_free(s0);
+}
+
+void sem_check_types(struct ast *ast)
+{
+       struct gamma *gamma = gamma_init();
+       gamma_preamble(gamma);
+
+       for (int i = 0; i<ast->ndecls; i++) {
+               struct decl *decl = ast->decls[i];
+               switch(decl->type) {
+               case dvardecl:
+                       //Infer if necessary
+                       type_vardecl(gamma, decl->data.dvar);
+                       break;
+               case dcomp:
+                       //Infer function as singleton component
+                       type_comp(gamma, decl->data.dcomp);
+                       break;
+               case dfundecl:
+                       die("fundecls should be gone by now\n");
+                       break;
+               }
+       }
+       gamma_free(gamma);
+}
diff --git a/sem/type.h b/sem/type.h
new file mode 100644 (file)
index 0000000..352cbb5
--- /dev/null
@@ -0,0 +1,8 @@
+#ifndef SEM_TYPE_H
+#define SEM_TYPE_H
+
+#include "../ast.h"
+
+void sem_check_types(struct ast *ast);
+
+#endif
diff --git a/sem/vardecl.c b/sem/vardecl.c
new file mode 100644 (file)
index 0000000..a1215c2
--- /dev/null
@@ -0,0 +1,17 @@
+#include "../ast.h"
+
+void fix_vars_fun(struct fundecl *d)
+{
+       (void)d;
+}
+
+void sem_check_vardecls(struct ast *ast)
+{
+       for (int i = 0; i<ast->ndecls; i++) {
+               if (ast->decls[i]->type == dcomp) {
+                       ARRAY_ITER(struct fundecl *, d, i, &ast->decls[i]->data.dcomp)
+                               fix_vars_fun(d);
+                       AIEND
+               }
+       }
+}
diff --git a/sem/vardecl.h b/sem/vardecl.h
new file mode 100644 (file)
index 0000000..8f72e2c
--- /dev/null
@@ -0,0 +1,8 @@
+#ifndef SEM_VARDECL_H
+#define SEM_VARDECL_H
+
+#include "../ast.h"
+
+void sem_check_vardecls(struct ast *ast);
+
+#endif
diff --git a/splc.c b/splc.c
index 48311c8..8a4f1f3 100644 (file)
--- a/splc.c
+++ b/splc.c
@@ -96,7 +96,7 @@ int main(int argc, char *argv[])
        cout = safe_fopen(cfile, "w+");
        free(cfile);
 
-       gen(result, lang, cout);
+//     gen(result, lang, cout);
 
        safe_fclose(cout);
        safe_fprintf(stderr, "code generation done\n");
diff --git a/util.c b/util.c
index f473dc6..ad9b50c 100644 (file)
--- a/util.c
+++ b/util.c
@@ -131,8 +131,7 @@ void die(const char *msg, ...)
 void pindent(int indent, FILE *out)
 {
        for (int i = 0; i<indent; i++)
-               if (fputc('\t', out) == EOF)
-                       pdie("fputc");
+               safe_fprintf(out, "\t");
 }
 
 void safe_vfprintf(FILE *out, const char *msg, va_list ap)
diff --git a/util.h b/util.h
index a1f68de..bfdd21c 100644 (file)
--- a/util.h
+++ b/util.h
@@ -5,8 +5,6 @@
 #include <stdbool.h>
 #include <stdio.h>
 
-#define min(x, y) ((x)<(y)?(x):(y))
-
 /* exit with an error message */
 void die(const char *msg, ...);
 /* exit with the system's error message prefixed by msg */