566972588c62c3da959872835bc70cf52f0ef325
[clean-tests.git] / datatype / Compiler.hs
1 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE LambdaCase #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 module Compiler where
7
8 import Language
9 import Serialise
10
11 import qualified Data.Map as DM
12 import Control.Monad.Writer
13 import Control.Monad.State
14 import Control.Monad.ST
15 import Debug.Trace
16 import Data.Array
17 import Data.Array.ST
18 import Data.Function
19
20 newtype Compiler a = Compiler { unCompiler :: StateT CS (WriterT [Instr] (Either String)) a }
21 deriving
22 ( Functor
23 , Applicative
24 , Monad
25 , MonadWriter [Instr]
26 , MonadState CS
27 )
28 instance MonadFail Compiler where fail s = Compiler $ lift $ lift $ Left s
29 data CS = CS
30 { fresh :: [Int]
31 , functions :: DM.Map Int [Instr]
32 }
33
34 runCompiler :: Compiler a -> Either String [Instr]
35 runCompiler c = execWriterT
36 $ evalStateT (unCompiler (c >> instr [Halt] >> writeFunctions))
37 $ CS {fresh=[0..], functions=DM.empty}
38 where
39 writeFunctions :: Compiler ()
40 writeFunctions = gets (DM.elems . functions) >>= tell . concat
41
42 instr :: [Instr] -> Compiler a
43 instr i = tell i >> pure undefined
44
45 freshLabel :: Compiler Int
46 freshLabel = gets fresh >>= \(f:fs)->modify (\s->s { fresh=fs }) >> pure f
47
48 binop :: Instr -> Compiler a1 -> Compiler a2 -> Compiler b
49 binop i l r = l >> r >> instr [i]
50
51 unop :: Instr -> Compiler a -> Compiler b
52 unop i l = l >> instr [i]
53
54 instance DSL Compiler
55 instance Expression Compiler where
56 lit v = instr [Push $ serialise v]
57 (+.) = binop Add
58 (-.) = binop Sub
59 (/.) = binop Div
60 (*.) = binop Mul
61 neg = unop Neg
62 (&.) = binop And
63 (|.) = binop Or
64 not = unop Not
65 (==.) = binop Eq
66 (/=.) = binop Neq
67 (<.) = binop Le
68 (>.) = binop Ge
69 (<=.) = binop Leq
70 (>=.) = binop Geq
71 if' p t e = freshLabel >>= \elselabel-> freshLabel >>= \endiflabel->
72 p >> instr [Brf elselabel] >>
73 t >> instr [Bra endiflabel, Lbl elselabel] >>
74 e >> instr [Lbl endiflabel]
75 bottom msg = instr [Error msg]
76
77 instance Function () Compiler where
78 fun def = Main $
79 freshLabel >>= \funlabel->
80 let g :- m = def (\()->instr [Jsr funlabel])
81 in liftFunction funlabel 0 (g ()) >> unmain m
82
83 instance Function (Compiler a) Compiler where
84 fun def = Main $
85 freshLabel >>= \funlabel->
86 let g :- m = def (\a->a >> instr [Jsr funlabel])
87 in liftFunction funlabel 1 (g (instr [Arg 0])) >> unmain m
88
89 instance Function (Compiler a, Compiler b) Compiler where
90 fun def = Main $
91 freshLabel >>= \funlabel->
92 let g :- m = def (\(a, b)->a >> b >> instr [Jsr funlabel])
93 in liftFunction funlabel 2 (g (instr [Arg 1], instr [Arg 0])) >> unmain m
94
95 instance Function (Compiler a, Compiler b, Compiler c) Compiler where
96 fun def = Main $
97 freshLabel >>= \funlabel->
98 let g :- m = def (\(a, b, c)->a >> b >> c >> instr [Jsr funlabel])
99 in liftFunction funlabel 3 (g (instr [Arg 2], instr [Arg 1], instr [Arg 0])) >> unmain m
100
101 instance Function (Compiler a, Compiler b, Compiler c, Compiler d) Compiler where
102 fun def = Main $
103 freshLabel >>= \funlabel->
104 let g :- m = def (\(a, b, c, d)->a >> b >> c >> d >> instr [Jsr funlabel])
105 in liftFunction funlabel 4 (g (instr [Arg 3], instr [Arg 2], instr [Arg 1], instr [Arg 0])) >> unmain m
106
107 liftFunction :: Int -> Int -> Compiler a -> Compiler ()
108 liftFunction lbl nargs body = do
109 is <- snd <$> censor (\_->[]) (listen body)
110 let instructions = Lbl lbl : is ++ [Ret nargs]
111 modify (\s->s { functions=DM.insert lbl instructions $ functions s })
112
113 data Instr
114 = Push Int | Pop Int
115 | Add | Sub | Mul | Div | Neg
116 | And | Or | Not
117 | Eq | Neq | Le | Ge | Leq | Geq
118 | Lbl Int | Bra Int | Brf Int
119 | Str Int | Ldr Int
120 | Sth Int | Ldh Int
121 | Jsr Int | Ret Int | Arg Int
122 | Halt | Error String
123 deriving Show
124
125 data Registers = Registers
126 { pc :: Int
127 , hp :: Int
128 , sp :: Int
129 , mp :: Int
130 , gp :: DM.Map Int Int
131 }
132 deriving Show
133
134 interpret :: Int -> [Instr] -> Int
135 interpret memsize prog = runSTArray resultStack ! (memsize-1)
136 where
137 resultStack :: ST s (STArray s Int Int)
138 resultStack = join $ int
139 <$> newListArray (0, length prog) prog
140 <*> newArray (0, memsize-1) 0
141 <*> pure (Registers {pc=0, sp=memsize-1, mp=0, hp=0, gp=DM.empty})
142
143 pushh :: STArray s Int Int -> Int -> Registers -> ST s Registers
144 pushh memory value reg = do
145 writeArray memory (hp reg) value
146 pure (reg { hp = hp reg + 1} )
147
148 loadh :: STArray s Int Int -> Int -> Registers -> ST s Registers
149 loadh memory hptr registers = readArray memory hptr >>= flip (push memory) registers
150
151 push :: STArray s Int Int -> Int -> Registers -> ST s Registers
152 push memory value reg = do
153 writeArray memory (sp reg) value
154 pure (reg { sp = sp reg - 1} )
155
156 pop :: STArray s Int Int -> Registers -> ST s (Registers, Int)
157 pop memory reg = do
158 v <- readArray memory (sp reg + 1)
159 pure (reg { sp = sp reg + 1}, v)
160
161 popn :: STArray s Int Int -> Int -> Registers -> ST s (Registers, [Int])
162 popn _ 0 reg = pure (reg, [])
163 popn memory n reg = do
164 (reg', v) <- pop memory reg
165 (reg'', vs) <- popn memory (n - 1) reg'
166 pure (reg'', v:vs)
167
168 bop :: (Int -> Int -> Int) -> STArray s Int Int -> Registers -> ST s Registers
169 bop op memory reg = do
170 (reg1, r) <- pop memory reg
171 uop (flip op r) memory reg1
172
173 uop :: (Int -> Int) -> STArray s Int Int -> Registers -> ST s Registers
174 uop op memory reg = do
175 (reg1, r) <- pop memory reg
176 push memory (op r) reg1
177
178 int :: STArray s Int Instr -> STArray s Int Int -> Registers -> ST s (STArray s Int Int)
179 int program memory registers = do
180 instruction <- readArray program $ pc registers
181 let reg = registers { pc = pc registers + 1 }
182 case instruction of
183 Str r -> do
184 (reg', v) <- pop memory reg
185 int program memory $ reg' { gp = DM.insert r v (gp reg')}
186 Ldr r -> push memory (DM.findWithDefault 0 r $ gp reg) reg >>= int program memory
187 Pop n -> popn memory n reg >>= int program memory . fst
188 Push v -> push memory v reg >>= int program memory
189 Add -> bop (+) memory reg >>= int program memory
190 Sub -> bop (-) memory reg >>= int program memory
191 Mul -> bop (*) memory reg >>= int program memory
192 Div -> bop div memory reg >>= int program memory
193 Neg -> uop negate memory reg >>= int program memory
194 And -> bop ((b2i .) . on (&&) i2b) memory reg >>= int program memory
195 Or -> bop ((b2i .) . on (||) i2b) memory reg >>= int program memory
196 Not -> uop (b2i . Prelude.not . i2b) memory reg >>= int program memory
197 Eq -> bop ((b2i .) . (==)) memory reg >>= int program memory
198 Neq -> bop ((b2i .) . (/=)) memory reg >>= int program memory
199 Le -> bop ((b2i .) . (<)) memory reg >>= int program memory
200 Ge -> bop ((b2i .) . (>)) memory reg >>= int program memory
201 Leq -> bop ((b2i .) . (<=)) memory reg >>= int program memory
202 Geq -> bop ((b2i .) . (>=)) memory reg >>= int program memory
203 Lbl _ -> int program memory reg
204 Bra l -> branch l program reg >>= int program memory
205 Brf l -> do
206 (reg', v) <- pop memory reg
207 reg'' <- if i2b v then pure reg' else branch l program reg'
208 int program memory reg''
209 Sth n ->
210 popn memory n reg
211 >>= uncurry (foldM $ flip $ pushh memory)
212 >>= push memory (hp reg + n - 1)
213 >>= int program memory
214 Ldh n -> pop memory reg >>= \(reg', hptr)->loadh memory (hptr - n - 1) reg'
215 >>= int program memory
216 Jsr i -> push memory (pc reg) reg
217 >>= push memory (mp reg)
218 >>= branch i program
219 >>= \r->int program memory (r { mp = sp r})
220 Ret n -> do
221 (reg1, rval:omp:ra:_) <- popn memory (3+n) reg
222 reg2 <- push memory rval reg1
223 int program memory $ reg2 { pc=ra, mp=omp }
224 Arg n -> do
225 v <- readArray memory (mp reg + 3 + n)
226 push memory v reg >>= int program memory
227 Halt -> pure memory
228 Error msg -> fail msg
229
230 branch :: Int -> STArray s Int Instr -> Registers -> ST s Registers
231 branch label program reg = case pc reg of
232 -1 -> getBounds program >>= \(_, m)->branch label program $ reg { pc = m - 1}
233 _ -> readArray program (pc reg) >>= \case
234 Lbl l | label == l -> pure $ reg
235 _ -> branch label program $ reg { pc = pc reg - 1 }
236
237 b2i :: Bool -> Int
238 b2i True = 1
239 b2i False = 0
240
241 i2b :: Int -> Bool
242 i2b 0 = False
243 i2b _ = True