cleanup
[clean-tests.git] / datatype / Compiler.hs
1 {-# LANGUAGE GeneralizedNewtypeDeriving #-}
2 {-# LANGUAGE FlexibleContexts #-}
3 {-# LANGUAGE FlexibleInstances #-}
4 {-# LANGUAGE LambdaCase #-}
5 {-# LANGUAGE MultiParamTypeClasses #-}
6 module Compiler where
7
8 import Language
9
10 import qualified Data.Map as DM
11 import Control.Monad.ST
12 import Control.Monad.State
13 import Control.Monad.Writer
14 import Data.Array
15 import Data.Array.ST
16 import Data.Function
17
18 newtype Compiler a = Compiler { unCompiler :: StateT CS (WriterT [Instr] (Either String)) a }
19 deriving
20 ( Functor
21 , Applicative
22 , Monad
23 , MonadWriter [Instr]
24 , MonadState CS
25 )
26 instance MonadFail Compiler where fail s = Compiler $ lift $ lift $ Left s
27 data CS = CS
28 { fresh :: [Int]
29 , functions :: DM.Map Int [Instr]
30 }
31
32 runCompiler :: Compiler a -> Either String [Instr]
33 runCompiler c = execWriterT
34 $ evalStateT (unCompiler (c >> instr [Halt] >> writeFunctions))
35 $ CS {fresh=[0..], functions=DM.empty}
36 where
37 writeFunctions :: Compiler ()
38 writeFunctions = gets (DM.elems . functions) >>= tell . concat
39
40 instr :: [Instr] -> Compiler a
41 instr i = tell i >> pure undefined
42
43 freshLabel :: Compiler Int
44 freshLabel = gets fresh >>= \(f:fs)->modify (\s->s { fresh=fs }) >> pure f
45
46 binop :: Instr -> Compiler a1 -> Compiler a2 -> Compiler b
47 binop i l r = l >> r >> instr [i]
48
49 unop :: Instr -> Compiler a -> Compiler b
50 unop i l = l >> instr [i]
51
52 instance DSL Compiler
53 instance Expression Compiler where
54 lit v = instr [Push $ serialise v]
55 (+.) = binop Add
56 (-.) = binop Sub
57 (/.) = binop Div
58 (*.) = binop Mul
59 neg = unop Neg
60 (&.) = binop And
61 (|.) = binop Or
62 not = unop Not
63 (==.) = binop Eq
64 (/=.) = binop Neq
65 (<.) = binop Le
66 (>.) = binop Ge
67 (<=.) = binop Leq
68 (>=.) = binop Geq
69 if' p t e = freshLabel >>= \elselabel-> freshLabel >>= \endiflabel->
70 p >> instr [Brf elselabel] >>
71 t >> instr [Bra endiflabel, Lbl elselabel] >>
72 e >> instr [Lbl endiflabel]
73 bottom msg = instr [Error msg]
74
75 instance Function () Compiler where
76 fun def = Main $
77 freshLabel >>= \funlabel->
78 let g :- m = def (\()->instr [Jsr funlabel])
79 in liftFunction funlabel 0 (g ()) >> unmain m
80
81 instance Function (Compiler a) Compiler where
82 fun def = Main $
83 freshLabel >>= \funlabel->
84 let g :- m = def (\a->a >> instr [Jsr funlabel])
85 in liftFunction funlabel 1 (g (instr [Arg 0])) >> unmain m
86
87 instance Function (Compiler a, Compiler b) Compiler where
88 fun def = Main $
89 freshLabel >>= \funlabel->
90 let g :- m = def (\(a, b)->a >> b >> instr [Jsr funlabel])
91 in liftFunction funlabel 2 (g (instr [Arg 1], instr [Arg 0])) >> unmain m
92
93 instance Function (Compiler a, Compiler b, Compiler c) Compiler where
94 fun def = Main $
95 freshLabel >>= \funlabel->
96 let g :- m = def (\(a, b, c)->a >> b >> c >> instr [Jsr funlabel])
97 in liftFunction funlabel 3 (g (instr [Arg 2], instr [Arg 1], instr [Arg 0])) >> unmain m
98
99 instance Function (Compiler a, Compiler b, Compiler c, Compiler d) Compiler where
100 fun def = Main $
101 freshLabel >>= \funlabel->
102 let g :- m = def (\(a, b, c, d)->a >> b >> c >> d >> instr [Jsr funlabel])
103 in liftFunction funlabel 4 (g (instr [Arg 3], instr [Arg 2], instr [Arg 1], instr [Arg 0])) >> unmain m
104
105 liftFunction :: Int -> Int -> Compiler a -> Compiler ()
106 liftFunction lbl nargs body = do
107 is <- snd <$> censor (\_->[]) (listen body)
108 let instructions = Lbl lbl : is ++ [Ret nargs]
109 modify (\s->s { functions=DM.insert lbl instructions $ functions s })
110
111 data Instr
112 = Push Int | Pop Int
113 | Add | Sub | Mul | Div | Neg
114 | And | Or | Not
115 | Eq | Neq | Le | Ge | Leq | Geq
116 | Lbl Int | Bra Int | Brf Int
117 | Str Int | Ldr Int
118 | Sth Int | Ldh Int
119 | Jsr Int | Ret Int | Arg Int
120 | Halt | Error String
121 deriving Show
122
123 data Registers = Registers
124 { pc :: Int
125 , hp :: Int
126 , sp :: Int
127 , mp :: Int
128 , gp :: DM.Map Int Int
129 }
130 deriving Show
131
132 interpret :: Int -> [Instr] -> Int
133 interpret memsize prog = runSTArray resultStack ! (memsize-1)
134 where
135 resultStack :: ST s (STArray s Int Int)
136 resultStack = join $ int
137 <$> newListArray (0, length prog) prog
138 <*> newArray (0, memsize-1) 0
139 <*> pure (Registers {pc=0, sp=memsize-1, mp=0, hp=0, gp=DM.empty})
140
141 pushh :: STArray s Int Int -> Int -> Registers -> ST s Registers
142 pushh memory value reg = do
143 writeArray memory (hp reg) value
144 pure (reg { hp = hp reg + 1} )
145
146 loadh :: STArray s Int Int -> Int -> Registers -> ST s Registers
147 loadh memory hptr registers = readArray memory hptr >>= flip (push memory) registers
148
149 push :: STArray s Int Int -> Int -> Registers -> ST s Registers
150 push memory value reg = do
151 writeArray memory (sp reg) value
152 pure (reg { sp = sp reg - 1} )
153
154 pop :: STArray s Int Int -> Registers -> ST s (Registers, Int)
155 pop memory reg = do
156 v <- readArray memory (sp reg + 1)
157 pure (reg { sp = sp reg + 1}, v)
158
159 popn :: STArray s Int Int -> Int -> Registers -> ST s (Registers, [Int])
160 popn _ 0 reg = pure (reg, [])
161 popn memory n reg = do
162 (reg', v) <- pop memory reg
163 (reg'', vs) <- popn memory (n - 1) reg'
164 pure (reg'', v:vs)
165
166 bop :: (Int -> Int -> Int) -> STArray s Int Int -> Registers -> ST s Registers
167 bop op memory reg = do
168 (reg1, r) <- pop memory reg
169 uop (flip op r) memory reg1
170
171 uop :: (Int -> Int) -> STArray s Int Int -> Registers -> ST s Registers
172 uop op memory reg = do
173 (reg1, r) <- pop memory reg
174 push memory (op r) reg1
175
176 int :: STArray s Int Instr -> STArray s Int Int -> Registers -> ST s (STArray s Int Int)
177 int program memory registers = do
178 instruction <- readArray program $ pc registers
179 let reg = registers { pc = pc registers + 1 }
180 case instruction of
181 Str r -> do
182 (reg', v) <- pop memory reg
183 int program memory $ reg' { gp = DM.insert r v (gp reg')}
184 Ldr r -> push memory (DM.findWithDefault 0 r $ gp reg) reg >>= int program memory
185 Pop n -> popn memory n reg >>= int program memory . fst
186 Push v -> push memory v reg >>= int program memory
187 Add -> bop (+) memory reg >>= int program memory
188 Sub -> bop (-) memory reg >>= int program memory
189 Mul -> bop (*) memory reg >>= int program memory
190 Div -> bop div memory reg >>= int program memory
191 Neg -> uop negate memory reg >>= int program memory
192 And -> bop ((fromEnum .) . on (&&) toEnum) memory reg >>= int program memory
193 Or -> bop ((fromEnum .) . on (||) toEnum) memory reg >>= int program memory
194 Not -> uop (fromEnum . Prelude.not . toEnum) memory reg >>= int program memory
195 Eq -> bop ((fromEnum .) . (==)) memory reg >>= int program memory
196 Neq -> bop ((fromEnum .) . (/=)) memory reg >>= int program memory
197 Le -> bop ((fromEnum .) . (<)) memory reg >>= int program memory
198 Ge -> bop ((fromEnum .) . (>)) memory reg >>= int program memory
199 Leq -> bop ((fromEnum .) . (<=)) memory reg >>= int program memory
200 Geq -> bop ((fromEnum .) . (>=)) memory reg >>= int program memory
201 Lbl _ -> int program memory reg
202 Bra l -> branch l program reg >>= int program memory
203 Brf l -> do
204 (reg', v) <- pop memory reg
205 reg'' <- if toEnum v then pure reg' else branch l program reg'
206 int program memory reg''
207 Sth n ->
208 popn memory n reg
209 >>= uncurry (foldM $ flip $ pushh memory)
210 >>= push memory (hp reg + n - 1)
211 >>= int program memory
212 Ldh n -> pop memory reg >>= \(reg', hptr)->loadh memory (hptr - n - 1) reg'
213 >>= int program memory
214 Jsr i -> push memory (pc reg) reg
215 >>= push memory (mp reg)
216 >>= branch i program
217 >>= \r->int program memory (r { mp = sp r})
218 Ret n -> do
219 (reg1, rval:omp:ra:_) <- popn memory (3+n) reg
220 reg2 <- push memory rval reg1
221 int program memory $ reg2 { pc=ra, mp=omp }
222 Arg n -> do
223 v <- readArray memory (mp reg + 3 + n)
224 push memory v reg >>= int program memory
225 Halt -> pure memory
226 Error msg -> fail msg
227
228 branch :: Int -> STArray s Int Instr -> Registers -> ST s Registers
229 branch label program reg = case pc reg of
230 -1 -> getBounds program >>= \(_, m)->branch label program $ reg { pc = m - 1}
231 _ -> readArray program (pc reg) >>= \case
232 Lbl l | label == l -> pure $ reg
233 _ -> branch label program $ reg { pc = pc reg - 1 }