add assign, nil
[ccc.git] / ast.c
diff --git a/ast.c b/ast.c
index a757079..23619ec 100644 (file)
--- a/ast.c
+++ b/ast.c
@@ -5,13 +5,16 @@
 #include "util.h"
 #include "ast.h"
 
+#ifdef DEBUG
 static const char *ast_type_str[] = {
-       [an_bool] = "bool", [an_binop] = "binop", [an_char] = "char",
-       [an_cons] = "cons", [an_fundecl] = "fundecl", [an_ident] = "ident",
-       [an_if] = "if", [an_int] = "int", [an_list] = "list",
+       [an_assign] = "assign", [an_bool] = "bool", [an_binop] = "binop",
+       [an_char] = "char", [an_cons] = "cons", [an_funcall] = "funcall",
+       [an_fundecl] = "fundecl", [an_ident] = "ident", [an_if] = "if",
+       [an_int] = "int", [an_list] = "list", [an_return] = "return",
        [an_stmt_expr] = "stmt_expr", [an_unop] = "unop",
        [an_vardecl] = "vardecl", [an_while] = "while",
 };
+#endif
 static const char *binop_str[] = {
        [binor] = "||", [binand] = "&&", [eq] = "==", [neq] = "!=",
        [leq] = "<=", [le] = "<", [geq] = ">=", [ge] = ">", [cons] = ":",
@@ -20,7 +23,7 @@ static const char *binop_str[] = {
 };
 static const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
 
-
+#ifdef DEBUG
 #define must_be(node, ntype, msg) {\
        if ((node)->type != (ntype)) {\
                fprintf(stderr, "%s can't be %s\n",\
@@ -28,13 +31,22 @@ static const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
                exit(1);\
        }\
 }
+#else
+#define must_be(node, ntype, msg) ;
+#endif
+
 #define ast_alloc() ((struct ast *)safe_malloc(sizeof(struct ast)))
 
-struct ast *ast_bool(bool b)
+struct ast *ast_assign(struct ast *ident, struct ast *expr)
 {
        struct ast *res = ast_alloc();
-       res->type = an_bool;
-       res->data.an_bool = b;
+       res->type = an_assign;
+
+       must_be(ident, an_ident, "ident of an assign");
+       res->data.an_assign.ident = ident->data.an_ident;
+       free(ident);
+
+       res->data.an_assign.expr = expr;
        return res;
 }
 
@@ -48,6 +60,14 @@ struct ast *ast_binop(struct ast *l, enum binop op, struct ast *r)
        return res;
 }
 
+struct ast *ast_bool(bool b)
+{
+       struct ast *res = ast_alloc();
+       res->type = an_bool;
+       res->data.an_bool = b;
+       return res;
+}
+
 int fromHex(char c)
 {
        if (c >= '0' && c <= '9')
@@ -92,6 +112,24 @@ struct ast *ast_cons(struct ast *el, struct ast *tail)
        return res;
 }
 
+struct ast *ast_funcall(struct ast *ident, struct ast *args)
+{
+       struct ast *res = ast_alloc();
+       res->type = an_funcall;
+
+       //ident
+       must_be(ident, an_ident, "ident of a funcall");
+       res->data.an_funcall.ident = ident->data.an_ident;
+       free(ident);
+
+       //args
+       must_be(args, an_list, "args of a funcall");
+       res->data.an_funcall.nargs = args->data.an_list.n;
+       res->data.an_funcall.args = args->data.an_list.ptr;
+       free(args);
+       return res;
+}
+
 struct ast *ast_fundecl(struct ast *ident, struct ast *args, struct ast *body)
 {
        struct ast *res = ast_alloc();
@@ -117,7 +155,9 @@ struct ast *ast_fundecl(struct ast *ident, struct ast *args, struct ast *body)
 
        //body
        must_be(body, an_list, "body of a fundecl");
-       res->data.an_fundecl.body = body;
+       res->data.an_fundecl.nbody = body->data.an_list.n;
+       res->data.an_fundecl.body = body->data.an_list.ptr;
+       free(body);
 
        return res;
 }
@@ -127,8 +167,17 @@ struct ast *ast_if(struct ast *pred, struct ast *then, struct ast *els)
        struct ast *res = ast_alloc();
        res->type = an_if;
        res->data.an_if.pred = pred;
-       res->data.an_if.then = then;
-       res->data.an_if.els = els;
+
+       must_be(body, an_list, "body of a then");
+       res->data.an_if.nthen = then->data.an_list.n;
+       res->data.an_if.then = then->data.an_list.ptr;
+       free(then);
+
+       must_be(body, an_list, "body of a els");
+       res->data.an_if.nels = els->data.an_list.n;
+       res->data.an_if.els = els->data.an_list.ptr;
+       free(els);
+
        return res;
 }
 
@@ -171,6 +220,21 @@ struct ast *ast_list(struct ast *llist)
        return res;
 }
 
+struct ast *ast_nil()
+{
+       struct ast *res = ast_alloc();
+       res->type = an_nil;
+       return res;
+}
+
+struct ast *ast_return(struct ast *r)
+{
+       struct ast *res = ast_alloc();
+       res->type = an_return;
+       res->data.an_return = r;
+       return res;
+}
+
 struct ast *ast_stmt_expr(struct ast *expr)
 {
        struct ast *res = ast_alloc();
@@ -205,7 +269,10 @@ struct ast *ast_while(struct ast *pred, struct ast *body)
        struct ast *res = ast_alloc();
        res->type = an_while;
        res->data.an_while.pred = pred;
-       res->data.an_while.body = body;
+       must_be(body, an_list, "body of a while");
+       res->data.an_while.nbody = body->data.an_list.n;
+       res->data.an_while.body = body->data.an_list.ptr;
+       free(body);
        return res;
 }
 
@@ -234,13 +301,19 @@ const char *cescapes[] = {
        [127] = "\\x7F"
 };
 
-void ast_print(struct ast * ast, int indent, FILE *out)
+void ast_print(struct ast *ast, int indent, FILE *out)
 {
        if (ast == NULL)
                return;
+#ifdef DEBUG
+       fprintf(stderr, "ast_free(%s)\n", ast_type_str[ast->type]);
+#endif
        switch(ast->type) {
-       case an_bool:
-               safe_fprintf(out, "%s", ast->data.an_bool ? "true" : "false");
+       case an_assign:
+               pindent(indent, out);
+               safe_fprintf(out, "%s = ", ast->data.an_assign.ident);
+               ast_print(ast->data.an_assign.expr, indent, out);
+               safe_fprintf(out, ";\n");
                break;
        case an_binop:
                safe_fprintf(out, "(");
@@ -249,6 +322,9 @@ void ast_print(struct ast * ast, int indent, FILE *out)
                ast_print(ast->data.an_binop.r, indent, out);
                safe_fprintf(out, ")");
                break;
+       case an_bool:
+               safe_fprintf(out, "%s", ast->data.an_bool ? "true" : "false");
+               break;
        case an_char:
                if (ast->data.an_char < 0)
                        safe_fprintf(out, "'?'");
@@ -258,6 +334,15 @@ void ast_print(struct ast * ast, int indent, FILE *out)
                else
                        safe_fprintf(out, "'%c'", ast->data.an_char);
                break;
+       case an_funcall:
+               safe_fprintf(out, "%s(", ast->data.an_funcall.ident);
+               for(int i = 0; i<ast->data.an_fundecl.nargs; i++) {
+                       ast_print(ast->data.an_funcall.args[i], indent, out);
+                       if (i+1 < ast->data.an_fundecl.nargs)
+                               safe_fprintf(out, ", ");
+               }
+               safe_fprintf(out, ")");
+               break;
        case an_fundecl:
                pindent(indent, out);
                safe_fprintf(out, "%s (", ast->data.an_fundecl.ident);
@@ -267,7 +352,8 @@ void ast_print(struct ast * ast, int indent, FILE *out)
                                safe_fprintf(out, ", ");
                }
                safe_fprintf(out, ") {\n");
-               ast_print(ast->data.an_fundecl.body, indent+1, out);
+               for (int i = 0; i<ast->data.an_fundecl.nbody; i++)
+                       ast_print(ast->data.an_fundecl.body[i], indent+1, out);
                pindent(indent, out);
                safe_fprintf(out, "}\n");
                break;
@@ -276,16 +362,12 @@ void ast_print(struct ast * ast, int indent, FILE *out)
                safe_fprintf(out, "if (");
                ast_print(ast->data.an_if.pred, indent, out);
                safe_fprintf(out, ") {\n");
-               if (ast->data.an_if.then->data.an_list.n > 0) {
-                       pindent(indent, out);
-                       ast_print(ast->data.an_if.then, indent+1, out);
-               }
+               for (int i = 0; i<ast->data.an_if.nthen; i++)
+                       ast_print(ast->data.an_if.then[i], indent+1, out);
                pindent(indent, out);
                safe_fprintf(out, "} else {\n");
-               if (ast->data.an_if.els->data.an_list.n > 0) {
-                       pindent(indent, out);
-                       ast_print(ast->data.an_if.els, indent+1, out);
-               }
+               for (int i = 0; i<ast->data.an_if.nels; i++)
+                       ast_print(ast->data.an_if.els[i], indent+1, out);
                pindent(indent, out);
                safe_fprintf(out, "}\n");
                break;
@@ -303,6 +385,15 @@ void ast_print(struct ast * ast, int indent, FILE *out)
                for (int i = 0; i<ast->data.an_list.n; i++)
                        ast_print(ast->data.an_list.ptr[i], indent, out);
                break;
+       case an_nil:
+               safe_fprintf(out, "[]");
+               break;
+       case an_return:
+               pindent(indent, out);
+               safe_fprintf(out, "return ");
+               ast_print(ast->data.an_return, indent, out);
+               safe_fprintf(out, ";\n");
+               break;
        case an_stmt_expr:
                pindent(indent, out);
                ast_print(ast->data.an_stmt_expr, indent, out);
@@ -324,9 +415,8 @@ void ast_print(struct ast * ast, int indent, FILE *out)
                safe_fprintf(out, "while (");
                ast_print(ast->data.an_while.pred, indent, out);
                safe_fprintf(out, ") {\n");
-               if (ast->data.an_while.body->data.an_list.n > 0) {
-                       pindent(indent, out);
-                       ast_print(ast->data.an_while.body, indent+1, out);
+               for (int i = 0; i<ast->data.an_while.nbody; i++) {
+                       ast_print(ast->data.an_while.body[i], indent+1, out);
                }
                pindent(indent, out);
                safe_fprintf(out, "}\n");
@@ -340,30 +430,50 @@ void ast_free(struct ast *ast)
 {
        if (ast == NULL)
                return;
+#ifdef DEBUG
+       fprintf(stderr, "ast_free(%s)\n", ast_type_str[ast->type]);
+#endif
        switch(ast->type) {
-       case an_bool:
+       case an_assign:
+               free(ast->data.an_assign.ident);
+               ast_free(ast->data.an_assign.expr);
                break;
        case an_binop:
                ast_free(ast->data.an_binop.l);
                ast_free(ast->data.an_binop.r);
                break;
+       case an_bool:
+               break;
        case an_char:
                break;
        case an_cons:
                ast_free(ast->data.an_cons.el);
                ast_free(ast->data.an_cons.tail);
                break;
+       case an_funcall:
+               free(ast->data.an_funcall.ident);
+               for (int i = 0; i<ast->data.an_fundecl.nargs; i++)
+                       free(ast->data.an_fundecl.args[i]);
+               free(ast->data.an_funcall.args);
+               break;
        case an_fundecl:
                free(ast->data.an_fundecl.ident);
                for (int i = 0; i<ast->data.an_fundecl.nargs; i++)
                        free(ast->data.an_fundecl.args[i]);
                free(ast->data.an_fundecl.args);
-               ast_free(ast->data.an_fundecl.body);
+               for (int i = 0; i<ast->data.an_fundecl.nbody; i++)
+                       ast_free(ast->data.an_fundecl.body[i]);
+               free(ast->data.an_fundecl.body);
                break;
        case an_if:
                ast_free(ast->data.an_if.pred);
-               ast_free(ast->data.an_if.then);
-               ast_free(ast->data.an_if.els);
+               for (int i = 0; i<ast->data.an_if.nthen; i++)
+                       ast_free(ast->data.an_if.then[i]);
+               free(ast->data.an_if.then);
+               for (int i = 0; i<ast->data.an_if.nels; i++)
+                       ast_free(ast->data.an_if.els[i]);
+               free(ast->data.an_if.els);
+               break;
        case an_int:
                break;
        case an_ident:
@@ -374,6 +484,11 @@ void ast_free(struct ast *ast)
                        ast_free(ast->data.an_list.ptr[i]);
                free(ast->data.an_list.ptr);
                break;
+       case an_nil:
+               break;
+       case an_return:
+               ast_free(ast->data.an_return);
+               break;
        case an_stmt_expr:
                ast_free(ast->data.an_stmt_expr);
                break;
@@ -386,7 +501,9 @@ void ast_free(struct ast *ast)
                break;
        case an_while:
                ast_free(ast->data.an_while.pred);
-               ast_free(ast->data.an_while.body);
+               for (int i = 0; i<ast->data.an_while.nbody; i++)
+                       ast_free(ast->data.an_while.body[i]);
+               free(ast->data.an_while.body);
                break;
        default:
                die("Unsupported AST node: %d\n", ast->type);