make parser more robuust, add string literals and escapes
[ccc.git] / ast.c
1 #include <stdlib.h>
2 #include <stdio.h>
3 #include <string.h>
4
5 #include "util.h"
6 #include "ast.h"
7 #include "y.tab.h"
8
9 static const char *binop_str[] = {
10 [binor] = "||", [binand] = "&&", [eq] = "==", [neq] = "!=",
11 [leq] = "<=", [le] = "<", [geq] = ">=", [ge] = ">", [cons] = ":",
12 [plus] = "+", [minus] = "-", [times] = "*", [divide] = "/",
13 [modulo] = "%", [power] = "^",
14 };
15 static const char *fieldspec_str[] = {
16 [fst] = "fst", [snd] = "snd", [hd] = "hd", [tl] = "tl"};
17 static const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
18 static const char *basictype_str[] = {
19 [btbool] = "Bool", [btchar] = "Char", [btint] = "Int",
20 [btvoid] = "Void",
21 };
22
23 struct ast *ast(struct list *decls)
24 {
25 struct ast *res = safe_malloc(sizeof(struct ast));
26 res->decls = (struct decl **)list_to_array(decls, &res->ndecls, true);
27 return res;
28 }
29
30 struct vardecl *vardecl(struct type *type, char *ident, struct expr *expr)
31 {
32 struct vardecl *res = safe_malloc(sizeof(struct vardecl));
33 res->type = type;
34 res->ident = ident;
35 res->expr = expr;
36 return res;
37 }
38
39 struct decl *decl_fun(char *ident, struct list *args, struct list *atypes,
40 struct type *rtype, struct list *vars, struct list *body)
41 {
42 struct decl *res = safe_malloc(sizeof(struct decl));
43 res->type = dfundecl;
44 res->data.dfun.ident = ident;
45 res->data.dfun.args = (char **)
46 list_to_array(args, &res->data.dfun.nargs, true);
47 res->data.dfun.atypes = (struct type **)
48 list_to_array(atypes, &res->data.dfun.natypes, true);
49 res->data.dfun.rtype = rtype;
50 res->data.dfun.vars = (struct vardecl **)
51 list_to_array(vars, &res->data.dfun.nvar, true);
52 res->data.dfun.body = (struct stmt **)
53 list_to_array(body, &res->data.dfun.nbody, true);
54 return res;
55 }
56
57 struct decl *decl_var(struct vardecl *vardecl)
58 {
59 struct decl *res = safe_malloc(sizeof(struct decl));
60 res->type = dvardecl;
61 res->data.dvar = vardecl;
62 return res;
63 }
64
65 struct stmt *stmt_assign(char *ident, struct list *fields, struct expr *expr)
66 {
67 struct stmt *res = safe_malloc(sizeof(struct stmt));
68 res->type = sassign;
69 res->data.sassign.ident = ident;
70 res->data.sassign.fields = (char **)
71 list_to_array(fields, &res->data.sassign.nfield, true);
72 res->data.sassign.expr = expr;
73 return res;
74 }
75
76 struct stmt *stmt_if(struct expr *pred, struct list *then, struct list *els)
77 {
78 struct stmt *res = safe_malloc(sizeof(struct stmt));
79 res->type = sif;
80 res->data.sif.pred = pred;
81 res->data.sif.then = (struct stmt **)
82 list_to_array(then, &res->data.sif.nthen, true);
83 res->data.sif.els = (struct stmt **)
84 list_to_array(els, &res->data.sif.nels, true);
85 return res;
86 }
87
88 struct stmt *stmt_return(struct expr *rtrn)
89 {
90 struct stmt *res = safe_malloc(sizeof(struct stmt));
91 res->type = sreturn;
92 res->data.sreturn = rtrn;
93 return res;
94 }
95
96 struct stmt *stmt_expr(struct expr *expr)
97 {
98 struct stmt *res = safe_malloc(sizeof(struct stmt));
99 res->type = sexpr;
100 res->data.sexpr = expr;
101 return res;
102 }
103
104 struct stmt *stmt_vardecl(struct vardecl *vardecl)
105 {
106 struct stmt *res = safe_malloc(sizeof(struct stmt));
107 res->type = svardecl;
108 res->data.svardecl = vardecl;
109 return res;
110 }
111
112 struct stmt *stmt_while(struct expr *pred, struct list *body)
113 {
114 struct stmt *res = safe_malloc(sizeof(struct stmt));
115 res->type = swhile;
116 res->data.swhile.pred = pred;
117 res->data.swhile.body = (struct stmt **)
118 list_to_array(body, &res->data.swhile.nbody, true);
119 return res;
120 }
121
122 struct expr *expr_binop(struct expr *l, enum binop op, struct expr *r)
123 {
124 struct expr *res = safe_malloc(sizeof(struct expr));
125 res->type = ebinop;
126 res->data.ebinop.l = l;
127 res->data.ebinop.op = op;
128 res->data.ebinop.r = r;
129 return res;
130 }
131
132 struct expr *expr_bool(bool b)
133 {
134 struct expr *res = safe_malloc(sizeof(struct expr));
135 res->type = ebool;
136 res->data.ebool = b;
137 return res;
138 }
139
140 struct expr *expr_char(char *c)
141 {
142 struct expr *res = safe_malloc(sizeof(struct expr));
143 res->type = echar;
144 res->data.echar = unescape_char(c)[0];
145 return res;
146 }
147
148 static void set_fields(enum fieldspec **farray, int *n, struct list *fields)
149 {
150 void **els = list_to_array(fields, n, true);
151 *farray = (enum fieldspec *)safe_malloc(*n*sizeof(enum fieldspec));
152 for (int i = 0; i<*n; i++) {
153 char *t = els[i];
154 if (strcmp(t, "fst") == 0)
155 (*farray)[i] = fst;
156 else if (strcmp(t, "snd") == 0)
157 (*farray)[i] = snd;
158 else if (strcmp(t, "hd") == 0)
159 (*farray)[i] = hd;
160 else if (strcmp(t, "tl") == 0)
161 (*farray)[i] = tl;
162 free(t);
163 }
164 free(els);
165 }
166
167
168 struct expr *expr_funcall(char *ident, struct list *args, struct list *fields)
169 {
170 struct expr *res = safe_malloc(sizeof(struct expr));
171 res->type = efuncall;
172 res->data.efuncall.ident = ident;
173 res->data.efuncall.args = (struct expr **)
174 list_to_array(args, &res->data.efuncall.nargs, true);
175 set_fields(&res->data.efuncall.fields,
176 &res->data.efuncall.nfields, fields);
177 return res;
178 }
179
180 struct expr *expr_int(int integer)
181 {
182 struct expr *res = safe_malloc(sizeof(struct expr));
183 res->type = eint;
184 res->data.eint = integer;
185 return res;
186 }
187
188 struct expr *expr_ident(char *ident, struct list *fields)
189 {
190 struct expr *res = safe_malloc(sizeof(struct expr));
191 res->type = eident;
192 res->data.eident.ident = ident;
193 set_fields(&res->data.eident.fields, &res->data.eident.nfields, fields);
194 return res;
195 }
196
197 struct expr *expr_nil()
198 {
199 struct expr *res = safe_malloc(sizeof(struct expr));
200 res->type = enil;
201 return res;
202 }
203
204 struct expr *expr_tuple(struct expr *left, struct expr *right)
205 {
206 struct expr *res = safe_malloc(sizeof(struct expr));
207 res->type = etuple;
208 res->data.etuple.left = left;
209 res->data.etuple.right = right;
210 return res;
211 }
212
213 struct expr *expr_string(char *str)
214 {
215 struct expr *res = safe_malloc(sizeof(struct expr));
216 res->type = estring;
217 res->data.estring.nchar = 0;
218 res->data.estring.chars = safe_malloc(strlen(str)+1);
219 char *p = res->data.estring.chars;
220 while(*str != '\0') {
221 str = unescape_char(str);
222 *p++ = *str++;
223 res->data.estring.nchar++;
224 }
225 *p = '\0';
226 return res;
227 }
228
229 struct expr *expr_unop(enum unop op, struct expr *l)
230 {
231 struct expr *res = safe_malloc(sizeof(struct expr));
232 res->type = eunop;
233 res->data.eunop.op = op;
234 res->data.eunop.l = l;
235 return res;
236 }
237
238 struct type *type_basic(enum basictype type)
239 {
240 struct type *res = safe_malloc(sizeof(struct type));
241 res->type = tbasic;
242 res->data.tbasic = type;
243 return res;
244 }
245
246 struct type *type_list(struct type *type)
247 {
248 struct type *res = safe_malloc(sizeof(struct type));
249 res->type = tlist;
250 res->data.tlist = type;
251 return res;
252 }
253
254 struct type *type_tuple(struct type *l, struct type *r)
255 {
256 struct type *res = safe_malloc(sizeof(struct type));
257 res->type = ttuple;
258 res->data.ttuple.l = l;
259 res->data.ttuple.r = r;
260 return res;
261 }
262
263 struct type *type_var(char *ident)
264 {
265 struct type *res = safe_malloc(sizeof(struct type));
266 if (strcmp(ident, "Int") == 0) {
267 res->type = tbasic;
268 res->data.tbasic = btint;
269 free(ident);
270 } else if (strcmp(ident, "Char") == 0) {
271 res->type = tbasic;
272 res->data.tbasic = btchar;
273 free(ident);
274 } else if (strcmp(ident, "Bool") == 0) {
275 res->type = tbasic;
276 res->data.tbasic = btbool;
277 free(ident);
278 } else if (strcmp(ident, "Void") == 0) {
279 res->type = tbasic;
280 res->data.tbasic = btvoid;
281 free(ident);
282 } else {
283 res->type = tvar;
284 res->data.tvar = ident;
285 }
286 return res;
287 }
288
289 void ast_print(struct ast *ast, FILE *out)
290 {
291 if (ast == NULL)
292 return;
293 for (int i = 0; i<ast->ndecls; i++)
294 decl_print(ast->decls[i], 0, out);
295 }
296
297 void vardecl_print(struct vardecl *decl, int indent, FILE *out)
298 {
299 pindent(indent, out);
300 if (decl->type == NULL)
301 safe_fprintf(out, "var");
302 else
303 type_print(decl->type, out);
304 safe_fprintf(out, " %s = ", decl->ident);
305 expr_print(decl->expr, out);
306 safe_fprintf(out, ";\n");
307 }
308
309 void decl_print(struct decl *decl, int indent, FILE *out)
310 {
311 if (decl == NULL)
312 return;
313 switch(decl->type) {
314 case dfundecl:
315 pindent(indent, out);
316 safe_fprintf(out, "%s (", decl->data.dfun.ident);
317 for (int i = 0; i<decl->data.dfun.nargs; i++) {
318 safe_fprintf(out, "%s", decl->data.dfun.args[i]);
319 if (i < decl->data.dfun.nargs - 1)
320 safe_fprintf(out, ", ");
321 }
322 safe_fprintf(out, ")");
323 if (decl->data.dfun.rtype != NULL) {
324 safe_fprintf(out, " :: ");
325 for (int i = 0; i<decl->data.dfun.natypes; i++) {
326 type_print(decl->data.dfun.atypes[i], out);
327 safe_fprintf(out, " ");
328 }
329 safe_fprintf(out, "-> ");
330 type_print(decl->data.dfun.rtype, out);
331 }
332 safe_fprintf(out, " {\n");
333 for (int i = 0; i<decl->data.dfun.nvar; i++)
334 vardecl_print(decl->data.dfun.vars[i], indent+1, out);
335 for (int i = 0; i<decl->data.dfun.nbody; i++)
336 stmt_print(decl->data.dfun.body[i], indent+1, out);
337 pindent(indent, out);
338 safe_fprintf(out, "}\n");
339 break;
340 case dvardecl:
341 vardecl_print(decl->data.dvar, indent, out);
342 break;
343 default:
344 die("Unsupported decl node\n");
345 }
346 }
347
348 void stmt_print(struct stmt *stmt, int indent, FILE *out)
349 {
350 if (stmt == NULL)
351 return;
352 switch(stmt->type) {
353 case sassign:
354 pindent(indent, out);
355 fprintf(out, "%s", stmt->data.sassign.ident);
356 for (int i = 0; i<stmt->data.sassign.nfield; i++)
357 fprintf(out, ".%s", stmt->data.sassign.fields[i]);
358 safe_fprintf(out, " = ");
359 expr_print(stmt->data.sassign.expr, out);
360 safe_fprintf(out, ";\n");
361 break;
362 case sif:
363 pindent(indent, out);
364 safe_fprintf(out, "if (");
365 expr_print(stmt->data.sif.pred, out);
366 safe_fprintf(out, ") {\n");
367 for (int i = 0; i<stmt->data.sif.nthen; i++)
368 stmt_print(stmt->data.sif.then[i], indent+1, out);
369 pindent(indent, out);
370 safe_fprintf(out, "} else {\n");
371 for (int i = 0; i<stmt->data.sif.nels; i++)
372 stmt_print(stmt->data.sif.els[i], indent+1, out);
373 pindent(indent, out);
374 safe_fprintf(out, "}\n");
375 break;
376 case sreturn:
377 pindent(indent, out);
378 safe_fprintf(out, "return ");
379 expr_print(stmt->data.sreturn, out);
380 safe_fprintf(out, ";\n");
381 break;
382 case sexpr:
383 pindent(indent, out);
384 expr_print(stmt->data.sexpr, out);
385 safe_fprintf(out, ";\n");
386 break;
387 case svardecl:
388 vardecl_print(stmt->data.svardecl, indent, out);
389 break;
390 case swhile:
391 pindent(indent, out);
392 safe_fprintf(out, "while (");
393 expr_print(stmt->data.swhile.pred, out);
394 safe_fprintf(out, ") {\n");
395 for (int i = 0; i<stmt->data.swhile.nbody; i++) {
396 stmt_print(stmt->data.swhile.body[i], indent+1, out);
397 }
398 pindent(indent, out);
399 safe_fprintf(out, "}\n");
400 break;
401 default:
402 die("Unsupported stmt node\n");
403 }
404 }
405
406 void expr_print(struct expr *expr, FILE *out)
407 {
408 if (expr == NULL)
409 return;
410 char buf[] = "\\xff";
411 switch(expr->type) {
412 case ebinop:
413 safe_fprintf(out, "(");
414 expr_print(expr->data.ebinop.l, out);
415 safe_fprintf(out, "%s", binop_str[expr->data.ebinop.op]);
416 expr_print(expr->data.ebinop.r, out);
417 safe_fprintf(out, ")");
418 break;
419 case ebool:
420 safe_fprintf(out, "%s", expr->data.ebool ? "true" : "false");
421 break;
422 case echar:
423 safe_fprintf(out, "'%s'",
424 escape_char(expr->data.echar, buf, false));
425 break;
426 case efuncall:
427 safe_fprintf(out, "%s(", expr->data.efuncall.ident);
428 for(int i = 0; i<expr->data.efuncall.nargs; i++) {
429 expr_print(expr->data.efuncall.args[i], out);
430 if (i+1 < expr->data.efuncall.nargs)
431 safe_fprintf(out, ", ");
432 }
433 safe_fprintf(out, ")");
434 for (int i = 0; i<expr->data.efuncall.nfields; i++)
435 fprintf(out, ".%s",
436 fieldspec_str[expr->data.efuncall.fields[i]]);
437 break;
438 case eint:
439 safe_fprintf(out, "%d", expr->data.eint);
440 break;
441 case eident:
442 fprintf(out, "%s", expr->data.eident.ident);
443 for (int i = 0; i<expr->data.eident.nfields; i++)
444 fprintf(out, ".%s",
445 fieldspec_str[expr->data.eident.fields[i]]);
446 break;
447 case enil:
448 safe_fprintf(out, "[]");
449 break;
450 case etuple:
451 safe_fprintf(out, "(");
452 expr_print(expr->data.etuple.left, out);
453 safe_fprintf(out, ", ");
454 expr_print(expr->data.etuple.right, out);
455 safe_fprintf(out, ")");
456 break;
457 case estring:
458 safe_fprintf(out, "\"");
459 for (int i = 0; i<expr->data.estring.nchar; i++)
460 safe_fprintf(out, "%s", escape_char(
461 expr->data.estring.chars[i], buf, true));
462 safe_fprintf(out, "\"");
463 break;
464 case eunop:
465 safe_fprintf(out, "(%s", unop_str[expr->data.eunop.op]);
466 expr_print(expr->data.eunop.l, out);
467 safe_fprintf(out, ")");
468 break;
469 default:
470 die("Unsupported expr node\n");
471 }
472 }
473
474 void type_print(struct type *type, FILE *out)
475 {
476 if (type == NULL)
477 return;
478 switch (type->type) {
479 case tbasic:
480 safe_fprintf(out, "%s", basictype_str[type->data.tbasic]);
481 break;
482 case tlist:
483 safe_fprintf(out, "[");
484 type_print(type->data.tlist, out);
485 safe_fprintf(out, "]");
486 break;
487 case ttuple:
488 safe_fprintf(out, "(");
489 type_print(type->data.ttuple.l, out);
490 safe_fprintf(out, ",");
491 type_print(type->data.ttuple.r, out);
492 safe_fprintf(out, ")");
493 break;
494 case tvar:
495 safe_fprintf(out, "%s", type->data.tvar);
496 break;
497 default:
498 die("Unsupported type node\n");
499 }
500 }
501
502 void ast_free(struct ast *ast)
503 {
504 if (ast == NULL)
505 return;
506 for (int i = 0; i<ast->ndecls; i++)
507 decl_free(ast->decls[i]);
508 free(ast->decls);
509 free(ast);
510 }
511
512 void vardecl_free(struct vardecl *decl)
513 {
514 type_free(decl->type);
515 free(decl->ident);
516 expr_free(decl->expr);
517 free(decl);
518 }
519
520 void decl_free(struct decl *decl)
521 {
522 if (decl == NULL)
523 return;
524 switch(decl->type) {
525 case dfundecl:
526 free(decl->data.dfun.ident);
527 for (int i = 0; i<decl->data.dfun.nargs; i++)
528 free(decl->data.dfun.args[i]);
529 free(decl->data.dfun.args);
530 for (int i = 0; i<decl->data.dfun.natypes; i++)
531 type_free(decl->data.dfun.atypes[i]);
532 free(decl->data.dfun.atypes);
533 type_free(decl->data.dfun.rtype);
534 for (int i = 0; i<decl->data.dfun.nvar; i++)
535 vardecl_free(decl->data.dfun.vars[i]);
536 free(decl->data.dfun.vars);
537 for (int i = 0; i<decl->data.dfun.nbody; i++)
538 stmt_free(decl->data.dfun.body[i]);
539 free(decl->data.dfun.body);
540 break;
541 case dvardecl:
542 vardecl_free(decl->data.dvar);
543 break;
544 default:
545 die("Unsupported decl node\n");
546 }
547 free(decl);
548 }
549
550 void stmt_free(struct stmt *stmt)
551 {
552 if (stmt == NULL)
553 return;
554 switch(stmt->type) {
555 case sassign:
556 free(stmt->data.sassign.ident);
557 for (int i = 0; i<stmt->data.sassign.nfield; i++)
558 free(stmt->data.sassign.fields[i]);
559 free(stmt->data.sassign.fields);
560 expr_free(stmt->data.sassign.expr);
561 break;
562 case sif:
563 expr_free(stmt->data.sif.pred);
564 for (int i = 0; i<stmt->data.sif.nthen; i++)
565 stmt_free(stmt->data.sif.then[i]);
566 free(stmt->data.sif.then);
567 for (int i = 0; i<stmt->data.sif.nels; i++)
568 stmt_free(stmt->data.sif.els[i]);
569 free(stmt->data.sif.els);
570 break;
571 case sreturn:
572 expr_free(stmt->data.sreturn);
573 break;
574 case sexpr:
575 expr_free(stmt->data.sexpr);
576 break;
577 case swhile:
578 expr_free(stmt->data.swhile.pred);
579 for (int i = 0; i<stmt->data.swhile.nbody; i++)
580 stmt_free(stmt->data.swhile.body[i]);
581 free(stmt->data.swhile.body);
582 break;
583 case svardecl:
584 vardecl_free(stmt->data.svardecl);
585 break;
586 default:
587 die("Unsupported stmt node\n");
588 }
589 free(stmt);
590 }
591
592 void expr_free(struct expr *expr)
593 {
594 if (expr == NULL)
595 return;
596 switch(expr->type) {
597 case ebinop:
598 expr_free(expr->data.ebinop.l);
599 expr_free(expr->data.ebinop.r);
600 break;
601 case ebool:
602 break;
603 case echar:
604 break;
605 case efuncall:
606 free(expr->data.efuncall.ident);
607 for (int i = 0; i<expr->data.efuncall.nargs; i++)
608 expr_free(expr->data.efuncall.args[i]);
609 free(expr->data.efuncall.fields);
610 free(expr->data.efuncall.args);
611 break;
612 case eint:
613 break;
614 case eident:
615 free(expr->data.eident.ident);
616 free(expr->data.eident.fields);
617 break;
618 case enil:
619 break;
620 case etuple:
621 expr_free(expr->data.etuple.left);
622 expr_free(expr->data.etuple.right);
623 break;
624 case estring:
625 free(expr->data.estring.chars);
626 break;
627 case eunop:
628 expr_free(expr->data.eunop.l);
629 break;
630 default:
631 die("Unsupported expr node\n");
632 }
633 free(expr);
634 }
635
636 void type_free(struct type *type)
637 {
638 if (type == NULL)
639 return;
640 switch (type->type) {
641 case tbasic:
642 break;
643 case tlist:
644 type_free(type->data.tlist);
645 break;
646 case ttuple:
647 type_free(type->data.ttuple.l);
648 type_free(type->data.ttuple.r);
649 break;
650 case tvar:
651 free(type->data.tvar);
652 break;
653 default:
654 die("Unsupported type node\n");
655 }
656 free(type);
657 }