add util header, characters, stmts, etc
authorMart Lubbers <mart@martlubbers.net>
Sun, 7 Feb 2021 11:14:00 +0000 (12:14 +0100)
committerMart Lubbers <mart@martlubbers.net>
Sun, 7 Feb 2021 11:14:00 +0000 (12:14 +0100)
Makefile
ast.c
ast.h
expr.c
parse.y
scan.l
util.c [new file with mode: 0644]
util.h [new file with mode: 0644]

index 53a8ced..ba01e93 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -1,13 +1,14 @@
-CFLAGS:=-Wall -std=c99 -D_XOPEN_SOURCE=700
+CFLAGS:=-Wall -std=c99 -D_XOPEN_SOURCE=700 -ggdb
 YFLAGS:=-d
 LFLAGS:=-t
 
-OBJECTS:=scan.o parse.o expr.o ast.c
+OBJECTS:=scan.o parse.o ast.o util.o
 
 all: expr
 
 scan.c: scan.l y.tab.h
 y.tab.h: parse.c
+expr.c: y.tab.h
 
 expr: $(OBJECTS)
 
diff --git a/ast.c b/ast.c
index 4d89925..a757079 100644 (file)
--- a/ast.c
+++ b/ast.c
@@ -1,15 +1,85 @@
 #include <stdlib.h>
 #include <stdio.h>
+#include <string.h>
 
+#include "util.h"
 #include "ast.h"
 
-struct ast *ast_alloc()
+static const char *ast_type_str[] = {
+       [an_bool] = "bool", [an_binop] = "binop", [an_char] = "char",
+       [an_cons] = "cons", [an_fundecl] = "fundecl", [an_ident] = "ident",
+       [an_if] = "if", [an_int] = "int", [an_list] = "list",
+       [an_stmt_expr] = "stmt_expr", [an_unop] = "unop",
+       [an_vardecl] = "vardecl", [an_while] = "while",
+};
+static const char *binop_str[] = {
+       [binor] = "||", [binand] = "&&", [eq] = "==", [neq] = "!=",
+       [leq] = "<=", [le] = "<", [geq] = ">=", [ge] = ">", [cons] = ":",
+       [plus] = "+", [minus] = "-", [times] = "*", [divide] = "/",
+       [modulo] = "%", [power] = "^",
+};
+static const char *unop_str[] = { [inverse] = "!", [negate] = "-", };
+
+
+#define must_be(node, ntype, msg) {\
+       if ((node)->type != (ntype)) {\
+               fprintf(stderr, "%s can't be %s\n",\
+                       msg, ast_type_str[node->type]);\
+               exit(1);\
+       }\
+}
+#define ast_alloc() ((struct ast *)safe_malloc(sizeof(struct ast)))
+
+struct ast *ast_bool(bool b)
 {
-       struct ast *res = malloc(sizeof(struct ast));
-       if (res == NULL) {
-               perror("malloc");
-               exit(1);
-       }
+       struct ast *res = ast_alloc();
+       res->type = an_bool;
+       res->data.an_bool = b;
+       return res;
+}
+
+struct ast *ast_binop(struct ast *l, enum binop op, struct ast *r)
+{
+       struct ast *res = ast_alloc();
+       res->type = an_binop;
+       res->data.an_binop.l = l;
+       res->data.an_binop.op = op;
+       res->data.an_binop.r = r;
+       return res;
+}
+
+int fromHex(char c)
+{
+       if (c >= '0' && c <= '9')
+               return c-'0';
+       if (c >= 'a' && c <= 'f')
+               return c-'a'+10;
+       if (c >= 'A' && c <= 'F')
+               return c-'A'+10;
+       return -1;
+}
+
+struct ast *ast_char(const char *c)
+{
+       struct ast *res = ast_alloc();
+       res->type = an_char;
+       //regular char
+       if (strlen(c) == 3)
+               res->data.an_char = c[1];
+       //escape
+       if (strlen(c) == 4)
+               switch(c[2]) {
+               case '0': res->data.an_char = '\0'; break;
+               case 'a': res->data.an_char = '\a'; break;
+               case 'b': res->data.an_char = '\b'; break;
+               case 't': res->data.an_char = '\t'; break;
+               case 'v': res->data.an_char = '\v'; break;
+               case 'f': res->data.an_char = '\f'; break;
+               case 'r': res->data.an_char = '\r'; break;
+               }
+       //hex escape
+       if (strlen(c) == 6)
+               res->data.an_char = (fromHex(c[3])<<4)+fromHex(c[4]);
        return res;
 }
 
@@ -22,13 +92,43 @@ struct ast *ast_cons(struct ast *el, struct ast *tail)
        return res;
 }
 
-struct ast *ast_binop(struct ast *l, enum binop op, struct ast *r)
+struct ast *ast_fundecl(struct ast *ident, struct ast *args, struct ast *body)
 {
        struct ast *res = ast_alloc();
-       res->type = an_binop;
-       res->data.an_binop.l = l;
-       res->data.an_binop.op = op;
-       res->data.an_binop.r = r;
+       res->type = an_fundecl;
+
+       //ident
+       must_be(ident, an_ident, "ident of a fundecl");
+       res->data.an_fundecl.ident = ident->data.an_ident;
+       free(ident);
+
+       //args
+       must_be(args, an_list, "args of a fundecl");
+       res->data.an_fundecl.nargs = args->data.an_list.n;
+       res->data.an_fundecl.args = (char **)safe_malloc(
+               args->data.an_list.n*sizeof(char *));
+       for (int i = 0; i<args->data.an_list.n; i++) {
+               struct ast *e = args->data.an_list.ptr[i];
+               must_be(e, an_ident, "arg of a fundecl")
+               res->data.an_fundecl.args[i] = e->data.an_ident;
+               free(e);
+       }
+       free(args);
+
+       //body
+       must_be(body, an_list, "body of a fundecl");
+       res->data.an_fundecl.body = body;
+
+       return res;
+}
+
+struct ast *ast_if(struct ast *pred, struct ast *then, struct ast *els)
+{
+       struct ast *res = ast_alloc();
+       res->type = an_if;
+       res->data.an_if.pred = pred;
+       res->data.an_if.then = then;
+       res->data.an_if.els = els;
        return res;
 }
 
@@ -40,6 +140,45 @@ struct ast *ast_int(int integer)
        return res;
 }
 
+struct ast *ast_ident(char *ident)
+{
+       struct ast *res = ast_alloc();
+       res->type = an_ident;
+       res->data.an_ident = safe_strdup(ident);
+       return res;
+}
+
+struct ast *ast_list(struct ast *llist)
+{
+       struct ast *res = ast_alloc();
+       res->type = an_list;
+       res->data.an_list.n = 0;
+
+       int i = ast_llistlength(llist);
+
+       //Allocate array
+       res->data.an_list.n = i;
+       res->data.an_list.ptr = (struct ast **)safe_malloc(
+               res->data.an_list.n*sizeof(struct ast *));
+
+       struct ast *r = llist;
+       while(i > 0) {
+               res->data.an_list.ptr[--i] = r->data.an_cons.el;
+               struct ast *t = r;
+               r = r->data.an_cons.tail;
+               free(t);
+       }
+       return res;
+}
+
+struct ast *ast_stmt_expr(struct ast *expr)
+{
+       struct ast *res = ast_alloc();
+       res->type = an_stmt_expr;
+       res->data.an_stmt_expr = expr;
+       return res;
+}
+
 struct ast *ast_unop(enum unop op, struct ast *l)
 {
        struct ast *res = ast_alloc();
@@ -48,56 +187,152 @@ struct ast *ast_unop(enum unop op, struct ast *l)
        res->data.an_unop.l = l;
        return res;
 }
-static const char *binop_str[] = {
-       [binor] = "||",
-       [binand] = "&&",
-       [eq] = "==",
-       [neq] = "!=",
-       [leq] = "<=",
-       [le] = "<",
-       [geq] = ">=",
-       [ge] = ">",
-       [cons] = ":",
-       [plus] = "+",
-       [minus] = "-",
-       [times] = "*",
-       [divide] = "/",
-       [modulo] = "%",
-       [power] = "^",
-};
-static const char *unop_str[] = {
-       [inverse] = "!",
-       [negate] = "-",
+
+struct ast *ast_vardecl(struct ast *ident, struct ast *l)
+{
+       struct ast *res = ast_alloc();
+       res->type = an_vardecl;
+       must_be(ident, an_ident, "ident of a vardecl");
+
+       res->data.an_vardecl.ident = ident->data.an_ident;
+       free(ident);
+       res->data.an_vardecl.l = l;
+       return res;
+}
+
+struct ast *ast_while(struct ast *pred, struct ast *body)
+{
+       struct ast *res = ast_alloc();
+       res->type = an_while;
+       res->data.an_while.pred = pred;
+       res->data.an_while.body = body;
+       return res;
+}
+
+int ast_llistlength(struct ast *r)
+{
+       int i = 0;
+       while(r != NULL) {
+               i++;
+               if (r->type != an_cons) {
+                       return 1;
+               }
+               r = r->data.an_cons.tail;
+       }
+       return i;
+}
+
+const char *cescapes[] = {
+       [0] = "\\0", [1] = "\\x01", [2] = "\\x02", [3] = "\\x03", 
+       [4] = "\\x04", [5] = "\\x05", [6] = "\\x06", [7] = "\\a", [8] = "\\b",
+       [9] = "\\t", [10] = "\\n", [11] = "\\v", [12] = "\\f", [13] = "\\r",
+       [14] = "\\x0E", [15] = "\\x0F", [16] = "\\x10", [17] = "\\x11",
+       [18] = "\\x12", [19] = "\\x13", [20] = "\\x14", [21] = "\\x15",
+       [22] = "\\x16", [23] = "\\x17", [24] = "\\x18", [25] = "\\x19",
+       [26] = "\\x1A", [27] = "\\x1B", [28] = "\\x1C", [29] = "\\x1D",
+       [30] = "\\x1E", [31] = "\\x1F",
+       [127] = "\\x7F"
 };
 
-void ast_print(struct ast * ast, FILE *out)
+void ast_print(struct ast * ast, int indent, FILE *out)
 {
        if (ast == NULL)
                return;
        switch(ast->type) {
+       case an_bool:
+               safe_fprintf(out, "%s", ast->data.an_bool ? "true" : "false");
+               break;
        case an_binop:
-               fprintf(out, "(");
-               ast_print(ast->data.an_binop.l, out);
-               fprintf(out, "%s", binop_str[ast->data.an_binop.op]);
-               ast_print(ast->data.an_binop.r, out);
-               fprintf(out, ")");
+               safe_fprintf(out, "(");
+               ast_print(ast->data.an_binop.l, indent, out);
+               safe_fprintf(out, "%s", binop_str[ast->data.an_binop.op]);
+               ast_print(ast->data.an_binop.r, indent, out);
+               safe_fprintf(out, ")");
                break;
-       case an_cons:
-               ast_print(ast->data.an_cons.el, out);
-               fprintf(out, ";\n");
-               ast_print(ast->data.an_cons.tail, out);
+       case an_char:
+               if (ast->data.an_char < 0)
+                       safe_fprintf(out, "'?'");
+               if (ast->data.an_char < ' ' || ast->data.an_char == 127)
+                       safe_fprintf(out, "'%s'",
+                               cescapes[(int)ast->data.an_char]);
+               else
+                       safe_fprintf(out, "'%c'", ast->data.an_char);
+               break;
+       case an_fundecl:
+               pindent(indent, out);
+               safe_fprintf(out, "%s (", ast->data.an_fundecl.ident);
+               for (int i = 0; i<ast->data.an_fundecl.nargs; i++) {
+                       safe_fprintf(out, "%s", ast->data.an_fundecl.args[i]);
+                       if (i < ast->data.an_fundecl.nargs - 1)
+                               safe_fprintf(out, ", ");
+               }
+               safe_fprintf(out, ") {\n");
+               ast_print(ast->data.an_fundecl.body, indent+1, out);
+               pindent(indent, out);
+               safe_fprintf(out, "}\n");
+               break;
+       case an_if:
+               pindent(indent, out);
+               safe_fprintf(out, "if (");
+               ast_print(ast->data.an_if.pred, indent, out);
+               safe_fprintf(out, ") {\n");
+               if (ast->data.an_if.then->data.an_list.n > 0) {
+                       pindent(indent, out);
+                       ast_print(ast->data.an_if.then, indent+1, out);
+               }
+               pindent(indent, out);
+               safe_fprintf(out, "} else {\n");
+               if (ast->data.an_if.els->data.an_list.n > 0) {
+                       pindent(indent, out);
+                       ast_print(ast->data.an_if.els, indent+1, out);
+               }
+               pindent(indent, out);
+               safe_fprintf(out, "}\n");
                break;
        case an_int:
-               fprintf(out, "%d", ast->data.an_int);
+               safe_fprintf(out, "%d", ast->data.an_int);
+               break;
+       case an_ident:
+               safe_fprintf(out, "%s", ast->data.an_ident);
+               break;
+       case an_cons:
+               ast_print(ast->data.an_cons.el, indent, out);
+               ast_print(ast->data.an_cons.tail, indent, out);
+               break;
+       case an_list:
+               for (int i = 0; i<ast->data.an_list.n; i++)
+                       ast_print(ast->data.an_list.ptr[i], indent, out);
+               break;
+       case an_stmt_expr:
+               pindent(indent, out);
+               ast_print(ast->data.an_stmt_expr, indent, out);
+               safe_fprintf(out, ";\n");
                break;
        case an_unop:
-               fprintf(out, "(-");
-               ast_print(ast->data.an_unop.l, out);
-               fprintf(out, ")");
+               safe_fprintf(out, "(%s", unop_str[ast->data.an_unop.op]);
+               ast_print(ast->data.an_unop.l, indent, out);
+               safe_fprintf(out, ")");
+               break;
+       case an_vardecl:
+               pindent(indent, out);
+               safe_fprintf(out, "var %s = ", ast->data.an_vardecl.ident);
+               ast_print(ast->data.an_vardecl.l, indent, out);
+               safe_fprintf(out, ";\n");
+               break;
+       case an_while:
+               pindent(indent, out);
+               safe_fprintf(out, "while (");
+               ast_print(ast->data.an_while.pred, indent, out);
+               safe_fprintf(out, ") {\n");
+               if (ast->data.an_while.body->data.an_list.n > 0) {
+                       pindent(indent, out);
+                       ast_print(ast->data.an_while.body, indent+1, out);
+               }
+               pindent(indent, out);
+               safe_fprintf(out, "}\n");
                break;
        default:
-               fprintf(stderr, "Unsupported AST node\n");
-               exit(1);
+               die("Unsupported AST node\n");
        }
 }
 
@@ -106,19 +341,55 @@ void ast_free(struct ast *ast)
        if (ast == NULL)
                return;
        switch(ast->type) {
+       case an_bool:
+               break;
        case an_binop:
                ast_free(ast->data.an_binop.l);
                ast_free(ast->data.an_binop.r);
                break;
+       case an_char:
+               break;
        case an_cons:
                ast_free(ast->data.an_cons.el);
                ast_free(ast->data.an_cons.tail);
                break;
+       case an_fundecl:
+               free(ast->data.an_fundecl.ident);
+               for (int i = 0; i<ast->data.an_fundecl.nargs; i++)
+                       free(ast->data.an_fundecl.args[i]);
+               free(ast->data.an_fundecl.args);
+               ast_free(ast->data.an_fundecl.body);
+               break;
+       case an_if:
+               ast_free(ast->data.an_if.pred);
+               ast_free(ast->data.an_if.then);
+               ast_free(ast->data.an_if.els);
        case an_int:
                break;
+       case an_ident:
+               free(ast->data.an_ident);
+               break;
+       case an_list:
+               for (int i = 0; i<ast->data.an_list.n; i++)
+                       ast_free(ast->data.an_list.ptr[i]);
+               free(ast->data.an_list.ptr);
+               break;
+       case an_stmt_expr:
+               ast_free(ast->data.an_stmt_expr);
+               break;
+       case an_unop:
+               ast_free(ast->data.an_unop.l);
+               break;
+       case an_vardecl:
+               free(ast->data.an_vardecl.ident);
+               ast_free(ast->data.an_vardecl.l);
+               break;
+       case an_while:
+               ast_free(ast->data.an_while.pred);
+               ast_free(ast->data.an_while.body);
+               break;
        default:
-               fprintf(stderr, "Unsupported AST node\n");
-               exit(1);
+               die("Unsupported AST node: %d\n", ast->type);
        }
        free(ast);
 }
diff --git a/ast.h b/ast.h
index 872e25c..66a23fe 100644 (file)
--- a/ast.h
+++ b/ast.h
@@ -2,6 +2,7 @@
 #define AST_H
 
 #include <stdio.h>
+#include <stdbool.h>
 
 enum binop {
        binor,binand,
@@ -9,34 +10,74 @@ enum binop {
        cons,plus,minus,times,divide,modulo,power
 };
 enum unop {negate,inverse};
-enum ast_type {an_binop, an_cons, an_int, an_unop};
+enum ast_type {
+       an_binop, an_bool, an_char, an_cons, an_fundecl, an_ident, an_if,
+       an_int, an_list, an_stmt_expr, an_unop, an_vardecl, an_while
+};
 struct ast {
        enum ast_type type;
        union {
+               bool an_bool;
                struct {
                        struct ast *l;
                        enum binop op;
                        struct ast *r;
                } an_binop;
+               char an_char;
                struct {
                        struct ast *el;
                        struct ast *tail;
                } an_cons;
+               struct {
+                       char *ident;
+                       int nargs;
+                       char **args;
+                       struct ast *body; // make struct ast **?
+               } an_fundecl;
+               struct {
+                       struct ast *pred;
+                       struct ast *then;
+                       struct ast *els;
+               } an_if;
                int an_int;
+               char *an_ident;
+               struct {
+                       int n;
+                       struct ast **ptr;
+               } an_list;
+               struct ast *an_stmt_expr;
                struct {
                        enum unop op;
                        struct ast *l;
                } an_unop;
-
+               struct {
+                       char *ident;
+                       struct ast *l;
+               } an_vardecl;
+               struct {
+                       struct ast *pred;
+                       struct ast *body;
+               } an_while;
        } data;
 };
 
-struct ast *ast_cons(struct ast *el, struct ast *tail);
+struct ast *ast_bool(bool b);
 struct ast *ast_binop(struct ast *l, enum binop op, struct ast *tail);
+struct ast *ast_char(const char *c);
+struct ast *ast_cons(struct ast *el, struct ast *tail);
+struct ast *ast_fundecl(struct ast *ident, struct ast *args, struct ast *body);
+struct ast *ast_if(struct ast *pred, struct ast *then, struct ast *els);
 struct ast *ast_int(int integer);
+struct ast *ast_ident(char *ident);
+struct ast *ast_list(struct ast *llist);
+struct ast *ast_stmt_expr(struct ast *expr);
 struct ast *ast_unop(enum unop op, struct ast *l);
+struct ast *ast_vardecl(struct ast *ident, struct ast *l);
+struct ast *ast_while(struct ast *pred, struct ast *body);
+
+int ast_llistlength(struct ast *list);
 
-void ast_print(struct ast * ast, FILE *out);
+void ast_print(struct ast * ast, int indent, FILE *out);
 void ast_free(struct ast *ast);
 
 #endif
diff --git a/expr.c b/expr.c
index 455fa71..9cde378 100644 (file)
--- a/expr.c
+++ b/expr.c
@@ -10,6 +10,7 @@ int main()
         int r = yyparse(&result);
        if (r != 0)
                return r;
-       ast_print(result, stdout);
+       ast_print(result, 0, stdout);
+       ast_free(result);
        return 0;
 }
diff --git a/parse.y b/parse.y
index 6600e5b..76bbdd0 100644 (file)
--- a/parse.y
+++ b/parse.y
@@ -13,18 +13,21 @@ int yylex(void);
 
 void yyerror(struct ast **result, const char *str)
 {
-        fprintf(stderr, "error: %s\n", str);
+       ast_free(*result);
+       fprintf(stderr, "error: %s\n", str);
 }
 
 int yywrap()
 {
-        return 1;
+       return 1;
 }
 
 %}
 
 %define parse.error verbose
-%token INTEGER PLUS MINUS TIMES DIVIDE BOPEN BCLOSE SEMICOLON POWER CONS MODULO BINOR BINAND INVERSE
+%token INTEGER PLUS MINUS TIMES DIVIDE BOPEN BCLOSE SEMICOLON POWER CONS MODULO
+%token BINOR BINAND INVERSE VAR ASSIGN IDENT COMMA COPEN CCLOSE IF ELSE WHILE
+%token BOOL CHAR
 
 %parse-param { struct ast **result }
 
@@ -38,11 +41,43 @@ int yywrap()
 
 %%
 
-start : exprs { *result = $1; } ;
+start : decls { *result = ast_list($1); } ;
 
-exprs
+decls
        : { $$ = NULL; }
-       | exprs expr SEMICOLON { $$ = ast_cons($2, $1); }
+       | decls vardecl SEMICOLON { $$ = ast_cons($2, $1); }
+       | decls fundecl { $$ = ast_cons($2, $1); }
+       ;
+
+vardecl
+       : VAR IDENT ASSIGN expr { $$ = ast_vardecl($2, $4); }
+       ;
+
+fundecl
+       : IDENT BOPEN args BCLOSE COPEN body CCLOSE
+               { $$ = ast_fundecl($1, ast_list($3), ast_list($6)); }
+       ;
+
+args
+       : { $$ = NULL; }
+       | nargs
+       ;
+nargs
+       : nargs COMMA IDENT { $$ = ast_cons($3, $1); }
+       | IDENT { $$ = ast_cons($1, NULL); }
+       ;
+body
+       : { $$ = NULL; }
+       | body vardecl SEMICOLON { $$ = ast_cons($2, $1); }
+       | body stmt { $$ = ast_cons($2, $1); }
+       ;
+
+stmt
+       : IF BOPEN expr BCLOSE COPEN body CCLOSE ELSE COPEN body CCLOSE
+               { $$ = ast_if($3, ast_list($6), ast_list($10)); }
+       | WHILE BOPEN expr BCLOSE COPEN body CCLOSE
+               { $$ = ast_while($3, ast_list($6)); }
+       | expr SEMICOLON { $$ = ast_stmt_expr($1); }
        ;
 
 expr
@@ -65,4 +100,7 @@ expr
        | INVERSE expr { $$ = ast_unop(inverse, $2); }
        | BOPEN expr BCLOSE { $$ = $2; }
        | INTEGER
+       | BOOL
+       | CHAR
+       | IDENT
        ;
diff --git a/scan.l b/scan.l
index f92af61..dd51579 100644 (file)
--- a/scan.l
+++ b/scan.l
@@ -12,7 +12,13 @@ extern YYSTYPE yylval;
 
 %%
 
-[0-9]+      { yylval = ast_int(atoi(yytext)); return INTEGER; }
+if          return IF;
+else        return ELSE;
+while       return WHILE;
+var         return VAR;
+true        { yylval = ast_bool(true); return BOOL; }
+false       { yylval = ast_bool(false); return BOOL; }
+=           return ASSIGN;
 !           return INVERSE;
 \|\|        return BINOR;
 &&          return BINAND;
@@ -31,7 +37,22 @@ extern YYSTYPE yylval;
 \^          return POWER;
 \(          return BOPEN;
 \)          return BCLOSE;
+\{          return COPEN;
+\}          return CCLOSE;
 \;          return SEMICOLON;
+,           return COMMA;
+'([^']|\\[abtnvfr]|\\x[0-9a-fA-F]{2})' {
+       yylval = ast_char(yytext);
+       return CHAR;
+}
+[0-9]+      {
+               yylval = ast_int(atoi(yytext));
+               return INTEGER;
+}
+[_a-zA-Z][_a-zA-Z0-9]* {
+               yylval = ast_ident(yytext);
+               return IDENT;
+}
 [ \n\t]  ;
 
 %%
diff --git a/util.c b/util.c
new file mode 100644 (file)
index 0000000..1a89f94
--- /dev/null
+++ b/util.c
@@ -0,0 +1,52 @@
+#include <stdarg.h>
+#include <stdlib.h>
+#include <stdio.h>
+#include <string.h>
+
+void pdie(const char *msg)
+{
+       perror(msg);
+       exit(1);
+}
+
+void die(const char *msg, ...)
+{
+       va_list ap;
+       va_start(ap, msg);
+       vfprintf(stderr, msg, ap);
+       va_end(ap);
+       exit(1);
+}
+
+void pindent(int indent, FILE *out)
+{
+       for (int i = 0; i<indent; i++)
+               if (fputc('\t', out) == EOF)
+                       pdie("fputc");
+}
+
+void safe_fprintf(FILE *out, const char *msg, ...)
+{
+       va_list ap;
+       va_start(ap, msg);
+       int r = vfprintf(out, msg, ap);
+       va_end(ap);
+       if (r < 0)
+               pdie("fprintf");
+}
+
+void *safe_malloc(size_t size)
+{
+       void *res = malloc(size);
+       if (res == NULL)
+               pdie("malloc");
+       return res;
+}
+
+void *safe_strdup(const char *c)
+{
+       char *res = strdup(c);
+       if (res == NULL)
+               pdie("strdup");
+       return res;
+}
diff --git a/util.h b/util.h
new file mode 100644 (file)
index 0000000..57dd17b
--- /dev/null
+++ b/util.h
@@ -0,0 +1,14 @@
+#ifndef UTIL_H
+#define UTIL_H
+
+#include <stdarg.h>
+
+void die(const char *msg, ...);
+void pdie(const char *msg);
+
+void pindent(int indent, FILE *out);
+void safe_fprintf(FILE *out, const char *msg, ...);
+void *safe_malloc(size_t size);
+void *safe_strdup(const char *c);
+
+#endif