work on type inference some more
[ccc.git] / sem / hm / subst.c
1 #include <string.h>
2 #include <stdlib.h>
3
4 #include "../hm.h"
5
6 struct subst *subst_id()
7 {
8 struct subst *res = safe_malloc(sizeof(struct subst));
9 res->nvar = 0;
10 res->vars = NULL;
11 res->types = NULL;
12 return res;
13 }
14
15 struct subst *subst_singleton(char *ident, struct type *t)
16 {
17 struct subst *res = safe_malloc(sizeof(struct subst));
18 res->nvar = 1;
19 res->vars = safe_malloc(sizeof(char *));
20 res->vars[0] = safe_strdup(ident);
21 res->types = safe_malloc(sizeof(struct type *));
22 res->types[0] = type_dup(t);
23 return res;
24 }
25
26 struct subst *subst_union(struct subst *l, struct subst *r)
27 {
28 if (l == NULL || r == NULL)
29 return NULL;
30 struct subst *res = safe_malloc(sizeof(struct subst));
31 res->nvar = l->nvar+r->nvar;
32 res->vars = safe_malloc(res->nvar*sizeof(char *));
33 res->types = safe_malloc(res->nvar*sizeof(struct type *));
34 for (int i = 0; i<l->nvar; i++) {
35 res->vars[i] = l->vars[i];
36 res->types[i] = l->types[i];
37 }
38 for (int i = 0; i<r->nvar; i++) {
39 res->vars[l->nvar+i] = r->vars[i];
40 res->types[l->nvar+i] = subst_apply_t(l, r->types[i]);
41 }
42 return res;
43 }
44
45 struct type *subst_apply_t(struct subst *subst, struct type *l)
46 {
47 if (subst == NULL)
48 return l;
49 switch (l->type) {
50 case tarrow:
51 l->data.tarrow.l = subst_apply_t(subst, l->data.tarrow.l);
52 l->data.tarrow.r = subst_apply_t(subst, l->data.tarrow.r);
53 break;
54 case tbasic:
55 break;
56 case tlist:
57 l->data.tlist = subst_apply_t(subst, l->data.tlist);
58 break;
59 case ttuple:
60 l->data.ttuple.l = subst_apply_t(subst, l->data.ttuple.l);
61 l->data.ttuple.r = subst_apply_t(subst, l->data.ttuple.r);
62 break;
63 case tvar:
64 for (int i = 0; i<subst->nvar; i++) {
65 if (strcmp(subst->vars[i], l->data.tvar) == 0) {
66 free(l->data.tvar);
67 struct type *r = type_dup(subst->types[i]);
68 *l = *r;
69 free(r);
70 break;
71 }
72 }
73 break;
74 }
75 return l;
76 }
77
78 struct scheme *subst_apply_s(struct subst *subst, struct scheme *scheme)
79 {
80 for (int i = 0; i<scheme->nvar; i++) {
81 for (int j = 0; j<subst->nvar; j++) {
82 if (strcmp(scheme->var[i], subst->vars[j]) != 0) {
83 struct subst *t = subst_singleton(
84 subst->vars[j], subst->types[j]);
85 scheme->type = subst_apply_t(t, scheme->type);
86 }
87 }
88 }
89 return scheme;
90 }
91
92 struct gamma *subst_apply_g(struct subst *subst, struct gamma *gamma)
93 {
94 for (int i = 0; i<gamma->nschemes; i++)
95 subst_apply_s(subst, gamma->schemes[i]);
96 return gamma;
97 }
98
99 void subst_print(struct subst *s, FILE *out)
100 {
101 if (s == NULL) {
102 fprintf(out, "no subst\n");
103 } else {
104 fprintf(out, "[");
105 for (int i = 0; i<s->nvar; i++) {
106 fprintf(out, "%s->", s->vars[i]);
107 type_print(s->types[i], out);
108 if (i + 1 < s->nvar)
109 fprintf(out, ", ");
110 }
111 fprintf(out, "]\n");
112 }
113 }
114
115 void subst_free(struct subst *s, bool type)
116 {
117 if (s != NULL) {
118 for (int i = 0; i<s->nvar; i++) {
119 free(s->vars[i]);
120 if (type)
121 type_free(s->types[i]);
122 }
123 }
124 }