edgecases
[ccc.git] / ast.c
1 #include <stdlib.h>
2 #include <stdio.h>
3 #include <string.h>
4
5 #include "util.h"
6 #include "ast.h"
7 #include "y.tab.h"
8
9 static const char *binop_str[] = {
10 [binor] = "||", [binand] = "&&", [eq] = "==", [neq] = "!=",
11 [leq] = "<=", [le] = "<", [geq] = ">=", [ge] = ">", [cons] = ":",
12 [plus] = "+", [minus] = "-", [times] = "*", [divide] = "/",
13 [modulo] = "%", [power] = "^",
14 };
15 static const char *fieldspec_str[] = {
16 [fst] = "fst", [snd] = "snd", [hd] = "hd", [tl] = "tl"};
17 static const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
18 static const char *basictype_str[] = {
19 [btbool] = "Bool", [btchar] = "Char", [btint] = "Int",
20 [btvoid] = "Void",
21 };
22
23 struct ast *ast(struct list *decls)
24 {
25 struct ast *res = safe_malloc(sizeof(struct ast));
26 res->decls = (struct decl **)list_to_array(decls, &res->ndecls, true);
27 return res;
28 }
29
30 struct vardecl *vardecl(struct type *type, char *ident, struct expr *expr)
31 {
32 struct vardecl *res = safe_malloc(sizeof(struct vardecl));
33 res->type = type;
34 res->ident = ident;
35 res->expr = expr;
36 return res;
37 }
38
39 struct decl *decl_fun(char *ident, struct list *args, struct list *atypes,
40 struct type *rtype, struct list *vars, struct list *body)
41 {
42 struct decl *res = safe_malloc(sizeof(struct decl));
43 res->type = dfundecl;
44 res->data.dfun.ident = ident;
45 res->data.dfun.args = (char **)
46 list_to_array(args, &res->data.dfun.nargs, true);
47 res->data.dfun.atypes = (struct type **)
48 list_to_array(atypes, &res->data.dfun.natypes, true);
49 res->data.dfun.rtype = rtype;
50 res->data.dfun.vars = (struct vardecl **)
51 list_to_array(vars, &res->data.dfun.nvar, true);
52 res->data.dfun.body = (struct stmt **)
53 list_to_array(body, &res->data.dfun.nbody, true);
54 return res;
55 }
56
57 struct decl *decl_var(struct vardecl *vardecl)
58 {
59 struct decl *res = safe_malloc(sizeof(struct decl));
60 res->type = dvardecl;
61 res->data.dvar = vardecl;
62 return res;
63 }
64
65 struct stmt *stmt_assign(char *ident, struct list *fields, struct expr *expr)
66 {
67 struct stmt *res = safe_malloc(sizeof(struct stmt));
68 res->type = sassign;
69 res->data.sassign.ident = ident;
70 res->data.sassign.fields = (char **)
71 list_to_array(fields, &res->data.sassign.nfield, true);
72 res->data.sassign.expr = expr;
73 return res;
74 }
75
76 struct stmt *stmt_if(struct expr *pred, struct list *then, struct list *els)
77 {
78 struct stmt *res = safe_malloc(sizeof(struct stmt));
79 res->type = sif;
80 res->data.sif.pred = pred;
81 res->data.sif.then = (struct stmt **)
82 list_to_array(then, &res->data.sif.nthen, true);
83 res->data.sif.els = (struct stmt **)
84 list_to_array(els, &res->data.sif.nels, true);
85 return res;
86 }
87
88 struct stmt *stmt_return(struct expr *rtrn)
89 {
90 struct stmt *res = safe_malloc(sizeof(struct stmt));
91 res->type = sreturn;
92 res->data.sreturn = rtrn;
93 return res;
94 }
95
96 struct stmt *stmt_expr(struct expr *expr)
97 {
98 struct stmt *res = safe_malloc(sizeof(struct stmt));
99 res->type = sexpr;
100 res->data.sexpr = expr;
101 return res;
102 }
103
104 struct stmt *stmt_while(struct expr *pred, struct list *body)
105 {
106 struct stmt *res = safe_malloc(sizeof(struct stmt));
107 res->type = swhile;
108 res->data.swhile.pred = pred;
109 res->data.swhile.body = (struct stmt **)
110 list_to_array(body, &res->data.swhile.nbody, true);
111 return res;
112 }
113
114 struct expr *expr_binop(struct expr *l, enum binop op, struct expr *r)
115 {
116 struct expr *res = safe_malloc(sizeof(struct expr));
117 res->type = ebinop;
118 res->data.ebinop.l = l;
119 res->data.ebinop.op = op;
120 res->data.ebinop.r = r;
121 return res;
122 }
123
124 struct expr *expr_bool(bool b)
125 {
126 struct expr *res = safe_malloc(sizeof(struct expr));
127 res->type = ebool;
128 res->data.ebool = b;
129 return res;
130 }
131 int fromHex(char c)
132 {
133 if (c >= '0' && c <= '9')
134 return c-'0';
135 if (c >= 'a' && c <= 'f')
136 return c-'a'+10;
137 if (c >= 'A' && c <= 'F')
138 return c-'A'+10;
139 return -1;
140 }
141
142 struct expr *expr_char(const char *c)
143 {
144 struct expr *res = safe_malloc(sizeof(struct expr));
145 res->type = echar;
146 //regular char
147 if (c[0] == '\'' && c[2] == '\'')
148 res->data.echar = c[1];
149 //escape
150 else if (c[0] == '\'' && c[1] == '\\' && c[3] == '\'')
151 switch(c[2]) {
152 case '0': res->data.echar = '\0'; break;
153 case '\'': res->data.echar = '\''; break;
154 case '\\': res->data.echar = '\\'; break;
155 case 'a': res->data.echar = '\a'; break;
156 case 'b': res->data.echar = '\b'; break;
157 case 't': res->data.echar = '\t'; break;
158 case 'v': res->data.echar = '\v'; break;
159 case 'f': res->data.echar = '\f'; break;
160 case 'r': res->data.echar = '\r'; break;
161 }
162 //hex escape
163 else if (c[0] == '\'' && c[1] == '\\' && c[2] == 'x' && c[5] == '\'')
164 res->data.echar = (fromHex(c[3])<<4)+fromHex(c[4]);
165 else
166 die("malformed character: %s\n", c);
167 return res;
168 }
169
170 struct expr *expr_funcall(char *ident, struct list *args)
171 {
172 struct expr *res = safe_malloc(sizeof(struct expr));
173 res->type = efuncall;
174 res->data.efuncall.ident = ident;
175 res->data.efuncall.args = (struct expr **)
176 list_to_array(args, &res->data.efuncall.nargs, true);
177 return res;
178 }
179
180 struct expr *expr_int(int integer)
181 {
182 struct expr *res = safe_malloc(sizeof(struct expr));
183 res->type = eint;
184 res->data.eint = integer;
185 return res;
186 }
187
188 struct expr *expr_ident(char *ident, struct list *fields)
189 {
190 struct expr *res = safe_malloc(sizeof(struct expr));
191 res->type = eident;
192 res->data.eident.ident = ident;
193
194 void **els = list_to_array(fields, &res->data.eident.nfields, true);
195 res->data.eident.fields = (enum fieldspec *)safe_malloc(
196 res->data.eident.nfields*sizeof(enum fieldspec));
197 for (int i = 0; i<res->data.eident.nfields; i++) {
198 char *t = els[i];
199 if (strcmp(t, "fst") == 0)
200 res->data.eident.fields[i] = fst;
201 else if (strcmp(t, "snd") == 0)
202 res->data.eident.fields[i] = snd;
203 else if (strcmp(t, "hd") == 0)
204 res->data.eident.fields[i] = hd;
205 else if (strcmp(t, "tl") == 0)
206 res->data.eident.fields[i] = tl;
207 free(t);
208 }
209 free(els);
210 return res;
211 }
212
213 struct expr *expr_nil()
214 {
215 struct expr *res = safe_malloc(sizeof(struct expr));
216 res->type = enil;
217 return res;
218 }
219
220 struct expr *expr_tuple(struct expr *left, struct expr *right)
221 {
222 struct expr *res = safe_malloc(sizeof(struct expr));
223 res->type = etuple;
224 res->data.etuple.left = left;
225 res->data.etuple.right = right;
226 return res;
227 }
228
229 struct expr *expr_string(char *str)
230 {
231 struct expr *res = safe_malloc(sizeof(struct expr));
232 res->type = estring;
233 res->data.estring = safe_strdup(str+1);
234 res->data.estring[strlen(res->data.estring)-1] = '\0';
235 //TODO escapes
236 return res;
237 }
238
239 struct expr *expr_unop(enum unop op, struct expr *l)
240 {
241 struct expr *res = safe_malloc(sizeof(struct expr));
242 res->type = eunop;
243 res->data.eunop.op = op;
244 res->data.eunop.l = l;
245 return res;
246 }
247
248 struct type *type_basic(enum basictype type)
249 {
250 struct type *res = safe_malloc(sizeof(struct type));
251 res->type = tbasic;
252 res->data.tbasic = type;
253 return res;
254 }
255
256 struct type *type_list(struct type *type)
257 {
258 struct type *res = safe_malloc(sizeof(struct type));
259 res->type = tlist;
260 res->data.tlist = type;
261 return res;
262 }
263
264 struct type *type_tuple(struct type *l, struct type *r)
265 {
266 struct type *res = safe_malloc(sizeof(struct type));
267 res->type = ttuple;
268 res->data.ttuple.l = l;
269 res->data.ttuple.r = r;
270 return res;
271 }
272
273 struct type *type_var(char *ident)
274 {
275 struct type *res = safe_malloc(sizeof(struct type));
276 if (strcmp(ident, "Int") == 0) {
277 res->type = tbasic;
278 res->data.tbasic = btint;
279 free(ident);
280 } else if (strcmp(ident, "Char") == 0) {
281 res->type = tbasic;
282 res->data.tbasic = btchar;
283 free(ident);
284 } else if (strcmp(ident, "Bool") == 0) {
285 res->type = tbasic;
286 res->data.tbasic = btbool;
287 free(ident);
288 } else if (strcmp(ident, "Void") == 0) {
289 res->type = tbasic;
290 res->data.tbasic = btvoid;
291 free(ident);
292 } else {
293 res->type = tvar;
294 res->data.tvar = ident;
295 }
296 return res;
297 }
298
299
300 const char *cescapes[] = {
301 [0] = "0", [1] = "x01", [2] = "x02", [3] = "x03",
302 [4] = "x04", [5] = "x05", [6] = "x06", [7] = "a", [8] = "b",
303 [9] = "t", [10] = "n", [11] = "v", [12] = "f", [13] = "r",
304 [14] = "x0E", [15] = "x0F", [16] = "x10", [17] = "x11",
305 [18] = "x12", [19] = "x13", [20] = "x14", [21] = "x15",
306 [22] = "x16", [23] = "x17", [24] = "x18", [25] = "x19",
307 [26] = "x1A", [27] = "x1B", [28] = "x1C", [29] = "x1D",
308 [30] = "x1E", [31] = "x1F",
309 ['\\'] = "\\", ['\''] = "'",
310 [127] = "x7F"
311 };
312
313 void ast_print(struct ast *ast, FILE *out)
314 {
315 if (ast == NULL)
316 return;
317 for (int i = 0; i<ast->ndecls; i++)
318 decl_print(ast->decls[i], 0, out);
319 }
320
321 void vardecl_print(struct vardecl *decl, int indent, FILE *out)
322 {
323 pindent(indent, out);
324 if (decl->type == NULL)
325 safe_fprintf(out, "var");
326 else
327 type_print(decl->type, out);
328 safe_fprintf(out, " %s = ", decl->ident);
329 expr_print(decl->expr, out);
330 safe_fprintf(out, ";\n");
331 }
332
333 void decl_print(struct decl *decl, int indent, FILE *out)
334 {
335 if (decl == NULL)
336 return;
337 switch(decl->type) {
338 case dfundecl:
339 pindent(indent, out);
340 safe_fprintf(out, "%s (", decl->data.dfun.ident);
341 for (int i = 0; i<decl->data.dfun.nargs; i++) {
342 safe_fprintf(out, "%s", decl->data.dfun.args[i]);
343 if (i < decl->data.dfun.nargs - 1)
344 safe_fprintf(out, ", ");
345 }
346 safe_fprintf(out, ")");
347 if (decl->data.dfun.rtype != NULL) {
348 safe_fprintf(out, " :: ");
349 for (int i = 0; i<decl->data.dfun.natypes; i++) {
350 type_print(decl->data.dfun.atypes[i], out);
351 safe_fprintf(out, " ");
352 }
353 safe_fprintf(out, "-> ");
354 type_print(decl->data.dfun.rtype, out);
355 }
356 safe_fprintf(out, " {\n");
357 for (int i = 0; i<decl->data.dfun.nvar; i++)
358 vardecl_print(decl->data.dfun.vars[i], indent+1, out);
359 for (int i = 0; i<decl->data.dfun.nbody; i++)
360 stmt_print(decl->data.dfun.body[i], indent+1, out);
361 pindent(indent, out);
362 safe_fprintf(out, "}\n");
363 break;
364 case dvardecl:
365 vardecl_print(decl->data.dvar, indent, out);
366 break;
367 default:
368 die("Unsupported decl node\n");
369 }
370 }
371
372 void stmt_print(struct stmt *stmt, int indent, FILE *out)
373 {
374 if (stmt == NULL)
375 return;
376 switch(stmt->type) {
377 case sassign:
378 pindent(indent, out);
379 fprintf(out, "%s", stmt->data.sassign.ident);
380 for (int i = 0; i<stmt->data.sassign.nfield; i++)
381 fprintf(out, ".%s", stmt->data.sassign.fields[i]);
382 safe_fprintf(out, " = ");
383 expr_print(stmt->data.sassign.expr, out);
384 safe_fprintf(out, ";\n");
385 break;
386 case sif:
387 pindent(indent, out);
388 safe_fprintf(out, "if (");
389 expr_print(stmt->data.sif.pred, out);
390 safe_fprintf(out, ") {\n");
391 for (int i = 0; i<stmt->data.sif.nthen; i++)
392 stmt_print(stmt->data.sif.then[i], indent+1, out);
393 pindent(indent, out);
394 safe_fprintf(out, "} else {\n");
395 for (int i = 0; i<stmt->data.sif.nels; i++)
396 stmt_print(stmt->data.sif.els[i], indent+1, out);
397 pindent(indent, out);
398 safe_fprintf(out, "}\n");
399 break;
400 case sreturn:
401 pindent(indent, out);
402 safe_fprintf(out, "return ");
403 expr_print(stmt->data.sreturn, out);
404 safe_fprintf(out, ";\n");
405 break;
406 case sexpr:
407 pindent(indent, out);
408 expr_print(stmt->data.sexpr, out);
409 safe_fprintf(out, ";\n");
410 break;
411 case swhile:
412 pindent(indent, out);
413 safe_fprintf(out, "while (");
414 expr_print(stmt->data.swhile.pred, out);
415 safe_fprintf(out, ") {\n");
416 for (int i = 0; i<stmt->data.swhile.nbody; i++) {
417 stmt_print(stmt->data.swhile.body[i], indent+1, out);
418 }
419 pindent(indent, out);
420 safe_fprintf(out, "}\n");
421 break;
422 default:
423 die("Unsupported stmt node\n");
424 }
425 }
426
427 void expr_print(struct expr *expr, FILE *out)
428 {
429 if (expr == NULL)
430 return;
431 switch(expr->type) {
432 case ebinop:
433 safe_fprintf(out, "(");
434 expr_print(expr->data.ebinop.l, out);
435 safe_fprintf(out, "%s", binop_str[expr->data.ebinop.op]);
436 expr_print(expr->data.ebinop.r, out);
437 safe_fprintf(out, ")");
438 break;
439 case ebool:
440 safe_fprintf(out, "%s", expr->data.ebool ? "true" : "false");
441 break;
442 case echar:
443 if (expr->data.echar < 0) {
444 safe_fprintf(out, "'?'");
445 } else if (expr->data.echar < ' ' || expr->data.echar == 127
446 || expr->data.echar == '\\'
447 || expr->data.echar == '\'') {
448 safe_fprintf(out, "'\\%s'",
449 cescapes[(int)expr->data.echar]);
450 } else {
451 safe_fprintf(out, "'%c'", expr->data.echar);
452 }
453 break;
454 case efuncall:
455 safe_fprintf(out, "%s(", expr->data.efuncall.ident);
456 for(int i = 0; i<expr->data.efuncall.nargs; i++) {
457 expr_print(expr->data.efuncall.args[i], out);
458 if (i+1 < expr->data.efuncall.nargs)
459 safe_fprintf(out, ", ");
460 }
461 safe_fprintf(out, ")");
462 break;
463 case eint:
464 safe_fprintf(out, "%d", expr->data.eint);
465 break;
466 case eident:
467 fprintf(out, "%s", expr->data.eident.ident);
468 for (int i = 0; i<expr->data.eident.nfields; i++)
469 fprintf(out, ".%s",
470 fieldspec_str[expr->data.eident.fields[i]]);
471 break;
472 case enil:
473 safe_fprintf(out, "[]");
474 break;
475 case etuple:
476 safe_fprintf(out, "(");
477 expr_print(expr->data.etuple.left, out);
478 safe_fprintf(out, ", ");
479 expr_print(expr->data.etuple.right, out);
480 safe_fprintf(out, ")");
481 break;
482 case estring:
483 safe_fprintf(out, "\"%s\"", expr->data.estring);
484 break;
485 case eunop:
486 safe_fprintf(out, "(%s", unop_str[expr->data.eunop.op]);
487 expr_print(expr->data.eunop.l, out);
488 safe_fprintf(out, ")");
489 break;
490 default:
491 die("Unsupported expr node\n");
492 }
493 }
494
495 void type_print(struct type *type, FILE *out)
496 {
497 if (type == NULL)
498 return;
499 switch (type->type) {
500 case tbasic:
501 safe_fprintf(out, "%s", basictype_str[type->data.tbasic]);
502 break;
503 case tlist:
504 safe_fprintf(out, "[");
505 type_print(type->data.tlist, out);
506 safe_fprintf(out, "]");
507 break;
508 case ttuple:
509 safe_fprintf(out, "(");
510 type_print(type->data.ttuple.l, out);
511 safe_fprintf(out, ",");
512 type_print(type->data.ttuple.r, out);
513 safe_fprintf(out, ")");
514 break;
515 case tvar:
516 safe_fprintf(out, "%s", type->data.tvar);
517 break;
518 default:
519 die("Unsupported type node\n");
520 }
521 }
522
523 void ast_free(struct ast *ast)
524 {
525 if (ast == NULL)
526 return;
527 for (int i = 0; i<ast->ndecls; i++)
528 decl_free(ast->decls[i]);
529 free(ast->decls);
530 free(ast);
531 }
532
533 void vardecl_free(struct vardecl *decl)
534 {
535 type_free(decl->type);
536 free(decl->ident);
537 expr_free(decl->expr);
538 free(decl);
539 }
540
541 void decl_free(struct decl *decl)
542 {
543 if (decl == NULL)
544 return;
545 switch(decl->type) {
546 case dfundecl:
547 free(decl->data.dfun.ident);
548 for (int i = 0; i<decl->data.dfun.nargs; i++)
549 free(decl->data.dfun.args[i]);
550 free(decl->data.dfun.args);
551 for (int i = 0; i<decl->data.dfun.natypes; i++)
552 type_free(decl->data.dfun.atypes[i]);
553 free(decl->data.dfun.atypes);
554 type_free(decl->data.dfun.rtype);
555 for (int i = 0; i<decl->data.dfun.nvar; i++)
556 vardecl_free(decl->data.dfun.vars[i]);
557 free(decl->data.dfun.vars);
558 for (int i = 0; i<decl->data.dfun.nbody; i++)
559 stmt_free(decl->data.dfun.body[i]);
560 free(decl->data.dfun.body);
561 break;
562 case dvardecl:
563 vardecl_free(decl->data.dvar);
564 break;
565 default:
566 die("Unsupported decl node\n");
567 }
568 free(decl);
569 }
570
571 void stmt_free(struct stmt *stmt)
572 {
573 if (stmt == NULL)
574 return;
575 switch(stmt->type) {
576 case sassign:
577 free(stmt->data.sassign.ident);
578 for (int i = 0; i<stmt->data.sassign.nfield; i++)
579 free(stmt->data.sassign.fields[i]);
580 expr_free(stmt->data.sassign.expr);
581 break;
582 case sif:
583 expr_free(stmt->data.sif.pred);
584 for (int i = 0; i<stmt->data.sif.nthen; i++)
585 stmt_free(stmt->data.sif.then[i]);
586 free(stmt->data.sif.then);
587 for (int i = 0; i<stmt->data.sif.nels; i++)
588 stmt_free(stmt->data.sif.els[i]);
589 free(stmt->data.sif.els);
590 break;
591 case sreturn:
592 expr_free(stmt->data.sreturn);
593 break;
594 case sexpr:
595 expr_free(stmt->data.sexpr);
596 break;
597 case swhile:
598 expr_free(stmt->data.swhile.pred);
599 for (int i = 0; i<stmt->data.swhile.nbody; i++)
600 stmt_free(stmt->data.swhile.body[i]);
601 free(stmt->data.swhile.body);
602 break;
603 default:
604 die("Unsupported stmt node\n");
605 }
606 free(stmt);
607 }
608
609 void expr_free(struct expr *expr)
610 {
611 if (expr == NULL)
612 return;
613 switch(expr->type) {
614 case ebinop:
615 expr_free(expr->data.ebinop.l);
616 expr_free(expr->data.ebinop.r);
617 break;
618 case ebool:
619 break;
620 case echar:
621 break;
622 case efuncall:
623 free(expr->data.efuncall.ident);
624 for (int i = 0; i<expr->data.efuncall.nargs; i++)
625 expr_free(expr->data.efuncall.args[i]);
626 free(expr->data.efuncall.args);
627 break;
628 case eint:
629 break;
630 case eident:
631 free(expr->data.eident.ident);
632 free(expr->data.eident.fields);
633 break;
634 case enil:
635 break;
636 case etuple:
637 expr_free(expr->data.etuple.left);
638 expr_free(expr->data.etuple.right);
639 break;
640 case estring:
641 free(expr->data.estring);
642 break;
643 case eunop:
644 expr_free(expr->data.eunop.l);
645 break;
646 default:
647 die("Unsupported expr node\n");
648 }
649 free(expr);
650 }
651
652 void type_free(struct type *type)
653 {
654 if (type == NULL)
655 return;
656 switch (type->type) {
657 case tbasic:
658 break;
659 case tlist:
660 type_free(type->data.tlist);
661 break;
662 case ttuple:
663 type_free(type->data.ttuple.l);
664 type_free(type->data.ttuple.r);
665 break;
666 case tvar:
667 free(type->data.tvar);
668 break;
669 default:
670 die("Unsupported type node\n");
671 }
672 free(type);
673 }