mutual recursion type inference
[minfp.git] / check.icl
index a264430..5b0cc18 100644 (file)
--- a/check.icl
+++ b/check.icl
@@ -5,6 +5,7 @@ import StdEnv
 import Control.Monad => qualified join
 import Control.Monad.State
 import Control.Monad.Trans
+import Control.Monad.Writer
 import Data.Either
 import Data.Func
 import Data.List
@@ -18,7 +19,7 @@ import ast, scc
 import Text.GenPrint
 import StdDebug
 
-check :: [Function] -> Either [String] (Expression, Scheme)
+check :: [Function] -> Either [String] (Expression, [([Char], Scheme)])
 check fs
        # dups = filter (\x->length x > 1) (groupBy (\(Function i _ _) (Function j _ _)->i == j) fs)
        | length dups > 0 = Left ["Duplicate functions: ":[toString n\\[(Function n _ _):_]<-dups]]
@@ -26,7 +27,6 @@ check fs
                ([], _) = Left ["No start function defined"]
                ([Function _ [] e], fs)
                        = (\x->(e, x)) <$> runInfer (infer preamble (makeExpression fs e))
-//                     = pure (makeExpression fs e, undef)
                ([Function _ _ _], _) = Left ["Start cannot have arguments"]
 
 makeExpression :: [Function] Expression -> Expression
@@ -40,9 +40,7 @@ where
        nicefuns = [(l, foldr ((o) o Lambda) id i e)\\(Function l i e)<-fs]
 
        vars :: Expression [[Char]] -> [[Char]]
-       vars (Var v=:[m:_]) c
-               = [v:c]
-//             | m <> '_' = [v:c]
+       vars (Var v=:[m:_]) c = [v:c]
        vars (App l r) c = vars l $ vars r c
        vars (Lambda l e) c = [v\\v<-vars e c | v <> l]
        vars (Let ns e) c = flatten
@@ -72,10 +70,11 @@ preamble = fromList
        ]
 :: Subst :== Map [Char] Type
 
-:: Infer a :== StateT [Int] (Either [String]) a
-runInfer :: (Infer (Subst, Type)) -> Either [String] Scheme
-runInfer i = uncurry ((o) (generalize newMap) o apply)
-       <$> evalStateT i [0..]
+:: Infer a :== StateT [Int] (WriterT [([Char], Scheme)] (Either [String])) a
+runInfer :: (Infer (Subst, Type)) -> Either [String] [([Char], Scheme)]
+runInfer i = case runWriterT (evalStateT i [0..]) of
+       Left e = Left e
+       Right ((s, t), w) = pure [(['start'], generalize newMap (apply s t)):w]
 
 fresh :: Infer Type
 fresh = getState >>= \[s:ss]->put ss >>| pure (TVar (['v':[c\\c<-:toString s]]))
@@ -111,6 +110,9 @@ instance Substitutable [a] | Substitutable a where
 occursCheck :: [Char] -> (a -> Bool) | Substitutable a
 occursCheck a = isMember a o ftv
 
+err :: [String] -> Infer a
+err e = liftT (liftT (Left e))
+
 unify :: Type Type -> Infer Subst
 unify (l --> r) (l` --> r`)
        =        unify l l`
@@ -119,12 +121,16 @@ unify (l --> r) (l` --> r`)
 unify (TVar a) (TVar t)
        | a == t = pure newMap
 unify (TVar a) t
-       | occursCheck a t = liftT (Left ["Infinite type: ", toString a, " to ", toString t])
+       | occursCheck a t = err ["Infinite type: ", toString a, " to ", toString t]
        = pure (singleton a t)
 unify t (TVar a) = unify (TVar a) t
 unify TInt TInt = pure newMap
 unify TBool TBool = pure newMap
-unify t1 t2 = liftT (Left ["Cannot unify: ", toString t1, " with ", toString t2])
+unify t1 t2 = err ["Cannot unify: ", toString t1, " with ", toString t2]
+
+unifyl :: [Type] -> Infer Subst
+unifyl [t1,t2:ts] = unify t1 t2 >>= \s->unifyl (map (apply s) [t2:ts])
+unifyl _ = pure newMap
 
 instantiate :: Scheme -> Infer Type
 instantiate (Forall as t)
@@ -138,7 +144,7 @@ infer :: TypeEnv Expression -> Infer (Subst, Type)
 infer env (Lit (Int _)) = pure (newMap, TInt)
 infer env (Lit (Bool _)) = pure (newMap, TBool)
 infer env (Var x) = case get x env of
-       Nothing = liftT (Left ["Unbound variable: ", toString x])
+       Nothing = err ["Unbound variable: ", toString x]
        Just s = (\x->(newMap, x)) <$> instantiate s
 infer env (App e1 e2)
        =              fresh
@@ -150,32 +156,26 @@ infer env (Lambda x b)
        =              fresh
        >>= \tv->      infer ('Data.Map'.put x (Forall [] tv) env) b
        >>= \(s1, t1)->pure (s1, apply s1 tv --> t1)
+//Non recursion
 //infer env (Let [(x, e1)] e2)
 //     =              infer env e1
 //     >>= \(s1, t1)->infer ('Data.Map'.put x (generalize (apply s1 env) t1) env) e2
 //     >>= \(s2, t2)->pure (s1 oo s2, t2)
-infer env (Let [(x, e1)] e2)
-       =              fresh
-       >>= \tv->      let env` = 'Data.Map'.put x (Forall [] tv) env
-                      in infer env` e1
-       >>= \(s1,t1)-> unify t1 tv
-       >>= \t->infer ('Data.Map'.put x (generalize (apply s1 env`) t1) env`) e2
-       >>= \(s2, t2)->pure (s1 oo s2, t2)
-infer env (Let _ _)
-       = liftT (Left ["Mutual recursion typechecking not implemented yet"])
-//infer env (Let xs e2)
-//     # (ns, bs) = unzip xs
-//     =              sequence [fresh\\_<-ns]
-//     >>= \tvs->     let env` = foldr (uncurry putenv) env (zip2 ns tvs)
-//                    in  unzip <$> sequence (map infer env`) bs
-//     >>= \(ss,ts)-> let s = foldr (oo) newMap ss
-//                    in  //unify t1 tv
-//     >>= \t->infer ('Data.Map'.put x (generalize (apply s1 env`) t1) env`) e2
+//Single recursion
+//infer env (Let [(x, e1)] e2)
+//     =              fresh
+//     >>= \tv->      let env` = 'Data.Map'.put x (Forall [] tv) env
+//                    in infer env` e1
+//     >>= \(s1,t1)-> infer ('Data.Map'.put x (generalize (apply s1 env`) t1) env`) e2
 //     >>= \(s2, t2)->pure (s1 oo s2, t2)
-where
-       putenv :: [Char] -> (Type TypeEnv -> TypeEnv)
-       putenv k = 'Data.Map'.put k o Forall []
-
-unifyl :: [Type] -> Infer Subst
-unifyl [t1,t2:ts] = unify t1 t2 >>= \s->unifyl [t2:map (apply s) ts]
-unifyl _ = pure newMap
+//Multiple recursion
+infer env (Let xs e2)
+       # (ns, bs) = unzip xs
+       =              sequence [fresh\\_<-ns]
+       >>= \tvs->     let env` = foldr (\(k, v)->'Data.Map'.put k (Forall [] v)) env (zip2 ns tvs)
+                      in  unzip <$> sequence (map (infer env`) bs)
+       >>= \(ss,ts)-> unifyl ts
+       >>= \s->       liftT (tell [(n, generalize (apply s env`) t)\\t<-ts & n<-ns])
+       >>|            let env`` = foldr (\(n, t) m->'Data.Map'.put n (generalize (apply s env`) t) m) env` (zip2 ns ts)
+                      in infer env`` e2
+       >>= \(s2, t2)->pure (s oo s2, t2)