work on type inference some more
[ccc.git] / type.c
diff --git a/type.c b/type.c
index f1828bc..8fc6102 100644 (file)
--- a/type.c
+++ b/type.c
-#include "ast.h"
+#include <string.h>
+#include <stdlib.h>
 
-struct vardecl *type_vardecl(struct vardecl *vardecl)
+#include "util.h"
+#include "type.h"
+
+static const char *basictype_str[] = {
+       [btbool] = "Bool", [btchar] = "Char", [btint] = "Int",
+       [btvoid] = "Void",
+};
+
+struct type *type_arrow(struct type *l, struct type *r)
+{
+       struct type *res = safe_malloc(sizeof(struct type));
+       res->type = tarrow;
+       res->data.tarrow.l = l;
+       res->data.tarrow.r = r;
+       return res;
+
+}
+
+struct type *type_basic(enum basictype type)
+{
+       struct type *res = safe_malloc(sizeof(struct type));
+       res->type = tbasic;
+       res->data.tbasic = type;
+       return res;
+}
+
+struct type *type_list(struct type *type)
+{
+       struct type *res = safe_malloc(sizeof(struct type));
+       res->type = tlist;
+       res->data.tlist = type;
+       return res;
+}
+
+struct type *type_tuple(struct type *l, struct type *r)
+{
+       struct type *res = safe_malloc(sizeof(struct type));
+       res->type = ttuple;
+       res->data.ttuple.l = l;
+       res->data.ttuple.r = r;
+       return res;
+}
+
+struct type *type_var(char *ident)
+{
+       struct type *res = safe_malloc(sizeof(struct type));
+       if (strcmp(ident, "Int") == 0) {
+               res->type = tbasic;
+               res->data.tbasic = btint;
+               free(ident);
+       } else if (strcmp(ident, "Char") == 0) {
+               res->type = tbasic;
+               res->data.tbasic = btchar;
+               free(ident);
+       } else if (strcmp(ident, "Bool") == 0) {
+               res->type = tbasic;
+               res->data.tbasic = btbool;
+               free(ident);
+       } else if (strcmp(ident, "Void") == 0) {
+               res->type = tbasic;
+               res->data.tbasic = btvoid;
+               free(ident);
+       } else {
+               res->type = tvar;
+               res->data.tvar = ident;
+       }
+       return res;
+}
+
+void type_print(struct type *type, FILE *out)
+{
+       if (type == NULL)
+               return;
+       switch (type->type) {
+       case tarrow:
+               safe_fprintf(out, "(");
+               type_print(type->data.tarrow.l, out);
+               safe_fprintf(out, "->");
+               type_print(type->data.tarrow.r, out);
+               safe_fprintf(out, ")");
+               break;
+       case tbasic:
+               safe_fprintf(out, "%s", basictype_str[type->data.tbasic]);
+               break;
+       case tlist:
+               safe_fprintf(out, "[");
+               type_print(type->data.tlist, out);
+               safe_fprintf(out, "]");
+               break;
+       case ttuple:
+               safe_fprintf(out, "(");
+               type_print(type->data.ttuple.l, out);
+               safe_fprintf(out, ",");
+               type_print(type->data.ttuple.r, out);
+               safe_fprintf(out, ")");
+               break;
+       case tvar:
+               safe_fprintf(out, "%s", type->data.tvar);
+               break;
+       default:
+               die("Unsupported type node\n");
+       }
+}
+
+void type_free(struct type *type)
 {
-       return vardecl;
+       if (type == NULL)
+               return;
+       switch (type->type) {
+       case tarrow:
+               type_free(type->data.tarrow.l);
+               type_free(type->data.tarrow.r);
+               break;
+       case tbasic:
+               break;
+       case tlist:
+               type_free(type->data.tlist);
+               break;
+       case ttuple:
+               type_free(type->data.ttuple.l);
+               type_free(type->data.ttuple.r);
+               break;
+       case tvar:
+               free(type->data.tvar);
+               break;
+       default:
+               die("Unsupported type node\n");
+       }
+       free(type);
 }
 
-struct decl *type_decl(struct decl *decl)
+struct type *type_dup(struct type *r)
 {
-       switch (decl->type) {
-       case dcomponent:
-               fprintf(stderr, "type_decl:component unsupported\n");
+       struct type *res = safe_malloc(sizeof(struct type));
+       *res = *r;
+       switch (r->type) {
+       case tarrow:
+               res->data.tarrow.l = type_dup(r->data.tarrow.l);
+               res->data.tarrow.r = type_dup(r->data.tarrow.r);
+               break;
+       case tbasic:
                break;
-       case dfundecl:
-               fprintf(stderr, "type_decl:fundecl unsupported\n");
+       case tlist:
+               res->data.tlist = type_dup(r->data.tlist);
                break;
-       case dvardecl:
-               decl->data.dvar = type_vardecl(decl->data.dvar);
+       case ttuple:
+               res->data.ttuple.l = type_dup(r->data.ttuple.l);
+               res->data.ttuple.r = type_dup(r->data.ttuple.r);
+               break;
+       case tvar:
+               res->data.tvar = safe_strdup(r->data.tvar);
                break;
        }
-       return decl;
+       return res;
 }
 
-struct ast *type(struct ast *ast)
+void type_ftv(struct type *r, int *nftv, char ***ftv)
 {
-       for (int i = 0; i<ast->ndecls; i++)
-               ast->decls[i] = type_decl(ast->decls[i]);
-       return ast;
+       switch (r->type) {
+       case tarrow:
+               type_ftv(r->data.ttuple.l, nftv, ftv);
+               type_ftv(r->data.ttuple.r, nftv, ftv);
+               break;
+       case tbasic:
+               break;
+       case tlist:
+               type_ftv(r->data.tlist, nftv, ftv);
+               break;
+       case ttuple:
+               type_ftv(r->data.ttuple.l, nftv, ftv);
+               type_ftv(r->data.ttuple.r, nftv, ftv);
+               break;
+       case tvar:
+               *ftv = realloc(*ftv, (*nftv+1)*sizeof(char *));
+               if (*ftv == NULL)
+                       perror("realloc");
+               (*ftv)[(*nftv)++] = r->data.tvar;
+               break;
+       }
 }