work on type inference some more
[ccc.git] / sem / hm / subst.c
diff --git a/sem/hm/subst.c b/sem/hm/subst.c
new file mode 100644 (file)
index 0000000..55d6cfe
--- /dev/null
@@ -0,0 +1,124 @@
+#include <string.h>
+#include <stdlib.h>
+
+#include "../hm.h"
+
+struct subst *subst_id()
+{
+       struct subst *res = safe_malloc(sizeof(struct subst));
+       res->nvar = 0;
+       res->vars = NULL;
+       res->types = NULL;
+       return res;
+}
+
+struct subst *subst_singleton(char *ident, struct type *t)
+{
+       struct subst *res = safe_malloc(sizeof(struct subst));
+       res->nvar = 1;
+       res->vars = safe_malloc(sizeof(char *));
+       res->vars[0] = safe_strdup(ident);
+       res->types = safe_malloc(sizeof(struct type *));
+       res->types[0] = type_dup(t);
+       return res;
+}
+
+struct subst *subst_union(struct subst *l, struct subst *r)
+{
+       if (l == NULL || r == NULL)
+               return NULL;
+       struct subst *res = safe_malloc(sizeof(struct subst));
+       res->nvar = l->nvar+r->nvar;
+       res->vars = safe_malloc(res->nvar*sizeof(char *));
+       res->types = safe_malloc(res->nvar*sizeof(struct type *));
+       for (int i = 0; i<l->nvar; i++) {
+               res->vars[i] = l->vars[i];
+               res->types[i] = l->types[i];
+       }
+       for (int i = 0; i<r->nvar; i++) {
+               res->vars[l->nvar+i] = r->vars[i];
+               res->types[l->nvar+i] = subst_apply_t(l, r->types[i]);
+       }
+       return res;
+}
+
+struct type *subst_apply_t(struct subst *subst, struct type *l)
+{
+       if (subst == NULL)
+               return l;
+       switch (l->type) {
+       case tarrow:
+               l->data.tarrow.l = subst_apply_t(subst, l->data.tarrow.l);
+               l->data.tarrow.r = subst_apply_t(subst, l->data.tarrow.r);
+               break;
+       case tbasic:
+               break;
+       case tlist:
+               l->data.tlist = subst_apply_t(subst, l->data.tlist);
+               break;
+       case ttuple:
+               l->data.ttuple.l = subst_apply_t(subst, l->data.ttuple.l);
+               l->data.ttuple.r = subst_apply_t(subst, l->data.ttuple.r);
+               break;
+       case tvar:
+               for (int i = 0; i<subst->nvar; i++) {
+                       if (strcmp(subst->vars[i], l->data.tvar) == 0) {
+                               free(l->data.tvar);
+                               struct type *r = type_dup(subst->types[i]);
+                               *l = *r;
+                               free(r);
+                               break;
+                       }
+               }
+               break;
+       }
+       return l;
+}
+
+struct scheme *subst_apply_s(struct subst *subst, struct scheme *scheme)
+{
+       for (int i = 0; i<scheme->nvar; i++) {
+               for (int j = 0; j<subst->nvar; j++) {
+                       if (strcmp(scheme->var[i], subst->vars[j]) != 0) {
+                               struct subst *t = subst_singleton(
+                                       subst->vars[j], subst->types[j]);
+                               scheme->type = subst_apply_t(t, scheme->type);
+                       }
+               }
+       }
+       return scheme;
+}
+
+struct gamma *subst_apply_g(struct subst *subst, struct gamma *gamma)
+{
+       for (int i = 0; i<gamma->nschemes; i++)
+               subst_apply_s(subst, gamma->schemes[i]);
+       return gamma;
+}
+
+void subst_print(struct subst *s, FILE *out)
+{
+       if (s == NULL) {
+               fprintf(out, "no subst\n");
+       } else {
+               fprintf(out, "[");
+               for (int i = 0; i<s->nvar; i++) {
+                       fprintf(out, "%s->", s->vars[i]);
+                       type_print(s->types[i], out);
+                       if (i + 1 < s->nvar)
+                               fprintf(out, ", ");
+               }
+               fprintf(out, "]\n");
+       }
+}
+
+void subst_free(struct subst *s, bool type)
+{
+       if (s != NULL) {
+               for (int i = 0; i<s->nvar; i++) {
+                       free(s->vars[i]);
+                       if (type)
+                               type_free(s->types[i]);
+               }
+       }
+}