add assign, nil
[ccc.git] / ast.c
1 #include <stdlib.h>
2 #include <stdio.h>
3 #include <string.h>
4
5 #include "util.h"
6 #include "ast.h"
7
8 #ifdef DEBUG
9 static const char *ast_type_str[] = {
10 [an_assign] = "assign", [an_bool] = "bool", [an_binop] = "binop",
11 [an_char] = "char", [an_cons] = "cons", [an_funcall] = "funcall",
12 [an_fundecl] = "fundecl", [an_ident] = "ident", [an_if] = "if",
13 [an_int] = "int", [an_list] = "list", [an_return] = "return",
14 [an_stmt_expr] = "stmt_expr", [an_unop] = "unop",
15 [an_vardecl] = "vardecl", [an_while] = "while",
16 };
17 #endif
18 static const char *binop_str[] = {
19 [binor] = "||", [binand] = "&&", [eq] = "==", [neq] = "!=",
20 [leq] = "<=", [le] = "<", [geq] = ">=", [ge] = ">", [cons] = ":",
21 [plus] = "+", [minus] = "-", [times] = "*", [divide] = "/",
22 [modulo] = "%", [power] = "^",
23 };
24 static const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
25
26 #ifdef DEBUG
27 #define must_be(node, ntype, msg) {\
28 if ((node)->type != (ntype)) {\
29 fprintf(stderr, "%s can't be %s\n",\
30 msg, ast_type_str[node->type]);\
31 exit(1);\
32 }\
33 }
34 #else
35 #define must_be(node, ntype, msg) ;
36 #endif
37
38 #define ast_alloc() ((struct ast *)safe_malloc(sizeof(struct ast)))
39
40 struct ast *ast_assign(struct ast *ident, struct ast *expr)
41 {
42 struct ast *res = ast_alloc();
43 res->type = an_assign;
44
45 must_be(ident, an_ident, "ident of an assign");
46 res->data.an_assign.ident = ident->data.an_ident;
47 free(ident);
48
49 res->data.an_assign.expr = expr;
50 return res;
51 }
52
53 struct ast *ast_binop(struct ast *l, enum binop op, struct ast *r)
54 {
55 struct ast *res = ast_alloc();
56 res->type = an_binop;
57 res->data.an_binop.l = l;
58 res->data.an_binop.op = op;
59 res->data.an_binop.r = r;
60 return res;
61 }
62
63 struct ast *ast_bool(bool b)
64 {
65 struct ast *res = ast_alloc();
66 res->type = an_bool;
67 res->data.an_bool = b;
68 return res;
69 }
70
71 int fromHex(char c)
72 {
73 if (c >= '0' && c <= '9')
74 return c-'0';
75 if (c >= 'a' && c <= 'f')
76 return c-'a'+10;
77 if (c >= 'A' && c <= 'F')
78 return c-'A'+10;
79 return -1;
80 }
81
82 struct ast *ast_char(const char *c)
83 {
84 struct ast *res = ast_alloc();
85 res->type = an_char;
86 //regular char
87 if (strlen(c) == 3)
88 res->data.an_char = c[1];
89 //escape
90 if (strlen(c) == 4)
91 switch(c[2]) {
92 case '0': res->data.an_char = '\0'; break;
93 case 'a': res->data.an_char = '\a'; break;
94 case 'b': res->data.an_char = '\b'; break;
95 case 't': res->data.an_char = '\t'; break;
96 case 'v': res->data.an_char = '\v'; break;
97 case 'f': res->data.an_char = '\f'; break;
98 case 'r': res->data.an_char = '\r'; break;
99 }
100 //hex escape
101 if (strlen(c) == 6)
102 res->data.an_char = (fromHex(c[3])<<4)+fromHex(c[4]);
103 return res;
104 }
105
106 struct ast *ast_cons(struct ast *el, struct ast *tail)
107 {
108 struct ast *res = ast_alloc();
109 res->type = an_cons;
110 res->data.an_cons.el = el;
111 res->data.an_cons.tail = tail;
112 return res;
113 }
114
115 struct ast *ast_funcall(struct ast *ident, struct ast *args)
116 {
117 struct ast *res = ast_alloc();
118 res->type = an_funcall;
119
120 //ident
121 must_be(ident, an_ident, "ident of a funcall");
122 res->data.an_funcall.ident = ident->data.an_ident;
123 free(ident);
124
125 //args
126 must_be(args, an_list, "args of a funcall");
127 res->data.an_funcall.nargs = args->data.an_list.n;
128 res->data.an_funcall.args = args->data.an_list.ptr;
129 free(args);
130 return res;
131 }
132
133 struct ast *ast_fundecl(struct ast *ident, struct ast *args, struct ast *body)
134 {
135 struct ast *res = ast_alloc();
136 res->type = an_fundecl;
137
138 //ident
139 must_be(ident, an_ident, "ident of a fundecl");
140 res->data.an_fundecl.ident = ident->data.an_ident;
141 free(ident);
142
143 //args
144 must_be(args, an_list, "args of a fundecl");
145 res->data.an_fundecl.nargs = args->data.an_list.n;
146 res->data.an_fundecl.args = (char **)safe_malloc(
147 args->data.an_list.n*sizeof(char *));
148 for (int i = 0; i<args->data.an_list.n; i++) {
149 struct ast *e = args->data.an_list.ptr[i];
150 must_be(e, an_ident, "arg of a fundecl")
151 res->data.an_fundecl.args[i] = e->data.an_ident;
152 free(e);
153 }
154 free(args);
155
156 //body
157 must_be(body, an_list, "body of a fundecl");
158 res->data.an_fundecl.nbody = body->data.an_list.n;
159 res->data.an_fundecl.body = body->data.an_list.ptr;
160 free(body);
161
162 return res;
163 }
164
165 struct ast *ast_if(struct ast *pred, struct ast *then, struct ast *els)
166 {
167 struct ast *res = ast_alloc();
168 res->type = an_if;
169 res->data.an_if.pred = pred;
170
171 must_be(body, an_list, "body of a then");
172 res->data.an_if.nthen = then->data.an_list.n;
173 res->data.an_if.then = then->data.an_list.ptr;
174 free(then);
175
176 must_be(body, an_list, "body of a els");
177 res->data.an_if.nels = els->data.an_list.n;
178 res->data.an_if.els = els->data.an_list.ptr;
179 free(els);
180
181 return res;
182 }
183
184 struct ast *ast_int(int integer)
185 {
186 struct ast *res = ast_alloc();
187 res->type = an_int;
188 res->data.an_int = integer;
189 return res;
190 }
191
192 struct ast *ast_ident(char *ident)
193 {
194 struct ast *res = ast_alloc();
195 res->type = an_ident;
196 res->data.an_ident = safe_strdup(ident);
197 return res;
198 }
199
200 struct ast *ast_list(struct ast *llist)
201 {
202 struct ast *res = ast_alloc();
203 res->type = an_list;
204 res->data.an_list.n = 0;
205
206 int i = ast_llistlength(llist);
207
208 //Allocate array
209 res->data.an_list.n = i;
210 res->data.an_list.ptr = (struct ast **)safe_malloc(
211 res->data.an_list.n*sizeof(struct ast *));
212
213 struct ast *r = llist;
214 while(i > 0) {
215 res->data.an_list.ptr[--i] = r->data.an_cons.el;
216 struct ast *t = r;
217 r = r->data.an_cons.tail;
218 free(t);
219 }
220 return res;
221 }
222
223 struct ast *ast_nil()
224 {
225 struct ast *res = ast_alloc();
226 res->type = an_nil;
227 return res;
228 }
229
230 struct ast *ast_return(struct ast *r)
231 {
232 struct ast *res = ast_alloc();
233 res->type = an_return;
234 res->data.an_return = r;
235 return res;
236 }
237
238 struct ast *ast_stmt_expr(struct ast *expr)
239 {
240 struct ast *res = ast_alloc();
241 res->type = an_stmt_expr;
242 res->data.an_stmt_expr = expr;
243 return res;
244 }
245
246 struct ast *ast_unop(enum unop op, struct ast *l)
247 {
248 struct ast *res = ast_alloc();
249 res->type = an_unop;
250 res->data.an_unop.op = op;
251 res->data.an_unop.l = l;
252 return res;
253 }
254
255 struct ast *ast_vardecl(struct ast *ident, struct ast *l)
256 {
257 struct ast *res = ast_alloc();
258 res->type = an_vardecl;
259 must_be(ident, an_ident, "ident of a vardecl");
260
261 res->data.an_vardecl.ident = ident->data.an_ident;
262 free(ident);
263 res->data.an_vardecl.l = l;
264 return res;
265 }
266
267 struct ast *ast_while(struct ast *pred, struct ast *body)
268 {
269 struct ast *res = ast_alloc();
270 res->type = an_while;
271 res->data.an_while.pred = pred;
272 must_be(body, an_list, "body of a while");
273 res->data.an_while.nbody = body->data.an_list.n;
274 res->data.an_while.body = body->data.an_list.ptr;
275 free(body);
276 return res;
277 }
278
279 int ast_llistlength(struct ast *r)
280 {
281 int i = 0;
282 while(r != NULL) {
283 i++;
284 if (r->type != an_cons) {
285 return 1;
286 }
287 r = r->data.an_cons.tail;
288 }
289 return i;
290 }
291
292 const char *cescapes[] = {
293 [0] = "\\0", [1] = "\\x01", [2] = "\\x02", [3] = "\\x03",
294 [4] = "\\x04", [5] = "\\x05", [6] = "\\x06", [7] = "\\a", [8] = "\\b",
295 [9] = "\\t", [10] = "\\n", [11] = "\\v", [12] = "\\f", [13] = "\\r",
296 [14] = "\\x0E", [15] = "\\x0F", [16] = "\\x10", [17] = "\\x11",
297 [18] = "\\x12", [19] = "\\x13", [20] = "\\x14", [21] = "\\x15",
298 [22] = "\\x16", [23] = "\\x17", [24] = "\\x18", [25] = "\\x19",
299 [26] = "\\x1A", [27] = "\\x1B", [28] = "\\x1C", [29] = "\\x1D",
300 [30] = "\\x1E", [31] = "\\x1F",
301 [127] = "\\x7F"
302 };
303
304 void ast_print(struct ast *ast, int indent, FILE *out)
305 {
306 if (ast == NULL)
307 return;
308 #ifdef DEBUG
309 fprintf(stderr, "ast_free(%s)\n", ast_type_str[ast->type]);
310 #endif
311 switch(ast->type) {
312 case an_assign:
313 pindent(indent, out);
314 safe_fprintf(out, "%s = ", ast->data.an_assign.ident);
315 ast_print(ast->data.an_assign.expr, indent, out);
316 safe_fprintf(out, ";\n");
317 break;
318 case an_binop:
319 safe_fprintf(out, "(");
320 ast_print(ast->data.an_binop.l, indent, out);
321 safe_fprintf(out, "%s", binop_str[ast->data.an_binop.op]);
322 ast_print(ast->data.an_binop.r, indent, out);
323 safe_fprintf(out, ")");
324 break;
325 case an_bool:
326 safe_fprintf(out, "%s", ast->data.an_bool ? "true" : "false");
327 break;
328 case an_char:
329 if (ast->data.an_char < 0)
330 safe_fprintf(out, "'?'");
331 if (ast->data.an_char < ' ' || ast->data.an_char == 127)
332 safe_fprintf(out, "'%s'",
333 cescapes[(int)ast->data.an_char]);
334 else
335 safe_fprintf(out, "'%c'", ast->data.an_char);
336 break;
337 case an_funcall:
338 safe_fprintf(out, "%s(", ast->data.an_funcall.ident);
339 for(int i = 0; i<ast->data.an_fundecl.nargs; i++) {
340 ast_print(ast->data.an_funcall.args[i], indent, out);
341 if (i+1 < ast->data.an_fundecl.nargs)
342 safe_fprintf(out, ", ");
343 }
344 safe_fprintf(out, ")");
345 break;
346 case an_fundecl:
347 pindent(indent, out);
348 safe_fprintf(out, "%s (", ast->data.an_fundecl.ident);
349 for (int i = 0; i<ast->data.an_fundecl.nargs; i++) {
350 safe_fprintf(out, "%s", ast->data.an_fundecl.args[i]);
351 if (i < ast->data.an_fundecl.nargs - 1)
352 safe_fprintf(out, ", ");
353 }
354 safe_fprintf(out, ") {\n");
355 for (int i = 0; i<ast->data.an_fundecl.nbody; i++)
356 ast_print(ast->data.an_fundecl.body[i], indent+1, out);
357 pindent(indent, out);
358 safe_fprintf(out, "}\n");
359 break;
360 case an_if:
361 pindent(indent, out);
362 safe_fprintf(out, "if (");
363 ast_print(ast->data.an_if.pred, indent, out);
364 safe_fprintf(out, ") {\n");
365 for (int i = 0; i<ast->data.an_if.nthen; i++)
366 ast_print(ast->data.an_if.then[i], indent+1, out);
367 pindent(indent, out);
368 safe_fprintf(out, "} else {\n");
369 for (int i = 0; i<ast->data.an_if.nels; i++)
370 ast_print(ast->data.an_if.els[i], indent+1, out);
371 pindent(indent, out);
372 safe_fprintf(out, "}\n");
373 break;
374 case an_int:
375 safe_fprintf(out, "%d", ast->data.an_int);
376 break;
377 case an_ident:
378 safe_fprintf(out, "%s", ast->data.an_ident);
379 break;
380 case an_cons:
381 ast_print(ast->data.an_cons.el, indent, out);
382 ast_print(ast->data.an_cons.tail, indent, out);
383 break;
384 case an_list:
385 for (int i = 0; i<ast->data.an_list.n; i++)
386 ast_print(ast->data.an_list.ptr[i], indent, out);
387 break;
388 case an_nil:
389 safe_fprintf(out, "[]");
390 break;
391 case an_return:
392 pindent(indent, out);
393 safe_fprintf(out, "return ");
394 ast_print(ast->data.an_return, indent, out);
395 safe_fprintf(out, ";\n");
396 break;
397 case an_stmt_expr:
398 pindent(indent, out);
399 ast_print(ast->data.an_stmt_expr, indent, out);
400 safe_fprintf(out, ";\n");
401 break;
402 case an_unop:
403 safe_fprintf(out, "(%s", unop_str[ast->data.an_unop.op]);
404 ast_print(ast->data.an_unop.l, indent, out);
405 safe_fprintf(out, ")");
406 break;
407 case an_vardecl:
408 pindent(indent, out);
409 safe_fprintf(out, "var %s = ", ast->data.an_vardecl.ident);
410 ast_print(ast->data.an_vardecl.l, indent, out);
411 safe_fprintf(out, ";\n");
412 break;
413 case an_while:
414 pindent(indent, out);
415 safe_fprintf(out, "while (");
416 ast_print(ast->data.an_while.pred, indent, out);
417 safe_fprintf(out, ") {\n");
418 for (int i = 0; i<ast->data.an_while.nbody; i++) {
419 ast_print(ast->data.an_while.body[i], indent+1, out);
420 }
421 pindent(indent, out);
422 safe_fprintf(out, "}\n");
423 break;
424 default:
425 die("Unsupported AST node\n");
426 }
427 }
428
429 void ast_free(struct ast *ast)
430 {
431 if (ast == NULL)
432 return;
433 #ifdef DEBUG
434 fprintf(stderr, "ast_free(%s)\n", ast_type_str[ast->type]);
435 #endif
436 switch(ast->type) {
437 case an_assign:
438 free(ast->data.an_assign.ident);
439 ast_free(ast->data.an_assign.expr);
440 break;
441 case an_binop:
442 ast_free(ast->data.an_binop.l);
443 ast_free(ast->data.an_binop.r);
444 break;
445 case an_bool:
446 break;
447 case an_char:
448 break;
449 case an_cons:
450 ast_free(ast->data.an_cons.el);
451 ast_free(ast->data.an_cons.tail);
452 break;
453 case an_funcall:
454 free(ast->data.an_funcall.ident);
455 for (int i = 0; i<ast->data.an_fundecl.nargs; i++)
456 free(ast->data.an_fundecl.args[i]);
457 free(ast->data.an_funcall.args);
458 break;
459 case an_fundecl:
460 free(ast->data.an_fundecl.ident);
461 for (int i = 0; i<ast->data.an_fundecl.nargs; i++)
462 free(ast->data.an_fundecl.args[i]);
463 free(ast->data.an_fundecl.args);
464 for (int i = 0; i<ast->data.an_fundecl.nbody; i++)
465 ast_free(ast->data.an_fundecl.body[i]);
466 free(ast->data.an_fundecl.body);
467 break;
468 case an_if:
469 ast_free(ast->data.an_if.pred);
470 for (int i = 0; i<ast->data.an_if.nthen; i++)
471 ast_free(ast->data.an_if.then[i]);
472 free(ast->data.an_if.then);
473 for (int i = 0; i<ast->data.an_if.nels; i++)
474 ast_free(ast->data.an_if.els[i]);
475 free(ast->data.an_if.els);
476 break;
477 case an_int:
478 break;
479 case an_ident:
480 free(ast->data.an_ident);
481 break;
482 case an_list:
483 for (int i = 0; i<ast->data.an_list.n; i++)
484 ast_free(ast->data.an_list.ptr[i]);
485 free(ast->data.an_list.ptr);
486 break;
487 case an_nil:
488 break;
489 case an_return:
490 ast_free(ast->data.an_return);
491 break;
492 case an_stmt_expr:
493 ast_free(ast->data.an_stmt_expr);
494 break;
495 case an_unop:
496 ast_free(ast->data.an_unop.l);
497 break;
498 case an_vardecl:
499 free(ast->data.an_vardecl.ident);
500 ast_free(ast->data.an_vardecl.l);
501 break;
502 case an_while:
503 ast_free(ast->data.an_while.pred);
504 for (int i = 0; i<ast->data.an_while.nbody; i++)
505 ast_free(ast->data.an_while.body[i]);
506 free(ast->data.an_while.body);
507 break;
508 default:
509 die("Unsupported AST node: %d\n", ast->type);
510 }
511 free(ast);
512 }