From: Mart Lubbers Date: Wed, 17 Feb 2021 10:43:04 +0000 (+0100) Subject: work on type inference some more X-Git-Url: https://git.martlubbers.net/?a=commitdiff_plain;ds=sidebyside;p=ccc.git work on type inference some more --- diff --git a/Makefile b/Makefile index 74272aa..5328897 100644 --- a/Makefile +++ b/Makefile @@ -1,9 +1,10 @@ -CFLAGS+=-Wall -Wextra -std=c99 -pedantic -D_XOPEN_SOURCE=700 -ggdb +CFLAGS+=-Wall -Wextra -std=c99 -pedantic -ggdb YFLAGS+=-d --locations -v --defines=parse.h LFLAGS+=--header-file=scan.h -OBJECTS:=scan.o parse.o ast.o util.o list.o sem.o genc.o \ - sem/scc.o sem/hm.o +OBJECTS:=scan.o parse.o ast.o type.o util.o list.o sem.o genc.o \ + sem/scc.o\ + $(addprefix sem/hm, .o /gamma.o /subst.o /scheme.o) all: splc splc: $(OBJECTS) @@ -11,5 +12,13 @@ scan.c: scan.l parse.h parse.h: parse.c expr.c: y.tab.h +scan.o: CFLAGS+=-D_XOPEN_SOURCE=700 + +.PHONY: test + +test: + CFLAGS="$(CFLAGS)" $(MAKE) -C test test + clean: $(RM) $(OBJECTS) y.output parse.h scan.h scan.c parse.c expr a.c + $(MAKE) -C test clean diff --git a/ast.c b/ast.c index bb8a541..1e93a71 100644 --- a/ast.c +++ b/ast.c @@ -4,6 +4,7 @@ #include "util.h" #include "ast.h" +#include "type.h" #include "list.h" #include "parse.h" @@ -16,10 +17,6 @@ const char *binop_str[] = { const char *fieldspec_str[] = { [fst] = "fst", [snd] = "snd", [hd] = "hd", [tl] = "tl"}; const char *unop_str[] = { [inverse] = "!", [negate] = "-", }; -static const char *basictype_str[] = { - [btbool] = "Bool", [btchar] = "Char", [btint] = "Int", - [btvoid] = "Void", -}; struct ast *ast(struct list *decls) { @@ -237,57 +234,6 @@ struct expr *expr_unop(enum unop op, struct expr *l) return res; } -struct type *type_basic(enum basictype type) -{ - struct type *res = safe_malloc(sizeof(struct type)); - res->type = tbasic; - res->data.tbasic = type; - return res; -} - -struct type *type_list(struct type *type) -{ - struct type *res = safe_malloc(sizeof(struct type)); - res->type = tlist; - res->data.tlist = type; - return res; -} - -struct type *type_tuple(struct type *l, struct type *r) -{ - struct type *res = safe_malloc(sizeof(struct type)); - res->type = ttuple; - res->data.ttuple.l = l; - res->data.ttuple.r = r; - return res; -} - -struct type *type_var(char *ident) -{ - struct type *res = safe_malloc(sizeof(struct type)); - if (strcmp(ident, "Int") == 0) { - res->type = tbasic; - res->data.tbasic = btint; - free(ident); - } else if (strcmp(ident, "Char") == 0) { - res->type = tbasic; - res->data.tbasic = btchar; - free(ident); - } else if (strcmp(ident, "Bool") == 0) { - res->type = tbasic; - res->data.tbasic = btbool; - free(ident); - } else if (strcmp(ident, "Void") == 0) { - res->type = tbasic; - res->data.tbasic = btvoid; - free(ident); - } else { - res->type = tvar; - res->data.tvar = ident; - } - return res; -} - void ast_print(struct ast *ast, FILE *out) { if (ast == NULL) @@ -479,34 +425,6 @@ void expr_print(struct expr *expr, FILE *out) } } -void type_print(struct type *type, FILE *out) -{ - if (type == NULL) - return; - switch (type->type) { - case tbasic: - safe_fprintf(out, "%s", basictype_str[type->data.tbasic]); - break; - case tlist: - safe_fprintf(out, "["); - type_print(type->data.tlist, out); - safe_fprintf(out, "]"); - break; - case ttuple: - safe_fprintf(out, "("); - type_print(type->data.ttuple.l, out); - safe_fprintf(out, ","); - type_print(type->data.ttuple.r, out); - safe_fprintf(out, ")"); - break; - case tvar: - safe_fprintf(out, "%s", type->data.tvar); - break; - default: - die("Unsupported type node\n"); - } -} - void ast_free(struct ast *ast) { if (ast == NULL) @@ -648,26 +566,3 @@ void expr_free(struct expr *expr) } free(expr); } - -void type_free(struct type *type) -{ - if (type == NULL) - return; - switch (type->type) { - case tbasic: - break; - case tlist: - type_free(type->data.tlist); - break; - case ttuple: - type_free(type->data.ttuple.l); - type_free(type->data.ttuple.r); - break; - case tvar: - free(type->data.tvar); - break; - default: - die("Unsupported type node\n"); - } - free(type); -} diff --git a/ast.h b/ast.h index db69333..635d0c8 100644 --- a/ast.h +++ b/ast.h @@ -5,6 +5,7 @@ #include #include "util.h" +#include "type.h" struct ast; #include "parse.h" @@ -32,20 +33,6 @@ struct fundecl { struct stmt **body; }; -enum basictype {btbool, btchar, btint, btvoid}; -struct type { - enum {tbasic,tlist,ttuple,tvar} type; - union { - enum basictype tbasic; - struct type *tlist; - struct { - struct type *l; - struct type *r; - } ttuple; - char *tvar; - } data; -}; - struct decl { //NOTE: DON'T CHANGE THIS ORDER enum {dcomp, dvardecl, dfundecl} type; @@ -158,18 +145,12 @@ struct expr *expr_tuple(struct expr *left, struct expr *right); struct expr *expr_string(char *str); struct expr *expr_unop(enum unop op, struct expr *l); -struct type *type_basic(enum basictype type); -struct type *type_list(struct type *type); -struct type *type_tuple(struct type *l, struct type *r); -struct type *type_var(char *ident); - void ast_print(struct ast *, FILE *out); void vardecl_print(struct vardecl *decl, int indent, FILE *out); void fundecl_print(struct fundecl *decl, FILE *out); void decl_print(struct decl *ast, FILE *out); void stmt_print(struct stmt *ast, int indent, FILE *out); void expr_print(struct expr *ast, FILE *out); -void type_print(struct type *type, FILE *out); void ast_free(struct ast *); void vardecl_free(struct vardecl *decl); @@ -177,6 +158,5 @@ void fundecl_free(struct fundecl *fundecl); void decl_free(struct decl *ast); void stmt_free(struct stmt *ast); void expr_free(struct expr *ast); -void type_free(struct type *type); #endif diff --git a/sem.c b/sem.c index 1cea48f..3404fb4 100644 --- a/sem.c +++ b/sem.c @@ -3,6 +3,7 @@ #include "list.h" #include "sem/scc.h" +#include "sem/hm.h" #include "ast.h" void type_error(const char *msg, ...) @@ -34,38 +35,54 @@ void check_expr_constant(struct expr *expr) } } -struct vardecl *type_vardecl(struct vardecl *vardecl) +struct vardecl *type_vardecl(struct gamma *gamma, struct vardecl *vardecl) { - return vardecl; -} + struct type *t = vardecl->type == NULL + ? gamma_fresh(gamma) : type_dup(vardecl->type); + struct subst *s = infer_expr(gamma, vardecl->expr, t); -struct decl *type_decl(struct decl *decl) -{ - switch (decl->type) { - case dcomp: - fprintf(stderr, "type_decl:component unsupported\n"); - break; - case dfundecl: - fprintf(stderr, "type_decl:fundecl unsupported\n"); - break; - case dvardecl: - decl->data.dvar = type_vardecl(decl->data.dvar); - break; - } - return decl; + if (s == NULL) + die("error inferring variable\n"); + vardecl->type = subst_apply_t(s, t); + + //subst_free(s); + + return vardecl; } struct ast *sem(struct ast *ast) { ast = ast_scc(ast); - //Check that all globals are constant + struct gamma *gamma = gamma_init(); + + //Check all vardecls for (int i = 0; indecls; i++) { - if (ast->decls[i]->type == dvardecl) { - //Check globals + switch(ast->decls[i]->type) { + case dvardecl: + //Check if constant check_expr_constant(ast->decls[i]->data.dvar->expr); + //Infer if necessary + type_vardecl(gamma, ast->decls[i]->data.dvar); + break; + case dfundecl: { +// struct type *f1 = gamma_fresh(gamma); +// gamma_insert(gamma, ast->decls[i]->data.dfun->ident +// , scheme_create(f1)); +//infer env (Let [(x, e1)] e2) +// = fresh +// >>= \tv-> let env` = 'Data.Map'.put x (Forall [] tv) env +// in infer env` e1 +// >>= \(s1,t1)-> infer ('Data.Map'.put x (generalize (apply s1 env`) t1) env`) e2 +// >>= \(s2, t2)->pure (s1 oo s2, t2) + break; + } + case dcomp: break; } } + + gamma_free(gamma); + return ast; } diff --git a/sem/hm.c b/sem/hm.c index d91715e..fb3d2ce 100644 --- a/sem/hm.c +++ b/sem/hm.c @@ -2,173 +2,38 @@ #include #include "hm.h" -#include "../util.h" +#include "hm/subst.h" +#include "hm/gamma.h" +#include "hm/scheme.h" #include "../ast.h" -struct gamma { - int fresh; - int nschemes; - char **vars; - struct scheme *schemes; -}; - -struct scheme *gamma_lookup(struct gamma *gamma, char *ident) -{ - for (int i = 0; ivars[i]) == 0) { - //TODO - } - } - return NULL; -} - -struct type *fresh(struct gamma *gamma) -{ - char *buf = safe_malloc(10); - sprintf(buf, "%d", gamma->fresh); - gamma->fresh++; - return type_var(buf); -} - -void ftv_type(struct type *r, int *nftv, char **ftv) -{ - switch (r->type) { - case tbasic: - break; - case tlist: - ftv_type(r->data.tlist, nftv, ftv); - break; - case ttuple: - ftv_type(r->data.ttuple.l, nftv, ftv); - ftv_type(r->data.ttuple.r, nftv, ftv); - break; - case tvar: - *ftv = realloc(*ftv, (*nftv+1)*sizeof(char *)); - if (*ftv == NULL) - perror("realloc"); - ftv[(*nftv)++] = r->data.tvar; - break; - } -} - bool occurs_check(char *var, struct type *r) { int nftv = 0; - char *ftv = NULL; - ftv_type(r, &nftv, &ftv); + char **ftv = NULL; + type_ftv(r, &nftv, &ftv); for (int i = 0; itype) { - case tbasic: - break; - case tlist: - res->data.tlist = dup_type(r->data.tlist); - break; - case ttuple: - res->data.ttuple.l = dup_type(r->data.ttuple.l); - res->data.ttuple.r = dup_type(r->data.ttuple.r); - break; - case tvar: - res->data.tvar = strdup(r->data.tvar); - break; - } - return res; -} - -struct type *subst_apply_t(struct substitution *subst, struct type *l) +struct subst *unify(struct type *l, struct type *r) { + if (l == NULL || r == NULL) + return NULL; + if (r->type == tvar && l->type != tvar) + return unify(r, l); + struct subst *s1, *s2; switch (l->type) { - case tbasic: - break; - case tlist: - l->data.tlist = subst_apply_t(subst, l->data.tlist); - break; - case ttuple: - l->data.ttuple.l = subst_apply_t(subst, l->data.ttuple.l); - l->data.ttuple.r = subst_apply_t(subst, l->data.ttuple.r); - break; - case tvar: - for (int i = 0; invar; i++) { - if (strcmp(subst->vars[i], l->data.tvar) == 0) { - free(l->data.tvar); - struct type *r = subst->types[i]; - *l = *r; - free(r); - } + case tarrow: + if (r->type == tarrow) { + s1 = unify(l->data.tarrow.l, r->data.tarrow.l); + s2 = unify(subst_apply_t(s1, l->data.tarrow.l), + subst_apply_t(s1, r->data.tarrow.l)); + return subst_union(s1, s2); } break; - } - return l; -} -struct gamma *subst_apply_g(struct substitution *subst, struct gamma *gamma) -{ - //TODO - return gamma; -} - -void subst_print(struct substitution *s, FILE *out) -{ - if (s == NULL) { - fprintf(out, "no substitution\n"); - } else { - fprintf(out, "["); - for (int i = 0; invar; i++) { - fprintf(out, "%s->", s->vars[i]); - type_print(s->types[i], out); - if (i + 1 < s->nvar) - fprintf(out, ", "); - } - fprintf(out, "]\n"); - } -} - -struct substitution *subst_id() -{ - struct substitution *res = safe_malloc(sizeof(struct substitution)); - res->nvar = 0; - res->vars = NULL; - res->types = NULL; - return res; -} - -struct substitution *subst_singleton(char *ident, struct type *t) -{ - struct substitution *res = safe_malloc(sizeof(struct substitution)); - res->nvar = 1; - res->vars = safe_malloc(sizeof(char *)); - res->vars[0] = safe_strdup(ident); - res->types = safe_malloc(sizeof(struct type *)); - res->types[0] = dup_type(t); - return res; -} - -struct substitution *subst_union(struct substitution *l, struct substitution *r) -{ - struct substitution *res = safe_malloc(sizeof(struct substitution)); - res->nvar = l->nvar+r->nvar; - res->vars = safe_malloc(res->nvar*sizeof(char *)); - res->types = safe_malloc(res->nvar*sizeof(struct type *)); - for (int i = 0; invar; i++) { - res->vars[i] = l->vars[i]; - res->types[i] = l->types[i]; - } - for (int i = 0; invar; i++) { - res->vars[l->nvar+i] = l->vars[i]; - res->types[l->nvar+i] = subst_apply_t(l, r->types[i]); - } - return res; -} - -struct substitution *unify(struct type *l, struct type *r) { - switch (l->type) { case tbasic: if (r->type == tbasic && l->data.tbasic == r->data.tbasic) return subst_id(); @@ -179,11 +44,8 @@ struct substitution *unify(struct type *l, struct type *r) { break; case ttuple: if (r->type == ttuple) { - struct substitution *s1 = unify( - l->data.ttuple.l, - r->data.ttuple.l); - struct substitution *s2 = unify( - subst_apply_t(s1, l->data.ttuple.l), + s1 = unify(l->data.ttuple.l, r->data.ttuple.l); + s2 = unify(subst_apply_t(s1, l->data.ttuple.l), subst_apply_t(s1, r->data.ttuple.l)); return subst_union(s1, s2); } @@ -197,11 +59,23 @@ struct substitution *unify(struct type *l, struct type *r) { return subst_singleton(l->data.tvar, r); break; } + fprintf(stderr, "cannot unify "); + type_print(l, stderr); + fprintf(stderr, " with "); + type_print(r, stderr); + fprintf(stderr, "\n"); return NULL; } -struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type) +struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type) { + fprintf(stderr, "infer expr: "); + expr_print(expr, stderr); + fprintf(stderr, "\ngamma: "); + gamma_print(gamma, stderr); + fprintf(stderr, "\ntype: "); + type_print(type, stderr); + fprintf(stderr, "\n"); #define infbop(l, r, a1, a2, rt, sigma) {\ s1 = infer_expr(gamma, l, a1);\ @@ -210,13 +84,14 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t } #define infbinop(e, a1, a2, rt, sigma)\ infbop(e->data.ebinop.l, e->data.ebinop.r, a1, a2, rt, sigma) - struct substitution *s1, *s2; + struct subst *s1, *s2; struct type *f1, *f2; + struct scheme *s; switch (expr->type) { case ebool: return unify(type_basic(btbool), type); case ebinop: - switch(expr->data.ebinop.op) { + switch (expr->data.ebinop.op) { case binor: case binand: infbinop(expr, type_basic(btbool), type_basic(btbool), @@ -227,10 +102,10 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t case le: case geq: case ge: - f1 = fresh(gamma); + f1 = gamma_fresh(gamma); infbinop(expr, f1, f1, type_basic(btbool), type); case cons: - f1 = fresh(gamma); + f1 = gamma_fresh(gamma); infbinop(expr, f1, type_list(f1), type_list(f1), type); case plus: case minus: @@ -241,20 +116,28 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t infbinop(expr, type_basic(btint), type_basic(btint), type_basic(btint), type); } + break; case echar: return unify(type_basic(btchar), type); case efuncall: + if ((s = gamma_lookup(gamma, expr->data.efuncall.ident)) == NULL) + die("Unbound function: %s\n", expr->data.efuncall.ident); //TODO + //TODO fields return NULL; case eint: return unify(type_basic(btint), type); case eident: - - //TODO - return NULL; + if ((s = gamma_lookup(gamma, expr->data.eident.ident)) == NULL) + die("Unbound variable: %s\n", expr->data.eident.ident); + //TODO fields + return unify(scheme_instantiate(gamma, s), type); + case enil: + f1 = gamma_fresh(gamma); + return unify(type_list(f1), type); case etuple: - f1 = fresh(gamma); - f2 = fresh(gamma); + f1 = gamma_fresh(gamma); + f2 = gamma_fresh(gamma); infbop(expr->data.etuple.left, expr->data.etuple.right, f1, f2, type_tuple(f1, f2), type); case estring: @@ -264,6 +147,8 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t case negate: s1 = infer_expr(gamma, expr->data.eunop.l, type_basic(btint)); + if (s1 == NULL) + return NULL; return subst_union(s1, unify(subst_apply_t(s1, type), type_basic(btint))); @@ -275,5 +160,5 @@ struct substitution *infer_expr(struct gamma *gamma, struct expr *expr, struct t type_basic(btbool))); } } - + return NULL; } diff --git a/sem/hm.h b/sem/hm.h index efb1afa..6792c3b 100644 --- a/sem/hm.h +++ b/sem/hm.h @@ -2,19 +2,11 @@ #define SEM_HM_C #include "../ast.h" - -struct scheme { - struct type *type; - int nvar; - char **var; -}; - -struct substitution { - int nvar; - char **vars; - struct type **types; -}; +#include "hm/gamma.h" +#include "hm/subst.h" +#include "hm/scheme.h" struct ast *infer(struct ast *ast); +struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type); #endif diff --git a/sem/hm/gamma.c b/sem/hm/gamma.c new file mode 100644 index 0000000..9a97c27 --- /dev/null +++ b/sem/hm/gamma.c @@ -0,0 +1,62 @@ +#include +#include + +#include "../hm.h" + +struct gamma *gamma_init() +{ + struct gamma *gamma = safe_malloc(sizeof(struct gamma)); + gamma->fresh = 0; + gamma->nschemes = 0; + gamma->vars = NULL; + gamma->schemes = NULL; + return gamma; +} + +void gamma_insert(struct gamma *gamma, char *ident, struct scheme *scheme) +{ + gamma->nschemes++; + gamma->vars = realloc(gamma->vars, gamma->nschemes*sizeof(char *)); + gamma->schemes = realloc(gamma->schemes, + gamma->nschemes*sizeof(struct scheme *)); + gamma->vars[gamma->nschemes-1] = safe_strdup(ident); + gamma->schemes[gamma->nschemes-1] = scheme; +} + +struct scheme *gamma_lookup(struct gamma *gamma, char *ident) +{ + for (int i = 0; inschemes; i++) + if (strcmp(ident, gamma->vars[i]) == 0) + return gamma->schemes[i]; + return NULL; +} + +struct type *gamma_fresh(struct gamma *gamma) +{ + char buf[10] = {0}; + sprintf(buf, "%d", gamma->fresh++); + return type_var(safe_strdup(buf)); +} + +void gamma_print(struct gamma *gamma, FILE *out) +{ + fprintf(out, "{"); + for (int i = 0; inschemes; i++) { + fprintf(out, "%s=", gamma->vars[i]); + scheme_print(gamma->schemes[i], out); + if (i + 1 < gamma->nschemes) + fprintf(out, ", "); + } + fprintf(out, "}"); +} + +void gamma_free(struct gamma *gamma) +{ + for (int i = 0; inschemes; i++) { + free(gamma->vars[i]); + scheme_free(gamma->schemes[i]); + } + free(gamma->vars); + free(gamma->schemes); + free(gamma); +} diff --git a/sem/hm/gamma.h b/sem/hm/gamma.h new file mode 100644 index 0000000..8499144 --- /dev/null +++ b/sem/hm/gamma.h @@ -0,0 +1,24 @@ +#ifndef SEM_HM_GAMMA_H +#define SEM_HM_GAMMA_H + +#include + +#include "../hm.h" + +struct gamma { + int fresh; + int nschemes; + char **vars; + struct scheme **schemes; +}; + +struct gamma *gamma_init(); +void gamma_insert(struct gamma *gamma, char *ident, struct scheme *scheme); + +struct scheme *gamma_lookup(struct gamma *gamma, char *ident); +struct type *gamma_fresh(struct gamma *gamma); + +void gamma_print(struct gamma *gamma, FILE *out); +void gamma_free(struct gamma *gamma); + +#endif diff --git a/sem/hm/scheme.c b/sem/hm/scheme.c new file mode 100644 index 0000000..177bb64 --- /dev/null +++ b/sem/hm/scheme.c @@ -0,0 +1,68 @@ +#include +#include + +#include "../hm.h" + +struct type *scheme_instantiate(struct gamma *gamma, struct scheme *sch) +{ + struct subst *s = subst_id(); + for (int i = 0; invar; i++) { + s = subst_union(s, subst_singleton(sch->var[i], gamma_fresh(gamma))); + } + struct type *t = subst_apply_t(s, type_dup(sch->type)); + for (int i = 0; invar; i++) + free(s->vars[i]); + free(s); + return t; +} + +struct scheme *scheme_create(struct type *t) +{ + struct scheme *s = safe_malloc(sizeof(struct scheme)); + s->type = t; + s->nvar = 0; + s->var = NULL; +} + +struct scheme *scheme_generalise(struct gamma *gamma, struct type *t) +{ + struct scheme *s = safe_malloc(sizeof(struct scheme)); + int nftv = 0; + char **ftv = NULL; + type_ftv(t, &nftv, &ftv); + + s->type = t; + s->nvar = 0; + s->var = safe_malloc(nftv*sizeof(char *)); + for (int i = 0; inschemes; j++) + if (strcmp(gamma->vars[j], ftv[i]) == 0) + skip = true; + if (skip) + continue; + s->nvar++; + s->var[i] = ftv[i]; + } + return s; +} + +void scheme_print(struct scheme *scheme, FILE *out) +{ + if (scheme->nvar > 0) { + fprintf(out, "A."); + for (int i = 0; invar; i++) + fprintf(out, "%s", scheme->var[i]); + fprintf(out, ": "); + } + type_print(scheme->type, out); +} + +void scheme_free(struct scheme *scheme) +{ + type_free(scheme->type); + for (int i = 0; invar; i++) + free(scheme->var[i]); + free(scheme->var); + free(scheme); +} diff --git a/sem/hm/scheme.h b/sem/hm/scheme.h new file mode 100644 index 0000000..eaab3fc --- /dev/null +++ b/sem/hm/scheme.h @@ -0,0 +1,19 @@ +#ifndef SEM_HM_SCHEME_H +#define SEM_HM_SCHEME_H + +#include "../hm.h" + +struct scheme { + struct type *type; + int nvar; + char **var; +}; + +struct type *scheme_instantiate(struct gamma *gamma, struct scheme *s); +struct scheme *scheme_create(struct type *t); +struct scheme *scheme_generalise(struct gamma *gamma, struct type *t); + +void scheme_print(struct scheme *scheme, FILE *out); +void scheme_free(struct scheme *scheme); + +#endif diff --git a/sem/hm/subst.c b/sem/hm/subst.c new file mode 100644 index 0000000..55d6cfe --- /dev/null +++ b/sem/hm/subst.c @@ -0,0 +1,124 @@ +#include +#include + +#include "../hm.h" + +struct subst *subst_id() +{ + struct subst *res = safe_malloc(sizeof(struct subst)); + res->nvar = 0; + res->vars = NULL; + res->types = NULL; + return res; +} + +struct subst *subst_singleton(char *ident, struct type *t) +{ + struct subst *res = safe_malloc(sizeof(struct subst)); + res->nvar = 1; + res->vars = safe_malloc(sizeof(char *)); + res->vars[0] = safe_strdup(ident); + res->types = safe_malloc(sizeof(struct type *)); + res->types[0] = type_dup(t); + return res; +} + +struct subst *subst_union(struct subst *l, struct subst *r) +{ + if (l == NULL || r == NULL) + return NULL; + struct subst *res = safe_malloc(sizeof(struct subst)); + res->nvar = l->nvar+r->nvar; + res->vars = safe_malloc(res->nvar*sizeof(char *)); + res->types = safe_malloc(res->nvar*sizeof(struct type *)); + for (int i = 0; invar; i++) { + res->vars[i] = l->vars[i]; + res->types[i] = l->types[i]; + } + for (int i = 0; invar; i++) { + res->vars[l->nvar+i] = r->vars[i]; + res->types[l->nvar+i] = subst_apply_t(l, r->types[i]); + } + return res; +} + +struct type *subst_apply_t(struct subst *subst, struct type *l) +{ + if (subst == NULL) + return l; + switch (l->type) { + case tarrow: + l->data.tarrow.l = subst_apply_t(subst, l->data.tarrow.l); + l->data.tarrow.r = subst_apply_t(subst, l->data.tarrow.r); + break; + case tbasic: + break; + case tlist: + l->data.tlist = subst_apply_t(subst, l->data.tlist); + break; + case ttuple: + l->data.ttuple.l = subst_apply_t(subst, l->data.ttuple.l); + l->data.ttuple.r = subst_apply_t(subst, l->data.ttuple.r); + break; + case tvar: + for (int i = 0; invar; i++) { + if (strcmp(subst->vars[i], l->data.tvar) == 0) { + free(l->data.tvar); + struct type *r = type_dup(subst->types[i]); + *l = *r; + free(r); + break; + } + } + break; + } + return l; +} + +struct scheme *subst_apply_s(struct subst *subst, struct scheme *scheme) +{ + for (int i = 0; invar; i++) { + for (int j = 0; jnvar; j++) { + if (strcmp(scheme->var[i], subst->vars[j]) != 0) { + struct subst *t = subst_singleton( + subst->vars[j], subst->types[j]); + scheme->type = subst_apply_t(t, scheme->type); + } + } + } + return scheme; +} + +struct gamma *subst_apply_g(struct subst *subst, struct gamma *gamma) +{ + for (int i = 0; inschemes; i++) + subst_apply_s(subst, gamma->schemes[i]); + return gamma; +} + +void subst_print(struct subst *s, FILE *out) +{ + if (s == NULL) { + fprintf(out, "no subst\n"); + } else { + fprintf(out, "["); + for (int i = 0; invar; i++) { + fprintf(out, "%s->", s->vars[i]); + type_print(s->types[i], out); + if (i + 1 < s->nvar) + fprintf(out, ", "); + } + fprintf(out, "]\n"); + } +} + +void subst_free(struct subst *s, bool type) +{ + if (s != NULL) { + for (int i = 0; invar; i++) { + free(s->vars[i]); + if (type) + type_free(s->types[i]); + } + } +} diff --git a/sem/hm/subst.h b/sem/hm/subst.h new file mode 100644 index 0000000..ffec9a7 --- /dev/null +++ b/sem/hm/subst.h @@ -0,0 +1,24 @@ +#ifndef SEM_HM_SUBST_H +#define SEM_HM_SUBST_H + +#include "../../ast.h" +#include "../hm.h" + +struct subst { + int nvar; + char **vars; + struct type **types; +}; + +struct subst *subst_id(); +struct subst *subst_singleton(char *ident, struct type *t); +struct subst *subst_union(struct subst *l, struct subst *r); + +struct type *subst_apply_t(struct subst *subst, struct type *l); +struct scheme *subst_apply_s(struct subst *subst, struct scheme *scheme); +struct gamma *subst_apply_g(struct subst *subst, struct gamma *gamma); + +void subst_print(struct subst *s, FILE *out); +void subst_free(struct subst *s, bool type); + +#endif diff --git a/sem/scc.c b/sem/scc.c index fa47590..6d7b02e 100644 --- a/sem/scc.c +++ b/sem/scc.c @@ -271,10 +271,11 @@ struct ast *ast_scc(struct ast *ast) struct edge **edata = (struct edge **) list_to_array(edges, &nedges, false); + fprintf(stderr, "nfun: %d, ffun: %d, nedges: %d\n", nfun, ffun, nedges); // Do tarjan's and convert back into the declaration list struct components *cs = tarjans(nfun, (void **)fundecls, nedges, edata); - if (cs == NULL) - die("malformed edges in tarjan's????"); +// if (cs == NULL) +// die("malformed edges in tarjan's????"); int i = ffun; for (struct components *c = cs; c != NULL; c = c->next) { diff --git a/test/Makefile b/test/Makefile new file mode 100644 index 0000000..faaff6a --- /dev/null +++ b/test/Makefile @@ -0,0 +1,19 @@ +TESTOBJECTS:=$(patsubst %.c,%.o,$(wildcard *.c)) +TESTS:=$(patsubst %.o,%,$(TESTOBJECTS)) + +LDLIBS+=$(shell pkg-config --libs check) + +.PHONY: test + +test_sem_hm_gamma.o: CFLAGS+=$(shell pkg-config --cflags check) + +test_sem_hm_gamma: test_sem_hm_gamma.o $(addprefix ../sem/hm/,gamma.o scheme.o subst.o) ../util.o ../type.o + +test: $(TESTS) + $(foreach f,$^,./$(f);) + +clean: + $(RM) $(TESTOBJECTS) $(TESTS) +ifeq ($(MAKELEVEL), 0) + $(MAKE) -C ../ clean +endif diff --git a/test/test_sem_hm_gamma.c b/test/test_sem_hm_gamma.c new file mode 100644 index 0000000..1665031 --- /dev/null +++ b/test/test_sem_hm_gamma.c @@ -0,0 +1,97 @@ +#include +#include +#include + +#include "../sem/hm/gamma.h" +#include "../sem/hm/subst.h" +#include "../sem/hm/scheme.h" + +START_TEST(test_gamma_lookup) +{ + struct gamma *gamma = gamma_init(); + + ck_assert_ptr_null(gamma_lookup(gamma, "fun")); + ck_assert_ptr_null(gamma_lookup(gamma, "fun2")); + + gamma_insert(gamma, "fun", scheme_generalise(gamma, type_basic(btint))); + + ck_assert_ptr_nonnull(gamma_lookup(gamma, "fun")); + ck_assert_ptr_null(gamma_lookup(gamma, "fun2")); + + struct type *t1 = gamma_fresh(gamma); + ck_assert(t1->type == tvar); + struct type *t2 = gamma_fresh(gamma); + ck_assert(t2->type == tvar); + struct type *t3 = gamma_fresh(gamma); + ck_assert(t3->type == tvar); + struct type *t4 = gamma_fresh(gamma); + ck_assert(t4->type == tvar); + + ck_assert_str_ne(t1->data.tvar, t2->data.tvar); + ck_assert_str_ne(t2->data.tvar, t3->data.tvar); + ck_assert_str_ne(t3->data.tvar, t4->data.tvar); +} +END_TEST + +START_TEST(test_scheme) +{ + struct gamma *gamma = gamma_init(); + + char **var = malloc(sizeof(char *)); + var[0] = safe_strdup("a"); + struct scheme scheme = {.type=type_var("a"), .nvar=1, .var=var}; + + struct type *t = scheme_instantiate(gamma, &scheme); + ck_assert(t->type == tvar); + ck_assert_str_eq(t->data.tvar, "0"); + + scheme.type = type_list(type_var("a")); + t = scheme_instantiate(gamma, &scheme); + ck_assert(t->type == tlist); + ck_assert(t->data.tlist->type == tvar); + ck_assert_str_eq(t->data.tlist->data.tvar, "1"); +} +END_TEST + +START_TEST(test_subst) +{ + struct subst *s1 = subst_id(); + ck_assert_int_eq(0, s1->nvar); + s1 = subst_singleton("i1", type_basic(btint)); + ck_assert_int_eq(1, s1->nvar); + s1 = subst_union(subst_id(), subst_singleton("i1", type_basic(btint))); + ck_assert_int_eq(1, s1->nvar); + s1 = subst_union(subst_singleton("i2", type_basic(btbool)), + subst_singleton("i1", type_basic(btint))); + ck_assert_int_eq(2, s1->nvar); + +} +END_TEST + +Suite *util_suite(void) +{ + Suite *s = suite_create("List"); + + TCase *tc_gamma = tcase_create("Gamma lookup"); + tcase_add_test(tc_gamma, test_gamma_lookup); + tcase_add_test(tc_gamma, test_scheme); + tcase_add_test(tc_gamma, test_subst); + suite_add_tcase(s, tc_gamma); + + return s; +} + +int main(void) +{ + int failed; + Suite *s; + SRunner *sr; + + s = util_suite(); + sr = srunner_create(s); + + srunner_run_all(sr, CK_NORMAL); + failed = srunner_ntests_failed(sr); + srunner_free(sr); + return (failed == 0) ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/type.c b/type.c new file mode 100644 index 0000000..8fc6102 --- /dev/null +++ b/type.c @@ -0,0 +1,183 @@ +#include +#include + +#include "util.h" +#include "type.h" + +static const char *basictype_str[] = { + [btbool] = "Bool", [btchar] = "Char", [btint] = "Int", + [btvoid] = "Void", +}; + +struct type *type_arrow(struct type *l, struct type *r) +{ + struct type *res = safe_malloc(sizeof(struct type)); + res->type = tarrow; + res->data.tarrow.l = l; + res->data.tarrow.r = r; + return res; + +} + +struct type *type_basic(enum basictype type) +{ + struct type *res = safe_malloc(sizeof(struct type)); + res->type = tbasic; + res->data.tbasic = type; + return res; +} + +struct type *type_list(struct type *type) +{ + struct type *res = safe_malloc(sizeof(struct type)); + res->type = tlist; + res->data.tlist = type; + return res; +} + +struct type *type_tuple(struct type *l, struct type *r) +{ + struct type *res = safe_malloc(sizeof(struct type)); + res->type = ttuple; + res->data.ttuple.l = l; + res->data.ttuple.r = r; + return res; +} + +struct type *type_var(char *ident) +{ + struct type *res = safe_malloc(sizeof(struct type)); + if (strcmp(ident, "Int") == 0) { + res->type = tbasic; + res->data.tbasic = btint; + free(ident); + } else if (strcmp(ident, "Char") == 0) { + res->type = tbasic; + res->data.tbasic = btchar; + free(ident); + } else if (strcmp(ident, "Bool") == 0) { + res->type = tbasic; + res->data.tbasic = btbool; + free(ident); + } else if (strcmp(ident, "Void") == 0) { + res->type = tbasic; + res->data.tbasic = btvoid; + free(ident); + } else { + res->type = tvar; + res->data.tvar = ident; + } + return res; +} + +void type_print(struct type *type, FILE *out) +{ + if (type == NULL) + return; + switch (type->type) { + case tarrow: + safe_fprintf(out, "("); + type_print(type->data.tarrow.l, out); + safe_fprintf(out, "->"); + type_print(type->data.tarrow.r, out); + safe_fprintf(out, ")"); + break; + case tbasic: + safe_fprintf(out, "%s", basictype_str[type->data.tbasic]); + break; + case tlist: + safe_fprintf(out, "["); + type_print(type->data.tlist, out); + safe_fprintf(out, "]"); + break; + case ttuple: + safe_fprintf(out, "("); + type_print(type->data.ttuple.l, out); + safe_fprintf(out, ","); + type_print(type->data.ttuple.r, out); + safe_fprintf(out, ")"); + break; + case tvar: + safe_fprintf(out, "%s", type->data.tvar); + break; + default: + die("Unsupported type node\n"); + } +} + +void type_free(struct type *type) +{ + if (type == NULL) + return; + switch (type->type) { + case tarrow: + type_free(type->data.tarrow.l); + type_free(type->data.tarrow.r); + break; + case tbasic: + break; + case tlist: + type_free(type->data.tlist); + break; + case ttuple: + type_free(type->data.ttuple.l); + type_free(type->data.ttuple.r); + break; + case tvar: + free(type->data.tvar); + break; + default: + die("Unsupported type node\n"); + } + free(type); +} + +struct type *type_dup(struct type *r) +{ + struct type *res = safe_malloc(sizeof(struct type)); + *res = *r; + switch (r->type) { + case tarrow: + res->data.tarrow.l = type_dup(r->data.tarrow.l); + res->data.tarrow.r = type_dup(r->data.tarrow.r); + break; + case tbasic: + break; + case tlist: + res->data.tlist = type_dup(r->data.tlist); + break; + case ttuple: + res->data.ttuple.l = type_dup(r->data.ttuple.l); + res->data.ttuple.r = type_dup(r->data.ttuple.r); + break; + case tvar: + res->data.tvar = safe_strdup(r->data.tvar); + break; + } + return res; +} + +void type_ftv(struct type *r, int *nftv, char ***ftv) +{ + switch (r->type) { + case tarrow: + type_ftv(r->data.ttuple.l, nftv, ftv); + type_ftv(r->data.ttuple.r, nftv, ftv); + break; + case tbasic: + break; + case tlist: + type_ftv(r->data.tlist, nftv, ftv); + break; + case ttuple: + type_ftv(r->data.ttuple.l, nftv, ftv); + type_ftv(r->data.ttuple.r, nftv, ftv); + break; + case tvar: + *ftv = realloc(*ftv, (*nftv+1)*sizeof(char *)); + if (*ftv == NULL) + perror("realloc"); + (*ftv)[(*nftv)++] = r->data.tvar; + break; + } +} diff --git a/type.h b/type.h new file mode 100644 index 0000000..0bbb2e0 --- /dev/null +++ b/type.h @@ -0,0 +1,36 @@ +#ifndef TYPE_H +#define TYPE_H + +#include + +enum basictype {btbool, btchar, btint, btvoid}; +struct type { + enum {tarrow,tbasic,tlist,ttuple,tvar} type; + union { + struct { + struct type *l; + struct type *r; + } tarrow; + enum basictype tbasic; + struct type *tlist; + struct { + struct type *l; + struct type *r; + } ttuple; + char *tvar; + } data; +}; + +struct type *type_arrow(struct type *l, struct type *r); +struct type *type_basic(enum basictype type); +struct type *type_list(struct type *type); +struct type *type_tuple(struct type *l, struct type *r); +struct type *type_var(char *ident); + +void type_print(struct type *type, FILE *out); +void type_free(struct type *type); + +struct type *type_dup(struct type *t); +void type_ftv(struct type *r, int *nftv, char ***ftv); + +#endif diff --git a/util.c b/util.c index efe22c7..fa60ddc 100644 --- a/util.c +++ b/util.c @@ -154,9 +154,11 @@ void *safe_malloc(size_t size) void *safe_strdup(const char *c) { - char *res = strdup(c); + size_t nchar = strlen(c); + char *res = malloc((nchar+1)*sizeof(char)); if (res == NULL) pdie("strdup"); + memcpy(res, c, nchar+1); return res; }