+#include <stdlib.h>
+#include <string.h>
+
+#include "hm.h"
+#include "../util.h"
+#include "../ast.h"
+
+struct gamma {
+ int fresh;
+ int nschemes;
+ char **vars;
+ struct scheme *schemes;
+};
+
+struct scheme *gamma_lookup(struct gamma *gamma, char *ident)
+{
+ for (int i = 0; i<nschemes; i++) {
+ if (strcmp(ident, gamma->vars[i]) == 0) {
+ //TODO
+ }
+ }
+ return NULL;
+}
+
+struct type *fresh(struct gamma *gamma)
+{
+ char *buf = safe_malloc(10);
+ sprintf(buf, "%d", gamma->fresh);
+ gamma->fresh++;
+ return type_var(buf);
+}
+
+void ftv_type(struct type *r, int *nftv, char **ftv)
+{
+ switch (r->type) {
+ case tbasic:
+ break;
+ case tlist:
+ ftv_type(r->data.tlist, nftv, ftv);
+ break;
+ case ttuple:
+ ftv_type(r->data.ttuple.l, nftv, ftv);
+ ftv_type(r->data.ttuple.r, nftv, ftv);
+ break;
+ case tvar:
+ *ftv = realloc(*ftv, (*nftv+1)*sizeof(char *));
+ if (*ftv == NULL)
+ perror("realloc");
+ ftv[(*nftv)++] = r->data.tvar;
+ break;
+ }
+}
+
+bool occurs_check(char *var, struct type *r)
+{
+ int nftv = 0;
+ char *ftv = NULL;
+ ftv_type(r, &nftv, &ftv);
+ for (int i = 0; i<nftv; i++)
+ if (strcmp(ftv+i, var) == 0)
+ return true;
+ return false;
+}
+
+struct type *dup_type(struct type *r)
+{
+ struct type *res = safe_malloc(sizeof(struct type));
+ *res = *r;
+ switch (r->type) {
+ case tbasic:
+ break;
+ case tlist:
+ res->data.tlist = dup_type(r->data.tlist);
+ break;
+ case ttuple:
+ res->data.ttuple.l = dup_type(r->data.ttuple.l);
+ res->data.ttuple.r = dup_type(r->data.ttuple.r);
+ break;
+ case tvar:
+ res->data.tvar = strdup(r->data.tvar);
+ break;
+ }
+ return res;
+}
+
+struct type *subst_apply_t(struct substitution *subst, struct type *l)
+{
+ switch (l->type) {
+ case tbasic:
+ break;
+ case tlist:
+ l->data.tlist = subst_apply_t(subst, l->data.tlist);
+ break;
+ case ttuple:
+ l->data.ttuple.l = subst_apply_t(subst, l->data.ttuple.l);
+ l->data.ttuple.r = subst_apply_t(subst, l->data.ttuple.r);
+ break;
+ case tvar:
+ for (int i = 0; i<subst->nvar; i++) {
+ if (strcmp(subst->vars[i], l->data.tvar) == 0) {
+ free(l->data.tvar);
+ struct type *r = subst->types[i];
+ *l = *r;
+ free(r);
+ }
+ }
+ break;
+ }
+ return l;
+}
+struct gamma *subst_apply_g(struct substitution *subst, struct gamma *gamma)
+{
+ //TODO
+ return gamma;
+}
+
+void subst_print(struct substitution *s, FILE *out)
+{
+ if (s == NULL) {
+ fprintf(out, "no substitution\n");
+ } else {
+ fprintf(out, "[");
+ for (int i = 0; i<s->nvar; i++) {
+ fprintf(out, "%s->", s->vars[i]);
+ type_print(s->types[i], out);
+ if (i + 1 < s->nvar)
+ fprintf(out, ", ");
+ }
+ fprintf(out, "]\n");
+ }
+}
+
+struct substitution *subst_id()
+{
+ struct substitution *res = safe_malloc(sizeof(struct substitution));
+ res->nvar = 0;
+ res->vars = NULL;
+ res->types = NULL;
+ return res;
+}
+
+struct substitution *subst_singleton(char *ident, struct type *t)
+{
+ struct substitution *res = safe_malloc(sizeof(struct substitution));
+ res->nvar = 1;
+ res->vars = safe_malloc(sizeof(char *));
+ res->vars[0] = safe_strdup(ident);
+ res->types = safe_malloc(sizeof(struct type *));
+ res->types[0] = dup_type(t);
+ return res;
+}
+
+struct substitution *subst_union(struct substitution *l, struct substitution *r)
+{
+ struct substitution *res = safe_malloc(sizeof(struct substitution));
+ res->nvar = l->nvar+r->nvar;
+ res->vars = safe_malloc(res->nvar*sizeof(char *));
+ res->types = safe_malloc(res->nvar*sizeof(struct type *));
+ for (int i = 0; i<l->nvar; i++) {
+ res->vars[i] = l->vars[i];
+ res->types[i] = l->types[i];
+ }
+ for (int i = 0; i<r->nvar; i++) {
+ res->vars[l->nvar+i] = l->vars[i];
+ res->types[l->nvar+i] = subst_apply_t(l, r->types[i]);
+ }
+ return res;
+}
+
+struct substitution *unify(struct type *l, struct type *r) {
+ switch (l->type) {
+ case tbasic:
+ if (r->type == tbasic && l->data.tbasic == r->data.tbasic)
+ return subst_id();
+ break;
+ case tlist:
+ if (r->type == tlist)
+ return unify(l->data.tlist, r->data.tlist);
+ break;
+ case ttuple:
+ if (r->type == ttuple) {
+ struct substitution *s1 = unify(
+ l->data.ttuple.l,
+ r->data.ttuple.l);
+ struct substitution *s2 = unify(
+ subst_apply_t(s1, l->data.ttuple.l),
+ subst_apply_t(s1, r->data.ttuple.l));
+ return subst_union(s1, s2);
+ }
+ break;
+ case tvar:
+ if (r->type == tvar && strcmp(l->data.tvar, r->data.tvar) == 0)
+ return subst_id();
+ else if (occurs_check(l->data.tvar, r))
+ fprintf(stderr, "Infinite type %s\n", l->data.tvar);
+ else
+ return subst_singleton(l->data.tvar, r);
+ break;
+ }
+ return NULL;
+}
+
+struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type)
+{
+
+#define infbop(l, r, a1, a2, rt, sigma) {\
+ s1 = infer_expr(gamma, l, a1);\
+ s2 = subst_union(s1, infer_expr(subst_apply_g(s1, gamma), r, a2));\
+ return subst_union(s2, unify(subst_apply_t(s2, sigma), rt));\
+}
+#define infbinop(e, a1, a2, rt, sigma)\
+ infbop(e->data.ebinop.l, e->data.ebinop.r, a1, a2, rt, sigma)
+ struct substitution *s1, *s2;
+ struct type *f1, *f2;
+ switch (expr->type) {
+ case ebool:
+ return unify(type_basic(btbool), type);
+ case ebinop:
+ switch(expr->data.ebinop.op) {
+ case binor:
+ case binand:
+ infbinop(expr, type_basic(btbool), type_basic(btbool),
+ type_basic(btbool), type);
+ case eq:
+ case neq:
+ case leq:
+ case le:
+ case geq:
+ case ge:
+ f1 = fresh(gamma);
+ infbinop(expr, f1, f1, type_basic(btbool), type);
+ case cons:
+ f1 = fresh(gamma);
+ infbinop(expr, f1, type_list(f1), type_list(f1), type);
+ case plus:
+ case minus:
+ case times:
+ case divide:
+ case modulo:
+ case power:
+ infbinop(expr, type_basic(btint), type_basic(btint),
+ type_basic(btint), type);
+ }
+ case echar:
+ return unify(type_basic(btchar), type);
+ case efuncall:
+ //TODO
+ return NULL;
+ case eint:
+ return unify(type_basic(btint), type);
+ case eident:
+
+ //TODO
+ return NULL;
+ case etuple:
+ f1 = fresh(gamma);
+ f2 = fresh(gamma);
+ infbop(expr->data.etuple.left, expr->data.etuple.right,
+ f1, f2, type_tuple(f1, f2), type);
+ case estring:
+ return unify(type_list(type_basic(btchar)), type);
+ case eunop:
+ switch(expr->data.eunop.op) {
+ case negate:
+ s1 = infer_expr(gamma,
+ expr->data.eunop.l, type_basic(btint));
+ return subst_union(s1,
+ unify(subst_apply_t(s1, type),
+ type_basic(btint)));
+ case inverse:
+ s1 = infer_expr(gamma,
+ expr->data.eunop.l, type_basic(btbool));
+ return subst_union(s1,
+ unify(subst_apply_t(s1, type),
+ type_basic(btbool)));
+ }
+ }
+
+}