{-# LANGUAGE TemplateHaskellQuotes #-}
module Optics.TH.Internal.Sum
( makePrisms
, makePrismLabels
, makeClassyPrisms
, makeDecPrisms
) where
import Data.Char
import Data.List
import Data.Maybe
import Data.Traversable
import Language.Haskell.TH
import qualified Data.Map as M
import qualified Data.Set as S
import qualified Language.Haskell.TH.Datatype as D
import Data.Set.Optics
import Language.Haskell.TH.Optics.Internal
import Optics.Core hiding (cons)
import Optics.TH.Internal.Utils
makePrisms :: Name -> DecsQ
makePrisms = makePrisms' True
makeClassyPrisms :: Name -> DecsQ
makeClassyPrisms = makePrisms' False
makePrismLabels :: Name -> DecsQ
makePrismLabels typeName = do
info <- D.reifyDatatype typeName
let cons = map normalizeCon $ D.datatypeCons info
catMaybes <$> traverse (makeLabel info cons) cons
where
makeLabel :: D.DatatypeInfo -> [NCon] -> NCon -> Q (Maybe Dec)
makeLabel info cons con = do
stab@(Stab cx otype s t a b) <- computeOpticType labelConfig ty cons con
case otype of
ReviewType -> pure Nothing
_ -> do
(a', cxtA) <- eqSubst a "a"
(b', cxtB) <- eqSubst b "b"
let label = nameBase . prismName $ view nconName con
instHead = pure $ conAppsT ''LabelOptic
[LitT (StrTyLit label), ConT $ opticTypeToTag otype, s, t, a', b']
Just <$> instanceD (pure $ cx ++ [cxtA, cxtB]) instHead (fun stab 'labelOptic)
where
ty = D.datatypeType info
isNewtype = D.datatypeVariant info == D.Newtype
opticTypeToTag IsoType = ''An_Iso
opticTypeToTag PrismType = ''A_Prism
opticTypeToTag ReviewType = ''A_Review
fun :: Stab -> Name -> [DecQ]
fun stab n = valD (varP n) (normalB $ funDef stab) [] : inlinePragma n
funDef :: Stab -> ExpQ
funDef stab
| isNewtype = varE 'coerced
| otherwise = makeConOpticExp stab cons con
makePrisms' :: Bool -> Name -> DecsQ
makePrisms' normal typeName =
do info <- D.reifyDatatype typeName
let cls | normal = Nothing
| otherwise = Just (D.datatypeName info)
cons = D.datatypeCons info
makeConsPrisms info (map normalizeCon cons) cls
makeDecPrisms :: Bool -> Dec -> DecsQ
makeDecPrisms normal dec =
do info <- D.normalizeDec dec
let cls | normal = Nothing
| otherwise = Just (D.datatypeName info)
cons = D.datatypeCons info
makeConsPrisms info (map normalizeCon cons) cls
makeConsPrisms :: D.DatatypeInfo -> [NCon] -> Maybe Name -> DecsQ
makeConsPrisms info cons Nothing = fmap concat . for cons $ \con -> do
stab <- computeOpticType defaultConfig ty cons con
let n = prismName $ view nconName con
body = if isNewtype
then varE 'coerced
else makeConOpticExp stab cons con
sequenceA $
[ sigD n (close (stabToType stab))
, valD (varP n) (normalB body) []
] ++ inlinePragma n
where
ty = D.datatypeType info
isNewtype = D.datatypeVariant info == D.Newtype
makeConsPrisms info cons (Just typeName) =
sequenceA
[ makeClassyPrismClass ty className methodName cons
, makeClassyPrismInstance ty className methodName cons
]
where
ty = D.datatypeType info
className = mkName ("As" ++ nameBase typeName)
methodName = prismName typeName
data StabConfig = StabConfig
{ scAllowPhantomsChange :: Bool
, scAllowIsos :: Bool
}
defaultConfig :: StabConfig
defaultConfig = StabConfig
{ scAllowPhantomsChange = True
, scAllowIsos = True
}
classyConfig :: StabConfig
classyConfig = StabConfig
{ scAllowPhantomsChange = True
, scAllowIsos = False
}
labelConfig :: StabConfig
labelConfig = StabConfig
{ scAllowPhantomsChange = False
, scAllowIsos = True
}
data OpticType = IsoType | PrismType | ReviewType
data Stab = Stab Cxt OpticType Type Type Type Type
simplifyStab :: Stab -> Stab
simplifyStab (Stab cx ty _ t _ b) = Stab cx ty t t b b
stabSimple :: Stab -> Bool
stabSimple (Stab _ _ s t a b) = s == t && a == b
stabToType :: Stab -> Type
stabToType stab@(Stab cx ty s t a b) = ForallT vs cx $
case ty of
IsoType | stabSimple stab -> ''Iso' `conAppsT` [s,a]
| otherwise -> ''Iso `conAppsT` [s,t,a,b]
PrismType | stabSimple stab -> ''Prism' `conAppsT` [s,a]
| otherwise -> ''Prism `conAppsT` [s,t,a,b]
ReviewType -> ''Review `conAppsT` [t,b]
where
vs = map PlainTV
$ nub
$ toListOf typeVars cx
stabType :: Stab -> OpticType
stabType (Stab _ o _ _ _ _) = o
computeOpticType :: StabConfig -> Type -> [NCon] -> NCon -> Q Stab
computeOpticType conf t cons con =
do let cons' = delete con cons
if null (_nconVars con)
then computePrismType conf t (view nconCxt con) cons' con
else computeReviewType t (view nconCxt con) (view nconTypes con)
computeReviewType :: Type -> Cxt -> [Type] -> Q Stab
computeReviewType t cx tys = do
b <- toTupleT (map return tys)
return (Stab cx ReviewType t t b b)
computePrismType :: StabConfig -> Type -> Cxt -> [NCon] -> NCon -> Q Stab
computePrismType conf t cx cons con = do
let ts = view nconTypes con
fixed = setOf typeVars cons
phantoms = setOf typeVars t S.\\ (setOf typeVars con `S.union` fixed)
unbound = if scAllowPhantomsChange conf
then setOf typeVars t S.\\ fixed
else setOf typeVars t S.\\ fixed S.\\ phantoms
sub <- sequenceA (M.fromSet (newName . nameBase) unbound)
b <- toTupleT (map return ts)
a <- toTupleT (map return (substTypeVars sub ts))
let s = substTypeVars sub t
otype = if null cons && scAllowIsos conf
then IsoType
else PrismType
return (Stab cx otype s t a b)
makeConOpticExp :: Stab -> [NCon] -> NCon -> ExpQ
makeConOpticExp stab cons con =
case stabType stab of
IsoType -> makeConIsoExp con
PrismType -> makeConPrismExp stab cons con
ReviewType -> makeConReviewExp con
makeConPrismExp ::
Stab ->
[NCon] ->
NCon ->
ExpQ
makeConPrismExp stab cons con = appsE [varE 'prism, reviewer, remitter]
where
ts = view nconTypes con
fields = length ts
conName = view nconName con
reviewer = makeReviewer conName fields
remitter | stabSimple stab = makeSimpleRemitter conName fields
| otherwise = makeFullRemitter cons conName
makeConIsoExp :: NCon -> ExpQ
makeConIsoExp con = appsE [varE 'iso, remitter, reviewer]
where
conName = view nconName con
fields = length (view nconTypes con)
reviewer = makeReviewer conName fields
remitter = makeIsoRemitter conName fields
makeConReviewExp :: NCon -> ExpQ
makeConReviewExp con = appE (varE 'unto) reviewer
where
conName = view nconName con
fields = length (view nconTypes con)
reviewer = makeReviewer conName fields
makeReviewer :: Name -> Int -> ExpQ
makeReviewer conName fields =
do xs <- newNames "x" fields
lam1E (toTupleP (map varP xs))
(conE conName `appsE1` map varE xs)
makeSimpleRemitter :: Name -> Int -> ExpQ
makeSimpleRemitter conName fields =
do x <- newName "x"
xs <- newNames "y" fields
let matches =
[ match (conP conName (map varP xs))
(normalB (appE (conE 'Right) (toTupleE (map varE xs))))
[]
, match wildP (normalB (appE (conE 'Left) (varE x))) []
]
lam1E (varP x) (caseE (varE x) matches)
makeFullRemitter :: [NCon] -> Name -> ExpQ
makeFullRemitter cons target =
do x <- newName "x"
lam1E (varP x) (caseE (varE x) (map mkMatch cons))
where
mkMatch (NCon conName _ _ n) =
do xs <- newNames "y" (length n)
match (conP conName (map varP xs))
(normalB
(if conName == target
then appE (conE 'Right) (toTupleE (map varE xs))
else appE (conE 'Left) (conE conName `appsE1` map varE xs)))
[]
makeIsoRemitter :: Name -> Int -> ExpQ
makeIsoRemitter conName fields =
do xs <- newNames "x" fields
lam1E (conP conName (map varP xs))
(toTupleE (map varE xs))
makeClassyPrismClass ::
Type ->
Name ->
Name ->
[NCon] ->
DecQ
makeClassyPrismClass t className methodName cons =
do r <- newName "r"
let methodType = appsT (conT ''Prism') [varT r,return t]
methodss <- traverse (mkMethod (VarT r)) cons'
classD (cxt[]) className (map PlainTV (r : vs)) (fds r)
( sigD methodName methodType
: map return (concat methodss)
)
where
mkMethod r con =
do Stab cx o _ _ _ b <- computeOpticType classyConfig t cons con
let stab' = Stab cx o r r b b
defName = view nconName con
body = appsE [varE '(%), varE methodName, varE defName]
sequenceA
[ sigD defName (return (stabToType stab'))
, valD (varP defName) (normalB body) []
]
cons' = map (over nconName prismName) cons
vs = S.toList (setOf typeVars t)
fds r
| null vs = []
| otherwise = [FunDep [r] vs]
makeClassyPrismInstance ::
Type ->
Name ->
Name ->
[NCon] ->
DecQ
makeClassyPrismInstance s className methodName cons =
do let vs = S.toList (setOf typeVars s)
cls = className `conAppsT` (s : map VarT vs)
instanceD (cxt[]) (return cls)
( valD (varP methodName)
(normalB (varE 'castOptic `appE` varE 'equality)) []
: [ do stab <- computeOpticType classyConfig s cons con
let stab' = simplifyStab stab
valD (varP (prismName conName))
(normalB (makeConOpticExp stab' cons con)) []
| con <- cons
, let conName = view nconName con
]
)
data NCon = NCon
{ _nconName :: Name
, _nconVars :: [Name]
, _nconCxt :: Cxt
, _nconTypes :: [Type]
}
deriving (Eq)
instance HasTypeVars NCon where
typeVarsEx s = traversalVL $ \f (NCon x vars y z) ->
let s' = foldl' (flip S.insert) s vars
in NCon x vars <$> traverseOf (typeVarsEx s') f y
<*> traverseOf (typeVarsEx s') f z
nconName :: Lens' NCon Name
nconName = lensVL $ \f x -> fmap (\y -> x {_nconName = y}) (f (_nconName x))
nconCxt :: Lens' NCon Cxt
nconCxt = lensVL $ \f x -> fmap (\y -> x {_nconCxt = y}) (f (_nconCxt x))
nconTypes :: Lens' NCon [Type]
nconTypes = lensVL $ \f x -> fmap (\y -> x {_nconTypes = y}) (f (_nconTypes x))
normalizeCon :: D.ConstructorInfo -> NCon
normalizeCon info = NCon (D.constructorName info)
(D.tvName <$> D.constructorVars info)
(D.constructorContext info)
(D.constructorFields info)
prismName :: Name -> Name
prismName n = case nameBase n of
[] -> error "prismName: empty name base?"
x:xs | isUpper x -> mkName ('_':x:xs)
| otherwise -> mkName ('.':x:xs)
close :: Type -> TypeQ
close t = forallT (map PlainTV (S.toList vs)) (cxt[]) (return t)
where
vs = setOf typeVars t