work on type inference some more
[ccc.git] / ast.c
diff --git a/ast.c b/ast.c
index 597dfa6..1e93a71 100644 (file)
--- a/ast.c
+++ b/ast.c
@@ -4,21 +4,19 @@
 
 #include "util.h"
 #include "ast.h"
-#include "y.tab.h"
+#include "type.h"
+#include "list.h"
+#include "parse.h"
 
-static const char *binop_str[] = {
+const char *binop_str[] = {
        [binor] = "||", [binand] = "&&", [eq] = "==", [neq] = "!=",
        [leq] = "<=", [le] = "<", [geq] = ">=", [ge] = ">", [cons] = ":",
        [plus] = "+", [minus] = "-", [times] = "*", [divide] = "/",
        [modulo] = "%", [power] = "^",
 };
-static const char *fieldspec_str[] = {
+const char *fieldspec_str[] = {
        [fst] = "fst", [snd] = "snd", [hd] = "hd", [tl] = "tl"};
-static const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
-static const char *basictype_str[] = {
-       [btbool] = "Bool", [btchar] = "Char", [btint] = "Int",
-       [btvoid] = "Void",
-};
+const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
 
 struct ast *ast(struct list *decls)
 {
@@ -35,22 +33,23 @@ struct vardecl *vardecl(struct type *type, char *ident, struct expr *expr)
        res->expr = expr;
        return res;
 }
+struct fundecl *fundecl(char *ident, struct list *args, struct list *atypes,
+       struct type *rtype, struct list *body)
+{
+       struct fundecl *res = safe_malloc(sizeof(struct fundecl));
+       res->ident = ident;
+       res->args = (char **)list_to_array(args, &res->nargs, true);
+       res->atypes = (struct type **)list_to_array(atypes, &res->natypes, true);
+       res->rtype = rtype;
+       res->body = (struct stmt **)list_to_array(body, &res->nbody, true);
+       return res;
+}
 
-struct decl *decl_fun(char *ident, struct list *args, struct list *atypes,
-       struct type *rtype, struct list *vars, struct list *body)
+struct decl *decl_fun(struct fundecl *fundecl)
 {
        struct decl *res = safe_malloc(sizeof(struct decl));
        res->type = dfundecl;
-       res->data.dfun.ident = ident;
-       res->data.dfun.args = (char **)
-               list_to_array(args, &res->data.dfun.nargs, true);
-       res->data.dfun.atypes = (struct type **)
-               list_to_array(atypes, &res->data.dfun.natypes, true);
-       res->data.dfun.rtype = rtype;
-       res->data.dfun.vars = (struct vardecl **)
-               list_to_array(vars, &res->data.dfun.nvar, true);
-       res->data.dfun.body = (struct stmt **)
-               list_to_array(body, &res->data.dfun.nbody, true);
+       res->data.dfun = fundecl;
        return res;
 }
 
@@ -68,7 +67,7 @@ struct stmt *stmt_assign(char *ident, struct list *fields, struct expr *expr)
        res->type = sassign;
        res->data.sassign.ident = ident;
        res->data.sassign.fields = (char **)
-               list_to_array(fields, &res->data.sassign.nfield, true);
+               list_to_array(fields, &res->data.sassign.nfields, true);
        res->data.sassign.expr = expr;
        return res;
 }
@@ -214,13 +213,13 @@ struct expr *expr_string(char *str)
 {
        struct expr *res = safe_malloc(sizeof(struct expr));
        res->type = estring;
-       res->data.estring.nchar = 0;
+       res->data.estring.nchars = 0;
        res->data.estring.chars = safe_malloc(strlen(str)+1);
        char *p = res->data.estring.chars;
        while(*str != '\0') {
                str = unescape_char(str);
                *p++ = *str++;
-               res->data.estring.nchar++;
+               res->data.estring.nchars++;
        }
        *p = '\0';
        return res;
@@ -235,63 +234,12 @@ struct expr *expr_unop(enum unop op, struct expr *l)
        return res;
 }
 
-struct type *type_basic(enum basictype type)
-{
-       struct type *res = safe_malloc(sizeof(struct type));
-       res->type = tbasic;
-       res->data.tbasic = type;
-       return res;
-}
-
-struct type *type_list(struct type *type)
-{
-       struct type *res = safe_malloc(sizeof(struct type));
-       res->type = tlist;
-       res->data.tlist = type;
-       return res;
-}
-
-struct type *type_tuple(struct type *l, struct type *r)
-{
-       struct type *res = safe_malloc(sizeof(struct type));
-       res->type = ttuple;
-       res->data.ttuple.l = l;
-       res->data.ttuple.r = r;
-       return res;
-}
-
-struct type *type_var(char *ident)
-{
-       struct type *res = safe_malloc(sizeof(struct type));
-       if (strcmp(ident, "Int") == 0) {
-               res->type = tbasic;
-               res->data.tbasic = btint;
-               free(ident);
-       } else if (strcmp(ident, "Char") == 0) {
-               res->type = tbasic;
-               res->data.tbasic = btchar;
-               free(ident);
-       } else if (strcmp(ident, "Bool") == 0) {
-               res->type = tbasic;
-               res->data.tbasic = btbool;
-               free(ident);
-       } else if (strcmp(ident, "Void") == 0) {
-               res->type = tbasic;
-               res->data.tbasic = btvoid;
-               free(ident);
-       } else {
-               res->type = tvar;
-               res->data.tvar = ident;
-       }
-       return res;
-}
-
 void ast_print(struct ast *ast, FILE *out)
 {
        if (ast == NULL)
                return;
        for (int i = 0; i<ast->ndecls; i++)
-               decl_print(ast->decls[i], 0, out);
+               decl_print(ast->decls[i], out);
 }
 
 void vardecl_print(struct vardecl *decl, int indent, FILE *out)
@@ -306,39 +254,46 @@ void vardecl_print(struct vardecl *decl, int indent, FILE *out)
        safe_fprintf(out, ";\n");
 }
 
-void decl_print(struct decl *decl, int indent, FILE *out)
+void fundecl_print(struct fundecl *decl, FILE *out)
+{
+       safe_fprintf(out, "%s (", decl->ident);
+       for (int i = 0; i<decl->nargs; i++) {
+               safe_fprintf(out, "%s", decl->args[i]);
+               if (i < decl->nargs - 1)
+                       safe_fprintf(out, ", ");
+       }
+       safe_fprintf(out, ")");
+       if (decl->rtype != NULL) {
+               safe_fprintf(out, " :: ");
+               for (int i = 0; i<decl->natypes; i++) {
+                       type_print(decl->atypes[i], out);
+                       safe_fprintf(out, " ");
+               }
+               safe_fprintf(out, "-> ");
+               type_print(decl->rtype, out);
+       }
+       safe_fprintf(out, " {\n");
+       for (int i = 0; i<decl->nbody; i++)
+               stmt_print(decl->body[i], 1, out);
+       safe_fprintf(out, "}\n");
+}
+
+void decl_print(struct decl *decl, FILE *out)
 {
        if (decl == NULL)
                return;
        switch(decl->type) {
        case dfundecl:
-               pindent(indent, out);
-               safe_fprintf(out, "%s (", decl->data.dfun.ident);
-               for (int i = 0; i<decl->data.dfun.nargs; i++) {
-                       safe_fprintf(out, "%s", decl->data.dfun.args[i]);
-                       if (i < decl->data.dfun.nargs - 1)
-                               safe_fprintf(out, ", ");
-               }
-               safe_fprintf(out, ")");
-               if (decl->data.dfun.rtype != NULL) {
-                       safe_fprintf(out, " :: ");
-                       for (int i = 0; i<decl->data.dfun.natypes; i++) {
-                               type_print(decl->data.dfun.atypes[i], out);
-                               safe_fprintf(out, " ");
-                       }
-                       safe_fprintf(out, "-> ");
-                       type_print(decl->data.dfun.rtype, out);
-               }
-               safe_fprintf(out, " {\n");
-               for (int i = 0; i<decl->data.dfun.nvar; i++)
-                       vardecl_print(decl->data.dfun.vars[i], indent+1, out);
-               for (int i = 0; i<decl->data.dfun.nbody; i++)
-                       stmt_print(decl->data.dfun.body[i], indent+1, out);
-               pindent(indent, out);
-               safe_fprintf(out, "}\n");
+               fundecl_print(decl->data.dfun, out);
                break;
        case dvardecl:
-               vardecl_print(decl->data.dvar, indent, out);
+               vardecl_print(decl->data.dvar, 0, out);
+               break;
+       case dcomp:
+               fprintf(out, "//<<<comp\n");
+               for (int i = 0; i<decl->data.dcomp.ndecls; i++)
+                       fundecl_print(decl->data.dcomp.decls[i], out);
+               fprintf(out, "//>>>comp\n");
                break;
        default:
                die("Unsupported decl node\n");
@@ -353,7 +308,7 @@ void stmt_print(struct stmt *stmt, int indent, FILE *out)
        case sassign:
                pindent(indent, out);
                fprintf(out, "%s", stmt->data.sassign.ident);
-               for (int i = 0; i<stmt->data.sassign.nfield; i++)
+               for (int i = 0; i<stmt->data.sassign.nfields; i++)
                        fprintf(out, ".%s", stmt->data.sassign.fields[i]);
                safe_fprintf(out, " = ");
                expr_print(stmt->data.sassign.expr, out);
@@ -392,9 +347,8 @@ void stmt_print(struct stmt *stmt, int indent, FILE *out)
                safe_fprintf(out, "while (");
                expr_print(stmt->data.swhile.pred, out);
                safe_fprintf(out, ") {\n");
-               for (int i = 0; i<stmt->data.swhile.nbody; i++) {
+               for (int i = 0; i<stmt->data.swhile.nbody; i++)
                        stmt_print(stmt->data.swhile.body[i], indent+1, out);
-               }
                pindent(indent, out);
                safe_fprintf(out, "}\n");
                break;
@@ -456,7 +410,7 @@ void expr_print(struct expr *expr, FILE *out)
                break;
        case estring:
                safe_fprintf(out, "\"");
-               for (int i = 0; i<expr->data.estring.nchar; i++)
+               for (int i = 0; i<expr->data.estring.nchars; i++)
                        safe_fprintf(out, "%s", escape_char(
                                expr->data.estring.chars[i], buf, true));
                safe_fprintf(out, "\"");
@@ -471,34 +425,6 @@ void expr_print(struct expr *expr, FILE *out)
        }
 }
 
-void type_print(struct type *type, FILE *out)
-{
-       if (type == NULL)
-               return;
-       switch (type->type) {
-       case tbasic:
-               safe_fprintf(out, "%s", basictype_str[type->data.tbasic]);
-               break;
-       case tlist:
-               safe_fprintf(out, "[");
-               type_print(type->data.tlist, out);
-               safe_fprintf(out, "]");
-               break;
-       case ttuple:
-               safe_fprintf(out, "(");
-               type_print(type->data.ttuple.l, out);
-               safe_fprintf(out, ",");
-               type_print(type->data.ttuple.r, out);
-               safe_fprintf(out, ")");
-               break;
-       case tvar:
-               safe_fprintf(out, "%s", type->data.tvar);
-               break;
-       default:
-               die("Unsupported type node\n");
-       }
-}
-
 void ast_free(struct ast *ast)
 {
        if (ast == NULL)
@@ -517,26 +443,34 @@ void vardecl_free(struct vardecl *decl)
        free(decl);
 }
 
+void fundecl_free(struct fundecl *decl)
+{
+       free(decl->ident);
+       for (int i = 0; i<decl->nargs; i++)
+               free(decl->args[i]);
+       free(decl->args);
+       for (int i = 0; i<decl->natypes; i++)
+               type_free(decl->atypes[i]);
+       free(decl->atypes);
+       type_free(decl->rtype);
+       for (int i = 0; i<decl->nbody; i++)
+               stmt_free(decl->body[i]);
+       free(decl->body);
+       free(decl);
+}
+
 void decl_free(struct decl *decl)
 {
        if (decl == NULL)
                return;
        switch(decl->type) {
+       case dcomp:
+               for (int i = 0; i<decl->data.dcomp.ndecls; i++)
+                       fundecl_free(decl->data.dcomp.decls[i]);
+               free(decl->data.dcomp.decls);
+               break;
        case dfundecl:
-               free(decl->data.dfun.ident);
-               for (int i = 0; i<decl->data.dfun.nargs; i++)
-                       free(decl->data.dfun.args[i]);
-               free(decl->data.dfun.args);
-               for (int i = 0; i<decl->data.dfun.natypes; i++)
-                       type_free(decl->data.dfun.atypes[i]);
-               free(decl->data.dfun.atypes);
-               type_free(decl->data.dfun.rtype);
-               for (int i = 0; i<decl->data.dfun.nvar; i++)
-                       vardecl_free(decl->data.dfun.vars[i]);
-               free(decl->data.dfun.vars);
-               for (int i = 0; i<decl->data.dfun.nbody; i++)
-                       stmt_free(decl->data.dfun.body[i]);
-               free(decl->data.dfun.body);
+               fundecl_free(decl->data.dfun);
                break;
        case dvardecl:
                vardecl_free(decl->data.dvar);
@@ -554,7 +488,7 @@ void stmt_free(struct stmt *stmt)
        switch(stmt->type) {
        case sassign:
                free(stmt->data.sassign.ident);
-               for (int i = 0; i<stmt->data.sassign.nfield; i++)
+               for (int i = 0; i<stmt->data.sassign.nfields; i++)
                        free(stmt->data.sassign.fields[i]);
                free(stmt->data.sassign.fields);
                expr_free(stmt->data.sassign.expr);
@@ -632,26 +566,3 @@ void expr_free(struct expr *expr)
        }
        free(expr);
 }
-
-void type_free(struct type *type)
-{
-       if (type == NULL)
-               return;
-       switch (type->type) {
-       case tbasic:
-               break;
-       case tlist:
-               type_free(type->data.tlist);
-               break;
-       case ttuple:
-               type_free(type->data.ttuple.l);
-               type_free(type->data.ttuple.r);
-               break;
-       case tvar:
-               free(type->data.tvar);
-               break;
-       default:
-               die("Unsupported type node\n");
-       }
-       free(type);
-}