mutual recursion type inference
authorMart Lubbers <mart@martlubbers.net>
Mon, 18 Mar 2019 15:16:59 +0000 (16:16 +0100)
committerMart Lubbers <mart@martlubbers.net>
Mon, 18 Mar 2019 15:16:59 +0000 (16:16 +0100)
Makefile
check.dcl
check.icl
main.icl
tests/preamble.mfp

index d532f0d..8bd8989 100644 (file)
--- a/Makefile
+++ b/Makefile
@@ -12,4 +12,3 @@ clean:
 
 %: %.o rts.o
        $(CC) $(CFLAGS) $(LDFLAGS) $(TARGET_ARCH) $^ $(OUTPUT_OPTION)
-
index 91059a7..da6c1f9 100644 (file)
--- a/check.dcl
+++ b/check.dcl
@@ -9,4 +9,4 @@ from ast import :: Function, :: Expression
 
 instance toString Scheme, Type
 
-check :: [Function] -> Either [String] (Expression, Scheme)
+check :: [Function] -> Either [String] (Expression, [([Char], Scheme)])
index a264430..5b0cc18 100644 (file)
--- a/check.icl
+++ b/check.icl
@@ -5,6 +5,7 @@ import StdEnv
 import Control.Monad => qualified join
 import Control.Monad.State
 import Control.Monad.Trans
+import Control.Monad.Writer
 import Data.Either
 import Data.Func
 import Data.List
@@ -18,7 +19,7 @@ import ast, scc
 import Text.GenPrint
 import StdDebug
 
-check :: [Function] -> Either [String] (Expression, Scheme)
+check :: [Function] -> Either [String] (Expression, [([Char], Scheme)])
 check fs
        # dups = filter (\x->length x > 1) (groupBy (\(Function i _ _) (Function j _ _)->i == j) fs)
        | length dups > 0 = Left ["Duplicate functions: ":[toString n\\[(Function n _ _):_]<-dups]]
@@ -26,7 +27,6 @@ check fs
                ([], _) = Left ["No start function defined"]
                ([Function _ [] e], fs)
                        = (\x->(e, x)) <$> runInfer (infer preamble (makeExpression fs e))
-//                     = pure (makeExpression fs e, undef)
                ([Function _ _ _], _) = Left ["Start cannot have arguments"]
 
 makeExpression :: [Function] Expression -> Expression
@@ -40,9 +40,7 @@ where
        nicefuns = [(l, foldr ((o) o Lambda) id i e)\\(Function l i e)<-fs]
 
        vars :: Expression [[Char]] -> [[Char]]
-       vars (Var v=:[m:_]) c
-               = [v:c]
-//             | m <> '_' = [v:c]
+       vars (Var v=:[m:_]) c = [v:c]
        vars (App l r) c = vars l $ vars r c
        vars (Lambda l e) c = [v\\v<-vars e c | v <> l]
        vars (Let ns e) c = flatten
@@ -72,10 +70,11 @@ preamble = fromList
        ]
 :: Subst :== Map [Char] Type
 
-:: Infer a :== StateT [Int] (Either [String]) a
-runInfer :: (Infer (Subst, Type)) -> Either [String] Scheme
-runInfer i = uncurry ((o) (generalize newMap) o apply)
-       <$> evalStateT i [0..]
+:: Infer a :== StateT [Int] (WriterT [([Char], Scheme)] (Either [String])) a
+runInfer :: (Infer (Subst, Type)) -> Either [String] [([Char], Scheme)]
+runInfer i = case runWriterT (evalStateT i [0..]) of
+       Left e = Left e
+       Right ((s, t), w) = pure [(['start'], generalize newMap (apply s t)):w]
 
 fresh :: Infer Type
 fresh = getState >>= \[s:ss]->put ss >>| pure (TVar (['v':[c\\c<-:toString s]]))
@@ -111,6 +110,9 @@ instance Substitutable [a] | Substitutable a where
 occursCheck :: [Char] -> (a -> Bool) | Substitutable a
 occursCheck a = isMember a o ftv
 
+err :: [String] -> Infer a
+err e = liftT (liftT (Left e))
+
 unify :: Type Type -> Infer Subst
 unify (l --> r) (l` --> r`)
        =        unify l l`
@@ -119,12 +121,16 @@ unify (l --> r) (l` --> r`)
 unify (TVar a) (TVar t)
        | a == t = pure newMap
 unify (TVar a) t
-       | occursCheck a t = liftT (Left ["Infinite type: ", toString a, " to ", toString t])
+       | occursCheck a t = err ["Infinite type: ", toString a, " to ", toString t]
        = pure (singleton a t)
 unify t (TVar a) = unify (TVar a) t
 unify TInt TInt = pure newMap
 unify TBool TBool = pure newMap
-unify t1 t2 = liftT (Left ["Cannot unify: ", toString t1, " with ", toString t2])
+unify t1 t2 = err ["Cannot unify: ", toString t1, " with ", toString t2]
+
+unifyl :: [Type] -> Infer Subst
+unifyl [t1,t2:ts] = unify t1 t2 >>= \s->unifyl (map (apply s) [t2:ts])
+unifyl _ = pure newMap
 
 instantiate :: Scheme -> Infer Type
 instantiate (Forall as t)
@@ -138,7 +144,7 @@ infer :: TypeEnv Expression -> Infer (Subst, Type)
 infer env (Lit (Int _)) = pure (newMap, TInt)
 infer env (Lit (Bool _)) = pure (newMap, TBool)
 infer env (Var x) = case get x env of
-       Nothing = liftT (Left ["Unbound variable: ", toString x])
+       Nothing = err ["Unbound variable: ", toString x]
        Just s = (\x->(newMap, x)) <$> instantiate s
 infer env (App e1 e2)
        =              fresh
@@ -150,32 +156,26 @@ infer env (Lambda x b)
        =              fresh
        >>= \tv->      infer ('Data.Map'.put x (Forall [] tv) env) b
        >>= \(s1, t1)->pure (s1, apply s1 tv --> t1)
+//Non recursion
 //infer env (Let [(x, e1)] e2)
 //     =              infer env e1
 //     >>= \(s1, t1)->infer ('Data.Map'.put x (generalize (apply s1 env) t1) env) e2
 //     >>= \(s2, t2)->pure (s1 oo s2, t2)
-infer env (Let [(x, e1)] e2)
-       =              fresh
-       >>= \tv->      let env` = 'Data.Map'.put x (Forall [] tv) env
-                      in infer env` e1
-       >>= \(s1,t1)-> unify t1 tv
-       >>= \t->infer ('Data.Map'.put x (generalize (apply s1 env`) t1) env`) e2
-       >>= \(s2, t2)->pure (s1 oo s2, t2)
-infer env (Let _ _)
-       = liftT (Left ["Mutual recursion typechecking not implemented yet"])
-//infer env (Let xs e2)
-//     # (ns, bs) = unzip xs
-//     =              sequence [fresh\\_<-ns]
-//     >>= \tvs->     let env` = foldr (uncurry putenv) env (zip2 ns tvs)
-//                    in  unzip <$> sequence (map infer env`) bs
-//     >>= \(ss,ts)-> let s = foldr (oo) newMap ss
-//                    in  //unify t1 tv
-//     >>= \t->infer ('Data.Map'.put x (generalize (apply s1 env`) t1) env`) e2
+//Single recursion
+//infer env (Let [(x, e1)] e2)
+//     =              fresh
+//     >>= \tv->      let env` = 'Data.Map'.put x (Forall [] tv) env
+//                    in infer env` e1
+//     >>= \(s1,t1)-> infer ('Data.Map'.put x (generalize (apply s1 env`) t1) env`) e2
 //     >>= \(s2, t2)->pure (s1 oo s2, t2)
-where
-       putenv :: [Char] -> (Type TypeEnv -> TypeEnv)
-       putenv k = 'Data.Map'.put k o Forall []
-
-unifyl :: [Type] -> Infer Subst
-unifyl [t1,t2:ts] = unify t1 t2 >>= \s->unifyl [t2:map (apply s) ts]
-unifyl _ = pure newMap
+//Multiple recursion
+infer env (Let xs e2)
+       # (ns, bs) = unzip xs
+       =              sequence [fresh\\_<-ns]
+       >>= \tvs->     let env` = foldr (\(k, v)->'Data.Map'.put k (Forall [] v)) env (zip2 ns tvs)
+                      in  unzip <$> sequence (map (infer env`) bs)
+       >>= \(ss,ts)-> unifyl ts
+       >>= \s->       liftT (tell [(n, generalize (apply s env`) t)\\t<-ts & n<-ns])
+       >>|            let env`` = foldr (\(n, t) m->'Data.Map'.put n (generalize (apply s env`) t) m) env` (zip2 ns ts)
+                      in infer env`` e2
+       >>= \(s2, t2)->pure (s oo s2, t2)
index 6b181af..69d66e6 100644 (file)
--- a/main.icl
+++ b/main.icl
@@ -3,7 +3,8 @@ module main
 import StdEnv
 import Data.Either
 import Data.Functor
-import Control.Monad => qualified join
+import Data.List
+import Control.Monad
 import System.GetOpt
 import System.CommandLine
 
@@ -46,9 +47,13 @@ Start w
        # (cs, io) = chars io
        # mstr = case mode of
                MHelp = Left [usageInfo ("Usage: " +++ argv0 +++ " [opts]\n") opts]
-               MLex = map (\x->toString x +++ "\n") <$> lex cs
-               MParse = map (\x->toString x +++ "\n") <$> (lex cs >>= parse)
-               MType = (\(e, x)->["type: ",toString x, "\n", toString e]) <$> (lex cs >>= parse >>= check)
-               MInterpret = (\x->[toString x]) <$> (lex cs >>= parse >>= check >>= int o fst)
+               MLex = map (nl o toString) <$> lex cs
+               MParse = map (nl o toString) <$> (lex cs >>= parse)
+               MType = map (\(t, s)->nl (toString t +++ " :: " +++ toString s)) o snd <$> (lex cs >>= parse >>= check)
+               MInterpret = pure o toString <$> (lex cs >>= parse >>= check >>= int o fst)
                MGen = lex cs >>= parse >>= check >>= gen o fst
        = exit (either (\_->1) (\_->0) mstr) (either id id mstr) io w
+
+nl x = x +++ "\n"
+
+
index 5fd27a7..cae72d6 100644 (file)
@@ -18,4 +18,4 @@ id x = x;
 even i = if (i == 0) True (odd (i - 1));
 odd i = if (i == 0) False (even (i - 1));
 
-start = odd 5;
+start = odd;