.
[clean-tests.git] / datatype / Language / GenDSL.hs
1 {-# LANGUAGE TemplateHaskell #-}
2 {-# LANGUAGE ParallelListComp #-}
3 {-# LANGUAGE KindSignatures #-}
4 module Language.GenDSL where
5
6 import Language.Haskell.TH.Syntax
7 import Language.Haskell.TH
8 import Data.Char
9 import qualified Data.Set as DS
10 import Control.Monad
11 import Debug.Trace
12
13 import Printer
14 import Compiler
15 import Interpreter
16
17 className,constructorName,selectorName,predicateName :: Name -> Name
18 className = mkName . (++"DSL") . stringName
19 constructorName = mkName . (\(c:cs)->toLower c:cs) . stringName
20 selectorName = mkName . ("get"++) . stringName
21 predicateName = mkName . ("is"++) . stringName
22 setterName = mkName . ("set"++) . stringName
23
24 stringName :: Name -> String
25 stringName (Name occ _) = occString occ
26
27 adtFieldName :: Name -> Integer -> Name
28 adtFieldName consName idx = mkName $ map toLower (stringName consName) ++ "f" ++ show idx
29
30 ifx :: String -> ExpQ -> ExpQ -> ExpQ
31 ifx op a b = infixE (Just a) (varE $ mkName op) (Just b)
32
33 fun :: Name -> [PatQ] -> ExpQ -> DecQ
34 fun name args body = funD name [clause args (normalB body) []]
35
36 class GenDSL a where genDSL :: a -> DecsQ
37 instance GenDSL a => GenDSL [a] where genDSL = fmap concat . mapM genDSL
38 instance GenDSL Dec where
39 genDSL (DataD _ typeName tyVars _ constructors _)
40 = mapM getConsName constructors >>= mkDSL typeName . concat
41 where
42 getConsName :: Con -> Q [(Name, [(Name, Type)], Type)]
43 getConsName (RecGadtC consNames fs ty)
44 = pure [(consName, [(n, t) | (n, _, t)<-fs], ty) | consName<-consNames]
45 --Invent names for non record types
46 getConsName (GadtC consNames fs ty)
47 | all (not . (':'==) . head . stringName) consNames
48 = concat <$> mapM getConsName [RecGadtC [consName] [(adtFieldName consName i, b, t) | (b, t)<-fs | i<-[0..]] ty | consName <- consNames]
49 getConsName (NormalC consName fs) = getConsName $ RecC consName [(adtFieldName consName i, b, t) | (b, t)<-fs | i<-[0..]]
50 getConsName (RecC consName fs) = getConsName $ RecGadtC [consName] fs
51 $ foldl AppT (ConT typeName) $ map getName tyVars
52 where getName (PlainTV name) = VarT name
53 getName (KindedTV name _) = VarT name
54 getConsName (ForallC _ [] ty) = getConsName ty
55 getConsName c = fail $ "Unsupported constructor type: " ++ show c
56 genDSL (NewtypeD cxt name tvs mk con ds) = genDSL (DataD cxt name tvs mk [con] ds)
57 genDSL t = fail $ "mkConsClass only supports simple datatypes and not on: " ++ show t
58 instance GenDSL Name where
59 genDSL typeName = reify typeName >>= \info->case info of
60 TyConI dec -> genDSL dec
61 DataConI _ _ parent -> genDSL parent
62 t -> fail $ "mkConsClass only works on types and not on: " ++ show t
63
64 uncurry3 f (x, y, z) = f x y z
65
66 posSelector :: Type -> Type -> Bool
67 posSelector field res = vars field DS.empty `DS.isSubsetOf` vars res DS.empty
68 where
69 vars :: Type -> DS.Set String -> DS.Set String
70 vars (AppT l r) = vars l . vars r
71 vars (VarT n) = DS.insert (stringName n)
72 vars (InfixT l _ r) = vars l . vars r
73 vars (UInfixT l _ r) = vars l . vars r
74 vars (ParensT t) = vars t
75 vars (ImplicitParamT _ t) = vars t
76 vars _ = id
77
78 mkDSL :: Name -> [(Name, [(Name, Type)], Type)] -> DecsQ
79 mkDSL typeName constructors = sequence [ mkClass, mkPrinter, mkCompiler, mkInterpreter ]
80 where
81 (consNames, fields, types) = unzip3 constructors
82 selectors = [(n, [f | f@(_, ft)<-fs, posSelector ft ty], ty) | (n, fs, ty)<-constructors]
83 (_, sfields, stypes) = unzip3 selectors
84
85 mkClass :: DecQ
86 mkClass = classD (pure []) (className typeName) [PlainTV (mkName "v")] []
87 ( map (uncurry3 mkConstructor) constructors
88 ++ concat (zipWith (\ct fs->map (uncurry $ mkSelector ct) fs) types sfields)
89 ++ zipWith mkPredicate types consNames
90 )
91 where
92 v = varT $ mkName "v"
93
94 mkConstructor :: Name -> [(Name, Type)] -> Type -> DecQ
95 mkConstructor n fs res = sigD (constructorName n)
96 $ foldr (\x y->[t|$x -> $y|]) [t|$v $(pure res)|]
97 $ map (appT v . pure .snd) fs
98
99 mkSelector :: Type -> Name -> Type -> DecQ
100 mkSelector res n t = sigD (selectorName n) [t|$v $(pure res) -> $v $(pure t)|]
101
102 mkPredicate :: Type -> Name -> DecQ
103 mkPredicate res n = sigD (predicateName n) [t|$v $(pure res) -> $v Bool|]
104
105 mkPrinter :: DecQ
106 mkPrinter = instanceD (pure []) [t|$(conT $ className typeName) Printer|]
107 $ zipWith mkConstructor consNames fields
108 ++ concatMap (map (mkSelector . fst)) sfields
109 ++ map mkPredicate consNames
110 where
111 pl s = [|printLit $(lift s)|]
112
113 mkConstructor :: Name -> [(Name, Type)] -> DecQ
114 mkConstructor consName fs = do
115 fresh <- sequence [newName "f" | _<- fs]
116 fun (constructorName consName) (map varP fresh) (pcons `appE` pargs fresh)
117 where pcons = [|printCons $(lift $ stringName consName)|]
118 pargs fresh = foldl (ifx ">->") (pl "") $ map varE fresh
119
120 mkSelector :: Name -> DecQ
121 mkSelector n = fun (selectorName n) [] [|\x->x >> $(pl ('.':stringName n))|]
122
123 mkPredicate :: Name -> DecQ
124 mkPredicate n = fun (predicateName n) []
125 [|\x-> $(pl $ stringName $ predicateName n) >-> x|]
126
127 mkCompiler :: DecQ
128 mkCompiler = instanceD (pure []) [t|$(conT $ className typeName) Compiler|]
129 $ zipWith3 mkConstructor [0..] consNames fields
130 ++ concatMap (zipWith mkSelector [0..]. map fst) sfields
131 ++ zipWith mkPredicate [0..] consNames
132 where
133 mkConstructor :: Integer -> Name -> [(Name, Type)] -> DecQ
134 mkConstructor consNum consName fs = do
135 fresh <- sequence [newName "f" | _<-fs]
136 fun (constructorName consName) (map varP fresh)
137 $ ifx "*>" pushCons $ ifx "<*" (mkBody $ map varE fresh) storeHeap
138 where storeHeap = [|instr [Sth $ 1 + $(lift $ length fs)]|]
139 mkBody = foldl (ifx "<*>") [|pure $(conE consName)|]
140 pushCons = [|instr [Push $(lift consNum)]|]
141
142 mkSelector :: Integer -> Name -> DecQ
143 mkSelector fn n = fun (selectorName n) [] [|\x->x >> instr [Ldh $(lift fn)]|]
144
145 mkPredicate :: Integer -> Name -> DecQ
146 mkPredicate consNum consName = fun (predicateName consName) []
147 [|\x->x >> instr [Ldh (-1), Push $(lift consNum), Eq]|]
148
149 mkInterpreter :: DecQ
150 mkInterpreter = instanceD (pure []) [t|$(conT $ className typeName) Interpreter|]
151 $ zipWith mkConstructor consNames fields
152 ++ concatMap (\(cn, fs, _)->zipWith (mkSelector cn (length fs)) [0..] (map fst fs)) selectors
153 ++ zipWith mkPredicate consNames fields
154 where
155 mkConstructor :: Name -> [(Name, Type)] -> DecQ
156 mkConstructor consName fs = do
157 fresh <- sequence [newName "f" | _<-fs]
158 fun (constructorName consName) (map varP fresh)
159 $ foldl (ifx "<*>") [|pure $(conE consName)|] (map varE fresh)
160
161 mkSelector :: Name -> Int -> Int -> Name -> DecQ
162 mkSelector consName ftotal fnum n = do
163 fresh <- newName "f"
164 fun (selectorName n) [varP fresh] $
165 [|$(varE fresh) >>= $(lamCaseE $ mkMatch:wilds)|]
166 where
167 mkMatch = do
168 fresh <- newName "e"
169 match (conP consName [if fnum == i then varP fresh else wildP | i<-[0..ftotal-1]])
170 (normalB [|pure $(varE fresh)|]) []
171 wilds = if length consNames == 1 then [] else
172 [match wildP (normalB [|fail "Exhausted case"|]) []]
173
174 mkPredicate :: Name -> [(Name, Type)] -> DecQ
175 mkPredicate n fs = fun (predicateName n) []
176 $ if length consNames == 1 then [|\_->true|] else
177 [|\x->x >>= \p->case p of $(conP n [wildP | _<-fs]) -> true; _ -> false|]