From ea9d370cc62d7ba66213d4d284de7c379f34d1f8 Mon Sep 17 00:00:00 2001
From: Mart Lubbers <mart@martlubbers.net>
Date: Wed, 17 Feb 2021 11:43:04 +0100
Subject: [PATCH] work on type inference some more

---
 Makefile                 |  15 ++-
 ast.c                    | 107 +------------------
 ast.h                    |  22 +---
 sem.c                    |  57 ++++++----
 sem/hm.c                 | 223 ++++++++++-----------------------------
 sem/hm.h                 |  16 +--
 sem/hm/gamma.c           |  62 +++++++++++
 sem/hm/gamma.h           |  24 +++++
 sem/hm/scheme.c          |  68 ++++++++++++
 sem/hm/scheme.h          |  19 ++++
 sem/hm/subst.c           | 124 ++++++++++++++++++++++
 sem/hm/subst.h           |  24 +++++
 sem/scc.c                |   5 +-
 test/Makefile            |  19 ++++
 test/test_sem_hm_gamma.c |  97 +++++++++++++++++
 type.c                   | 183 ++++++++++++++++++++++++++++++++
 type.h                   |  36 +++++++
 util.c                   |   4 +-
 18 files changed, 771 insertions(+), 334 deletions(-)
 create mode 100644 sem/hm/gamma.c
 create mode 100644 sem/hm/gamma.h
 create mode 100644 sem/hm/scheme.c
 create mode 100644 sem/hm/scheme.h
 create mode 100644 sem/hm/subst.c
 create mode 100644 sem/hm/subst.h
 create mode 100644 test/Makefile
 create mode 100644 test/test_sem_hm_gamma.c
 create mode 100644 type.c
 create mode 100644 type.h

diff --git a/Makefile b/Makefile
index 74272aa..5328897 100644
--- a/Makefile
+++ b/Makefile
@@ -1,9 +1,10 @@
-CFLAGS+=-Wall -Wextra -std=c99 -pedantic -D_XOPEN_SOURCE=700 -ggdb
+CFLAGS+=-Wall -Wextra -std=c99 -pedantic -ggdb
 YFLAGS+=-d --locations -v --defines=parse.h
 LFLAGS+=--header-file=scan.h
 
-OBJECTS:=scan.o parse.o ast.o util.o list.o sem.o genc.o \
-	sem/scc.o sem/hm.o
+OBJECTS:=scan.o parse.o ast.o type.o util.o list.o sem.o genc.o \
+	sem/scc.o\
+	$(addprefix sem/hm, .o /gamma.o /subst.o /scheme.o)
 
 all: splc
 splc: $(OBJECTS)
@@ -11,5 +12,13 @@ scan.c: scan.l parse.h
 parse.h: parse.c
 expr.c: y.tab.h
 
+scan.o: CFLAGS+=-D_XOPEN_SOURCE=700
+
+.PHONY: test
+
+test:
+	CFLAGS="$(CFLAGS)" $(MAKE) -C test test
+
 clean:
 	$(RM) $(OBJECTS) y.output parse.h scan.h scan.c parse.c expr a.c
+	$(MAKE) -C test clean
diff --git a/ast.c b/ast.c
index bb8a541..1e93a71 100644
--- a/ast.c
+++ b/ast.c
@@ -4,6 +4,7 @@
 
 #include "util.h"
 #include "ast.h"
+#include "type.h"
 #include "list.h"
 #include "parse.h"
 
@@ -16,10 +17,6 @@ const char *binop_str[] = {
 const char *fieldspec_str[] = {
 	[fst] = "fst", [snd] = "snd", [hd] = "hd", [tl] = "tl"};
 const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
-static const char *basictype_str[] = {
-	[btbool] = "Bool", [btchar] = "Char", [btint] = "Int",
-	[btvoid] = "Void",
-};
 
 struct ast *ast(struct list *decls)
 {
@@ -237,57 +234,6 @@ struct expr *expr_unop(enum unop op, struct expr *l)
 	return res;
 }
 
-struct type *type_basic(enum basictype type)
-{
-	struct type *res = safe_malloc(sizeof(struct type));
-	res->type = tbasic;
-	res->data.tbasic = type;
-	return res;
-}
-
-struct type *type_list(struct type *type)
-{
-	struct type *res = safe_malloc(sizeof(struct type));
-	res->type = tlist;
-	res->data.tlist = type;
-	return res;
-}
-
-struct type *type_tuple(struct type *l, struct type *r)
-{
-	struct type *res = safe_malloc(sizeof(struct type));
-	res->type = ttuple;
-	res->data.ttuple.l = l;
-	res->data.ttuple.r = r;
-	return res;
-}
-
-struct type *type_var(char *ident)
-{
-	struct type *res = safe_malloc(sizeof(struct type));
-	if (strcmp(ident, "Int") == 0) {
-		res->type = tbasic;
-		res->data.tbasic = btint;
-		free(ident);
-	} else if (strcmp(ident, "Char") == 0) {
-		res->type = tbasic;
-		res->data.tbasic = btchar;
-		free(ident);
-	} else if (strcmp(ident, "Bool") == 0) {
-		res->type = tbasic;
-		res->data.tbasic = btbool;
-		free(ident);
-	} else if (strcmp(ident, "Void") == 0) {
-		res->type = tbasic;
-		res->data.tbasic = btvoid;
-		free(ident);
-	} else {
-		res->type = tvar;
-		res->data.tvar = ident;
-	}
-	return res;
-}
-
 void ast_print(struct ast *ast, FILE *out)
 {
 	if (ast == NULL)
@@ -479,34 +425,6 @@ void expr_print(struct expr *expr, FILE *out)
 	}
 }
 
-void type_print(struct type *type, FILE *out)
-{
-	if (type == NULL)
-		return;
-	switch (type->type) {
-	case tbasic:
-		safe_fprintf(out, "%s", basictype_str[type->data.tbasic]);
-		break;
-	case tlist:
-		safe_fprintf(out, "[");
-		type_print(type->data.tlist, out);
-		safe_fprintf(out, "]");
-		break;
-	case ttuple:
-		safe_fprintf(out, "(");
-		type_print(type->data.ttuple.l, out);
-		safe_fprintf(out, ",");
-		type_print(type->data.ttuple.r, out);
-		safe_fprintf(out, ")");
-		break;
-	case tvar:
-		safe_fprintf(out, "%s", type->data.tvar);
-		break;
-	default:
-		die("Unsupported type node\n");
-	}
-}
-
 void ast_free(struct ast *ast)
 {
 	if (ast == NULL)
@@ -648,26 +566,3 @@ void expr_free(struct expr *expr)
 	}
 	free(expr);
 }
-
-void type_free(struct type *type)
-{
-	if (type == NULL)
-		return;
-	switch (type->type) {
-	case tbasic:
-		break;
-	case tlist:
-		type_free(type->data.tlist);
-		break;
-	case ttuple:
-		type_free(type->data.ttuple.l);
-		type_free(type->data.ttuple.r);
-		break;
-	case tvar:
-		free(type->data.tvar);
-		break;
-	default:
-		die("Unsupported type node\n");
-	}
-	free(type);
-}
diff --git a/ast.h b/ast.h
index db69333..635d0c8 100644
--- a/ast.h
+++ b/ast.h
@@ -5,6 +5,7 @@
 #include <stdbool.h>
 
 #include "util.h"
+#include "type.h"
 struct ast;
 #include "parse.h"
 
@@ -32,20 +33,6 @@ struct fundecl {
 	struct stmt **body;
 };
 
-enum basictype {btbool, btchar, btint, btvoid};
-struct type {
-	enum {tbasic,tlist,ttuple,tvar} type;
-	union {
-		enum basictype tbasic;
-		struct type *tlist;
-		struct {
-			struct type *l;
-			struct type *r;
-		} ttuple;
-		char *tvar;
-	} data;
-};
-
 struct decl {
 	//NOTE: DON'T CHANGE THIS ORDER
 	enum {dcomp, dvardecl, dfundecl} type;
@@ -158,18 +145,12 @@ struct expr *expr_tuple(struct expr *left, struct expr *right);
 struct expr *expr_string(char *str);
 struct expr *expr_unop(enum unop op, struct expr *l);
 
-struct type *type_basic(enum basictype type);
-struct type *type_list(struct type *type);
-struct type *type_tuple(struct type *l, struct type *r);
-struct type *type_var(char *ident);
-
 void ast_print(struct ast *, FILE *out);
 void vardecl_print(struct vardecl *decl, int indent, FILE *out);
 void fundecl_print(struct fundecl *decl, FILE *out);
 void decl_print(struct decl *ast, FILE *out);
 void stmt_print(struct stmt *ast, int indent, FILE *out);
 void expr_print(struct expr *ast, FILE *out);
-void type_print(struct type *type, FILE *out);
 
 void ast_free(struct ast *);
 void vardecl_free(struct vardecl *decl);
@@ -177,6 +158,5 @@ void fundecl_free(struct fundecl *fundecl);
 void decl_free(struct decl *ast);
 void stmt_free(struct stmt *ast);
 void expr_free(struct expr *ast);
-void type_free(struct type *type);
 
 #endif
diff --git a/sem.c b/sem.c
index 1cea48f..3404fb4 100644
--- a/sem.c
+++ b/sem.c
@@ -3,6 +3,7 @@
 
 #include "list.h"
 #include "sem/scc.h"
+#include "sem/hm.h"
 #include "ast.h"
 
 void type_error(const char *msg, ...)
@@ -34,38 +35,54 @@ void check_expr_constant(struct expr *expr)
 	}
 }
 
-struct vardecl *type_vardecl(struct vardecl *vardecl)
+struct vardecl *type_vardecl(struct gamma *gamma, struct vardecl *vardecl)
 {
-	return vardecl;
-}
+	struct type *t = vardecl->type == NULL
+		? gamma_fresh(gamma) : type_dup(vardecl->type);
+	struct subst *s = infer_expr(gamma, vardecl->expr, t);
 
-struct decl *type_decl(struct decl *decl)
-{
-	switch (decl->type) {
-	case dcomp:
-		fprintf(stderr, "type_decl:component unsupported\n");
-		break;
-	case dfundecl:
-		fprintf(stderr, "type_decl:fundecl unsupported\n");
-		break;
-	case dvardecl:
-		decl->data.dvar = type_vardecl(decl->data.dvar);
-		break;
-	}
-	return decl;
+	if (s == NULL)
+		die("error inferring variable\n");
+	vardecl->type = subst_apply_t(s, t);
+
+	//subst_free(s);
+
+	return vardecl;
 }
 
 struct ast *sem(struct ast *ast)
 {
 	ast = ast_scc(ast);
 
-	//Check that all globals are constant
+	struct gamma *gamma = gamma_init();
+
+	//Check all vardecls
 	for (int i = 0; i<ast->ndecls; i++) {
-		if (ast->decls[i]->type == dvardecl) {
-			//Check globals
+		switch(ast->decls[i]->type) {
+		case dvardecl:
+			//Check if constant
 			check_expr_constant(ast->decls[i]->data.dvar->expr);
+			//Infer if necessary
+			type_vardecl(gamma, ast->decls[i]->data.dvar);
+			break;
+		case dfundecl: {
+//			struct type *f1 = gamma_fresh(gamma);
+//			gamma_insert(gamma, ast->decls[i]->data.dfun->ident
+//				, scheme_create(f1));
+//infer env (Let [(x, e1)] e2)
+//	=              fresh
+//	>>= \tv->      let env` = 'Data.Map'.put x (Forall [] tv) env
+//	               in infer env` e1
+//	>>= \(s1,t1)-> infer ('Data.Map'.put x (generalize (apply s1 env`) t1) env`) e2
+//	>>= \(s2, t2)->pure (s1 oo s2, t2)
+			break;
+		}
+		case dcomp:
 			break;
 		}
 	}
+
+	gamma_free(gamma);
+
 	return ast;
 }
diff --git a/sem/hm.c b/sem/hm.c
index d91715e..fb3d2ce 100644
--- a/sem/hm.c
+++ b/sem/hm.c
@@ -2,173 +2,38 @@
 #include <string.h>
 
 #include "hm.h"
-#include "../util.h"
+#include "hm/subst.h"
+#include "hm/gamma.h"
+#include "hm/scheme.h"
 #include "../ast.h"
 
-struct gamma {
-	int fresh;
-	int nschemes;
-	char **vars;
-	struct scheme *schemes;
-};
-
-struct scheme *gamma_lookup(struct gamma *gamma, char *ident)
-{
-	for (int i = 0; i<nschemes; i++) {
-		if (strcmp(ident, gamma->vars[i]) == 0) {
-			//TODO
-		}
-	}
-	return NULL;
-}
-
-struct type *fresh(struct gamma *gamma)
-{
-	char *buf = safe_malloc(10);
-	sprintf(buf, "%d", gamma->fresh);
-	gamma->fresh++;
-	return type_var(buf);
-}
-
-void ftv_type(struct type *r, int *nftv, char **ftv)
-{
-	switch (r->type) {
-	case tbasic:
-		break;
-	case tlist:
-		ftv_type(r->data.tlist, nftv, ftv);
-		break;
-	case ttuple:
-		ftv_type(r->data.ttuple.l, nftv, ftv);
-		ftv_type(r->data.ttuple.r, nftv, ftv);
-		break;
-	case tvar:
-		*ftv = realloc(*ftv, (*nftv+1)*sizeof(char *));
-		if (*ftv == NULL)
-			perror("realloc");
-		ftv[(*nftv)++] = r->data.tvar;
-		break;
-	}
-}
-
 bool occurs_check(char *var, struct type *r)
 {
 	int nftv = 0;
-	char *ftv = NULL;
-	ftv_type(r, &nftv, &ftv);
+	char **ftv = NULL;
+	type_ftv(r, &nftv, &ftv);
 	for (int i = 0; i<nftv; i++)
-		if (strcmp(ftv+i, var) == 0)
+		if (strcmp(ftv[i], var) == 0)
 			return true;
 	return false;
 }
 
-struct type *dup_type(struct type *r)
-{
-	struct type *res = safe_malloc(sizeof(struct type));
-	*res = *r;
-	switch (r->type) {
-	case tbasic:
-		break;
-	case tlist:
-		res->data.tlist = dup_type(r->data.tlist);
-		break;
-	case ttuple:
-		res->data.ttuple.l = dup_type(r->data.ttuple.l);
-		res->data.ttuple.r = dup_type(r->data.ttuple.r);
-		break;
-	case tvar:
-		res->data.tvar = strdup(r->data.tvar);
-		break;
-	}
-	return res;
-}
-
-struct type *subst_apply_t(struct substitution *subst, struct type *l)
+struct subst *unify(struct type *l, struct type *r)
 {
+	if (l == NULL || r == NULL)
+		return NULL;
+	if (r->type == tvar && l->type != tvar)
+		return unify(r, l);
+	struct subst *s1, *s2;
 	switch (l->type) {
-	case tbasic:
-		break;
-	case tlist:
-		l->data.tlist = subst_apply_t(subst, l->data.tlist);
-		break;
-	case ttuple:
-		l->data.ttuple.l = subst_apply_t(subst, l->data.ttuple.l);
-		l->data.ttuple.r = subst_apply_t(subst, l->data.ttuple.r);
-		break;
-	case tvar:
-		for (int i = 0; i<subst->nvar; i++) {
-			if (strcmp(subst->vars[i], l->data.tvar) == 0) {
-				free(l->data.tvar);
-				struct type *r = subst->types[i];
-				*l = *r;
-				free(r);
-			}
+	case tarrow:
+		if (r->type == tarrow) {
+			s1 = unify(l->data.tarrow.l, r->data.tarrow.l);
+			s2 = unify(subst_apply_t(s1, l->data.tarrow.l),
+				subst_apply_t(s1, r->data.tarrow.l));
+			return subst_union(s1, s2);
 		}
 		break;
-	}
-	return l;
-}
-struct gamma *subst_apply_g(struct substitution *subst, struct gamma *gamma)
-{
-	//TODO
-	return gamma;
-}
-
-void subst_print(struct substitution *s, FILE *out)
-{
-	if (s == NULL) {
-		fprintf(out, "no substitution\n");
-	} else {
-		fprintf(out, "[");
-		for (int i = 0; i<s->nvar; i++) {
-			fprintf(out, "%s->", s->vars[i]);
-			type_print(s->types[i], out);
-			if (i + 1 < s->nvar)
-				fprintf(out, ", ");
-		}
-		fprintf(out, "]\n");
-	}
-}
-
-struct substitution *subst_id()
-{
-	struct substitution *res = safe_malloc(sizeof(struct substitution));
-	res->nvar = 0;
-	res->vars = NULL;
-	res->types = NULL;
-	return res;
-}
-
-struct substitution *subst_singleton(char *ident, struct type *t)
-{
-	struct substitution *res = safe_malloc(sizeof(struct substitution));
-	res->nvar = 1;
-	res->vars = safe_malloc(sizeof(char *));
-	res->vars[0] = safe_strdup(ident);
-	res->types = safe_malloc(sizeof(struct type *));
-	res->types[0] = dup_type(t);
-	return res;
-}
-
-struct substitution *subst_union(struct substitution *l, struct substitution *r)
-{
-	struct substitution *res = safe_malloc(sizeof(struct substitution));
-	res->nvar = l->nvar+r->nvar;
-	res->vars = safe_malloc(res->nvar*sizeof(char *));
-	res->types = safe_malloc(res->nvar*sizeof(struct type *));
-	for (int i = 0; i<l->nvar; i++) {
-		res->vars[i] = l->vars[i];
-		res->types[i] = l->types[i];
-	}
-	for (int i = 0; i<r->nvar; i++) {
-		res->vars[l->nvar+i] = l->vars[i];
-		res->types[l->nvar+i] = subst_apply_t(l, r->types[i]);
-	}
-	return res;
-}
-
-struct substitution *unify(struct type *l, struct type *r) {
-	switch (l->type) {
 	case tbasic:
 		if (r->type == tbasic && l->data.tbasic == r->data.tbasic)
 			return subst_id();
@@ -179,11 +44,8 @@ struct substitution *unify(struct type *l, struct type *r) {
 		break;
 	case ttuple:
 		if (r->type == ttuple) {
-			struct substitution *s1 = unify(
-				l->data.ttuple.l,
-				r->data.ttuple.l);
-			struct substitution *s2 = unify(
-				subst_apply_t(s1, l->data.ttuple.l),
+			s1 = unify(l->data.ttuple.l, r->data.ttuple.l);
+			s2 = unify(subst_apply_t(s1, l->data.ttuple.l),
 				subst_apply_t(s1, r->data.ttuple.l));
 			return subst_union(s1, s2);
 		}
@@ -197,11 +59,23 @@ struct substitution *unify(struct type *l, struct type *r) {
 			return subst_singleton(l->data.tvar, r);
 		break;
 	}
+	fprintf(stderr, "cannot unify ");
+	type_print(l, stderr);
+	fprintf(stderr, " with ");
+	type_print(r, stderr);
+	fprintf(stderr, "\n");
 	return NULL;
 }
 
-struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type)
+struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type)
 {
+	fprintf(stderr, "infer expr: ");
+	expr_print(expr, stderr);
+	fprintf(stderr, "\ngamma: ");
+	gamma_print(gamma, stderr);
+	fprintf(stderr, "\ntype: ");
+	type_print(type, stderr);
+	fprintf(stderr, "\n");
 
 #define infbop(l, r, a1, a2, rt, sigma) {\
 	s1 = infer_expr(gamma, l, a1);\
@@ -210,13 +84,14 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t
 }
 #define infbinop(e, a1, a2, rt, sigma)\
 	infbop(e->data.ebinop.l, e->data.ebinop.r, a1, a2, rt, sigma)
-	struct substitution *s1, *s2;
+	struct subst *s1, *s2;
 	struct type *f1, *f2;
+	struct scheme *s;
 	switch (expr->type) {
 	case ebool:
 		return unify(type_basic(btbool), type);
 	case ebinop:
-		switch(expr->data.ebinop.op) {
+		switch (expr->data.ebinop.op) {
 		case binor:
 		case binand:
 			infbinop(expr, type_basic(btbool), type_basic(btbool),
@@ -227,10 +102,10 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t
 		case le:
 		case geq:
 		case ge:
-			f1 = fresh(gamma);
+			f1 = gamma_fresh(gamma);
 			infbinop(expr, f1, f1, type_basic(btbool), type);
 		case cons:
-			f1 = fresh(gamma);
+			f1 = gamma_fresh(gamma);
 			infbinop(expr, f1, type_list(f1), type_list(f1), type);
 		case plus:
 		case minus:
@@ -241,20 +116,28 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t
 			infbinop(expr, type_basic(btint), type_basic(btint),
 				type_basic(btint), type);
 		}
+		break;
 	case echar:
 		return unify(type_basic(btchar), type);
 	case efuncall:
+		if ((s = gamma_lookup(gamma, expr->data.efuncall.ident)) == NULL)
+			die("Unbound function: %s\n", expr->data.efuncall.ident);
 		//TODO
+		//TODO fields
 		return NULL;
 	case eint:
 		return unify(type_basic(btint), type);
 	case eident:
-
-		//TODO
-		return NULL;
+		if ((s = gamma_lookup(gamma, expr->data.eident.ident)) == NULL)
+			die("Unbound variable: %s\n", expr->data.eident.ident);
+		//TODO fields
+		return unify(scheme_instantiate(gamma, s), type);
+	case enil:
+		f1 = gamma_fresh(gamma);
+		return unify(type_list(f1), type);
 	case etuple:
-		f1 = fresh(gamma);
-		f2 = fresh(gamma);
+		f1 = gamma_fresh(gamma);
+		f2 = gamma_fresh(gamma);
 		infbop(expr->data.etuple.left, expr->data.etuple.right,
 		       f1, f2, type_tuple(f1, f2), type);
 	case estring:
@@ -264,6 +147,8 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t
 		case negate:
 			s1 = infer_expr(gamma,
 				expr->data.eunop.l, type_basic(btint));
+			if (s1 == NULL)
+				return NULL;
 			return subst_union(s1,
 				unify(subst_apply_t(s1, type),
 				type_basic(btint)));
@@ -275,5 +160,5 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t
 				type_basic(btbool)));
 		}
 	}
-
+	return NULL;
 }
diff --git a/sem/hm.h b/sem/hm.h
index efb1afa..6792c3b 100644
--- a/sem/hm.h
+++ b/sem/hm.h
@@ -2,19 +2,11 @@
 #define SEM_HM_C
 
 #include "../ast.h"
-
-struct scheme {
-	struct type *type;
-	int nvar;
-	char **var;
-};
-
-struct substitution {
-	int nvar;
-	char **vars;
-	struct type **types;
-};
+#include "hm/gamma.h"
+#include "hm/subst.h"
+#include "hm/scheme.h"
 
 struct ast *infer(struct ast *ast);
+struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type);
 
 #endif
diff --git a/sem/hm/gamma.c b/sem/hm/gamma.c
new file mode 100644
index 0000000..9a97c27
--- /dev/null
+++ b/sem/hm/gamma.c
@@ -0,0 +1,62 @@
+#include <stdlib.h>
+#include <string.h>
+
+#include "../hm.h"
+
+struct gamma *gamma_init()
+{
+	struct gamma *gamma = safe_malloc(sizeof(struct gamma));
+	gamma->fresh = 0;
+	gamma->nschemes = 0;
+	gamma->vars = NULL;
+	gamma->schemes = NULL;
+	return gamma;
+}
+
+void gamma_insert(struct gamma *gamma, char *ident, struct scheme *scheme)
+{
+	gamma->nschemes++;
+	gamma->vars = realloc(gamma->vars, gamma->nschemes*sizeof(char *));
+	gamma->schemes = realloc(gamma->schemes,
+		gamma->nschemes*sizeof(struct scheme *));
+	gamma->vars[gamma->nschemes-1] = safe_strdup(ident);
+	gamma->schemes[gamma->nschemes-1] = scheme;
+}
+
+struct scheme *gamma_lookup(struct gamma *gamma, char *ident)
+{
+	for (int i = 0; i<gamma->nschemes; i++)
+		if (strcmp(ident, gamma->vars[i]) == 0)
+			return gamma->schemes[i];
+	return NULL;
+}
+
+struct type *gamma_fresh(struct gamma *gamma)
+{
+	char buf[10] = {0};
+	sprintf(buf, "%d", gamma->fresh++);
+	return type_var(safe_strdup(buf));
+}
+
+void gamma_print(struct gamma *gamma, FILE *out)
+{
+	fprintf(out, "{");
+	for (int i = 0; i<gamma->nschemes; i++) {
+		fprintf(out, "%s=", gamma->vars[i]);
+		scheme_print(gamma->schemes[i], out);
+		if (i + 1 < gamma->nschemes)
+			fprintf(out, ", ");
+	}
+	fprintf(out, "}");
+}
+
+void gamma_free(struct gamma *gamma)
+{
+	for (int i = 0; i<gamma->nschemes; i++) {
+		free(gamma->vars[i]);
+		scheme_free(gamma->schemes[i]);
+	}
+	free(gamma->vars);
+	free(gamma->schemes);
+	free(gamma);
+}
diff --git a/sem/hm/gamma.h b/sem/hm/gamma.h
new file mode 100644
index 0000000..8499144
--- /dev/null
+++ b/sem/hm/gamma.h
@@ -0,0 +1,24 @@
+#ifndef SEM_HM_GAMMA_H
+#define SEM_HM_GAMMA_H
+
+#include <stdlib.h>
+
+#include "../hm.h"
+
+struct gamma {
+	int fresh;
+	int nschemes;
+	char **vars;
+	struct scheme **schemes;
+};
+
+struct gamma *gamma_init();
+void gamma_insert(struct gamma *gamma, char *ident, struct scheme *scheme);
+
+struct scheme *gamma_lookup(struct gamma *gamma, char *ident);
+struct type *gamma_fresh(struct gamma *gamma);
+
+void gamma_print(struct gamma *gamma, FILE *out);
+void gamma_free(struct gamma *gamma);
+
+#endif
diff --git a/sem/hm/scheme.c b/sem/hm/scheme.c
new file mode 100644
index 0000000..177bb64
--- /dev/null
+++ b/sem/hm/scheme.c
@@ -0,0 +1,68 @@
+#include <string.h>
+#include <stdlib.h>
+
+#include "../hm.h"
+
+struct type *scheme_instantiate(struct gamma *gamma, struct scheme *sch)
+{
+	struct subst *s = subst_id();
+	for (int i = 0; i<sch->nvar; i++) {
+		s = subst_union(s, subst_singleton(sch->var[i], gamma_fresh(gamma)));
+	}
+	struct type *t = subst_apply_t(s, type_dup(sch->type));
+	for (int i = 0; i<s->nvar; i++)
+		free(s->vars[i]);
+	free(s);
+	return t;
+}
+
+struct scheme *scheme_create(struct type *t)
+{
+	struct scheme *s = safe_malloc(sizeof(struct scheme));
+	s->type = t;
+	s->nvar = 0;
+	s->var = NULL;
+}
+
+struct scheme *scheme_generalise(struct gamma *gamma, struct type *t)
+{
+	struct scheme *s = safe_malloc(sizeof(struct scheme));
+	int nftv = 0;
+	char **ftv = NULL;
+	type_ftv(t, &nftv, &ftv);
+
+	s->type = t;
+	s->nvar = 0;
+	s->var = safe_malloc(nftv*sizeof(char *));
+	for (int i = 0; i<nftv; i++) {
+		bool skip = false;
+		for (int j = 0; j<gamma->nschemes; j++)
+			if (strcmp(gamma->vars[j], ftv[i]) == 0)
+				skip = true;
+		if (skip)
+			continue;
+		s->nvar++;
+		s->var[i] = ftv[i];
+	}
+	return s;
+}
+
+void scheme_print(struct scheme *scheme, FILE *out)
+{
+	if (scheme->nvar > 0) {
+		fprintf(out, "A.");
+		for (int i = 0; i<scheme->nvar; i++)
+			fprintf(out, "%s", scheme->var[i]);
+		fprintf(out, ": ");
+	}
+	type_print(scheme->type, out);
+}
+
+void scheme_free(struct scheme *scheme)
+{
+	type_free(scheme->type);
+	for (int i = 0; i<scheme->nvar; i++)
+		free(scheme->var[i]);
+	free(scheme->var);
+	free(scheme);
+}
diff --git a/sem/hm/scheme.h b/sem/hm/scheme.h
new file mode 100644
index 0000000..eaab3fc
--- /dev/null
+++ b/sem/hm/scheme.h
@@ -0,0 +1,19 @@
+#ifndef SEM_HM_SCHEME_H
+#define SEM_HM_SCHEME_H
+
+#include "../hm.h"
+
+struct scheme {
+	struct type *type;
+	int nvar;
+	char **var;
+};
+
+struct type *scheme_instantiate(struct gamma *gamma, struct scheme *s);
+struct scheme *scheme_create(struct type *t);
+struct scheme *scheme_generalise(struct gamma *gamma, struct type *t);
+
+void scheme_print(struct scheme *scheme, FILE *out);
+void scheme_free(struct scheme *scheme);
+
+#endif
diff --git a/sem/hm/subst.c b/sem/hm/subst.c
new file mode 100644
index 0000000..55d6cfe
--- /dev/null
+++ b/sem/hm/subst.c
@@ -0,0 +1,124 @@
+#include <string.h>
+#include <stdlib.h>
+
+#include "../hm.h"
+
+struct subst *subst_id()
+{
+	struct subst *res = safe_malloc(sizeof(struct subst));
+	res->nvar = 0;
+	res->vars = NULL;
+	res->types = NULL;
+	return res;
+}
+
+struct subst *subst_singleton(char *ident, struct type *t)
+{
+	struct subst *res = safe_malloc(sizeof(struct subst));
+	res->nvar = 1;
+	res->vars = safe_malloc(sizeof(char *));
+	res->vars[0] = safe_strdup(ident);
+	res->types = safe_malloc(sizeof(struct type *));
+	res->types[0] = type_dup(t);
+	return res;
+}
+
+struct subst *subst_union(struct subst *l, struct subst *r)
+{
+	if (l == NULL || r == NULL)
+		return NULL;
+	struct subst *res = safe_malloc(sizeof(struct subst));
+	res->nvar = l->nvar+r->nvar;
+	res->vars = safe_malloc(res->nvar*sizeof(char *));
+	res->types = safe_malloc(res->nvar*sizeof(struct type *));
+	for (int i = 0; i<l->nvar; i++) {
+		res->vars[i] = l->vars[i];
+		res->types[i] = l->types[i];
+	}
+	for (int i = 0; i<r->nvar; i++) {
+		res->vars[l->nvar+i] = r->vars[i];
+		res->types[l->nvar+i] = subst_apply_t(l, r->types[i]);
+	}
+	return res;
+}
+
+struct type *subst_apply_t(struct subst *subst, struct type *l)
+{
+	if (subst == NULL)
+		return l;
+	switch (l->type) {
+	case tarrow:
+		l->data.tarrow.l = subst_apply_t(subst, l->data.tarrow.l);
+		l->data.tarrow.r = subst_apply_t(subst, l->data.tarrow.r);
+		break;
+	case tbasic:
+		break;
+	case tlist:
+		l->data.tlist = subst_apply_t(subst, l->data.tlist);
+		break;
+	case ttuple:
+		l->data.ttuple.l = subst_apply_t(subst, l->data.ttuple.l);
+		l->data.ttuple.r = subst_apply_t(subst, l->data.ttuple.r);
+		break;
+	case tvar:
+		for (int i = 0; i<subst->nvar; i++) {
+			if (strcmp(subst->vars[i], l->data.tvar) == 0) {
+				free(l->data.tvar);
+				struct type *r = type_dup(subst->types[i]);
+				*l = *r;
+				free(r);
+				break;
+			}
+		}
+		break;
+	}
+	return l;
+}
+
+struct scheme *subst_apply_s(struct subst *subst, struct scheme *scheme)
+{
+	for (int i = 0; i<scheme->nvar; i++) {
+		for (int j = 0; j<subst->nvar; j++) {
+			if (strcmp(scheme->var[i], subst->vars[j]) != 0) {
+				struct subst *t = subst_singleton(
+					subst->vars[j], subst->types[j]);
+				scheme->type = subst_apply_t(t, scheme->type);
+			}
+		}
+	}
+	return scheme;
+}
+
+struct gamma *subst_apply_g(struct subst *subst, struct gamma *gamma)
+{
+	for (int i = 0; i<gamma->nschemes; i++)
+		subst_apply_s(subst, gamma->schemes[i]);
+	return gamma;
+}
+
+void subst_print(struct subst *s, FILE *out)
+{
+	if (s == NULL) {
+		fprintf(out, "no subst\n");
+	} else {
+		fprintf(out, "[");
+		for (int i = 0; i<s->nvar; i++) {
+			fprintf(out, "%s->", s->vars[i]);
+			type_print(s->types[i], out);
+			if (i + 1 < s->nvar)
+				fprintf(out, ", ");
+		}
+		fprintf(out, "]\n");
+	}
+}
+
+void subst_free(struct subst *s, bool type)
+{
+	if (s != NULL) {
+		for (int i = 0; i<s->nvar; i++) {
+			free(s->vars[i]);
+			if (type)
+				type_free(s->types[i]);
+		}
+	}
+}
diff --git a/sem/hm/subst.h b/sem/hm/subst.h
new file mode 100644
index 0000000..ffec9a7
--- /dev/null
+++ b/sem/hm/subst.h
@@ -0,0 +1,24 @@
+#ifndef SEM_HM_SUBST_H
+#define SEM_HM_SUBST_H
+
+#include "../../ast.h"
+#include "../hm.h"
+
+struct subst {
+	int nvar;
+	char **vars;
+	struct type **types;
+};
+
+struct subst *subst_id();
+struct subst *subst_singleton(char *ident, struct type *t);
+struct subst *subst_union(struct subst *l, struct subst *r);
+
+struct type *subst_apply_t(struct subst *subst, struct type *l);
+struct scheme *subst_apply_s(struct subst *subst, struct scheme *scheme);
+struct gamma *subst_apply_g(struct subst *subst, struct gamma *gamma);
+
+void subst_print(struct subst *s, FILE *out);
+void subst_free(struct subst *s, bool type);
+
+#endif
diff --git a/sem/scc.c b/sem/scc.c
index fa47590..6d7b02e 100644
--- a/sem/scc.c
+++ b/sem/scc.c
@@ -271,10 +271,11 @@ struct ast *ast_scc(struct ast *ast)
 	struct edge **edata = (struct edge **)
 		list_to_array(edges, &nedges, false);
 
+	fprintf(stderr, "nfun: %d, ffun: %d, nedges: %d\n", nfun, ffun, nedges);
 	// Do tarjan's and convert back into the declaration list
 	struct components *cs = tarjans(nfun, (void **)fundecls, nedges, edata);
-	if (cs == NULL)
-		die("malformed edges in tarjan's????");
+//	if (cs == NULL)
+//		die("malformed edges in tarjan's????");
 
 	int i = ffun;
 	for (struct components *c = cs; c != NULL; c = c->next) {
diff --git a/test/Makefile b/test/Makefile
new file mode 100644
index 0000000..faaff6a
--- /dev/null
+++ b/test/Makefile
@@ -0,0 +1,19 @@
+TESTOBJECTS:=$(patsubst %.c,%.o,$(wildcard *.c))
+TESTS:=$(patsubst %.o,%,$(TESTOBJECTS))
+
+LDLIBS+=$(shell pkg-config --libs check)
+
+.PHONY: test
+
+test_sem_hm_gamma.o: CFLAGS+=$(shell pkg-config --cflags check)
+
+test_sem_hm_gamma: test_sem_hm_gamma.o $(addprefix ../sem/hm/,gamma.o scheme.o subst.o) ../util.o ../type.o
+
+test: $(TESTS)
+	$(foreach f,$^,./$(f);)
+
+clean:
+	$(RM) $(TESTOBJECTS) $(TESTS)
+ifeq ($(MAKELEVEL), 0)
+	$(MAKE) -C ../ clean
+endif
diff --git a/test/test_sem_hm_gamma.c b/test/test_sem_hm_gamma.c
new file mode 100644
index 0000000..1665031
--- /dev/null
+++ b/test/test_sem_hm_gamma.c
@@ -0,0 +1,97 @@
+#include <stdbool.h>
+#include <stdlib.h>
+#include <check.h>
+
+#include "../sem/hm/gamma.h"
+#include "../sem/hm/subst.h"
+#include "../sem/hm/scheme.h"
+
+START_TEST(test_gamma_lookup)
+{
+	struct gamma *gamma = gamma_init();
+
+	ck_assert_ptr_null(gamma_lookup(gamma, "fun"));
+	ck_assert_ptr_null(gamma_lookup(gamma, "fun2"));
+
+	gamma_insert(gamma, "fun", scheme_generalise(gamma, type_basic(btint)));
+
+	ck_assert_ptr_nonnull(gamma_lookup(gamma, "fun"));
+	ck_assert_ptr_null(gamma_lookup(gamma, "fun2"));
+
+	struct type *t1 = gamma_fresh(gamma);
+	ck_assert(t1->type == tvar);
+	struct type *t2 = gamma_fresh(gamma);
+	ck_assert(t2->type == tvar);
+	struct type *t3 = gamma_fresh(gamma);
+	ck_assert(t3->type == tvar);
+	struct type *t4 = gamma_fresh(gamma);
+	ck_assert(t4->type == tvar);
+
+	ck_assert_str_ne(t1->data.tvar, t2->data.tvar);
+	ck_assert_str_ne(t2->data.tvar, t3->data.tvar);
+	ck_assert_str_ne(t3->data.tvar, t4->data.tvar);
+}
+END_TEST
+
+START_TEST(test_scheme)
+{
+	struct gamma *gamma = gamma_init();
+
+	char **var = malloc(sizeof(char *));
+	var[0] = safe_strdup("a");
+	struct scheme scheme = {.type=type_var("a"), .nvar=1, .var=var};
+
+	struct type *t = scheme_instantiate(gamma, &scheme);
+	ck_assert(t->type == tvar);
+	ck_assert_str_eq(t->data.tvar, "0");
+
+	scheme.type = type_list(type_var("a"));
+	t = scheme_instantiate(gamma, &scheme);
+	ck_assert(t->type == tlist);
+	ck_assert(t->data.tlist->type == tvar);
+	ck_assert_str_eq(t->data.tlist->data.tvar, "1");
+}
+END_TEST
+
+START_TEST(test_subst)
+{
+	struct subst *s1 = subst_id();
+	ck_assert_int_eq(0, s1->nvar);
+	s1 = subst_singleton("i1", type_basic(btint));
+	ck_assert_int_eq(1, s1->nvar);
+	s1 = subst_union(subst_id(), subst_singleton("i1", type_basic(btint)));
+	ck_assert_int_eq(1, s1->nvar);
+	s1 = subst_union(subst_singleton("i2", type_basic(btbool)),
+		subst_singleton("i1", type_basic(btint)));
+	ck_assert_int_eq(2, s1->nvar);
+
+}
+END_TEST
+
+Suite *util_suite(void)
+{
+	Suite *s = suite_create("List");
+
+	TCase *tc_gamma = tcase_create("Gamma lookup");
+	tcase_add_test(tc_gamma, test_gamma_lookup);
+	tcase_add_test(tc_gamma, test_scheme);
+	tcase_add_test(tc_gamma, test_subst);
+	suite_add_tcase(s, tc_gamma);
+
+	return s;
+}
+
+int main(void)
+{
+	int failed;
+	Suite *s;
+	SRunner *sr;
+
+	s = util_suite();
+	sr = srunner_create(s);
+
+	srunner_run_all(sr, CK_NORMAL);
+	failed = srunner_ntests_failed(sr);
+	srunner_free(sr);
+	return (failed == 0) ? EXIT_SUCCESS : EXIT_FAILURE;
+}
diff --git a/type.c b/type.c
new file mode 100644
index 0000000..8fc6102
--- /dev/null
+++ b/type.c
@@ -0,0 +1,183 @@
+#include <string.h>
+#include <stdlib.h>
+
+#include "util.h"
+#include "type.h"
+
+static const char *basictype_str[] = {
+	[btbool] = "Bool", [btchar] = "Char", [btint] = "Int",
+	[btvoid] = "Void",
+};
+
+struct type *type_arrow(struct type *l, struct type *r)
+{
+	struct type *res = safe_malloc(sizeof(struct type));
+	res->type = tarrow;
+	res->data.tarrow.l = l;
+	res->data.tarrow.r = r;
+	return res;
+
+}
+
+struct type *type_basic(enum basictype type)
+{
+	struct type *res = safe_malloc(sizeof(struct type));
+	res->type = tbasic;
+	res->data.tbasic = type;
+	return res;
+}
+
+struct type *type_list(struct type *type)
+{
+	struct type *res = safe_malloc(sizeof(struct type));
+	res->type = tlist;
+	res->data.tlist = type;
+	return res;
+}
+
+struct type *type_tuple(struct type *l, struct type *r)
+{
+	struct type *res = safe_malloc(sizeof(struct type));
+	res->type = ttuple;
+	res->data.ttuple.l = l;
+	res->data.ttuple.r = r;
+	return res;
+}
+
+struct type *type_var(char *ident)
+{
+	struct type *res = safe_malloc(sizeof(struct type));
+	if (strcmp(ident, "Int") == 0) {
+		res->type = tbasic;
+		res->data.tbasic = btint;
+		free(ident);
+	} else if (strcmp(ident, "Char") == 0) {
+		res->type = tbasic;
+		res->data.tbasic = btchar;
+		free(ident);
+	} else if (strcmp(ident, "Bool") == 0) {
+		res->type = tbasic;
+		res->data.tbasic = btbool;
+		free(ident);
+	} else if (strcmp(ident, "Void") == 0) {
+		res->type = tbasic;
+		res->data.tbasic = btvoid;
+		free(ident);
+	} else {
+		res->type = tvar;
+		res->data.tvar = ident;
+	}
+	return res;
+}
+
+void type_print(struct type *type, FILE *out)
+{
+	if (type == NULL)
+		return;
+	switch (type->type) {
+	case tarrow:
+		safe_fprintf(out, "(");
+		type_print(type->data.tarrow.l, out);
+		safe_fprintf(out, "->");
+		type_print(type->data.tarrow.r, out);
+		safe_fprintf(out, ")");
+		break;
+	case tbasic:
+		safe_fprintf(out, "%s", basictype_str[type->data.tbasic]);
+		break;
+	case tlist:
+		safe_fprintf(out, "[");
+		type_print(type->data.tlist, out);
+		safe_fprintf(out, "]");
+		break;
+	case ttuple:
+		safe_fprintf(out, "(");
+		type_print(type->data.ttuple.l, out);
+		safe_fprintf(out, ",");
+		type_print(type->data.ttuple.r, out);
+		safe_fprintf(out, ")");
+		break;
+	case tvar:
+		safe_fprintf(out, "%s", type->data.tvar);
+		break;
+	default:
+		die("Unsupported type node\n");
+	}
+}
+
+void type_free(struct type *type)
+{
+	if (type == NULL)
+		return;
+	switch (type->type) {
+	case tarrow:
+		type_free(type->data.tarrow.l);
+		type_free(type->data.tarrow.r);
+		break;
+	case tbasic:
+		break;
+	case tlist:
+		type_free(type->data.tlist);
+		break;
+	case ttuple:
+		type_free(type->data.ttuple.l);
+		type_free(type->data.ttuple.r);
+		break;
+	case tvar:
+		free(type->data.tvar);
+		break;
+	default:
+		die("Unsupported type node\n");
+	}
+	free(type);
+}
+
+struct type *type_dup(struct type *r)
+{
+	struct type *res = safe_malloc(sizeof(struct type));
+	*res = *r;
+	switch (r->type) {
+	case tarrow:
+		res->data.tarrow.l = type_dup(r->data.tarrow.l);
+		res->data.tarrow.r = type_dup(r->data.tarrow.r);
+		break;
+	case tbasic:
+		break;
+	case tlist:
+		res->data.tlist = type_dup(r->data.tlist);
+		break;
+	case ttuple:
+		res->data.ttuple.l = type_dup(r->data.ttuple.l);
+		res->data.ttuple.r = type_dup(r->data.ttuple.r);
+		break;
+	case tvar:
+		res->data.tvar = safe_strdup(r->data.tvar);
+		break;
+	}
+	return res;
+}
+
+void type_ftv(struct type *r, int *nftv, char ***ftv)
+{
+	switch (r->type) {
+	case tarrow:
+		type_ftv(r->data.ttuple.l, nftv, ftv);
+		type_ftv(r->data.ttuple.r, nftv, ftv);
+		break;
+	case tbasic:
+		break;
+	case tlist:
+		type_ftv(r->data.tlist, nftv, ftv);
+		break;
+	case ttuple:
+		type_ftv(r->data.ttuple.l, nftv, ftv);
+		type_ftv(r->data.ttuple.r, nftv, ftv);
+		break;
+	case tvar:
+		*ftv = realloc(*ftv, (*nftv+1)*sizeof(char *));
+		if (*ftv == NULL)
+			perror("realloc");
+		(*ftv)[(*nftv)++] = r->data.tvar;
+		break;
+	}
+}
diff --git a/type.h b/type.h
new file mode 100644
index 0000000..0bbb2e0
--- /dev/null
+++ b/type.h
@@ -0,0 +1,36 @@
+#ifndef TYPE_H
+#define TYPE_H
+
+#include <stdio.h>
+
+enum basictype {btbool, btchar, btint, btvoid};
+struct type {
+	enum {tarrow,tbasic,tlist,ttuple,tvar} type;
+	union {
+		struct {
+			struct type *l;
+			struct type *r;
+		} tarrow;
+		enum basictype tbasic;
+		struct type *tlist;
+		struct {
+			struct type *l;
+			struct type *r;
+		} ttuple;
+		char *tvar;
+	} data;
+};
+
+struct type *type_arrow(struct type *l, struct type *r);
+struct type *type_basic(enum basictype type);
+struct type *type_list(struct type *type);
+struct type *type_tuple(struct type *l, struct type *r);
+struct type *type_var(char *ident);
+
+void type_print(struct type *type, FILE *out);
+void type_free(struct type *type);
+
+struct type *type_dup(struct type *t);
+void type_ftv(struct type *r, int *nftv, char ***ftv);
+
+#endif
diff --git a/util.c b/util.c
index efe22c7..fa60ddc 100644
--- a/util.c
+++ b/util.c
@@ -154,9 +154,11 @@ void *safe_malloc(size_t size)
 
 void *safe_strdup(const char *c)
 {
-	char *res = strdup(c);
+	size_t nchar = strlen(c);
+	char *res = malloc((nchar+1)*sizeof(char));
 	if (res == NULL)
 		pdie("strdup");
+	memcpy(res, c, nchar+1);
 	return res;
 }
 
-- 
2.20.1