work on type inference some more
[ccc.git] / ast.c
diff --git a/ast.c b/ast.c
index 97bd1ad..1e93a71 100644 (file)
--- a/ast.c
+++ b/ast.c
 
 #include "util.h"
 #include "ast.h"
+#include "type.h"
+#include "list.h"
+#include "parse.h"
 
-#ifdef DEBUG
-static const char *ast_type_str[] = {
-       [an_assign] = "assign", [an_bool] = "bool", [an_binop] = "binop",
-       [an_char] = "char", [an_cons] = "cons", [an_funcall] = "funcall",
-       [an_fundecl] = "fundecl", [an_ident] = "ident", [an_if] = "if",
-       [an_int] = "int", [an_nil] = "nil", [an_list] = "list",
-       [an_return] = "return", [an_stmt_expr] = "stmt_expr",
-       [an_unop] = "unop", [an_vardecl] = "vardecl", [an_while] = "while",
-};
-#endif
-static const char *binop_str[] = {
+const char *binop_str[] = {
        [binor] = "||", [binand] = "&&", [eq] = "==", [neq] = "!=",
        [leq] = "<=", [le] = "<", [geq] = ">=", [ge] = ">", [cons] = ":",
        [plus] = "+", [minus] = "-", [times] = "*", [divide] = "/",
        [modulo] = "%", [power] = "^",
 };
-static const char *fieldspec_str[] = {
+const char *fieldspec_str[] = {
        [fst] = "fst", [snd] = "snd", [hd] = "hd", [tl] = "tl"};
-static const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
+const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
 
-#ifdef DEBUG
-#define must_be(node, ntype, msg) {\
-       if ((node)->type != (ntype)) {\
-               fprintf(stderr, "%s can't be %s\n",\
-                       msg, ast_type_str[node->type]);\
-               exit(1);\
-       }\
+struct ast *ast(struct list *decls)
+{
+       struct ast *res = safe_malloc(sizeof(struct ast));
+       res->decls = (struct decl **)list_to_array(decls, &res->ndecls, true);
+       return res;
 }
-#else
-#define must_be(node, ntype, msg) ;
-#endif
 
-#define ast_alloc() ((struct ast *)safe_malloc(sizeof(struct ast)))
-
-struct ast *ast_assign(struct ast *ident, struct ast *expr)
+struct vardecl *vardecl(struct type *type, char *ident, struct expr *expr)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_assign;
-       res->data.an_assign.ident = ident;
-       res->data.an_assign.expr = expr;
+       struct vardecl *res = safe_malloc(sizeof(struct vardecl));
+       res->type = type;
+       res->ident = ident;
+       res->expr = expr;
        return res;
 }
-
-struct ast *ast_binop(struct ast *l, enum binop op, struct ast *r)
+struct fundecl *fundecl(char *ident, struct list *args, struct list *atypes,
+       struct type *rtype, struct list *body)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_binop;
-       res->data.an_binop.l = l;
-       res->data.an_binop.op = op;
-       res->data.an_binop.r = r;
+       struct fundecl *res = safe_malloc(sizeof(struct fundecl));
+       res->ident = ident;
+       res->args = (char **)list_to_array(args, &res->nargs, true);
+       res->atypes = (struct type **)list_to_array(atypes, &res->natypes, true);
+       res->rtype = rtype;
+       res->body = (struct stmt **)list_to_array(body, &res->nbody, true);
        return res;
 }
 
-struct ast *ast_bool(bool b)
+struct decl *decl_fun(struct fundecl *fundecl)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_bool;
-       res->data.an_bool = b;
+       struct decl *res = safe_malloc(sizeof(struct decl));
+       res->type = dfundecl;
+       res->data.dfun = fundecl;
        return res;
 }
 
-int fromHex(char c)
-{
-       if (c >= '0' && c <= '9')
-               return c-'0';
-       if (c >= 'a' && c <= 'f')
-               return c-'a'+10;
-       if (c >= 'A' && c <= 'F')
-               return c-'A'+10;
-       return -1;
-}
-
-struct ast *ast_char(const char *c)
-{
-       struct ast *res = ast_alloc();
-       res->type = an_char;
-       //regular char
-       if (strlen(c) == 3)
-               res->data.an_char = c[1];
-       //escape
-       if (strlen(c) == 4)
-               switch(c[2]) {
-               case '0': res->data.an_char = '\0'; break;
-               case 'a': res->data.an_char = '\a'; break;
-               case 'b': res->data.an_char = '\b'; break;
-               case 't': res->data.an_char = '\t'; break;
-               case 'v': res->data.an_char = '\v'; break;
-               case 'f': res->data.an_char = '\f'; break;
-               case 'r': res->data.an_char = '\r'; break;
-               }
-       //hex escape
-       if (strlen(c) == 6)
-               res->data.an_char = (fromHex(c[3])<<4)+fromHex(c[4]);
+struct decl *decl_var(struct vardecl *vardecl)
+{
+       struct decl *res = safe_malloc(sizeof(struct decl));
+       res->type = dvardecl;
+       res->data.dvar = vardecl;
        return res;
 }
 
-struct ast *ast_cons(struct ast *el, struct ast *tail)
+struct stmt *stmt_assign(char *ident, struct list *fields, struct expr *expr)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_cons;
-       res->data.an_cons.el = el;
-       res->data.an_cons.tail = tail;
+       struct stmt *res = safe_malloc(sizeof(struct stmt));
+       res->type = sassign;
+       res->data.sassign.ident = ident;
+       res->data.sassign.fields = (char **)
+               list_to_array(fields, &res->data.sassign.nfields, true);
+       res->data.sassign.expr = expr;
        return res;
 }
 
-struct ast *ast_funcall(struct ast *ident, struct ast *args)
+struct stmt *stmt_if(struct expr *pred, struct list *then, struct list *els)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_funcall;
-
-       //ident
-       must_be(ident, an_ident, "ident of a funcall");
-       res->data.an_funcall.ident = ident->data.an_ident.ident;
-       free(ident->data.an_ident.fields);
-       free(ident);
-
-       //args
-       must_be(args, an_list, "args of a funcall");
-       res->data.an_funcall.nargs = args->data.an_list.n;
-       res->data.an_funcall.args = args->data.an_list.ptr;
-       free(args);
+       struct stmt *res = safe_malloc(sizeof(struct stmt));
+       res->type = sif;
+       res->data.sif.pred = pred;
+       res->data.sif.then = (struct stmt **)
+               list_to_array(then, &res->data.sif.nthen, true);
+       res->data.sif.els = (struct stmt **)
+               list_to_array(els, &res->data.sif.nels, true);
        return res;
 }
 
-struct ast *ast_fundecl(struct ast *ident, struct ast *args, struct ast *body)
+struct stmt *stmt_return(struct expr *rtrn)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_fundecl;
-
-       //ident
-       must_be(ident, an_ident, "ident of a fundecl");
-       res->data.an_fundecl.ident = ident->data.an_ident.ident;
-       free(ident->data.an_ident.fields);
-       free(ident);
-
-       //args
-       must_be(args, an_list, "args of a fundecl");
-       res->data.an_fundecl.nargs = args->data.an_list.n;
-       res->data.an_fundecl.args = (char **)args->data.an_list.ptr;
-       for (int i = 0; i<args->data.an_list.n; i++) {
-               struct ast *e = args->data.an_list.ptr[i];
-               must_be(e, an_ident, "arg of a fundecl")
-               res->data.an_fundecl.args[i] = e->data.an_ident.ident;
-               free(e->data.an_ident.fields);
-               free(e);
-       }
-       free(args);
-
-       //body
-       must_be(body, an_list, "body of a fundecl");
-       res->data.an_fundecl.nbody = body->data.an_list.n;
-       res->data.an_fundecl.body = body->data.an_list.ptr;
-       free(body);
-
+       struct stmt *res = safe_malloc(sizeof(struct stmt));
+       res->type = sreturn;
+       res->data.sreturn = rtrn;
        return res;
 }
 
-struct ast *ast_if(struct ast *pred, struct ast *then, struct ast *els)
+struct stmt *stmt_expr(struct expr *expr)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_if;
-       res->data.an_if.pred = pred;
-
-       must_be(then, an_list, "body of a then");
-       res->data.an_if.nthen = then->data.an_list.n;
-       res->data.an_if.then = then->data.an_list.ptr;
-       free(then);
-
-       must_be(els, an_list, "body of a els");
-       res->data.an_if.nels = els->data.an_list.n;
-       res->data.an_if.els = els->data.an_list.ptr;
-       free(els);
-
+       struct stmt *res = safe_malloc(sizeof(struct stmt));
+       res->type = sexpr;
+       res->data.sexpr = expr;
        return res;
 }
 
-struct ast *ast_int(int integer)
+struct stmt *stmt_vardecl(struct vardecl *vardecl)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_int;
-       res->data.an_int = integer;
+       struct stmt *res = safe_malloc(sizeof(struct stmt));
+       res->type = svardecl;
+       res->data.svardecl = vardecl;
        return res;
 }
 
-struct ast *ast_identc(char *ident)
+struct stmt *stmt_while(struct expr *pred, struct list *body)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_ident;
-       res->data.an_ident.ident = safe_strdup(ident);
-       res->data.an_ident.nfields = 0;
-       res->data.an_ident.fields = NULL;
+       struct stmt *res = safe_malloc(sizeof(struct stmt));
+       res->type = swhile;
+       res->data.swhile.pred = pred;
+       res->data.swhile.body = (struct stmt **)
+               list_to_array(body, &res->data.swhile.nbody, true);
        return res;
 }
 
-struct ast *ast_ident(struct ast *ident, struct ast *fields)
-{
-       struct ast *res = ast_alloc();
-       res->type = an_ident;
-       must_be(fields, an_ident, "ident of an ident");
-       res->data.an_ident.ident = ident->data.an_ident.ident;
-       free(ident);
-
-       must_be(fields, an_list, "fields of an ident");
-       res->data.an_ident.nfields = fields->data.an_list.n;
-       res->data.an_ident.fields = (enum fieldspec *)safe_malloc(
-               fields->data.an_list.n*sizeof(enum fieldspec));
-       for (int i = 0; i<fields->data.an_list.n; i++) {
-               struct ast *t = fields->data.an_list.ptr[i];
-               must_be(t, an_ident, "field of an ident");
-               if (strcmp(t->data.an_ident.ident, "fst") == 0)
-                       res->data.an_ident.fields[i] = fst;
-               else if (strcmp(t->data.an_ident.ident, "snd") == 0)
-                       res->data.an_ident.fields[i] = snd;
-               else if (strcmp(t->data.an_ident.ident, "hd") == 0)
-                       res->data.an_ident.fields[i] = hd;
-               else if (strcmp(t->data.an_ident.ident, "tl") == 0)
-                       res->data.an_ident.fields[i] = tl;
-               free(t->data.an_ident.ident);
-               free(t);
-       }
-       free(fields->data.an_list.ptr);
-       free(fields);
+struct expr *expr_binop(struct expr *l, enum binop op, struct expr *r)
+{
+       struct expr *res = safe_malloc(sizeof(struct expr));
+       res->type = ebinop;
+       res->data.ebinop.l = l;
+       res->data.ebinop.op = op;
+       res->data.ebinop.r = r;
        return res;
 }
 
-struct ast *ast_list(struct ast *llist)
+struct expr *expr_bool(bool b)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_list;
-       res->data.an_list.n = 0;
-
-       int i = ast_llistlength(llist);
+       struct expr *res = safe_malloc(sizeof(struct expr));
+       res->type = ebool;
+       res->data.ebool = b;
+       return res;
+}
 
-       //Allocate array
-       res->data.an_list.n = i;
-       res->data.an_list.ptr = (struct ast **)safe_malloc(
-               res->data.an_list.n*sizeof(struct ast *));
+struct expr *expr_char(char *c)
+{
+       struct expr *res = safe_malloc(sizeof(struct expr));
+       res->type = echar;
+       res->data.echar = unescape_char(c)[0];
+       return res;
+}
 
-       struct ast *r = llist;
-       while(i > 0) {
-               res->data.an_list.ptr[--i] = r->data.an_cons.el;
-               struct ast *t = r;
-               r = r->data.an_cons.tail;
+static void set_fields(enum fieldspec **farray, int *n, struct list *fields)
+{
+       void **els = list_to_array(fields, n, true);
+       *farray = (enum fieldspec *)safe_malloc(*n*sizeof(enum fieldspec));
+       for (int i = 0; i<*n; i++) {
+               char *t = els[i];
+               if (strcmp(t, "fst") == 0)
+                       (*farray)[i] = fst;
+               else if (strcmp(t, "snd") == 0)
+                       (*farray)[i] = snd;
+               else if (strcmp(t, "hd") == 0)
+                       (*farray)[i] = hd;
+               else if (strcmp(t, "tl") == 0)
+                       (*farray)[i] = tl;
                free(t);
        }
-       return res;
+       free(els);
 }
 
-struct ast *ast_nil()
+
+struct expr *expr_funcall(char *ident, struct list *args, struct list *fields)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_nil;
+       struct expr *res = safe_malloc(sizeof(struct expr));
+       res->type = efuncall;
+       res->data.efuncall.ident = ident;
+       res->data.efuncall.args = (struct expr **)
+               list_to_array(args, &res->data.efuncall.nargs, true);
+       set_fields(&res->data.efuncall.fields,
+               &res->data.efuncall.nfields, fields);
        return res;
 }
 
-struct ast *ast_return(struct ast *r)
+struct expr *expr_int(int integer)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_return;
-       res->data.an_return = r;
+       struct expr *res = safe_malloc(sizeof(struct expr));
+       res->type = eint;
+       res->data.eint = integer;
        return res;
 }
 
-struct ast *ast_stmt_expr(struct ast *expr)
+struct expr *expr_ident(char *ident, struct list *fields)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_stmt_expr;
-       res->data.an_stmt_expr = expr;
+       struct expr *res = safe_malloc(sizeof(struct expr));
+       res->type = eident;
+       res->data.eident.ident = ident;
+       set_fields(&res->data.eident.fields, &res->data.eident.nfields, fields);
        return res;
 }
 
-struct ast *ast_unop(enum unop op, struct ast *l)
+struct expr *expr_nil()
 {
-       struct ast *res = ast_alloc();
-       res->type = an_unop;
-       res->data.an_unop.op = op;
-       res->data.an_unop.l = l;
+       struct expr *res = safe_malloc(sizeof(struct expr));
+       res->type = enil;
        return res;
 }
 
-struct ast *ast_vardecl(struct ast *ident, struct ast *l)
+struct expr *expr_tuple(struct expr *left, struct expr *right)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_vardecl;
-       must_be(ident, an_ident, "ident of a vardecl");
+       struct expr *res = safe_malloc(sizeof(struct expr));
+       res->type = etuple;
+       res->data.etuple.left = left;
+       res->data.etuple.right = right;
+       return res;
+}
 
-       res->data.an_vardecl.ident = ident->data.an_ident.ident;
-       free(ident->data.an_ident.fields);
-       free(ident);
-       res->data.an_vardecl.l = l;
+struct expr *expr_string(char *str)
+{
+       struct expr *res = safe_malloc(sizeof(struct expr));
+       res->type = estring;
+       res->data.estring.nchars = 0;
+       res->data.estring.chars = safe_malloc(strlen(str)+1);
+       char *p = res->data.estring.chars;
+       while(*str != '\0') {
+               str = unescape_char(str);
+               *p++ = *str++;
+               res->data.estring.nchars++;
+       }
+       *p = '\0';
        return res;
 }
 
-struct ast *ast_while(struct ast *pred, struct ast *body)
+struct expr *expr_unop(enum unop op, struct expr *l)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_while;
-       res->data.an_while.pred = pred;
-       must_be(body, an_list, "body of a while");
-       res->data.an_while.nbody = body->data.an_list.n;
-       res->data.an_while.body = body->data.an_list.ptr;
-       free(body);
+       struct expr *res = safe_malloc(sizeof(struct expr));
+       res->type = eunop;
+       res->data.eunop.op = op;
+       res->data.eunop.l = l;
        return res;
 }
 
-int ast_llistlength(struct ast *r)
+void ast_print(struct ast *ast, FILE *out)
+{
+       if (ast == NULL)
+               return;
+       for (int i = 0; i<ast->ndecls; i++)
+               decl_print(ast->decls[i], out);
+}
+
+void vardecl_print(struct vardecl *decl, int indent, FILE *out)
 {
-       int i = 0;
-       while(r != NULL) {
-               i++;
-               if (r->type != an_cons) {
-                       return 1;
+       pindent(indent, out);
+       if (decl->type == NULL)
+               safe_fprintf(out, "var");
+       else
+               type_print(decl->type, out);
+       safe_fprintf(out, " %s = ", decl->ident);
+       expr_print(decl->expr, out);
+       safe_fprintf(out, ";\n");
+}
+
+void fundecl_print(struct fundecl *decl, FILE *out)
+{
+       safe_fprintf(out, "%s (", decl->ident);
+       for (int i = 0; i<decl->nargs; i++) {
+               safe_fprintf(out, "%s", decl->args[i]);
+               if (i < decl->nargs - 1)
+                       safe_fprintf(out, ", ");
+       }
+       safe_fprintf(out, ")");
+       if (decl->rtype != NULL) {
+               safe_fprintf(out, " :: ");
+               for (int i = 0; i<decl->natypes; i++) {
+                       type_print(decl->atypes[i], out);
+                       safe_fprintf(out, " ");
                }
-               r = r->data.an_cons.tail;
+               safe_fprintf(out, "-> ");
+               type_print(decl->rtype, out);
        }
-       return i;
-}
-
-const char *cescapes[] = {
-       [0] = "\\0", [1] = "\\x01", [2] = "\\x02", [3] = "\\x03", 
-       [4] = "\\x04", [5] = "\\x05", [6] = "\\x06", [7] = "\\a", [8] = "\\b",
-       [9] = "\\t", [10] = "\\n", [11] = "\\v", [12] = "\\f", [13] = "\\r",
-       [14] = "\\x0E", [15] = "\\x0F", [16] = "\\x10", [17] = "\\x11",
-       [18] = "\\x12", [19] = "\\x13", [20] = "\\x14", [21] = "\\x15",
-       [22] = "\\x16", [23] = "\\x17", [24] = "\\x18", [25] = "\\x19",
-       [26] = "\\x1A", [27] = "\\x1B", [28] = "\\x1C", [29] = "\\x1D",
-       [30] = "\\x1E", [31] = "\\x1F",
-       [127] = "\\x7F"
-};
+       safe_fprintf(out, " {\n");
+       for (int i = 0; i<decl->nbody; i++)
+               stmt_print(decl->body[i], 1, out);
+       safe_fprintf(out, "}\n");
+}
 
-void ast_print(struct ast *ast, int indent, FILE *out)
+void decl_print(struct decl *decl, FILE *out)
 {
-       if (ast == NULL)
+       if (decl == NULL)
                return;
-#ifdef DEBUG
-       fprintf(stderr, "ast_free(%s)\n", ast_type_str[ast->type]);
-#endif
-       switch(ast->type) {
-       case an_assign:
-               pindent(indent, out);
-               ast_print(ast->data.an_assign.ident, indent, out);
-               safe_fprintf(out, " = ");
-               ast_print(ast->data.an_assign.expr, indent, out);
-               safe_fprintf(out, ";\n");
+       switch(decl->type) {
+       case dfundecl:
+               fundecl_print(decl->data.dfun, out);
                break;
-       case an_binop:
-               safe_fprintf(out, "(");
-               ast_print(ast->data.an_binop.l, indent, out);
-               safe_fprintf(out, "%s", binop_str[ast->data.an_binop.op]);
-               ast_print(ast->data.an_binop.r, indent, out);
-               safe_fprintf(out, ")");
+       case dvardecl:
+               vardecl_print(decl->data.dvar, 0, out);
                break;
-       case an_bool:
-               safe_fprintf(out, "%s", ast->data.an_bool ? "true" : "false");
-               break;
-       case an_char:
-               if (ast->data.an_char < 0)
-                       safe_fprintf(out, "'?'");
-               if (ast->data.an_char < ' ' || ast->data.an_char == 127)
-                       safe_fprintf(out, "'%s'",
-                               cescapes[(int)ast->data.an_char]);
-               else
-                       safe_fprintf(out, "'%c'", ast->data.an_char);
-               break;
-       case an_funcall:
-               safe_fprintf(out, "%s(", ast->data.an_funcall.ident);
-               for(int i = 0; i<ast->data.an_fundecl.nargs; i++) {
-                       ast_print(ast->data.an_funcall.args[i], indent, out);
-                       if (i+1 < ast->data.an_fundecl.nargs)
-                               safe_fprintf(out, ", ");
-               }
-               safe_fprintf(out, ")");
+       case dcomp:
+               fprintf(out, "//<<<comp\n");
+               for (int i = 0; i<decl->data.dcomp.ndecls; i++)
+                       fundecl_print(decl->data.dcomp.decls[i], out);
+               fprintf(out, "//>>>comp\n");
                break;
-       case an_fundecl:
-               pindent(indent, out);
-               safe_fprintf(out, "%s (", ast->data.an_fundecl.ident);
-               for (int i = 0; i<ast->data.an_fundecl.nargs; i++) {
-                       safe_fprintf(out, "%s", ast->data.an_fundecl.args[i]);
-                       if (i < ast->data.an_fundecl.nargs - 1)
-                               safe_fprintf(out, ", ");
-               }
-               safe_fprintf(out, ") {\n");
-               for (int i = 0; i<ast->data.an_fundecl.nbody; i++)
-                       ast_print(ast->data.an_fundecl.body[i], indent+1, out);
+       default:
+               die("Unsupported decl node\n");
+       }
+}
+
+void stmt_print(struct stmt *stmt, int indent, FILE *out)
+{
+       if (stmt == NULL)
+               return;
+       switch(stmt->type) {
+       case sassign:
                pindent(indent, out);
-               safe_fprintf(out, "}\n");
+               fprintf(out, "%s", stmt->data.sassign.ident);
+               for (int i = 0; i<stmt->data.sassign.nfields; i++)
+                       fprintf(out, ".%s", stmt->data.sassign.fields[i]);
+               safe_fprintf(out, " = ");
+               expr_print(stmt->data.sassign.expr, out);
+               safe_fprintf(out, ";\n");
                break;
-       case an_if:
+       case sif:
                pindent(indent, out);
                safe_fprintf(out, "if (");
-               ast_print(ast->data.an_if.pred, indent, out);
+               expr_print(stmt->data.sif.pred, out);
                safe_fprintf(out, ") {\n");
-               for (int i = 0; i<ast->data.an_if.nthen; i++)
-                       ast_print(ast->data.an_if.then[i], indent+1, out);
+               for (int i = 0; i<stmt->data.sif.nthen; i++)
+                       stmt_print(stmt->data.sif.then[i], indent+1, out);
                pindent(indent, out);
                safe_fprintf(out, "} else {\n");
-               for (int i = 0; i<ast->data.an_if.nels; i++)
-                       ast_print(ast->data.an_if.els[i], indent+1, out);
+               for (int i = 0; i<stmt->data.sif.nels; i++)
+                       stmt_print(stmt->data.sif.els[i], indent+1, out);
                pindent(indent, out);
                safe_fprintf(out, "}\n");
                break;
-       case an_int:
-               safe_fprintf(out, "%d", ast->data.an_int);
-               break;
-       case an_ident:
-               fprintf(out, "%s", ast->data.an_ident.ident);
-               for (int i = 0; i<ast->data.an_ident.nfields; i++)
-                       fprintf(out, ".%s",
-                               fieldspec_str[ast->data.an_ident.fields[i]]);
-               break;
-       case an_cons:
-               ast_print(ast->data.an_cons.el, indent, out);
-               ast_print(ast->data.an_cons.tail, indent, out);
-               break;
-       case an_list:
-               for (int i = 0; i<ast->data.an_list.n; i++)
-                       ast_print(ast->data.an_list.ptr[i], indent, out);
-               break;
-       case an_nil:
-               safe_fprintf(out, "[]");
-               break;
-       case an_return:
+       case sreturn:
                pindent(indent, out);
                safe_fprintf(out, "return ");
-               ast_print(ast->data.an_return, indent, out);
+               expr_print(stmt->data.sreturn, out);
                safe_fprintf(out, ";\n");
                break;
-       case an_stmt_expr:
+       case sexpr:
                pindent(indent, out);
-               ast_print(ast->data.an_stmt_expr, indent, out);
+               expr_print(stmt->data.sexpr, out);
                safe_fprintf(out, ";\n");
                break;
-       case an_unop:
-               safe_fprintf(out, "(%s", unop_str[ast->data.an_unop.op]);
-               ast_print(ast->data.an_unop.l, indent, out);
-               safe_fprintf(out, ")");
+       case svardecl:
+               vardecl_print(stmt->data.svardecl, indent, out);
                break;
-       case an_vardecl:
-               pindent(indent, out);
-               safe_fprintf(out, "var %s = ", ast->data.an_vardecl.ident);
-               ast_print(ast->data.an_vardecl.l, indent, out);
-               safe_fprintf(out, ";\n");
-               break;
-       case an_while:
+       case swhile:
                pindent(indent, out);
                safe_fprintf(out, "while (");
-               ast_print(ast->data.an_while.pred, indent, out);
+               expr_print(stmt->data.swhile.pred, out);
                safe_fprintf(out, ") {\n");
-               for (int i = 0; i<ast->data.an_while.nbody; i++) {
-                       ast_print(ast->data.an_while.body[i], indent+1, out);
-               }
+               for (int i = 0; i<stmt->data.swhile.nbody; i++)
+                       stmt_print(stmt->data.swhile.body[i], indent+1, out);
                pindent(indent, out);
                safe_fprintf(out, "}\n");
                break;
        default:
-               die("Unsupported AST node\n");
+               die("Unsupported stmt node\n");
+       }
+}
+
+void expr_print(struct expr *expr, FILE *out)
+{
+       if (expr == NULL)
+               return;
+       char buf[] = "\\xff";
+       switch(expr->type) {
+       case ebinop:
+               safe_fprintf(out, "(");
+               expr_print(expr->data.ebinop.l, out);
+               safe_fprintf(out, "%s", binop_str[expr->data.ebinop.op]);
+               expr_print(expr->data.ebinop.r, out);
+               safe_fprintf(out, ")");
+               break;
+       case ebool:
+               safe_fprintf(out, "%s", expr->data.ebool ? "true" : "false");
+               break;
+       case echar:
+               safe_fprintf(out, "'%s'",
+                       escape_char(expr->data.echar, buf, false));
+               break;
+       case efuncall:
+               safe_fprintf(out, "%s(", expr->data.efuncall.ident);
+               for(int i = 0; i<expr->data.efuncall.nargs; i++) {
+                       expr_print(expr->data.efuncall.args[i], out);
+                       if (i+1 < expr->data.efuncall.nargs)
+                               safe_fprintf(out, ", ");
+               }
+               safe_fprintf(out, ")");
+               for (int i = 0; i<expr->data.efuncall.nfields; i++)
+                       fprintf(out, ".%s",
+                               fieldspec_str[expr->data.efuncall.fields[i]]);
+               break;
+       case eint:
+               safe_fprintf(out, "%d", expr->data.eint);
+               break;
+       case eident:
+               fprintf(out, "%s", expr->data.eident.ident);
+               for (int i = 0; i<expr->data.eident.nfields; i++)
+                       fprintf(out, ".%s",
+                               fieldspec_str[expr->data.eident.fields[i]]);
+               break;
+       case enil:
+               safe_fprintf(out, "[]");
+               break;
+       case etuple:
+               safe_fprintf(out, "(");
+               expr_print(expr->data.etuple.left, out);
+               safe_fprintf(out, ", ");
+               expr_print(expr->data.etuple.right, out);
+               safe_fprintf(out, ")");
+               break;
+       case estring:
+               safe_fprintf(out, "\"");
+               for (int i = 0; i<expr->data.estring.nchars; i++)
+                       safe_fprintf(out, "%s", escape_char(
+                               expr->data.estring.chars[i], buf, true));
+               safe_fprintf(out, "\"");
+               break;
+       case eunop:
+               safe_fprintf(out, "(%s", unop_str[expr->data.eunop.op]);
+               expr_print(expr->data.eunop.l, out);
+               safe_fprintf(out, ")");
+               break;
+       default:
+               die("Unsupported expr node\n");
        }
 }
 
@@ -468,84 +429,140 @@ void ast_free(struct ast *ast)
 {
        if (ast == NULL)
                return;
-#ifdef DEBUG
-       fprintf(stderr, "ast_free(%s)\n", ast_type_str[ast->type]);
-#endif
-       switch(ast->type) {
-       case an_assign:
-               ast_free(ast->data.an_assign.ident);
-               ast_free(ast->data.an_assign.expr);
-               break;
-       case an_binop:
-               ast_free(ast->data.an_binop.l);
-               ast_free(ast->data.an_binop.r);
-               break;
-       case an_bool:
-               break;
-       case an_char:
-               break;
-       case an_cons:
-               ast_free(ast->data.an_cons.el);
-               ast_free(ast->data.an_cons.tail);
-               break;
-       case an_funcall:
-               free(ast->data.an_funcall.ident);
-               for (int i = 0; i<ast->data.an_fundecl.nargs; i++)
-                       ast_free(ast->data.an_funcall.args[i]);
-               free(ast->data.an_funcall.args);
-               break;
-       case an_fundecl:
-               free(ast->data.an_fundecl.ident);
-               for (int i = 0; i<ast->data.an_fundecl.nargs; i++)
-                       free(ast->data.an_fundecl.args[i]);
-               free(ast->data.an_fundecl.args);
-               for (int i = 0; i<ast->data.an_fundecl.nbody; i++)
-                       ast_free(ast->data.an_fundecl.body[i]);
-               free(ast->data.an_fundecl.body);
-               break;
-       case an_if:
-               ast_free(ast->data.an_if.pred);
-               for (int i = 0; i<ast->data.an_if.nthen; i++)
-                       ast_free(ast->data.an_if.then[i]);
-               free(ast->data.an_if.then);
-               for (int i = 0; i<ast->data.an_if.nels; i++)
-                       ast_free(ast->data.an_if.els[i]);
-               free(ast->data.an_if.els);
-               break;
-       case an_int:
-               break;
-       case an_ident:
-               free(ast->data.an_ident.ident);
-               free(ast->data.an_ident.fields);
-               break;
-       case an_list:
-               for (int i = 0; i<ast->data.an_list.n; i++)
-                       ast_free(ast->data.an_list.ptr[i]);
-               free(ast->data.an_list.ptr);
-               break;
-       case an_nil:
-               break;
-       case an_return:
-               ast_free(ast->data.an_return);
-               break;
-       case an_stmt_expr:
-               ast_free(ast->data.an_stmt_expr);
-               break;
-       case an_unop:
-               ast_free(ast->data.an_unop.l);
-               break;
-       case an_vardecl:
-               free(ast->data.an_vardecl.ident);
-               ast_free(ast->data.an_vardecl.l);
-               break;
-       case an_while:
-               ast_free(ast->data.an_while.pred);
-               for (int i = 0; i<ast->data.an_while.nbody; i++)
-                       ast_free(ast->data.an_while.body[i]);
-               free(ast->data.an_while.body);
+       for (int i = 0; i<ast->ndecls; i++)
+               decl_free(ast->decls[i]);
+       free(ast->decls);
+       free(ast);
+}
+
+void vardecl_free(struct vardecl *decl)
+{
+       type_free(decl->type);
+       free(decl->ident);
+       expr_free(decl->expr);
+       free(decl);
+}
+
+void fundecl_free(struct fundecl *decl)
+{
+       free(decl->ident);
+       for (int i = 0; i<decl->nargs; i++)
+               free(decl->args[i]);
+       free(decl->args);
+       for (int i = 0; i<decl->natypes; i++)
+               type_free(decl->atypes[i]);
+       free(decl->atypes);
+       type_free(decl->rtype);
+       for (int i = 0; i<decl->nbody; i++)
+               stmt_free(decl->body[i]);
+       free(decl->body);
+       free(decl);
+}
+
+void decl_free(struct decl *decl)
+{
+       if (decl == NULL)
+               return;
+       switch(decl->type) {
+       case dcomp:
+               for (int i = 0; i<decl->data.dcomp.ndecls; i++)
+                       fundecl_free(decl->data.dcomp.decls[i]);
+               free(decl->data.dcomp.decls);
+               break;
+       case dfundecl:
+               fundecl_free(decl->data.dfun);
+               break;
+       case dvardecl:
+               vardecl_free(decl->data.dvar);
                break;
        default:
-               die("Unsupported AST node: %d\n", ast->type);
+               die("Unsupported decl node\n");
        }
-       free(ast);
+       free(decl);
+}
+
+void stmt_free(struct stmt *stmt)
+{
+       if (stmt == NULL)
+               return;
+       switch(stmt->type) {
+       case sassign:
+               free(stmt->data.sassign.ident);
+               for (int i = 0; i<stmt->data.sassign.nfields; i++)
+                       free(stmt->data.sassign.fields[i]);
+               free(stmt->data.sassign.fields);
+               expr_free(stmt->data.sassign.expr);
+               break;
+       case sif:
+               expr_free(stmt->data.sif.pred);
+               for (int i = 0; i<stmt->data.sif.nthen; i++)
+                       stmt_free(stmt->data.sif.then[i]);
+               free(stmt->data.sif.then);
+               for (int i = 0; i<stmt->data.sif.nels; i++)
+                       stmt_free(stmt->data.sif.els[i]);
+               free(stmt->data.sif.els);
+               break;
+       case sreturn:
+               expr_free(stmt->data.sreturn);
+               break;
+       case sexpr:
+               expr_free(stmt->data.sexpr);
+               break;
+       case swhile:
+               expr_free(stmt->data.swhile.pred);
+               for (int i = 0; i<stmt->data.swhile.nbody; i++)
+                       stmt_free(stmt->data.swhile.body[i]);
+               free(stmt->data.swhile.body);
+               break;
+       case svardecl:
+               vardecl_free(stmt->data.svardecl);
+               break;
+       default:
+               die("Unsupported stmt node\n");
+       }
+       free(stmt);
+}
+
+void expr_free(struct expr *expr)
+{
+       if (expr == NULL)
+               return;
+       switch(expr->type) {
+       case ebinop:
+               expr_free(expr->data.ebinop.l);
+               expr_free(expr->data.ebinop.r);
+               break;
+       case ebool:
+               break;
+       case echar:
+               break;
+       case efuncall:
+               free(expr->data.efuncall.ident);
+               for (int i = 0; i<expr->data.efuncall.nargs; i++)
+                       expr_free(expr->data.efuncall.args[i]);
+               free(expr->data.efuncall.fields);
+               free(expr->data.efuncall.args);
+               break;
+       case eint:
+               break;
+       case eident:
+               free(expr->data.eident.ident);
+               free(expr->data.eident.fields);
+               break;
+       case enil:
+               break;
+       case etuple:
+               expr_free(expr->data.etuple.left);
+               expr_free(expr->data.etuple.right);
+               break;
+       case estring:
+               free(expr->data.estring.chars);
+               break;
+       case eunop:
+               expr_free(expr->data.eunop.l);
+               break;
+       default:
+               die("Unsupported expr node\n");
+       }
+       free(expr);
 }