From: Mart Lubbers Date: Mon, 18 Mar 2019 15:16:59 +0000 (+0100) Subject: mutual recursion type inference X-Git-Url: https://git.martlubbers.net/?a=commitdiff_plain;h=1f5d6dc255faec4d27c0a11efed6612f2583a3aa;p=minfp.git mutual recursion type inference --- diff --git a/Makefile b/Makefile index d532f0d..8bd8989 100644 --- a/Makefile +++ b/Makefile @@ -12,4 +12,3 @@ clean: %: %.o rts.o $(CC) $(CFLAGS) $(LDFLAGS) $(TARGET_ARCH) $^ $(OUTPUT_OPTION) - diff --git a/check.dcl b/check.dcl index 91059a7..da6c1f9 100644 --- a/check.dcl +++ b/check.dcl @@ -9,4 +9,4 @@ from ast import :: Function, :: Expression instance toString Scheme, Type -check :: [Function] -> Either [String] (Expression, Scheme) +check :: [Function] -> Either [String] (Expression, [([Char], Scheme)]) diff --git a/check.icl b/check.icl index a264430..5b0cc18 100644 --- a/check.icl +++ b/check.icl @@ -5,6 +5,7 @@ import StdEnv import Control.Monad => qualified join import Control.Monad.State import Control.Monad.Trans +import Control.Monad.Writer import Data.Either import Data.Func import Data.List @@ -18,7 +19,7 @@ import ast, scc import Text.GenPrint import StdDebug -check :: [Function] -> Either [String] (Expression, Scheme) +check :: [Function] -> Either [String] (Expression, [([Char], Scheme)]) check fs # dups = filter (\x->length x > 1) (groupBy (\(Function i _ _) (Function j _ _)->i == j) fs) | length dups > 0 = Left ["Duplicate functions: ":[toString n\\[(Function n _ _):_]<-dups]] @@ -26,7 +27,6 @@ check fs ([], _) = Left ["No start function defined"] ([Function _ [] e], fs) = (\x->(e, x)) <$> runInfer (infer preamble (makeExpression fs e)) -// = pure (makeExpression fs e, undef) ([Function _ _ _], _) = Left ["Start cannot have arguments"] makeExpression :: [Function] Expression -> Expression @@ -40,9 +40,7 @@ where nicefuns = [(l, foldr ((o) o Lambda) id i e)\\(Function l i e)<-fs] vars :: Expression [[Char]] -> [[Char]] - vars (Var v=:[m:_]) c - = [v:c] -// | m <> '_' = [v:c] + vars (Var v=:[m:_]) c = [v:c] vars (App l r) c = vars l $ vars r c vars (Lambda l e) c = [v\\v<-vars e c | v <> l] vars (Let ns e) c = flatten @@ -72,10 +70,11 @@ preamble = fromList ] :: Subst :== Map [Char] Type -:: Infer a :== StateT [Int] (Either [String]) a -runInfer :: (Infer (Subst, Type)) -> Either [String] Scheme -runInfer i = uncurry ((o) (generalize newMap) o apply) - <$> evalStateT i [0..] +:: Infer a :== StateT [Int] (WriterT [([Char], Scheme)] (Either [String])) a +runInfer :: (Infer (Subst, Type)) -> Either [String] [([Char], Scheme)] +runInfer i = case runWriterT (evalStateT i [0..]) of + Left e = Left e + Right ((s, t), w) = pure [(['start'], generalize newMap (apply s t)):w] fresh :: Infer Type fresh = getState >>= \[s:ss]->put ss >>| pure (TVar (['v':[c\\c<-:toString s]])) @@ -111,6 +110,9 @@ instance Substitutable [a] | Substitutable a where occursCheck :: [Char] -> (a -> Bool) | Substitutable a occursCheck a = isMember a o ftv +err :: [String] -> Infer a +err e = liftT (liftT (Left e)) + unify :: Type Type -> Infer Subst unify (l --> r) (l` --> r`) = unify l l` @@ -119,12 +121,16 @@ unify (l --> r) (l` --> r`) unify (TVar a) (TVar t) | a == t = pure newMap unify (TVar a) t - | occursCheck a t = liftT (Left ["Infinite type: ", toString a, " to ", toString t]) + | occursCheck a t = err ["Infinite type: ", toString a, " to ", toString t] = pure (singleton a t) unify t (TVar a) = unify (TVar a) t unify TInt TInt = pure newMap unify TBool TBool = pure newMap -unify t1 t2 = liftT (Left ["Cannot unify: ", toString t1, " with ", toString t2]) +unify t1 t2 = err ["Cannot unify: ", toString t1, " with ", toString t2] + +unifyl :: [Type] -> Infer Subst +unifyl [t1,t2:ts] = unify t1 t2 >>= \s->unifyl (map (apply s) [t2:ts]) +unifyl _ = pure newMap instantiate :: Scheme -> Infer Type instantiate (Forall as t) @@ -138,7 +144,7 @@ infer :: TypeEnv Expression -> Infer (Subst, Type) infer env (Lit (Int _)) = pure (newMap, TInt) infer env (Lit (Bool _)) = pure (newMap, TBool) infer env (Var x) = case get x env of - Nothing = liftT (Left ["Unbound variable: ", toString x]) + Nothing = err ["Unbound variable: ", toString x] Just s = (\x->(newMap, x)) <$> instantiate s infer env (App e1 e2) = fresh @@ -150,32 +156,26 @@ infer env (Lambda x b) = fresh >>= \tv-> infer ('Data.Map'.put x (Forall [] tv) env) b >>= \(s1, t1)->pure (s1, apply s1 tv --> t1) +//Non recursion //infer env (Let [(x, e1)] e2) // = infer env e1 // >>= \(s1, t1)->infer ('Data.Map'.put x (generalize (apply s1 env) t1) env) e2 // >>= \(s2, t2)->pure (s1 oo s2, t2) -infer env (Let [(x, e1)] e2) - = fresh - >>= \tv-> let env` = 'Data.Map'.put x (Forall [] tv) env - in infer env` e1 - >>= \(s1,t1)-> unify t1 tv - >>= \t->infer ('Data.Map'.put x (generalize (apply s1 env`) t1) env`) e2 - >>= \(s2, t2)->pure (s1 oo s2, t2) -infer env (Let _ _) - = liftT (Left ["Mutual recursion typechecking not implemented yet"]) -//infer env (Let xs e2) -// # (ns, bs) = unzip xs -// = sequence [fresh\\_<-ns] -// >>= \tvs-> let env` = foldr (uncurry putenv) env (zip2 ns tvs) -// in unzip <$> sequence (map infer env`) bs -// >>= \(ss,ts)-> let s = foldr (oo) newMap ss -// in //unify t1 tv -// >>= \t->infer ('Data.Map'.put x (generalize (apply s1 env`) t1) env`) e2 +//Single recursion +//infer env (Let [(x, e1)] e2) +// = fresh +// >>= \tv-> let env` = 'Data.Map'.put x (Forall [] tv) env +// in infer env` e1 +// >>= \(s1,t1)-> infer ('Data.Map'.put x (generalize (apply s1 env`) t1) env`) e2 // >>= \(s2, t2)->pure (s1 oo s2, t2) -where - putenv :: [Char] -> (Type TypeEnv -> TypeEnv) - putenv k = 'Data.Map'.put k o Forall [] - -unifyl :: [Type] -> Infer Subst -unifyl [t1,t2:ts] = unify t1 t2 >>= \s->unifyl [t2:map (apply s) ts] -unifyl _ = pure newMap +//Multiple recursion +infer env (Let xs e2) + # (ns, bs) = unzip xs + = sequence [fresh\\_<-ns] + >>= \tvs-> let env` = foldr (\(k, v)->'Data.Map'.put k (Forall [] v)) env (zip2 ns tvs) + in unzip <$> sequence (map (infer env`) bs) + >>= \(ss,ts)-> unifyl ts + >>= \s-> liftT (tell [(n, generalize (apply s env`) t)\\t<-ts & n<-ns]) + >>| let env`` = foldr (\(n, t) m->'Data.Map'.put n (generalize (apply s env`) t) m) env` (zip2 ns ts) + in infer env`` e2 + >>= \(s2, t2)->pure (s oo s2, t2) diff --git a/main.icl b/main.icl index 6b181af..69d66e6 100644 --- a/main.icl +++ b/main.icl @@ -3,7 +3,8 @@ module main import StdEnv import Data.Either import Data.Functor -import Control.Monad => qualified join +import Data.List +import Control.Monad import System.GetOpt import System.CommandLine @@ -46,9 +47,13 @@ Start w # (cs, io) = chars io # mstr = case mode of MHelp = Left [usageInfo ("Usage: " +++ argv0 +++ " [opts]\n") opts] - MLex = map (\x->toString x +++ "\n") <$> lex cs - MParse = map (\x->toString x +++ "\n") <$> (lex cs >>= parse) - MType = (\(e, x)->["type: ",toString x, "\n", toString e]) <$> (lex cs >>= parse >>= check) - MInterpret = (\x->[toString x]) <$> (lex cs >>= parse >>= check >>= int o fst) + MLex = map (nl o toString) <$> lex cs + MParse = map (nl o toString) <$> (lex cs >>= parse) + MType = map (\(t, s)->nl (toString t +++ " :: " +++ toString s)) o snd <$> (lex cs >>= parse >>= check) + MInterpret = pure o toString <$> (lex cs >>= parse >>= check >>= int o fst) MGen = lex cs >>= parse >>= check >>= gen o fst = exit (either (\_->1) (\_->0) mstr) (either id id mstr) io w + +nl x = x +++ "\n" + + diff --git a/tests/preamble.mfp b/tests/preamble.mfp index 5fd27a7..cae72d6 100644 --- a/tests/preamble.mfp +++ b/tests/preamble.mfp @@ -18,4 +18,4 @@ id x = x; even i = if (i == 0) True (odd (i - 1)); odd i = if (i == 0) False (even (i - 1)); -start = odd 5; +start = odd;