{
int nftv = 0;
char **ftv = NULL;
+ bool res = false;
type_ftv(r, &nftv, &ftv);
- for (int i = 0; i<nftv; i++)
- if (strcmp(ftv[i], var) == 0)
- return true;
- return false;
+ for (int i = 0; i<nftv; i++) {
+ if (strcmp(ftv[i], var) == 0) {
+ res = true;
+ break;
+ }
+ }
+ free(ftv);
+ return res;
}
struct subst *unify(struct type *l, struct type *r)
return NULL;
if (r->type == tvar && l->type != tvar)
return unify(r, l);
- struct subst *s1, *s2;
+ struct subst *s1, *s2, *s3;
switch (l->type) {
case tarrow:
if (r->type == tarrow) {
s1 = unify(l->data.tarrow.l, r->data.tarrow.l);
s2 = unify(subst_apply_t(s1, l->data.tarrow.l),
subst_apply_t(s1, r->data.tarrow.l));
- return subst_union(s1, s2);
+ s3 = subst_union(s1, s2);
+ return s3;
}
break;
case tbasic:
s1 = unify(l->data.ttuple.l, r->data.ttuple.l);
s2 = unify(subst_apply_t(s1, l->data.ttuple.l),
subst_apply_t(s1, r->data.ttuple.l));
- return subst_union(s1, s2);
+ s3 = subst_union(s1, s2);
+ subst_free(s1);
+ subst_free(s2);
+ return s3;
}
break;
case tvar:
return NULL;
}
+struct subst *unifyfree(struct type *l, struct type *r, bool freel, bool freer)
+{
+ struct subst *s = unify(l, r);
+ if (freel)
+ type_free(l);
+ if (freer)
+ type_free(r);
+ return s;
+}
+
+struct subst *infer_binop(struct gamma *gamma, struct expr *l, struct expr *r,
+ struct type *a1, struct type *a2, struct type *rt, struct type *sigma)
+{
+ struct subst *s1 = infer_expr(gamma, l, a1);
+ struct subst *s2 = infer_expr(subst_apply_g(s1, gamma), r, a2);
+ struct subst *s3 = subst_union(s1, s2);
+ struct subst *s4 = unify(subst_apply_t(s3, sigma), rt);
+ struct subst *s5 = subst_union(s3, s4);
+ subst_free(s1);
+ subst_free(s2);
+ subst_free(s3);
+ subst_free(s4);
+ return s5;
+}
+
+struct subst *infer_unop(struct gamma *gamma, struct expr *e,
+ struct type *a, struct type *rt, struct type *sigma)
+{
+ struct subst *s1 = infer_expr(gamma, e, a);
+ struct subst *s2 = unify(subst_apply_t(s1, sigma), rt);
+ struct subst *s3 = subst_union(s1, s2);
+ subst_free(s1);
+ subst_free(s2);
+ return s3;
+}
+
+static struct type tybool = {.type=tbasic, .data={.tbasic=btbool}};
+static struct type tychar = {.type=tbasic, .data={.tbasic=btchar}};
+static struct type tyint = {.type=tbasic, .data={.tbasic=btint}};
+static struct type tystring = {.type=tlist, .data={.tlist=&tychar}};
+
struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type)
{
fprintf(stderr, "infer expr: ");
type_print(type, stderr);
fprintf(stderr, "\n");
-#define infbop(l, r, a1, a2, rt, sigma) {\
- s1 = infer_expr(gamma, l, a1);\
- s2 = subst_union(s1, infer_expr(subst_apply_g(s1, gamma), r, a2));\
- return subst_union(s2, unify(subst_apply_t(s2, sigma), rt));\
-}
-#define infbinop(e, a1, a2, rt, sigma)\
- infbop(e->data.ebinop.l, e->data.ebinop.r, a1, a2, rt, sigma)
- struct subst *s1, *s2;
- struct type *f1, *f2;
+#define infbinop(e, a1, a2, rt) infer_binop(\
+ gamma, e->data.ebinop.l, e->data.ebinop.r, a1, a2, rt, type)
+ struct subst *s1;
+ struct type *f1, *f2, *f3;
struct scheme *s;
switch (expr->type) {
case ebool:
- return unify(type_basic(btbool), type);
+ return unify(&tybool, type);
case ebinop:
switch (expr->data.ebinop.op) {
case binor:
case binand:
- infbinop(expr, type_basic(btbool), type_basic(btbool),
- type_basic(btbool), type);
+ return infbinop(expr, &tybool, &tybool, &tybool);
case eq:
case neq:
case leq:
case geq:
case ge:
f1 = gamma_fresh(gamma);
- infbinop(expr, f1, f1, type_basic(btbool), type);
+ s1 = infbinop(expr, f1, f1, &tybool);
+ type_free(f1);
+ return s1;
case cons:
f1 = gamma_fresh(gamma);
- infbinop(expr, f1, type_list(f1), type_list(f1), type);
+ f2 = type_list(f1);
+ s1 = infbinop(expr, f1, f2, f2);
+ type_free(f2);
+ return s1;
case plus:
case minus:
case times:
case divide:
case modulo:
case power:
- infbinop(expr, type_basic(btint), type_basic(btint),
- type_basic(btint), type);
+ return infbinop(expr, &tyint, &tyint, &tyint);
}
break;
case echar:
- return unify(type_basic(btchar), type);
+ return unify(&tychar, type);
case efuncall:
if ((s = gamma_lookup(gamma, expr->data.efuncall.ident)) == NULL)
die("Unbound function: %s\n", expr->data.efuncall.ident);
//TODO fields
return NULL;
case eint:
- return unify(type_basic(btint), type);
+ return unify(&tyint, type);
case eident:
if ((s = gamma_lookup(gamma, expr->data.eident.ident)) == NULL)
die("Unbound variable: %s\n", expr->data.eident.ident);
- //TODO fields
- return unify(scheme_instantiate(gamma, s), type);
+ f1 = scheme_instantiate(gamma, s);
+ s1 = unify(f1, type);
+ type_free(f1);
+ //TODO field
+ return s1;
case enil:
f1 = gamma_fresh(gamma);
- return unify(type_list(f1), type);
+ return unifyfree(type_list(f1), type, true, false);
case etuple:
f1 = gamma_fresh(gamma);
f2 = gamma_fresh(gamma);
- infbop(expr->data.etuple.left, expr->data.etuple.right,
- f1, f2, type_tuple(f1, f2), type);
+ f3 = type_tuple(f1, f2);
+ s1 = infer_binop(gamma, expr->data.etuple.left,
+ expr->data.etuple.right, f1, f2, f3, type);
+ type_free(f3);
+ return s1;
case estring:
- return unify(type_list(type_basic(btchar)), type);
+ return unify(&tystring, type);
case eunop:
- switch(expr->data.eunop.op) {
+ switch (expr->data.eunop.op) {
case negate:
- s1 = infer_expr(gamma,
- expr->data.eunop.l, type_basic(btint));
- if (s1 == NULL)
- return NULL;
- return subst_union(s1,
- unify(subst_apply_t(s1, type),
- type_basic(btint)));
+ return infer_unop(gamma, expr->data.eunop.l,
+ &tyint, &tyint, type);
case inverse:
- s1 = infer_expr(gamma,
- expr->data.eunop.l, type_basic(btbool));
- return subst_union(s1,
- unify(subst_apply_t(s1, type),
- type_basic(btbool)));
+ return infer_unop(gamma, expr->data.eunop.l,
+ &tybool, &tybool, type);
}
}
return NULL;
}
+
+struct subst *infer_stmt(struct gamma *gamma, struct stmt *stmt, struct type *type)
+{
+ fprintf(stderr, "infer stmt: ");
+ stmt_print(stmt, 0, stderr);
+ fprintf(stderr, "\ngamma: ");
+ gamma_print(gamma, stderr);
+ fprintf(stderr, "\ntype: ");
+ type_print(type, stderr);
+ fprintf(stderr, "\n");
+
+// struct subst *s1, *s2;
+// struct type *f1, *f2, *f3;
+// struct scheme *s;
+ switch (stmt->type) {
+ case sassign:
+ break;
+ case sif:
+ break;
+ case sreturn:
+ return infer_expr(gamma, stmt->data.sreturn, type);
+ case sexpr:
+ break;
+ case svardecl:
+ break;
+ case swhile:
+ break;
+ }
+ return subst_id();
+// union {
+// struct {
+// char *ident;
+// int nfields;
+// char **fields;
+// struct expr *expr;
+// } sassign;
+// struct {
+// struct expr *pred;
+// int nthen;
+// struct stmt **then;
+// int nels;
+// struct stmt **els;
+// } sif;
+// struct vardecl *svardecl;
+// struct expr *sreturn;
+// struct expr *sexpr;
+// struct {
+// struct expr *pred;
+// int nbody;
+// struct stmt **body;
+// } swhile;
+// } data;
+}
+
+struct subst *infer_fundecl(struct gamma *gamma, struct fundecl *fundecl)
+{
+ //struct type *t;
+ if (fundecl->rtype == NULL || fundecl->atypes == NULL) {
+ fundecl->rtype = gamma_fresh(gamma);
+ fundecl->atypes = safe_realloc(fundecl->atypes, fundecl->nargs*sizeof(struct type));
+ for (int i = 0; i<fundecl->nargs; i++)
+ fundecl->atypes[i] = gamma_fresh(gamma);
+ }
+ fprintf(stderr, "fundecl with type: ");
+ for (int i = 0; i<fundecl->nargs; i++) {
+ type_print(fundecl->atypes[i], stderr);
+ fprintf(stderr, " ");
+ }
+ fprintf(stderr, "-> ");
+ type_print(fundecl->rtype, stderr);
+ fprintf(stderr, "\n");
+
+ for (int i = 0; i<fundecl->nargs; i++)
+ gamma_insert(gamma, fundecl->args[i],
+ scheme_create(fundecl->atypes[i]));
+
+ struct subst *s = subst_id();
+ for (int i = 0; i<fundecl->nbody; i++) {
+ struct subst *s1 = infer_stmt(gamma, fundecl->body[i], fundecl->rtype);
+ struct subst *s2 = s;
+ s = subst_union(s2, s1);
+ subst_free(s1);
+ subst_free(s2);
+ }
+
+ fprintf(stderr, "inferred function substitution: ");
+ subst_print(s, stderr);
+
+ for (int i = 0; i<fundecl->nargs; i++)
+ fundecl->atypes[i] = subst_apply_t(s, fundecl->atypes[i]);
+ fundecl->rtype = subst_apply_t(s, fundecl->rtype);
+ fprintf(stderr, "fundecl with type: ");
+ for (int i = 0; i<fundecl->nargs; i++) {
+ type_print(fundecl->atypes[i], stderr);
+ fprintf(stderr, " ");
+ }
+ fprintf(stderr, "-> ");
+ type_print(fundecl->rtype, stderr);
+ fprintf(stderr, "\n");
+ //char *ident;
+ //int nargs;
+ //char **args;
+ //int natypes;
+ //struct type **atypes;
+ //struct type *rtype;
+ //int nbody;
+ //struct stmt **body;
+
+ return s;
+}