add datatype generation DSL stuff
[clean-tests.git] / datatype / Printer.hs
diff --git a/datatype/Printer.hs b/datatype/Printer.hs
new file mode 100644 (file)
index 0000000..8668b91
--- /dev/null
@@ -0,0 +1,136 @@
+{-# LANGUAGE MultiParamTypeClasses #-}
+{-# LANGUAGE GeneralizedNewtypeDeriving #-}
+{-# LANGUAGE FlexibleContexts #-}
+{-# LANGUAGE FlexibleInstances #-}
+module Printer where
+
+import Control.Monad.RWS
+import Language
+
+newtype Printer a = P { runPrinter :: RWS Ctx [String] PS a }
+  deriving
+    ( Functor
+    , Applicative
+    , Monad
+    , MonadWriter [String]
+    , MonadState PS
+    , MonadReader Ctx
+    )
+data PS = PS {fresh :: [Int]}
+data Ctx = CtxNo | CtxNullary | CtxNonfix | CtxInfix {assoc :: CtxAssoc, prio :: Int, branch :: CtxAssoc}
+  deriving Eq
+
+leftctx,rightctx,nonectx :: Int -> Ctx
+leftctx p = CtxInfix {assoc=CtxLeft, prio=p, branch=CtxNone}
+rightctx p = CtxInfix {assoc=CtxRight, prio=p, branch=CtxNone}
+nonectx p = CtxInfix {assoc=CtxNone, prio=p, branch=CtxNone}
+
+setBranch :: Ctx -> CtxAssoc -> Ctx
+setBranch ctx@(CtxInfix _ _ _) b = ctx { branch=b }
+setBranch ctx _ = ctx
+
+data CtxAssoc = CtxLeft | CtxRight | CtxNone
+  deriving Eq
+
+runPrint :: Printer a -> String
+runPrint p = concat $ snd $ execRWS (runPrinter p) CtxNo $ PS {fresh=[0..]}
+
+--printString :: Show a => a -> Printer a
+--printString = pure . shows
+--
+printLit :: String -> Printer a
+printLit a = tell [a] *> pure undefined
+--
+--printcc :: Printer a -> Printer b -> Printer c
+--printcc a b = a >>= bkkkkkkkkkkP $ \ps->runPrinter a ps . runPrinter b ps
+--
+--printcs :: Printer a -> Printer b -> Printer c
+--printcs a b = P $ \ps->runPrinter a ps . (' ':) . runPrinter b ps
+
+paren :: Printer a -> Printer a
+paren p = printLit "(" *> p <* printLit ")"
+
+accol :: Printer a -> Printer a
+accol p = printLit "{" *> p <* printLit "}"
+
+paren' :: Ctx -> Printer a -> Printer a
+paren' this p = ask >>= \outer->if needsParen this outer then paren p else p
+
+needsParen :: Ctx -> Ctx -> Bool
+needsParen CtxNo _ = False
+needsParen CtxNullary _ = False
+needsParen CtxNonfix CtxNo = False
+needsParen CtxNonfix CtxNonfix = True
+needsParen CtxNonfix (CtxInfix _ _ _) = False
+needsParen (CtxInfix _ _ _) CtxNo = False
+needsParen (CtxInfix _ _ _) CtxNonfix = True
+needsParen (CtxInfix thisassoc thisprio _) (CtxInfix outerassoc outerprio outerbranch)
+    | outerprio > thisprio = True
+    | outerprio == thisprio
+        = thisassoc /= outerassoc || thisassoc /= outerbranch
+    | otherwise = False
+needsParen _ CtxNullary = error "shouldn't occur"
+
+instance Expression Printer where
+    lit = printLit . show
+    (+.) = printBinOp (leftctx 6) "+"
+    (-.) = printBinOp (leftctx 6) "-"
+    (*.) = printBinOp (leftctx 7) "*"
+    (/.) = printBinOp (leftctx 7) "/"
+    (^.) = printBinOp (rightctx 8) "^"
+    neg = printUnOp (nonectx 7) "!"
+    (&.) = printBinOp (rightctx 3) "&"
+    (|.) = printBinOp (rightctx 2) "|"
+    not = printUnOp (nonectx 7) "!"
+    (==.) = printBinOp (nonectx 4) "=="
+    (/=.) = printBinOp (nonectx 4) "/="
+    (<.) = printBinOp (nonectx 4) "<"
+    (>.) = printBinOp (nonectx 4) ">"
+    (<=.) = printBinOp (nonectx 4) "<"
+    (>=.) = printBinOp (nonectx 4) ">"
+    if' p t e = printLit "if" >> p >> printLit "then" >> t >> printLit "else" >> e
+
+freshLabel :: MonadState PS m => String -> m String
+freshLabel prefix = gets fresh >>= \(f:fs)->modify (\s->s {fresh=fs}) >> pure (prefix ++ show f)
+
+instance Function () Printer where
+    fun def = Main $ freshLabel "f" >>= \f->
+        let g :- m = def (\()->printLit (f ++ " ()"))
+        in  printLit ("let " ++ f ++ " () = ") >> g () >> printLit "\n in " >> unmain m
+instance Function (Printer a) Printer where
+    fun def = Main $ freshLabel "f" >>= \f->freshLabel "a" >>= \a->
+        let g :- m = def (\arg->printLit (f ++ " ") >>> arg)
+        in  printLit (concat ["let ", f, " ", a, " = "]) >> g (printLit a) >> printLit " in\n" >> unmain m
+instance Function (Printer a, Printer b) Printer where
+    fun def = Main $ freshLabel "f" >>= \f->freshLabel "a" >>= \a1->freshLabel "a" >>= \a2->
+        let g :- m = def (\(arg1, arg2)->printLit (f ++ " ") >> arg1 >> printLit " " >>> arg2)
+        in  printLit (concat ["let ", f, " ", a1, " ", a2, " = "]) >> g (printLit a1, printLit a2) >> printLit " in\n" >> unmain m
+instance Function (Printer a, Printer b, Printer c) Printer where
+    fun def = Main $
+        freshLabel "f" >>= \f->
+        freshLabel "a" >>= \a1->
+        freshLabel "a" >>= \a2->
+        freshLabel "a" >>= \a3->
+        let g :- m = def (\(arg1, arg2, arg3)->printLit (f ++ " ") >> arg1 >> printLit " " >> arg2 >> printLit " " >>> arg3)
+        in  printLit (concat ["let ", f, " ", a1, " ", a2, " ", a3, " = "]) >> g (printLit a1, printLit a2, printLit a3) >> printLit " in\n" >> unmain m
+
+(>>>) :: Printer a1 -> Printer a2 -> Printer a3
+l >>> r = l >> r >> pure undefined
+
+printBinOp :: Ctx -> String -> Printer a1 -> Printer a2 -> Printer a3
+printBinOp thisctx op l r = paren' thisctx $
+       local (\_->setBranch thisctx CtxLeft) l
+    >> printLit op
+    >> local (\_->setBranch thisctx CtxRight) r
+    >> pure undefined
+
+printUnOp :: Ctx -> String -> Printer a -> Printer a
+printUnOp thisctx op l = paren' thisctx $
+       printLit op
+    >> local (\_->setBranch thisctx CtxRight) l
+
+printCons :: String -> Printer a -> Printer a
+printCons = printUnOp CtxNonfix . (++" ")
+
+printRec :: String -> Printer a -> Printer a
+printRec op l = printUnOp CtxNo (op++" ") (accol l)