types and locations
[ccc.git] / ast.c
diff --git a/ast.c b/ast.c
index 622ba74..cb484c4 100644 (file)
--- a/ast.c
+++ b/ast.c
@@ -4,6 +4,7 @@
 
 #include "util.h"
 #include "ast.h"
+#include "y.tab.h"
 
 static const char *binop_str[] = {
        [binor] = "||", [binand] = "&&", [eq] = "==", [neq] = "!=",
@@ -14,6 +15,10 @@ static const char *binop_str[] = {
 static const char *fieldspec_str[] = {
        [fst] = "fst", [snd] = "snd", [hd] = "hd", [tl] = "tl"};
 static const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
+static const char *basictype_str[] = {
+       [btbool] = "Bool", [btchar] = "Char", [btint] = "Int",
+       [btvoid] = "Void",
+};
 
 struct ast *ast(struct list *decls)
 {
@@ -22,19 +27,34 @@ struct ast *ast(struct list *decls)
        return res;
 }
 
-struct decl *decl_fun(char *ident, struct list *args, struct list *body)
+struct vardecl *vardecl(struct type *type, char *ident, struct expr *expr)
+{
+       struct vardecl *res = safe_malloc(sizeof(struct vardecl));
+       res->type = type;
+       res->ident = ident;
+       res->expr = expr;
+       return res;
+}
+
+struct decl *decl_fun(char *ident, struct list *args, struct list *atypes,
+       struct type *rtype, struct list *vars, struct list *body)
 {
        struct decl *res = safe_malloc(sizeof(struct decl));
        res->type = dfundecl;
        res->data.dfun.ident = ident;
        res->data.dfun.args = (char **)
                list_to_array(args, &res->data.dfun.nargs, true);
+       res->data.dfun.atypes = (struct type **)
+               list_to_array(atypes, &res->data.dfun.natypes, true);
+       res->data.dfun.rtype = rtype;
+       res->data.dfun.vars = (struct vardecl **)
+               list_to_array(vars, &res->data.dfun.nvar, true);
        res->data.dfun.body = (struct stmt **)
                list_to_array(body, &res->data.dfun.nbody, true);
        return res;
 }
 
-struct decl *decl_var(struct vardecl vardecl)
+struct decl *decl_var(struct vardecl *vardecl)
 {
        struct decl *res = safe_malloc(sizeof(struct decl));
        res->type = dvardecl;
@@ -79,14 +99,6 @@ struct stmt *stmt_expr(struct expr *expr)
        return res;
 }
 
-struct stmt *stmt_vardecl(struct vardecl vardecl)
-{
-       struct stmt *res = safe_malloc(sizeof(struct stmt));
-       res->type = svardecl;
-       res->data.svardecl = vardecl;
-       return res;
-}
-
 struct stmt *stmt_while(struct expr *pred, struct list *body)
 {
        struct stmt *res = safe_malloc(sizeof(struct stmt));
@@ -217,6 +229,50 @@ struct expr *expr_unop(enum unop op, struct expr *l)
        return res;
 }
 
+struct type *type_list(struct type *type)
+{
+       struct type *res = safe_malloc(sizeof(struct type));
+       res->type = tlist;
+       res->data.tlist = type;
+       return res;
+}
+
+struct type *type_tuple(struct type *l, struct type *r)
+{
+       struct type *res = safe_malloc(sizeof(struct type));
+       res->type = ttuple;
+       res->data.ttuple.l = l;
+       res->data.ttuple.r = r;
+       return res;
+}
+
+struct type *type_var(char *ident)
+{
+       struct type *res = safe_malloc(sizeof(struct type));
+       if (strcmp(ident, "Int") == 0) {
+               res->type = tbasic;
+               res->data.tbasic = btint;
+               free(ident);
+       } else if (strcmp(ident, "Char") == 0) {
+               res->type = tbasic;
+               res->data.tbasic = btchar;
+               free(ident);
+       } else if (strcmp(ident, "Bool") == 0) {
+               res->type = tbasic;
+               res->data.tbasic = btbool;
+               free(ident);
+       } else if (strcmp(ident, "Void") == 0) {
+               res->type = tbasic;
+               res->data.tbasic = btvoid;
+               free(ident);
+       } else {
+               res->type = tvar;
+               res->data.tvar = ident;
+       }
+       return res;
+}
+
+
 const char *cescapes[] = {
        [0] = "\\0", [1] = "\\x01", [2] = "\\x02", [3] = "\\x03", 
        [4] = "\\x04", [5] = "\\x05", [6] = "\\x06", [7] = "\\a", [8] = "\\b",
@@ -237,6 +293,18 @@ void ast_print(struct ast *ast, FILE *out)
                decl_print(ast->decls[i], 0, out);
 }
 
+void vardecl_print(struct vardecl *decl, int indent, FILE *out)
+{
+       pindent(indent, out);
+       if (decl->type == NULL)
+               safe_fprintf(out, "var");
+       else
+               type_print(decl->type, out);
+       safe_fprintf(out, " %s = ", decl->ident);
+       expr_print(decl->expr, out);
+       safe_fprintf(out, ";\n");
+}
+
 void decl_print(struct decl *decl, int indent, FILE *out)
 {
        if (decl == NULL)
@@ -250,17 +318,26 @@ void decl_print(struct decl *decl, int indent, FILE *out)
                        if (i < decl->data.dfun.nargs - 1)
                                safe_fprintf(out, ", ");
                }
-               safe_fprintf(out, ") {\n");
+               safe_fprintf(out, ")");
+               if (decl->data.dfun.rtype != NULL) {
+                       safe_fprintf(out, " :: ");
+                       for (int i = 0; i<decl->data.dfun.natypes; i++) {
+                               type_print(decl->data.dfun.atypes[i], out);
+                               safe_fprintf(out, " ");
+                       }
+                       safe_fprintf(out, "-> ");
+                       type_print(decl->data.dfun.rtype, out);
+               }
+               safe_fprintf(out, " {\n");
+               for (int i = 0; i<decl->data.dfun.nvar; i++)
+                       vardecl_print(decl->data.dfun.vars[i], indent+1, out);
                for (int i = 0; i<decl->data.dfun.nbody; i++)
                        stmt_print(decl->data.dfun.body[i], indent+1, out);
                pindent(indent, out);
                safe_fprintf(out, "}\n");
                break;
        case dvardecl:
-               pindent(indent, out);
-               safe_fprintf(out, "var %s = ", decl->data.dvar.ident);
-               expr_print(decl->data.dvar.expr, out);
-               safe_fprintf(out, ";\n");
+               vardecl_print(decl->data.dvar, indent, out);
                break;
        default:
                die("Unsupported decl node\n");
@@ -304,12 +381,6 @@ void stmt_print(struct stmt *stmt, int indent, FILE *out)
                expr_print(stmt->data.sexpr, out);
                safe_fprintf(out, ";\n");
                break;
-       case svardecl:
-               pindent(indent, out);
-               safe_fprintf(out, "var %s = ", stmt->data.svardecl.ident);
-               expr_print(stmt->data.svardecl.expr, out);
-               safe_fprintf(out, ";\n");
-               break;
        case swhile:
                pindent(indent, out);
                safe_fprintf(out, "while (");
@@ -388,6 +459,34 @@ void expr_print(struct expr *expr, FILE *out)
        }
 }
 
+void type_print(struct type *type, FILE *out)
+{
+       if (type == NULL)
+               return;
+       switch (type->type) {
+       case tbasic:
+               safe_fprintf(out, "%s", basictype_str[type->data.tbasic]);
+               break;
+       case tlist:
+               safe_fprintf(out, "[");
+               type_print(type->data.tlist, out);
+               safe_fprintf(out, "]");
+               break;
+       case ttuple:
+               safe_fprintf(out, "(");
+               type_print(type->data.ttuple.l, out);
+               safe_fprintf(out, ",");
+               type_print(type->data.ttuple.r, out);
+               safe_fprintf(out, ")");
+               break;
+       case tvar:
+               safe_fprintf(out, "%s", type->data.tvar);
+               break;
+       default:
+               die("Unsupported type node\n");
+       }
+}
+
 void ast_free(struct ast *ast)
 {
        if (ast == NULL)
@@ -398,6 +497,14 @@ void ast_free(struct ast *ast)
        free(ast);
 }
 
+void vardecl_free(struct vardecl *decl)
+{
+       type_free(decl->type);
+       free(decl->ident);
+       expr_free(decl->expr);
+       free(decl);
+}
+
 void decl_free(struct decl *decl)
 {
        if (decl == NULL)
@@ -408,13 +515,19 @@ void decl_free(struct decl *decl)
                for (int i = 0; i<decl->data.dfun.nargs; i++)
                        free(decl->data.dfun.args[i]);
                free(decl->data.dfun.args);
+               for (int i = 0; i<decl->data.dfun.natypes; i++)
+                       type_free(decl->data.dfun.atypes[i]);
+               free(decl->data.dfun.atypes);
+               type_free(decl->data.dfun.rtype);
+               for (int i = 0; i<decl->data.dfun.nvar; i++)
+                       vardecl_free(decl->data.dfun.vars[i]);
+               free(decl->data.dfun.vars);
                for (int i = 0; i<decl->data.dfun.nbody; i++)
                        stmt_free(decl->data.dfun.body[i]);
                free(decl->data.dfun.body);
                break;
        case dvardecl:
-               free(decl->data.dvar.ident);
-               expr_free(decl->data.dvar.expr);
+               vardecl_free(decl->data.dvar);
                break;
        default:
                die("Unsupported decl node\n");
@@ -446,10 +559,6 @@ void stmt_free(struct stmt *stmt)
        case sexpr:
                expr_free(stmt->data.sexpr);
                break;
-       case svardecl:
-               free(stmt->data.svardecl.ident);
-               expr_free(stmt->data.svardecl.expr);
-               break;
        case swhile:
                expr_free(stmt->data.swhile.pred);
                for (int i = 0; i<stmt->data.swhile.nbody; i++)
@@ -464,6 +573,8 @@ void stmt_free(struct stmt *stmt)
 
 void expr_free(struct expr *expr)
 {
+       if (expr == NULL)
+               return;
        switch(expr->type) {
        case ebinop:
                expr_free(expr->data.ebinop.l);
@@ -499,3 +610,26 @@ void expr_free(struct expr *expr)
        }
        free(expr);
 }
+
+void type_free(struct type *type)
+{
+       if (type == NULL)
+               return;
+       switch (type->type) {
+       case tbasic:
+               break;
+       case tlist:
+               type_free(type->data.tlist);
+               break;
+       case ttuple:
+               type_free(type->data.ttuple.l);
+               type_free(type->data.ttuple.r);
+               break;
+       case tvar:
+               free(type->data.tvar);
+               break;
+       default:
+               die("Unsupported type node\n");
+       }
+       free(type);
+}