From f9c160e32ee3b50e71de32b91ed225b7059001bd Mon Sep 17 00:00:00 2001 From: Mart Lubbers Date: Wed, 24 Feb 2021 09:40:17 +0100 Subject: [PATCH] locations, type checking --- ast.c | 74 +++++++++++++++++--------- ast.h | 65 +++++++++++----------- genc.c | 3 ++ in.txt | 1 + parse.y | 72 ++++++++++++------------- scan.l | 10 ++-- sem.c | 25 +++++---- sem.h | 3 ++ sem/hm.c | 139 ++++++++++++++++++++++++++---------------------- sem/hm.h | 2 +- sem/hm/scheme.c | 8 ++- type.c | 12 +++-- type.h | 5 +- 13 files changed, 237 insertions(+), 182 deletions(-) create mode 100644 in.txt diff --git a/ast.c b/ast.c index 1e93a71..88083bd 100644 --- a/ast.c +++ b/ast.c @@ -18,25 +18,29 @@ const char *fieldspec_str[] = { [fst] = "fst", [snd] = "snd", [hd] = "hd", [tl] = "tl"}; const char *unop_str[] = { [inverse] = "!", [negate] = "-", }; -struct ast *ast(struct list *decls) +struct ast *ast(struct list *decls, YYLTYPE l) { struct ast *res = safe_malloc(sizeof(struct ast)); + res->loc = l; + res->loc = l; res->decls = (struct decl **)list_to_array(decls, &res->ndecls, true); return res; } -struct vardecl *vardecl(struct type *type, char *ident, struct expr *expr) +struct vardecl *vardecl(struct type *type, char *ident, struct expr *expr, YYLTYPE l) { struct vardecl *res = safe_malloc(sizeof(struct vardecl)); + res->loc = l; res->type = type; res->ident = ident; res->expr = expr; return res; } struct fundecl *fundecl(char *ident, struct list *args, struct list *atypes, - struct type *rtype, struct list *body) + struct type *rtype, struct list *body, YYLTYPE l) { struct fundecl *res = safe_malloc(sizeof(struct fundecl)); + res->loc = l; res->ident = ident; res->args = (char **)list_to_array(args, &res->nargs, true); res->atypes = (struct type **)list_to_array(atypes, &res->natypes, true); @@ -45,25 +49,30 @@ struct fundecl *fundecl(char *ident, struct list *args, struct list *atypes, return res; } -struct decl *decl_fun(struct fundecl *fundecl) +struct decl *decl_fun(struct fundecl *fundecl, YYLTYPE l) { struct decl *res = safe_malloc(sizeof(struct decl)); + res->loc = l; + res->loc = l; res->type = dfundecl; res->data.dfun = fundecl; return res; } -struct decl *decl_var(struct vardecl *vardecl) +struct decl *decl_var(struct vardecl *vardecl, YYLTYPE l) { struct decl *res = safe_malloc(sizeof(struct decl)); + res->loc = l; + res->loc = l; res->type = dvardecl; res->data.dvar = vardecl; return res; } -struct stmt *stmt_assign(char *ident, struct list *fields, struct expr *expr) +struct stmt *stmt_assign(char *ident, struct list *fields, struct expr *expr, YYLTYPE l) { struct stmt *res = safe_malloc(sizeof(struct stmt)); + res->loc = l; res->type = sassign; res->data.sassign.ident = ident; res->data.sassign.fields = (char **) @@ -72,9 +81,10 @@ struct stmt *stmt_assign(char *ident, struct list *fields, struct expr *expr) return res; } -struct stmt *stmt_if(struct expr *pred, struct list *then, struct list *els) +struct stmt *stmt_if(struct expr *pred, struct list *then, struct list *els, YYLTYPE l) { struct stmt *res = safe_malloc(sizeof(struct stmt)); + res->loc = l; res->type = sif; res->data.sif.pred = pred; res->data.sif.then = (struct stmt **) @@ -84,33 +94,37 @@ struct stmt *stmt_if(struct expr *pred, struct list *then, struct list *els) return res; } -struct stmt *stmt_return(struct expr *rtrn) +struct stmt *stmt_return(struct expr *rtrn, YYLTYPE l) { struct stmt *res = safe_malloc(sizeof(struct stmt)); + res->loc = l; res->type = sreturn; res->data.sreturn = rtrn; return res; } -struct stmt *stmt_expr(struct expr *expr) +struct stmt *stmt_expr(struct expr *expr, YYLTYPE l) { struct stmt *res = safe_malloc(sizeof(struct stmt)); + res->loc = l; res->type = sexpr; res->data.sexpr = expr; return res; } -struct stmt *stmt_vardecl(struct vardecl *vardecl) +struct stmt *stmt_vardecl(struct vardecl *vardecl, YYLTYPE l) { struct stmt *res = safe_malloc(sizeof(struct stmt)); + res->loc = l; res->type = svardecl; res->data.svardecl = vardecl; return res; } -struct stmt *stmt_while(struct expr *pred, struct list *body) +struct stmt *stmt_while(struct expr *pred, struct list *body, YYLTYPE l) { struct stmt *res = safe_malloc(sizeof(struct stmt)); + res->loc = l; res->type = swhile; res->data.swhile.pred = pred; res->data.swhile.body = (struct stmt **) @@ -118,27 +132,30 @@ struct stmt *stmt_while(struct expr *pred, struct list *body) return res; } -struct expr *expr_binop(struct expr *l, enum binop op, struct expr *r) +struct expr *expr_binop(struct expr *left, enum binop op, struct expr *right, YYLTYPE l) { struct expr *res = safe_malloc(sizeof(struct expr)); + res->loc = l; res->type = ebinop; - res->data.ebinop.l = l; + res->data.ebinop.l = left; res->data.ebinop.op = op; - res->data.ebinop.r = r; + res->data.ebinop.r = right; return res; } -struct expr *expr_bool(bool b) +struct expr *expr_bool(bool b, YYLTYPE l) { struct expr *res = safe_malloc(sizeof(struct expr)); + res->loc = l; res->type = ebool; res->data.ebool = b; return res; } -struct expr *expr_char(char *c) +struct expr *expr_char(char *c, YYLTYPE l) { struct expr *res = safe_malloc(sizeof(struct expr)); + res->loc = l; res->type = echar; res->data.echar = unescape_char(c)[0]; return res; @@ -163,10 +180,10 @@ static void set_fields(enum fieldspec **farray, int *n, struct list *fields) free(els); } - -struct expr *expr_funcall(char *ident, struct list *args, struct list *fields) +struct expr *expr_funcall(char *ident, struct list *args, struct list *fields, YYLTYPE l) { struct expr *res = safe_malloc(sizeof(struct expr)); + res->loc = l; res->type = efuncall; res->data.efuncall.ident = ident; res->data.efuncall.args = (struct expr **) @@ -176,45 +193,51 @@ struct expr *expr_funcall(char *ident, struct list *args, struct list *fields) return res; } -struct expr *expr_int(int integer) +struct expr *expr_int(int integer, YYLTYPE l) { struct expr *res = safe_malloc(sizeof(struct expr)); + res->loc = l; res->type = eint; res->data.eint = integer; return res; } -struct expr *expr_ident(char *ident, struct list *fields) +struct expr *expr_ident(char *ident, struct list *fields, YYLTYPE l) { struct expr *res = safe_malloc(sizeof(struct expr)); + res->loc = l; res->type = eident; res->data.eident.ident = ident; set_fields(&res->data.eident.fields, &res->data.eident.nfields, fields); return res; } -struct expr *expr_nil() +struct expr *expr_nil(YYLTYPE l) { struct expr *res = safe_malloc(sizeof(struct expr)); + res->loc = l; res->type = enil; return res; } -struct expr *expr_tuple(struct expr *left, struct expr *right) +struct expr *expr_tuple(struct expr *left, struct expr *right, YYLTYPE l) { struct expr *res = safe_malloc(sizeof(struct expr)); + res->loc = l; res->type = etuple; res->data.etuple.left = left; res->data.etuple.right = right; return res; } -struct expr *expr_string(char *str) +struct expr *expr_string(char *str, YYLTYPE l) { struct expr *res = safe_malloc(sizeof(struct expr)); + res->loc = l; res->type = estring; res->data.estring.nchars = 0; res->data.estring.chars = safe_malloc(strlen(str)+1); + res->loc = l; char *p = res->data.estring.chars; while(*str != '\0') { str = unescape_char(str); @@ -225,12 +248,13 @@ struct expr *expr_string(char *str) return res; } -struct expr *expr_unop(enum unop op, struct expr *l) +struct expr *expr_unop(enum unop op, struct expr *e, YYLTYPE l) { struct expr *res = safe_malloc(sizeof(struct expr)); + res->loc = l; res->type = eunop; res->data.eunop.op = op; - res->data.eunop.l = l; + res->data.eunop.l = e; return res; } diff --git a/ast.h b/ast.h index 635d0c8..5b6c506 100644 --- a/ast.h +++ b/ast.h @@ -13,16 +13,19 @@ extern const char *fieldspec_str[]; extern const char *binop_str[]; extern const char *unop_str[]; struct ast { + YYLTYPE loc; int ndecls; struct decl **decls; }; struct vardecl { + YYLTYPE loc; struct type *type; char *ident; struct expr *expr; }; struct fundecl { + YYLTYPE loc; char *ident; int nargs; char **args; @@ -34,6 +37,7 @@ struct fundecl { }; struct decl { + YYLTYPE loc; //NOTE: DON'T CHANGE THIS ORDER enum {dcomp, dvardecl, dfundecl} type; union { @@ -47,6 +51,7 @@ struct decl { }; struct stmt { + YYLTYPE loc; enum {sassign, sif, sreturn, sexpr, svardecl, swhile} type; union { struct { @@ -80,6 +85,7 @@ enum binop { enum fieldspec {fst,snd,hd,tl}; enum unop {negate,inverse}; struct expr { + YYLTYPE loc; enum {ebinop, ebool, echar, efuncall, eident, eint, enil, etuple, estring, eunop} type; union { @@ -118,45 +124,44 @@ struct expr { } data; }; -struct ast *ast(struct list *decls); +struct ast *ast(struct list *decls, YYLTYPE l); -struct vardecl *vardecl(struct type *type, char *ident, struct expr *expr); -struct fundecl *fundecl(char *ident, struct list *args, struct list *atypes, - struct type *rtype, struct list *body); +struct vardecl *vardecl(struct type *type, char *ident, struct expr *expr, YYLTYPE l); +struct fundecl *fundecl(char *ident, struct list *args, struct list *atypes, struct type *rtype, struct list *body, YYLTYPE l); -struct decl *decl_fun(struct fundecl *fundecl); -struct decl *decl_var(struct vardecl *vardecl); +struct decl *decl_fun(struct fundecl *fundecl, YYLTYPE l); +struct decl *decl_var(struct vardecl *vardecl, YYLTYPE l); -struct stmt *stmt_assign(char *ident, struct list *fields, struct expr *expr); -struct stmt *stmt_if(struct expr *pred, struct list *then, struct list *els); -struct stmt *stmt_return(struct expr *rtrn); -struct stmt *stmt_expr(struct expr *expr); -struct stmt *stmt_vardecl(struct vardecl *vardecl); -struct stmt *stmt_while(struct expr *pred, struct list *body); +struct stmt *stmt_assign(char *ident, struct list *fields, struct expr *expr, YYLTYPE l); +struct stmt *stmt_if(struct expr *pred, struct list *then, struct list *els, YYLTYPE l); +struct stmt *stmt_return(struct expr *rtrn, YYLTYPE l); +struct stmt *stmt_expr(struct expr *expr, YYLTYPE l); +struct stmt *stmt_vardecl(struct vardecl *vardecl, YYLTYPE l); +struct stmt *stmt_while(struct expr *pred, struct list *body, YYLTYPE l); -struct expr *expr_binop(struct expr *l, enum binop op, struct expr *r); -struct expr *expr_bool(bool b); -struct expr *expr_char(char *c); -struct expr *expr_funcall(char *ident, struct list *args, struct list *fields); -struct expr *expr_int(int integer); -struct expr *expr_ident(char *ident, struct list *fields); -struct expr *expr_nil(); -struct expr *expr_tuple(struct expr *left, struct expr *right); -struct expr *expr_string(char *str); -struct expr *expr_unop(enum unop op, struct expr *l); +struct expr *expr_binop(struct expr *left, enum binop op, struct expr *right, YYLTYPE l); +struct expr *expr_bool(bool b, YYLTYPE l); +struct expr *expr_char(char *c, YYLTYPE l); +struct expr *expr_funcall(char *ident, struct list *args, struct list *fields, YYLTYPE l); +struct expr *expr_int(int integer, YYLTYPE l); +struct expr *expr_ident(char *ident, struct list *fields, YYLTYPE l); +struct expr *expr_nil(YYLTYPE l); +struct expr *expr_tuple(struct expr *left, struct expr *right, YYLTYPE l); +struct expr *expr_string(char *str, YYLTYPE l); +struct expr *expr_unop(enum unop op, struct expr *e, YYLTYPE l); void ast_print(struct ast *, FILE *out); void vardecl_print(struct vardecl *decl, int indent, FILE *out); void fundecl_print(struct fundecl *decl, FILE *out); -void decl_print(struct decl *ast, FILE *out); -void stmt_print(struct stmt *ast, int indent, FILE *out); -void expr_print(struct expr *ast, FILE *out); +void decl_print(struct decl *decl, FILE *out); +void stmt_print(struct stmt *stmt, int indent, FILE *out); +void expr_print(struct expr *expr, FILE *out); -void ast_free(struct ast *); +void ast_free(struct ast *ast); void vardecl_free(struct vardecl *decl); -void fundecl_free(struct fundecl *fundecl); -void decl_free(struct decl *ast); -void stmt_free(struct stmt *ast); -void expr_free(struct expr *ast); +void fundecl_free(struct fundecl *decl); +void decl_free(struct decl *decl); +void stmt_free(struct stmt *stmt); +void expr_free(struct expr *expr); #endif diff --git a/genc.c b/genc.c index 5ebfa29..d808c36 100644 --- a/genc.c +++ b/genc.c @@ -103,6 +103,9 @@ void type_genc(struct type *type, FILE *cout) case tvar: fprintf(cout, "WORD "); break; + case tarrow: + die("Arrows cannot be generated\n"); + break; default: die("Unsupported type node\n"); } diff --git a/in.txt b/in.txt new file mode 100644 index 0000000..564415e --- /dev/null +++ b/in.txt @@ -0,0 +1 @@ +fun (x) { return fun(x); } diff --git a/parse.y b/parse.y index 29fa0ca..2c4271b 100644 --- a/parse.y +++ b/parse.y @@ -61,21 +61,21 @@ int yywrap() %% -start : decls { *result = ast($1); } ; +start : decls { *result = ast($1, @$); } ; decls : /* empty */ { $$ = NULL; } - | decls vardecl { $$ = list_cons(decl_var($2), $1); } - | decls fundecl { $$ = list_cons(decl_fun($2), $1); } + | decls vardecl { $$ = list_cons(decl_var($2, @2), $1); } + | decls fundecl { $$ = list_cons(decl_fun($2, @2), $1); } ; vardecl - : VAR IDENT ASSIGN expr SEMICOLON { $$ = vardecl(NULL, $2, $4); } - | type IDENT ASSIGN expr SEMICOLON { $$ = vardecl($1, $2, $4); } + : VAR IDENT ASSIGN expr SEMICOLON { $$ = vardecl(NULL, $2, $4, @$); } + | type IDENT ASSIGN expr SEMICOLON { $$ = vardecl($1, $2, $4, @$); } ; fundecl : IDENT BOPEN args BCLOSE COPEN body CCLOSE - { $$ = fundecl($1, $3, NULL, NULL, $6); } + { $$ = fundecl($1, $3, NULL, NULL, $6, @$); } | IDENT BOPEN args BCLOSE CONS CONS funtype ARROW ftype COPEN body CCLOSE - { $$ = fundecl($1, $3, $7, $9, $11); } + { $$ = fundecl($1, $3, $7, $9, $11, @$); } ; funtype : /* empty */ { $$ = NULL; } @@ -123,40 +123,40 @@ bbody | stmt { $$ = list_cons($1, NULL); } ; stmt - : IF BOPEN expr BCLOSE bbody { $$ = stmt_if($3, $5, NULL); } - | IF BOPEN expr BCLOSE bbody ELSE bbody { $$ = stmt_if($3, $5, $7); } - | WHILE BOPEN expr BCLOSE bbody { $$ = stmt_while($3, $5); } - | IDENT field ASSIGN expr SEMICOLON { $$ = stmt_assign($1, $2, $4); } - | RETURN expr SEMICOLON { $$ = stmt_return($2); } - | RETURN SEMICOLON { $$ = stmt_return(NULL); } - | vardecl { $$ = stmt_vardecl($1); } - | expr SEMICOLON { $$ = stmt_expr($1); } + : IF BOPEN expr BCLOSE bbody { $$ = stmt_if($3, $5, NULL, @$); } + | IF BOPEN expr BCLOSE bbody ELSE bbody { $$ = stmt_if($3, $5, $7, @$); } + | WHILE BOPEN expr BCLOSE bbody { $$ = stmt_while($3, $5, @$); } + | IDENT field ASSIGN expr SEMICOLON { $$ = stmt_assign($1, $2, $4, @$); } + | RETURN expr SEMICOLON { $$ = stmt_return($2, @$); } + | RETURN SEMICOLON { $$ = stmt_return(NULL, @$); } + | vardecl { $$ = stmt_vardecl($1, @$); } + | expr SEMICOLON { $$ = stmt_expr($1, @$); } ; expr - : expr BINOR expr { $$ = expr_binop($1, binor, $3); } - | expr BINAND expr { $$ = expr_binop($1, binand, $3); } - | expr EQ expr { $$ = expr_binop($1, eq, $3); } - | expr NEQ expr { $$ = expr_binop($1, neq, $3); } - | expr LEQ expr { $$ = expr_binop($1, leq, $3); } - | expr LE expr { $$ = expr_binop($1, le, $3); } - | expr GEQ expr { $$ = expr_binop($1, geq, $3); } - | expr GE expr { $$ = expr_binop($1, ge, $3); } - | expr CONS expr { $$ = expr_binop($1, cons, $3); } - | expr PLUS expr { $$ = expr_binop($1, plus, $3); } - | expr MINUS expr { $$ = expr_binop($1, minus, $3); } - | expr TIMES expr { $$ = expr_binop($1, times, $3); } - | expr DIVIDE expr { $$ = expr_binop($1, divide, $3); } - | expr MODULO expr { $$ = expr_binop($1, modulo, $3); } - | expr POWER expr { $$ = expr_binop($1, power, $3); } - | MINUS expr %prec TIMES { $$ = expr_unop(negate, $2); } - | INVERSE expr %prec TIMES { $$ = expr_unop(inverse, $2); } - | IDENT BOPEN fargs BCLOSE field { $$ = expr_funcall($1, $3, $5); } - | BOPEN expr COMMA expr BCLOSE { $$ = expr_tuple($2, $4); } + : expr BINOR expr { $$ = expr_binop($1, binor, $3, @$); } + | expr BINAND expr { $$ = expr_binop($1, binand, $3, @$); } + | expr EQ expr { $$ = expr_binop($1, eq, $3, @$); } + | expr NEQ expr { $$ = expr_binop($1, neq, $3, @$); } + | expr LEQ expr { $$ = expr_binop($1, leq, $3, @$); } + | expr LE expr { $$ = expr_binop($1, le, $3, @$); } + | expr GEQ expr { $$ = expr_binop($1, geq, $3, @$); } + | expr GE expr { $$ = expr_binop($1, ge, $3, @$); } + | expr CONS expr { $$ = expr_binop($1, cons, $3, @$); } + | expr PLUS expr { $$ = expr_binop($1, plus, $3, @$); } + | expr MINUS expr { $$ = expr_binop($1, minus, $3, @$); } + | expr TIMES expr { $$ = expr_binop($1, times, $3, @$); } + | expr DIVIDE expr { $$ = expr_binop($1, divide, $3, @$); } + | expr MODULO expr { $$ = expr_binop($1, modulo, $3, @$); } + | expr POWER expr { $$ = expr_binop($1, power, $3, @$); } + | MINUS expr %prec TIMES { $$ = expr_unop(negate, $2, @$); } + | INVERSE expr %prec TIMES { $$ = expr_unop(inverse, $2, @$); } + | IDENT BOPEN fargs BCLOSE field { $$ = expr_funcall($1, $3, $5, @$); } + | BOPEN expr COMMA expr BCLOSE { $$ = expr_tuple($2, $4, @$); } | BOPEN expr BCLOSE { $$ = $2; } | INTEGER | BOOL | CHAR | STRING - | IDENT field { $$ = expr_ident($1, $2); } - | NIL { $$ = expr_nil(); } + | IDENT field { $$ = expr_ident($1, $2, @$); } + | NIL { $$ = expr_nil(@$); } ; diff --git a/scan.l b/scan.l index 664868c..31620af 100644 --- a/scan.l +++ b/scan.l @@ -39,8 +39,8 @@ if return IF; else return ELSE; while return WHILE; var return VAR; -true { yylval.expr = expr_bool(true); return BOOL; } -false { yylval.expr = expr_bool(false); return BOOL; } +true { yylval.expr = expr_bool(true, yylloc); return BOOL; } +false { yylval.expr = expr_bool(false, yylloc); return BOOL; } return return RETURN; Int return TINT; Bool return TBOOL; @@ -75,11 +75,11 @@ Void return TVOID; \. return DOT; , return COMMA; \"([^\\"]|\\(\"|{E}))*\" { - yylval.expr = expr_string(trimquotes(yytext)); return STRING; } + yylval.expr = expr_string(trimquotes(yytext), yylloc); return STRING; } '([^\\']|\\('|{E}))' { - yylval.expr = expr_char(trimquotes(yytext)); return CHAR; } + yylval.expr = expr_char(trimquotes(yytext), yylloc); return CHAR; } {D}+ { - yylval.expr = expr_int(atoi(yytext)); return INTEGER; } + yylval.expr = expr_int(atoi(yytext), yylloc); return INTEGER; } {I}({I}|{D})* { yylval.ident = safe_strdup(yytext); return IDENT; } } diff --git a/sem.c b/sem.c index 5c9a528..7fd0c7a 100644 --- a/sem.c +++ b/sem.c @@ -6,14 +6,15 @@ #include "sem/hm.h" #include "ast.h" -void type_error(const char *msg, ...) +void type_error(YYLTYPE l, bool d, const char *msg, ...) { va_list ap; va_start(ap, msg); - fprintf(stderr, "type error: "); + fprintf(stderr, "Type error\n%d-%d: ", l.first_line, l.first_column); vfprintf(stderr, msg, ap); va_end(ap); - die(""); + if (d) + die(""); } void check_expr_constant(struct expr *expr) @@ -28,7 +29,8 @@ void check_expr_constant(struct expr *expr) break; case efuncall: case eident: - type_error("Initialiser is not constant\n"); + type_error(expr->loc, true, + "Initialiser is not constant (identifier used)\n"); break; default: break; @@ -67,16 +69,13 @@ struct ast *sem(struct ast *ast) break; case dfundecl: { struct type *f1 = gamma_fresh(gamma); - gamma_insert(gamma, ast->decls[i]->data.dfun->ident - , scheme_create(f1)); - struct subst *s = infer_fundecl(gamma, ast->decls[i]->data.dfun); + struct subst *s = infer_fundecl(gamma, ast->decls[i]->data.dfun, f1); + f1 = subst_apply_t(s, f1); + gamma_insert(gamma, ast->decls[i]->data.dfun->ident, scheme_generalise(gamma, subst_apply_t(s, f1))); +// type_free(f1); subst_free(s); -//infer env (Let [(x, e1)] e2) -// = fresh -// >>= \tv-> let env` = 'Data.Map'.put x (Forall [] tv) env -// in infer env` e1 -// >>= \(s1,t1)-> infer ('Data.Map'.put x (generalize (apply s1 env`) t1) env`) e2 -// >>= \(s2, t2)->pure (s1 oo s2, t2) + gamma_print(gamma, stderr); + fprintf(stderr, "done\n"); break; } case dcomp: diff --git a/sem.h b/sem.h index 3785a9d..3c1a929 100644 --- a/sem.h +++ b/sem.h @@ -1,8 +1,11 @@ #ifndef SEM_H #define SEM_H +#include + #include "ast.h" +void type_error(YYLTYPE l, bool d, const char *msg, ...); struct ast *sem(struct ast *ast); #endif diff --git a/sem/hm.c b/sem/hm.c index 4b88076..ebc21dd 100644 --- a/sem/hm.c +++ b/sem/hm.c @@ -5,6 +5,7 @@ #include "hm/subst.h" #include "hm/gamma.h" #include "hm/scheme.h" +#include "../sem.h" #include "../ast.h" bool occurs_check(char *var, struct type *r) @@ -23,18 +24,18 @@ bool occurs_check(char *var, struct type *r) return res; } -struct subst *unify(struct type *l, struct type *r) +struct subst *unify(YYLTYPE loc, struct type *l, struct type *r) { if (l == NULL || r == NULL) return NULL; if (r->type == tvar && l->type != tvar) - return unify(r, l); + return unify(loc, r, l); struct subst *s1, *s2, *s3; switch (l->type) { case tarrow: if (r->type == tarrow) { - s1 = unify(l->data.tarrow.l, r->data.tarrow.l); - s2 = unify(subst_apply_t(s1, l->data.tarrow.l), + s1 = unify(loc, l->data.tarrow.l, r->data.tarrow.l); + s2 = unify(loc, subst_apply_t(s1, l->data.tarrow.l), subst_apply_t(s1, r->data.tarrow.l)); s3 = subst_union(s1, s2); return s3; @@ -46,12 +47,12 @@ struct subst *unify(struct type *l, struct type *r) break; case tlist: if (r->type == tlist) - return unify(l->data.tlist, r->data.tlist); + return unify(loc, l->data.tlist, r->data.tlist); break; case ttuple: if (r->type == ttuple) { - s1 = unify(l->data.ttuple.l, r->data.ttuple.l); - s2 = unify(subst_apply_t(s1, l->data.ttuple.l), + s1 = unify(loc, l->data.ttuple.l, r->data.ttuple.l); + s2 = unify(loc, subst_apply_t(s1, l->data.ttuple.l), subst_apply_t(s1, r->data.ttuple.l)); s3 = subst_union(s1, s2); subst_free(s1); @@ -63,22 +64,24 @@ struct subst *unify(struct type *l, struct type *r) if (r->type == tvar && strcmp(l->data.tvar, r->data.tvar) == 0) return subst_id(); else if (occurs_check(l->data.tvar, r)) - fprintf(stderr, "Infinite type %s\n", l->data.tvar); + type_error(loc, true, "Infinite type %s\n", + l->data.tvar); else return subst_singleton(l->data.tvar, r); break; } - fprintf(stderr, "cannot unify "); + type_error(loc, false, "cannot unify "); type_print(l, stderr); fprintf(stderr, " with "); type_print(r, stderr); - fprintf(stderr, "\n"); + die("\n"); return NULL; } -struct subst *unifyfree(struct type *l, struct type *r, bool freel, bool freer) +struct subst *unifyfree(YYLTYPE loc, + struct type *l, struct type *r, bool freel, bool freer) { - struct subst *s = unify(l, r); + struct subst *s = unify(loc, l, r); if (freel) type_free(l); if (freer) @@ -92,7 +95,7 @@ struct subst *infer_binop(struct gamma *gamma, struct expr *l, struct expr *r, struct subst *s1 = infer_expr(gamma, l, a1); struct subst *s2 = infer_expr(subst_apply_g(s1, gamma), r, a2); struct subst *s3 = subst_union(s1, s2); - struct subst *s4 = unify(subst_apply_t(s3, sigma), rt); + struct subst *s4 = unify(l->loc, subst_apply_t(s3, sigma), rt); struct subst *s5 = subst_union(s3, s4); subst_free(s1); subst_free(s2); @@ -105,7 +108,7 @@ struct subst *infer_unop(struct gamma *gamma, struct expr *e, struct type *a, struct type *rt, struct type *sigma) { struct subst *s1 = infer_expr(gamma, e, a); - struct subst *s2 = unify(subst_apply_t(s1, sigma), rt); + struct subst *s2 = unify(e->loc, subst_apply_t(s1, sigma), rt); struct subst *s3 = subst_union(s1, s2); subst_free(s1); subst_free(s2); @@ -129,12 +132,12 @@ struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *ty #define infbinop(e, a1, a2, rt) infer_binop(\ gamma, e->data.ebinop.l, e->data.ebinop.r, a1, a2, rt, type) - struct subst *s1; + struct subst *s0; struct type *f1, *f2, *f3; struct scheme *s; switch (expr->type) { case ebool: - return unify(&tybool, type); + return unify(expr->loc, &tybool, type); case ebinop: switch (expr->data.ebinop.op) { case binor: @@ -147,15 +150,15 @@ struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *ty case geq: case ge: f1 = gamma_fresh(gamma); - s1 = infbinop(expr, f1, f1, &tybool); + s0 = infbinop(expr, f1, f1, &tybool); type_free(f1); - return s1; + return s0; case cons: f1 = gamma_fresh(gamma); f2 = type_list(f1); - s1 = infbinop(expr, f1, f2, f2); + s0 = infbinop(expr, f1, f2, f2); type_free(f2); - return s1; + return s0; case plus: case minus: case times: @@ -166,36 +169,57 @@ struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *ty } break; case echar: - return unify(&tychar, type); + return unify(expr->loc, &tychar, type); case efuncall: if ((s = gamma_lookup(gamma, expr->data.efuncall.ident)) == NULL) - die("Unbound function: %s\n", expr->data.efuncall.ident); - //TODO + type_error(expr->loc, "Unbound function: %s\n" + , expr->data.efuncall.ident); + struct type *t = scheme_instantiate(gamma, s); + struct subst *s0 = subst_id(); + for (int i = 0; idata.efuncall.nargs; i++) { + if (t->type != tarrow) + type_error(expr->loc, true, + "too many arguments to %s\n", + expr->data.efuncall.ident); + struct subst *s1 = infer_expr(gamma, + expr->data.efuncall.args[i], t->data.tarrow.l); + struct subst *s2 = s0; + s0 = subst_union(s2, s1); + subst_free(s1); + subst_free(s2); + t = t->data.tarrow.r; + } + if (t->type == tarrow) + type_error(expr->loc, true, + "not enough arguments to %s\n", + expr->data.efuncall.ident); + type_free(t); //TODO fields - return NULL; + return s0; case eint: - return unify(&tyint, type); + return unify(expr->loc, &tyint, type); case eident: if ((s = gamma_lookup(gamma, expr->data.eident.ident)) == NULL) - die("Unbound variable: %s\n", expr->data.eident.ident); + type_error(expr->loc, true, "Unbound variable: %s\n" + , expr->data.eident.ident); f1 = scheme_instantiate(gamma, s); - s1 = unify(f1, type); + s0 = unify(expr->loc, f1, type); type_free(f1); //TODO field - return s1; + return s0; case enil: f1 = gamma_fresh(gamma); - return unifyfree(type_list(f1), type, true, false); + return unifyfree(expr->loc, type_list(f1), type, true, false); case etuple: f1 = gamma_fresh(gamma); f2 = gamma_fresh(gamma); f3 = type_tuple(f1, f2); - s1 = infer_binop(gamma, expr->data.etuple.left, + s0 = infer_binop(gamma, expr->data.etuple.left, expr->data.etuple.right, f1, f2, f3, type); type_free(f3); - return s1; + return s0; case estring: - return unify(&tystring, type); + return unify(expr->loc, &tystring, type); case eunop: switch (expr->data.eunop.op) { case negate: @@ -262,59 +286,44 @@ struct subst *infer_stmt(struct gamma *gamma, struct stmt *stmt, struct type *ty // } data; } -struct subst *infer_fundecl(struct gamma *gamma, struct fundecl *fundecl) +struct subst *infer_fundecl(struct gamma *gamma, struct fundecl *fundecl, struct type *fty) { - //struct type *t; + fprintf(stderr, "inferring function to type "); + type_print(fty, stderr); + fprintf(stderr, " with gamma "); + gamma_print(gamma, stderr); + fprintf(stderr, "\n"); if (fundecl->rtype == NULL || fundecl->atypes == NULL) { fundecl->rtype = gamma_fresh(gamma); - fundecl->atypes = safe_realloc(fundecl->atypes, fundecl->nargs*sizeof(struct type)); + fundecl->atypes = safe_realloc(fundecl->atypes, + fundecl->nargs*sizeof(struct type)); for (int i = 0; inargs; i++) fundecl->atypes[i] = gamma_fresh(gamma); } - fprintf(stderr, "fundecl with type: "); - for (int i = 0; inargs; i++) { - type_print(fundecl->atypes[i], stderr); - fprintf(stderr, " "); - } - fprintf(stderr, "-> "); - type_print(fundecl->rtype, stderr); - fprintf(stderr, "\n"); - for (int i = 0; inargs; i++) + struct type *ftype = type_dup(fundecl->rtype); + for (int i = 0; inargs; i++) { + ftype = type_arrow(type_dup(fundecl->atypes[i]), ftype); gamma_insert(gamma, fundecl->args[i], scheme_create(fundecl->atypes[i])); + } + gamma_insert(gamma, fundecl->ident, scheme_create(ftype)); struct subst *s = subst_id(); for (int i = 0; inbody; i++) { - struct subst *s1 = infer_stmt(gamma, fundecl->body[i], fundecl->rtype); + struct subst *s1 = infer_stmt(gamma, + fundecl->body[i], fundecl->rtype); struct subst *s2 = s; s = subst_union(s2, s1); subst_free(s1); subst_free(s2); } - fprintf(stderr, "inferred function substitution: "); - subst_print(s, stderr); - for (int i = 0; inargs; i++) fundecl->atypes[i] = subst_apply_t(s, fundecl->atypes[i]); fundecl->rtype = subst_apply_t(s, fundecl->rtype); - fprintf(stderr, "fundecl with type: "); - for (int i = 0; inargs; i++) { - type_print(fundecl->atypes[i], stderr); - fprintf(stderr, " "); - } - fprintf(stderr, "-> "); - type_print(fundecl->rtype, stderr); - fprintf(stderr, "\n"); - //char *ident; - //int nargs; - //char **args; - //int natypes; - //struct type **atypes; - //struct type *rtype; - //int nbody; - //struct stmt **body; - return s; + struct subst *r = unify(fundecl->loc, fty, ftype); + type_free(ftype); + return r; } diff --git a/sem/hm.h b/sem/hm.h index 106655b..423ce70 100644 --- a/sem/hm.h +++ b/sem/hm.h @@ -9,6 +9,6 @@ struct ast *infer(struct ast *ast); struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *type); struct subst *infer_stmt(struct gamma *gamma, struct stmt *stmt, struct type *type); -struct subst *infer_fundecl(struct gamma *gamma, struct fundecl *fundecl); +struct subst *infer_fundecl(struct gamma *gamma, struct fundecl *fundecl, struct type *ftype); #endif diff --git a/sem/hm/scheme.c b/sem/hm/scheme.c index 41d5401..33ff121 100644 --- a/sem/hm/scheme.c +++ b/sem/hm/scheme.c @@ -19,7 +19,7 @@ struct type *scheme_instantiate(struct gamma *gamma, struct scheme *sch) struct scheme *scheme_create(struct type *t) { struct scheme *s = safe_malloc(sizeof(struct scheme)); - s->type = t; + s->type = type_dup(t); s->nvar = 0; s->var = NULL; return s; @@ -32,7 +32,7 @@ struct scheme *scheme_generalise(struct gamma *gamma, struct type *t) char **ftv = NULL; type_ftv(t, &nftv, &ftv); - s->type = t; + s->type = type_dup(t); s->nvar = 0; s->var = safe_malloc(nftv*sizeof(char *)); for (int i = 0; invar > 0) { fprintf(out, "A."); for (int i = 0; invar; i++) diff --git a/type.c b/type.c index 8fc6102..39d67b5 100644 --- a/type.c +++ b/type.c @@ -16,7 +16,6 @@ struct type *type_arrow(struct type *l, struct type *r) res->data.tarrow.l = l; res->data.tarrow.r = r; return res; - } struct type *type_basic(enum basictype type) @@ -101,7 +100,7 @@ void type_print(struct type *type, FILE *out) safe_fprintf(out, "%s", type->data.tvar); break; default: - die("Unsupported type node\n"); + die("Unsupported type node: %d\n", type->type); } } @@ -127,7 +126,7 @@ void type_free(struct type *type) free(type->data.tvar); break; default: - die("Unsupported type node\n"); + die("Unsupported type node: %d\n", type->type); } free(type); } @@ -153,6 +152,8 @@ struct type *type_dup(struct type *r) case tvar: res->data.tvar = safe_strdup(r->data.tvar); break; + default: + die("Unsupported type node: %d\n", r->type); } return res; } @@ -174,10 +175,15 @@ void type_ftv(struct type *r, int *nftv, char ***ftv) type_ftv(r->data.ttuple.r, nftv, ftv); break; case tvar: + for (int i = 0; i<*nftv; i++) + if (strcmp((*ftv)[i], r->data.tvar) == 0) + return; *ftv = realloc(*ftv, (*nftv+1)*sizeof(char *)); if (*ftv == NULL) perror("realloc"); (*ftv)[(*nftv)++] = r->data.tvar; break; + default: + die("Unsupported type node: %d\n", r->type); } } diff --git a/type.h b/type.h index 0bbb2e0..65135aa 100644 --- a/type.h +++ b/type.h @@ -2,6 +2,7 @@ #define TYPE_H #include +#include "ast.h" enum basictype {btbool, btchar, btint, btvoid}; struct type { @@ -21,10 +22,10 @@ struct type { } data; }; -struct type *type_arrow(struct type *l, struct type *r); +struct type *type_arrow(struct type *left, struct type *right); struct type *type_basic(enum basictype type); struct type *type_list(struct type *type); -struct type *type_tuple(struct type *l, struct type *r); +struct type *type_tuple(struct type *left, struct type *right); struct type *type_var(char *ident); void type_print(struct type *type, FILE *out); -- 2.20.1