inference from haskell writing
[fp.git] / infer.icl
1 implementation module infer
2
3 import StdEnv
4
5 import qualified Data.Map as DM
6 from Data.Map import instance Functor (Map k)
7 import qualified Data.Set as DS
8 import Data.Functor
9 import Data.Func
10 import Data.Either
11 import Data.List
12 import Data.Maybe
13 import Data.GenEq
14 import Control.Applicative
15 import Control.Monad
16 import Control.Monad.Trans
17 import qualified Control.Monad.State as MS
18 import Control.Monad.State => qualified gets, put, modify
19 import Control.Monad.RWST => qualified put
20
21 import ast
22
23 derive gEq Type
24 :: TypeEnv :== 'DM'.Map String Scheme
25 :: Constraint :== (Type, Type)
26 :: Subst :== 'DM'.Map String Type
27 :: Unifier :== (Subst, [Constraint])
28 :: Infer a :== RWST TypeEnv [Constraint] InferState (Either TypeError) a
29 :: InferState = { count :: Int }
30 :: Solve a :== 'MS'.StateT Unifier (Either TypeError) a
31 :: TypeError
32 = UnboundVariable String
33 | UnificationFail Type Type
34 | UnificationMismatch [Type] [Type]
35 | InfiniteType String Type
36
37 nullSubst :: Subst
38 nullSubst = 'DM'.newMap
39
40 uni :: Type Type -> Infer ()
41 uni t1 t2 = tell [(t1, t2)]
42
43 inEnv :: (String, Scheme) (Infer a) -> Infer a
44 inEnv (x, sc) m = local (\e->'DM'.put x sc $ 'DM'.del x e) m
45
46 letters :: [String]
47 letters = map toString $ [1..] >>= flip replicateM ['a'..'z']
48
49 fresh :: Infer Type
50 fresh = get >>= \s=:{count}->'Control.Monad.RWST'.put {s & count=count + 1} >>| pure (VarType $ letters !! count)
51
52 lookupEnv :: String -> Infer Type
53 lookupEnv x = asks ('DM'.get x)
54 >>= maybe (liftT $ Left $ UnboundVariable x) instantiate
55
56 class Substitutable a
57 where
58 apply :: Subst a -> a
59 ftv :: a -> 'DS'.Set String
60
61 instance Substitutable Type
62 where
63 apply s t=:(VarType a) = maybe t id $ 'DM'.get a s
64 apply s (Arrow t1 t2) = Arrow (apply s t1) (apply s t2)
65 apply _ t = t
66
67 ftv (VarType a) = 'DS'.singleton a
68 ftv (Arrow t1 t2) = 'DS'.union (ftv t1) (ftv t2)
69 ftv t = 'DS'.newSet
70
71 instance Substitutable Scheme
72 where
73 apply s (Forall as t) = Forall as $ apply (foldr 'DM'.del s as) t
74 ftv (Forall as t) = 'DS'.difference (ftv t) ('DS'.fromList as)
75
76 instance Substitutable [a] | Substitutable a
77 where
78 apply s ls = map (apply s) ls
79 ftv t = foldr ('DS'.union o ftv) 'DS'.newSet t
80
81 instance Substitutable TypeEnv where
82 apply s env = fmap (apply s) env
83 ftv env = ftv $ 'DM'.elems env
84
85 instance Substitutable Constraint where
86 apply s (t1, t2) = (apply s t1, apply s t2)
87 ftv (t1, t2) = 'DS'.union (ftv t1) (ftv t2)
88
89 instantiate :: Scheme -> Infer Type
90 instantiate (Forall as t) = mapM (const fresh) as
91 >>= \as`->let s = 'DM'.fromList $ zip2 as as` in pure $ apply s t
92 generalize :: TypeEnv Type -> Scheme
93 generalize env t = Forall ('DS'.toList $ 'DS'.difference (ftv t) (ftv env)) t
94
95 infer :: Expression -> Infer Type
96 infer (Literal v) = case v of
97 Int _ = pure IntType
98 Bool _ = pure BoolType
99 Char _ = pure CharType
100 infer (Variable s) = lookupEnv s
101 infer (Apply e1 e2)
102 = infer e1
103 >>= \t1->infer e2
104 >>= \t2->fresh
105 >>= \tv->uni t1 (Arrow t2 tv)
106 >>| pure tv
107 infer (Lambda s e)
108 = fresh
109 >>= \tv->inEnv (s, Forall [] tv) (infer e)
110 >>= \t-> pure (Arrow tv t)
111 infer (Let x e1 e2)
112 = ask
113 >>= \en->infer e1
114 >>= \t1->inEnv (x, generalize en t1) (infer e2)
115 infer (Code c) = fresh >>= \v->pure case c of
116 "add" = Arrow v (Arrow v v)
117 "sub" = Arrow v (Arrow v v)
118 "mul" = Arrow v (Arrow v v)
119 "div" = Arrow v (Arrow v v)
120 "and" = Arrow v (Arrow v BoolType)
121 "or" = Arrow v (Arrow v BoolType)
122
123 unifies :: Type Type -> Solve Unifier
124 unifies t1 t2
125 | t1 === t2 = pure ('DM'.newMap, [])
126 unifies (VarType v) t = tbind v t
127 unifies t (VarType v) = tbind v t
128 unifies (Arrow t1 t2) (Arrow t3 t4) = unifyMany [t1, t2] [t3, t4]
129 unifies t1 t2 = liftT (Left (UnificationFail t1 t2))
130
131 unifyMany :: [Type] [Type] -> Solve Unifier
132 unifyMany [] [] = pure ('DM'.newMap, [])
133 unifyMany [t1:ts1] [t2:ts2] = unifies t1 t2
134 >>= \(su1, cs1)->unifyMany (apply su1 ts1) (apply su1 ts2)
135 >>= \(su2, cs2)->pure (su2 `compose` su1, cs1 ++ cs2)
136 unifyMany t1 t2 = liftT (Left (UnificationMismatch t1 t2))
137
138 (`compose`) infix 1 :: Subst Subst -> Subst
139 (`compose`) s1 s2 = 'DM'.union (apply s1 <$> s2) s1
140
141 tbind :: String Type -> Solve Unifier
142 tbind a t
143 | t === VarType a = pure ('DM'.newMap, [])
144 | occursCheck a t = liftT $ Left $ InfiniteType a t
145 = pure $ ('DM'.singleton a t, [])
146
147 occursCheck :: String a -> Bool | Substitutable a
148 occursCheck a t = 'DS'.member a $ ftv t
149
150 solver :: Solve Subst
151 solver = getState >>= \(su, cs)->case cs of
152 [] = pure su
153 [(t1, t2):cs0] = unifies t1 t2
154 >>= \(su1, cs1)->'MS'.put (su1 `compose` su, cs1 ++ (apply su1 cs0))
155 >>| solver