#include <stdlib.h>
#include <stdio.h>
+#include <string.h>
+#include "util.h"
#include "ast.h"
-struct ast *ast_alloc()
+static const char *ast_type_str[] = {
+ [an_bool] = "bool", [an_binop] = "binop", [an_char] = "char",
+ [an_cons] = "cons", [an_fundecl] = "fundecl", [an_ident] = "ident",
+ [an_if] = "if", [an_int] = "int", [an_list] = "list",
+ [an_stmt_expr] = "stmt_expr", [an_unop] = "unop",
+ [an_vardecl] = "vardecl", [an_while] = "while",
+};
+static const char *binop_str[] = {
+ [binor] = "||", [binand] = "&&", [eq] = "==", [neq] = "!=",
+ [leq] = "<=", [le] = "<", [geq] = ">=", [ge] = ">", [cons] = ":",
+ [plus] = "+", [minus] = "-", [times] = "*", [divide] = "/",
+ [modulo] = "%", [power] = "^",
+};
+static const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
+
+
+#define must_be(node, ntype, msg) {\
+ if ((node)->type != (ntype)) {\
+ fprintf(stderr, "%s can't be %s\n",\
+ msg, ast_type_str[node->type]);\
+ exit(1);\
+ }\
+}
+#define ast_alloc() ((struct ast *)safe_malloc(sizeof(struct ast)))
+
+struct ast *ast_bool(bool b)
{
- struct ast *res = malloc(sizeof(struct ast));
- if (res == NULL) {
- perror("malloc");
- exit(1);
- }
+ struct ast *res = ast_alloc();
+ res->type = an_bool;
+ res->data.an_bool = b;
+ return res;
+}
+
+struct ast *ast_binop(struct ast *l, enum binop op, struct ast *r)
+{
+ struct ast *res = ast_alloc();
+ res->type = an_binop;
+ res->data.an_binop.l = l;
+ res->data.an_binop.op = op;
+ res->data.an_binop.r = r;
+ return res;
+}
+
+int fromHex(char c)
+{
+ if (c >= '0' && c <= '9')
+ return c-'0';
+ if (c >= 'a' && c <= 'f')
+ return c-'a'+10;
+ if (c >= 'A' && c <= 'F')
+ return c-'A'+10;
+ return -1;
+}
+
+struct ast *ast_char(const char *c)
+{
+ struct ast *res = ast_alloc();
+ res->type = an_char;
+ //regular char
+ if (strlen(c) == 3)
+ res->data.an_char = c[1];
+ //escape
+ if (strlen(c) == 4)
+ switch(c[2]) {
+ case '0': res->data.an_char = '\0'; break;
+ case 'a': res->data.an_char = '\a'; break;
+ case 'b': res->data.an_char = '\b'; break;
+ case 't': res->data.an_char = '\t'; break;
+ case 'v': res->data.an_char = '\v'; break;
+ case 'f': res->data.an_char = '\f'; break;
+ case 'r': res->data.an_char = '\r'; break;
+ }
+ //hex escape
+ if (strlen(c) == 6)
+ res->data.an_char = (fromHex(c[3])<<4)+fromHex(c[4]);
return res;
}
return res;
}
-struct ast *ast_binop(struct ast *l, enum binop op, struct ast *r)
+struct ast *ast_fundecl(struct ast *ident, struct ast *args, struct ast *body)
{
struct ast *res = ast_alloc();
- res->type = an_binop;
- res->data.an_binop.l = l;
- res->data.an_binop.op = op;
- res->data.an_binop.r = r;
+ res->type = an_fundecl;
+
+ //ident
+ must_be(ident, an_ident, "ident of a fundecl");
+ res->data.an_fundecl.ident = ident->data.an_ident;
+ free(ident);
+
+ //args
+ must_be(args, an_list, "args of a fundecl");
+ res->data.an_fundecl.nargs = args->data.an_list.n;
+ res->data.an_fundecl.args = (char **)safe_malloc(
+ args->data.an_list.n*sizeof(char *));
+ for (int i = 0; i<args->data.an_list.n; i++) {
+ struct ast *e = args->data.an_list.ptr[i];
+ must_be(e, an_ident, "arg of a fundecl")
+ res->data.an_fundecl.args[i] = e->data.an_ident;
+ free(e);
+ }
+ free(args);
+
+ //body
+ must_be(body, an_list, "body of a fundecl");
+ res->data.an_fundecl.body = body;
+
+ return res;
+}
+
+struct ast *ast_if(struct ast *pred, struct ast *then, struct ast *els)
+{
+ struct ast *res = ast_alloc();
+ res->type = an_if;
+ res->data.an_if.pred = pred;
+ res->data.an_if.then = then;
+ res->data.an_if.els = els;
return res;
}
return res;
}
+struct ast *ast_ident(char *ident)
+{
+ struct ast *res = ast_alloc();
+ res->type = an_ident;
+ res->data.an_ident = safe_strdup(ident);
+ return res;
+}
+
+struct ast *ast_list(struct ast *llist)
+{
+ struct ast *res = ast_alloc();
+ res->type = an_list;
+ res->data.an_list.n = 0;
+
+ int i = ast_llistlength(llist);
+
+ //Allocate array
+ res->data.an_list.n = i;
+ res->data.an_list.ptr = (struct ast **)safe_malloc(
+ res->data.an_list.n*sizeof(struct ast *));
+
+ struct ast *r = llist;
+ while(i > 0) {
+ res->data.an_list.ptr[--i] = r->data.an_cons.el;
+ struct ast *t = r;
+ r = r->data.an_cons.tail;
+ free(t);
+ }
+ return res;
+}
+
+struct ast *ast_stmt_expr(struct ast *expr)
+{
+ struct ast *res = ast_alloc();
+ res->type = an_stmt_expr;
+ res->data.an_stmt_expr = expr;
+ return res;
+}
+
struct ast *ast_unop(enum unop op, struct ast *l)
{
struct ast *res = ast_alloc();
res->data.an_unop.l = l;
return res;
}
-static const char *binop_str[] = {
- [binor] = "||",
- [binand] = "&&",
- [eq] = "==",
- [neq] = "!=",
- [leq] = "<=",
- [le] = "<",
- [geq] = ">=",
- [ge] = ">",
- [cons] = ":",
- [plus] = "+",
- [minus] = "-",
- [times] = "*",
- [divide] = "/",
- [modulo] = "%",
- [power] = "^",
-};
-static const char *unop_str[] = {
- [inverse] = "!",
- [negate] = "-",
+
+struct ast *ast_vardecl(struct ast *ident, struct ast *l)
+{
+ struct ast *res = ast_alloc();
+ res->type = an_vardecl;
+ must_be(ident, an_ident, "ident of a vardecl");
+
+ res->data.an_vardecl.ident = ident->data.an_ident;
+ free(ident);
+ res->data.an_vardecl.l = l;
+ return res;
+}
+
+struct ast *ast_while(struct ast *pred, struct ast *body)
+{
+ struct ast *res = ast_alloc();
+ res->type = an_while;
+ res->data.an_while.pred = pred;
+ res->data.an_while.body = body;
+ return res;
+}
+
+int ast_llistlength(struct ast *r)
+{
+ int i = 0;
+ while(r != NULL) {
+ i++;
+ if (r->type != an_cons) {
+ return 1;
+ }
+ r = r->data.an_cons.tail;
+ }
+ return i;
+}
+
+const char *cescapes[] = {
+ [0] = "\\0", [1] = "\\x01", [2] = "\\x02", [3] = "\\x03",
+ [4] = "\\x04", [5] = "\\x05", [6] = "\\x06", [7] = "\\a", [8] = "\\b",
+ [9] = "\\t", [10] = "\\n", [11] = "\\v", [12] = "\\f", [13] = "\\r",
+ [14] = "\\x0E", [15] = "\\x0F", [16] = "\\x10", [17] = "\\x11",
+ [18] = "\\x12", [19] = "\\x13", [20] = "\\x14", [21] = "\\x15",
+ [22] = "\\x16", [23] = "\\x17", [24] = "\\x18", [25] = "\\x19",
+ [26] = "\\x1A", [27] = "\\x1B", [28] = "\\x1C", [29] = "\\x1D",
+ [30] = "\\x1E", [31] = "\\x1F",
+ [127] = "\\x7F"
};
-void ast_print(struct ast * ast, FILE *out)
+void ast_print(struct ast * ast, int indent, FILE *out)
{
if (ast == NULL)
return;
switch(ast->type) {
+ case an_bool:
+ safe_fprintf(out, "%s", ast->data.an_bool ? "true" : "false");
+ break;
case an_binop:
- fprintf(out, "(");
- ast_print(ast->data.an_binop.l, out);
- fprintf(out, "%s", binop_str[ast->data.an_binop.op]);
- ast_print(ast->data.an_binop.r, out);
- fprintf(out, ")");
+ safe_fprintf(out, "(");
+ ast_print(ast->data.an_binop.l, indent, out);
+ safe_fprintf(out, "%s", binop_str[ast->data.an_binop.op]);
+ ast_print(ast->data.an_binop.r, indent, out);
+ safe_fprintf(out, ")");
break;
- case an_cons:
- ast_print(ast->data.an_cons.el, out);
- fprintf(out, ";\n");
- ast_print(ast->data.an_cons.tail, out);
+ case an_char:
+ if (ast->data.an_char < 0)
+ safe_fprintf(out, "'?'");
+ if (ast->data.an_char < ' ' || ast->data.an_char == 127)
+ safe_fprintf(out, "'%s'",
+ cescapes[(int)ast->data.an_char]);
+ else
+ safe_fprintf(out, "'%c'", ast->data.an_char);
+ break;
+ case an_fundecl:
+ pindent(indent, out);
+ safe_fprintf(out, "%s (", ast->data.an_fundecl.ident);
+ for (int i = 0; i<ast->data.an_fundecl.nargs; i++) {
+ safe_fprintf(out, "%s", ast->data.an_fundecl.args[i]);
+ if (i < ast->data.an_fundecl.nargs - 1)
+ safe_fprintf(out, ", ");
+ }
+ safe_fprintf(out, ") {\n");
+ ast_print(ast->data.an_fundecl.body, indent+1, out);
+ pindent(indent, out);
+ safe_fprintf(out, "}\n");
+ break;
+ case an_if:
+ pindent(indent, out);
+ safe_fprintf(out, "if (");
+ ast_print(ast->data.an_if.pred, indent, out);
+ safe_fprintf(out, ") {\n");
+ if (ast->data.an_if.then->data.an_list.n > 0) {
+ pindent(indent, out);
+ ast_print(ast->data.an_if.then, indent+1, out);
+ }
+ pindent(indent, out);
+ safe_fprintf(out, "} else {\n");
+ if (ast->data.an_if.els->data.an_list.n > 0) {
+ pindent(indent, out);
+ ast_print(ast->data.an_if.els, indent+1, out);
+ }
+ pindent(indent, out);
+ safe_fprintf(out, "}\n");
break;
case an_int:
- fprintf(out, "%d", ast->data.an_int);
+ safe_fprintf(out, "%d", ast->data.an_int);
+ break;
+ case an_ident:
+ safe_fprintf(out, "%s", ast->data.an_ident);
+ break;
+ case an_cons:
+ ast_print(ast->data.an_cons.el, indent, out);
+ ast_print(ast->data.an_cons.tail, indent, out);
+ break;
+ case an_list:
+ for (int i = 0; i<ast->data.an_list.n; i++)
+ ast_print(ast->data.an_list.ptr[i], indent, out);
+ break;
+ case an_stmt_expr:
+ pindent(indent, out);
+ ast_print(ast->data.an_stmt_expr, indent, out);
+ safe_fprintf(out, ";\n");
break;
case an_unop:
- fprintf(out, "(-");
- ast_print(ast->data.an_unop.l, out);
- fprintf(out, ")");
+ safe_fprintf(out, "(%s", unop_str[ast->data.an_unop.op]);
+ ast_print(ast->data.an_unop.l, indent, out);
+ safe_fprintf(out, ")");
+ break;
+ case an_vardecl:
+ pindent(indent, out);
+ safe_fprintf(out, "var %s = ", ast->data.an_vardecl.ident);
+ ast_print(ast->data.an_vardecl.l, indent, out);
+ safe_fprintf(out, ";\n");
+ break;
+ case an_while:
+ pindent(indent, out);
+ safe_fprintf(out, "while (");
+ ast_print(ast->data.an_while.pred, indent, out);
+ safe_fprintf(out, ") {\n");
+ if (ast->data.an_while.body->data.an_list.n > 0) {
+ pindent(indent, out);
+ ast_print(ast->data.an_while.body, indent+1, out);
+ }
+ pindent(indent, out);
+ safe_fprintf(out, "}\n");
break;
default:
- fprintf(stderr, "Unsupported AST node\n");
- exit(1);
+ die("Unsupported AST node\n");
}
}
if (ast == NULL)
return;
switch(ast->type) {
+ case an_bool:
+ break;
case an_binop:
ast_free(ast->data.an_binop.l);
ast_free(ast->data.an_binop.r);
break;
+ case an_char:
+ break;
case an_cons:
ast_free(ast->data.an_cons.el);
ast_free(ast->data.an_cons.tail);
break;
+ case an_fundecl:
+ free(ast->data.an_fundecl.ident);
+ for (int i = 0; i<ast->data.an_fundecl.nargs; i++)
+ free(ast->data.an_fundecl.args[i]);
+ free(ast->data.an_fundecl.args);
+ ast_free(ast->data.an_fundecl.body);
+ break;
+ case an_if:
+ ast_free(ast->data.an_if.pred);
+ ast_free(ast->data.an_if.then);
+ ast_free(ast->data.an_if.els);
case an_int:
break;
+ case an_ident:
+ free(ast->data.an_ident);
+ break;
+ case an_list:
+ for (int i = 0; i<ast->data.an_list.n; i++)
+ ast_free(ast->data.an_list.ptr[i]);
+ free(ast->data.an_list.ptr);
+ break;
+ case an_stmt_expr:
+ ast_free(ast->data.an_stmt_expr);
+ break;
+ case an_unop:
+ ast_free(ast->data.an_unop.l);
+ break;
+ case an_vardecl:
+ free(ast->data.an_vardecl.ident);
+ ast_free(ast->data.an_vardecl.l);
+ break;
+ case an_while:
+ ast_free(ast->data.an_while.pred);
+ ast_free(ast->data.an_while.body);
+ break;
default:
- fprintf(stderr, "Unsupported AST node\n");
- exit(1);
+ die("Unsupported AST node: %d\n", ast->type);
}
free(ast);
}