#include "util.h"
 #include "ast.h"
 
+#ifdef DEBUG
 static const char *ast_type_str[] = {
-       [an_bool] = "bool", [an_binop] = "binop", [an_char] = "char",
-       [an_cons] = "cons", [an_fundecl] = "fundecl", [an_ident] = "ident",
-       [an_if] = "if", [an_int] = "int", [an_list] = "list",
+       [an_assign] = "assign", [an_bool] = "bool", [an_binop] = "binop",
+       [an_char] = "char", [an_cons] = "cons", [an_funcall] = "funcall",
+       [an_fundecl] = "fundecl", [an_ident] = "ident", [an_if] = "if",
+       [an_int] = "int", [an_list] = "list", [an_return] = "return",
        [an_stmt_expr] = "stmt_expr", [an_unop] = "unop",
        [an_vardecl] = "vardecl", [an_while] = "while",
 };
+#endif
 static const char *binop_str[] = {
        [binor] = "||", [binand] = "&&", [eq] = "==", [neq] = "!=",
        [leq] = "<=", [le] = "<", [geq] = ">=", [ge] = ">", [cons] = ":",
 };
 static const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
 
-
+#ifdef DEBUG
 #define must_be(node, ntype, msg) {\
        if ((node)->type != (ntype)) {\
                fprintf(stderr, "%s can't be %s\n",\
                exit(1);\
        }\
 }
+#else
+#define must_be(node, ntype, msg) ;
+#endif
+
 #define ast_alloc() ((struct ast *)safe_malloc(sizeof(struct ast)))
 
-struct ast *ast_bool(bool b)
+struct ast *ast_assign(struct ast *ident, struct ast *expr)
 {
        struct ast *res = ast_alloc();
-       res->type = an_bool;
-       res->data.an_bool = b;
+       res->type = an_assign;
+
+       must_be(ident, an_ident, "ident of an assign");
+       res->data.an_assign.ident = ident->data.an_ident;
+       free(ident);
+
+       res->data.an_assign.expr = expr;
        return res;
 }
 
        return res;
 }
 
+struct ast *ast_bool(bool b)
+{
+       struct ast *res = ast_alloc();
+       res->type = an_bool;
+       res->data.an_bool = b;
+       return res;
+}
+
 int fromHex(char c)
 {
        if (c >= '0' && c <= '9')
        return res;
 }
 
+struct ast *ast_funcall(struct ast *ident, struct ast *args)
+{
+       struct ast *res = ast_alloc();
+       res->type = an_funcall;
+
+       //ident
+       must_be(ident, an_ident, "ident of a funcall");
+       res->data.an_funcall.ident = ident->data.an_ident;
+       free(ident);
+
+       //args
+       must_be(args, an_list, "args of a funcall");
+       res->data.an_funcall.nargs = args->data.an_list.n;
+       res->data.an_funcall.args = args->data.an_list.ptr;
+       free(args);
+       return res;
+}
+
 struct ast *ast_fundecl(struct ast *ident, struct ast *args, struct ast *body)
 {
        struct ast *res = ast_alloc();
 
        //body
        must_be(body, an_list, "body of a fundecl");
-       res->data.an_fundecl.body = body;
+       res->data.an_fundecl.nbody = body->data.an_list.n;
+       res->data.an_fundecl.body = body->data.an_list.ptr;
+       free(body);
 
        return res;
 }
        struct ast *res = ast_alloc();
        res->type = an_if;
        res->data.an_if.pred = pred;
-       res->data.an_if.then = then;
-       res->data.an_if.els = els;
+
+       must_be(body, an_list, "body of a then");
+       res->data.an_if.nthen = then->data.an_list.n;
+       res->data.an_if.then = then->data.an_list.ptr;
+       free(then);
+
+       must_be(body, an_list, "body of a els");
+       res->data.an_if.nels = els->data.an_list.n;
+       res->data.an_if.els = els->data.an_list.ptr;
+       free(els);
+
        return res;
 }
 
        return res;
 }
 
+struct ast *ast_nil()
+{
+       struct ast *res = ast_alloc();
+       res->type = an_nil;
+       return res;
+}
+
+struct ast *ast_return(struct ast *r)
+{
+       struct ast *res = ast_alloc();
+       res->type = an_return;
+       res->data.an_return = r;
+       return res;
+}
+
 struct ast *ast_stmt_expr(struct ast *expr)
 {
        struct ast *res = ast_alloc();
        struct ast *res = ast_alloc();
        res->type = an_while;
        res->data.an_while.pred = pred;
-       res->data.an_while.body = body;
+       must_be(body, an_list, "body of a while");
+       res->data.an_while.nbody = body->data.an_list.n;
+       res->data.an_while.body = body->data.an_list.ptr;
+       free(body);
        return res;
 }
 
        [127] = "\\x7F"
 };
 
-void ast_print(struct ast * ast, int indent, FILE *out)
+void ast_print(struct ast *ast, int indent, FILE *out)
 {
        if (ast == NULL)
                return;
+#ifdef DEBUG
+       fprintf(stderr, "ast_free(%s)\n", ast_type_str[ast->type]);
+#endif
        switch(ast->type) {
-       case an_bool:
-               safe_fprintf(out, "%s", ast->data.an_bool ? "true" : "false");
+       case an_assign:
+               pindent(indent, out);
+               safe_fprintf(out, "%s = ", ast->data.an_assign.ident);
+               ast_print(ast->data.an_assign.expr, indent, out);
+               safe_fprintf(out, ";\n");
                break;
        case an_binop:
                safe_fprintf(out, "(");
                ast_print(ast->data.an_binop.r, indent, out);
                safe_fprintf(out, ")");
                break;
+       case an_bool:
+               safe_fprintf(out, "%s", ast->data.an_bool ? "true" : "false");
+               break;
        case an_char:
                if (ast->data.an_char < 0)
                        safe_fprintf(out, "'?'");
                else
                        safe_fprintf(out, "'%c'", ast->data.an_char);
                break;
+       case an_funcall:
+               safe_fprintf(out, "%s(", ast->data.an_funcall.ident);
+               for(int i = 0; i<ast->data.an_fundecl.nargs; i++) {
+                       ast_print(ast->data.an_funcall.args[i], indent, out);
+                       if (i+1 < ast->data.an_fundecl.nargs)
+                               safe_fprintf(out, ", ");
+               }
+               safe_fprintf(out, ")");
+               break;
        case an_fundecl:
                pindent(indent, out);
                safe_fprintf(out, "%s (", ast->data.an_fundecl.ident);
                                safe_fprintf(out, ", ");
                }
                safe_fprintf(out, ") {\n");
-               ast_print(ast->data.an_fundecl.body, indent+1, out);
+               for (int i = 0; i<ast->data.an_fundecl.nbody; i++)
+                       ast_print(ast->data.an_fundecl.body[i], indent+1, out);
                pindent(indent, out);
                safe_fprintf(out, "}\n");
                break;
                safe_fprintf(out, "if (");
                ast_print(ast->data.an_if.pred, indent, out);
                safe_fprintf(out, ") {\n");
-               if (ast->data.an_if.then->data.an_list.n > 0) {
-                       pindent(indent, out);
-                       ast_print(ast->data.an_if.then, indent+1, out);
-               }
+               for (int i = 0; i<ast->data.an_if.nthen; i++)
+                       ast_print(ast->data.an_if.then[i], indent+1, out);
                pindent(indent, out);
                safe_fprintf(out, "} else {\n");
-               if (ast->data.an_if.els->data.an_list.n > 0) {
-                       pindent(indent, out);
-                       ast_print(ast->data.an_if.els, indent+1, out);
-               }
+               for (int i = 0; i<ast->data.an_if.nels; i++)
+                       ast_print(ast->data.an_if.els[i], indent+1, out);
                pindent(indent, out);
                safe_fprintf(out, "}\n");
                break;
                for (int i = 0; i<ast->data.an_list.n; i++)
                        ast_print(ast->data.an_list.ptr[i], indent, out);
                break;
+       case an_nil:
+               safe_fprintf(out, "[]");
+               break;
+       case an_return:
+               pindent(indent, out);
+               safe_fprintf(out, "return ");
+               ast_print(ast->data.an_return, indent, out);
+               safe_fprintf(out, ";\n");
+               break;
        case an_stmt_expr:
                pindent(indent, out);
                ast_print(ast->data.an_stmt_expr, indent, out);
                safe_fprintf(out, "while (");
                ast_print(ast->data.an_while.pred, indent, out);
                safe_fprintf(out, ") {\n");
-               if (ast->data.an_while.body->data.an_list.n > 0) {
-                       pindent(indent, out);
-                       ast_print(ast->data.an_while.body, indent+1, out);
+               for (int i = 0; i<ast->data.an_while.nbody; i++) {
+                       ast_print(ast->data.an_while.body[i], indent+1, out);
                }
                pindent(indent, out);
                safe_fprintf(out, "}\n");
 {
        if (ast == NULL)
                return;
+#ifdef DEBUG
+       fprintf(stderr, "ast_free(%s)\n", ast_type_str[ast->type]);
+#endif
        switch(ast->type) {
-       case an_bool:
+       case an_assign:
+               free(ast->data.an_assign.ident);
+               ast_free(ast->data.an_assign.expr);
                break;
        case an_binop:
                ast_free(ast->data.an_binop.l);
                ast_free(ast->data.an_binop.r);
                break;
+       case an_bool:
+               break;
        case an_char:
                break;
        case an_cons:
                ast_free(ast->data.an_cons.el);
                ast_free(ast->data.an_cons.tail);
                break;
+       case an_funcall:
+               free(ast->data.an_funcall.ident);
+               for (int i = 0; i<ast->data.an_fundecl.nargs; i++)
+                       free(ast->data.an_fundecl.args[i]);
+               free(ast->data.an_funcall.args);
+               break;
        case an_fundecl:
                free(ast->data.an_fundecl.ident);
                for (int i = 0; i<ast->data.an_fundecl.nargs; i++)
                        free(ast->data.an_fundecl.args[i]);
                free(ast->data.an_fundecl.args);
-               ast_free(ast->data.an_fundecl.body);
+               for (int i = 0; i<ast->data.an_fundecl.nbody; i++)
+                       ast_free(ast->data.an_fundecl.body[i]);
+               free(ast->data.an_fundecl.body);
                break;
        case an_if:
                ast_free(ast->data.an_if.pred);
-               ast_free(ast->data.an_if.then);
-               ast_free(ast->data.an_if.els);
+               for (int i = 0; i<ast->data.an_if.nthen; i++)
+                       ast_free(ast->data.an_if.then[i]);
+               free(ast->data.an_if.then);
+               for (int i = 0; i<ast->data.an_if.nels; i++)
+                       ast_free(ast->data.an_if.els[i]);
+               free(ast->data.an_if.els);
+               break;
        case an_int:
                break;
        case an_ident:
                        ast_free(ast->data.an_list.ptr[i]);
                free(ast->data.an_list.ptr);
                break;
+       case an_nil:
+               break;
+       case an_return:
+               ast_free(ast->data.an_return);
+               break;
        case an_stmt_expr:
                ast_free(ast->data.an_stmt_expr);
                break;
                break;
        case an_while:
                ast_free(ast->data.an_while.pred);
-               ast_free(ast->data.an_while.body);
+               for (int i = 0; i<ast->data.an_while.nbody; i++)
+                       ast_free(ast->data.an_while.body[i]);
+               free(ast->data.an_while.body);
                break;
        default:
                die("Unsupported AST node: %d\n", ast->type);
 
 };
 enum unop {negate,inverse};
 enum ast_type {
-       an_binop, an_bool, an_char, an_cons, an_fundecl, an_ident, an_if,
-       an_int, an_list, an_stmt_expr, an_unop, an_vardecl, an_while
+       an_assign, an_binop, an_bool, an_char, an_cons, an_funcall, an_fundecl,
+       an_ident, an_if, an_int, an_list, an_nil, an_return, an_stmt_expr,
+       an_unop, an_vardecl, an_while
 };
 struct ast {
        enum ast_type type;
        union {
+               struct {
+                       char *ident;
+                       struct ast *expr;
+               } an_assign;
                bool an_bool;
                struct {
                        struct ast *l;
                        struct ast *el;
                        struct ast *tail;
                } an_cons;
+               struct {
+                       char *ident;
+                       int nargs;
+                       struct ast **args;
+               } an_funcall;
                struct {
                        char *ident;
                        int nargs;
                        char **args;
-                       struct ast *body; // make struct ast **?
+                       int nbody;
+                       struct ast **body;
                } an_fundecl;
                struct {
                        struct ast *pred;
-                       struct ast *then;
-                       struct ast *els;
+                       int nthen;
+                       struct ast **then;
+                       int nels;
+                       struct ast **els;
                } an_if;
                int an_int;
                char *an_ident;
                        int n;
                        struct ast **ptr;
                } an_list;
+               //struct { } an_nil;
+               struct ast *an_return;
                struct ast *an_stmt_expr;
                struct {
                        enum unop op;
                } an_vardecl;
                struct {
                        struct ast *pred;
-                       struct ast *body;
+                       int nbody;
+                       struct ast **body;
                } an_while;
        } data;
 };
 
-struct ast *ast_bool(bool b);
+struct ast *ast_assign(struct ast *ident, struct ast *expr);
 struct ast *ast_binop(struct ast *l, enum binop op, struct ast *tail);
+struct ast *ast_bool(bool b);
 struct ast *ast_char(const char *c);
 struct ast *ast_cons(struct ast *el, struct ast *tail);
+struct ast *ast_funcall(struct ast *ident, struct ast *args);
 struct ast *ast_fundecl(struct ast *ident, struct ast *args, struct ast *body);
 struct ast *ast_if(struct ast *pred, struct ast *then, struct ast *els);
 struct ast *ast_int(int integer);
 struct ast *ast_ident(char *ident);
 struct ast *ast_list(struct ast *llist);
+struct ast *ast_nil();
+struct ast *ast_return(struct ast *rtrn);
 struct ast *ast_stmt_expr(struct ast *expr);
 struct ast *ast_unop(enum unop op, struct ast *l);
 struct ast *ast_vardecl(struct ast *ident, struct ast *l);