From: Mart Lubbers Date: Thu, 25 Feb 2021 12:41:26 +0000 (+0100) Subject: fix inference and scc X-Git-Url: https://git.martlubbers.net/?a=commitdiff_plain;h=95c7ee6dcbf37a05d860934818f76848f4057ac3;p=ccc.git fix inference and scc --- diff --git a/sem.c b/sem.c index 9df5bbf..ffb7019 100644 --- a/sem.c +++ b/sem.c @@ -46,12 +46,69 @@ struct vardecl *type_vardecl(struct gamma *gamma, struct vardecl *vardecl) if (s == NULL) die("error inferring variable\n"); vardecl->type = subst_apply_t(s, t); + gamma_insert(gamma, vardecl->ident, scheme_create(vardecl->type)); subst_free(s); return vardecl; } +void type_fundecl(struct gamma *gamma, struct fundecl *decl) +{ + struct type *f1 = gamma_fresh(gamma); + struct subst *s1 = infer_fundecl(gamma, decl, f1); + f1 = subst_apply_t(s1, f1); + + gamma_insert(gamma, decl->ident, scheme_generalise(gamma, f1)); + subst_free(s1); + type_free(f1); +} + +void type_comp(struct gamma *gamma, int ndecls, struct fundecl **decl) +{ + //Create a fresh variable for every function in the component + struct type **fs = safe_malloc(ndecls*sizeof(struct type *)); + for (int i = 0; irtype == NULL || decl[i]->atypes == NULL; + fs[i] = fresh ? gamma_fresh(gamma) : decl[i]->rtype; + for (int j = 0; jnargs; j++) { + struct type *a = fresh ? gamma_fresh(gamma) + : type_dup(decl[i]->atypes[j]); + fs[i] = type_arrow(a, fs[i]); + } + gamma_insert(gamma, decl[i]->ident, scheme_create(fs[i])); + } + + //Infer each function + struct subst *s0 = subst_id(); + for (int i = 0; iident, scheme_generalise(gamma, t)); + + decl[i]->atypes = safe_malloc(decl[i]->nargs*sizeof(struct type *)); + decl[i]->natypes = decl[i]->nargs; + for (int j = 0; jnargs; j++) { + decl[i]->atypes[j] = type_dup(t->data.tarrow.l); + t = t->data.tarrow.r; + } + decl[i]->rtype = type_dup(t); + } + + //Free all types + for (int i = 0; idecls[i]->data.dvar); break; - case dfundecl: { - struct type *f1 = gamma_fresh(gamma); - struct subst *s = infer_fundecl(gamma, - ast->decls[i]->data.dfun, f1); - f1 = subst_apply_t(s, f1); - gamma_insert(gamma, ast->decls[i]->data.dfun->ident, - scheme_generalise(gamma, f1)); - gamma_print(gamma, stderr); - fprintf(stderr, "done\n"); - subst_free(s); - type_free(f1); + case dfundecl: + //Infer function as singleton component + type_comp(gamma, 1, &ast->decls[i]->data.dfun); break; - } case dcomp: + //Infer function as singleton component + type_comp(gamma, ast->decls[i]->data.dcomp.ndecls, + ast->decls[i]->data.dcomp.decls); break; } } - gamma_free(gamma); return ast; diff --git a/sem/hm.c b/sem/hm.c index 8de1088..45104e5 100644 --- a/sem/hm.c +++ b/sem/hm.c @@ -30,37 +30,23 @@ struct subst *unify(YYLTYPE loc, struct type *l, struct type *r) return NULL; if (r->type == tvar && l->type != tvar) return unify(loc, r, l); - struct subst *s1, *s2, *s3; - switch (l->type) { - case tarrow: - if (r->type == tarrow) { - s1 = unify(loc, l->data.tarrow.l, r->data.tarrow.l); - s2 = unify(loc, subst_apply_t(s1, l->data.tarrow.l), - subst_apply_t(s1, r->data.tarrow.l)); - s3 = subst_union(s1, s2); - return s3; - } - break; - case tbasic: - if (r->type == tbasic && l->data.tbasic == r->data.tbasic) - return subst_id(); - break; - case tlist: - if (r->type == tlist) - return unify(loc, l->data.tlist, r->data.tlist); - break; - case ttuple: - if (r->type == ttuple) { - s1 = unify(loc, l->data.ttuple.l, r->data.ttuple.l); - s2 = unify(loc, subst_apply_t(s1, l->data.ttuple.l), - subst_apply_t(s1, r->data.ttuple.l)); - s3 = subst_union(s1, s2); - subst_free(s1); - subst_free(s2); - return s3; - } - break; - case tvar: + struct subst *s1, *s2; + if (l->type == tarrow && r->type == tarrow) { + s1 = unify(loc, l->data.tarrow.l, r->data.tarrow.l); + s2 = unify(loc, subst_apply_t(s1, l->data.tarrow.l), + subst_apply_t(s1, r->data.tarrow.l)); + return subst_union(s2, s1); + } else if (l->type == tbasic && r->type == tbasic + && l->data.tbasic == r->data.tbasic) { + return subst_id(); + } else if (l->type == tlist && r->type == tlist) { + return unify(loc, l->data.tlist, r->data.tlist); + } else if (l->type == ttuple && r->type == ttuple) { + s1 = unify(loc, l->data.ttuple.l, r->data.ttuple.l); + s2 = unify(loc, subst_apply_t(s1, l->data.ttuple.r), + subst_apply_t(s1, r->data.ttuple.r)); + return subst_union(s2, s1); + } else if (l->type == tvar) { if (r->type == tvar && strcmp(l->data.tvar, r->data.tvar) == 0) return subst_id(); else if (occurs_check(l->data.tvar, r)) @@ -68,13 +54,13 @@ struct subst *unify(YYLTYPE loc, struct type *l, struct type *r) l->data.tvar); else return subst_singleton(l->data.tvar, r); - break; + } else { + type_error(loc, false, "cannot unify "); + type_print(l, stderr); + fprintf(stderr, " with "); + type_print(r, stderr); + die("\n"); } - type_error(loc, false, "cannot unify "); - type_print(l, stderr); - fprintf(stderr, " with "); - type_print(r, stderr); - die("\n"); return NULL; } @@ -94,14 +80,9 @@ struct subst *infer_binop(struct gamma *gamma, struct expr *l, struct expr *r, { struct subst *s1 = infer_expr(gamma, l, a1); struct subst *s2 = infer_expr(subst_apply_g(s1, gamma), r, a2); - struct subst *s3 = subst_union(s1, s2); + struct subst *s3 = subst_union(s2, s1); struct subst *s4 = unify(l->loc, subst_apply_t(s3, sigma), rt); - struct subst *s5 = subst_union(s3, s4); - subst_free(s1); - subst_free(s2); - subst_free(s3); - subst_free(s4); - return s5; + return subst_union(s4, s3); } struct subst *infer_unop(struct gamma *gamma, struct expr *e, @@ -109,10 +90,7 @@ struct subst *infer_unop(struct gamma *gamma, struct expr *e, { struct subst *s1 = infer_expr(gamma, e, a); struct subst *s2 = unify(e->loc, subst_apply_t(s1, sigma), rt); - struct subst *s3 = subst_union(s1, s2); - subst_free(s1); - subst_free(s2); - return s3; + return subst_union(s2, s1); } static struct type tybool = {.type=tbasic, .data={.tbasic=btbool}}; @@ -122,17 +100,9 @@ static struct type tystring = {.type=tlist, .data={.tlist=&tychar}}; struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type) { - fprintf(stderr, "infer expr: "); - expr_print(expr, stderr); - fprintf(stderr, "\ngamma: "); - gamma_print(gamma, stderr); - fprintf(stderr, "\ntype: "); - type_print(type, stderr); - fprintf(stderr, "\n"); - -#define infbinop(e, a1, a2, rt) infer_binop(\ +#define infbinop(e, a1, a2, rt, type) infer_binop(\ gamma, e->data.ebinop.l, e->data.ebinop.r, a1, a2, rt, type) - struct subst *s0; + struct subst *s0, *s1; struct type *f1, *f2, *f3; struct scheme *s; switch (expr->type) { @@ -142,7 +112,7 @@ struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *ty switch (expr->data.ebinop.op) { case binor: case binand: - return infbinop(expr, &tybool, &tybool, &tybool); + return infbinop(expr, &tybool, &tybool, &tybool, type); case eq: case neq: case leq: @@ -150,13 +120,13 @@ struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *ty case geq: case ge: f1 = gamma_fresh(gamma); - s0 = infbinop(expr, f1, f1, &tybool); + s0 = infbinop(expr, f1, f1, &tybool, type); type_free(f1); return s0; case cons: f1 = gamma_fresh(gamma); f2 = type_list(f1); - s0 = infbinop(expr, f1, f2, f2); + s0 = infbinop(expr, f1, f2, f2, type); type_free(f2); return s0; case plus: @@ -165,7 +135,7 @@ struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *ty case divide: case modulo: case power: - return infbinop(expr, &tyint, &tyint, &tyint); + return infbinop(expr, &tyint, &tyint, &tyint, type); } break; case echar: @@ -175,24 +145,28 @@ struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *ty type_error(expr->loc, "Unbound function: %s\n" , expr->data.efuncall.ident); struct type *t = scheme_instantiate(gamma, s); + struct subst *s0 = subst_id(); + //Infer args for (int i = 0; idata.efuncall.nargs; i++) { if (t->type != tarrow) type_error(expr->loc, true, "too many arguments to %s\n", expr->data.efuncall.ident); - struct subst *s1 = infer_expr(gamma, + s1 = infer_expr(gamma, expr->data.efuncall.args[i], t->data.tarrow.l); - struct subst *s2 = s0; - s0 = subst_union(s2, s1); - subst_free(s1); - subst_free(s2); - t = t->data.tarrow.r; + s0 = subst_union(s1, s0); + subst_apply_g(s0, gamma); + t = subst_apply_t(s0, t->data.tarrow.r); } if (t->type == tarrow) type_error(expr->loc, true, "not enough arguments to %s\n", expr->data.efuncall.ident); + + //Infer return type + s1 = unify(expr->loc, t, type); + s0 = subst_union(s1, s0); type_free(t); //TODO fields return s0; @@ -235,103 +209,72 @@ struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *ty struct subst *infer_stmt(struct gamma *gamma, struct stmt *stmt, struct type *type) { - fprintf(stderr, "infer stmt: "); - stmt_print(stmt, 0, stderr); - fprintf(stderr, "\ngamma: "); - gamma_print(gamma, stderr); - fprintf(stderr, "\ntype: "); - type_print(type, stderr); - fprintf(stderr, "\n"); - - struct subst *s1;//, *s2; -// struct type *f1, *f2, *f3; -// struct scheme *s; + struct subst *s0, *s1, *s2; + struct type *f1; switch (stmt->type) { case sassign: break; case sif: - s1 = infer_expr(gamma, stmt->data.sif.pred, &tybool); - subst_apply_g(s1, gamma); - subst_free(s1); - break; + s0 = infer_expr(gamma, stmt->data.sif.pred, &tybool); + subst_apply_g(s0, gamma); + + for (int i = 0; idata.sif.nthen; i++) { + s1 = infer_stmt(gamma, stmt->data.sif.then[i], type); + s0 = subst_union(s1, s0); + subst_apply_g(s0, gamma); + } + + for (int i = 0; idata.sif.nels; i++) { + s1 = infer_stmt(gamma, stmt->data.sif.els[i], type); + s0 = subst_union(s1, s0); + subst_apply_g(s0, gamma); + } + return s0; case sreturn: return infer_expr(gamma, stmt->data.sreturn, type); case sexpr: - break; + f1 = gamma_fresh(gamma); + s0 = infer_expr(gamma, stmt->data.sexpr, f1); + type_free(f1); + return s0; case svardecl: break; case swhile: - break; + s0 = infer_expr(gamma, stmt->data.swhile.pred, &tybool); + subst_apply_g(s0, gamma); + + for (int i = 0; idata.swhile.nbody; i++) { + s1 = infer_stmt(gamma, stmt->data.swhile.body[i], type); + s0 = subst_union(s1, s0); + subst_apply_g(s0, gamma); + } + + return s0; } return subst_id(); -// union { -// struct { -// char *ident; -// int nfields; -// char **fields; -// struct expr *expr; -// } sassign; -// struct { -// struct expr *pred; -// int nthen; -// struct stmt **then; -// int nels; -// struct stmt **els; -// } sif; -// struct vardecl *svardecl; -// struct expr *sreturn; -// struct expr *sexpr; -// struct { -// struct expr *pred; -// int nbody; -// struct stmt **body; -// } swhile; -// } data; } -struct subst *infer_fundecl(struct gamma *gamma, struct fundecl *fundecl, struct type *fty) +struct subst *infer_fundecl(struct gamma *gamma, struct fundecl *fundecl, struct type *ftype) { - fprintf(stderr, "inferring function to type "); - type_print(fty, stderr); - fprintf(stderr, " with gamma "); - gamma_print(gamma, stderr); - fprintf(stderr, "\n"); - if (fundecl->rtype == NULL || fundecl->atypes == NULL) { - fundecl->rtype = gamma_fresh(gamma); - fundecl->atypes = safe_realloc(fundecl->atypes, - fundecl->nargs*sizeof(struct type)); - for (int i = 0; inargs; i++) - fundecl->atypes[i] = gamma_fresh(gamma); - } - - struct type *ftype = type_dup(fundecl->rtype); + // Put arguments in gamma + struct type *at = ftype; for (int i = 0; inargs; i++) { - ftype = type_arrow(type_dup(fundecl->atypes[i]), ftype); + if (at->type != tarrow) + die("malformed ftype\n"); gamma_insert(gamma, fundecl->args[i], - scheme_create(fundecl->atypes[i])); + scheme_create(at->data.tarrow.l)); + at = at->data.tarrow.r; } - gamma_insert(gamma, fundecl->ident, scheme_create(ftype)); + if (at->type == tarrow) + die("malformed ftype\n"); struct subst *s = subst_id(); for (int i = 0; inbody; i++) { struct subst *s1 = infer_stmt(gamma, - fundecl->body[i], fundecl->rtype); - struct subst *s2 = s; - s = subst_union(s2, s1); - subst_free(s1); - subst_free(s2); + fundecl->body[i], at); + s = subst_union(s1, s); + subst_apply_g(s, gamma); } - for (int i = 0; inargs; i++) { - struct type *t = subst_apply_t(s, fundecl->atypes[i]); - type_free(fundecl->atypes[i]); - fundecl->atypes[i] = t; - } - fundecl->rtype = subst_apply_t(s, fundecl->rtype); - - ftype = subst_apply_t(s, ftype); - struct subst *r = unify(fundecl->loc, fty, ftype); - type_free(ftype); - subst_free(s); - return r; + return s; } diff --git a/sem/hm/scheme.c b/sem/hm/scheme.c index 8d65075..a374f6b 100644 --- a/sem/hm/scheme.c +++ b/sem/hm/scheme.c @@ -6,9 +6,9 @@ struct type *scheme_instantiate(struct gamma *gamma, struct scheme *sch) { struct subst *s = subst_id(); - for (int i = 0; invar; i++) { - s = subst_union(s, subst_singleton(sch->var[i], gamma_fresh(gamma))); - } + for (int i = 0; invar; i++) + subst_insert(s, safe_strdup(sch->var[i]), gamma_fresh(gamma)); + struct type *t = subst_apply_t(s, type_dup(sch->type)); for (int i = 0; invar; i++) free(s->vars[i]); @@ -57,8 +57,11 @@ void scheme_print(struct scheme *scheme, FILE *out) } if (scheme->nvar > 0) { fprintf(out, "A."); - for (int i = 0; invar; i++) + for (int i = 0; invar; i++) { + if (i > 0) + fprintf(out, " "); fprintf(out, "%s", scheme->var[i]); + } fprintf(out, ": "); } type_print(scheme->type, out); diff --git a/sem/hm/subst.c b/sem/hm/subst.c index e8fb35c..f232c25 100644 --- a/sem/hm/subst.c +++ b/sem/hm/subst.c @@ -3,43 +3,59 @@ #include "../hm.h" +#define INCAP 10 + struct subst *subst_id() { struct subst *res = safe_malloc(sizeof(struct subst)); res->nvar = 0; - res->vars = NULL; - res->types = NULL; + res->capacity = INCAP; + res->vars = safe_malloc(INCAP*sizeof(char *)); + res->types = safe_malloc(INCAP*sizeof(struct type *)); return res; } +struct subst *subst_insert(struct subst *s, char *ident, struct type *t) +{ + int i = 0; + while (i < s->nvar) { + if (strcmp(s->vars[i], ident) == 0) { + free(s->vars[i]); + s->vars[i] = safe_strdup(ident); + type_free(s->types[i]); + s->types[i] = type_dup(t); + return s; + } + i++; + } + s->nvar++; + if (s->nvar > s->capacity) { + s->capacity += s->capacity; + s->vars = safe_realloc(s->vars, s->capacity*sizeof(char *)); + s->types = safe_realloc(s->vars, s->capacity*sizeof(struct type *)); + } + s->vars[i] = safe_strdup(ident); + s->types[i] = type_dup(t); + return s; +} + struct subst *subst_singleton(char *ident, struct type *t) { - struct subst *res = safe_malloc(sizeof(struct subst)); - res->nvar = 1; - res->vars = safe_malloc(sizeof(char *)); - res->vars[0] = safe_strdup(ident); - res->types = safe_malloc(sizeof(struct type *)); - res->types[0] = type_dup(t); + struct subst *res = subst_id(); + subst_insert(res, ident, t); return res; } -struct subst *subst_union(struct subst *l, struct subst *r) +struct subst *subst_union(struct subst *s1, struct subst *s2) { - if (l == NULL || r == NULL) - return NULL; - struct subst *res = safe_malloc(sizeof(struct subst)); - res->nvar = l->nvar+r->nvar; - res->vars = safe_malloc(res->nvar*sizeof(char *)); - res->types = safe_malloc(res->nvar*sizeof(struct type *)); - for (int i = 0; invar; i++) { - res->vars[i] = safe_strdup(l->vars[i]); - res->types[i] = type_dup(l->types[i]); - } - for (int i = 0; invar; i++) { - res->vars[l->nvar+i] = safe_strdup(r->vars[i]); - res->types[l->nvar+i] = subst_apply_t(l, type_dup(r->types[i])); - } - return res; + //Apply s1 on s2 + for (int i = 0; invar; i++) + s2->types[i] = subst_apply_t(s1, s2->types[i]); + //Insert s1 into s2 + for (int i = 0; invar; i++) + subst_insert(s2, s1->vars[i], s1->types[i]); + subst_free(s1); + return s2; } struct type *subst_apply_t(struct subst *subst, struct type *l) @@ -77,15 +93,19 @@ struct type *subst_apply_t(struct subst *subst, struct type *l) struct scheme *subst_apply_s(struct subst *subst, struct scheme *scheme) { - for (int i = 0; invar; i++) { - for (int j = 0; jnvar; j++) { - if (strcmp(scheme->var[i], subst->vars[j]) != 0) { - struct subst *t = subst_singleton( - subst->vars[j], subst->types[j]); - scheme->type = subst_apply_t(t, scheme->type); + //remove the quantified ones from the subst + struct subst *s = subst_id(); + for (int j = 0; jnvar; j++) { + bool found = false; + for (int i = 0; invar; i++) { + if (strcmp(scheme->var[i], subst->vars[j]) == 0) { + found = true; } } + if (!found) + subst_insert(s, subst->vars[j], subst->types[j]); } + scheme->type = subst_apply_t(s, scheme->type); return scheme; } diff --git a/sem/hm/subst.h b/sem/hm/subst.h index 0fba436..ff9e71f 100644 --- a/sem/hm/subst.h +++ b/sem/hm/subst.h @@ -6,11 +6,13 @@ struct subst { int nvar; + int capacity; char **vars; struct type **types; }; struct subst *subst_id(); +struct subst *subst_insert(struct subst *s, char *ident, struct type *t); struct subst *subst_singleton(char *ident, struct type *t); struct subst *subst_union(struct subst *l, struct subst *r); diff --git a/sem/scc.c b/sem/scc.c index f728f4f..a490b4c 100644 --- a/sem/scc.c +++ b/sem/scc.c @@ -148,6 +148,7 @@ struct components *tarjans( int iddeclcmp(const void *l, const void *r) { + fprintf(stderr, "iddeclcmp: %s %s\n", (char *)l, (*(struct decl **)r)->data.dfun->ident); return strcmp((char *)l, (*(struct decl **)r)->data.dfun->ident); } @@ -169,16 +170,19 @@ struct list *edges_expr(int ndecls, struct decl **decls, void *parent, for(int i = 0; idata.efuncall.nargs; i++) l = edges_expr(ndecls, decls, parent, expr->data.efuncall.args[i], l); - struct decl **to = bsearch(expr->data.efuncall.ident, decls, - ndecls, sizeof(struct decl *), iddeclcmp); - if (to == NULL) { - die("calling an unknown function\n"); - } else { - struct edge *edge = safe_malloc(sizeof(struct edge)); - edge->from = parent; - edge->to = (void *)*to; - l = list_cons(edge, l); + bool found = false; + for (int i = 0; idata.dfun->ident, + expr->data.efuncall.ident) == 0) { + struct edge *edge = safe_malloc(sizeof(struct edge)); + edge->from = parent; + edge->to = (void *)decls[i]; + l = list_cons(edge, l); + found = true; + } } + if (!found) + die("Malformed function call\n"); break; case eint: break; @@ -195,8 +199,7 @@ struct list *edges_expr(int ndecls, struct decl **decls, void *parent, case estring: break; case eunop: - l = edges_expr(ndecls, decls, parent, expr->data.eunop.l, l); - break; + return edges_expr(ndecls, decls, parent, expr->data.eunop.l, l); default: die("Unsupported expr node\n"); }