support all other patterns and nested patterns
[clean-tests.git] / datatype / Compiler.hs
1 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE LambdaCase #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 module Compiler where
7
8 import Language
9 import Serialise
10
11 import qualified Data.Map as DM
12 import Control.Monad.Writer
13 import Control.Monad.State
14 import Control.Monad.ST
15 import Debug.Trace
16 import Data.Array
17 import Data.Array.ST
18 import Data.Function
19
20 newtype Compiler a = Compiler { unCompiler :: StateT CS (WriterT [Instr] (Either String)) a }
21 deriving
22 ( Functor
23 , Applicative
24 , Monad
25 , MonadWriter [Instr]
26 , MonadState CS
27 )
28 instance MonadFail Compiler where fail s = Compiler $ lift $ lift $ Left s
29 data CS = CS
30 { fresh :: [Int]
31 , functions :: DM.Map Int [Instr]
32 }
33
34 runCompiler :: Compiler a -> Either String [Instr]
35 runCompiler c = execWriterT
36 $ evalStateT (unCompiler (c >> instr [Halt] >> writeFunctions))
37 $ CS {fresh=[0..], functions=DM.empty}
38 where
39 writeFunctions :: Compiler ()
40 writeFunctions = gets (DM.elems . functions) >>= tell . concat
41
42 instr :: [Instr] -> Compiler a
43 instr i = tell i >> pure undefined
44
45 freshLabel :: Compiler Int
46 freshLabel = gets fresh >>= \(f:fs)->modify (\s->s { fresh=fs }) >> pure f
47
48 binop :: Instr -> Compiler a1 -> Compiler a2 -> Compiler b
49 binop i l r = l >> r >> instr [i]
50
51 unop :: Instr -> Compiler a -> Compiler b
52 unop i l = l >> instr [i]
53
54 instance Expression Compiler where
55 lit v = instr [Push $ serialise v]
56 (+.) = binop Add
57 (-.) = binop Sub
58 (/.) = binop Div
59 (*.) = binop Mul
60 neg = unop Neg
61 (&.) = binop And
62 (|.) = binop Or
63 not = unop Not
64 (==.) = binop Eq
65 (/=.) = binop Neq
66 (<.) = binop Le
67 (>.) = binop Ge
68 (<=.) = binop Leq
69 (>=.) = binop Geq
70 if' p t e = freshLabel >>= \elselabel-> freshLabel >>= \endiflabel->
71 p >> instr [Brf elselabel] >>
72 t >> instr [Bra endiflabel, Lbl elselabel] >>
73 e >> instr [Lbl endiflabel]
74 bottom msg = instr [Error msg]
75
76 instance Function () Compiler where
77 fun def = Main $
78 freshLabel >>= \funlabel->
79 let g :- m = def (\()->instr [Jsr funlabel])
80 in liftFunction funlabel 0 (g ()) >> unmain m
81
82 instance Function (Compiler a) Compiler where
83 fun def = Main $
84 freshLabel >>= \funlabel->
85 let g :- m = def (\a->a >> instr [Jsr funlabel])
86 in liftFunction funlabel 1 (g (instr [Arg 0])) >> unmain m
87
88 instance Function (Compiler a, Compiler b) Compiler where
89 fun def = Main $
90 freshLabel >>= \funlabel->
91 let g :- m = def (\(a, b)->a >> b >> instr [Jsr funlabel])
92 in liftFunction funlabel 2 (g (instr [Arg 1], instr [Arg 0])) >> unmain m
93
94 instance Function (Compiler a, Compiler b, Compiler c) Compiler where
95 fun def = Main $
96 freshLabel >>= \funlabel->
97 let g :- m = def (\(a, b, c)->a >> b >> c >> instr [Jsr funlabel])
98 in liftFunction funlabel 3 (g (instr [Arg 2], instr [Arg 1], instr [Arg 0])) >> unmain m
99
100 instance Function (Compiler a, Compiler b, Compiler c, Compiler d) Compiler where
101 fun def = Main $
102 freshLabel >>= \funlabel->
103 let g :- m = def (\(a, b, c, d)->a >> b >> c >> d >> instr [Jsr funlabel])
104 in liftFunction funlabel 4 (g (instr [Arg 3], instr [Arg 2], instr [Arg 1], instr [Arg 0])) >> unmain m
105
106 liftFunction :: Int -> Int -> Compiler a -> Compiler ()
107 liftFunction lbl nargs body = do
108 is <- snd <$> censor (\_->[]) (listen body)
109 let instructions = Lbl lbl : is ++ [Ret nargs]
110 modify (\s->s { functions=DM.insert lbl instructions $ functions s })
111
112 data Instr
113 = Push Int | Pop Int
114 | Add | Sub | Mul | Div | Neg
115 | And | Or | Not
116 | Eq | Neq | Le | Ge | Leq | Geq
117 | Lbl Int | Bra Int | Brf Int
118 | Str Int | Ldr Int
119 | Sth Int | Ldh Int
120 | Jsr Int | Ret Int | Arg Int
121 | Halt | Error String
122 deriving Show
123
124 data Registers = Registers
125 { pc :: Int
126 , hp :: Int
127 , sp :: Int
128 , mp :: Int
129 , gp :: DM.Map Int Int
130 }
131 deriving Show
132
133 interpret :: Int -> [Instr] -> Int
134 interpret memsize prog = runSTArray resultStack ! (memsize-1)
135 where
136 resultStack :: ST s (STArray s Int Int)
137 resultStack = join $ int
138 <$> newListArray (0, length prog) prog
139 <*> newArray (0, memsize-1) 0
140 <*> pure (Registers {pc=0, sp=memsize-1, mp=0, hp=0, gp=DM.empty})
141
142 pushh :: STArray s Int Int -> Int -> Registers -> ST s Registers
143 pushh memory value reg = do
144 writeArray memory (hp reg) value
145 pure (reg { hp = hp reg + 1} )
146
147 loadh :: STArray s Int Int -> Int -> Registers -> ST s Registers
148 loadh memory hptr registers = readArray memory hptr >>= flip (push memory) registers
149
150 push :: STArray s Int Int -> Int -> Registers -> ST s Registers
151 push memory value reg = do
152 writeArray memory (sp reg) value
153 pure (reg { sp = sp reg - 1} )
154
155 pop :: STArray s Int Int -> Registers -> ST s (Registers, Int)
156 pop memory reg = do
157 v <- readArray memory (sp reg + 1)
158 pure (reg { sp = sp reg + 1}, v)
159
160 popn :: STArray s Int Int -> Int -> Registers -> ST s (Registers, [Int])
161 popn _ 0 reg = pure (reg, [])
162 popn memory n reg = do
163 (reg', v) <- pop memory reg
164 (reg'', vs) <- popn memory (n - 1) reg'
165 pure (reg'', v:vs)
166
167 bop :: (Int -> Int -> Int) -> STArray s Int Int -> Registers -> ST s Registers
168 bop op memory reg = do
169 (reg1, r) <- pop memory reg
170 uop (flip op r) memory reg1
171
172 uop :: (Int -> Int) -> STArray s Int Int -> Registers -> ST s Registers
173 uop op memory reg = do
174 (reg1, r) <- pop memory reg
175 push memory (op r) reg1
176
177 int :: STArray s Int Instr -> STArray s Int Int -> Registers -> ST s (STArray s Int Int)
178 int program memory registers = do
179 instruction <- readArray program $ pc registers
180 let reg = registers { pc = pc registers + 1 }
181 case instruction of
182 Str r -> do
183 (reg', v) <- pop memory reg
184 int program memory $ reg' { gp = DM.insert r v (gp reg')}
185 Ldr r -> push memory (DM.findWithDefault 0 r $ gp reg) reg >>= int program memory
186 Pop n -> popn memory n reg >>= int program memory . fst
187 Push v -> push memory v reg >>= int program memory
188 Add -> bop (+) memory reg >>= int program memory
189 Sub -> bop (-) memory reg >>= int program memory
190 Mul -> bop (*) memory reg >>= int program memory
191 Div -> bop div memory reg >>= int program memory
192 Neg -> uop negate memory reg >>= int program memory
193 And -> bop ((b2i .) . on (&&) i2b) memory reg >>= int program memory
194 Or -> bop ((b2i .) . on (||) i2b) memory reg >>= int program memory
195 Not -> uop (b2i . Prelude.not . i2b) memory reg >>= int program memory
196 Eq -> bop ((b2i .) . (==)) memory reg >>= int program memory
197 Neq -> bop ((b2i .) . (/=)) memory reg >>= int program memory
198 Le -> bop ((b2i .) . (<)) memory reg >>= int program memory
199 Ge -> bop ((b2i .) . (>)) memory reg >>= int program memory
200 Leq -> bop ((b2i .) . (<=)) memory reg >>= int program memory
201 Geq -> bop ((b2i .) . (>=)) memory reg >>= int program memory
202 Lbl _ -> int program memory reg
203 Bra l -> branch l program reg >>= int program memory
204 Brf l -> do
205 (reg', v) <- pop memory reg
206 reg'' <- if i2b v then pure reg' else branch l program reg'
207 int program memory reg''
208 Sth n ->
209 popn memory n reg
210 >>= uncurry (foldM $ flip $ pushh memory)
211 >>= push memory (hp reg + n - 1)
212 >>= int program memory
213 Ldh n -> pop memory reg >>= \(reg', hptr)->loadh memory (hptr - n - 1) reg'
214 >>= int program memory
215 Jsr i -> push memory (pc reg) reg
216 >>= push memory (mp reg)
217 >>= branch i program
218 >>= \r->int program memory (r { mp = sp r})
219 Ret n -> do
220 (reg1, rval:omp:ra:_) <- popn memory (3+n) reg
221 reg2 <- push memory rval reg1
222 int program memory $ reg2 { pc=ra, mp=omp }
223 Arg n -> do
224 v <- readArray memory (mp reg + 3 + n)
225 push memory v reg >>= int program memory
226 Halt -> pure memory
227 Error msg -> fail msg
228
229 branch :: Int -> STArray s Int Instr -> Registers -> ST s Registers
230 branch label program reg = case pc reg of
231 -1 -> getBounds program >>= \(_, m)->branch label program $ reg { pc = m - 1}
232 _ -> readArray program (pc reg) >>= \case
233 Lbl l | label == l -> pure $ reg
234 _ -> branch label program $ reg { pc = pc reg - 1 }
235
236 b2i :: Bool -> Int
237 b2i True = 1
238 b2i False = 0
239
240 i2b :: Int -> Bool
241 i2b 0 = False
242 i2b _ = True