79ee6b7799b12c87a1ae5c57cff9814c55a2ff1b
[clean-tests.git] / datatype / Compiler.hs
1 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE LambdaCase #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 module Compiler where
7
8 import Language
9 import Serialise
10 --
11 import qualified Data.Map as DM
12 import Control.Monad.Writer
13 import Control.Monad.State
14 import Control.Monad.ST
15 --import Debug.Trace
16 import Data.Array
17 import Data.Array.ST
18 import Data.Function
19
20 newtype Compiler a = Compiler { unCompiler :: StateT CS (WriterT [Instr] (Either String)) a }
21 deriving
22 ( Functor
23 , Applicative
24 , Monad
25 , MonadWriter [Instr]
26 , MonadState CS
27 )
28 instance MonadFail Compiler where fail s = Compiler $ lift $ lift $ Left s
29 data CS = CS
30 { fresh :: [Int]
31 , functions :: DM.Map Int [Instr]
32 }
33
34 runCompiler :: Compiler a -> Either String [Instr]
35 runCompiler c = execWriterT
36 $ evalStateT (unCompiler (c >> instr [Halt] >> writeFunctions))
37 $ CS {fresh=[0..], functions=DM.empty}
38 where
39 writeFunctions :: Compiler ()
40 writeFunctions = gets (DM.elems . functions) >>= tell . concat
41
42 instr :: [Instr] -> Compiler a
43 instr i = tell i >> pure undefined
44
45 freshLabel :: Compiler Int
46 freshLabel = gets fresh >>= \(f:fs)->modify (\s->s { fresh=fs }) >> pure f
47
48 binop :: Instr -> Compiler a1 -> Compiler a2 -> Compiler b
49 binop i l r = l >> r >> instr [i]
50
51 unop :: Instr -> Compiler a -> Compiler b
52 unop i l = l >> instr [i]
53
54 instance DSL Compiler
55 instance Expression Compiler where
56 lit v = instr [Push $ serialise v]
57 (+.) = binop Add
58 (-.) = binop Sub
59 (/.) = binop Div
60 (*.) = binop Mul
61 neg = unop Neg
62 (&.) = binop And
63 (|.) = binop Or
64 not = unop Not
65 (==.) = binop Eq
66 (/=.) = binop Neq
67 (<.) = binop Le
68 (>.) = binop Ge
69 (<=.) = binop Leq
70 (>=.) = binop Geq
71 if' p t e = freshLabel >>= \elselabel-> freshLabel >>= \endiflabel->
72 p >> instr [Brf elselabel] >>
73 t >> instr [Bra endiflabel, Lbl elselabel] >>
74 e >> instr [Lbl endiflabel]
75 bottom msg = instr [Error msg]
76
77 instance Function () Compiler where
78 fun def = Main $
79 freshLabel >>= \funlabel->
80 let g :- m = def (\()->instr [Jsr funlabel])
81 in liftFunction funlabel 0 (g ()) >> unmain m
82
83 instance Function (Compiler a) Compiler where
84 fun def = Main $
85 freshLabel >>= \funlabel->
86 let g :- m = def (\a->a >> instr [Jsr funlabel])
87 in liftFunction funlabel 1 (g (instr [Arg 0])) >> unmain m
88
89 instance Function (Compiler a, Compiler b) Compiler where
90 fun def = Main $
91 freshLabel >>= \funlabel->
92 let g :- m = def (\(a, b)->a >> b >> instr [Jsr funlabel])
93 in liftFunction funlabel 2 (g (instr [Arg 1], instr [Arg 0])) >> unmain m
94
95 instance Function (Compiler a, Compiler b, Compiler c) Compiler where
96 fun def = Main $
97 freshLabel >>= \funlabel->
98 let g :- m = def (\(a, b, c)->a >> b >> c >> instr [Jsr funlabel])
99 in liftFunction funlabel 3 (g (instr [Arg 2], instr [Arg 1], instr [Arg 0])) >> unmain m
100
101 instance Function (Compiler a, Compiler b, Compiler c, Compiler d) Compiler where
102 fun def = Main $
103 freshLabel >>= \funlabel->
104 let g :- m = def (\(a, b, c, d)->a >> b >> c >> d >> instr [Jsr funlabel])
105 in liftFunction funlabel 4 (g (instr [Arg 3], instr [Arg 2], instr [Arg 1], instr [Arg 0])) >> unmain m
106
107 liftFunction :: Int -> Int -> Compiler a -> Compiler ()
108 liftFunction lbl nargs body = do
109 is <- snd <$> censor (\_->[]) (listen body)
110 let instructions = Lbl lbl : is ++ [Ret nargs]
111 modify (\s->s { functions=DM.insert lbl instructions $ functions s })
112
113 data Instr
114 = Push Int | Pop Int
115 | Add | Sub | Mul | Div | Neg
116 | And | Or | Not
117 | Eq | Neq | Le | Ge | Leq | Geq
118 | Lbl Int | Bra Int | Brf Int
119 | Str Int | Ldr Int
120 | Sth Int | Ldh Int
121 | Jsr Int | Ret Int | Arg Int
122 | Halt | Error String
123 deriving Show
124
125 data Registers = Registers
126 { pc :: Int
127 , hp :: Int
128 , sp :: Int
129 , mp :: Int
130 , gp :: DM.Map Int Int
131 }
132 deriving Show
133
134 interpret :: Int -> [Instr] -> Int
135 interpret memsize prog = runSTArray resultStack ! (memsize-1)
136 where
137 resultStack :: ST s (STArray s Int Int)
138 resultStack = join $ int
139 <$> newListArray (0, length prog) prog
140 <*> newArray (0, memsize-1) 0
141 <*> pure (Registers {pc=0, sp=memsize-1, mp=0, hp=0, gp=DM.empty})
142
143 pushh :: STArray s Int Int -> Int -> Registers -> ST s Registers
144 pushh memory value reg = do
145 writeArray memory (hp reg) value
146 pure (reg { hp = hp reg + 1} )
147
148 loadh :: STArray s Int Int -> Int -> Registers -> ST s Registers
149 loadh memory hptr registers = readArray memory hptr >>= flip (push memory) registers
150
151 push :: STArray s Int Int -> Int -> Registers -> ST s Registers
152 push memory value reg = do
153 writeArray memory (sp reg) value
154 pure (reg { sp = sp reg - 1} )
155
156 pop :: STArray s Int Int -> Registers -> ST s (Registers, Int)
157 pop memory reg = do
158 v <- readArray memory (sp reg + 1)
159 pure (reg { sp = sp reg + 1}, v)
160
161 popn :: STArray s Int Int -> Int -> Registers -> ST s (Registers, [Int])
162 popn _ 0 reg = pure (reg, [])
163 popn memory n reg = do
164 (reg', v) <- pop memory reg
165 (reg'', vs) <- popn memory (n - 1) reg'
166 pure (reg'', v:vs)
167
168 bop :: (Int -> Int -> Int) -> STArray s Int Int -> Registers -> ST s Registers
169 bop op memory reg = do
170 (reg1, r) <- pop memory reg
171 uop (flip op r) memory reg1
172
173 uop :: (Int -> Int) -> STArray s Int Int -> Registers -> ST s Registers
174 uop op memory reg = do
175 (reg1, r) <- pop memory reg
176 push memory (op r) reg1
177
178 int :: STArray s Int Instr -> STArray s Int Int -> Registers -> ST s (STArray s Int Int)
179 int program memory registers = do
180 instruction <- readArray program $ pc registers
181 let reg = registers { pc = pc registers + 1 }
182 case instruction of
183 Str r -> do
184 (reg', v) <- pop memory reg
185 int program memory $ reg' { gp = DM.insert r v (gp reg')}
186 Ldr r -> push memory (DM.findWithDefault 0 r $ gp reg) reg >>= int program memory
187 Pop n -> popn memory n reg >>= int program memory . fst
188 Push v -> push memory v reg >>= int program memory
189 Add -> bop (+) memory reg >>= int program memory
190 Sub -> bop (-) memory reg >>= int program memory
191 Mul -> bop (*) memory reg >>= int program memory
192 Div -> bop div memory reg >>= int program memory
193 Neg -> uop negate memory reg >>= int program memory
194 And -> bop ((b2i .) . on (&&) i2b) memory reg >>= int program memory
195 Or -> bop ((b2i .) . on (||) i2b) memory reg >>= int program memory
196 Not -> uop (b2i . Prelude.not . i2b) memory reg >>= int program memory
197 Eq -> bop ((b2i .) . (==)) memory reg >>= int program memory
198 Neq -> bop ((b2i .) . (/=)) memory reg >>= int program memory
199 Le -> bop ((b2i .) . (<)) memory reg >>= int program memory
200 Ge -> bop ((b2i .) . (>)) memory reg >>= int program memory
201 Leq -> bop ((b2i .) . (<=)) memory reg >>= int program memory
202 Geq -> bop ((b2i .) . (>=)) memory reg >>= int program memory
203 Lbl _ -> int program memory reg
204 Bra l -> branch l program reg >>= int program memory
205 Brf l -> do
206 (reg', v) <- pop memory reg
207 reg'' <- if i2b v then pure reg' else branch l program reg'
208 int program memory reg''
209 Sth n ->
210 popn memory n reg
211 >>= uncurry (foldM $ flip $ pushh memory)
212 >>= push memory (hp reg + n - 1)
213 >>= int program memory
214 Ldh n -> pop memory reg >>= \(reg', hptr)->loadh memory (hptr - n - 1) reg'
215 >>= int program memory
216 Jsr i -> push memory (pc reg) reg
217 >>= push memory (mp reg)
218 >>= branch i program
219 >>= \r->int program memory (r { mp = sp r})
220 Ret n -> do
221 (reg1, rval:omp:ra:_) <- popn memory (3+n) reg
222 reg2 <- push memory rval reg1
223 int program memory $ reg2 { pc=ra, mp=omp }
224 Arg n -> do
225 v <- readArray memory (mp reg + 3 + n)
226 push memory v reg >>= int program memory
227 Halt -> pure memory
228 Error msg -> fail msg
229
230 branch :: Int -> STArray s Int Instr -> Registers -> ST s Registers
231 branch label program reg = case pc reg of
232 -1 -> getBounds program >>= \(_, m)->branch label program $ reg { pc = m - 1}
233 _ -> readArray program (pc reg) >>= \case
234 Lbl l | label == l -> pure $ reg
235 _ -> branch label program $ reg { pc = pc reg - 1 }
236
237 b2i :: Bool -> Int
238 b2i True = 1
239 b2i False = 0
240
241 i2b :: Int -> Bool
242 i2b 0 = False
243 i2b _ = True