Type inference for functions works YAAAY
[cc1516.git] / sem.icl
1 implementation module sem
2
3 import qualified Data.Map as Map
4
5 from Data.Func import $
6 from StdFunc import o, flip, const, id
7
8 import Control.Monad
9 import Control.Monad.Trans
10 import Control.Monad.State
11 import Data.Either
12 import Data.Maybe
13 import Data.Monoid
14 import Data.List
15 import Data.Functor
16 import Data.Tuple
17
18 import StdString
19 import StdTuple
20 import StdList
21 import StdMisc
22 import StdEnum
23 import GenEq
24
25 from Text import class Text(concat), instance Text String
26
27 import AST
28
29
30 :: Scheme = Forall [TVar] Type
31 :: Gamma :== 'Map'.Map String Scheme //map from Variables! to types
32 :: Typing a :== StateT (Gamma, [TVar]) (Either SemError) a
33 :: Substitution :== 'Map'.Map TVar Type
34 :: Constraints :== [(Type, Type)]
35 :: SemError
36 = ParseError Pos String
37 | UnifyError Pos Type Type
38 | InfiniteTypeError Pos Type
39 | FieldSelectorError Pos Type FieldSelector
40 | OperatorError Pos Op2 Type
41 | UndeclaredVariableError Pos String
42 | ArgumentMisMatchError Pos String
43 | SanityError Pos String
44 | Error String
45
46 instance zero Gamma where
47 zero = 'Map'.newMap
48
49 variableStream :: [TVar]
50 variableStream = map toString [1..]
51
52 sem :: AST -> Either [SemError] AST
53 //sem a = pure a
54 sem (AST fd) = case foldM (const $ hasNoDups fd) () fd
55 >>| foldM (const isNiceMain) () fd
56 >>| hasMain fd
57 >>| evalStateT (type fd) (zero, variableStream) of
58 Left e = Left [e]
59 Right fds = Right (AST fds)
60 //_ = case execRWST (constraints fd) zero variableStream of
61 // Left e = Left [e]
62 // Right (a, b) = Right b
63 where
64 constraints :: [FunDecl] -> Typing ()
65 constraints _ = pure ()
66 //TODO: fix
67 //constraints fds = mapM_ funconstraint fds >>| pure ()
68
69 funconstraint :: FunDecl -> Typing ()
70 funconstraint fd=:(FunDecl _ ident args mt vardecls stmts) = case mt of
71 Nothing = abort "Cannot infer functions yet"
72 _ = pure ()
73 //Just t = inEnv (ident, (Forall [] t)) (
74 // mapM_ vardeclconstraint vardecls >>| pure ())
75
76 vardeclconstraint :: VarDecl -> Typing ()
77 vardeclconstraint _ = pure ()
78 //TODO: fix!
79 //vardeclconstraint (VarDecl p mt ident expr) = infer expr
80 //>>= \it->inEnv (ident, (Forall [] it)) (pure ())
81
82 hasNoDups :: [FunDecl] FunDecl -> Either SemError ()
83 hasNoDups fds (FunDecl p n _ _ _ _)
84 # mbs = map (\(FunDecl p` n` _ _ _ _)->if (n == n`) (Just p`) Nothing) fds
85 = case catMaybes mbs of
86 [] = Left $ SanityError p "HUH THIS SHOULDN'T HAPPEN"
87 [x] = pure ()
88 [_:x] = Left $ SanityError p (concat
89 [n, " multiply defined at ", toString p])
90
91 hasMain :: [FunDecl] -> Either SemError ()
92 hasMain [(FunDecl _ "main" _ _ _ _):fd] = pure ()
93 hasMain [_:fd] = hasMain fd
94 hasMain [] = Left $ SanityError zero "no main function defined"
95
96 isNiceMain :: FunDecl -> Either SemError ()
97 isNiceMain (FunDecl p "main" as mt _ _) = case (as, mt) of
98 ([_:_], _) = Left $ SanityError p "main must have arity 0"
99 ([], t) = (case t of
100 Nothing = pure ()
101 Just VoidType = pure ()
102 _ = Left $ SanityError p "main has to return Void")
103 isNiceMain _ = pure ()
104
105 class Typeable a where
106 ftv :: a -> [TVar]
107 subst :: Substitution a -> a
108
109 instance Typeable Scheme where
110 ftv (Forall bound t) = difference (ftv t) bound
111 subst s (Forall bound t) = Forall bound $ subst s_ t
112 where s_ = 'Map'.filterWithKey (\k _ -> not (elem k bound)) s
113
114 instance Typeable [a] | Typeable a where
115 ftv types = foldr (\t ts-> ftv t ++ ts) [] types
116 subst s ts = map (\t->subst s t) ts
117
118 instance Typeable Type where
119 ftv (TupleType (t1, t2)) = ftv t1 ++ ftv t2
120 ftv (ListType t) = ftv t
121 ftv (IdType tvar) = [tvar]
122 ftv (t1 ->> t2) = ftv t1 ++ ftv t2
123 ftv _ = []
124 subst s (TupleType (t1, t2))= TupleType (subst s t1, subst s t2)
125 subst s (ListType t1) = ListType (subst s t1)
126 subst s (t1 ->> t2) = (subst s t1) ->> (subst s t2)
127 subst s t1=:(IdType tvar) = 'Map'.findWithDefault t1 tvar s
128 subst s t = t
129
130 instance Typeable Gamma where
131 ftv gamma = concatMap id $ map ftv ('Map'.elems gamma)
132 subst s gamma = Mapmap (subst s) gamma
133
134 extend :: String Scheme Gamma -> Gamma
135 extend k t g = 'Map'.put k t g
136
137 //// ------------------------
138 //// algorithm U, Unification
139 //// ------------------------
140 instance zero Substitution where zero = 'Map'.newMap
141
142 compose :: Substitution Substitution -> Substitution
143 compose s1 s2 = 'Map'.union (Mapmap (subst s1) s2) s1
144 //Note: just like function compositon compose does snd first
145
146 occurs :: TVar a -> Bool | Typeable a
147 occurs tvar a = elem tvar (ftv a)
148
149 unify :: Type Type -> Either SemError Substitution
150 unify t1 t2=:(IdType tv) | t1 == (IdType tv) = Right zero
151 | occurs tv t1 = Left $ InfiniteTypeError zero t1
152 | otherwise = Right $ 'Map'.singleton tv t1
153 unify t1=:(IdType tv) t2 = unify t2 t1
154 unify (ta1->>ta2) (tb1->>tb2) = unify ta1 tb1 >>= \s1->
155 unify ta2 tb2 >>= \s2->
156 Right $ compose s1 s2
157 unify (TupleType (ta1,ta2)) (TupleType (tb1,tb2)) = unify ta1 tb1 >>= \s1->
158 unify ta2 tb2 >>= \s2->
159 Right $ compose s1 s2
160 unify (ListType t1) (ListType t2) = unify t1 t2
161 unify t1 t2 | t1 == t2 = Right zero
162 | otherwise = Left $ UnifyError zero t1 t2
163
164 //// ------------------------
165 //// Algorithm M, Inference and Solving
166 //// ------------------------
167 gamma :: Typing Gamma
168 gamma = gets fst
169 putGamma :: Gamma -> Typing ()
170 putGamma g = modify (appFst $ const g) >>| pure ()
171 changeGamma :: (Gamma -> Gamma) -> Typing Gamma
172 changeGamma f = modify (appFst f) >>| gamma
173 withGamma :: (Gamma -> a) -> Typing a
174 withGamma f = f <$> gamma
175 fresh :: Typing Type
176 fresh = gets snd >>= \vars->
177 modify (appSnd $ const $ tail vars) >>|
178 pure (IdType (head vars))
179
180 lift :: (Either SemError a) -> Typing a
181 lift (Left e) = liftT $ Left e
182 lift (Right v) = pure v
183
184 //instantiate maps a schemes type variables to variables with fresh names
185 //and drops the quantification: i.e. forall a,b.a->[b] becomes c->[d]
186 instantiate :: Scheme -> Typing Type
187 instantiate (Forall bound t) =
188 mapM (const fresh) bound >>= \newVars->
189 let s = 'Map'.fromList (zip (bound,newVars)) in
190 pure (subst s t)
191
192 //generalize quentifies all free type variables in a type which are not
193 //in the gamma
194 generalize :: Type -> Typing Scheme
195 generalize t = gamma >>= \g-> pure $ Forall (difference (ftv t) (ftv g)) t
196
197 lookup :: String -> Typing Type
198 lookup k = gamma >>= \g-> case 'Map'.member k g of
199 False = liftT (Left $ UndeclaredVariableError zero k)
200 True = instantiate $ 'Map'.find k g
201
202 //The inference class
203 //When tying it all together we will treat the program is a big
204 //let x=e1 in let y=e2 in ....
205 class infer a :: a -> Typing (Substitution, Type)
206
207 ////---- Inference for Expressions ----
208
209 instance infer Expr where
210 infer e = case e of
211 VarExpr _ (VarDef k fs) = (\t->(zero,t)) <$> lookup k
212 //instantiate is key for the let polymorphism!
213 //TODO: field selectors
214
215 Op2Expr _ e1 op e2 =
216 infer e1 >>= \(s1, t1) ->
217 infer e2 >>= \(s2, t2) ->
218 fresh >>= \tv ->
219 let given = t1 ->> t2 ->> tv in
220 op2Type op >>= \expected ->
221 lift (unify expected given) >>= \s3 ->
222 pure ((compose s3 $ compose s2 s1), subst s3 tv)
223
224 Op1Expr _ op e1 =
225 abort "infereing op1" >>|
226 infer e1 >>= \(s1, t1) ->
227 fresh >>= \tv ->
228 let given = t1 ->> tv in
229 op1Type op >>= \expected ->
230 lift (unify expected given) >>= \s2 ->
231 pure (compose s2 s1, subst s2 tv)
232
233 EmptyListExpr _ = abort "infereing []" >>| (\tv->(zero,tv)) <$> fresh
234
235 TupleExpr _ (e1, e2) =
236 abort "infereing (,)" >>|
237 infer e1 >>= \(s1, t1) ->
238 infer e2 >>= \(s2, t2) ->
239 pure (compose s2 s1, TupleType (t1,t2))
240
241 FunExpr _ f args fs = //todo: fieldselectors
242 lookup f >>= \expected ->
243 let accST = (\(s,ts) e->infer e >>= \(s_,et)->pure (compose s_ s,ts++[et])) in
244 foldM accST (zero,[]) args >>= \(s1, argTs)->
245 //abort (concat (["argsTs: "] ++ (map toString argTs))) >>|
246 fresh >>= \tv->
247 let given = foldr (->>) tv argTs in
248 lift (unify expected given) >>= \s2->
249 pure (compose s2 s1, subst s2 tv)
250
251 IntExpr _ _ = pure $ (zero, IntType)
252 BoolExpr _ _ = pure $ (zero, BoolType)
253 CharExpr _ _ = pure $ (zero, CharType)
254
255
256 op2Type :: Op2 -> Typing Type
257 op2Type op
258 | elem op [BiPlus, BiMinus, BiTimes, BiDivide, BiMod]
259 = pure (IntType ->> IntType ->> IntType)
260 | elem op [BiEquals, BiUnEqual]
261 = fresh >>= \t1-> fresh >>= \t2-> pure (t1 ->> t2 ->> BoolType)
262 | elem op [BiLesser, BiGreater, BiLesserEq, BiGreaterEq]
263 = pure (IntType ->> IntType ->> BoolType)
264 | elem op [BiAnd, BiOr]
265 = pure (BoolType ->> BoolType ->> BoolType)
266 | op == BiCons
267 = fresh >>= \t1-> pure (t1 ->> ListType t1 ->> ListType t1)
268
269 op1Type :: Op1 -> Typing Type
270 op1Type UnNegation = pure $ (BoolType ->> BoolType)
271 op1Type UnMinus = pure $ (IntType ->> IntType)
272
273 ////----- Inference for Statements -----
274 applySubst :: Substitution -> Typing Gamma
275 applySubst s = changeGamma (subst s)
276
277 instance infer Stmt where
278 infer s = case s of
279 IfStmt e th el =
280 infer e >>= \(s1, et)->
281 lift (unify et BoolType) >>= \s2 ->
282 applySubst (compose s2 s1) >>|
283 infer th >>= \(s3, tht)->
284 applySubst s3 >>|
285 infer el >>= \(s4, elt)->
286 applySubst s4 >>|
287 lift (unify tht elt) >>= \s5->
288 pure (compose s5 $ compose s4 $ compose s3 $ compose s2 s1, subst s5 tht)
289
290 WhileStmt e wh =
291 infer e >>= \(s1, et)->
292 lift (unify et BoolType) >>= \s2 ->
293 applySubst (compose s2 s1) >>|
294 infer wh >>= \(s3, wht)->
295 pure (compose s3 $ compose s2 s1, subst s3 wht)
296
297 AssStmt (VarDef k fs) e =
298 infer e >>= \(s1, et)->
299 applySubst s1 >>|
300 changeGamma (extend k (Forall [] et)) >>| //todo: fieldselectors
301 pure (s1, VoidType)
302
303 FunStmt f es = undef //what is this?
304
305 ReturnStmt Nothing = pure (zero, VoidType)
306 ReturnStmt (Just e) = infer e
307
308 //The type of a list of statements is either an encountered
309 //return, or VoidType
310 instance infer [a] | infer a where
311 infer [] = pure (zero, VoidType)
312 infer [stmt:ss] =
313 infer stmt >>= \(s1, t1) ->
314 applySubst s1 >>|
315 infer ss >>= \(s2, t2) ->
316 applySubst s2 >>|
317 case t1 of
318 VoidType = pure (compose s2 s1, t2)
319 _ = case t2 of
320 VoidType = pure (compose s2 s1, t1)
321 _ = lift (unify t1 t2) >>= \s3 ->
322 pure (compose s3 $ compose s2 s1, t1)
323
324 //the type class inferes the type of an AST element (VarDecl or FunDecl)
325 //and adds it to the AST element
326 class type a :: a -> Typing a
327
328 instance type VarDecl where
329 type (VarDecl p expected k e) =
330 infer e >>= \(s, given) ->
331 applySubst s >>|
332 case expected of
333 Nothing = pure zero
334 Just expected_ = lift (unify expected_ given)
335 >>|
336 generalize given >>= \t ->
337 changeGamma (extend k t) >>|
338 pure (VarDecl p (Just given) k e)
339
340 instance type FunDecl where
341 type (FunDecl p f args expected vds stmts) =
342 introduce f >>|
343 mapM introduce args >>= \argTs->
344 type vds >>= \tVds->
345 infer stmts >>= \(s1, result)->
346 let given = foldr (->>) result argTs in
347 applySubst s1 >>|
348 (case expected of
349 Nothing = pure zero
350 Just expected_ = lift (unify expected_ given))
351 >>= \s2 ->
352 let ftype = subst (compose s2 s1) given in
353 generalize ftype >>= \t->
354 changeGamma (extend f t) >>|
355 pure (FunDecl p f args (Just ftype) tVds stmts)
356
357 instance toString (Maybe a) | toString a where
358 toString Nothing = "Nothing"
359 toString (Just e) = concat ["Just ", toString e]
360
361 instance type [a] | type a where
362 type dcls = mapM type dcls
363
364 introduce :: String -> Typing Type
365 introduce k =
366 fresh >>= \tv ->
367 changeGamma (extend k (Forall [] tv)) >>|
368 pure tv
369
370 instance toString Scheme where
371 toString (Forall x t) =
372 concat ["Forall ": intersperse "," x] +++ concat [". ", toString t];
373
374 instance toString Gamma where
375 toString mp =
376 concat [concat [k, ": ", toString v, "\n"]\\(k, v)<-'Map'.toList mp]
377
378 instance toString Substitution where
379 toString subs =
380 concat [concat [k, ": ", toString t, "\n"]\\(k, t)<-'Map'.toList subs]
381
382 instance toString SemError where
383 toString (SanityError p e) = concat [toString p,
384 "SemError: SanityError: ", e]
385 toString (ParseError p s) = concat [toString p,
386 "ParseError: ", s]
387 toString (UnifyError p t1 t2) = concat [toString p,
388 "Can not unify types, expected|given:\n", toString t1,
389 "\n", toString t2]
390 toString (InfiniteTypeError p t) = concat [toString p,
391 "Infinite type: ", toString t]
392 toString (FieldSelectorError p t fs) = concat [toString p,
393 "Can not run fieldselector '", toString fs, "' on type: ",
394 toString t]
395 toString (OperatorError p op t) = concat [toString p,
396 "Operator error, operator '", toString op, "' can not be",
397 "used on type: ", toString t]
398 toString (UndeclaredVariableError p k) = concat [toString p,
399 "Undeclared identifier: ", k]
400 toString (ArgumentMisMatchError p str) = concat [toString p,
401 "Argument mismatch: ", str]
402 toString (Error e) = concat ["Unknown error during semantical",
403 "analysis: ", e]
404
405 instance MonadTrans (StateT (Gamma, [TVar])) where
406 liftT m = StateT \s-> m >>= \a-> return (a, s)
407
408 Mapmap :: (a->b) ('Map'.Map k a) -> ('Map'.Map k b)
409 Mapmap _ 'Map'.Tip = 'Map'.Tip
410 Mapmap f ('Map'.Bin sz k v ml mr) = 'Map'.Bin sz k (f v)
411 (Mapmap f ml)
412 (Mapmap f mr)
413
414 //// ------------------------
415 //// First step: Inference
416 //// ------------------------//
417
418 //unify :: Type Type -> Infer ()
419 //unify t1 t2 = tell [(t1, t2)]//
420
421 //fresh :: Infer Type
422 //fresh = (gets id) >>= \vars-> (put $ tail vars) >>| (pure $ IdType $ head vars)//
423
424 //op2Type :: Op2 -> Infer Type
425 //op2Type op
426 //| elem op [BiPlus, BiMinus, BiTimes, BiDivide, BiMod]
427 // = pure (IntType ->> IntType ->> IntType)
428 //| elem op [BiEquals, BiUnEqual]
429 // = fresh >>= \t1-> fresh >>= \t2-> pure (t1 ->> t2 ->> BoolType)
430 //| elem op [BiLesser, BiGreater, BiLesserEq, BiGreaterEq]
431 // = pure (IntType ->> IntType ->> BoolType)
432 //| elem op [BiAnd, BiOr]
433 // = pure (BoolType ->> BoolType ->> BoolType)
434 //| op == BiCons
435 // = fresh >>= \t1-> pure (t1 ->> ListType t1 ->> ListType t1)//
436
437 //op1Type :: Op1 -> Infer Type
438 //op1Type UnNegation = pure $ (BoolType ->> BoolType)
439 //op1Type UnMinus = pure $ (IntType ->> IntType)//
440
441 ////instantiate :: Scheme -> Infer Type
442 ////instantiate (Forall as t) = mapM (const fresh) as//
443
444 //lookupEnv :: String -> Infer Type
445 //lookupEnv ident = asks ('Map'.get ident)
446 // >>= \m->case m of
447 // Nothing = liftT $ Left $ UndeclaredVariableError zero ident
448 // Just (Forall as t) = pure t //instantiate ???//
449
450 //class infer a :: a -> Infer Type
451 //instance infer Expr where
452 // infer (VarExpr _ (VarDef ident fs)) = lookupEnv ident
453 // infer (Op2Expr _ e1 op e2) =
454 // infer e1 >>= \t1 ->
455 // infer e2 >>= \t2 ->
456 // fresh >>= \frsh ->
457 // let given = t1 ->> (t2 ->> frsh) in
458 // op2Type op >>= \expected ->
459 // unify expected given >>|
460 // return frsh
461 // infer (Op1Expr _ op e) =
462 // infer e >>= \t1 ->
463 // fresh >>= \frsh ->
464 // let given = t1 ->> frsh in
465 // op1Type op >>= \expected ->
466 // unify expected given >>|
467 // pure frsh
468 // infer (IntExpr _ _) = pure IntType
469 // infer (CharExpr _ _) = pure CharType
470 // infer (BoolExpr _ _) = pure BoolType
471 // infer (FunExpr _ f args sels) = //todo, iets met field selectors
472 // lookupEnv f >>= \expected ->
473 // fresh >>= \frsh ->
474 // mapM infer args >>= \argTypes ->
475 // let given = foldr (->>) frsh argTypes in
476 // unify expected given >>|
477 // pure frsh
478 // infer (EmptyListExpr _) = ListType <$> fresh
479 // infer (TupleExpr _ (e1, e2)) =
480 // infer e1 >>= \et1->infer e2 >>= \et2->pure $ TupleType (et1, et2)//
481
482 ////:: VarDef = VarDef String [FieldSelector]
483 ////:: FieldSelector = FieldHd | FieldTl | FieldFst | FieldSnd
484 ////:: Op1 = UnNegation | UnMinus
485 ////:: Op2 = BiPlus | BiMinus | BiTimes | BiDivide | BiMod | BiEquals | BiLesser |
486 //// BiGreater | BiLesserEq | BiGreaterEq | BiUnEqual | BiAnd | BiOr | BiCons
487 ////:: FunDecl = FunDecl Pos String [String] (Maybe Type) [VarDecl] [Stmt]
488 ////:: FunCall = FunCall String [Expr]
489 ////:: Stmt
490 //// = IfStmt Expr [Stmt] [Stmt]
491 //// | WhileStmt Expr [Stmt]
492 //// | AssStmt VarDef Expr
493 //// | FunStmt FunCall
494 //// | ReturnStmt (Maybe Expr)
495 ////:: Pos = {line :: Int, col :: Int}
496 ////:: AST = AST [VarDecl] [FunDecl]
497 ////:: VarDecl = VarDecl Pos Type String Expr
498 ////:: Type
499 //// = TupleType (Type, Type)
500 //// | ListType Type
501 //// | IdType String
502 //// | IntType
503 //// | BoolType
504 //// | CharType
505 //// | VarType
506 //// | VoidType
507 //// | (->>) infixl 7 Type Type