add util header, characters, stmts, etc
[ccc.git] / ast.c
1 #include <stdlib.h>
2 #include <stdio.h>
3 #include <string.h>
4
5 #include "util.h"
6 #include "ast.h"
7
8 static const char *ast_type_str[] = {
9 [an_bool] = "bool", [an_binop] = "binop", [an_char] = "char",
10 [an_cons] = "cons", [an_fundecl] = "fundecl", [an_ident] = "ident",
11 [an_if] = "if", [an_int] = "int", [an_list] = "list",
12 [an_stmt_expr] = "stmt_expr", [an_unop] = "unop",
13 [an_vardecl] = "vardecl", [an_while] = "while",
14 };
15 static const char *binop_str[] = {
16 [binor] = "||", [binand] = "&&", [eq] = "==", [neq] = "!=",
17 [leq] = "<=", [le] = "<", [geq] = ">=", [ge] = ">", [cons] = ":",
18 [plus] = "+", [minus] = "-", [times] = "*", [divide] = "/",
19 [modulo] = "%", [power] = "^",
20 };
21 static const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
22
23
24 #define must_be(node, ntype, msg) {\
25 if ((node)->type != (ntype)) {\
26 fprintf(stderr, "%s can't be %s\n",\
27 msg, ast_type_str[node->type]);\
28 exit(1);\
29 }\
30 }
31 #define ast_alloc() ((struct ast *)safe_malloc(sizeof(struct ast)))
32
33 struct ast *ast_bool(bool b)
34 {
35 struct ast *res = ast_alloc();
36 res->type = an_bool;
37 res->data.an_bool = b;
38 return res;
39 }
40
41 struct ast *ast_binop(struct ast *l, enum binop op, struct ast *r)
42 {
43 struct ast *res = ast_alloc();
44 res->type = an_binop;
45 res->data.an_binop.l = l;
46 res->data.an_binop.op = op;
47 res->data.an_binop.r = r;
48 return res;
49 }
50
51 int fromHex(char c)
52 {
53 if (c >= '0' && c <= '9')
54 return c-'0';
55 if (c >= 'a' && c <= 'f')
56 return c-'a'+10;
57 if (c >= 'A' && c <= 'F')
58 return c-'A'+10;
59 return -1;
60 }
61
62 struct ast *ast_char(const char *c)
63 {
64 struct ast *res = ast_alloc();
65 res->type = an_char;
66 //regular char
67 if (strlen(c) == 3)
68 res->data.an_char = c[1];
69 //escape
70 if (strlen(c) == 4)
71 switch(c[2]) {
72 case '0': res->data.an_char = '\0'; break;
73 case 'a': res->data.an_char = '\a'; break;
74 case 'b': res->data.an_char = '\b'; break;
75 case 't': res->data.an_char = '\t'; break;
76 case 'v': res->data.an_char = '\v'; break;
77 case 'f': res->data.an_char = '\f'; break;
78 case 'r': res->data.an_char = '\r'; break;
79 }
80 //hex escape
81 if (strlen(c) == 6)
82 res->data.an_char = (fromHex(c[3])<<4)+fromHex(c[4]);
83 return res;
84 }
85
86 struct ast *ast_cons(struct ast *el, struct ast *tail)
87 {
88 struct ast *res = ast_alloc();
89 res->type = an_cons;
90 res->data.an_cons.el = el;
91 res->data.an_cons.tail = tail;
92 return res;
93 }
94
95 struct ast *ast_fundecl(struct ast *ident, struct ast *args, struct ast *body)
96 {
97 struct ast *res = ast_alloc();
98 res->type = an_fundecl;
99
100 //ident
101 must_be(ident, an_ident, "ident of a fundecl");
102 res->data.an_fundecl.ident = ident->data.an_ident;
103 free(ident);
104
105 //args
106 must_be(args, an_list, "args of a fundecl");
107 res->data.an_fundecl.nargs = args->data.an_list.n;
108 res->data.an_fundecl.args = (char **)safe_malloc(
109 args->data.an_list.n*sizeof(char *));
110 for (int i = 0; i<args->data.an_list.n; i++) {
111 struct ast *e = args->data.an_list.ptr[i];
112 must_be(e, an_ident, "arg of a fundecl")
113 res->data.an_fundecl.args[i] = e->data.an_ident;
114 free(e);
115 }
116 free(args);
117
118 //body
119 must_be(body, an_list, "body of a fundecl");
120 res->data.an_fundecl.body = body;
121
122 return res;
123 }
124
125 struct ast *ast_if(struct ast *pred, struct ast *then, struct ast *els)
126 {
127 struct ast *res = ast_alloc();
128 res->type = an_if;
129 res->data.an_if.pred = pred;
130 res->data.an_if.then = then;
131 res->data.an_if.els = els;
132 return res;
133 }
134
135 struct ast *ast_int(int integer)
136 {
137 struct ast *res = ast_alloc();
138 res->type = an_int;
139 res->data.an_int = integer;
140 return res;
141 }
142
143 struct ast *ast_ident(char *ident)
144 {
145 struct ast *res = ast_alloc();
146 res->type = an_ident;
147 res->data.an_ident = safe_strdup(ident);
148 return res;
149 }
150
151 struct ast *ast_list(struct ast *llist)
152 {
153 struct ast *res = ast_alloc();
154 res->type = an_list;
155 res->data.an_list.n = 0;
156
157 int i = ast_llistlength(llist);
158
159 //Allocate array
160 res->data.an_list.n = i;
161 res->data.an_list.ptr = (struct ast **)safe_malloc(
162 res->data.an_list.n*sizeof(struct ast *));
163
164 struct ast *r = llist;
165 while(i > 0) {
166 res->data.an_list.ptr[--i] = r->data.an_cons.el;
167 struct ast *t = r;
168 r = r->data.an_cons.tail;
169 free(t);
170 }
171 return res;
172 }
173
174 struct ast *ast_stmt_expr(struct ast *expr)
175 {
176 struct ast *res = ast_alloc();
177 res->type = an_stmt_expr;
178 res->data.an_stmt_expr = expr;
179 return res;
180 }
181
182 struct ast *ast_unop(enum unop op, struct ast *l)
183 {
184 struct ast *res = ast_alloc();
185 res->type = an_unop;
186 res->data.an_unop.op = op;
187 res->data.an_unop.l = l;
188 return res;
189 }
190
191 struct ast *ast_vardecl(struct ast *ident, struct ast *l)
192 {
193 struct ast *res = ast_alloc();
194 res->type = an_vardecl;
195 must_be(ident, an_ident, "ident of a vardecl");
196
197 res->data.an_vardecl.ident = ident->data.an_ident;
198 free(ident);
199 res->data.an_vardecl.l = l;
200 return res;
201 }
202
203 struct ast *ast_while(struct ast *pred, struct ast *body)
204 {
205 struct ast *res = ast_alloc();
206 res->type = an_while;
207 res->data.an_while.pred = pred;
208 res->data.an_while.body = body;
209 return res;
210 }
211
212 int ast_llistlength(struct ast *r)
213 {
214 int i = 0;
215 while(r != NULL) {
216 i++;
217 if (r->type != an_cons) {
218 return 1;
219 }
220 r = r->data.an_cons.tail;
221 }
222 return i;
223 }
224
225 const char *cescapes[] = {
226 [0] = "\\0", [1] = "\\x01", [2] = "\\x02", [3] = "\\x03",
227 [4] = "\\x04", [5] = "\\x05", [6] = "\\x06", [7] = "\\a", [8] = "\\b",
228 [9] = "\\t", [10] = "\\n", [11] = "\\v", [12] = "\\f", [13] = "\\r",
229 [14] = "\\x0E", [15] = "\\x0F", [16] = "\\x10", [17] = "\\x11",
230 [18] = "\\x12", [19] = "\\x13", [20] = "\\x14", [21] = "\\x15",
231 [22] = "\\x16", [23] = "\\x17", [24] = "\\x18", [25] = "\\x19",
232 [26] = "\\x1A", [27] = "\\x1B", [28] = "\\x1C", [29] = "\\x1D",
233 [30] = "\\x1E", [31] = "\\x1F",
234 [127] = "\\x7F"
235 };
236
237 void ast_print(struct ast * ast, int indent, FILE *out)
238 {
239 if (ast == NULL)
240 return;
241 switch(ast->type) {
242 case an_bool:
243 safe_fprintf(out, "%s", ast->data.an_bool ? "true" : "false");
244 break;
245 case an_binop:
246 safe_fprintf(out, "(");
247 ast_print(ast->data.an_binop.l, indent, out);
248 safe_fprintf(out, "%s", binop_str[ast->data.an_binop.op]);
249 ast_print(ast->data.an_binop.r, indent, out);
250 safe_fprintf(out, ")");
251 break;
252 case an_char:
253 if (ast->data.an_char < 0)
254 safe_fprintf(out, "'?'");
255 if (ast->data.an_char < ' ' || ast->data.an_char == 127)
256 safe_fprintf(out, "'%s'",
257 cescapes[(int)ast->data.an_char]);
258 else
259 safe_fprintf(out, "'%c'", ast->data.an_char);
260 break;
261 case an_fundecl:
262 pindent(indent, out);
263 safe_fprintf(out, "%s (", ast->data.an_fundecl.ident);
264 for (int i = 0; i<ast->data.an_fundecl.nargs; i++) {
265 safe_fprintf(out, "%s", ast->data.an_fundecl.args[i]);
266 if (i < ast->data.an_fundecl.nargs - 1)
267 safe_fprintf(out, ", ");
268 }
269 safe_fprintf(out, ") {\n");
270 ast_print(ast->data.an_fundecl.body, indent+1, out);
271 pindent(indent, out);
272 safe_fprintf(out, "}\n");
273 break;
274 case an_if:
275 pindent(indent, out);
276 safe_fprintf(out, "if (");
277 ast_print(ast->data.an_if.pred, indent, out);
278 safe_fprintf(out, ") {\n");
279 if (ast->data.an_if.then->data.an_list.n > 0) {
280 pindent(indent, out);
281 ast_print(ast->data.an_if.then, indent+1, out);
282 }
283 pindent(indent, out);
284 safe_fprintf(out, "} else {\n");
285 if (ast->data.an_if.els->data.an_list.n > 0) {
286 pindent(indent, out);
287 ast_print(ast->data.an_if.els, indent+1, out);
288 }
289 pindent(indent, out);
290 safe_fprintf(out, "}\n");
291 break;
292 case an_int:
293 safe_fprintf(out, "%d", ast->data.an_int);
294 break;
295 case an_ident:
296 safe_fprintf(out, "%s", ast->data.an_ident);
297 break;
298 case an_cons:
299 ast_print(ast->data.an_cons.el, indent, out);
300 ast_print(ast->data.an_cons.tail, indent, out);
301 break;
302 case an_list:
303 for (int i = 0; i<ast->data.an_list.n; i++)
304 ast_print(ast->data.an_list.ptr[i], indent, out);
305 break;
306 case an_stmt_expr:
307 pindent(indent, out);
308 ast_print(ast->data.an_stmt_expr, indent, out);
309 safe_fprintf(out, ";\n");
310 break;
311 case an_unop:
312 safe_fprintf(out, "(%s", unop_str[ast->data.an_unop.op]);
313 ast_print(ast->data.an_unop.l, indent, out);
314 safe_fprintf(out, ")");
315 break;
316 case an_vardecl:
317 pindent(indent, out);
318 safe_fprintf(out, "var %s = ", ast->data.an_vardecl.ident);
319 ast_print(ast->data.an_vardecl.l, indent, out);
320 safe_fprintf(out, ";\n");
321 break;
322 case an_while:
323 pindent(indent, out);
324 safe_fprintf(out, "while (");
325 ast_print(ast->data.an_while.pred, indent, out);
326 safe_fprintf(out, ") {\n");
327 if (ast->data.an_while.body->data.an_list.n > 0) {
328 pindent(indent, out);
329 ast_print(ast->data.an_while.body, indent+1, out);
330 }
331 pindent(indent, out);
332 safe_fprintf(out, "}\n");
333 break;
334 default:
335 die("Unsupported AST node\n");
336 }
337 }
338
339 void ast_free(struct ast *ast)
340 {
341 if (ast == NULL)
342 return;
343 switch(ast->type) {
344 case an_bool:
345 break;
346 case an_binop:
347 ast_free(ast->data.an_binop.l);
348 ast_free(ast->data.an_binop.r);
349 break;
350 case an_char:
351 break;
352 case an_cons:
353 ast_free(ast->data.an_cons.el);
354 ast_free(ast->data.an_cons.tail);
355 break;
356 case an_fundecl:
357 free(ast->data.an_fundecl.ident);
358 for (int i = 0; i<ast->data.an_fundecl.nargs; i++)
359 free(ast->data.an_fundecl.args[i]);
360 free(ast->data.an_fundecl.args);
361 ast_free(ast->data.an_fundecl.body);
362 break;
363 case an_if:
364 ast_free(ast->data.an_if.pred);
365 ast_free(ast->data.an_if.then);
366 ast_free(ast->data.an_if.els);
367 case an_int:
368 break;
369 case an_ident:
370 free(ast->data.an_ident);
371 break;
372 case an_list:
373 for (int i = 0; i<ast->data.an_list.n; i++)
374 ast_free(ast->data.an_list.ptr[i]);
375 free(ast->data.an_list.ptr);
376 break;
377 case an_stmt_expr:
378 ast_free(ast->data.an_stmt_expr);
379 break;
380 case an_unop:
381 ast_free(ast->data.an_unop.l);
382 break;
383 case an_vardecl:
384 free(ast->data.an_vardecl.ident);
385 ast_free(ast->data.an_vardecl.l);
386 break;
387 case an_while:
388 ast_free(ast->data.an_while.pred);
389 ast_free(ast->data.an_while.body);
390 break;
391 default:
392 die("Unsupported AST node: %d\n", ast->type);
393 }
394 free(ast);
395 }