locations
[ccc.git] / ast.c
1 #include <stdlib.h>
2 #include <stdio.h>
3 #include <string.h>
4
5 #include "util.h"
6 #include "ast.h"
7 #include "y.tab.h"
8
9 static const char *binop_str[] = {
10 [binor] = "||", [binand] = "&&", [eq] = "==", [neq] = "!=",
11 [leq] = "<=", [le] = "<", [geq] = ">=", [ge] = ">", [cons] = ":",
12 [plus] = "+", [minus] = "-", [times] = "*", [divide] = "/",
13 [modulo] = "%", [power] = "^",
14 };
15 static const char *fieldspec_str[] = {
16 [fst] = "fst", [snd] = "snd", [hd] = "hd", [tl] = "tl"};
17 static const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
18 static const char *basictype_str[] = {
19 [btbool] = "Bool", [btchar] = "Char", [btint] = "Int",
20 [btvoid] = "Void",
21 };
22
23 struct ast *ast(struct list *decls)
24 {
25 struct ast *res = safe_malloc(sizeof(struct ast));
26 res->decls = (struct decl **)list_to_array(decls, &res->ndecls, true);
27 return res;
28 }
29
30 struct vardecl *vardecl(struct type *type, char *ident, struct expr *expr)
31 {
32 struct vardecl *res = safe_malloc(sizeof(struct vardecl));
33 res->type = type;
34 res->ident = ident;
35 res->expr = expr;
36 return res;
37 }
38
39 struct decl *decl_fun(char *ident, struct list *args, struct list *atypes,
40 struct type *rtype, struct list *vars, struct list *body)
41 {
42 struct decl *res = safe_malloc(sizeof(struct decl));
43 res->type = dfundecl;
44 res->data.dfun.ident = ident;
45 res->data.dfun.args = (char **)
46 list_to_array(args, &res->data.dfun.nargs, true);
47 res->data.dfun.atypes = (struct type **)
48 list_to_array(atypes, &res->data.dfun.natypes, true);
49 res->data.dfun.rtype = rtype;
50 res->data.dfun.vars = (struct vardecl **)
51 list_to_array(vars, &res->data.dfun.nvar, true);
52 res->data.dfun.body = (struct stmt **)
53 list_to_array(body, &res->data.dfun.nbody, true);
54 return res;
55 }
56
57 struct decl *decl_var(struct vardecl *vardecl)
58 {
59 struct decl *res = safe_malloc(sizeof(struct decl));
60 res->type = dvardecl;
61 res->data.dvar = vardecl;
62 return res;
63 }
64
65 struct stmt *stmt_assign(char *ident, struct expr *expr)
66 {
67 struct stmt *res = safe_malloc(sizeof(struct stmt));
68 res->type = sassign;
69 res->data.sassign.ident = ident;
70 res->data.sassign.expr = expr;
71 return res;
72 }
73
74 struct stmt *stmt_if(struct expr *pred, struct list *then, struct list *els)
75 {
76 struct stmt *res = safe_malloc(sizeof(struct stmt));
77 res->type = sif;
78 res->data.sif.pred = pred;
79 res->data.sif.then = (struct stmt **)
80 list_to_array(then, &res->data.sif.nthen, true);
81 res->data.sif.els = (struct stmt **)
82 list_to_array(els, &res->data.sif.nels, true);
83 return res;
84 }
85
86 struct stmt *stmt_return(struct expr *rtrn)
87 {
88 struct stmt *res = safe_malloc(sizeof(struct stmt));
89 res->type = sreturn;
90 res->data.sreturn = rtrn;
91 return res;
92 }
93
94 struct stmt *stmt_expr(struct expr *expr)
95 {
96 struct stmt *res = safe_malloc(sizeof(struct stmt));
97 res->type = sexpr;
98 res->data.sexpr = expr;
99 return res;
100 }
101
102 struct stmt *stmt_while(struct expr *pred, struct list *body)
103 {
104 struct stmt *res = safe_malloc(sizeof(struct stmt));
105 res->type = swhile;
106 res->data.swhile.pred = pred;
107 res->data.swhile.body = (struct stmt **)
108 list_to_array(body, &res->data.swhile.nbody, true);
109 return res;
110 }
111
112 struct expr *expr_binop(struct expr *l, enum binop op, struct expr *r)
113 {
114 struct expr *res = safe_malloc(sizeof(struct expr));
115 res->type = ebinop;
116 res->data.ebinop.l = l;
117 res->data.ebinop.op = op;
118 res->data.ebinop.r = r;
119 return res;
120 }
121
122 struct expr *expr_bool(bool b)
123 {
124 struct expr *res = safe_malloc(sizeof(struct expr));
125 res->type = ebool;
126 res->data.ebool = b;
127 return res;
128 }
129 int fromHex(char c)
130 {
131 if (c >= '0' && c <= '9')
132 return c-'0';
133 if (c >= 'a' && c <= 'f')
134 return c-'a'+10;
135 if (c >= 'A' && c <= 'F')
136 return c-'A'+10;
137 return -1;
138 }
139
140 struct expr *expr_char(const char *c)
141 {
142 struct expr *res = safe_malloc(sizeof(struct expr));
143 res->type = echar;
144 //regular char
145 if (strlen(c) == 3)
146 res->data.echar = c[1];
147 //escape
148 if (strlen(c) == 4)
149 switch(c[2]) {
150 case '0': res->data.echar = '\0'; break;
151 case 'a': res->data.echar = '\a'; break;
152 case 'b': res->data.echar = '\b'; break;
153 case 't': res->data.echar = '\t'; break;
154 case 'v': res->data.echar = '\v'; break;
155 case 'f': res->data.echar = '\f'; break;
156 case 'r': res->data.echar = '\r'; break;
157 }
158 //hex escape
159 if (strlen(c) == 6)
160 res->data.echar = (fromHex(c[3])<<4)+fromHex(c[4]);
161 return res;
162 }
163
164 struct expr *expr_funcall(char *ident, struct list *args)
165 {
166 struct expr *res = safe_malloc(sizeof(struct expr));
167 res->type = efuncall;
168 res->data.efuncall.ident = ident;
169 res->data.efuncall.args = (struct expr **)
170 list_to_array(args, &res->data.efuncall.nargs, true);
171 return res;
172 }
173
174 struct expr *expr_int(int integer)
175 {
176 struct expr *res = safe_malloc(sizeof(struct expr));
177 res->type = eint;
178 res->data.eint = integer;
179 return res;
180 }
181
182 struct expr *expr_ident(char *ident, struct list *fields)
183 {
184 struct expr *res = safe_malloc(sizeof(struct expr));
185 res->type = eident;
186 res->data.eident.ident = ident;
187
188 void **els = list_to_array(fields, &res->data.eident.nfields, true);
189 res->data.eident.fields = (enum fieldspec *)safe_malloc(
190 res->data.eident.nfields*sizeof(enum fieldspec));
191 for (int i = 0; i<res->data.eident.nfields; i++) {
192 char *t = els[i];
193 if (strcmp(t, "fst") == 0)
194 res->data.eident.fields[i] = fst;
195 else if (strcmp(t, "snd") == 0)
196 res->data.eident.fields[i] = snd;
197 else if (strcmp(t, "hd") == 0)
198 res->data.eident.fields[i] = hd;
199 else if (strcmp(t, "tl") == 0)
200 res->data.eident.fields[i] = tl;
201 free(t);
202 }
203 free(els);
204 return res;
205 }
206
207 struct expr *expr_nil()
208 {
209 struct expr *res = safe_malloc(sizeof(struct expr));
210 res->type = enil;
211 return res;
212 }
213
214 struct expr *expr_tuple(struct expr *left, struct expr *right)
215 {
216 struct expr *res = safe_malloc(sizeof(struct expr));
217 res->type = etuple;
218 res->data.etuple.left = left;
219 res->data.etuple.right = right;
220 return res;
221 }
222
223 struct expr *expr_unop(enum unop op, struct expr *l)
224 {
225 struct expr *res = safe_malloc(sizeof(struct expr));
226 res->type = eunop;
227 res->data.eunop.op = op;
228 res->data.eunop.l = l;
229 return res;
230 }
231
232 struct type *type_basic(enum basictype type)
233 {
234 struct type *res = safe_malloc(sizeof(struct type));
235 res->type = tbasic;
236 res->data.tbasic = type;
237 return res;
238 }
239
240 struct type *type_list(struct type *type)
241 {
242 struct type *res = safe_malloc(sizeof(struct type));
243 res->type = tlist;
244 res->data.tlist = type;
245 return res;
246 }
247
248 struct type *type_tuple(struct type *l, struct type *r)
249 {
250 struct type *res = safe_malloc(sizeof(struct type));
251 res->type = ttuple;
252 res->data.ttuple.l = l;
253 res->data.ttuple.r = r;
254 return res;
255 }
256
257 struct type *type_var(char *ident)
258 {
259 struct type *res = safe_malloc(sizeof(struct type));
260 if (strcmp(ident, "Int") == 0) {
261 res->type = tbasic;
262 res->data.tbasic = btint;
263 free(ident);
264 } else if (strcmp(ident, "Char") == 0) {
265 res->type = tbasic;
266 res->data.tbasic = btchar;
267 free(ident);
268 } else if (strcmp(ident, "Bool") == 0) {
269 res->type = tbasic;
270 res->data.tbasic = btbool;
271 free(ident);
272 } else if (strcmp(ident, "Void") == 0) {
273 res->type = tbasic;
274 res->data.tbasic = btvoid;
275 free(ident);
276 } else {
277 res->type = tvar;
278 res->data.tvar = ident;
279 }
280 return res;
281 }
282
283
284 const char *cescapes[] = {
285 [0] = "\\0", [1] = "\\x01", [2] = "\\x02", [3] = "\\x03",
286 [4] = "\\x04", [5] = "\\x05", [6] = "\\x06", [7] = "\\a", [8] = "\\b",
287 [9] = "\\t", [10] = "\\n", [11] = "\\v", [12] = "\\f", [13] = "\\r",
288 [14] = "\\x0E", [15] = "\\x0F", [16] = "\\x10", [17] = "\\x11",
289 [18] = "\\x12", [19] = "\\x13", [20] = "\\x14", [21] = "\\x15",
290 [22] = "\\x16", [23] = "\\x17", [24] = "\\x18", [25] = "\\x19",
291 [26] = "\\x1A", [27] = "\\x1B", [28] = "\\x1C", [29] = "\\x1D",
292 [30] = "\\x1E", [31] = "\\x1F",
293 [127] = "\\x7F"
294 };
295
296 void ast_print(struct ast *ast, FILE *out)
297 {
298 if (ast == NULL)
299 return;
300 for (int i = 0; i<ast->ndecls; i++)
301 decl_print(ast->decls[i], 0, out);
302 }
303
304 void vardecl_print(struct vardecl *decl, int indent, FILE *out)
305 {
306 pindent(indent, out);
307 if (decl->type == NULL)
308 safe_fprintf(out, "var");
309 else
310 type_print(decl->type, out);
311 safe_fprintf(out, " %s = ", decl->ident);
312 expr_print(decl->expr, out);
313 safe_fprintf(out, ";\n");
314 }
315
316 void decl_print(struct decl *decl, int indent, FILE *out)
317 {
318 if (decl == NULL)
319 return;
320 switch(decl->type) {
321 case dfundecl:
322 pindent(indent, out);
323 safe_fprintf(out, "%s (", decl->data.dfun.ident);
324 for (int i = 0; i<decl->data.dfun.nargs; i++) {
325 safe_fprintf(out, "%s", decl->data.dfun.args[i]);
326 if (i < decl->data.dfun.nargs - 1)
327 safe_fprintf(out, ", ");
328 }
329 safe_fprintf(out, ")");
330 if (decl->data.dfun.rtype != NULL) {
331 safe_fprintf(out, " :: ");
332 for (int i = 0; i<decl->data.dfun.natypes; i++) {
333 type_print(decl->data.dfun.atypes[i], out);
334 safe_fprintf(out, " ");
335 }
336 safe_fprintf(out, "-> ");
337 type_print(decl->data.dfun.rtype, out);
338 }
339 safe_fprintf(out, " {\n");
340 for (int i = 0; i<decl->data.dfun.nvar; i++)
341 vardecl_print(decl->data.dfun.vars[i], indent+1, out);
342 for (int i = 0; i<decl->data.dfun.nbody; i++)
343 stmt_print(decl->data.dfun.body[i], indent+1, out);
344 pindent(indent, out);
345 safe_fprintf(out, "}\n");
346 break;
347 case dvardecl:
348 vardecl_print(decl->data.dvar, indent, out);
349 break;
350 default:
351 die("Unsupported decl node\n");
352 }
353 }
354
355 void stmt_print(struct stmt *stmt, int indent, FILE *out)
356 {
357 if (stmt == NULL)
358 return;
359 switch(stmt->type) {
360 case sassign:
361 pindent(indent, out);
362 fprintf(out, "%s", stmt->data.sassign.ident);
363 safe_fprintf(out, " = ");
364 expr_print(stmt->data.sassign.expr, out);
365 safe_fprintf(out, ";\n");
366 break;
367 case sif:
368 pindent(indent, out);
369 safe_fprintf(out, "if (");
370 expr_print(stmt->data.sif.pred, out);
371 safe_fprintf(out, ") {\n");
372 for (int i = 0; i<stmt->data.sif.nthen; i++)
373 stmt_print(stmt->data.sif.then[i], indent+1, out);
374 pindent(indent, out);
375 safe_fprintf(out, "} else {\n");
376 for (int i = 0; i<stmt->data.sif.nels; i++)
377 stmt_print(stmt->data.sif.els[i], indent+1, out);
378 pindent(indent, out);
379 safe_fprintf(out, "}\n");
380 break;
381 case sreturn:
382 pindent(indent, out);
383 safe_fprintf(out, "return ");
384 expr_print(stmt->data.sreturn, out);
385 safe_fprintf(out, ";\n");
386 break;
387 case sexpr:
388 pindent(indent, out);
389 expr_print(stmt->data.sexpr, out);
390 safe_fprintf(out, ";\n");
391 break;
392 case swhile:
393 pindent(indent, out);
394 safe_fprintf(out, "while (");
395 expr_print(stmt->data.swhile.pred, out);
396 safe_fprintf(out, ") {\n");
397 for (int i = 0; i<stmt->data.swhile.nbody; i++) {
398 stmt_print(stmt->data.swhile.body[i], indent+1, out);
399 }
400 pindent(indent, out);
401 safe_fprintf(out, "}\n");
402 break;
403 default:
404 die("Unsupported stmt node\n");
405 }
406 }
407
408 void expr_print(struct expr *expr, FILE *out)
409 {
410 if (expr == NULL)
411 return;
412 switch(expr->type) {
413 case ebinop:
414 safe_fprintf(out, "(");
415 expr_print(expr->data.ebinop.l, out);
416 safe_fprintf(out, "%s", binop_str[expr->data.ebinop.op]);
417 expr_print(expr->data.ebinop.r, out);
418 safe_fprintf(out, ")");
419 break;
420 case ebool:
421 safe_fprintf(out, "%s", expr->data.ebool ? "true" : "false");
422 break;
423 case echar:
424 if (expr->data.echar < 0)
425 safe_fprintf(out, "'?'");
426 if (expr->data.echar < ' ' || expr->data.echar == 127)
427 safe_fprintf(out, "'%s'",
428 cescapes[(int)expr->data.echar]);
429 else
430 safe_fprintf(out, "'%c'", expr->data.echar);
431 break;
432 case efuncall:
433 safe_fprintf(out, "%s(", expr->data.efuncall.ident);
434 for(int i = 0; i<expr->data.efuncall.nargs; i++) {
435 expr_print(expr->data.efuncall.args[i], out);
436 if (i+1 < expr->data.efuncall.nargs)
437 safe_fprintf(out, ", ");
438 }
439 safe_fprintf(out, ")");
440 break;
441 case eint:
442 safe_fprintf(out, "%d", expr->data.eint);
443 break;
444 case eident:
445 fprintf(out, "%s", expr->data.eident.ident);
446 for (int i = 0; i<expr->data.eident.nfields; i++)
447 fprintf(out, ".%s",
448 fieldspec_str[expr->data.eident.fields[i]]);
449 break;
450 case enil:
451 safe_fprintf(out, "[]");
452 break;
453 case etuple:
454 safe_fprintf(out, "(");
455 expr_print(expr->data.etuple.left, out);
456 safe_fprintf(out, ", ");
457 expr_print(expr->data.etuple.right, out);
458 safe_fprintf(out, ")");
459 break;
460 case eunop:
461 safe_fprintf(out, "(%s", unop_str[expr->data.eunop.op]);
462 expr_print(expr->data.eunop.l, out);
463 safe_fprintf(out, ")");
464 break;
465 default:
466 die("Unsupported expr node\n");
467 }
468 }
469
470 void type_print(struct type *type, FILE *out)
471 {
472 if (type == NULL)
473 return;
474 switch (type->type) {
475 case tbasic:
476 safe_fprintf(out, "%s", basictype_str[type->data.tbasic]);
477 break;
478 case tlist:
479 safe_fprintf(out, "[");
480 type_print(type->data.tlist, out);
481 safe_fprintf(out, "]");
482 break;
483 case ttuple:
484 safe_fprintf(out, "(");
485 type_print(type->data.ttuple.l, out);
486 safe_fprintf(out, ",");
487 type_print(type->data.ttuple.r, out);
488 safe_fprintf(out, ")");
489 break;
490 case tvar:
491 safe_fprintf(out, "%s", type->data.tvar);
492 break;
493 default:
494 die("Unsupported type node\n");
495 }
496 }
497
498 void ast_free(struct ast *ast)
499 {
500 if (ast == NULL)
501 return;
502 for (int i = 0; i<ast->ndecls; i++)
503 decl_free(ast->decls[i]);
504 free(ast->decls);
505 free(ast);
506 }
507
508 void vardecl_free(struct vardecl *decl)
509 {
510 type_free(decl->type);
511 free(decl->ident);
512 expr_free(decl->expr);
513 free(decl);
514 }
515
516 void decl_free(struct decl *decl)
517 {
518 if (decl == NULL)
519 return;
520 switch(decl->type) {
521 case dfundecl:
522 free(decl->data.dfun.ident);
523 for (int i = 0; i<decl->data.dfun.nargs; i++)
524 free(decl->data.dfun.args[i]);
525 free(decl->data.dfun.args);
526 for (int i = 0; i<decl->data.dfun.natypes; i++)
527 type_free(decl->data.dfun.atypes[i]);
528 free(decl->data.dfun.atypes);
529 type_free(decl->data.dfun.rtype);
530 for (int i = 0; i<decl->data.dfun.nvar; i++)
531 vardecl_free(decl->data.dfun.vars[i]);
532 free(decl->data.dfun.vars);
533 for (int i = 0; i<decl->data.dfun.nbody; i++)
534 stmt_free(decl->data.dfun.body[i]);
535 free(decl->data.dfun.body);
536 break;
537 case dvardecl:
538 vardecl_free(decl->data.dvar);
539 break;
540 default:
541 die("Unsupported decl node\n");
542 }
543 free(decl);
544 }
545
546 void stmt_free(struct stmt *stmt)
547 {
548 if (stmt == NULL)
549 return;
550 switch(stmt->type) {
551 case sassign:
552 free(stmt->data.sassign.ident);
553 expr_free(stmt->data.sassign.expr);
554 break;
555 case sif:
556 expr_free(stmt->data.sif.pred);
557 for (int i = 0; i<stmt->data.sif.nthen; i++)
558 stmt_free(stmt->data.sif.then[i]);
559 free(stmt->data.sif.then);
560 for (int i = 0; i<stmt->data.sif.nels; i++)
561 stmt_free(stmt->data.sif.els[i]);
562 free(stmt->data.sif.els);
563 break;
564 case sreturn:
565 expr_free(stmt->data.sreturn);
566 break;
567 case sexpr:
568 expr_free(stmt->data.sexpr);
569 break;
570 case swhile:
571 expr_free(stmt->data.swhile.pred);
572 for (int i = 0; i<stmt->data.swhile.nbody; i++)
573 stmt_free(stmt->data.swhile.body[i]);
574 free(stmt->data.swhile.body);
575 break;
576 default:
577 die("Unsupported stmt node\n");
578 }
579 free(stmt);
580 }
581
582 void expr_free(struct expr *expr)
583 {
584 if (expr == NULL)
585 return;
586 switch(expr->type) {
587 case ebinop:
588 expr_free(expr->data.ebinop.l);
589 expr_free(expr->data.ebinop.r);
590 break;
591 case ebool:
592 break;
593 case echar:
594 break;
595 case efuncall:
596 free(expr->data.efuncall.ident);
597 for (int i = 0; i<expr->data.efuncall.nargs; i++)
598 expr_free(expr->data.efuncall.args[i]);
599 free(expr->data.efuncall.args);
600 break;
601 case eint:
602 break;
603 case eident:
604 free(expr->data.eident.ident);
605 free(expr->data.eident.fields);
606 break;
607 case enil:
608 break;
609 case etuple:
610 expr_free(expr->data.etuple.left);
611 expr_free(expr->data.etuple.right);
612 break;
613 case eunop:
614 expr_free(expr->data.eunop.l);
615 break;
616 default:
617 die("Unsupported expr node\n");
618 }
619 free(expr);
620 }
621
622 void type_free(struct type *type)
623 {
624 if (type == NULL)
625 return;
626 switch (type->type) {
627 case tbasic:
628 break;
629 case tlist:
630 type_free(type->data.tlist);
631 break;
632 case ttuple:
633 type_free(type->data.ttuple.l);
634 type_free(type->data.ttuple.r);
635 break;
636 case tvar:
637 free(type->data.tvar);
638 break;
639 default:
640 die("Unsupported type node\n");
641 }
642 free(type);
643 }