rewrite to union type, much better
authorMart Lubbers <mart@martlubbers.net>
Mon, 8 Feb 2021 09:59:28 +0000 (10:59 +0100)
committerMart Lubbers <mart@martlubbers.net>
Mon, 8 Feb 2021 09:59:28 +0000 (10:59 +0100)
Makefile
ast.c
ast.h
expr.c
parse.y
scan.l
util.c
util.h

index b605a85..fec3916 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -1,5 +1,5 @@
 CFLAGS:=-Wall -Wextra -Werror -std=c99 -pedantic-errors -D_XOPEN_SOURCE=700 -ggdb
-YFLAGS:=-d
+YFLAGS:=-d -Wcounterexamples
 
 OBJECTS:=scan.o parse.o ast.o util.o
 
diff --git a/ast.c b/ast.c
index 95ef8fd..f7748f6 100644 (file)
--- a/ast.c
+++ b/ast.c
@@ -5,17 +5,6 @@
 #include "util.h"
 #include "ast.h"
 
-#ifdef DEBUG
-static const char *ast_type_str[] = {
-       [an_assign] = "assign", [an_bool] = "bool", [an_binop] = "binop",
-       [an_char] = "char", [an_cons] = "cons", [an_funcall] = "funcall",
-       [an_fundecl] = "fundecl", [an_ident] = "ident", [an_if] = "if",
-       [an_int] = "int", [an_nil] = "nil", [an_list] = "list",
-       [an_return] = "return", [an_stmt_expr] = "stmt_expr",
-       [an_tuple] = "tuple", [an_unop] = "unop", [an_vardecl] = "vardecl",
-       [an_while] = "while",
-};
-#endif
 static const char *binop_str[] = {
        [binor] = "||", [binand] = "&&", [eq] = "==", [neq] = "!=",
        [leq] = "<=", [le] = "<", [geq] = ">=", [ge] = ">", [cons] = ":",
@@ -26,312 +15,211 @@ static const char *fieldspec_str[] = {
        [fst] = "fst", [snd] = "snd", [hd] = "hd", [tl] = "tl"};
 static const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
 
-#ifdef DEBUG
-#define must_be(node, ntype, msg) {\
-       if ((node)->type != (ntype)) {\
-               fprintf(stderr, "%s can't be %s\n",\
-                       msg, ast_type_str[node->type]);\
-               exit(1);\
-       }\
-}
-#else
-#define must_be(node, ntype, msg) ;
-#endif
-
-#define ast_alloc() ((struct ast *)safe_malloc(sizeof(struct ast)))
-
-struct ast *ast_assign(struct ast *ident, struct ast *expr)
+struct ast *ast(struct list *decls)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_assign;
-       res->data.an_assign.ident = ident;
-       res->data.an_assign.expr = expr;
+       struct ast *res = safe_malloc(sizeof(struct ast));
+       res->decls = (struct decl **)list_to_array(decls, &res->ndecls, true);
        return res;
 }
 
-struct ast *ast_binop(struct ast *l, enum binop op, struct ast *r)
+struct vardecl vardecl(char *ident, struct expr *expr)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_binop;
-       res->data.an_binop.l = l;
-       res->data.an_binop.op = op;
-       res->data.an_binop.r = r;
-       return res;
+       return (struct vardecl) {.ident=ident, .expr=expr};
 }
 
-struct ast *ast_bool(bool b)
+struct decl *decl_fun(char *ident, struct list *args, struct list *body)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_bool;
-       res->data.an_bool = b;
+       struct decl *res = safe_malloc(sizeof(struct decl));
+       res->type = dfundecl;
+       res->data.dfun.ident = ident;
+       res->data.dfun.args = (char **)
+               list_to_array(args, &res->data.dfun.nargs, true);
+       res->data.dfun.body = (struct stmt **)
+               list_to_array(body, &res->data.dfun.nbody, true);
        return res;
 }
 
-int fromHex(char c)
-{
-       if (c >= '0' && c <= '9')
-               return c-'0';
-       if (c >= 'a' && c <= 'f')
-               return c-'a'+10;
-       if (c >= 'A' && c <= 'F')
-               return c-'A'+10;
-       return -1;
-}
-
-struct ast *ast_char(const char *c)
+struct decl *decl_var(struct vardecl vardecl)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_char;
-       //regular char
-       if (strlen(c) == 3)
-               res->data.an_char = c[1];
-       //escape
-       if (strlen(c) == 4)
-               switch(c[2]) {
-               case '0': res->data.an_char = '\0'; break;
-               case 'a': res->data.an_char = '\a'; break;
-               case 'b': res->data.an_char = '\b'; break;
-               case 't': res->data.an_char = '\t'; break;
-               case 'v': res->data.an_char = '\v'; break;
-               case 'f': res->data.an_char = '\f'; break;
-               case 'r': res->data.an_char = '\r'; break;
-               }
-       //hex escape
-       if (strlen(c) == 6)
-               res->data.an_char = (fromHex(c[3])<<4)+fromHex(c[4]);
+       struct decl *res = safe_malloc(sizeof(struct decl));
+       res->type = dvardecl;
+       res->data.dvar = vardecl;
        return res;
 }
 
-struct ast *ast_cons(struct ast *el, struct ast *tail)
+struct stmt *stmt_assign(char *ident, struct expr *expr)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_cons;
-       res->data.an_cons.el = el;
-       res->data.an_cons.tail = tail;
+       struct stmt *res = safe_malloc(sizeof(struct stmt));
+       res->type = sassign;
+       res->data.sassign.ident = ident;
+       res->data.sassign.expr = expr;
        return res;
 }
 
-struct ast *ast_funcall(struct ast *ident, struct ast *args)
+struct stmt *stmt_if(struct expr *pred, struct list *then, struct list *els)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_funcall;
-
-       //ident
-       must_be(ident, an_ident, "ident of a funcall");
-       res->data.an_funcall.ident = ident->data.an_ident.ident;
-       free(ident->data.an_ident.fields);
-       free(ident);
-
-       //args
-       must_be(args, an_list, "args of a funcall");
-       res->data.an_funcall.nargs = args->data.an_list.n;
-       res->data.an_funcall.args = args->data.an_list.ptr;
-       free(args);
+       struct stmt *res = safe_malloc(sizeof(struct stmt));
+       res->type = sif;
+       res->data.sif.pred = pred;
+       res->data.sif.then = (struct stmt **)
+               list_to_array(then, &res->data.sif.nthen, true);
+       res->data.sif.els = (struct stmt **)
+               list_to_array(els, &res->data.sif.nels, true);
        return res;
 }
 
-struct ast *ast_fundecl(struct ast *ident, struct ast *args, struct ast *body)
+struct stmt *stmt_return(struct expr *rtrn)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_fundecl;
-
-       //ident
-       must_be(ident, an_ident, "ident of a fundecl");
-       res->data.an_fundecl.ident = ident->data.an_ident.ident;
-       free(ident->data.an_ident.fields);
-       free(ident);
-
-       //args
-       must_be(args, an_list, "args of a fundecl");
-       res->data.an_fundecl.nargs = args->data.an_list.n;
-       res->data.an_fundecl.args = (char **)args->data.an_list.ptr;
-       for (int i = 0; i<args->data.an_list.n; i++) {
-               struct ast *e = args->data.an_list.ptr[i];
-               must_be(e, an_ident, "arg of a fundecl")
-               res->data.an_fundecl.args[i] = e->data.an_ident.ident;
-               free(e->data.an_ident.fields);
-               free(e);
-       }
-       free(args);
-
-       //body
-       must_be(body, an_list, "body of a fundecl");
-       res->data.an_fundecl.nbody = body->data.an_list.n;
-       res->data.an_fundecl.body = body->data.an_list.ptr;
-       free(body);
-
+       struct stmt *res = safe_malloc(sizeof(struct stmt));
+       res->type = sreturn;
+       res->data.sreturn = rtrn;
        return res;
 }
 
-struct ast *ast_if(struct ast *pred, struct ast *then, struct ast *els)
+struct stmt *stmt_expr(struct expr *expr)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_if;
-       res->data.an_if.pred = pred;
-
-       must_be(then, an_list, "body of a then");
-       res->data.an_if.nthen = then->data.an_list.n;
-       res->data.an_if.then = then->data.an_list.ptr;
-       free(then);
-
-       must_be(els, an_list, "body of a els");
-       res->data.an_if.nels = els->data.an_list.n;
-       res->data.an_if.els = els->data.an_list.ptr;
-       free(els);
-
+       struct stmt *res = safe_malloc(sizeof(struct stmt));
+       res->type = sexpr;
+       res->data.sexpr = expr;
        return res;
 }
 
-struct ast *ast_int(int integer)
+struct stmt *stmt_vardecl(struct vardecl vardecl)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_int;
-       res->data.an_int = integer;
+       struct stmt *res = safe_malloc(sizeof(struct stmt));
+       res->type = svardecl;
+       res->data.svardecl = vardecl;
        return res;
 }
 
-struct ast *ast_identc(char *ident)
+struct stmt *stmt_while(struct expr *pred, struct list *body)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_ident;
-       res->data.an_ident.ident = safe_strdup(ident);
-       res->data.an_ident.nfields = 0;
-       res->data.an_ident.fields = NULL;
+       struct stmt *res = safe_malloc(sizeof(struct stmt));
+       res->type = swhile;
+       res->data.swhile.pred = pred;
+       res->data.swhile.body = (struct stmt **)
+               list_to_array(body, &res->data.swhile.nbody, true);
        return res;
 }
 
-struct ast *ast_ident(struct ast *ident, struct ast *fields)
+struct expr *expr_binop(struct expr *l, enum binop op, struct expr *r)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_ident;
-       must_be(fields, an_ident, "ident of an ident");
-       res->data.an_ident.ident = ident->data.an_ident.ident;
-       free(ident);
-
-       must_be(fields, an_list, "fields of an ident");
-       res->data.an_ident.nfields = fields->data.an_list.n;
-       res->data.an_ident.fields = (enum fieldspec *)safe_malloc(
-               fields->data.an_list.n*sizeof(enum fieldspec));
-       for (int i = 0; i<fields->data.an_list.n; i++) {
-               struct ast *t = fields->data.an_list.ptr[i];
-               must_be(t, an_ident, "field of an ident");
-               if (strcmp(t->data.an_ident.ident, "fst") == 0)
-                       res->data.an_ident.fields[i] = fst;
-               else if (strcmp(t->data.an_ident.ident, "snd") == 0)
-                       res->data.an_ident.fields[i] = snd;
-               else if (strcmp(t->data.an_ident.ident, "hd") == 0)
-                       res->data.an_ident.fields[i] = hd;
-               else if (strcmp(t->data.an_ident.ident, "tl") == 0)
-                       res->data.an_ident.fields[i] = tl;
-               free(t->data.an_ident.ident);
-               free(t);
-       }
-       free(fields->data.an_list.ptr);
-       free(fields);
+       struct expr *res = safe_malloc(sizeof(struct expr));
+       res->type = ebinop;
+       res->data.ebinop.l = l;
+       res->data.ebinop.op = op;
+       res->data.ebinop.r = r;
        return res;
 }
 
-struct ast *ast_list(struct ast *llist)
+struct expr *expr_bool(bool b)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_list;
-       res->data.an_list.n = 0;
-
-       int i = ast_llistlength(llist);
-
-       //Allocate array
-       res->data.an_list.n = i;
-       res->data.an_list.ptr = (struct ast **)safe_malloc(
-               res->data.an_list.n*sizeof(struct ast *));
-
-       struct ast *r = llist;
-       while(i > 0) {
-               res->data.an_list.ptr[--i] = r->data.an_cons.el;
-               struct ast *t = r;
-               r = r->data.an_cons.tail;
-               free(t);
-       }
+       struct expr *res = safe_malloc(sizeof(struct expr));
+       res->type = ebool;
+       res->data.ebool = b;
        return res;
 }
-
-struct ast *ast_nil()
+int fromHex(char c)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_nil;
-       return res;
+       if (c >= '0' && c <= '9')
+               return c-'0';
+       if (c >= 'a' && c <= 'f')
+               return c-'a'+10;
+       if (c >= 'A' && c <= 'F')
+               return c-'A'+10;
+       return -1;
 }
 
-struct ast *ast_return(struct ast *r)
+struct expr *expr_char(const char *c)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_return;
-       res->data.an_return = r;
+       struct expr *res = safe_malloc(sizeof(struct expr));
+       res->type = echar;
+       //regular char
+       if (strlen(c) == 3)
+               res->data.echar = c[1];
+       //escape
+       if (strlen(c) == 4)
+               switch(c[2]) {
+               case '0': res->data.echar = '\0'; break;
+               case 'a': res->data.echar = '\a'; break;
+               case 'b': res->data.echar = '\b'; break;
+               case 't': res->data.echar = '\t'; break;
+               case 'v': res->data.echar = '\v'; break;
+               case 'f': res->data.echar = '\f'; break;
+               case 'r': res->data.echar = '\r'; break;
+               }
+       //hex escape
+       if (strlen(c) == 6)
+               res->data.echar = (fromHex(c[3])<<4)+fromHex(c[4]);
        return res;
 }
 
-struct ast *ast_stmt_expr(struct ast *expr)
+struct expr *expr_funcall(char *ident, struct list *args)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_stmt_expr;
-       res->data.an_stmt_expr = expr;
+       struct expr *res = safe_malloc(sizeof(struct expr));
+       res->type = efuncall;
+       res->data.efuncall.ident = ident;
+       res->data.efuncall.args = (struct expr **)
+               list_to_array(args, &res->data.efuncall.nargs, true);
        return res;
 }
 
-struct ast *ast_tuple(struct ast *left, struct ast *right)
+struct expr *expr_int(int integer)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_tuple;
-       res->data.an_tuple.left = left;
-       res->data.an_tuple.right = right;
+       struct expr *res = safe_malloc(sizeof(struct expr));
+       res->type = eint;
+       res->data.eint = integer;
        return res;
 }
 
-
-struct ast *ast_unop(enum unop op, struct ast *l)
+struct expr *expr_ident(char *ident, struct list *fields)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_unop;
-       res->data.an_unop.op = op;
-       res->data.an_unop.l = l;
+       struct expr *res = safe_malloc(sizeof(struct expr));
+       res->type = eident;
+       res->data.eident.ident = ident;
+
+       void **els = list_to_array(fields, &res->data.eident.nfields, true);
+       res->data.eident.fields = (enum fieldspec *)safe_malloc(
+               res->data.eident.nfields*sizeof(enum fieldspec));
+       for (int i = 0; i<res->data.eident.nfields; i++) {
+               char *t = els[i];
+               if (strcmp(t, "fst") == 0)
+                       res->data.eident.fields[i] = fst;
+               else if (strcmp(t, "snd") == 0)
+                       res->data.eident.fields[i] = snd;
+               else if (strcmp(t, "hd") == 0)
+                       res->data.eident.fields[i] = hd;
+               else if (strcmp(t, "tl") == 0)
+                       res->data.eident.fields[i] = tl;
+               free(t);
+       }
+       free(els);
        return res;
 }
 
-struct ast *ast_vardecl(struct ast *ident, struct ast *l)
+struct expr *expr_nil()
 {
-       struct ast *res = ast_alloc();
-       res->type = an_vardecl;
-       must_be(ident, an_ident, "ident of a vardecl");
-
-       res->data.an_vardecl.ident = ident->data.an_ident.ident;
-       free(ident->data.an_ident.fields);
-       free(ident);
-       res->data.an_vardecl.l = l;
+       struct expr *res = safe_malloc(sizeof(struct expr));
+       res->type = enil;
        return res;
 }
 
-struct ast *ast_while(struct ast *pred, struct ast *body)
+struct expr *expr_tuple(struct expr *left, struct expr *right)
 {
-       struct ast *res = ast_alloc();
-       res->type = an_while;
-       res->data.an_while.pred = pred;
-       must_be(body, an_list, "body of a while");
-       res->data.an_while.nbody = body->data.an_list.n;
-       res->data.an_while.body = body->data.an_list.ptr;
-       free(body);
+       struct expr *res = safe_malloc(sizeof(struct expr));
+       res->type = etuple;
+       res->data.etuple.left = left;
+       res->data.etuple.right = right;
        return res;
 }
 
-int ast_llistlength(struct ast *r)
+struct expr *expr_unop(enum unop op, struct expr *l)
 {
-       int i = 0;
-       while(r != NULL) {
-               i++;
-               if (r->type != an_cons) {
-                       return 1;
-               }
-               r = r->data.an_cons.tail;
-       }
-       return i;
+       struct expr *res = safe_malloc(sizeof(struct expr));
+       res->type = eunop;
+       res->data.eunop.op = op;
+       res->data.eunop.l = l;
+       return res;
 }
 
 const char *cescapes[] = {
@@ -346,139 +234,162 @@ const char *cescapes[] = {
        [127] = "\\x7F"
 };
 
-void ast_print(struct ast *ast, int indent, FILE *out)
+void ast_print(struct ast *ast, FILE *out)
 {
        if (ast == NULL)
                return;
-#ifdef DEBUG
-       fprintf(stderr, "ast_free(%s)\n", ast_type_str[ast->type]);
-#endif
-       switch(ast->type) {
-       case an_assign:
-               pindent(indent, out);
-               ast_print(ast->data.an_assign.ident, indent, out);
-               safe_fprintf(out, " = ");
-               ast_print(ast->data.an_assign.expr, indent, out);
-               safe_fprintf(out, ";\n");
-               break;
-       case an_binop:
-               safe_fprintf(out, "(");
-               ast_print(ast->data.an_binop.l, indent, out);
-               safe_fprintf(out, "%s", binop_str[ast->data.an_binop.op]);
-               ast_print(ast->data.an_binop.r, indent, out);
-               safe_fprintf(out, ")");
-               break;
-       case an_bool:
-               safe_fprintf(out, "%s", ast->data.an_bool ? "true" : "false");
-               break;
-       case an_char:
-               if (ast->data.an_char < 0)
-                       safe_fprintf(out, "'?'");
-               if (ast->data.an_char < ' ' || ast->data.an_char == 127)
-                       safe_fprintf(out, "'%s'",
-                               cescapes[(int)ast->data.an_char]);
-               else
-                       safe_fprintf(out, "'%c'", ast->data.an_char);
-               break;
-       case an_funcall:
-               safe_fprintf(out, "%s(", ast->data.an_funcall.ident);
-               for(int i = 0; i<ast->data.an_fundecl.nargs; i++) {
-                       ast_print(ast->data.an_funcall.args[i], indent, out);
-                       if (i+1 < ast->data.an_fundecl.nargs)
-                               safe_fprintf(out, ", ");
-               }
-               safe_fprintf(out, ")");
-               break;
-       case an_fundecl:
+       for (int i = 0; i<ast->ndecls; i++)
+               decl_print(ast->decls[i], 0, out);
+}
+
+void decl_print(struct decl *decl, int indent, FILE *out)
+{
+       if (decl == NULL)
+               return;
+       switch(decl->type) {
+       case dfundecl:
                pindent(indent, out);
-               safe_fprintf(out, "%s (", ast->data.an_fundecl.ident);
-               for (int i = 0; i<ast->data.an_fundecl.nargs; i++) {
-                       safe_fprintf(out, "%s", ast->data.an_fundecl.args[i]);
-                       if (i < ast->data.an_fundecl.nargs - 1)
+               safe_fprintf(out, "%s (", decl->data.dfun.ident);
+               for (int i = 0; i<decl->data.dfun.nargs; i++) {
+                       safe_fprintf(out, "%s", decl->data.dfun.args[i]);
+                       if (i < decl->data.dfun.nargs - 1)
                                safe_fprintf(out, ", ");
                }
                safe_fprintf(out, ") {\n");
-               for (int i = 0; i<ast->data.an_fundecl.nbody; i++)
-                       ast_print(ast->data.an_fundecl.body[i], indent+1, out);
+               for (int i = 0; i<decl->data.dfun.nbody; i++)
+                       stmt_print(decl->data.dfun.body[i], indent+1, out);
                pindent(indent, out);
                safe_fprintf(out, "}\n");
                break;
-       case an_if:
+       case dvardecl:
+               pindent(indent, out);
+               safe_fprintf(out, "var %s = ", decl->data.dvar.ident);
+               expr_print(decl->data.dvar.expr, out);
+               safe_fprintf(out, ";\n");
+               break;
+       default:
+               die("Unsupported decl node\n");
+       }
+}
+
+void stmt_print(struct stmt *stmt, int indent, FILE *out)
+{
+       if (stmt == NULL)
+               return;
+       switch(stmt->type) {
+       case sassign:
+               pindent(indent, out);
+               fprintf(out, "%s", stmt->data.sassign.ident);
+               safe_fprintf(out, " = ");
+               expr_print(stmt->data.sassign.expr, out);
+               safe_fprintf(out, ";\n");
+               break;
+       case sif:
                pindent(indent, out);
                safe_fprintf(out, "if (");
-               ast_print(ast->data.an_if.pred, indent, out);
+               expr_print(stmt->data.sif.pred, out);
                safe_fprintf(out, ") {\n");
-               for (int i = 0; i<ast->data.an_if.nthen; i++)
-                       ast_print(ast->data.an_if.then[i], indent+1, out);
+               for (int i = 0; i<stmt->data.sif.nthen; i++)
+                       stmt_print(stmt->data.sif.then[i], indent+1, out);
                pindent(indent, out);
                safe_fprintf(out, "} else {\n");
-               for (int i = 0; i<ast->data.an_if.nels; i++)
-                       ast_print(ast->data.an_if.els[i], indent+1, out);
+               for (int i = 0; i<stmt->data.sif.nels; i++)
+                       stmt_print(stmt->data.sif.els[i], indent+1, out);
                pindent(indent, out);
                safe_fprintf(out, "}\n");
                break;
-       case an_int:
-               safe_fprintf(out, "%d", ast->data.an_int);
-               break;
-       case an_ident:
-               fprintf(out, "%s", ast->data.an_ident.ident);
-               for (int i = 0; i<ast->data.an_ident.nfields; i++)
-                       fprintf(out, ".%s",
-                               fieldspec_str[ast->data.an_ident.fields[i]]);
-               break;
-       case an_cons:
-               ast_print(ast->data.an_cons.el, indent, out);
-               ast_print(ast->data.an_cons.tail, indent, out);
-               break;
-       case an_list:
-               for (int i = 0; i<ast->data.an_list.n; i++)
-                       ast_print(ast->data.an_list.ptr[i], indent, out);
-               break;
-       case an_nil:
-               safe_fprintf(out, "[]");
-               break;
-       case an_return:
+       case sreturn:
                pindent(indent, out);
                safe_fprintf(out, "return ");
-               ast_print(ast->data.an_return, indent, out);
+               expr_print(stmt->data.sreturn, out);
                safe_fprintf(out, ";\n");
                break;
-       case an_stmt_expr:
+       case sexpr:
                pindent(indent, out);
-               ast_print(ast->data.an_stmt_expr, indent, out);
+               expr_print(stmt->data.sexpr, out);
                safe_fprintf(out, ";\n");
                break;
-       case an_tuple:
-               safe_fprintf(out, "(");
-               ast_print(ast->data.an_tuple.left, indent, out);
-               safe_fprintf(out, ", ");
-               ast_print(ast->data.an_tuple.right, indent, out);
-               safe_fprintf(out, ")");
-               break;
-       case an_unop:
-               safe_fprintf(out, "(%s", unop_str[ast->data.an_unop.op]);
-               ast_print(ast->data.an_unop.l, indent, out);
-               safe_fprintf(out, ")");
-               break;
-       case an_vardecl:
+       case svardecl:
                pindent(indent, out);
-               safe_fprintf(out, "var %s = ", ast->data.an_vardecl.ident);
-               ast_print(ast->data.an_vardecl.l, indent, out);
+               safe_fprintf(out, "var %s = ", stmt->data.svardecl.ident);
+               expr_print(stmt->data.svardecl.expr, out);
                safe_fprintf(out, ";\n");
                break;
-       case an_while:
+       case swhile:
                pindent(indent, out);
                safe_fprintf(out, "while (");
-               ast_print(ast->data.an_while.pred, indent, out);
+               expr_print(stmt->data.swhile.pred, out);
                safe_fprintf(out, ") {\n");
-               for (int i = 0; i<ast->data.an_while.nbody; i++) {
-                       ast_print(ast->data.an_while.body[i], indent+1, out);
+               for (int i = 0; i<stmt->data.swhile.nbody; i++) {
+                       stmt_print(stmt->data.swhile.body[i], indent+1, out);
                }
                pindent(indent, out);
                safe_fprintf(out, "}\n");
                break;
        default:
-               die("Unsupported AST node\n");
+               die("Unsupported stmt node\n");
+       }
+}
+
+void expr_print(struct expr *expr, FILE *out)
+{
+       if (expr == NULL)
+               return;
+       switch(expr->type) {
+       case ebinop:
+               safe_fprintf(out, "(");
+               expr_print(expr->data.ebinop.l, out);
+               safe_fprintf(out, "%s", binop_str[expr->data.ebinop.op]);
+               expr_print(expr->data.ebinop.r, out);
+               safe_fprintf(out, ")");
+               break;
+       case ebool:
+               safe_fprintf(out, "%s", expr->data.ebool ? "true" : "false");
+               break;
+       case echar:
+               if (expr->data.echar < 0)
+                       safe_fprintf(out, "'?'");
+               if (expr->data.echar < ' ' || expr->data.echar == 127)
+                       safe_fprintf(out, "'%s'",
+                               cescapes[(int)expr->data.echar]);
+               else
+                       safe_fprintf(out, "'%c'", expr->data.echar);
+               break;
+       case efuncall:
+               safe_fprintf(out, "%s(", expr->data.efuncall.ident);
+               for(int i = 0; i<expr->data.efuncall.nargs; i++) {
+                       expr_print(expr->data.efuncall.args[i], out);
+                       if (i+1 < expr->data.efuncall.nargs)
+                               safe_fprintf(out, ", ");
+               }
+               safe_fprintf(out, ")");
+               break;
+       case eint:
+               safe_fprintf(out, "%d", expr->data.eint);
+               break;
+       case eident:
+               fprintf(out, "%s", expr->data.eident.ident);
+               for (int i = 0; i<expr->data.eident.nfields; i++)
+                       fprintf(out, ".%s",
+                               fieldspec_str[expr->data.eident.fields[i]]);
+               break;
+       case enil:
+               safe_fprintf(out, "[]");
+               break;
+       case etuple:
+               safe_fprintf(out, "(");
+               expr_print(expr->data.etuple.left, out);
+               safe_fprintf(out, ", ");
+               expr_print(expr->data.etuple.right, out);
+               safe_fprintf(out, ")");
+               break;
+       case eunop:
+               safe_fprintf(out, "(%s", unop_str[expr->data.eunop.op]);
+               expr_print(expr->data.eunop.l, out);
+               safe_fprintf(out, ")");
+               break;
+       default:
+               die("Unsupported expr node\n");
        }
 }
 
@@ -486,88 +397,109 @@ void ast_free(struct ast *ast)
 {
        if (ast == NULL)
                return;
-#ifdef DEBUG
-       fprintf(stderr, "ast_free(%s)\n", ast_type_str[ast->type]);
-#endif
-       switch(ast->type) {
-       case an_assign:
-               ast_free(ast->data.an_assign.ident);
-               ast_free(ast->data.an_assign.expr);
-               break;
-       case an_binop:
-               ast_free(ast->data.an_binop.l);
-               ast_free(ast->data.an_binop.r);
-               break;
-       case an_bool:
-               break;
-       case an_char:
-               break;
-       case an_cons:
-               ast_free(ast->data.an_cons.el);
-               ast_free(ast->data.an_cons.tail);
-               break;
-       case an_funcall:
-               free(ast->data.an_funcall.ident);
-               for (int i = 0; i<ast->data.an_fundecl.nargs; i++)
-                       ast_free(ast->data.an_funcall.args[i]);
-               free(ast->data.an_funcall.args);
-               break;
-       case an_fundecl:
-               free(ast->data.an_fundecl.ident);
-               for (int i = 0; i<ast->data.an_fundecl.nargs; i++)
-                       free(ast->data.an_fundecl.args[i]);
-               free(ast->data.an_fundecl.args);
-               for (int i = 0; i<ast->data.an_fundecl.nbody; i++)
-                       ast_free(ast->data.an_fundecl.body[i]);
-               free(ast->data.an_fundecl.body);
-               break;
-       case an_if:
-               ast_free(ast->data.an_if.pred);
-               for (int i = 0; i<ast->data.an_if.nthen; i++)
-                       ast_free(ast->data.an_if.then[i]);
-               free(ast->data.an_if.then);
-               for (int i = 0; i<ast->data.an_if.nels; i++)
-                       ast_free(ast->data.an_if.els[i]);
-               free(ast->data.an_if.els);
-               break;
-       case an_int:
-               break;
-       case an_ident:
-               free(ast->data.an_ident.ident);
-               free(ast->data.an_ident.fields);
-               break;
-       case an_list:
-               for (int i = 0; i<ast->data.an_list.n; i++)
-                       ast_free(ast->data.an_list.ptr[i]);
-               free(ast->data.an_list.ptr);
-               break;
-       case an_nil:
-               break;
-       case an_return:
-               ast_free(ast->data.an_return);
-               break;
-       case an_stmt_expr:
-               ast_free(ast->data.an_stmt_expr);
-               break;
-       case an_tuple:
-               ast_free(ast->data.an_tuple.left);
-               ast_free(ast->data.an_tuple.right);
-               break;
-       case an_unop:
-               ast_free(ast->data.an_unop.l);
-               break;
-       case an_vardecl:
-               free(ast->data.an_vardecl.ident);
-               ast_free(ast->data.an_vardecl.l);
-               break;
-       case an_while:
-               ast_free(ast->data.an_while.pred);
-               for (int i = 0; i<ast->data.an_while.nbody; i++)
-                       ast_free(ast->data.an_while.body[i]);
-               free(ast->data.an_while.body);
+       for (int i = 0; i<ast->ndecls; i++)
+               decl_free(ast->decls[i]);
+       free(ast);
+}
+
+void decl_free(struct decl *decl)
+{
+       if (decl == NULL)
+               return;
+       switch(decl->type) {
+       case dfundecl:
+               free(decl->data.dfun.ident);
+               for (int i = 0; i<decl->data.dfun.nargs; i++)
+                       free(decl->data.dfun.args[i]);
+               free(decl->data.dfun.args);
+               for (int i = 0; i<decl->data.dfun.nbody; i++)
+                       stmt_free(decl->data.dfun.body[i]);
+               free(decl->data.dfun.body);
+               break;
+       case dvardecl:
+               free(decl->data.dvar.ident);
+               expr_free(decl->data.dvar.expr);
                break;
        default:
-               die("Unsupported AST node: %d\n", ast->type);
+               die("Unsupported decl node\n");
        }
-       free(ast);
+       free(decl);
+}
+
+void stmt_free(struct stmt *stmt)
+{
+       if (stmt == NULL)
+               return;
+       switch(stmt->type) {
+       case sassign:
+               free(stmt->data.sassign.ident);
+               expr_free(stmt->data.sassign.expr);
+               break;
+       case sif:
+               expr_free(stmt->data.sif.pred);
+               for (int i = 0; i<stmt->data.sif.nthen; i++)
+                       stmt_free(stmt->data.sif.then[i]);
+               free(stmt->data.sif.then);
+               for (int i = 0; i<stmt->data.sif.nels; i++)
+                       stmt_free(stmt->data.sif.els[i]);
+               free(stmt->data.sif.els);
+               break;
+       case sreturn:
+               expr_free(stmt->data.sreturn);
+               break;
+       case sexpr:
+               expr_free(stmt->data.sexpr);
+               break;
+       case svardecl:
+               free(stmt->data.svardecl.ident);
+               expr_free(stmt->data.svardecl.expr);
+               break;
+       case swhile:
+               expr_free(stmt->data.swhile.pred);
+               for (int i = 0; i<stmt->data.swhile.nbody; i++)
+                       stmt_free(stmt->data.swhile.body[i]);
+               free(stmt->data.swhile.body);
+               break;
+       default:
+               die("Unsupported stmt node\n");
+       }
+       free(stmt);
+}
+
+void expr_free(struct expr *expr)
+{
+       switch(expr->type) {
+       case ebinop:
+               expr_free(expr->data.ebinop.l);
+               expr_free(expr->data.ebinop.r);
+               break;
+       case ebool:
+               break;
+       case echar:
+               break;
+       case efuncall:
+               free(expr->data.efuncall.ident);
+               for (int i = 0; i<expr->data.efuncall.nargs; i++)
+                       expr_free(expr->data.efuncall.args[i]);
+               free(expr->data.efuncall.args);
+               break;
+       case eint:
+               break;
+       case eident:
+               free(expr->data.eident.ident);
+               free(expr->data.eident.fields);
+               break;
+       case enil:
+               break;
+       case etuple:
+               expr_free(expr->data.etuple.left);
+               expr_free(expr->data.etuple.right);
+               break;
+       case eunop:
+               expr_free(expr->data.eunop.l);
+               break;
+       default:
+               die("Unsupported expr node\n");
+       }
+       free(expr);
 }
diff --git a/ast.h b/ast.h
index 12263a7..a001757 100644 (file)
--- a/ast.h
+++ b/ast.h
 #include <stdio.h>
 #include <stdbool.h>
 
-enum binop {
-       binor,binand,
-       eq,neq,leq,le,geq,ge,
-       cons,plus,minus,times,divide,modulo,power
+#include "util.h"
+
+struct ast {
+       int ndecls;
+       struct decl **decls;
 };
-enum fieldspec {fst,snd,hd,tl};
-enum unop {negate,inverse};
-enum ast_type {
-       an_assign, an_binop, an_bool, an_char, an_cons, an_funcall, an_fundecl,
-       an_ident, an_if, an_int, an_list, an_nil, an_return, an_stmt_expr,
-       an_tuple, an_unop, an_vardecl, an_while
+
+struct vardecl {
+       char *ident;
+       struct expr *expr;
 };
-struct ast {
-       enum ast_type type;
+
+struct decl {
+       enum {dfundecl, dvardecl} type;
        union {
-               struct {
-                       struct ast *ident;
-                       struct ast *expr;
-               } an_assign;
-               bool an_bool;
-               struct {
-                       struct ast *l;
-                       enum binop op;
-                       struct ast *r;
-               } an_binop;
-               char an_char;
-               struct {
-                       struct ast *el;
-                       struct ast *tail;
-               } an_cons;
-               struct {
-                       char *ident;
-                       int nargs;
-                       struct ast **args;
-               } an_funcall;
                struct {
                        char *ident;
                        int nargs;
                        char **args;
                        int nbody;
-                       struct ast **body;
-               } an_fundecl;
+                       struct stmt **body;
+               } dfun;
+               struct vardecl dvar;
+       } data;
+};
+
+struct stmt {
+       enum {sassign, sif, sreturn, sexpr, svardecl, swhile} type;
+       union {
+               struct {
+                       char *ident;
+                       struct expr *expr;
+               } sassign;
                struct {
-                       struct ast *pred;
+                       struct expr *pred;
                        int nthen;
-                       struct ast **then;
+                       struct stmt **then;
                        int nels;
-                       struct ast **els;
-               } an_if;
-               int an_int;
+                       struct stmt **els;
+               } sif;
+               struct expr *sreturn;
+               struct expr *sexpr;
+               struct vardecl svardecl;
+               struct {
+                       struct expr *pred;
+                       int nbody;
+                       struct stmt **body;
+               } swhile;
+       } data;
+};
+
+enum binop {
+       binor, binand, eq, neq, leq, le, geq, ge, cons, plus, minus, times,
+       divide, modulo, power,
+};
+enum fieldspec {fst,snd,hd,tl};
+enum unop {negate,inverse};
+struct expr {
+       enum {ebinop, ebool, echar, efuncall, eident, eint, enil, etuple,
+               eunop} type;
+       union {
+               bool ebool;
+               struct {
+                       struct expr *l;
+                       enum binop op;
+                       struct expr *r;
+               } ebinop;
+               char echar;
+               struct {
+                       char *ident;
+                       int nargs;
+                       struct expr **args;
+               } efuncall;
+               int eint;
                struct {
                        char *ident;
                        int nfields;
                        enum fieldspec *fields;
-               } an_ident;
-               struct {
-                       int n;
-                       struct ast **ptr;
-               } an_list;
-               //struct { } an_nil;
-               struct ast *an_return;
-               struct ast *an_stmt_expr;
+               } eident;
                struct {
-                       struct ast *left;
-                       struct ast *right;
-               } an_tuple;
+                       struct expr *left;
+                       struct expr *right;
+               } etuple;
                struct {
                        enum unop op;
-                       struct ast *l;
-               } an_unop;
-               struct {
-                       char *ident;
-                       struct ast *l;
-               } an_vardecl;
-               struct {
-                       struct ast *pred;
-                       int nbody;
-                       struct ast **body;
-               } an_while;
+                       struct expr *l;
+               } eunop;
        } data;
 };
 
-struct ast *ast_assign(struct ast *ident, struct ast *expr);
-struct ast *ast_binop(struct ast *l, enum binop op, struct ast *tail);
-struct ast *ast_bool(bool b);
-struct ast *ast_char(const char *c);
-struct ast *ast_cons(struct ast *el, struct ast *tail);
-struct ast *ast_funcall(struct ast *ident, struct ast *args);
-struct ast *ast_fundecl(struct ast *ident, struct ast *args, struct ast *body);
-struct ast *ast_if(struct ast *pred, struct ast *then, struct ast *els);
-struct ast *ast_int(int integer);
-struct ast *ast_identc(char *ident);
-struct ast *ast_ident(struct ast *ident, struct ast *fields);
-struct ast *ast_list(struct ast *llist);
-struct ast *ast_nil();
-struct ast *ast_return(struct ast *rtrn);
-struct ast *ast_stmt_expr(struct ast *expr);
-struct ast *ast_tuple(struct ast *left, struct ast *right);
-struct ast *ast_unop(enum unop op, struct ast *l);
-struct ast *ast_vardecl(struct ast *ident, struct ast *l);
-struct ast *ast_while(struct ast *pred, struct ast *body);
+struct ast *ast(struct list *decls);
+
+struct vardecl vardecl(char *ident, struct expr *expr);
+
+struct decl *decl_fun(char *ident, struct list *args, struct list *body);
+struct decl *decl_var(struct vardecl vardecl);
+
+struct stmt *stmt_assign(char *ident, struct expr *expr);
+struct stmt *stmt_if(struct expr *pred, struct list *then, struct list *els);
+struct stmt *stmt_return(struct expr *rtrn);
+struct stmt *stmt_expr(struct expr *expr);
+struct stmt *stmt_vardecl(struct vardecl vardecl);
+struct stmt *stmt_while(struct expr *pred, struct list *body);
+
+struct expr *expr_binop(struct expr *l, enum binop op, struct expr *r);
+struct expr *expr_bool(bool b);
+struct expr *expr_char(const char *c);
+struct expr *expr_funcall(char *ident, struct list *args);
+struct expr *expr_int(int integer);
+struct expr *expr_ident(char *ident, struct list *fields);
+struct expr *expr_nil();
+struct expr *expr_tuple(struct expr *left, struct expr *right);
+struct expr *expr_unop(enum unop op, struct expr *l);
 
-int ast_llistlength(struct ast *list);
+void ast_print(struct ast *, FILE *out);
+void decl_print(struct decl *ast, int indent, FILE *out);
+void stmt_print(struct stmt *ast, int indent, FILE *out);
+void expr_print(struct expr *ast, FILE *out);
 
-void ast_print(struct ast * ast, int indent, FILE *out);
-void ast_free(struct ast *ast);
+void ast_free(struct ast *);
+void decl_free(struct decl *ast);
+void stmt_free(struct stmt *ast);
+void expr_free(struct expr *ast);
 
 #endif
diff --git a/expr.c b/expr.c
index b5c8228..725e397 100644 (file)
--- a/expr.c
+++ b/expr.c
@@ -7,12 +7,17 @@ extern int yylex_destroy(void);
 
 int main()
 {
-       struct ast *result = NULL;
+       fprintf(stderr, "sizeof(struct ast): %lu\n", sizeof(struct ast));
+       fprintf(stderr, "sizeof(struct vardecl): %lu\n", sizeof(struct vardecl));
+       fprintf(stderr, "sizeof(struct decl): %lu\n", sizeof(struct decl));
+       fprintf(stderr, "sizeof(struct stmt): %lu\n", sizeof(struct stmt));
+       fprintf(stderr, "sizeof(struct expr): %lu\n", sizeof(struct expr));
+       struct ast *result;
         int r = yyparse(&result);
        if (r != 0)
                return r;
        yylex_destroy();
-       ast_print(result, 0, stdout);
+       ast_print(result, stdout);
        ast_free(result);
        return 0;
 }
diff --git a/parse.y b/parse.y
index 4311833..3cc1d44 100644 (file)
--- a/parse.y
+++ b/parse.y
@@ -4,11 +4,9 @@
 #include <stdlib.h>
 
 #include "ast.h"
-#define YYSTYPE struct ast *
 
 #include "y.tab.h"
 
-int yylex_debug = 1;
 int yylex(void);
 
 void yyerror(struct ast **result, const char *str)
@@ -24,10 +22,21 @@ int yywrap()
 
 %}
 
+%union
+{
+       struct expr *expr;
+       struct stmt *stmt;
+       struct list *list;
+       struct vardecl vardecl;
+       struct decl *decl;
+       char *ident;
+}
+
 //%define parse.error verbose
-%token ASSIGN BCLOSE BINAND BINOR BOOL BOPEN CCLOSE CHAR COMMA CONS COPEN
-%token DIVIDE DOT ELSE IDENT IF INTEGER INVERSE MINUS MODULO NIL PLUS POWER
-%token RETURN SEMICOLON TIMES VAR WHILE
+%token <ident> IDENT
+%token <expr> BOOL CHAR INTEGER
+%token ASSIGN BCLOSE BINAND BINOR BOPEN CCLOSE COMMA CONS COPEN DIVIDE DOT ELSE
+%token IF INVERSE MINUS MODULO NIL PLUS POWER RETURN SEMICOLON TIMES VAR WHILE
 
 %parse-param { struct ast **result }
 
@@ -39,87 +48,86 @@ int yywrap()
 %left TIMES DIVIDE MODULO
 %right POWER
 
-%%
+%type <ast> start
+%type <expr> expr
+%type <list> args body decls fargs field fnargs nargs
+%type <decl> fundecl
+%type <vardecl> vardecl
+%type <stmt> stmt
 
-start : decls { *result = ast_list($1); } ;
+%%
 
+start : decls { *result = ast($1); } ;
 decls
        : { $$ = NULL; }
-       | decls vardecl SEMICOLON { $$ = ast_cons($2, $1); }
-       | decls fundecl { $$ = ast_cons($2, $1); }
+       | decls vardecl SEMICOLON { $$ = list_cons(decl_var($2), $1); }
+       | decls fundecl { $$ = list_cons($2, $1); }
        ;
-
 vardecl
-       : VAR IDENT ASSIGN expr { $$ = ast_vardecl($2, $4); }
+       : VAR IDENT ASSIGN expr { $$ = vardecl($2, $4); }
        ;
-
 fundecl
        : IDENT BOPEN args BCLOSE COPEN body CCLOSE
-               { $$ = ast_fundecl($1, ast_list($3), ast_list($6)); }
+               { $$ = decl_fun($1, $3, $6); }
        ;
-
 args
        : { $$ = NULL; }
        | nargs
        ;
 nargs
-       : nargs COMMA IDENT { $$ = ast_cons($3, $1); }
-       | IDENT { $$ = ast_cons($1, NULL); }
+       : nargs COMMA IDENT { $$ = list_cons($3, $1); }
+       | IDENT { $$ = list_cons($1, NULL); }
        ;
-
 fargs
        : { $$ = NULL; }
        | fnargs
        ;
 fnargs
-       : fnargs COMMA expr { $$ = ast_cons($3, $1); }
-       | expr { $$ = ast_cons($1, NULL); }
+       : fnargs COMMA expr { $$ = list_cons($3, $1); }
+       | expr { $$ = list_cons($1, NULL); }
        ;
 body
        : { $$ = NULL; }
-       | body vardecl SEMICOLON { $$ = ast_cons($2, $1); }
-       | body stmt { $$ = ast_cons($2, $1); }
+       | body stmt { $$ = list_cons($2, $1); }
        ;
-
 stmt
        : IF BOPEN expr BCLOSE COPEN body CCLOSE ELSE COPEN body CCLOSE
-               { $$ = ast_if($3, ast_list($6), ast_list($10)); }
+               { $$ = stmt_if($3, $6, $10); }
        | WHILE BOPEN expr BCLOSE COPEN body CCLOSE
-               { $$ = ast_while($3, ast_list($6)); }
-       | expr SEMICOLON { $$ = ast_stmt_expr($1); }
-       | IDENT ASSIGN expr SEMICOLON { $$ = ast_assign($1, $3); }
-       | RETURN expr SEMICOLON { $$ = ast_return($2); }
-       | RETURN SEMICOLON { $$ = ast_return(NULL); }
+               { $$ = stmt_while($3, $6); }
+       | expr SEMICOLON { $$ = stmt_expr($1); }
+       | IDENT ASSIGN expr SEMICOLON { $$ = stmt_assign($1, $3); }
+       | RETURN expr SEMICOLON { $$ = stmt_return($2); }
+       | RETURN SEMICOLON { $$ = stmt_return(NULL); }
+       | vardecl SEMICOLON { $$ = stmt_vardecl($1); }
        ;
-
 field
        : { $$ = NULL; }
-       | field DOT IDENT { $$ = ast_cons($3, $1); }
-
+       | field DOT IDENT { $$ = list_cons($3, $1); }
 expr
-       : expr BINOR expr { $$ = ast_binop($1, binor, $3); }
-       | expr BINAND expr { $$ = ast_binop($1, binand, $3); }
-       | expr EQ expr { $$ = ast_binop($1, eq, $3); }
-       | expr NEQ expr { $$ = ast_binop($1, neq, $3); }
-       | expr LEQ expr { $$ = ast_binop($1, leq, $3); }
-       | expr LE expr { $$ = ast_binop($1, le, $3); }
-       | expr GEQ expr { $$ = ast_binop($1, geq, $3); }
-       | expr GE expr { $$ = ast_binop($1, ge, $3); }
-       | expr CONS expr { $$ = ast_binop($1, cons, $3); }
-       | expr PLUS expr { $$ = ast_binop($1, plus, $3); }
-       | expr MINUS expr { $$ = ast_binop($1, minus, $3); }
-       | expr TIMES expr { $$ = ast_binop($1, times, $3); }
-       | expr DIVIDE expr { $$ = ast_binop($1, divide, $3); }
-       | expr MODULO expr { $$ = ast_binop($1, modulo, $3); }
-       | expr POWER expr { $$ = ast_binop($1, power, $3); }
-       | MINUS expr { $$ = ast_unop(negate, $2); }
-       | INVERSE expr %prec TIMES { $$ = ast_unop(inverse, $2); }
-       | BOPEN expr COMMA expr BCLOSE { $$ = ast_tuple($2, $4); }
+       : expr BINOR expr { $$ = expr_binop($1, binor, $3); }
+       | expr BINAND expr { $$ = expr_binop($1, binand, $3); }
+       | expr EQ expr { $$ = expr_binop($1, eq, $3); }
+       | expr NEQ expr { $$ = expr_binop($1, neq, $3); }
+       | expr LEQ expr { $$ = expr_binop($1, leq, $3); }
+       | expr LE expr { $$ = expr_binop($1, le, $3); }
+       | expr GEQ expr { $$ = expr_binop($1, geq, $3); }
+       | expr GE expr { $$ = expr_binop($1, ge, $3); }
+       | expr CONS expr { $$ = expr_binop($1, cons, $3); }
+       | expr PLUS expr { $$ = expr_binop($1, plus, $3); }
+       | expr MINUS expr { $$ = expr_binop($1, minus, $3); }
+       | expr TIMES expr { $$ = expr_binop($1, times, $3); }
+       | expr DIVIDE expr { $$ = expr_binop($1, divide, $3); }
+       | expr MODULO expr { $$ = expr_binop($1, modulo, $3); }
+       | expr POWER expr { $$ = expr_binop($1, power, $3); }
+       | MINUS expr { $$ = expr_unop(negate, $2); }
+       | INVERSE expr %prec TIMES { $$ = expr_unop(inverse, $2); }
+       | BOPEN expr COMMA expr BCLOSE { $$ = expr_tuple($2, $4); }
        | BOPEN expr BCLOSE { $$ = $2; }
-       | IDENT BOPEN fargs BCLOSE { $$ = ast_funcall($1, ast_list($3)); }
+       | IDENT BOPEN fargs BCLOSE { $$ = expr_funcall($1, $3); }
        | INTEGER
        | BOOL
        | CHAR
-       | IDENT field { $$ = ast_ident($1, ast_list($2)); }
-       | NIL { $$ = ast_nil(); }
+       | IDENT field { $$ = expr_ident($1, $2); }
+       | NIL { $$ = expr_nil(); }
        ;
diff --git a/scan.l b/scan.l
index c6ebb87..048d104 100644 (file)
--- a/scan.l
+++ b/scan.l
@@ -4,7 +4,6 @@
 
 #include <stdio.h>
 #include "ast.h"
-#define YYSTYPE struct ast *
 #include "y.tab.h"
 extern YYSTYPE yylval;
 
@@ -16,8 +15,8 @@ if          return IF;
 else        return ELSE;
 while       return WHILE;
 var         return VAR;
-true        { yylval = ast_bool(true); return BOOL; }
-false       { yylval = ast_bool(false); return BOOL; }
+true        { yylval.expr = expr_bool(true); return BOOL; }
+false       { yylval.expr = expr_bool(false); return BOOL; }
 return      return RETURN;
 =           return ASSIGN;
 !           return INVERSE;
@@ -45,11 +44,11 @@ return      return RETURN;
 \.          return DOT;
 ,           return COMMA;
 '([^']|\\[abtnvfr]|\\x[0-9a-fA-F]{2})' {
-       yylval = ast_char(yytext); return CHAR; }
+       yylval.expr = expr_char(yytext); return CHAR; }
 [0-9]+ {
-       yylval = ast_int(atoi(yytext)); return INTEGER; }
+       yylval.expr = expr_int(atoi(yytext)); return INTEGER; }
 [_a-zA-Z][_a-zA-Z0-9]* {
-       yylval = ast_identc(yytext); return IDENT; }
+       yylval.ident = safe_strdup(yytext); return IDENT; }
 [ \n\t]  ;
 
 %%
diff --git a/util.c b/util.c
index 1a89f94..4b77f53 100644 (file)
--- a/util.c
+++ b/util.c
@@ -3,6 +3,53 @@
 #include <stdio.h>
 #include <string.h>
 
+#include "util.h"
+
+struct list *list_cons(void *el, struct list *tail)
+{
+       struct list *res = safe_malloc(sizeof(struct list));
+       res->el = el;
+       res->tail = tail;
+       return res;
+}
+
+void list_free(struct list *head, void (*freefun)(void *))
+{
+       while (head != NULL) {
+               freefun(head->el);
+               head = head->tail;
+       }
+}
+
+void **list_to_array(struct list *list, int *num, bool reverse)
+{
+       int i = list_length(list);
+       *num = i;
+       void **ptr = safe_malloc(i*sizeof(void *));
+
+       struct list *r = list;
+       while(i > 0) {
+               if (reverse)
+                       ptr[--i] = r->el;
+               else
+                       ptr[*num-(--i)] = r->el;
+               struct list *t = r;
+               r = r->tail;
+               free(t);
+       }
+       return ptr;
+}
+
+int list_length(struct list *r)
+{
+       int i = 0;
+       while(r != NULL) {
+               i++;
+               r = r->tail;
+       }
+       return i;
+}
+
 void pdie(const char *msg)
 {
        perror(msg);
diff --git a/util.h b/util.h
index 57dd17b..d8b40ba 100644 (file)
--- a/util.h
+++ b/util.h
@@ -2,6 +2,16 @@
 #define UTIL_H
 
 #include <stdarg.h>
+#include <stdbool.h>
+
+struct list {
+       void *el;
+       struct list *tail;
+};
+struct list *list_cons(void *el, struct list *tail);
+void list_free(struct list *head, void (*freefun)(void *));
+void **list_to_array(struct list *list, int *num, bool reverse);
+int list_length(struct list *head);
 
 void die(const char *msg, ...);
 void pdie(const char *msg);