if (s == NULL)
die("error inferring variable\n");
vardecl->type = subst_apply_t(s, t);
+ gamma_insert(gamma, vardecl->ident, scheme_create(vardecl->type));
subst_free(s);
return vardecl;
}
+void type_fundecl(struct gamma *gamma, struct fundecl *decl)
+{
+ struct type *f1 = gamma_fresh(gamma);
+ struct subst *s1 = infer_fundecl(gamma, decl, f1);
+ f1 = subst_apply_t(s1, f1);
+
+ gamma_insert(gamma, decl->ident, scheme_generalise(gamma, f1));
+ subst_free(s1);
+ type_free(f1);
+}
+
+void type_comp(struct gamma *gamma, int ndecls, struct fundecl **decl)
+{
+ //Create a fresh variable for every function in the component
+ struct type **fs = safe_malloc(ndecls*sizeof(struct type *));
+ for (int i = 0; i<ndecls; i++) {
+ bool fresh = decl[i]->rtype == NULL || decl[i]->atypes == NULL;
+ fs[i] = fresh ? gamma_fresh(gamma) : decl[i]->rtype;
+ for (int j = 0; j<decl[i]->nargs; j++) {
+ struct type *a = fresh ? gamma_fresh(gamma)
+ : type_dup(decl[i]->atypes[j]);
+ fs[i] = type_arrow(a, fs[i]);
+ }
+ gamma_insert(gamma, decl[i]->ident, scheme_create(fs[i]));
+ }
+
+ //Infer each function
+ struct subst *s0 = subst_id();
+ for (int i = 0; i<ndecls; i++) {
+ struct subst *s1 = infer_fundecl(gamma, decl[i],
+ subst_apply_t(s0, fs[i]));
+ s0 = subst_union(s1, s0);
+ subst_apply_g(s0, gamma);
+ }
+
+ //Generalise all functions and put in gamma
+ for (int i = 0; i<ndecls; i++) {
+ struct type *t = subst_apply_t(s0, fs[i]);
+ gamma_insert(gamma, decl[i]->ident, scheme_generalise(gamma, t));
+
+ decl[i]->atypes = safe_malloc(decl[i]->nargs*sizeof(struct type *));
+ decl[i]->natypes = decl[i]->nargs;
+ for (int j = 0; j<decl[i]->nargs; j++) {
+ decl[i]->atypes[j] = type_dup(t->data.tarrow.l);
+ t = t->data.tarrow.r;
+ }
+ decl[i]->rtype = type_dup(t);
+ }
+
+ //Free all types
+ for (int i = 0; i<ndecls; i++)
+ type_free(fs[i]);
+ free(fs);
+ subst_free(s0);
+}
+
struct ast *sem(struct ast *ast)
{
ast = ast_scc(ast);
//Infer if necessary
type_vardecl(gamma, ast->decls[i]->data.dvar);
break;
- case dfundecl: {
- struct type *f1 = gamma_fresh(gamma);
- struct subst *s = infer_fundecl(gamma,
- ast->decls[i]->data.dfun, f1);
- f1 = subst_apply_t(s, f1);
- gamma_insert(gamma, ast->decls[i]->data.dfun->ident,
- scheme_generalise(gamma, f1));
- gamma_print(gamma, stderr);
- fprintf(stderr, "done\n");
- subst_free(s);
- type_free(f1);
+ case dfundecl:
+ //Infer function as singleton component
+ type_comp(gamma, 1, &ast->decls[i]->data.dfun);
break;
- }
case dcomp:
+ //Infer function as singleton component
+ type_comp(gamma, ast->decls[i]->data.dcomp.ndecls,
+ ast->decls[i]->data.dcomp.decls);
break;
}
}
-
gamma_free(gamma);
return ast;
return NULL;
if (r->type == tvar && l->type != tvar)
return unify(loc, r, l);
- struct subst *s1, *s2, *s3;
- switch (l->type) {
- case tarrow:
- if (r->type == tarrow) {
- s1 = unify(loc, l->data.tarrow.l, r->data.tarrow.l);
- s2 = unify(loc, subst_apply_t(s1, l->data.tarrow.l),
- subst_apply_t(s1, r->data.tarrow.l));
- s3 = subst_union(s1, s2);
- return s3;
- }
- break;
- case tbasic:
- if (r->type == tbasic && l->data.tbasic == r->data.tbasic)
- return subst_id();
- break;
- case tlist:
- if (r->type == tlist)
- return unify(loc, l->data.tlist, r->data.tlist);
- break;
- case ttuple:
- if (r->type == ttuple) {
- s1 = unify(loc, l->data.ttuple.l, r->data.ttuple.l);
- s2 = unify(loc, subst_apply_t(s1, l->data.ttuple.l),
- subst_apply_t(s1, r->data.ttuple.l));
- s3 = subst_union(s1, s2);
- subst_free(s1);
- subst_free(s2);
- return s3;
- }
- break;
- case tvar:
+ struct subst *s1, *s2;
+ if (l->type == tarrow && r->type == tarrow) {
+ s1 = unify(loc, l->data.tarrow.l, r->data.tarrow.l);
+ s2 = unify(loc, subst_apply_t(s1, l->data.tarrow.l),
+ subst_apply_t(s1, r->data.tarrow.l));
+ return subst_union(s2, s1);
+ } else if (l->type == tbasic && r->type == tbasic
+ && l->data.tbasic == r->data.tbasic) {
+ return subst_id();
+ } else if (l->type == tlist && r->type == tlist) {
+ return unify(loc, l->data.tlist, r->data.tlist);
+ } else if (l->type == ttuple && r->type == ttuple) {
+ s1 = unify(loc, l->data.ttuple.l, r->data.ttuple.l);
+ s2 = unify(loc, subst_apply_t(s1, l->data.ttuple.r),
+ subst_apply_t(s1, r->data.ttuple.r));
+ return subst_union(s2, s1);
+ } else if (l->type == tvar) {
if (r->type == tvar && strcmp(l->data.tvar, r->data.tvar) == 0)
return subst_id();
else if (occurs_check(l->data.tvar, r))
l->data.tvar);
else
return subst_singleton(l->data.tvar, r);
- break;
+ } else {
+ type_error(loc, false, "cannot unify ");
+ type_print(l, stderr);
+ fprintf(stderr, " with ");
+ type_print(r, stderr);
+ die("\n");
}
- type_error(loc, false, "cannot unify ");
- type_print(l, stderr);
- fprintf(stderr, " with ");
- type_print(r, stderr);
- die("\n");
return NULL;
}
{
struct subst *s1 = infer_expr(gamma, l, a1);
struct subst *s2 = infer_expr(subst_apply_g(s1, gamma), r, a2);
- struct subst *s3 = subst_union(s1, s2);
+ struct subst *s3 = subst_union(s2, s1);
struct subst *s4 = unify(l->loc, subst_apply_t(s3, sigma), rt);
- struct subst *s5 = subst_union(s3, s4);
- subst_free(s1);
- subst_free(s2);
- subst_free(s3);
- subst_free(s4);
- return s5;
+ return subst_union(s4, s3);
}
struct subst *infer_unop(struct gamma *gamma, struct expr *e,
{
struct subst *s1 = infer_expr(gamma, e, a);
struct subst *s2 = unify(e->loc, subst_apply_t(s1, sigma), rt);
- struct subst *s3 = subst_union(s1, s2);
- subst_free(s1);
- subst_free(s2);
- return s3;
+ return subst_union(s2, s1);
}
static struct type tybool = {.type=tbasic, .data={.tbasic=btbool}};
struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type)
{
- fprintf(stderr, "infer expr: ");
- expr_print(expr, stderr);
- fprintf(stderr, "\ngamma: ");
- gamma_print(gamma, stderr);
- fprintf(stderr, "\ntype: ");
- type_print(type, stderr);
- fprintf(stderr, "\n");
-
-#define infbinop(e, a1, a2, rt) infer_binop(\
+#define infbinop(e, a1, a2, rt, type) infer_binop(\
gamma, e->data.ebinop.l, e->data.ebinop.r, a1, a2, rt, type)
- struct subst *s0;
+ struct subst *s0, *s1;
struct type *f1, *f2, *f3;
struct scheme *s;
switch (expr->type) {
switch (expr->data.ebinop.op) {
case binor:
case binand:
- return infbinop(expr, &tybool, &tybool, &tybool);
+ return infbinop(expr, &tybool, &tybool, &tybool, type);
case eq:
case neq:
case leq:
case geq:
case ge:
f1 = gamma_fresh(gamma);
- s0 = infbinop(expr, f1, f1, &tybool);
+ s0 = infbinop(expr, f1, f1, &tybool, type);
type_free(f1);
return s0;
case cons:
f1 = gamma_fresh(gamma);
f2 = type_list(f1);
- s0 = infbinop(expr, f1, f2, f2);
+ s0 = infbinop(expr, f1, f2, f2, type);
type_free(f2);
return s0;
case plus:
case divide:
case modulo:
case power:
- return infbinop(expr, &tyint, &tyint, &tyint);
+ return infbinop(expr, &tyint, &tyint, &tyint, type);
}
break;
case echar:
type_error(expr->loc, "Unbound function: %s\n"
, expr->data.efuncall.ident);
struct type *t = scheme_instantiate(gamma, s);
+
struct subst *s0 = subst_id();
+ //Infer args
for (int i = 0; i<expr->data.efuncall.nargs; i++) {
if (t->type != tarrow)
type_error(expr->loc, true,
"too many arguments to %s\n",
expr->data.efuncall.ident);
- struct subst *s1 = infer_expr(gamma,
+ s1 = infer_expr(gamma,
expr->data.efuncall.args[i], t->data.tarrow.l);
- struct subst *s2 = s0;
- s0 = subst_union(s2, s1);
- subst_free(s1);
- subst_free(s2);
- t = t->data.tarrow.r;
+ s0 = subst_union(s1, s0);
+ subst_apply_g(s0, gamma);
+ t = subst_apply_t(s0, t->data.tarrow.r);
}
if (t->type == tarrow)
type_error(expr->loc, true,
"not enough arguments to %s\n",
expr->data.efuncall.ident);
+
+ //Infer return type
+ s1 = unify(expr->loc, t, type);
+ s0 = subst_union(s1, s0);
type_free(t);
//TODO fields
return s0;
struct subst *infer_stmt(struct gamma *gamma, struct stmt *stmt, struct type *type)
{
- fprintf(stderr, "infer stmt: ");
- stmt_print(stmt, 0, stderr);
- fprintf(stderr, "\ngamma: ");
- gamma_print(gamma, stderr);
- fprintf(stderr, "\ntype: ");
- type_print(type, stderr);
- fprintf(stderr, "\n");
-
- struct subst *s1;//, *s2;
-// struct type *f1, *f2, *f3;
-// struct scheme *s;
+ struct subst *s0, *s1, *s2;
+ struct type *f1;
switch (stmt->type) {
case sassign:
break;
case sif:
- s1 = infer_expr(gamma, stmt->data.sif.pred, &tybool);
- subst_apply_g(s1, gamma);
- subst_free(s1);
- break;
+ s0 = infer_expr(gamma, stmt->data.sif.pred, &tybool);
+ subst_apply_g(s0, gamma);
+
+ for (int i = 0; i<stmt->data.sif.nthen; i++) {
+ s1 = infer_stmt(gamma, stmt->data.sif.then[i], type);
+ s0 = subst_union(s1, s0);
+ subst_apply_g(s0, gamma);
+ }
+
+ for (int i = 0; i<stmt->data.sif.nels; i++) {
+ s1 = infer_stmt(gamma, stmt->data.sif.els[i], type);
+ s0 = subst_union(s1, s0);
+ subst_apply_g(s0, gamma);
+ }
+ return s0;
case sreturn:
return infer_expr(gamma, stmt->data.sreturn, type);
case sexpr:
- break;
+ f1 = gamma_fresh(gamma);
+ s0 = infer_expr(gamma, stmt->data.sexpr, f1);
+ type_free(f1);
+ return s0;
case svardecl:
break;
case swhile:
- break;
+ s0 = infer_expr(gamma, stmt->data.swhile.pred, &tybool);
+ subst_apply_g(s0, gamma);
+
+ for (int i = 0; i<stmt->data.swhile.nbody; i++) {
+ s1 = infer_stmt(gamma, stmt->data.swhile.body[i], type);
+ s0 = subst_union(s1, s0);
+ subst_apply_g(s0, gamma);
+ }
+
+ return s0;
}
return subst_id();
-// union {
-// struct {
-// char *ident;
-// int nfields;
-// char **fields;
-// struct expr *expr;
-// } sassign;
-// struct {
-// struct expr *pred;
-// int nthen;
-// struct stmt **then;
-// int nels;
-// struct stmt **els;
-// } sif;
-// struct vardecl *svardecl;
-// struct expr *sreturn;
-// struct expr *sexpr;
-// struct {
-// struct expr *pred;
-// int nbody;
-// struct stmt **body;
-// } swhile;
-// } data;
}
-struct subst *infer_fundecl(struct gamma *gamma, struct fundecl *fundecl, struct type *fty)
+struct subst *infer_fundecl(struct gamma *gamma, struct fundecl *fundecl, struct type *ftype)
{
- fprintf(stderr, "inferring function to type ");
- type_print(fty, stderr);
- fprintf(stderr, " with gamma ");
- gamma_print(gamma, stderr);
- fprintf(stderr, "\n");
- if (fundecl->rtype == NULL || fundecl->atypes == NULL) {
- fundecl->rtype = gamma_fresh(gamma);
- fundecl->atypes = safe_realloc(fundecl->atypes,
- fundecl->nargs*sizeof(struct type));
- for (int i = 0; i<fundecl->nargs; i++)
- fundecl->atypes[i] = gamma_fresh(gamma);
- }
-
- struct type *ftype = type_dup(fundecl->rtype);
+ // Put arguments in gamma
+ struct type *at = ftype;
for (int i = 0; i<fundecl->nargs; i++) {
- ftype = type_arrow(type_dup(fundecl->atypes[i]), ftype);
+ if (at->type != tarrow)
+ die("malformed ftype\n");
gamma_insert(gamma, fundecl->args[i],
- scheme_create(fundecl->atypes[i]));
+ scheme_create(at->data.tarrow.l));
+ at = at->data.tarrow.r;
}
- gamma_insert(gamma, fundecl->ident, scheme_create(ftype));
+ if (at->type == tarrow)
+ die("malformed ftype\n");
struct subst *s = subst_id();
for (int i = 0; i<fundecl->nbody; i++) {
struct subst *s1 = infer_stmt(gamma,
- fundecl->body[i], fundecl->rtype);
- struct subst *s2 = s;
- s = subst_union(s2, s1);
- subst_free(s1);
- subst_free(s2);
+ fundecl->body[i], at);
+ s = subst_union(s1, s);
+ subst_apply_g(s, gamma);
}
- for (int i = 0; i<fundecl->nargs; i++) {
- struct type *t = subst_apply_t(s, fundecl->atypes[i]);
- type_free(fundecl->atypes[i]);
- fundecl->atypes[i] = t;
- }
- fundecl->rtype = subst_apply_t(s, fundecl->rtype);
-
- ftype = subst_apply_t(s, ftype);
- struct subst *r = unify(fundecl->loc, fty, ftype);
- type_free(ftype);
- subst_free(s);
- return r;
+ return s;
}
#include "../hm.h"
+#define INCAP 10
+
struct subst *subst_id()
{
struct subst *res = safe_malloc(sizeof(struct subst));
res->nvar = 0;
- res->vars = NULL;
- res->types = NULL;
+ res->capacity = INCAP;
+ res->vars = safe_malloc(INCAP*sizeof(char *));
+ res->types = safe_malloc(INCAP*sizeof(struct type *));
return res;
}
+struct subst *subst_insert(struct subst *s, char *ident, struct type *t)
+{
+ int i = 0;
+ while (i < s->nvar) {
+ if (strcmp(s->vars[i], ident) == 0) {
+ free(s->vars[i]);
+ s->vars[i] = safe_strdup(ident);
+ type_free(s->types[i]);
+ s->types[i] = type_dup(t);
+ return s;
+ }
+ i++;
+ }
+ s->nvar++;
+ if (s->nvar > s->capacity) {
+ s->capacity += s->capacity;
+ s->vars = safe_realloc(s->vars, s->capacity*sizeof(char *));
+ s->types = safe_realloc(s->vars, s->capacity*sizeof(struct type *));
+ }
+ s->vars[i] = safe_strdup(ident);
+ s->types[i] = type_dup(t);
+ return s;
+}
+
struct subst *subst_singleton(char *ident, struct type *t)
{
- struct subst *res = safe_malloc(sizeof(struct subst));
- res->nvar = 1;
- res->vars = safe_malloc(sizeof(char *));
- res->vars[0] = safe_strdup(ident);
- res->types = safe_malloc(sizeof(struct type *));
- res->types[0] = type_dup(t);
+ struct subst *res = subst_id();
+ subst_insert(res, ident, t);
return res;
}
-struct subst *subst_union(struct subst *l, struct subst *r)
+struct subst *subst_union(struct subst *s1, struct subst *s2)
{
- if (l == NULL || r == NULL)
- return NULL;
- struct subst *res = safe_malloc(sizeof(struct subst));
- res->nvar = l->nvar+r->nvar;
- res->vars = safe_malloc(res->nvar*sizeof(char *));
- res->types = safe_malloc(res->nvar*sizeof(struct type *));
- for (int i = 0; i<l->nvar; i++) {
- res->vars[i] = safe_strdup(l->vars[i]);
- res->types[i] = type_dup(l->types[i]);
- }
- for (int i = 0; i<r->nvar; i++) {
- res->vars[l->nvar+i] = safe_strdup(r->vars[i]);
- res->types[l->nvar+i] = subst_apply_t(l, type_dup(r->types[i]));
- }
- return res;
+ //Apply s1 on s2
+ for (int i = 0; i<s2->nvar; i++)
+ s2->types[i] = subst_apply_t(s1, s2->types[i]);
+ //Insert s1 into s2
+ for (int i = 0; i<s1->nvar; i++)
+ subst_insert(s2, s1->vars[i], s1->types[i]);
+ subst_free(s1);
+ return s2;
}
struct type *subst_apply_t(struct subst *subst, struct type *l)
struct scheme *subst_apply_s(struct subst *subst, struct scheme *scheme)
{
- for (int i = 0; i<scheme->nvar; i++) {
- for (int j = 0; j<subst->nvar; j++) {
- if (strcmp(scheme->var[i], subst->vars[j]) != 0) {
- struct subst *t = subst_singleton(
- subst->vars[j], subst->types[j]);
- scheme->type = subst_apply_t(t, scheme->type);
+ //remove the quantified ones from the subst
+ struct subst *s = subst_id();
+ for (int j = 0; j<subst->nvar; j++) {
+ bool found = false;
+ for (int i = 0; i<scheme->nvar; i++) {
+ if (strcmp(scheme->var[i], subst->vars[j]) == 0) {
+ found = true;
}
}
+ if (!found)
+ subst_insert(s, subst->vars[j], subst->types[j]);
}
+ scheme->type = subst_apply_t(s, scheme->type);
return scheme;
}