update todo
[ccc.git] / ast.c
1 #include <stdlib.h>
2 #include <stdio.h>
3 #include <string.h>
4
5 #include "util.h"
6 #include "ast.h"
7
8 #ifdef DEBUG
9 static const char *ast_type_str[] = {
10 [an_assign] = "assign", [an_bool] = "bool", [an_binop] = "binop",
11 [an_char] = "char", [an_cons] = "cons", [an_funcall] = "funcall",
12 [an_fundecl] = "fundecl", [an_ident] = "ident", [an_if] = "if",
13 [an_int] = "int", [an_nil] = "nil", [an_list] = "list",
14 [an_return] = "return", [an_stmt_expr] = "stmt_expr",
15 [an_unop] = "unop", [an_vardecl] = "vardecl", [an_while] = "while",
16 };
17 #endif
18 static const char *binop_str[] = {
19 [binor] = "||", [binand] = "&&", [eq] = "==", [neq] = "!=",
20 [leq] = "<=", [le] = "<", [geq] = ">=", [ge] = ">", [cons] = ":",
21 [plus] = "+", [minus] = "-", [times] = "*", [divide] = "/",
22 [modulo] = "%", [power] = "^",
23 };
24 static const char *fieldspec_str[] = {
25 [fst] = "fst", [snd] = "snd", [hd] = "hd", [tl] = "tl"};
26 static const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
27
28 #ifdef DEBUG
29 #define must_be(node, ntype, msg) {\
30 if ((node)->type != (ntype)) {\
31 fprintf(stderr, "%s can't be %s\n",\
32 msg, ast_type_str[node->type]);\
33 exit(1);\
34 }\
35 }
36 #else
37 #define must_be(node, ntype, msg) ;
38 #endif
39
40 #define ast_alloc() ((struct ast *)safe_malloc(sizeof(struct ast)))
41
42 struct ast *ast_assign(struct ast *ident, struct ast *expr)
43 {
44 struct ast *res = ast_alloc();
45 res->type = an_assign;
46 res->data.an_assign.ident = ident;
47 res->data.an_assign.expr = expr;
48 return res;
49 }
50
51 struct ast *ast_binop(struct ast *l, enum binop op, struct ast *r)
52 {
53 struct ast *res = ast_alloc();
54 res->type = an_binop;
55 res->data.an_binop.l = l;
56 res->data.an_binop.op = op;
57 res->data.an_binop.r = r;
58 return res;
59 }
60
61 struct ast *ast_bool(bool b)
62 {
63 struct ast *res = ast_alloc();
64 res->type = an_bool;
65 res->data.an_bool = b;
66 return res;
67 }
68
69 int fromHex(char c)
70 {
71 if (c >= '0' && c <= '9')
72 return c-'0';
73 if (c >= 'a' && c <= 'f')
74 return c-'a'+10;
75 if (c >= 'A' && c <= 'F')
76 return c-'A'+10;
77 return -1;
78 }
79
80 struct ast *ast_char(const char *c)
81 {
82 struct ast *res = ast_alloc();
83 res->type = an_char;
84 //regular char
85 if (strlen(c) == 3)
86 res->data.an_char = c[1];
87 //escape
88 if (strlen(c) == 4)
89 switch(c[2]) {
90 case '0': res->data.an_char = '\0'; break;
91 case 'a': res->data.an_char = '\a'; break;
92 case 'b': res->data.an_char = '\b'; break;
93 case 't': res->data.an_char = '\t'; break;
94 case 'v': res->data.an_char = '\v'; break;
95 case 'f': res->data.an_char = '\f'; break;
96 case 'r': res->data.an_char = '\r'; break;
97 }
98 //hex escape
99 if (strlen(c) == 6)
100 res->data.an_char = (fromHex(c[3])<<4)+fromHex(c[4]);
101 return res;
102 }
103
104 struct ast *ast_cons(struct ast *el, struct ast *tail)
105 {
106 struct ast *res = ast_alloc();
107 res->type = an_cons;
108 res->data.an_cons.el = el;
109 res->data.an_cons.tail = tail;
110 return res;
111 }
112
113 struct ast *ast_funcall(struct ast *ident, struct ast *args)
114 {
115 struct ast *res = ast_alloc();
116 res->type = an_funcall;
117
118 //ident
119 must_be(ident, an_ident, "ident of a funcall");
120 res->data.an_funcall.ident = ident->data.an_ident.ident;
121 free(ident->data.an_ident.fields);
122 free(ident);
123
124 //args
125 must_be(args, an_list, "args of a funcall");
126 res->data.an_funcall.nargs = args->data.an_list.n;
127 res->data.an_funcall.args = args->data.an_list.ptr;
128 free(args);
129 return res;
130 }
131
132 struct ast *ast_fundecl(struct ast *ident, struct ast *args, struct ast *body)
133 {
134 struct ast *res = ast_alloc();
135 res->type = an_fundecl;
136
137 //ident
138 must_be(ident, an_ident, "ident of a fundecl");
139 res->data.an_fundecl.ident = ident->data.an_ident.ident;
140 free(ident->data.an_ident.fields);
141 free(ident);
142
143 //args
144 must_be(args, an_list, "args of a fundecl");
145 res->data.an_fundecl.nargs = args->data.an_list.n;
146 res->data.an_fundecl.args = (char **)args->data.an_list.ptr;
147 for (int i = 0; i<args->data.an_list.n; i++) {
148 struct ast *e = args->data.an_list.ptr[i];
149 must_be(e, an_ident, "arg of a fundecl")
150 res->data.an_fundecl.args[i] = e->data.an_ident.ident;
151 free(e->data.an_ident.fields);
152 free(e);
153 }
154 free(args);
155
156 //body
157 must_be(body, an_list, "body of a fundecl");
158 res->data.an_fundecl.nbody = body->data.an_list.n;
159 res->data.an_fundecl.body = body->data.an_list.ptr;
160 free(body);
161
162 return res;
163 }
164
165 struct ast *ast_if(struct ast *pred, struct ast *then, struct ast *els)
166 {
167 struct ast *res = ast_alloc();
168 res->type = an_if;
169 res->data.an_if.pred = pred;
170
171 must_be(then, an_list, "body of a then");
172 res->data.an_if.nthen = then->data.an_list.n;
173 res->data.an_if.then = then->data.an_list.ptr;
174 free(then);
175
176 must_be(els, an_list, "body of a els");
177 res->data.an_if.nels = els->data.an_list.n;
178 res->data.an_if.els = els->data.an_list.ptr;
179 free(els);
180
181 return res;
182 }
183
184 struct ast *ast_int(int integer)
185 {
186 struct ast *res = ast_alloc();
187 res->type = an_int;
188 res->data.an_int = integer;
189 return res;
190 }
191
192 struct ast *ast_identc(char *ident)
193 {
194 struct ast *res = ast_alloc();
195 res->type = an_ident;
196 res->data.an_ident.ident = safe_strdup(ident);
197 res->data.an_ident.nfields = 0;
198 res->data.an_ident.fields = NULL;
199 return res;
200 }
201
202 struct ast *ast_ident(struct ast *ident, struct ast *fields)
203 {
204 struct ast *res = ast_alloc();
205 res->type = an_ident;
206 must_be(fields, an_ident, "ident of an ident");
207 res->data.an_ident.ident = ident->data.an_ident.ident;
208 free(ident);
209
210 must_be(fields, an_list, "fields of an ident");
211 res->data.an_ident.nfields = fields->data.an_list.n;
212 res->data.an_ident.fields = (enum fieldspec *)safe_malloc(
213 fields->data.an_list.n*sizeof(enum fieldspec));
214 for (int i = 0; i<fields->data.an_list.n; i++) {
215 struct ast *t = fields->data.an_list.ptr[i];
216 must_be(t, an_ident, "field of an ident");
217 if (strcmp(t->data.an_ident.ident, "fst") == 0)
218 res->data.an_ident.fields[i] = fst;
219 else if (strcmp(t->data.an_ident.ident, "snd") == 0)
220 res->data.an_ident.fields[i] = snd;
221 else if (strcmp(t->data.an_ident.ident, "hd") == 0)
222 res->data.an_ident.fields[i] = hd;
223 else if (strcmp(t->data.an_ident.ident, "tl") == 0)
224 res->data.an_ident.fields[i] = tl;
225 free(t->data.an_ident.ident);
226 free(t);
227 }
228 free(fields->data.an_list.ptr);
229 free(fields);
230 return res;
231 }
232
233 struct ast *ast_list(struct ast *llist)
234 {
235 struct ast *res = ast_alloc();
236 res->type = an_list;
237 res->data.an_list.n = 0;
238
239 int i = ast_llistlength(llist);
240
241 //Allocate array
242 res->data.an_list.n = i;
243 res->data.an_list.ptr = (struct ast **)safe_malloc(
244 res->data.an_list.n*sizeof(struct ast *));
245
246 struct ast *r = llist;
247 while(i > 0) {
248 res->data.an_list.ptr[--i] = r->data.an_cons.el;
249 struct ast *t = r;
250 r = r->data.an_cons.tail;
251 free(t);
252 }
253 return res;
254 }
255
256 struct ast *ast_nil()
257 {
258 struct ast *res = ast_alloc();
259 res->type = an_nil;
260 return res;
261 }
262
263 struct ast *ast_return(struct ast *r)
264 {
265 struct ast *res = ast_alloc();
266 res->type = an_return;
267 res->data.an_return = r;
268 return res;
269 }
270
271 struct ast *ast_stmt_expr(struct ast *expr)
272 {
273 struct ast *res = ast_alloc();
274 res->type = an_stmt_expr;
275 res->data.an_stmt_expr = expr;
276 return res;
277 }
278
279 struct ast *ast_unop(enum unop op, struct ast *l)
280 {
281 struct ast *res = ast_alloc();
282 res->type = an_unop;
283 res->data.an_unop.op = op;
284 res->data.an_unop.l = l;
285 return res;
286 }
287
288 struct ast *ast_vardecl(struct ast *ident, struct ast *l)
289 {
290 struct ast *res = ast_alloc();
291 res->type = an_vardecl;
292 must_be(ident, an_ident, "ident of a vardecl");
293
294 res->data.an_vardecl.ident = ident->data.an_ident.ident;
295 free(ident->data.an_ident.fields);
296 free(ident);
297 res->data.an_vardecl.l = l;
298 return res;
299 }
300
301 struct ast *ast_while(struct ast *pred, struct ast *body)
302 {
303 struct ast *res = ast_alloc();
304 res->type = an_while;
305 res->data.an_while.pred = pred;
306 must_be(body, an_list, "body of a while");
307 res->data.an_while.nbody = body->data.an_list.n;
308 res->data.an_while.body = body->data.an_list.ptr;
309 free(body);
310 return res;
311 }
312
313 int ast_llistlength(struct ast *r)
314 {
315 int i = 0;
316 while(r != NULL) {
317 i++;
318 if (r->type != an_cons) {
319 return 1;
320 }
321 r = r->data.an_cons.tail;
322 }
323 return i;
324 }
325
326 const char *cescapes[] = {
327 [0] = "\\0", [1] = "\\x01", [2] = "\\x02", [3] = "\\x03",
328 [4] = "\\x04", [5] = "\\x05", [6] = "\\x06", [7] = "\\a", [8] = "\\b",
329 [9] = "\\t", [10] = "\\n", [11] = "\\v", [12] = "\\f", [13] = "\\r",
330 [14] = "\\x0E", [15] = "\\x0F", [16] = "\\x10", [17] = "\\x11",
331 [18] = "\\x12", [19] = "\\x13", [20] = "\\x14", [21] = "\\x15",
332 [22] = "\\x16", [23] = "\\x17", [24] = "\\x18", [25] = "\\x19",
333 [26] = "\\x1A", [27] = "\\x1B", [28] = "\\x1C", [29] = "\\x1D",
334 [30] = "\\x1E", [31] = "\\x1F",
335 [127] = "\\x7F"
336 };
337
338 void ast_print(struct ast *ast, int indent, FILE *out)
339 {
340 if (ast == NULL)
341 return;
342 #ifdef DEBUG
343 fprintf(stderr, "ast_free(%s)\n", ast_type_str[ast->type]);
344 #endif
345 switch(ast->type) {
346 case an_assign:
347 pindent(indent, out);
348 ast_print(ast->data.an_assign.ident, indent, out);
349 safe_fprintf(out, " = ");
350 ast_print(ast->data.an_assign.expr, indent, out);
351 safe_fprintf(out, ";\n");
352 break;
353 case an_binop:
354 safe_fprintf(out, "(");
355 ast_print(ast->data.an_binop.l, indent, out);
356 safe_fprintf(out, "%s", binop_str[ast->data.an_binop.op]);
357 ast_print(ast->data.an_binop.r, indent, out);
358 safe_fprintf(out, ")");
359 break;
360 case an_bool:
361 safe_fprintf(out, "%s", ast->data.an_bool ? "true" : "false");
362 break;
363 case an_char:
364 if (ast->data.an_char < 0)
365 safe_fprintf(out, "'?'");
366 if (ast->data.an_char < ' ' || ast->data.an_char == 127)
367 safe_fprintf(out, "'%s'",
368 cescapes[(int)ast->data.an_char]);
369 else
370 safe_fprintf(out, "'%c'", ast->data.an_char);
371 break;
372 case an_funcall:
373 safe_fprintf(out, "%s(", ast->data.an_funcall.ident);
374 for(int i = 0; i<ast->data.an_fundecl.nargs; i++) {
375 ast_print(ast->data.an_funcall.args[i], indent, out);
376 if (i+1 < ast->data.an_fundecl.nargs)
377 safe_fprintf(out, ", ");
378 }
379 safe_fprintf(out, ")");
380 break;
381 case an_fundecl:
382 pindent(indent, out);
383 safe_fprintf(out, "%s (", ast->data.an_fundecl.ident);
384 for (int i = 0; i<ast->data.an_fundecl.nargs; i++) {
385 safe_fprintf(out, "%s", ast->data.an_fundecl.args[i]);
386 if (i < ast->data.an_fundecl.nargs - 1)
387 safe_fprintf(out, ", ");
388 }
389 safe_fprintf(out, ") {\n");
390 for (int i = 0; i<ast->data.an_fundecl.nbody; i++)
391 ast_print(ast->data.an_fundecl.body[i], indent+1, out);
392 pindent(indent, out);
393 safe_fprintf(out, "}\n");
394 break;
395 case an_if:
396 pindent(indent, out);
397 safe_fprintf(out, "if (");
398 ast_print(ast->data.an_if.pred, indent, out);
399 safe_fprintf(out, ") {\n");
400 for (int i = 0; i<ast->data.an_if.nthen; i++)
401 ast_print(ast->data.an_if.then[i], indent+1, out);
402 pindent(indent, out);
403 safe_fprintf(out, "} else {\n");
404 for (int i = 0; i<ast->data.an_if.nels; i++)
405 ast_print(ast->data.an_if.els[i], indent+1, out);
406 pindent(indent, out);
407 safe_fprintf(out, "}\n");
408 break;
409 case an_int:
410 safe_fprintf(out, "%d", ast->data.an_int);
411 break;
412 case an_ident:
413 fprintf(out, "%s", ast->data.an_ident.ident);
414 for (int i = 0; i<ast->data.an_ident.nfields; i++)
415 fprintf(out, ".%s",
416 fieldspec_str[ast->data.an_ident.fields[i]]);
417 break;
418 case an_cons:
419 ast_print(ast->data.an_cons.el, indent, out);
420 ast_print(ast->data.an_cons.tail, indent, out);
421 break;
422 case an_list:
423 for (int i = 0; i<ast->data.an_list.n; i++)
424 ast_print(ast->data.an_list.ptr[i], indent, out);
425 break;
426 case an_nil:
427 safe_fprintf(out, "[]");
428 break;
429 case an_return:
430 pindent(indent, out);
431 safe_fprintf(out, "return ");
432 ast_print(ast->data.an_return, indent, out);
433 safe_fprintf(out, ";\n");
434 break;
435 case an_stmt_expr:
436 pindent(indent, out);
437 ast_print(ast->data.an_stmt_expr, indent, out);
438 safe_fprintf(out, ";\n");
439 break;
440 case an_unop:
441 safe_fprintf(out, "(%s", unop_str[ast->data.an_unop.op]);
442 ast_print(ast->data.an_unop.l, indent, out);
443 safe_fprintf(out, ")");
444 break;
445 case an_vardecl:
446 pindent(indent, out);
447 safe_fprintf(out, "var %s = ", ast->data.an_vardecl.ident);
448 ast_print(ast->data.an_vardecl.l, indent, out);
449 safe_fprintf(out, ";\n");
450 break;
451 case an_while:
452 pindent(indent, out);
453 safe_fprintf(out, "while (");
454 ast_print(ast->data.an_while.pred, indent, out);
455 safe_fprintf(out, ") {\n");
456 for (int i = 0; i<ast->data.an_while.nbody; i++) {
457 ast_print(ast->data.an_while.body[i], indent+1, out);
458 }
459 pindent(indent, out);
460 safe_fprintf(out, "}\n");
461 break;
462 default:
463 die("Unsupported AST node\n");
464 }
465 }
466
467 void ast_free(struct ast *ast)
468 {
469 if (ast == NULL)
470 return;
471 #ifdef DEBUG
472 fprintf(stderr, "ast_free(%s)\n", ast_type_str[ast->type]);
473 #endif
474 switch(ast->type) {
475 case an_assign:
476 ast_free(ast->data.an_assign.ident);
477 ast_free(ast->data.an_assign.expr);
478 break;
479 case an_binop:
480 ast_free(ast->data.an_binop.l);
481 ast_free(ast->data.an_binop.r);
482 break;
483 case an_bool:
484 break;
485 case an_char:
486 break;
487 case an_cons:
488 ast_free(ast->data.an_cons.el);
489 ast_free(ast->data.an_cons.tail);
490 break;
491 case an_funcall:
492 free(ast->data.an_funcall.ident);
493 for (int i = 0; i<ast->data.an_fundecl.nargs; i++)
494 ast_free(ast->data.an_funcall.args[i]);
495 free(ast->data.an_funcall.args);
496 break;
497 case an_fundecl:
498 free(ast->data.an_fundecl.ident);
499 for (int i = 0; i<ast->data.an_fundecl.nargs; i++)
500 free(ast->data.an_fundecl.args[i]);
501 free(ast->data.an_fundecl.args);
502 for (int i = 0; i<ast->data.an_fundecl.nbody; i++)
503 ast_free(ast->data.an_fundecl.body[i]);
504 free(ast->data.an_fundecl.body);
505 break;
506 case an_if:
507 ast_free(ast->data.an_if.pred);
508 for (int i = 0; i<ast->data.an_if.nthen; i++)
509 ast_free(ast->data.an_if.then[i]);
510 free(ast->data.an_if.then);
511 for (int i = 0; i<ast->data.an_if.nels; i++)
512 ast_free(ast->data.an_if.els[i]);
513 free(ast->data.an_if.els);
514 break;
515 case an_int:
516 break;
517 case an_ident:
518 free(ast->data.an_ident.ident);
519 free(ast->data.an_ident.fields);
520 break;
521 case an_list:
522 for (int i = 0; i<ast->data.an_list.n; i++)
523 ast_free(ast->data.an_list.ptr[i]);
524 free(ast->data.an_list.ptr);
525 break;
526 case an_nil:
527 break;
528 case an_return:
529 ast_free(ast->data.an_return);
530 break;
531 case an_stmt_expr:
532 ast_free(ast->data.an_stmt_expr);
533 break;
534 case an_unop:
535 ast_free(ast->data.an_unop.l);
536 break;
537 case an_vardecl:
538 free(ast->data.an_vardecl.ident);
539 ast_free(ast->data.an_vardecl.l);
540 break;
541 case an_while:
542 ast_free(ast->data.an_while.pred);
543 for (int i = 0; i<ast->data.an_while.nbody; i++)
544 ast_free(ast->data.an_while.body[i]);
545 free(ast->data.an_while.body);
546 break;
547 default:
548 die("Unsupported AST node: %d\n", ast->type);
549 }
550 free(ast);
551 }