start with type inference
[ccc.git] / sem / scc.c
1 #include <string.h>
2 #include <stdlib.h>
3 #include <stddef.h>
4
5 #include "../ast.h"
6 #include "../list.h"
7
8 #ifndef min
9 #define min(x, y) ((x)<(y) ? (x) : (y))
10 #endif
11
12 struct edge {
13 void *from;
14 void *to;
15 };
16
17 struct components {
18 int nnodes;
19 void **nodes;
20 struct components *next;
21 };
22
23 struct node {
24 int index;
25 int lowlink;
26 bool onStack;
27 void *data;
28 };
29
30 struct tjstate {
31 int index;
32 int sp;
33 int nedges;
34 struct edge *edges;
35 struct node **stack;
36 struct components *head;
37 struct components *tail;
38 };
39
40 static int nodecmp(const void *l, const void *r)
41 {
42 return (ptrdiff_t)l -(ptrdiff_t)((struct node *)r)->data;
43 }
44
45 static int strongconnect(struct node *v, struct tjstate *tj)
46 {
47 struct node *w;
48
49 /* Set the depth index for v to the smallest unused index */
50 v->index = tj->index;
51 v->lowlink = tj->index;
52 tj->index++;
53 tj->stack[tj->sp] = v;
54 tj->sp++;
55 v->onStack = true;
56
57 for (int i = 0; i<tj->nedges; i++) {
58 /* Only consider nodes reachable from v */
59 if (tj->edges[i].from != v)
60 continue;
61 w = tj->edges[i].to;
62 /* Successor w has not yet been visited; recurse on it */
63 if (w->index == -1) {
64 int r = strongconnect(w, tj);
65 if (r != 0)
66 return r;
67 v->lowlink = min(v->lowlink, w->lowlink);
68 /* Successor w is in stack S and hence in the current SCC */
69 } else if (w->onStack) {
70 v->lowlink = min(v->lowlink, w->index);
71 }
72 }
73
74 /* If v is a root node, pop the stack and generate an SCC */
75 if (v->lowlink == v->index) {
76 struct components *ng = safe_malloc(sizeof(struct components));
77 if (tj->tail == NULL)
78 tj->head = ng;
79 else
80 tj->tail->next = ng;
81 tj->tail = ng;
82 ng->next = NULL;
83 ng->nnodes = 0;
84 do {
85 tj->sp--;
86 w = tj->stack[tj->sp];
87 w->onStack = false;
88 ng->nnodes++;
89 } while (w != v);
90 ng->nodes = safe_malloc(ng->nnodes*sizeof(void *));
91 for (int i = 0; i<ng->nnodes; i++)
92 ng->nodes[i] = tj->stack[tj->sp+i]->data;
93 }
94 return 0;
95 }
96
97 static int ptrcmp(const void *l, const void *r)
98 {
99 return (ptrdiff_t)((struct node *)l)->data
100 - (ptrdiff_t)((struct node *)r)->data;
101 }
102
103 /**
104 * Calculate the strongly connected components using Tarjan's algorithm:
105 * en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
106 *
107 * Returns NULL when there are invalid edges
108 *
109 * @param number of nodes
110 * @param data of the nodes
111 * @param number of edges
112 * @param data of edges
113 */
114 struct components *tarjans(
115 int nnodes, void *nodedata[],
116 int nedges, struct edge *edgedata[])
117 {
118 struct node nodes[nnodes];
119 struct edge edges[nedges];
120 struct node *stack[nnodes];
121 struct node *from, *to;
122 struct tjstate tj = {0, 0, nedges, edges, stack, NULL, .tail=NULL};
123
124 // Populate the nodes
125 for (int i = 0; i<nnodes; i++)
126 nodes[i] = (struct node){-1, -1, false, nodedata[i]};
127 qsort(nodes, nnodes, sizeof(struct node), ptrcmp);
128
129 // Populate the edges
130 for (int i = 0; i<nedges; i++) {
131 from = bsearch(edgedata[i]->from, nodes, nnodes,
132 sizeof(struct node), nodecmp);
133 if (from == NULL)
134 die("malformed from component of edge\n");
135 to = bsearch(edgedata[i]->to, nodes, nnodes,
136 sizeof(struct node), nodecmp);
137 if (to == NULL)
138 die("malformed to component of edge\n");
139 edges[i] = (struct edge){.from=from, .to=to};
140 }
141
142 //Tarjan's
143 for (int i = 0; i < nnodes; i++)
144 if (nodes[i].index == -1)
145 strongconnect(&nodes[i], &tj);
146 return tj.head;
147 }
148
149 int iddeclcmp(const void *l, const void *r)
150 {
151 return strcmp((char *)l, (*(struct decl **)r)->data.dfun->ident);
152 }
153
154 struct list *edges_expr(int ndecls, struct decl **decls, void *parent,
155 struct expr *expr, struct list *l)
156 {
157 if (expr == NULL)
158 return l;
159 switch(expr->type) {
160 case ebinop:
161 l = edges_expr(ndecls, decls, parent, expr->data.ebinop.l, l);
162 l = edges_expr(ndecls, decls, parent, expr->data.ebinop.r, l);
163 break;
164 case ebool:
165 break;
166 case echar:
167 break;
168 case efuncall:
169 for(int i = 0; i<expr->data.efuncall.nargs; i++)
170 l = edges_expr(ndecls, decls, parent,
171 expr->data.efuncall.args[i], l);
172 struct decl **to = bsearch(expr->data.efuncall.ident, decls,
173 ndecls, sizeof(struct decl *), iddeclcmp);
174 if (to == NULL) {
175 die("calling an unknown function\n");
176 } else {
177 struct edge *edge = safe_malloc(sizeof(struct edge));
178 edge->from = parent;
179 edge->to = (void *)*to;
180 l = list_cons(edge, l);
181 }
182 break;
183 case eint:
184 break;
185 case eident:
186 break;
187 case enil:
188 break;
189 case etuple:
190 l = edges_expr(ndecls, decls, parent,
191 expr->data.etuple.left, l);
192 l = edges_expr(ndecls, decls, parent,
193 expr->data.etuple.right, l);
194 break;
195 case estring:
196 break;
197 case eunop:
198 l = edges_expr(ndecls, decls, parent, expr->data.eunop.l, l);
199 break;
200 default:
201 die("Unsupported expr node\n");
202 }
203 return l;
204 }
205
206 struct list *edges_stmt(int ndecls, struct decl **decls, void *parent,
207 struct stmt *stmt, struct list *l)
208 {
209 switch(stmt->type) {
210 case sassign:
211 l = edges_expr(ndecls, decls, parent,
212 stmt->data.sassign.expr, l);
213 break;
214 case sif:
215 l = edges_expr(ndecls, decls, parent, stmt->data.sif.pred, l);
216 for (int i = 0; i<stmt->data.sif.nthen; i++)
217 l = edges_stmt(ndecls, decls, parent,
218 stmt->data.sif.then[i], l);
219 for (int i = 0; i<stmt->data.sif.nels; i++)
220 l = edges_stmt(ndecls, decls, parent,
221 stmt->data.sif.els[i], l);
222 break;
223 case sreturn:
224 l = edges_expr(ndecls, decls, parent, stmt->data.sreturn, l);
225 break;
226 case sexpr:
227 l = edges_expr(ndecls, decls, parent, stmt->data.sexpr, l);
228 break;
229 case svardecl:
230 l = edges_expr(ndecls, decls, parent,
231 stmt->data.svardecl->expr, l);
232 break;
233 case swhile:
234 l = edges_expr(ndecls, decls, parent,
235 stmt->data.swhile.pred, l);
236 for (int i = 0; i<stmt->data.swhile.nbody; i++)
237 l = edges_stmt(ndecls, decls, parent,
238 stmt->data.swhile.body[i], l);
239 break;
240 default:
241 die("Unsupported stmt node\n");
242 }
243 return l;
244 }
245
246 int declcmp(const void *l, const void *r)
247 {
248 return (*(struct decl **)l)->type - (*(struct decl **)r)->type;
249 }
250
251 struct ast *ast_scc(struct ast *ast)
252 {
253 //Sort so that the functions are at the end
254 qsort(ast->decls, ast->ndecls, sizeof(struct decl *), declcmp);
255 //Index of the first function
256 int ffun;
257 for (ffun = 0; ffun<ast->ndecls; ffun++)
258 if (ast->decls[ffun]->type == dfundecl)
259 break;
260 //Number of functions
261 int nfun = ast->ndecls-ffun;
262
263 //Calculate the edges
264 struct decl **fundecls = ast->decls+ffun;
265 struct list *edges = NULL;
266 for (int i = 0; i<nfun; i++)
267 for (int j = 0; j<fundecls[i]->data.dfun->nbody; j++)
268 edges = edges_stmt(nfun, fundecls, fundecls[i],
269 fundecls[i]->data.dfun->body[j], edges);
270 int nedges;
271 struct edge **edata = (struct edge **)
272 list_to_array(edges, &nedges, false);
273
274 // Do tarjan's and convert back into the declaration list
275 struct components *cs = tarjans(nfun, (void **)fundecls, nedges, edata);
276 if (cs == NULL)
277 die("malformed edges in tarjan's????");
278
279 int i = ffun;
280 for (struct components *c = cs; c != NULL; c = c->next) {
281 struct decl *d = safe_malloc(sizeof(struct decl));
282 if (c->nnodes > 1) {
283 d->type = dcomp;
284 d->data.dcomp.ndecls = c->nnodes;
285 d->data.dcomp.decls = safe_malloc(
286 c->nnodes*sizeof(struct fundecl *));
287 for (int i = 0; i<c->nnodes; i++)
288 d->data.dcomp.decls[i] =
289 ((struct decl *)c->nodes[i])->data.dfun;
290 } else {
291 d->type = dfundecl;
292 d->data.dfun = ((struct decl *)c->nodes[0])->data.dfun;
293 }
294 ast->decls[i++] = d;
295 }
296 ast->ndecls = i;
297
298 //Cleanup
299 for (int i = 0; i<nedges; i++)
300 free(edata[i]);
301 free(edata);
302
303 struct components *t;
304 while (cs != NULL) {
305 for (int i = 0; i<cs->nnodes; i++)
306 free(cs->nodes[i]);
307 free(cs->nodes);
308 t = cs->next;
309 free(cs);
310 cs = t;
311 }
312 return ast;
313 }