fix printing parenthesis and fields
authorMart Lubbers <mart@martlubbers.net>
Wed, 10 Mar 2021 14:08:34 +0000 (15:08 +0100)
committerMart Lubbers <mart@martlubbers.net>
Wed, 10 Mar 2021 14:08:34 +0000 (15:08 +0100)
ast.c
sem/hm.c
sem/hm/gamma.c
sem/hm/scheme.c
sem/hm/subst.c
sem/hm/subst.h

diff --git a/ast.c b/ast.c
index 2a2780c..5f9e26e 100644 (file)
--- a/ast.c
+++ b/ast.c
@@ -400,18 +400,65 @@ void stmt_print(struct stmt *stmt, int indent, FILE *out)
        }
 }
 
-void expr_print(struct expr *expr, FILE *out)
+struct ctx {
+       enum {nonfix,infix} type;
+       enum assoc {left, right, none} assoc;
+       enum assoc branch;
+       int prec;
+};
+static struct ctx nfctx = {.type=nonfix, .prec=-1, .assoc=none};
+static const struct ctx unop_ctx[] = {
+       [inverse] = {.type=infix, .prec=7, .assoc=left},
+       [negate]  = {.type=infix, .prec=7, .assoc=left}
+};
+static const struct ctx binop_ctx[] = {
+       [binor]  = {.type=infix, .prec=2, .assoc=right},
+       [binand] = {.type=infix, .prec=3, .assoc=right},
+       [eq]     = {.type=infix, .prec=4, .assoc=none},
+       [neq]    = {.type=infix, .prec=4, .assoc=none},
+       [leq]    = {.type=infix, .prec=4, .assoc=none},
+       [le]     = {.type=infix, .prec=4, .assoc=none},
+       [geq]    = {.type=infix, .prec=4, .assoc=none},
+       [ge]     = {.type=infix, .prec=4, .assoc=none},
+       [cons]   = {.type=infix, .prec=5, .assoc=right},
+       [plus]   = {.type=infix, .prec=6, .assoc=left},
+       [minus]  = {.type=infix, .prec=6, .assoc=left},
+       [times]  = {.type=infix, .prec=7, .assoc=left},
+       [divide] = {.type=infix, .prec=7, .assoc=left},
+       [modulo] = {.type=infix, .prec=7, .assoc=left},
+       [power]  = {.type=infix, .prec=8, .assoc=right},
+};
+
+static inline bool brace(struct ctx this, struct ctx outer)
+{
+       if (this.type == infix && outer.type == infix) {
+               if (outer.prec > this.prec)
+                       return true;
+               else if (outer.prec == this.prec)
+                       return this.assoc != outer.assoc
+                               || this.assoc != outer.branch;
+       }
+       return false;
+}
+
+static void expr_print2(struct expr *expr, FILE *out, struct ctx ctx)
 {
        if (expr == NULL)
                return;
        char buf[] = "\\xff";
+       struct ctx this;
        switch(expr->type) {
        case ebinop:
-               safe_fprintf(out, "(");
-               expr_print(expr->data.ebinop.l, out);
-               safe_fprintf(out, "%s", binop_str[expr->data.ebinop.op]);
-               expr_print(expr->data.ebinop.r, out);
-               safe_fprintf(out, ")");
+               this = binop_ctx[expr->data.ebinop.op];
+               if (brace(this, ctx))
+                       fprintf(out, "(");
+               this.branch = left;
+               expr_print2(expr->data.ebinop.l, out, this);
+               safe_fprintf(out, " %s ", binop_str[expr->data.ebinop.op]);
+               this.branch = right;
+               expr_print2(expr->data.ebinop.r, out, this);
+               if (brace(this, ctx))
+                       fprintf(out, ")");
                break;
        case ebool:
                safe_fprintf(out, "%s", expr->data.ebool ? "true" : "false");
@@ -423,7 +470,7 @@ void expr_print(struct expr *expr, FILE *out)
        case efuncall:
                safe_fprintf(out, "%s(", expr->data.efuncall.ident);
                ARRAY_ITER(struct expr *, e, i, expr->data.efuncall.args) {
-                       expr_print(e, out);
+                       expr_print2(e, out, nfctx);
                        if (i+1 < ARRAY_SIZE(expr->data.efuncall.args))
                                safe_fprintf(out, ", ");
                } AIEND
@@ -440,9 +487,9 @@ void expr_print(struct expr *expr, FILE *out)
                break;
        case etuple:
                safe_fprintf(out, "(");
-               expr_print(expr->data.etuple.left, out);
+               expr_print2(expr->data.etuple.left, out, nfctx);
                safe_fprintf(out, ", ");
-               expr_print(expr->data.etuple.right, out);
+               expr_print2(expr->data.etuple.right, out, nfctx);
                safe_fprintf(out, ")");
                break;
        case estring:
@@ -453,15 +500,25 @@ void expr_print(struct expr *expr, FILE *out)
                safe_fprintf(out, "\"");
                break;
        case eunop:
-               safe_fprintf(out, "(%s", unop_str[expr->data.eunop.op]);
-               expr_print(expr->data.eunop.l, out);
-               safe_fprintf(out, ")");
+               this = unop_ctx[expr->data.eunop.op];
+               if (brace(this, ctx))
+                       fprintf(out, "(");
+               safe_fprintf(out, "%s", unop_str[expr->data.eunop.op]);
+               this.branch = right;
+               expr_print2(expr->data.eunop.l, out, this);
+               if (brace(this, ctx))
+                       fprintf(out, ")");
                break;
        default:
                die("Unsupported expr node\n");
        }
 }
 
+void expr_print(struct expr *expr, FILE *out)
+{
+       expr_print2(expr, out, nfctx);
+}
+
 void ast_free(struct ast *ast)
 {
        if (ast == NULL)
index eb3f5c6..fbdfc4d 100644 (file)
--- a/sem/hm.c
+++ b/sem/hm.c
@@ -136,7 +136,7 @@ struct subst *infer_expr(struct gamma *gamma, struct expr *expr, struct type *ty
                                , expr->data.efuncall.ident);
                struct type *ft = scheme_instantiate(gamma, s);
                struct type *t = ft;
-               struct subst *s0 = subst_id();
+               s0 = subst_id();
                //Infer args
                ARRAY_ITER(struct expr *, a, i, expr->data.efuncall.args) {
                        if (t->type != tarrow)
@@ -211,10 +211,68 @@ struct subst *infer_body(struct gamma *gamma, struct array stmts, struct type *t
 struct subst *infer_stmt(struct gamma *gamma, struct stmt *stmt, struct type *type)
 {
        struct subst *s0, *s1;
+       struct scheme *s;
        struct type *f1;
        switch (stmt->type) {
        case sassign:
-               break;
+               if ((s = gamma_lookup(gamma, ident_str(stmt->data.sassign.ident))) == NULL)
+                       type_error(stmt->loc, true, "Unbound variable: %s\n"
+                               , stmt->data.sassign.ident);
+               f1 = scheme_instantiate(gamma, s);
+
+               s0 = subst_id();
+               ARRAY_ITER(char *, f, i, stmt->data.sassign.fields) {
+                       if (strcmp(f, "fst") == 0) {
+                               //f1 is the type of the variable in gamma
+                               struct type *fl = gamma_fresh(gamma);
+                               //fl is the type of the lhs of the tuple
+                               struct type *fr = gamma_fresh(gamma);
+                               //fr is the type of the rhs of the tuple
+                               struct type *ft = type_tuple(fl, fr);
+                               //ft is the type of the tuple, this must be
+                               //unified with f1
+                               s1 = unify(stmt->loc, f1, ft);
+
+                               s0 = subst_union(s1, s0);
+                               type_free(f1);
+                               f1 = subst_apply_t(s0, fl);
+                               type_free(fr);
+                               free(ft);
+                       } else if (strcmp(f, "snd") == 0) {
+                               struct type *fl = gamma_fresh(gamma);
+                               struct type *fr = gamma_fresh(gamma);
+                               struct type *ft = type_tuple(fl, fr);
+                               s1 = unify(stmt->loc, ft, f1);
+                               s0 = subst_union(s0, s1);
+                               type_free(f1);
+                               f1 = subst_apply_t(s0, fr);
+                               type_free(fl);
+                               free(ft);
+                       } else if (strcmp(f, "hd") == 0) {
+                               struct type *fe = gamma_fresh(gamma);
+                               struct type *fl = type_list(fe);
+                               s1 = unify(stmt->loc, fl, f1);
+                               s0 = subst_union(s0, s1);
+                               type_free(f1);
+                               f1 = subst_apply_t(s0, fe);
+                               free(fl);
+                       } else if (strcmp(f, "tl") == 0) {
+                               struct type *fe = gamma_fresh(gamma);
+                               struct type *fl = type_list(fe);
+                               s1 = unify(stmt->loc, fl, f1);
+                               s0 = subst_union(s0, s1);
+                               type_free(f1);
+                               f1 = subst_apply_t(s0, fl);
+                       } else {
+                               type_error(stmt->loc, true,
+                                       "Unknown field selector: %s\n", f);
+                       }
+               } AIEND
+               s1 = infer_expr(gamma, stmt->data.sassign.expr, f1);
+               s0 = subst_union(s0, s1);
+               type_free(f1);
+               subst_apply_g(s0, gamma);
+               return s0;
        case sif:
                s0 = infer_expr(gamma, stmt->data.sif.pred, &tybool);
                //subst_apply_g(s0, gamma);
@@ -242,8 +300,6 @@ struct subst *infer_stmt(struct gamma *gamma, struct stmt *stmt, struct type *ty
                else
                        s1 = subst_id();
                s0 = subst_union(s1, s0);
-               //TODO fielsd
-               //TODO
                gamma_insert(gamma, ident_str(stmt->data.svardecl->ident),
                        scheme_create(subst_apply_t(s0, f1)));
                type_free(f1);
index db69ec1..1dcad70 100644 (file)
@@ -103,19 +103,19 @@ struct type *gamma_fresh(struct gamma *gamma)
        return type_var_int(gamma->fresh++);
 }
 
-//void gamma_print(struct gamma *gamma, FILE *out)
-//{
-//     fprintf(out, "{");
-//     for (int i = 0; i<gamma->nentries; i++) {
-//             ident_print(gamma->entries[i].var, out);
-//             fprintf(out, "(%d) = ", gamma->entries[i].scope);
-//             scheme_print(gamma->entries[i].scheme, out);
-//             if (i + 1 < gamma->nentries)
-//                     fprintf(out, ", ");
-//     }
-//     fprintf(out, "}");
-//}
-//
+void gamma_print(struct gamma *gamma, FILE *out)
+{
+       fprintf(out, "{");
+       for (int i = 0; i<gamma->nentries; i++) {
+               ident_print(gamma->entries[i].var, out);
+               fprintf(out, "(%d) = ", gamma->entries[i].scope);
+               scheme_print(gamma->entries[i].scheme, out);
+               if (i + 1 < gamma->nentries)
+                       fprintf(out, ", ");
+       }
+       fprintf(out, "}");
+}
+
 void gamma_free(struct gamma *gamma)
 {
        for (int i = 0; i<gamma->nentries; i++) {
index eed2e32..ee10c7d 100644 (file)
@@ -47,24 +47,24 @@ struct scheme *scheme_generalise(struct gamma *gamma, struct type *t)
        return s;
 }
 
-//void scheme_print(struct scheme *scheme, FILE *out)
-//{
-//     if (scheme == NULL) {
-//             fprintf(out, "NULLSCHEME");
-//             return;
-//     }
-//     if (scheme->nvar > 0) {
-//             fprintf(out, "A.");
-//             for (int i = 0; i<scheme->nvar; i++) {
-//                     if (i > 0)
-//                             fprintf(out, " ");
-//                     ident_print(scheme->var[i], stderr);
-//             }
-//             fprintf(out, ": ");
-//     }
-//     type_print(scheme->type, out);
-//}
-//
+void scheme_print(struct scheme *scheme, FILE *out)
+{
+       if (scheme == NULL) {
+               fprintf(out, "NULLSCHEME");
+               return;
+       }
+       if (scheme->nvar > 0) {
+               fprintf(out, "A.");
+               for (int i = 0; i<scheme->nvar; i++) {
+                       if (i > 0)
+                               fprintf(out, " ");
+                       ident_print(scheme->var[i], stderr);
+               }
+               fprintf(out, ": ");
+       }
+       type_print(scheme->type, out);
+}
+
 void scheme_free(struct scheme *scheme)
 {
        type_free(scheme->type);
index da07746..aae1d5d 100644 (file)
@@ -115,10 +115,14 @@ struct type *subst_apply_t(struct subst *subst, struct type *l)
                struct subst_entry *e = bsearch(&l->data.tvar, subst->entries,
                        subst->nvar, sizeof(struct subst_entry), ident_cmpv);
                if (e != NULL) {
+                       if (e->type->type == tvar && ident_cmp(l->data.tvar,
+                                       e->type->data.tvar) == 0)
+                               break;
                        ident_free(l->data.tvar);
                        struct type *r = type_dup(e->type);
                        *l = *r;
                        free(r);
+                       return subst_apply_t(subst, l);
                }
                break;
        }}
@@ -143,10 +147,9 @@ struct scheme *subst_apply_s(struct subst *subst, struct scheme *scheme)
        return scheme;
 }
 
-
 static void giter(struct ident ident, struct scheme *s, void *st)
 {
-       subst_apply_s((struct subst *)st, s);
+       s = subst_apply_s((struct subst *)st, s);
        (void)ident;
 }
 
@@ -156,22 +159,22 @@ struct gamma *subst_apply_g(struct subst *subst, struct gamma *gamma)
        return gamma;
 }
 
-//void subst_print(struct subst *s, FILE *out)
-//{
-//     if (s == NULL) {
-//             fprintf(out, "(nil)");
-//     } else {
-//             fprintf(out, "[");
-//             for (size_t i = 0; i<s->nvar; i++) {
-//                     ident_print(s->entries[i].var, out);
-//                     fprintf(out, "->");
-//                     type_print(s->entries[i].type, out);
-//                     if (i + 1 < s->nvar)
-//                             fprintf(out, ", ");
-//             }
-//             fprintf(out, "]");
-//     }
-//}
+void subst_print(struct subst *s, FILE *out)
+{
+       if (s == NULL) {
+               fprintf(out, "(nil)");
+       } else {
+               fprintf(out, "[");
+               for (size_t i = 0; i<s->nvar; i++) {
+                       ident_print(s->entries[i].var, out);
+                       fprintf(out, "->");
+                       type_print(s->entries[i].type, out);
+                       if (i + 1 < s->nvar)
+                               fprintf(out, ", ");
+               }
+               fprintf(out, "]");
+       }
+}
 
 void subst_free(struct subst *s)
 {
index 96ed3ee..1e3cb8e 100644 (file)
@@ -24,7 +24,7 @@ struct type *subst_apply_t(struct subst *subst, struct type *l);
 struct scheme *subst_apply_s(struct subst *subst, struct scheme *scheme);
 struct gamma *subst_apply_g(struct subst *subst, struct gamma *gamma);
 
-//void subst_print(struct subst *s, FILE *out);
+void subst_print(struct subst *s, FILE *out);
 void subst_free(struct subst *s);
 
 #endif