From 8cadc28fcce91c8b0323af1522685eb0d42b0242 Mon Sep 17 00:00:00 2001 From: Mart Lubbers Date: Mon, 15 Feb 2021 14:07:16 +0100 Subject: [PATCH] start with type inference --- Makefile | 3 +- type.c => sem.c | 5 +- sem.h | 8 ++ sem/hm.c | 279 +++++++++++++++++++++++++++++++++++++++++++++ sem/hm.h | 20 ++++ scc.c => sem/scc.c | 179 ++++++++++++++++++++++++++--- scc.h => sem/scc.h | 6 +- splc.c | 4 +- tarjan.c | 183 ----------------------------- tarjan.h | 39 ------- type.h | 8 -- 11 files changed, 479 insertions(+), 255 deletions(-) rename type.c => sem.c (92%) create mode 100644 sem.h create mode 100644 sem/hm.c create mode 100644 sem/hm.h rename scc.c => sem/scc.c (51%) rename scc.h => sem/scc.h (64%) delete mode 100644 tarjan.c delete mode 100644 tarjan.h delete mode 100644 type.h diff --git a/Makefile b/Makefile index fe1880a..74272aa 100644 --- a/Makefile +++ b/Makefile @@ -2,7 +2,8 @@ CFLAGS+=-Wall -Wextra -std=c99 -pedantic -D_XOPEN_SOURCE=700 -ggdb YFLAGS+=-d --locations -v --defines=parse.h LFLAGS+=--header-file=scan.h -OBJECTS:=scan.o parse.o ast.o util.o list.o type.o genc.o scc.o tarjan.o +OBJECTS:=scan.o parse.o ast.o util.o list.o sem.o genc.o \ + sem/scc.o sem/hm.o all: splc splc: $(OBJECTS) diff --git a/type.c b/sem.c similarity index 92% rename from type.c rename to sem.c index 0cf22e1..1cea48f 100644 --- a/type.c +++ b/sem.c @@ -2,7 +2,7 @@ #include #include "list.h" -#include "scc.h" +#include "sem/scc.h" #include "ast.h" void type_error(const char *msg, ...) @@ -55,10 +55,11 @@ struct decl *type_decl(struct decl *decl) return decl; } -struct ast *type(struct ast *ast) +struct ast *sem(struct ast *ast) { ast = ast_scc(ast); + //Check that all globals are constant for (int i = 0; indecls; i++) { if (ast->decls[i]->type == dvardecl) { //Check globals diff --git a/sem.h b/sem.h new file mode 100644 index 0000000..3785a9d --- /dev/null +++ b/sem.h @@ -0,0 +1,8 @@ +#ifndef SEM_H +#define SEM_H + +#include "ast.h" + +struct ast *sem(struct ast *ast); + +#endif diff --git a/sem/hm.c b/sem/hm.c new file mode 100644 index 0000000..d91715e --- /dev/null +++ b/sem/hm.c @@ -0,0 +1,279 @@ +#include +#include + +#include "hm.h" +#include "../util.h" +#include "../ast.h" + +struct gamma { + int fresh; + int nschemes; + char **vars; + struct scheme *schemes; +}; + +struct scheme *gamma_lookup(struct gamma *gamma, char *ident) +{ + for (int i = 0; ivars[i]) == 0) { + //TODO + } + } + return NULL; +} + +struct type *fresh(struct gamma *gamma) +{ + char *buf = safe_malloc(10); + sprintf(buf, "%d", gamma->fresh); + gamma->fresh++; + return type_var(buf); +} + +void ftv_type(struct type *r, int *nftv, char **ftv) +{ + switch (r->type) { + case tbasic: + break; + case tlist: + ftv_type(r->data.tlist, nftv, ftv); + break; + case ttuple: + ftv_type(r->data.ttuple.l, nftv, ftv); + ftv_type(r->data.ttuple.r, nftv, ftv); + break; + case tvar: + *ftv = realloc(*ftv, (*nftv+1)*sizeof(char *)); + if (*ftv == NULL) + perror("realloc"); + ftv[(*nftv)++] = r->data.tvar; + break; + } +} + +bool occurs_check(char *var, struct type *r) +{ + int nftv = 0; + char *ftv = NULL; + ftv_type(r, &nftv, &ftv); + for (int i = 0; itype) { + case tbasic: + break; + case tlist: + res->data.tlist = dup_type(r->data.tlist); + break; + case ttuple: + res->data.ttuple.l = dup_type(r->data.ttuple.l); + res->data.ttuple.r = dup_type(r->data.ttuple.r); + break; + case tvar: + res->data.tvar = strdup(r->data.tvar); + break; + } + return res; +} + +struct type *subst_apply_t(struct substitution *subst, struct type *l) +{ + switch (l->type) { + case tbasic: + break; + case tlist: + l->data.tlist = subst_apply_t(subst, l->data.tlist); + break; + case ttuple: + l->data.ttuple.l = subst_apply_t(subst, l->data.ttuple.l); + l->data.ttuple.r = subst_apply_t(subst, l->data.ttuple.r); + break; + case tvar: + for (int i = 0; invar; i++) { + if (strcmp(subst->vars[i], l->data.tvar) == 0) { + free(l->data.tvar); + struct type *r = subst->types[i]; + *l = *r; + free(r); + } + } + break; + } + return l; +} +struct gamma *subst_apply_g(struct substitution *subst, struct gamma *gamma) +{ + //TODO + return gamma; +} + +void subst_print(struct substitution *s, FILE *out) +{ + if (s == NULL) { + fprintf(out, "no substitution\n"); + } else { + fprintf(out, "["); + for (int i = 0; invar; i++) { + fprintf(out, "%s->", s->vars[i]); + type_print(s->types[i], out); + if (i + 1 < s->nvar) + fprintf(out, ", "); + } + fprintf(out, "]\n"); + } +} + +struct substitution *subst_id() +{ + struct substitution *res = safe_malloc(sizeof(struct substitution)); + res->nvar = 0; + res->vars = NULL; + res->types = NULL; + return res; +} + +struct substitution *subst_singleton(char *ident, struct type *t) +{ + struct substitution *res = safe_malloc(sizeof(struct substitution)); + res->nvar = 1; + res->vars = safe_malloc(sizeof(char *)); + res->vars[0] = safe_strdup(ident); + res->types = safe_malloc(sizeof(struct type *)); + res->types[0] = dup_type(t); + return res; +} + +struct substitution *subst_union(struct substitution *l, struct substitution *r) +{ + struct substitution *res = safe_malloc(sizeof(struct substitution)); + res->nvar = l->nvar+r->nvar; + res->vars = safe_malloc(res->nvar*sizeof(char *)); + res->types = safe_malloc(res->nvar*sizeof(struct type *)); + for (int i = 0; invar; i++) { + res->vars[i] = l->vars[i]; + res->types[i] = l->types[i]; + } + for (int i = 0; invar; i++) { + res->vars[l->nvar+i] = l->vars[i]; + res->types[l->nvar+i] = subst_apply_t(l, r->types[i]); + } + return res; +} + +struct substitution *unify(struct type *l, struct type *r) { + switch (l->type) { + case tbasic: + if (r->type == tbasic && l->data.tbasic == r->data.tbasic) + return subst_id(); + break; + case tlist: + if (r->type == tlist) + return unify(l->data.tlist, r->data.tlist); + break; + case ttuple: + if (r->type == ttuple) { + struct substitution *s1 = unify( + l->data.ttuple.l, + r->data.ttuple.l); + struct substitution *s2 = unify( + subst_apply_t(s1, l->data.ttuple.l), + subst_apply_t(s1, r->data.ttuple.l)); + return subst_union(s1, s2); + } + break; + case tvar: + if (r->type == tvar && strcmp(l->data.tvar, r->data.tvar) == 0) + return subst_id(); + else if (occurs_check(l->data.tvar, r)) + fprintf(stderr, "Infinite type %s\n", l->data.tvar); + else + return subst_singleton(l->data.tvar, r); + break; + } + return NULL; +} + +struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type) +{ + +#define infbop(l, r, a1, a2, rt, sigma) {\ + s1 = infer_expr(gamma, l, a1);\ + s2 = subst_union(s1, infer_expr(subst_apply_g(s1, gamma), r, a2));\ + return subst_union(s2, unify(subst_apply_t(s2, sigma), rt));\ +} +#define infbinop(e, a1, a2, rt, sigma)\ + infbop(e->data.ebinop.l, e->data.ebinop.r, a1, a2, rt, sigma) + struct substitution *s1, *s2; + struct type *f1, *f2; + switch (expr->type) { + case ebool: + return unify(type_basic(btbool), type); + case ebinop: + switch(expr->data.ebinop.op) { + case binor: + case binand: + infbinop(expr, type_basic(btbool), type_basic(btbool), + type_basic(btbool), type); + case eq: + case neq: + case leq: + case le: + case geq: + case ge: + f1 = fresh(gamma); + infbinop(expr, f1, f1, type_basic(btbool), type); + case cons: + f1 = fresh(gamma); + infbinop(expr, f1, type_list(f1), type_list(f1), type); + case plus: + case minus: + case times: + case divide: + case modulo: + case power: + infbinop(expr, type_basic(btint), type_basic(btint), + type_basic(btint), type); + } + case echar: + return unify(type_basic(btchar), type); + case efuncall: + //TODO + return NULL; + case eint: + return unify(type_basic(btint), type); + case eident: + + //TODO + return NULL; + case etuple: + f1 = fresh(gamma); + f2 = fresh(gamma); + infbop(expr->data.etuple.left, expr->data.etuple.right, + f1, f2, type_tuple(f1, f2), type); + case estring: + return unify(type_list(type_basic(btchar)), type); + case eunop: + switch(expr->data.eunop.op) { + case negate: + s1 = infer_expr(gamma, + expr->data.eunop.l, type_basic(btint)); + return subst_union(s1, + unify(subst_apply_t(s1, type), + type_basic(btint))); + case inverse: + s1 = infer_expr(gamma, + expr->data.eunop.l, type_basic(btbool)); + return subst_union(s1, + unify(subst_apply_t(s1, type), + type_basic(btbool))); + } + } + +} diff --git a/sem/hm.h b/sem/hm.h new file mode 100644 index 0000000..efb1afa --- /dev/null +++ b/sem/hm.h @@ -0,0 +1,20 @@ +#ifndef SEM_HM_C +#define SEM_HM_C + +#include "../ast.h" + +struct scheme { + struct type *type; + int nvar; + char **var; +}; + +struct substitution { + int nvar; + char **vars; + struct type **types; +}; + +struct ast *infer(struct ast *ast); + +#endif diff --git a/scc.c b/sem/scc.c similarity index 51% rename from scc.c rename to sem/scc.c index 64e9ad0..fa47590 100644 --- a/scc.c +++ b/sem/scc.c @@ -1,9 +1,150 @@ #include #include +#include -#include "ast.h" -#include "list.h" -#include "tarjan.h" +#include "../ast.h" +#include "../list.h" + +#ifndef min +#define min(x, y) ((x)<(y) ? (x) : (y)) +#endif + +struct edge { + void *from; + void *to; +}; + +struct components { + int nnodes; + void **nodes; + struct components *next; +}; + +struct node { + int index; + int lowlink; + bool onStack; + void *data; +}; + +struct tjstate { + int index; + int sp; + int nedges; + struct edge *edges; + struct node **stack; + struct components *head; + struct components *tail; +}; + +static int nodecmp(const void *l, const void *r) +{ + return (ptrdiff_t)l -(ptrdiff_t)((struct node *)r)->data; +} + +static int strongconnect(struct node *v, struct tjstate *tj) +{ + struct node *w; + + /* Set the depth index for v to the smallest unused index */ + v->index = tj->index; + v->lowlink = tj->index; + tj->index++; + tj->stack[tj->sp] = v; + tj->sp++; + v->onStack = true; + + for (int i = 0; inedges; i++) { + /* Only consider nodes reachable from v */ + if (tj->edges[i].from != v) + continue; + w = tj->edges[i].to; + /* Successor w has not yet been visited; recurse on it */ + if (w->index == -1) { + int r = strongconnect(w, tj); + if (r != 0) + return r; + v->lowlink = min(v->lowlink, w->lowlink); + /* Successor w is in stack S and hence in the current SCC */ + } else if (w->onStack) { + v->lowlink = min(v->lowlink, w->index); + } + } + + /* If v is a root node, pop the stack and generate an SCC */ + if (v->lowlink == v->index) { + struct components *ng = safe_malloc(sizeof(struct components)); + if (tj->tail == NULL) + tj->head = ng; + else + tj->tail->next = ng; + tj->tail = ng; + ng->next = NULL; + ng->nnodes = 0; + do { + tj->sp--; + w = tj->stack[tj->sp]; + w->onStack = false; + ng->nnodes++; + } while (w != v); + ng->nodes = safe_malloc(ng->nnodes*sizeof(void *)); + for (int i = 0; innodes; i++) + ng->nodes[i] = tj->stack[tj->sp+i]->data; + } + return 0; +} + +static int ptrcmp(const void *l, const void *r) +{ + return (ptrdiff_t)((struct node *)l)->data + - (ptrdiff_t)((struct node *)r)->data; +} + +/** + * Calculate the strongly connected components using Tarjan's algorithm: + * en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm + * + * Returns NULL when there are invalid edges + * + * @param number of nodes + * @param data of the nodes + * @param number of edges + * @param data of edges + */ +struct components *tarjans( + int nnodes, void *nodedata[], + int nedges, struct edge *edgedata[]) +{ + struct node nodes[nnodes]; + struct edge edges[nedges]; + struct node *stack[nnodes]; + struct node *from, *to; + struct tjstate tj = {0, 0, nedges, edges, stack, NULL, .tail=NULL}; + + // Populate the nodes + for (int i = 0; ifrom, nodes, nnodes, + sizeof(struct node), nodecmp); + if (from == NULL) + die("malformed from component of edge\n"); + to = bsearch(edgedata[i]->to, nodes, nnodes, + sizeof(struct node), nodecmp); + if (to == NULL) + die("malformed to component of edge\n"); + edges[i] = (struct edge){.from=from, .to=to}; + } + + //Tarjan's + for (int i = 0; i < nnodes; i++) + if (nodes[i].index == -1) + strongconnect(&nodes[i], &tj); + return tj.head; +} int iddeclcmp(const void *l, const void *r) { @@ -31,7 +172,7 @@ struct list *edges_expr(int ndecls, struct decl **decls, void *parent, struct decl **to = bsearch(expr->data.efuncall.ident, decls, ndecls, sizeof(struct decl *), iddeclcmp); if (to == NULL) { - fprintf(stderr, "calling an unknown function\n"); + die("calling an unknown function\n"); } else { struct edge *edge = safe_malloc(sizeof(struct edge)); edge->from = parent; @@ -127,21 +268,16 @@ struct ast *ast_scc(struct ast *ast) edges = edges_stmt(nfun, fundecls, fundecls[i], fundecls[i]->data.dfun->body[j], edges); int nedges; - struct edge **edgedata = (struct edge **) + struct edge **edata = (struct edge **) list_to_array(edges, &nedges, false); // Do tarjan's and convert back into the declaration list - int err; - struct components *cs = tarjans(nfun, (void **)fundecls, nedges, edgedata, &err); - if (cs == NULL) { - if (err == 1) - die("malformed edges in tarjan's????"); - else - pdie("malloc"); - } + struct components *cs = tarjans(nfun, (void **)fundecls, nedges, edata); + if (cs == NULL) + die("malformed edges in tarjan's????"); int i = ffun; - FOREACHCOMP(c, cs) { + for (struct components *c = cs; c != NULL; c = c->next) { struct decl *d = safe_malloc(sizeof(struct decl)); if (c->nnodes > 1) { d->type = dcomp; @@ -161,8 +297,17 @@ struct ast *ast_scc(struct ast *ast) //Cleanup for (int i = 0; innodes; i++) + free(cs->nodes[i]); + free(cs->nodes); + t = cs->next; + free(cs); + cs = t; + } return ast; } diff --git a/scc.h b/sem/scc.h similarity index 64% rename from scc.h rename to sem/scc.h index 919e440..784ed30 100644 --- a/scc.h +++ b/sem/scc.h @@ -1,7 +1,7 @@ -#ifndef SCC_H -#define SCC_H +#ifndef SEM_SCC_H +#define SEM_SCC_H -#include "ast.h" +#include "../ast.h" // Split up the AST in strongly connected components struct ast *ast_scc(struct ast *ast); diff --git a/splc.c b/splc.c index aeb2e89..6e5c5ab 100644 --- a/splc.c +++ b/splc.c @@ -6,7 +6,7 @@ #include "genc.h" #include "parse.h" #include "scan.h" -#include "type.h" +#include "sem.h" extern int yylex_destroy(void); void usage(FILE *out, char *arg0) @@ -78,7 +78,7 @@ int main(int argc, char *argv[]) ast_print(result, stdout); //Typecheck - if ((result = type(result)) == NULL) { + if ((result = sem(result)) == NULL) { return 1; } if (ptype) diff --git a/tarjan.c b/tarjan.c deleted file mode 100644 index f399b84..0000000 --- a/tarjan.c +++ /dev/null @@ -1,183 +0,0 @@ -#include -#include -#include - -#ifndef min -#define min(x, y) ((x)<(y) ? (x) : (y)) -#endif - -struct edge { - void *from; - void *to; -}; - -struct components { - int nnodes; - void **nodes; - struct components *next; -}; - -struct node { - int index; - int lowlink; - bool onStack; - void *data; -}; - -struct tjstate { - int index; - int sp; - int nedges; - struct edge *edges; - struct node **stack; - struct components *head; - struct components *tail; -}; - -static int nodecmp(const void *l, const void *r) -{ - return (ptrdiff_t)l -(ptrdiff_t)((struct node *)r)->data; -} - -static int strongconnect(struct node *v, struct tjstate *tj) -{ - struct node *w; - - /* Set the depth index for v to the smallest unused index */ - v->index = tj->index; - v->lowlink = tj->index; - tj->index++; - tj->stack[tj->sp] = v; - tj->sp++; - v->onStack = true; - - for (int i = 0; inedges; i++) { - /* Only consider nodes reachable from v */ - if (tj->edges[i].from != v) { - continue; - } - w = tj->edges[i].to; - /* Successor w has not yet been visited; recurse on it */ - if (w->index == -1) { - int r = strongconnect(w, tj); - if (r != 0) - return r; - v->lowlink = min(v->lowlink, w->lowlink); - /* Successor w is in stack S and hence in the current SCC */ - } else if (w->onStack) { - v->lowlink = min(v->lowlink, w->index); - } - } - - /* If v is a root node, pop the stack and generate an SCC */ - if (v->lowlink == v->index) { - struct components *ng = malloc(sizeof(struct components)); - if (ng == NULL) { - return 2; - } - if (tj->tail == NULL) { - tj->head = ng; - } else { - tj->tail->next = ng; - } - tj->tail = ng; - ng->next = NULL; - ng->nnodes = 0; - do { - tj->sp--; - w = tj->stack[tj->sp]; - w->onStack = false; - ng->nnodes++; - } while (w != v); - ng->nodes = malloc(ng->nnodes*sizeof(void *)); - if (ng == NULL) { - return 2; - } - for (int i = 0; innodes; i++) { - ng->nodes[i] = tj->stack[tj->sp+i]->data; - } - } - return 0; -} - -static int ptrcmp(const void *l, const void *r) -{ - return (ptrdiff_t)((struct node *)l)->data - - (ptrdiff_t)((struct node *)r)->data; -} - -/** - * Calculate the strongly connected components using Tarjan's algorithm: - * en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm - * - * Returns NULL when there are invalid edges and sets the error to: - * 1 if there was a malformed edge - * 2 if malloc failed - * - * @param number of nodes - * @param data of the nodes - * @param number of edges - * @param data of edges - * @param pointer to error code - */ -struct components *tarjans( - int nnodes, void *nodedata[], - int nedges, struct edge *edgedata[], - int *error) -{ - struct node nodes[nnodes]; - struct edge edges[nedges]; - struct node *stack[nnodes]; - struct node *from, *to; - struct tjstate tj = {0, 0, nedges, edges, stack, NULL, .tail=NULL}; - - // Populate the nodes - for (int i = 0; ifrom, nodes, nnodes, - sizeof(struct node), nodecmp); - if (from == NULL) { - *error = 1; - return NULL; - } - to = bsearch(edgedata[i]->to, nodes, nnodes, - sizeof(struct node), nodecmp); - if (to == NULL) { - *error = 1; - return NULL; - } - edges[i] = (struct edge){.from=from, .to=to}; - } - - //Tarjan's - for (int i = 0; i < nnodes; i++) { - if (nodes[i].index == -1) { - *error = strongconnect(&nodes[i], &tj); - if (*error != 0) - return NULL; - } - } - return tj.head; -} - -void components_free(struct components *cs, void (*freefun)(void *)) -{ - struct components *t; - - while (cs != NULL) { - if (freefun != NULL) { - for (int i = 0; innodes; i++) { - freefun(cs->nodes[i]); - } - } - free(cs->nodes); - t = cs->next; - free(cs); - cs = t; - } -} diff --git a/tarjan.h b/tarjan.h deleted file mode 100644 index 3bcf0b8..0000000 --- a/tarjan.h +++ /dev/null @@ -1,39 +0,0 @@ -#ifndef SCC_H -#define SCC_H - -struct edge { - void *from; - void *to; -}; - -struct components { - int nnodes; - void **nodes; - struct components *next; -}; - -#define FOREACHCOMP(x, l) for(struct components *x = l; x != NULL; x = x->next) - -/** - * Calculate the strongly connected components using Tarjan's algorithm: - * en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm - * - * Returns NULL when there are invalid edges - * - * @param number of nodes - * @param data of the nodes - * @param number of edges - * @param data of edges - */ -struct components *tarjans(int nnodes, void *nodedata[], int nedges, - struct edge *edgedata[], int *error); - -/** - * Free a list of components - * - * @param cs components - * @param freefun function to free the data with, if NULL, data isn't freed - */ -void components_free(struct components *cs, void (*freefun)(void *)); - -#endif diff --git a/type.h b/type.h deleted file mode 100644 index 2cdaca7..0000000 --- a/type.h +++ /dev/null @@ -1,8 +0,0 @@ -#ifndef TYPE_H -#define TYPE_H - -#include "ast.h" - -struct ast *type(struct ast *ast); - -#endif -- 2.20.1