rewrite to union type, much better
[ccc.git] / ast.c
1 #include <stdlib.h>
2 #include <stdio.h>
3 #include <string.h>
4
5 #include "util.h"
6 #include "ast.h"
7
8 static const char *binop_str[] = {
9 [binor] = "||", [binand] = "&&", [eq] = "==", [neq] = "!=",
10 [leq] = "<=", [le] = "<", [geq] = ">=", [ge] = ">", [cons] = ":",
11 [plus] = "+", [minus] = "-", [times] = "*", [divide] = "/",
12 [modulo] = "%", [power] = "^",
13 };
14 static const char *fieldspec_str[] = {
15 [fst] = "fst", [snd] = "snd", [hd] = "hd", [tl] = "tl"};
16 static const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
17
18 struct ast *ast(struct list *decls)
19 {
20 struct ast *res = safe_malloc(sizeof(struct ast));
21 res->decls = (struct decl **)list_to_array(decls, &res->ndecls, true);
22 return res;
23 }
24
25 struct vardecl vardecl(char *ident, struct expr *expr)
26 {
27 return (struct vardecl) {.ident=ident, .expr=expr};
28 }
29
30 struct decl *decl_fun(char *ident, struct list *args, struct list *body)
31 {
32 struct decl *res = safe_malloc(sizeof(struct decl));
33 res->type = dfundecl;
34 res->data.dfun.ident = ident;
35 res->data.dfun.args = (char **)
36 list_to_array(args, &res->data.dfun.nargs, true);
37 res->data.dfun.body = (struct stmt **)
38 list_to_array(body, &res->data.dfun.nbody, true);
39 return res;
40 }
41
42 struct decl *decl_var(struct vardecl vardecl)
43 {
44 struct decl *res = safe_malloc(sizeof(struct decl));
45 res->type = dvardecl;
46 res->data.dvar = vardecl;
47 return res;
48 }
49
50 struct stmt *stmt_assign(char *ident, struct expr *expr)
51 {
52 struct stmt *res = safe_malloc(sizeof(struct stmt));
53 res->type = sassign;
54 res->data.sassign.ident = ident;
55 res->data.sassign.expr = expr;
56 return res;
57 }
58
59 struct stmt *stmt_if(struct expr *pred, struct list *then, struct list *els)
60 {
61 struct stmt *res = safe_malloc(sizeof(struct stmt));
62 res->type = sif;
63 res->data.sif.pred = pred;
64 res->data.sif.then = (struct stmt **)
65 list_to_array(then, &res->data.sif.nthen, true);
66 res->data.sif.els = (struct stmt **)
67 list_to_array(els, &res->data.sif.nels, true);
68 return res;
69 }
70
71 struct stmt *stmt_return(struct expr *rtrn)
72 {
73 struct stmt *res = safe_malloc(sizeof(struct stmt));
74 res->type = sreturn;
75 res->data.sreturn = rtrn;
76 return res;
77 }
78
79 struct stmt *stmt_expr(struct expr *expr)
80 {
81 struct stmt *res = safe_malloc(sizeof(struct stmt));
82 res->type = sexpr;
83 res->data.sexpr = expr;
84 return res;
85 }
86
87 struct stmt *stmt_vardecl(struct vardecl vardecl)
88 {
89 struct stmt *res = safe_malloc(sizeof(struct stmt));
90 res->type = svardecl;
91 res->data.svardecl = vardecl;
92 return res;
93 }
94
95 struct stmt *stmt_while(struct expr *pred, struct list *body)
96 {
97 struct stmt *res = safe_malloc(sizeof(struct stmt));
98 res->type = swhile;
99 res->data.swhile.pred = pred;
100 res->data.swhile.body = (struct stmt **)
101 list_to_array(body, &res->data.swhile.nbody, true);
102 return res;
103 }
104
105 struct expr *expr_binop(struct expr *l, enum binop op, struct expr *r)
106 {
107 struct expr *res = safe_malloc(sizeof(struct expr));
108 res->type = ebinop;
109 res->data.ebinop.l = l;
110 res->data.ebinop.op = op;
111 res->data.ebinop.r = r;
112 return res;
113 }
114
115 struct expr *expr_bool(bool b)
116 {
117 struct expr *res = safe_malloc(sizeof(struct expr));
118 res->type = ebool;
119 res->data.ebool = b;
120 return res;
121 }
122 int fromHex(char c)
123 {
124 if (c >= '0' && c <= '9')
125 return c-'0';
126 if (c >= 'a' && c <= 'f')
127 return c-'a'+10;
128 if (c >= 'A' && c <= 'F')
129 return c-'A'+10;
130 return -1;
131 }
132
133 struct expr *expr_char(const char *c)
134 {
135 struct expr *res = safe_malloc(sizeof(struct expr));
136 res->type = echar;
137 //regular char
138 if (strlen(c) == 3)
139 res->data.echar = c[1];
140 //escape
141 if (strlen(c) == 4)
142 switch(c[2]) {
143 case '0': res->data.echar = '\0'; break;
144 case 'a': res->data.echar = '\a'; break;
145 case 'b': res->data.echar = '\b'; break;
146 case 't': res->data.echar = '\t'; break;
147 case 'v': res->data.echar = '\v'; break;
148 case 'f': res->data.echar = '\f'; break;
149 case 'r': res->data.echar = '\r'; break;
150 }
151 //hex escape
152 if (strlen(c) == 6)
153 res->data.echar = (fromHex(c[3])<<4)+fromHex(c[4]);
154 return res;
155 }
156
157 struct expr *expr_funcall(char *ident, struct list *args)
158 {
159 struct expr *res = safe_malloc(sizeof(struct expr));
160 res->type = efuncall;
161 res->data.efuncall.ident = ident;
162 res->data.efuncall.args = (struct expr **)
163 list_to_array(args, &res->data.efuncall.nargs, true);
164 return res;
165 }
166
167 struct expr *expr_int(int integer)
168 {
169 struct expr *res = safe_malloc(sizeof(struct expr));
170 res->type = eint;
171 res->data.eint = integer;
172 return res;
173 }
174
175 struct expr *expr_ident(char *ident, struct list *fields)
176 {
177 struct expr *res = safe_malloc(sizeof(struct expr));
178 res->type = eident;
179 res->data.eident.ident = ident;
180
181 void **els = list_to_array(fields, &res->data.eident.nfields, true);
182 res->data.eident.fields = (enum fieldspec *)safe_malloc(
183 res->data.eident.nfields*sizeof(enum fieldspec));
184 for (int i = 0; i<res->data.eident.nfields; i++) {
185 char *t = els[i];
186 if (strcmp(t, "fst") == 0)
187 res->data.eident.fields[i] = fst;
188 else if (strcmp(t, "snd") == 0)
189 res->data.eident.fields[i] = snd;
190 else if (strcmp(t, "hd") == 0)
191 res->data.eident.fields[i] = hd;
192 else if (strcmp(t, "tl") == 0)
193 res->data.eident.fields[i] = tl;
194 free(t);
195 }
196 free(els);
197 return res;
198 }
199
200 struct expr *expr_nil()
201 {
202 struct expr *res = safe_malloc(sizeof(struct expr));
203 res->type = enil;
204 return res;
205 }
206
207 struct expr *expr_tuple(struct expr *left, struct expr *right)
208 {
209 struct expr *res = safe_malloc(sizeof(struct expr));
210 res->type = etuple;
211 res->data.etuple.left = left;
212 res->data.etuple.right = right;
213 return res;
214 }
215
216 struct expr *expr_unop(enum unop op, struct expr *l)
217 {
218 struct expr *res = safe_malloc(sizeof(struct expr));
219 res->type = eunop;
220 res->data.eunop.op = op;
221 res->data.eunop.l = l;
222 return res;
223 }
224
225 const char *cescapes[] = {
226 [0] = "\\0", [1] = "\\x01", [2] = "\\x02", [3] = "\\x03",
227 [4] = "\\x04", [5] = "\\x05", [6] = "\\x06", [7] = "\\a", [8] = "\\b",
228 [9] = "\\t", [10] = "\\n", [11] = "\\v", [12] = "\\f", [13] = "\\r",
229 [14] = "\\x0E", [15] = "\\x0F", [16] = "\\x10", [17] = "\\x11",
230 [18] = "\\x12", [19] = "\\x13", [20] = "\\x14", [21] = "\\x15",
231 [22] = "\\x16", [23] = "\\x17", [24] = "\\x18", [25] = "\\x19",
232 [26] = "\\x1A", [27] = "\\x1B", [28] = "\\x1C", [29] = "\\x1D",
233 [30] = "\\x1E", [31] = "\\x1F",
234 [127] = "\\x7F"
235 };
236
237 void ast_print(struct ast *ast, FILE *out)
238 {
239 if (ast == NULL)
240 return;
241 for (int i = 0; i<ast->ndecls; i++)
242 decl_print(ast->decls[i], 0, out);
243 }
244
245 void decl_print(struct decl *decl, int indent, FILE *out)
246 {
247 if (decl == NULL)
248 return;
249 switch(decl->type) {
250 case dfundecl:
251 pindent(indent, out);
252 safe_fprintf(out, "%s (", decl->data.dfun.ident);
253 for (int i = 0; i<decl->data.dfun.nargs; i++) {
254 safe_fprintf(out, "%s", decl->data.dfun.args[i]);
255 if (i < decl->data.dfun.nargs - 1)
256 safe_fprintf(out, ", ");
257 }
258 safe_fprintf(out, ") {\n");
259 for (int i = 0; i<decl->data.dfun.nbody; i++)
260 stmt_print(decl->data.dfun.body[i], indent+1, out);
261 pindent(indent, out);
262 safe_fprintf(out, "}\n");
263 break;
264 case dvardecl:
265 pindent(indent, out);
266 safe_fprintf(out, "var %s = ", decl->data.dvar.ident);
267 expr_print(decl->data.dvar.expr, out);
268 safe_fprintf(out, ";\n");
269 break;
270 default:
271 die("Unsupported decl node\n");
272 }
273 }
274
275 void stmt_print(struct stmt *stmt, int indent, FILE *out)
276 {
277 if (stmt == NULL)
278 return;
279 switch(stmt->type) {
280 case sassign:
281 pindent(indent, out);
282 fprintf(out, "%s", stmt->data.sassign.ident);
283 safe_fprintf(out, " = ");
284 expr_print(stmt->data.sassign.expr, out);
285 safe_fprintf(out, ";\n");
286 break;
287 case sif:
288 pindent(indent, out);
289 safe_fprintf(out, "if (");
290 expr_print(stmt->data.sif.pred, out);
291 safe_fprintf(out, ") {\n");
292 for (int i = 0; i<stmt->data.sif.nthen; i++)
293 stmt_print(stmt->data.sif.then[i], indent+1, out);
294 pindent(indent, out);
295 safe_fprintf(out, "} else {\n");
296 for (int i = 0; i<stmt->data.sif.nels; i++)
297 stmt_print(stmt->data.sif.els[i], indent+1, out);
298 pindent(indent, out);
299 safe_fprintf(out, "}\n");
300 break;
301 case sreturn:
302 pindent(indent, out);
303 safe_fprintf(out, "return ");
304 expr_print(stmt->data.sreturn, out);
305 safe_fprintf(out, ";\n");
306 break;
307 case sexpr:
308 pindent(indent, out);
309 expr_print(stmt->data.sexpr, out);
310 safe_fprintf(out, ";\n");
311 break;
312 case svardecl:
313 pindent(indent, out);
314 safe_fprintf(out, "var %s = ", stmt->data.svardecl.ident);
315 expr_print(stmt->data.svardecl.expr, out);
316 safe_fprintf(out, ";\n");
317 break;
318 case swhile:
319 pindent(indent, out);
320 safe_fprintf(out, "while (");
321 expr_print(stmt->data.swhile.pred, out);
322 safe_fprintf(out, ") {\n");
323 for (int i = 0; i<stmt->data.swhile.nbody; i++) {
324 stmt_print(stmt->data.swhile.body[i], indent+1, out);
325 }
326 pindent(indent, out);
327 safe_fprintf(out, "}\n");
328 break;
329 default:
330 die("Unsupported stmt node\n");
331 }
332 }
333
334 void expr_print(struct expr *expr, FILE *out)
335 {
336 if (expr == NULL)
337 return;
338 switch(expr->type) {
339 case ebinop:
340 safe_fprintf(out, "(");
341 expr_print(expr->data.ebinop.l, out);
342 safe_fprintf(out, "%s", binop_str[expr->data.ebinop.op]);
343 expr_print(expr->data.ebinop.r, out);
344 safe_fprintf(out, ")");
345 break;
346 case ebool:
347 safe_fprintf(out, "%s", expr->data.ebool ? "true" : "false");
348 break;
349 case echar:
350 if (expr->data.echar < 0)
351 safe_fprintf(out, "'?'");
352 if (expr->data.echar < ' ' || expr->data.echar == 127)
353 safe_fprintf(out, "'%s'",
354 cescapes[(int)expr->data.echar]);
355 else
356 safe_fprintf(out, "'%c'", expr->data.echar);
357 break;
358 case efuncall:
359 safe_fprintf(out, "%s(", expr->data.efuncall.ident);
360 for(int i = 0; i<expr->data.efuncall.nargs; i++) {
361 expr_print(expr->data.efuncall.args[i], out);
362 if (i+1 < expr->data.efuncall.nargs)
363 safe_fprintf(out, ", ");
364 }
365 safe_fprintf(out, ")");
366 break;
367 case eint:
368 safe_fprintf(out, "%d", expr->data.eint);
369 break;
370 case eident:
371 fprintf(out, "%s", expr->data.eident.ident);
372 for (int i = 0; i<expr->data.eident.nfields; i++)
373 fprintf(out, ".%s",
374 fieldspec_str[expr->data.eident.fields[i]]);
375 break;
376 case enil:
377 safe_fprintf(out, "[]");
378 break;
379 case etuple:
380 safe_fprintf(out, "(");
381 expr_print(expr->data.etuple.left, out);
382 safe_fprintf(out, ", ");
383 expr_print(expr->data.etuple.right, out);
384 safe_fprintf(out, ")");
385 break;
386 case eunop:
387 safe_fprintf(out, "(%s", unop_str[expr->data.eunop.op]);
388 expr_print(expr->data.eunop.l, out);
389 safe_fprintf(out, ")");
390 break;
391 default:
392 die("Unsupported expr node\n");
393 }
394 }
395
396 void ast_free(struct ast *ast)
397 {
398 if (ast == NULL)
399 return;
400 for (int i = 0; i<ast->ndecls; i++)
401 decl_free(ast->decls[i]);
402 free(ast);
403 }
404
405 void decl_free(struct decl *decl)
406 {
407 if (decl == NULL)
408 return;
409 switch(decl->type) {
410 case dfundecl:
411 free(decl->data.dfun.ident);
412 for (int i = 0; i<decl->data.dfun.nargs; i++)
413 free(decl->data.dfun.args[i]);
414 free(decl->data.dfun.args);
415 for (int i = 0; i<decl->data.dfun.nbody; i++)
416 stmt_free(decl->data.dfun.body[i]);
417 free(decl->data.dfun.body);
418 break;
419 case dvardecl:
420 free(decl->data.dvar.ident);
421 expr_free(decl->data.dvar.expr);
422 break;
423 default:
424 die("Unsupported decl node\n");
425 }
426 free(decl);
427 }
428
429 void stmt_free(struct stmt *stmt)
430 {
431 if (stmt == NULL)
432 return;
433 switch(stmt->type) {
434 case sassign:
435 free(stmt->data.sassign.ident);
436 expr_free(stmt->data.sassign.expr);
437 break;
438 case sif:
439 expr_free(stmt->data.sif.pred);
440 for (int i = 0; i<stmt->data.sif.nthen; i++)
441 stmt_free(stmt->data.sif.then[i]);
442 free(stmt->data.sif.then);
443 for (int i = 0; i<stmt->data.sif.nels; i++)
444 stmt_free(stmt->data.sif.els[i]);
445 free(stmt->data.sif.els);
446 break;
447 case sreturn:
448 expr_free(stmt->data.sreturn);
449 break;
450 case sexpr:
451 expr_free(stmt->data.sexpr);
452 break;
453 case svardecl:
454 free(stmt->data.svardecl.ident);
455 expr_free(stmt->data.svardecl.expr);
456 break;
457 case swhile:
458 expr_free(stmt->data.swhile.pred);
459 for (int i = 0; i<stmt->data.swhile.nbody; i++)
460 stmt_free(stmt->data.swhile.body[i]);
461 free(stmt->data.swhile.body);
462 break;
463 default:
464 die("Unsupported stmt node\n");
465 }
466 free(stmt);
467 }
468
469 void expr_free(struct expr *expr)
470 {
471 switch(expr->type) {
472 case ebinop:
473 expr_free(expr->data.ebinop.l);
474 expr_free(expr->data.ebinop.r);
475 break;
476 case ebool:
477 break;
478 case echar:
479 break;
480 case efuncall:
481 free(expr->data.efuncall.ident);
482 for (int i = 0; i<expr->data.efuncall.nargs; i++)
483 expr_free(expr->data.efuncall.args[i]);
484 free(expr->data.efuncall.args);
485 break;
486 case eint:
487 break;
488 case eident:
489 free(expr->data.eident.ident);
490 free(expr->data.eident.fields);
491 break;
492 case enil:
493 break;
494 case etuple:
495 expr_free(expr->data.etuple.left);
496 expr_free(expr->data.etuple.right);
497 break;
498 case eunop:
499 expr_free(expr->data.eunop.l);
500 break;
501 default:
502 die("Unsupported expr node\n");
503 }
504 free(expr);
505 }