{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE TemplateHaskell #-}
{-# OPTIONS_GHC -cpp #-}
module Generics.MRSOP.TH (deriveFamily, genFamilyDebug) where
import Data.Function (on)
import Data.Char (ord , isAlphaNum)
import Data.List (sortBy, foldl')
import Control.Monad
import Control.Monad.State
import Control.Monad.Writer
import Control.Monad.Identity
import Language.Haskell.TH hiding (match)
import Language.Haskell.TH.Syntax (liftString)
import Generics.MRSOP.Util
import Generics.MRSOP.Opaque
import Generics.MRSOP.Base.Class
import Generics.MRSOP.Base.NS
import Generics.MRSOP.Base.NP
import Generics.MRSOP.Base.Universe hiding (match)
import qualified Generics.MRSOP.Base.Metadata as Meta
import qualified Data.Map as M
deriveFamily :: Q Type -> Q [Dec]
deriveFamily t
= do sty <- t >>= convertType
(_ , (Idxs _ m)) <- runIdxsM (reifySTy sty)
m' <- mapM extractDTI (M.toList m)
let final = sortBy (compare `on` second) m'
dbg <- genFamilyDebug sty final
res <- genFamily sty final
return (dbg ++ res)
where
second (_ , x , _) = x
extractDTI (sty , (ix , Nothing))
= fail $ "Type " ++ show sty ++ " has no datatype information."
extractDTI (sty , (ix , Just dti))
= return (sty , ix , dti)
type DataName = Name
type ConName = Name
type FieldName = Name
type Args = [Name]
data DTI ty
= ADT DataName Args [ CI ty ]
| New DataName Args (CI ty)
deriving (Eq , Show , Functor)
data CI ty
= Normal ConName [ty]
| Infix ConName Fixity ty ty
| Record ConName [ (FieldName , ty) ]
deriving (Eq , Show , Functor)
ciMapM :: (Monad m) => (ty -> m tw) -> CI ty -> m (CI tw)
ciMapM f (Normal name tys) = Normal name <$> mapM f tys
ciMapM f (Infix name x l r) = Infix name x <$> f l <*> f r
ciMapM f (Record name tys) = Record name <$> mapM (rstr . (id *** f)) tys
where
rstr (a , b) = b >>= return . (a,)
dtiMapM :: (Monad m) => (ty -> m tw) -> DTI ty -> m (DTI tw)
dtiMapM f (ADT name args ci) = ADT name args <$> mapM (ciMapM f) ci
dtiMapM f (New name args ci) = New name args <$> ciMapM f ci
dti2ci :: DTI ty -> [CI ty]
dti2ci (ADT _ _ cis) = cis
dti2ci (New _ _ ci) = [ ci ]
ci2ty :: CI ty -> [ty]
ci2ty (Normal _ tys) = tys
ci2ty (Infix _ _ a b) = [a , b]
ci2ty (Record _ tys) = map snd tys
ciName :: CI ty -> Name
ciName (Normal n _) = n
ciName (Infix n _ _ _) = n
ciName (Record n _) = n
ci2Pat :: CI ty -> Q ([Name] , Pat)
ci2Pat ci
= do ns <- mapM (const (newName "x")) (ci2ty ci)
return (ns , (ConP (ciName ci) (map VarP ns)))
ci2Exp :: CI ty -> Q ([Name], Exp)
ci2Exp ci
= do ns <- mapM (const (newName "y")) (ci2ty ci)
return (ns , foldl (\e n -> AppE e (VarE n)) (ConE (ciName ci)) ns)
data STy
= AppST STy STy
| VarST Name
| ConST Name
deriving (Eq , Show, Ord)
styFold :: (a -> a -> a) -> (Name -> a) -> (Name -> a) -> STy -> a
styFold app var con (AppST a b) = app (styFold app var con a) (styFold app var con b)
styFold app var con (VarST n) = var n
styFold app var con (ConST n) = con n
isClosed :: STy -> Bool
isClosed = styFold (&&) (const False) (const True)
convertType :: (Monad m) => Type -> m STy
convertType (AppT a b) = AppST <$> convertType a <*> convertType b
convertType (SigT t _) = convertType t
convertType (VarT n) = return (VarST n)
convertType (ConT n) = return (ConST n)
convertType (ParensT t) = convertType t
convertType ListT = return (ConST (mkName "[]"))
convertType (TupleT n) = return (ConST (mkName $ '(':replicate (n-1) ',' ++ ")"))
convertType t = fail ("convertType: Unsupported Type: " ++ show t)
trevnocType :: STy -> Type
trevnocType (AppST a b) = AppT (trevnocType a) (trevnocType b)
trevnocType (VarST n) = VarT n
trevnocType (ConST n)
| n == mkName "[]" = ListT
| isTupleN n = TupleT $ length (show n) - 1
| otherwise = ConT n
where isTupleN n = take 2 (show n) == "(,"
stySubst :: STy -> Name -> STy -> STy
stySubst (AppST a b) m n = AppST (stySubst a m n) (stySubst b m n)
stySubst (ConST a) m n = ConST a
stySubst (VarST x) m n
| x == m = n
| otherwise = VarST x
styReduce :: [(Name , STy)] -> STy -> STy
styReduce parms t = foldr (\(n , m) ty -> stySubst ty n m) t parms
styFlatten :: STy -> (STy , [STy])
styFlatten (AppST a b) = id *** (++ [b]) $ styFlatten a
styFlatten sty = (sty , [])
reifyDec :: Name -> Q Dec
reifyDec name =
do info <- reify name
case info of TyConI dec -> return dec
_ -> fail $ show name ++ " is not a declaration"
argInfo :: TyVarBndr -> Name
argInfo (PlainTV n) = n
argInfo (KindedTV n _) = n
decInfo :: Dec -> Q (DTI STy)
decInfo (TySynD name args ty) = fail "Type Synonyms not supported"
decInfo (DataD _ name args _ cons _) = ADT name (map argInfo args) <$> mapM conInfo cons
decInfo (NewtypeD _ name args _ con _) = New name (map argInfo args) <$> conInfo con
decInfo _ = fail "Only type declarations are supported"
conInfo :: Con -> Q (CI STy)
conInfo (NormalC name ty) = Normal name <$> mapM (convertType . snd) ty
conInfo (RecC name ty) = Record name <$> mapM (\(s , _ , t) -> (s,) <$> convertType t) ty
conInfo (InfixC l name r)
= do info <- reifyFixity name
let fixity = maybe defaultFixity id $ info
Infix name fixity <$> convertType (snd l) <*> convertType (snd r)
conInfo (ForallC _ _ _) = fail "Existentials not supported"
#if MIN_VERSION_template_haskell(2,11,0)
conInfo (GadtC _ _ _) = fail "GADTs not supported"
conInfo (RecGadtC _ _ _) = fail "GADTs not supported"
#endif
dtiReduce :: DTI STy -> [STy] -> DTI STy
dtiReduce (ADT name args cons) parms
= ADT name [] (map (ciReduce (zip args parms)) cons)
dtiReduce (New name args con) parms
= New name [] (ciReduce (zip args parms) con)
ciReduce :: [(Name , STy)] -> CI STy -> CI STy
ciReduce parms ci = runIdentity (ciMapM (return . styReduce parms) ci)
data IK
= AtomI Int
| AtomK Name
deriving (Eq , Show)
ikElim :: (Int -> a) -> (Name -> a) -> IK -> a
ikElim i k (AtomI n) = i n
ikElim i k (AtomK n) = k n
data Idxs
= Idxs { idxsNext :: Int
, idxsMap :: M.Map STy (Int , Maybe (DTI IK))
}
deriving (Show)
onMap :: (M.Map STy (Int , Maybe (DTI IK)) -> M.Map STy (Int , Maybe (DTI IK)))
-> Idxs -> Idxs
onMap f (Idxs n m) = Idxs n (f m)
type IdxsM = StateT Idxs
runIdxsM :: (Monad m) => IdxsM m a -> m (a , Idxs)
runIdxsM = flip runStateT (Idxs 0 M.empty)
type M = IdxsM Q
indexOf :: (Monad m) => STy -> IdxsM m Int
indexOf name
= do st <- get
case M.lookup name (idxsMap st) of
Just i -> return (fst i)
Nothing -> let i = idxsNext st
in put (Idxs (i + 1) (M.insert name (i , Nothing) (idxsMap st)))
>> return i
register :: (Monad m) => STy -> DTI IK -> IdxsM m ()
register ty info = indexOf ty
>> modify (onMap $ M.adjust (id *** const (Just info)) ty)
lkup :: (Monad m) => STy -> IdxsM m (Maybe (Int , Maybe (DTI IK)))
lkup ty = M.lookup ty . idxsMap <$> get
lkupInfo :: (Monad m) => STy -> IdxsM m (Maybe Int)
lkupInfo ty = fmap fst <$> lkup ty
lkupData :: (Monad m) => STy -> IdxsM m (Maybe (DTI IK))
lkupData ty = join . fmap snd <$> lkup ty
hasData :: (Monad m) => STy -> IdxsM m Bool
hasData ty = maybe False (const True) <$> lkupData ty
reifySTy :: STy -> M ()
reifySTy sty
= do ix <- indexOf sty
uncurry go (styFlatten sty)
where
go :: STy -> [STy] -> M ()
go (ConST name) args
= do dec <- lift (reifyDec name >>= decInfo)
let res = dtiReduce dec args
(final , todo) <- runWriterT $ dtiMapM convertSTy res
register sty final
mapM_ reifySTy todo
convertSTy :: STy -> WriterT [STy] M IK
convertSTy ty
| ty == sty = AtomI <$> lift (indexOf ty)
| isClosed ty
= case makeCons ty of
Just k -> return (AtomK k)
Nothing -> do ix <- lift (indexOf ty)
hasDti <- lift (hasData ty)
when (not hasDti) (tell [ty])
return (AtomI ix)
| otherwise
= fail $ "I can't convert type variable " ++ show ty
++ " when converting " ++ show sty
makeCons :: STy -> Maybe Name
makeCons (ConST n) = M.lookup n consTable
makeCons _ = Nothing
consTable = M.fromList . map (id *** mkName)
$ [ ( ''Int , "KInt")
, ( ''Char , "KChar")
, ( ''Integer , "KInteger")
, ( ''Float , "KFloat")
, ( ''Bool , "KBool")
, ( ''String , "KString")
, ( ''Double , "KDouble")
]
type Input = [(STy , Int , DTI IK)]
tlListOf :: (a -> Type) -> [a] -> Type
tlListOf f = foldr (\h r -> AppT (AppT PromotedConsT (f h)) r) PromotedNilT
int2Type :: Int -> Type
int2Type 0 = tyZ
int2Type n = AppT tyS (int2Type (n - 1))
int2TySynName :: Int -> Name
int2TySynName i = mkName $ "D" ++ show i ++ "_"
int2SNatPat :: Int -> Pat
int2SNatPat 0 = ConP (mkName "SZ") []
int2SNatPat n = ConP (mkName "SS") [int2SNatPat $ n-1]
tyS = PromotedT (mkName "S")
tyZ = PromotedT (mkName "Z")
tyI = PromotedT (mkName "I")
tyK = PromotedT (mkName "K")
inputToCodes :: Input -> Q Type
inputToCodes = return . tlListOf dti2Codes . map third
where
third (_ , _ , x) = x
dti2Codes :: DTI IK -> Type
dti2Codes = tlListOf ci2Codes . dti2ci
ci2Codes :: CI IK -> Type
ci2Codes = tlListOf ik2Codes . ci2ty
ik2Codes :: IK -> Type
ik2Codes (AtomI n) = AppT tyI $ int2Type n
ik2Codes (AtomK k) = AppT tyK $ PromotedT k
inputToTySynNums :: Input -> Q [Dec]
inputToTySynNums input
= let maxI = maximum $ map (localMax . third) input
in return $ map genTySynNum [0..maxI]
where
third (_ , _ , x) = x
localMax :: DTI IK -> Int
localMax = foldr (\ci aux -> aux `max` getMaxIdx (ci2ty ci)) 0 . dti2ci
getMaxIdx :: [IK] -> Int
getMaxIdx = foldr (ikElim max (const id)) 0
genTySynNum i = TySynD (int2TySynName i) [] (int2Type i)
inputToFam :: Input -> Q Type
inputToFam = return . tlListOf trevnocType . map first
where
first (x , _ , _) = x
styToName :: STy -> Name
styToName = mkName . styFold (++) nameBase (fixList . nameBase)
where
fixList :: String -> String
fixList n
| n == "[]" = "List"
| take 2 n == "(," = "Tup" ++ show (length n - 2)
| otherwise = n
onBaseName :: (String -> String) -> Name -> Name
onBaseName f = mkName . f . nameBase
codesName :: STy -> Q Name
codesName = return . onBaseName ("Codes" ++) . styToName
familyName :: STy -> Q Name
familyName = return . onBaseName ("Fam" ++) . styToName
genPiece1 :: STy -> Input -> Q [Dec]
genPiece1 first ls
= do
codes <- TySynD <$> codesName first
<*> return []
<*> inputToCodes ls
fam <- TySynD <$> familyName first
<*> return []
<*> inputToFam ls
return [fam , codes]
idxPatSynName :: STy -> Name
idxPatSynName = styToName . (AppST (ConST (mkName "Idx")))
idxPatSyn :: STy -> Pat
idxPatSyn = flip ConP [] . idxPatSynName
htPatSynName :: Int -> CI IK -> Name
htPatSynName dtiIx ci = mkName . translate . nameBase . ciName $ ci
where
translate = ("Pat" ++) . foldl' (\str l -> str ++ tr l ) (show dtiIx)
tr l | isAlphaNum l = l:[]
| otherwise = show $ ord l
htPatSynExp :: Int -> CI IK -> Q Exp
htPatSynExp dtiIx = return . ConE . htPatSynName dtiIx
genIdxPatSyn :: STy -> Int -> Q Dec
genIdxPatSyn sty ix
= return (PatSynD (idxPatSynName sty) (PrefixPatSyn []) ImplBidir (int2SNatPat ix))
genHereTherePatSyn :: STy -> Input -> Q [Dec]
genHereTherePatSyn first ls
= flat . concat <$> mapM (\(_ , ix , dti) -> genHereThereFor ix dti) ls
where
flat = foldl' (\ac (x , y) -> x:y:ac) []
third (_ , _, x) = x
famName = ConT <$> familyName first
inj :: Int -> Q Pat -> Q Pat
inj 0 p = [p| Here $p |]
inj n p = [p| There ( $(inj (n-1) p) ) |]
genHereThereFor :: Int -> DTI IK -> Q [(Dec , Dec)]
genHereThereFor dtiIx dti
= do let dtiCode = dti2Codes dti
let cisIx = zip [0..] (dti2ci dti)
forM cisIx $ \ (ix , ci)
-> (,) <$> genHT_decl dtiCode dtiIx ix ci
<*> genHT_def dtiIx ix ci
genHT_decl dtiCode dtiIx ix ci
= PatSynSigD (htPatSynName dtiIx ci)
<$> [t| PoA Singl (El $famName) $(return $ ci2Codes ci)
-> NS (PoA Singl (El $famName)) $(return dtiCode) |]
genHT_def dtiIx ix ci
= do var <- newName "d"
PatSynD (htPatSynName dtiIx ci) (PrefixPatSyn [var]) ImplBidir
<$> inj ix (return $ VarP var)
genPiece2 :: STy -> Input -> Q [Dec]
genPiece2 first ls
= do p21 <- mapM (\(sty , ix , dti) -> genIdxPatSyn sty ix) ls
p211 <- genHereTherePatSyn first ls
return $ p21 ++ p211
genPiece3 :: STy -> Input -> Q Dec
genPiece3 first ls
= head <$> [d| instance Family Singl
$(ConT <$> familyName first)
$(ConT <$> codesName first)
where sfrom' = $(genPiece3_1 ls)
sto' = $(genPiece3_2 ls) |]
ci2PatExp :: Int -> CI IK -> Q (Pat , Exp)
ci2PatExp dtiIx ci
= do (vars , pat) <- ci2Pat ci
bdy <- [e| Rep $(inj $ genBdy (zip vars (ci2ty ci))) |]
return (ConP (mkName "El") [pat] , bdy)
where
inj :: Q Exp -> Q Exp
inj e = [e| $(htPatSynExp dtiIx ci) $e |]
genBdy :: [(Name , IK)] -> Q Exp
genBdy [] = [e| NP0 |]
genBdy (x : xs) = [e| $(mkHead x) :* ( $(genBdy xs) ) |]
mkHead (x , AtomI _) = [e| NA_I (El $(return (VarE x))) |]
mkHead (x , AtomK k) = [e| NA_K $(return (AppE (ConE (mkK k)) (VarE x))) |]
mkK k = mkName $ 'S':tail (nameBase k)
ci2ExpPat :: Int -> CI IK -> Q (Pat , Exp)
ci2ExpPat dtiIx ci
= do (vars , exp) <- ci2Exp ci
pat <- [p| Rep $(inj $ genBdy (zip vars (ci2ty ci))) |]
return (pat , AppE (ConE $ mkName "El") exp)
where
inj :: Q Pat -> Q Pat
inj e = ConP (htPatSynName dtiIx ci) . (:[]) <$> e
genBdy :: [(Name , IK)] -> Q Pat
genBdy [] = [p| NP0 |]
genBdy (x : xs) = [p| $(mkHead x) :* ( $(genBdy xs) ) |]
mkHead (x , AtomI _) = [p| NA_I (El $(return (VarP x))) |]
mkHead (x , AtomK k) = [p| NA_K $(return (ConP (mkK k) [VarP x])) |]
mkK k = mkName $ 'S':tail (nameBase k)
match :: Pat -> Exp -> Match
match pat bdy = Match pat (NormalB bdy) []
matchAll :: [Match] -> [Match]
matchAll = (++ [match WildP err])
where
err = AppE (VarE (mkName "error")) (LitE (StringL "matchAll"))
genPiece3_1 :: Input -> Q Exp
genPiece3_1 input
= LamCaseE <$> mapM (\(sty , ix , dti) -> clauseForIx sty ix dti) input
where
clauseForIx :: STy -> Int -> DTI IK -> Q Match
clauseForIx sty ix dti = match (idxPatSyn sty)
<$> (LamCaseE <$> genMatchFor ix dti)
genMatchFor :: Int -> DTI IK -> Q [Match]
genMatchFor ix dti = map (uncurry match) <$> mapM (ci2PatExp ix) (dti2ci dti)
genPiece3_2 :: Input -> Q Exp
genPiece3_2 input
= LamCaseE . matchAll <$> mapM (\(sty , ix , dti) -> clauseForIx sty ix dti) input
where
clauseForIx :: STy -> Int -> DTI IK -> Q Match
clauseForIx sty ix dti = match (idxPatSyn sty)
<$> (LamCaseE . matchAll <$> genMatchFor ix dti)
genMatchFor :: Int -> DTI IK -> Q [Match]
genMatchFor ix dti = map (uncurry match) <$> mapM (ci2ExpPat ix) (dti2ci dti)
genPiece4 :: STy -> Input -> Q [Dec]
genPiece4 first ls = concat <$> mapM genDatatypeInfoInstance ls
where
genDatatypeInfoInstance :: (STy , Int , DTI IK) -> Q [Dec]
genDatatypeInfoInstance (sty , idx , dti)
= [d| instance Meta.HasDatatypeInfo Singl $(ConT <$> familyName first)
$(ConT <$> codesName first)
$(return (int2Type idx))
where datatypeInfo _ _ = $(genInfo sty dti) |]
genMod :: Name -> Q Exp
genMod = strlit . maybe "" id . nameModule
strlit :: String -> Q Exp
strlit = return . LitE . StringL
genDatatypeName :: STy -> Q Exp
genDatatypeName = styFold (\e1 e2 -> [e| ( $e1 Meta.:@: $e2 ) |])
(\n -> [e| Meta.Name $(strlit (nameBase n)) |] )
(\n -> [e| Meta.Name $(strlit (nameBase n)) |] )
genInfo :: STy -> DTI IK -> Q Exp
genInfo sty (ADT name _ cis)
= [e| Meta.ADT $(genMod name) $(genDatatypeName sty) $(genConInfoNP cis) |]
genInfo sty (New name _ ci)
= [e| Meta.New $(genMod name) $(genDatatypeName sty) $(genConInfo ci) |]
genConInfo :: CI IK -> Q Exp
genConInfo (Record conname fields)
= [e| Meta.Record $(strlit $ nameBase conname) $(genFieldInfo $ map fst fields) |]
genConInfo (Normal conname _)
= [e| Meta.Constructor $(strlit $ nameBase conname) |]
genConInfo (Infix conname fix _ _)
= [e| Meta.Infix $(strlit $ nameBase conname) $(genAssoc fix) $(genFix fix) |]
where
genAssoc (Fixity _ InfixL) = [e| Meta.LeftAssociative |]
genAssoc (Fixity _ InfixR) = [e| Meta.RightAssociative |]
genAssoc (Fixity _ InfixN) = [e| Meta.NotAssociative |]
genFix (Fixity i _) = return . LitE . IntegerL . fromIntegral $ i
genFieldInfo :: [ FieldName ] -> Q Exp
genFieldInfo [] = [e| NP0 |]
genFieldInfo (f:fs) = [e| Meta.FieldInfo $(strlit . nameBase $ f) :* ( $(genFieldInfo fs) ) |]
genConInfoNP :: [ CI IK ] -> Q Exp
genConInfoNP [] = [e| NP0 |]
genConInfoNP (ci:cis) = [e| $(genConInfo ci) :* ( $(genConInfoNP cis) ) |]
genFamily :: STy -> Input -> Q [Dec]
genFamily first ls
= do p1 <- genPiece1 first ls
p2 <- genPiece2 first ls
p3 <- genPiece3 first ls
p4 <- genPiece4 first ls
return $ p1 ++ p2 ++ [p3] ++ p4
genFamilyDebug :: STy -> [(STy , Int , DTI IK)] -> Q [Dec]
genFamilyDebug _ ms = concat <$> mapM genDec ms
where
genDec :: (STy , Int , DTI IK) -> Q [Dec]
genDec (sty , ix , dti)
= [d| $( genPat ix ) = $(mkBody dti) |]
mkBody :: DTI IK -> Q Exp
mkBody dti = [e| $(liftString $ show dti) |]
genPat :: Int -> Q Pat
genPat n = genName n >>= \name -> return (VarP name)
genName :: Int -> Q Name
genName n = return (mkName $ "tyInfo_" ++ show n)