+#include <string.h>
+#include <stdlib.h>
+
+#include "util.h"
+#include "type.h"
+
+static const char *basictype_str[] = {
+ [btbool] = "Bool", [btchar] = "Char", [btint] = "Int",
+ [btvoid] = "Void",
+};
+
+struct type *type_arrow(struct type *l, struct type *r)
+{
+ struct type *res = safe_malloc(sizeof(struct type));
+ res->type = tarrow;
+ res->data.tarrow.l = l;
+ res->data.tarrow.r = r;
+ return res;
+
+}
+
+struct type *type_basic(enum basictype type)
+{
+ struct type *res = safe_malloc(sizeof(struct type));
+ res->type = tbasic;
+ res->data.tbasic = type;
+ return res;
+}
+
+struct type *type_list(struct type *type)
+{
+ struct type *res = safe_malloc(sizeof(struct type));
+ res->type = tlist;
+ res->data.tlist = type;
+ return res;
+}
+
+struct type *type_tuple(struct type *l, struct type *r)
+{
+ struct type *res = safe_malloc(sizeof(struct type));
+ res->type = ttuple;
+ res->data.ttuple.l = l;
+ res->data.ttuple.r = r;
+ return res;
+}
+
+struct type *type_var(char *ident)
+{
+ struct type *res = safe_malloc(sizeof(struct type));
+ if (strcmp(ident, "Int") == 0) {
+ res->type = tbasic;
+ res->data.tbasic = btint;
+ free(ident);
+ } else if (strcmp(ident, "Char") == 0) {
+ res->type = tbasic;
+ res->data.tbasic = btchar;
+ free(ident);
+ } else if (strcmp(ident, "Bool") == 0) {
+ res->type = tbasic;
+ res->data.tbasic = btbool;
+ free(ident);
+ } else if (strcmp(ident, "Void") == 0) {
+ res->type = tbasic;
+ res->data.tbasic = btvoid;
+ free(ident);
+ } else {
+ res->type = tvar;
+ res->data.tvar = ident;
+ }
+ return res;
+}
+
+void type_print(struct type *type, FILE *out)
+{
+ if (type == NULL)
+ return;
+ switch (type->type) {
+ case tarrow:
+ safe_fprintf(out, "(");
+ type_print(type->data.tarrow.l, out);
+ safe_fprintf(out, "->");
+ type_print(type->data.tarrow.r, out);
+ safe_fprintf(out, ")");
+ break;
+ case tbasic:
+ safe_fprintf(out, "%s", basictype_str[type->data.tbasic]);
+ break;
+ case tlist:
+ safe_fprintf(out, "[");
+ type_print(type->data.tlist, out);
+ safe_fprintf(out, "]");
+ break;
+ case ttuple:
+ safe_fprintf(out, "(");
+ type_print(type->data.ttuple.l, out);
+ safe_fprintf(out, ",");
+ type_print(type->data.ttuple.r, out);
+ safe_fprintf(out, ")");
+ break;
+ case tvar:
+ safe_fprintf(out, "%s", type->data.tvar);
+ break;
+ default:
+ die("Unsupported type node\n");
+ }
+}
+
+void type_free(struct type *type)
+{
+ if (type == NULL)
+ return;
+ switch (type->type) {
+ case tarrow:
+ type_free(type->data.tarrow.l);
+ type_free(type->data.tarrow.r);
+ break;
+ case tbasic:
+ break;
+ case tlist:
+ type_free(type->data.tlist);
+ break;
+ case ttuple:
+ type_free(type->data.ttuple.l);
+ type_free(type->data.ttuple.r);
+ break;
+ case tvar:
+ free(type->data.tvar);
+ break;
+ default:
+ die("Unsupported type node\n");
+ }
+ free(type);
+}
+
+struct type *type_dup(struct type *r)
+{
+ struct type *res = safe_malloc(sizeof(struct type));
+ *res = *r;
+ switch (r->type) {
+ case tarrow:
+ res->data.tarrow.l = type_dup(r->data.tarrow.l);
+ res->data.tarrow.r = type_dup(r->data.tarrow.r);
+ break;
+ case tbasic:
+ break;
+ case tlist:
+ res->data.tlist = type_dup(r->data.tlist);
+ break;
+ case ttuple:
+ res->data.ttuple.l = type_dup(r->data.ttuple.l);
+ res->data.ttuple.r = type_dup(r->data.ttuple.r);
+ break;
+ case tvar:
+ res->data.tvar = safe_strdup(r->data.tvar);
+ break;
+ }
+ return res;
+}
+
+void type_ftv(struct type *r, int *nftv, char ***ftv)
+{
+ switch (r->type) {
+ case tarrow:
+ type_ftv(r->data.ttuple.l, nftv, ftv);
+ type_ftv(r->data.ttuple.r, nftv, ftv);
+ break;
+ case tbasic:
+ break;
+ case tlist:
+ type_ftv(r->data.tlist, nftv, ftv);
+ break;
+ case ttuple:
+ type_ftv(r->data.ttuple.l, nftv, ftv);
+ type_ftv(r->data.ttuple.r, nftv, ftv);
+ break;
+ case tvar:
+ *ftv = realloc(*ftv, (*nftv+1)*sizeof(char *));
+ if (*ftv == NULL)
+ perror("realloc");
+ (*ftv)[(*nftv)++] = r->data.tvar;
+ break;
+ }
+}