From 5c5743a6026a91cda10878ceb80539796d9df725 Mon Sep 17 00:00:00 2001 From: Mart Lubbers Date: Tue, 23 Feb 2021 15:33:57 +0100 Subject: [PATCH] more type checking --- sem.c | 12 ++- sem/hm.c | 244 +++++++++++++++++++++++++++++++++++++++--------- sem/hm.h | 2 + sem/hm/gamma.c | 7 ++ sem/hm/scheme.c | 1 + sem/hm/subst.c | 16 ++-- sem/hm/subst.h | 2 +- sem/scc.c | 1 - util.c | 8 ++ util.h | 1 + 10 files changed, 236 insertions(+), 58 deletions(-) diff --git a/sem.c b/sem.c index 3404fb4..5c9a528 100644 --- a/sem.c +++ b/sem.c @@ -38,14 +38,14 @@ void check_expr_constant(struct expr *expr) struct vardecl *type_vardecl(struct gamma *gamma, struct vardecl *vardecl) { struct type *t = vardecl->type == NULL - ? gamma_fresh(gamma) : type_dup(vardecl->type); + ? gamma_fresh(gamma) : vardecl->type; struct subst *s = infer_expr(gamma, vardecl->expr, t); if (s == NULL) die("error inferring variable\n"); vardecl->type = subst_apply_t(s, t); - //subst_free(s); + subst_free(s); return vardecl; } @@ -66,9 +66,11 @@ struct ast *sem(struct ast *ast) type_vardecl(gamma, ast->decls[i]->data.dvar); break; case dfundecl: { -// struct type *f1 = gamma_fresh(gamma); -// gamma_insert(gamma, ast->decls[i]->data.dfun->ident -// , scheme_create(f1)); + struct type *f1 = gamma_fresh(gamma); + gamma_insert(gamma, ast->decls[i]->data.dfun->ident + , scheme_create(f1)); + struct subst *s = infer_fundecl(gamma, ast->decls[i]->data.dfun); + subst_free(s); //infer env (Let [(x, e1)] e2) // = fresh // >>= \tv-> let env` = 'Data.Map'.put x (Forall [] tv) env diff --git a/sem/hm.c b/sem/hm.c index fb3d2ce..4b88076 100644 --- a/sem/hm.c +++ b/sem/hm.c @@ -11,11 +11,16 @@ bool occurs_check(char *var, struct type *r) { int nftv = 0; char **ftv = NULL; + bool res = false; type_ftv(r, &nftv, &ftv); - for (int i = 0; itype == tvar && l->type != tvar) return unify(r, l); - struct subst *s1, *s2; + struct subst *s1, *s2, *s3; switch (l->type) { case tarrow: if (r->type == tarrow) { s1 = unify(l->data.tarrow.l, r->data.tarrow.l); s2 = unify(subst_apply_t(s1, l->data.tarrow.l), subst_apply_t(s1, r->data.tarrow.l)); - return subst_union(s1, s2); + s3 = subst_union(s1, s2); + return s3; } break; case tbasic: @@ -47,7 +53,10 @@ struct subst *unify(struct type *l, struct type *r) s1 = unify(l->data.ttuple.l, r->data.ttuple.l); s2 = unify(subst_apply_t(s1, l->data.ttuple.l), subst_apply_t(s1, r->data.ttuple.l)); - return subst_union(s1, s2); + s3 = subst_union(s1, s2); + subst_free(s1); + subst_free(s2); + return s3; } break; case tvar: @@ -67,6 +76,47 @@ struct subst *unify(struct type *l, struct type *r) return NULL; } +struct subst *unifyfree(struct type *l, struct type *r, bool freel, bool freer) +{ + struct subst *s = unify(l, r); + if (freel) + type_free(l); + if (freer) + type_free(r); + return s; +} + +struct subst *infer_binop(struct gamma *gamma, struct expr *l, struct expr *r, + struct type *a1, struct type *a2, struct type *rt, struct type *sigma) +{ + struct subst *s1 = infer_expr(gamma, l, a1); + struct subst *s2 = infer_expr(subst_apply_g(s1, gamma), r, a2); + struct subst *s3 = subst_union(s1, s2); + struct subst *s4 = unify(subst_apply_t(s3, sigma), rt); + struct subst *s5 = subst_union(s3, s4); + subst_free(s1); + subst_free(s2); + subst_free(s3); + subst_free(s4); + return s5; +} + +struct subst *infer_unop(struct gamma *gamma, struct expr *e, + struct type *a, struct type *rt, struct type *sigma) +{ + struct subst *s1 = infer_expr(gamma, e, a); + struct subst *s2 = unify(subst_apply_t(s1, sigma), rt); + struct subst *s3 = subst_union(s1, s2); + subst_free(s1); + subst_free(s2); + return s3; +} + +static struct type tybool = {.type=tbasic, .data={.tbasic=btbool}}; +static struct type tychar = {.type=tbasic, .data={.tbasic=btchar}}; +static struct type tyint = {.type=tbasic, .data={.tbasic=btint}}; +static struct type tystring = {.type=tlist, .data={.tlist=&tychar}}; + struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type) { fprintf(stderr, "infer expr: "); @@ -77,25 +127,19 @@ struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *ty type_print(type, stderr); fprintf(stderr, "\n"); -#define infbop(l, r, a1, a2, rt, sigma) {\ - s1 = infer_expr(gamma, l, a1);\ - s2 = subst_union(s1, infer_expr(subst_apply_g(s1, gamma), r, a2));\ - return subst_union(s2, unify(subst_apply_t(s2, sigma), rt));\ -} -#define infbinop(e, a1, a2, rt, sigma)\ - infbop(e->data.ebinop.l, e->data.ebinop.r, a1, a2, rt, sigma) - struct subst *s1, *s2; - struct type *f1, *f2; +#define infbinop(e, a1, a2, rt) infer_binop(\ + gamma, e->data.ebinop.l, e->data.ebinop.r, a1, a2, rt, type) + struct subst *s1; + struct type *f1, *f2, *f3; struct scheme *s; switch (expr->type) { case ebool: - return unify(type_basic(btbool), type); + return unify(&tybool, type); case ebinop: switch (expr->data.ebinop.op) { case binor: case binand: - infbinop(expr, type_basic(btbool), type_basic(btbool), - type_basic(btbool), type); + return infbinop(expr, &tybool, &tybool, &tybool); case eq: case neq: case leq: @@ -103,22 +147,26 @@ struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *ty case geq: case ge: f1 = gamma_fresh(gamma); - infbinop(expr, f1, f1, type_basic(btbool), type); + s1 = infbinop(expr, f1, f1, &tybool); + type_free(f1); + return s1; case cons: f1 = gamma_fresh(gamma); - infbinop(expr, f1, type_list(f1), type_list(f1), type); + f2 = type_list(f1); + s1 = infbinop(expr, f1, f2, f2); + type_free(f2); + return s1; case plus: case minus: case times: case divide: case modulo: case power: - infbinop(expr, type_basic(btint), type_basic(btint), - type_basic(btint), type); + return infbinop(expr, &tyint, &tyint, &tyint); } break; case echar: - return unify(type_basic(btchar), type); + return unify(&tychar, type); case efuncall: if ((s = gamma_lookup(gamma, expr->data.efuncall.ident)) == NULL) die("Unbound function: %s\n", expr->data.efuncall.ident); @@ -126,39 +174,147 @@ struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *ty //TODO fields return NULL; case eint: - return unify(type_basic(btint), type); + return unify(&tyint, type); case eident: if ((s = gamma_lookup(gamma, expr->data.eident.ident)) == NULL) die("Unbound variable: %s\n", expr->data.eident.ident); - //TODO fields - return unify(scheme_instantiate(gamma, s), type); + f1 = scheme_instantiate(gamma, s); + s1 = unify(f1, type); + type_free(f1); + //TODO field + return s1; case enil: f1 = gamma_fresh(gamma); - return unify(type_list(f1), type); + return unifyfree(type_list(f1), type, true, false); case etuple: f1 = gamma_fresh(gamma); f2 = gamma_fresh(gamma); - infbop(expr->data.etuple.left, expr->data.etuple.right, - f1, f2, type_tuple(f1, f2), type); + f3 = type_tuple(f1, f2); + s1 = infer_binop(gamma, expr->data.etuple.left, + expr->data.etuple.right, f1, f2, f3, type); + type_free(f3); + return s1; case estring: - return unify(type_list(type_basic(btchar)), type); + return unify(&tystring, type); case eunop: - switch(expr->data.eunop.op) { + switch (expr->data.eunop.op) { case negate: - s1 = infer_expr(gamma, - expr->data.eunop.l, type_basic(btint)); - if (s1 == NULL) - return NULL; - return subst_union(s1, - unify(subst_apply_t(s1, type), - type_basic(btint))); + return infer_unop(gamma, expr->data.eunop.l, + &tyint, &tyint, type); case inverse: - s1 = infer_expr(gamma, - expr->data.eunop.l, type_basic(btbool)); - return subst_union(s1, - unify(subst_apply_t(s1, type), - type_basic(btbool))); + return infer_unop(gamma, expr->data.eunop.l, + &tybool, &tybool, type); } } return NULL; } + +struct subst *infer_stmt(struct gamma *gamma, struct stmt *stmt, struct type *type) +{ + fprintf(stderr, "infer stmt: "); + stmt_print(stmt, 0, stderr); + fprintf(stderr, "\ngamma: "); + gamma_print(gamma, stderr); + fprintf(stderr, "\ntype: "); + type_print(type, stderr); + fprintf(stderr, "\n"); + +// struct subst *s1, *s2; +// struct type *f1, *f2, *f3; +// struct scheme *s; + switch (stmt->type) { + case sassign: + break; + case sif: + break; + case sreturn: + return infer_expr(gamma, stmt->data.sreturn, type); + case sexpr: + break; + case svardecl: + break; + case swhile: + break; + } + return subst_id(); +// union { +// struct { +// char *ident; +// int nfields; +// char **fields; +// struct expr *expr; +// } sassign; +// struct { +// struct expr *pred; +// int nthen; +// struct stmt **then; +// int nels; +// struct stmt **els; +// } sif; +// struct vardecl *svardecl; +// struct expr *sreturn; +// struct expr *sexpr; +// struct { +// struct expr *pred; +// int nbody; +// struct stmt **body; +// } swhile; +// } data; +} + +struct subst *infer_fundecl(struct gamma *gamma, struct fundecl *fundecl) +{ + //struct type *t; + if (fundecl->rtype == NULL || fundecl->atypes == NULL) { + fundecl->rtype = gamma_fresh(gamma); + fundecl->atypes = safe_realloc(fundecl->atypes, fundecl->nargs*sizeof(struct type)); + for (int i = 0; inargs; i++) + fundecl->atypes[i] = gamma_fresh(gamma); + } + fprintf(stderr, "fundecl with type: "); + for (int i = 0; inargs; i++) { + type_print(fundecl->atypes[i], stderr); + fprintf(stderr, " "); + } + fprintf(stderr, "-> "); + type_print(fundecl->rtype, stderr); + fprintf(stderr, "\n"); + + for (int i = 0; inargs; i++) + gamma_insert(gamma, fundecl->args[i], + scheme_create(fundecl->atypes[i])); + + struct subst *s = subst_id(); + for (int i = 0; inbody; i++) { + struct subst *s1 = infer_stmt(gamma, fundecl->body[i], fundecl->rtype); + struct subst *s2 = s; + s = subst_union(s2, s1); + subst_free(s1); + subst_free(s2); + } + + fprintf(stderr, "inferred function substitution: "); + subst_print(s, stderr); + + for (int i = 0; inargs; i++) + fundecl->atypes[i] = subst_apply_t(s, fundecl->atypes[i]); + fundecl->rtype = subst_apply_t(s, fundecl->rtype); + fprintf(stderr, "fundecl with type: "); + for (int i = 0; inargs; i++) { + type_print(fundecl->atypes[i], stderr); + fprintf(stderr, " "); + } + fprintf(stderr, "-> "); + type_print(fundecl->rtype, stderr); + fprintf(stderr, "\n"); + //char *ident; + //int nargs; + //char **args; + //int natypes; + //struct type **atypes; + //struct type *rtype; + //int nbody; + //struct stmt **body; + + return s; +} diff --git a/sem/hm.h b/sem/hm.h index 6792c3b..106655b 100644 --- a/sem/hm.h +++ b/sem/hm.h @@ -8,5 +8,7 @@ struct ast *infer(struct ast *ast); struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type); +struct subst *infer_stmt(struct gamma *gamma, struct stmt *stmt, struct type *type); +struct subst *infer_fundecl(struct gamma *gamma, struct fundecl *fundecl); #endif diff --git a/sem/hm/gamma.c b/sem/hm/gamma.c index 9a97c27..4202154 100644 --- a/sem/hm/gamma.c +++ b/sem/hm/gamma.c @@ -15,6 +15,13 @@ struct gamma *gamma_init() void gamma_insert(struct gamma *gamma, char *ident, struct scheme *scheme) { + for (int i = 0; inschemes; i++) { + if(strcmp(gamma->vars[i], ident) == 0) { + scheme_free(gamma->schemes[i]); + gamma->schemes[i] = scheme; + return; + } + } gamma->nschemes++; gamma->vars = realloc(gamma->vars, gamma->nschemes*sizeof(char *)); gamma->schemes = realloc(gamma->schemes, diff --git a/sem/hm/scheme.c b/sem/hm/scheme.c index 177bb64..41d5401 100644 --- a/sem/hm/scheme.c +++ b/sem/hm/scheme.c @@ -22,6 +22,7 @@ struct scheme *scheme_create(struct type *t) s->type = t; s->nvar = 0; s->var = NULL; + return s; } struct scheme *scheme_generalise(struct gamma *gamma, struct type *t) diff --git a/sem/hm/subst.c b/sem/hm/subst.c index 55d6cfe..e8fb35c 100644 --- a/sem/hm/subst.c +++ b/sem/hm/subst.c @@ -32,12 +32,12 @@ struct subst *subst_union(struct subst *l, struct subst *r) res->vars = safe_malloc(res->nvar*sizeof(char *)); res->types = safe_malloc(res->nvar*sizeof(struct type *)); for (int i = 0; invar; i++) { - res->vars[i] = l->vars[i]; - res->types[i] = l->types[i]; + res->vars[i] = safe_strdup(l->vars[i]); + res->types[i] = type_dup(l->types[i]); } for (int i = 0; invar; i++) { - res->vars[l->nvar+i] = r->vars[i]; - res->types[l->nvar+i] = subst_apply_t(l, r->types[i]); + res->vars[l->nvar+i] = safe_strdup(r->vars[i]); + res->types[l->nvar+i] = subst_apply_t(l, type_dup(r->types[i])); } return res; } @@ -112,13 +112,15 @@ void subst_print(struct subst *s, FILE *out) } } -void subst_free(struct subst *s, bool type) +void subst_free(struct subst *s) { if (s != NULL) { for (int i = 0; invar; i++) { free(s->vars[i]); - if (type) - type_free(s->types[i]); + type_free(s->types[i]); } + free(s->vars); + free(s->types); + free(s); } } diff --git a/sem/hm/subst.h b/sem/hm/subst.h index ffec9a7..0fba436 100644 --- a/sem/hm/subst.h +++ b/sem/hm/subst.h @@ -19,6 +19,6 @@ struct scheme *subst_apply_s(struct subst *subst, struct scheme *scheme); struct gamma *subst_apply_g(struct subst *subst, struct gamma *gamma); void subst_print(struct subst *s, FILE *out); -void subst_free(struct subst *s, bool type); +void subst_free(struct subst *s); #endif diff --git a/sem/scc.c b/sem/scc.c index 6d7b02e..f728f4f 100644 --- a/sem/scc.c +++ b/sem/scc.c @@ -271,7 +271,6 @@ struct ast *ast_scc(struct ast *ast) struct edge **edata = (struct edge **) list_to_array(edges, &nedges, false); - fprintf(stderr, "nfun: %d, ffun: %d, nedges: %d\n", nfun, ffun, nedges); // Do tarjan's and convert back into the declaration list struct components *cs = tarjans(nfun, (void **)fundecls, nedges, edata); // if (cs == NULL) diff --git a/util.c b/util.c index fa60ddc..6c70c63 100644 --- a/util.c +++ b/util.c @@ -152,6 +152,14 @@ void *safe_malloc(size_t size) return res; } +void *safe_realloc(void *ptr, size_t size) +{ + void *res = realloc(ptr, size); + if (res == NULL) + pdie("realloc"); + return res; +} + void *safe_strdup(const char *c) { size_t nchar = strlen(c); diff --git a/util.h b/util.h index ac47fa6..6704cf4 100644 --- a/util.h +++ b/util.h @@ -20,6 +20,7 @@ void pindent(int indent, FILE *out); void safe_fprintf(FILE *out, const char *msg, ...); void *safe_malloc(size_t size); +void *safe_realloc(void *ptr, size_t size); void *safe_strdup(const char *c); FILE *safe_fopen(const char *path, const char *mode); void safe_fclose(FILE *file); -- 2.20.1