work on type inference some more
[ccc.git] / sem / hm.c
1 #include <stdlib.h>
2 #include <string.h>
3
4 #include "hm.h"
5 #include "hm/subst.h"
6 #include "hm/gamma.h"
7 #include "hm/scheme.h"
8 #include "../ast.h"
9
10 bool occurs_check(char *var, struct type *r)
11 {
12 int nftv = 0;
13 char **ftv = NULL;
14 type_ftv(r, &nftv, &ftv);
15 for (int i = 0; i<nftv; i++)
16 if (strcmp(ftv[i], var) == 0)
17 return true;
18 return false;
19 }
20
21 struct subst *unify(struct type *l, struct type *r)
22 {
23 if (l == NULL || r == NULL)
24 return NULL;
25 if (r->type == tvar && l->type != tvar)
26 return unify(r, l);
27 struct subst *s1, *s2;
28 switch (l->type) {
29 case tarrow:
30 if (r->type == tarrow) {
31 s1 = unify(l->data.tarrow.l, r->data.tarrow.l);
32 s2 = unify(subst_apply_t(s1, l->data.tarrow.l),
33 subst_apply_t(s1, r->data.tarrow.l));
34 return subst_union(s1, s2);
35 }
36 break;
37 case tbasic:
38 if (r->type == tbasic && l->data.tbasic == r->data.tbasic)
39 return subst_id();
40 break;
41 case tlist:
42 if (r->type == tlist)
43 return unify(l->data.tlist, r->data.tlist);
44 break;
45 case ttuple:
46 if (r->type == ttuple) {
47 s1 = unify(l->data.ttuple.l, r->data.ttuple.l);
48 s2 = unify(subst_apply_t(s1, l->data.ttuple.l),
49 subst_apply_t(s1, r->data.ttuple.l));
50 return subst_union(s1, s2);
51 }
52 break;
53 case tvar:
54 if (r->type == tvar && strcmp(l->data.tvar, r->data.tvar) == 0)
55 return subst_id();
56 else if (occurs_check(l->data.tvar, r))
57 fprintf(stderr, "Infinite type %s\n", l->data.tvar);
58 else
59 return subst_singleton(l->data.tvar, r);
60 break;
61 }
62 fprintf(stderr, "cannot unify ");
63 type_print(l, stderr);
64 fprintf(stderr, " with ");
65 type_print(r, stderr);
66 fprintf(stderr, "\n");
67 return NULL;
68 }
69
70 struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type)
71 {
72 fprintf(stderr, "infer expr: ");
73 expr_print(expr, stderr);
74 fprintf(stderr, "\ngamma: ");
75 gamma_print(gamma, stderr);
76 fprintf(stderr, "\ntype: ");
77 type_print(type, stderr);
78 fprintf(stderr, "\n");
79
80 #define infbop(l, r, a1, a2, rt, sigma) {\
81 s1 = infer_expr(gamma, l, a1);\
82 s2 = subst_union(s1, infer_expr(subst_apply_g(s1, gamma), r, a2));\
83 return subst_union(s2, unify(subst_apply_t(s2, sigma), rt));\
84 }
85 #define infbinop(e, a1, a2, rt, sigma)\
86 infbop(e->data.ebinop.l, e->data.ebinop.r, a1, a2, rt, sigma)
87 struct subst *s1, *s2;
88 struct type *f1, *f2;
89 struct scheme *s;
90 switch (expr->type) {
91 case ebool:
92 return unify(type_basic(btbool), type);
93 case ebinop:
94 switch (expr->data.ebinop.op) {
95 case binor:
96 case binand:
97 infbinop(expr, type_basic(btbool), type_basic(btbool),
98 type_basic(btbool), type);
99 case eq:
100 case neq:
101 case leq:
102 case le:
103 case geq:
104 case ge:
105 f1 = gamma_fresh(gamma);
106 infbinop(expr, f1, f1, type_basic(btbool), type);
107 case cons:
108 f1 = gamma_fresh(gamma);
109 infbinop(expr, f1, type_list(f1), type_list(f1), type);
110 case plus:
111 case minus:
112 case times:
113 case divide:
114 case modulo:
115 case power:
116 infbinop(expr, type_basic(btint), type_basic(btint),
117 type_basic(btint), type);
118 }
119 break;
120 case echar:
121 return unify(type_basic(btchar), type);
122 case efuncall:
123 if ((s = gamma_lookup(gamma, expr->data.efuncall.ident)) == NULL)
124 die("Unbound function: %s\n", expr->data.efuncall.ident);
125 //TODO
126 //TODO fields
127 return NULL;
128 case eint:
129 return unify(type_basic(btint), type);
130 case eident:
131 if ((s = gamma_lookup(gamma, expr->data.eident.ident)) == NULL)
132 die("Unbound variable: %s\n", expr->data.eident.ident);
133 //TODO fields
134 return unify(scheme_instantiate(gamma, s), type);
135 case enil:
136 f1 = gamma_fresh(gamma);
137 return unify(type_list(f1), type);
138 case etuple:
139 f1 = gamma_fresh(gamma);
140 f2 = gamma_fresh(gamma);
141 infbop(expr->data.etuple.left, expr->data.etuple.right,
142 f1, f2, type_tuple(f1, f2), type);
143 case estring:
144 return unify(type_list(type_basic(btchar)), type);
145 case eunop:
146 switch(expr->data.eunop.op) {
147 case negate:
148 s1 = infer_expr(gamma,
149 expr->data.eunop.l, type_basic(btint));
150 if (s1 == NULL)
151 return NULL;
152 return subst_union(s1,
153 unify(subst_apply_t(s1, type),
154 type_basic(btint)));
155 case inverse:
156 s1 = infer_expr(gamma,
157 expr->data.eunop.l, type_basic(btbool));
158 return subst_union(s1,
159 unify(subst_apply_t(s1, type),
160 type_basic(btbool)));
161 }
162 }
163 return NULL;
164 }