work on type inference some more
[ccc.git] / type.c
1 #include <string.h>
2 #include <stdlib.h>
3
4 #include "util.h"
5 #include "type.h"
6
7 static const char *basictype_str[] = {
8 [btbool] = "Bool", [btchar] = "Char", [btint] = "Int",
9 [btvoid] = "Void",
10 };
11
12 struct type *type_arrow(struct type *l, struct type *r)
13 {
14 struct type *res = safe_malloc(sizeof(struct type));
15 res->type = tarrow;
16 res->data.tarrow.l = l;
17 res->data.tarrow.r = r;
18 return res;
19
20 }
21
22 struct type *type_basic(enum basictype type)
23 {
24 struct type *res = safe_malloc(sizeof(struct type));
25 res->type = tbasic;
26 res->data.tbasic = type;
27 return res;
28 }
29
30 struct type *type_list(struct type *type)
31 {
32 struct type *res = safe_malloc(sizeof(struct type));
33 res->type = tlist;
34 res->data.tlist = type;
35 return res;
36 }
37
38 struct type *type_tuple(struct type *l, struct type *r)
39 {
40 struct type *res = safe_malloc(sizeof(struct type));
41 res->type = ttuple;
42 res->data.ttuple.l = l;
43 res->data.ttuple.r = r;
44 return res;
45 }
46
47 struct type *type_var(char *ident)
48 {
49 struct type *res = safe_malloc(sizeof(struct type));
50 if (strcmp(ident, "Int") == 0) {
51 res->type = tbasic;
52 res->data.tbasic = btint;
53 free(ident);
54 } else if (strcmp(ident, "Char") == 0) {
55 res->type = tbasic;
56 res->data.tbasic = btchar;
57 free(ident);
58 } else if (strcmp(ident, "Bool") == 0) {
59 res->type = tbasic;
60 res->data.tbasic = btbool;
61 free(ident);
62 } else if (strcmp(ident, "Void") == 0) {
63 res->type = tbasic;
64 res->data.tbasic = btvoid;
65 free(ident);
66 } else {
67 res->type = tvar;
68 res->data.tvar = ident;
69 }
70 return res;
71 }
72
73 void type_print(struct type *type, FILE *out)
74 {
75 if (type == NULL)
76 return;
77 switch (type->type) {
78 case tarrow:
79 safe_fprintf(out, "(");
80 type_print(type->data.tarrow.l, out);
81 safe_fprintf(out, "->");
82 type_print(type->data.tarrow.r, out);
83 safe_fprintf(out, ")");
84 break;
85 case tbasic:
86 safe_fprintf(out, "%s", basictype_str[type->data.tbasic]);
87 break;
88 case tlist:
89 safe_fprintf(out, "[");
90 type_print(type->data.tlist, out);
91 safe_fprintf(out, "]");
92 break;
93 case ttuple:
94 safe_fprintf(out, "(");
95 type_print(type->data.ttuple.l, out);
96 safe_fprintf(out, ",");
97 type_print(type->data.ttuple.r, out);
98 safe_fprintf(out, ")");
99 break;
100 case tvar:
101 safe_fprintf(out, "%s", type->data.tvar);
102 break;
103 default:
104 die("Unsupported type node\n");
105 }
106 }
107
108 void type_free(struct type *type)
109 {
110 if (type == NULL)
111 return;
112 switch (type->type) {
113 case tarrow:
114 type_free(type->data.tarrow.l);
115 type_free(type->data.tarrow.r);
116 break;
117 case tbasic:
118 break;
119 case tlist:
120 type_free(type->data.tlist);
121 break;
122 case ttuple:
123 type_free(type->data.ttuple.l);
124 type_free(type->data.ttuple.r);
125 break;
126 case tvar:
127 free(type->data.tvar);
128 break;
129 default:
130 die("Unsupported type node\n");
131 }
132 free(type);
133 }
134
135 struct type *type_dup(struct type *r)
136 {
137 struct type *res = safe_malloc(sizeof(struct type));
138 *res = *r;
139 switch (r->type) {
140 case tarrow:
141 res->data.tarrow.l = type_dup(r->data.tarrow.l);
142 res->data.tarrow.r = type_dup(r->data.tarrow.r);
143 break;
144 case tbasic:
145 break;
146 case tlist:
147 res->data.tlist = type_dup(r->data.tlist);
148 break;
149 case ttuple:
150 res->data.ttuple.l = type_dup(r->data.ttuple.l);
151 res->data.ttuple.r = type_dup(r->data.ttuple.r);
152 break;
153 case tvar:
154 res->data.tvar = safe_strdup(r->data.tvar);
155 break;
156 }
157 return res;
158 }
159
160 void type_ftv(struct type *r, int *nftv, char ***ftv)
161 {
162 switch (r->type) {
163 case tarrow:
164 type_ftv(r->data.ttuple.l, nftv, ftv);
165 type_ftv(r->data.ttuple.r, nftv, ftv);
166 break;
167 case tbasic:
168 break;
169 case tlist:
170 type_ftv(r->data.tlist, nftv, ftv);
171 break;
172 case ttuple:
173 type_ftv(r->data.ttuple.l, nftv, ftv);
174 type_ftv(r->data.ttuple.r, nftv, ftv);
175 break;
176 case tvar:
177 *ftv = realloc(*ftv, (*nftv+1)*sizeof(char *));
178 if (*ftv == NULL)
179 perror("realloc");
180 (*ftv)[(*nftv)++] = r->data.tvar;
181 break;
182 }
183 }