{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TemplateHaskellQuotes #-}
module Optics.TH.Internal.Product
( LensRules(..)
, FieldNamer
, DefName(..)
, ClassyNamer
, makeFieldOptics
, makeFieldOpticsForDec
, makeFieldOpticsForDec'
, makeFieldLabelsWith
, makeFieldLabelsForDec
, HasFieldClasses
) where
import Control.Monad
import Control.Monad.State
import Data.Either
import Data.List
import Data.Maybe
import Language.Haskell.TH
import qualified Data.Map as M
import qualified Data.Set as S
import qualified Data.Traversable as T
import qualified Language.Haskell.TH.Datatype as D
import Data.Either.Optics
import Data.Tuple.Optics
import Data.Set.Optics
import Language.Haskell.TH.Optics.Internal
import Optics.Core hiding (cons)
import Optics.TH.Internal.Utils
typeSelf :: Traversal' Type Type
typeSelf = traversalVL $ \f -> \case
ForallT tyVarBndrs ctx ty ->
let go (KindedTV nam kind) = KindedTV <$> pure nam <*> f kind
go (PlainTV nam) = pure (PlainTV nam)
in ForallT <$> traverse go tyVarBndrs <*> traverse f ctx <*> f ty
AppT ty1 ty2 -> AppT <$> f ty1 <*> f ty2
SigT ty kind -> SigT <$> f ty <*> f kind
InfixT ty1 nam ty2 -> InfixT <$> f ty1 <*> pure nam <*> f ty2
UInfixT ty1 nam ty2 -> UInfixT <$> f ty1 <*> pure nam <*> f ty2
ParensT ty -> ParensT <$> f ty
ty -> pure ty
makeFieldOptics :: LensRules -> Name -> DecsQ
makeFieldOptics rules = (`evalStateT` S.empty) . makeFieldOpticsForDatatype rules <=< D.reifyDatatype
makeFieldOpticsForDec :: LensRules -> Dec -> DecsQ
makeFieldOpticsForDec rules = (`evalStateT` S.empty) . makeFieldOpticsForDec' rules
makeFieldOpticsForDec' :: LensRules -> Dec -> HasFieldClasses [Dec]
makeFieldOpticsForDec' rules = makeFieldOpticsForDatatype rules <=< lift . D.normalizeDec
makeFieldOpticsForDatatype :: LensRules -> D.DatatypeInfo -> HasFieldClasses [Dec]
makeFieldOpticsForDatatype rules info =
do perDef <- lift $ do
fieldCons <- traverse normalizeConstructor cons
let allFields = toListOf (folded % _2 % folded % _1 % folded) fieldCons
let defCons = over normFieldLabels (expandName allFields) fieldCons
allDefs = setOf (normFieldLabels % folded) defCons
T.sequenceA (M.fromSet (buildScaffold True rules s defCons) allDefs)
let defs = M.toList perDef
case _classyLenses rules tyName of
Just (className, methodName) ->
makeClassyDriver rules className methodName s defs
Nothing -> do decss <- traverse (makeFieldOptic rules) defs
return (concat decss)
where
tyName = D.datatypeName info
s = D.datatypeType info
cons = D.datatypeCons info
normFieldLabels :: Traversal [(Name,[(a,Type)])] [(Name,[(b,Type)])] a b
normFieldLabels = traversed % _2 % traversed % _1
expandName :: [Name] -> Maybe Name -> [DefName]
expandName allFields = concatMap (_fieldToDef rules tyName allFields) . maybeToList
makeFieldLabelsForDec :: LensRules -> Dec -> DecsQ
makeFieldLabelsForDec rules = makeFieldLabelsForDatatype rules <=< D.normalizeDec
makeFieldLabelsWith :: LensRules -> Name -> DecsQ
makeFieldLabelsWith rules = D.reifyDatatype >=> makeFieldLabelsForDatatype rules
makeFieldLabelsForDatatype :: LensRules -> D.DatatypeInfo -> Q [Dec]
makeFieldLabelsForDatatype rules info =
do perDef <- do
fieldCons <- traverse normalizeConstructor cons
let allFields = toListOf (folded % _2 % folded % _1 % folded) fieldCons
let defCons = over normFieldLabels (expandName allFields) fieldCons
allDefs = setOf (normFieldLabels % folded) defCons
T.sequenceA (M.fromSet (buildScaffold False rules s defCons) allDefs)
let defs = filter isRank1 $ M.toList perDef
traverse (makeFieldLabel rules) defs
where
isRank1 = \case
(_, (OpticSa rank1 _ _ _ _, _)) -> rank1
_ -> True
tyName = D.datatypeName info
s = D.datatypeType info
cons = D.datatypeCons info
normFieldLabels :: Traversal [(Name,[(a,Type)])] [(Name,[(b,Type)])] a b
normFieldLabels = traversed % _2 % traversed % _1
expandName :: [Name] -> Maybe Name -> [DefName]
expandName allFields = concatMap (_fieldToDef rules tyName allFields) . maybeToList
makeFieldLabel
:: LensRules
-> (DefName, (OpticStab, [(Name, Int, [Int])]))
-> Q Dec
makeFieldLabel rules (defName, (defType, cons)) = do
(context, instHead) <- case defType of
OpticSa _ _ otype s a -> do
(a', cxtA) <- eqSubst a "a"
(b', cxtB) <- eqSubst a "b"
pure (pure [cxtA, cxtB], pure $ conAppsT ''LabelOptic
[LitT (StrTyLit fieldName), ConT $ opticTypeToTag otype, s, s, a', b'])
OpticStab otype s t a b -> do
ambiguousTypeFamilies <- containsAmbiguousTypeFamilyApplications s a
let t' = if ambiguousTypeFamilies then s else t
(a', cxtA) <- eqSubst a "a"
(b', cxtB) <- if ambiguousTypeFamilies
then eqSubst a "b"
else eqSubst b "b"
pure (pure [cxtA, cxtB], pure $ conAppsT ''LabelOptic
[LitT (StrTyLit fieldName), ConT $ opticTypeToTag otype, s, t', a', b'])
instanceD context instHead (fun 'labelOptic)
where
opticTypeToTag AffineFoldType = ''An_AffineFold
opticTypeToTag AffineTraversalType = ''An_AffineTraversal
opticTypeToTag FoldType = ''A_Fold
opticTypeToTag GetterType = ''A_Getter
opticTypeToTag IsoType = ''An_Iso
opticTypeToTag LensType = ''A_Lens
opticTypeToTag TraversalType = ''A_Traversal
containsAmbiguousTypeFamilyApplications s a = do
(hasTypeFamilies, bareVars) <- (`runStateT` setOf typeVars s) $
go =<< lift (D.resolveTypeSynonyms a)
pure $ hasTypeFamilies && not (S.null bareVars)
where
go (ConT nm) = has (_FamilyI % _1 % _TypeFamilyD) <$> lift (reify nm)
go (VarT n) = modify' (S.delete n) *> pure False
go ty = or <$> traverse go (toListOf typeSelf ty)
fieldName = case defName of
TopName fname -> nameBase fname
MethodName _ fname -> nameBase fname
fun :: Name -> [DecQ]
fun n = funD n [funDef] : inlinePragma n
funDef :: ClauseQ
funDef = makeFieldClause rules (stabToOpticType defType) cons
normalizeConstructor ::
D.ConstructorInfo ->
Q (Name, [(Maybe Name, Type)])
normalizeConstructor con =
return (D.constructorName con,
zipWith checkForExistentials fieldNames (D.constructorFields con))
where
fieldNames =
case D.constructorVariant con of
D.RecordConstructor xs -> fmap Just xs
D.NormalConstructor -> repeat Nothing
D.InfixConstructor -> repeat Nothing
checkForExistentials _ fieldtype
| any (\tv -> D.tvName tv `S.member` used) unallowable
= (Nothing, fieldtype)
where
used = setOf typeVars fieldtype
unallowable = D.constructorVars con
checkForExistentials fieldname fieldtype = (fieldname, fieldtype)
buildScaffold ::
Bool ->
LensRules ->
Type ->
[(Name, [([DefName], Type)])] ->
DefName ->
Q (OpticStab, [(Name, Int, [Int])])
buildScaffold allowPhantomsChange rules s cons defName =
do (s',t,a,b) <- buildStab allowPhantomsChange s (concatMap snd consForDef)
let defType
| Just (tyvars,cx,a') <- preview _ForallT a =
let optic | lensCase = GetterType
| affineCase = AffineFoldType
| otherwise = FoldType
in OpticSa (null tyvars) cx optic s' a'
| not (_allowUpdates rules) =
let optic | lensCase = GetterType
| affineCase = AffineFoldType
| otherwise = FoldType
in OpticSa True [] optic s' a
| _simpleLenses rules || s' == t && a == b =
let optic | isoCase && _allowIsos rules = IsoType
| lensCase = LensType
| affineCase = AffineTraversalType
| otherwise = TraversalType
in OpticSa True [] optic s' a
| otherwise =
let optic | isoCase && _allowIsos rules = IsoType
| lensCase = LensType
| affineCase = AffineTraversalType
| otherwise = TraversalType
in OpticStab optic s' t a b
return (defType, scaffolds)
where
consForDef :: [(Name, [Either Type Type])]
consForDef = over (mapped % _2 % mapped) categorize cons
scaffolds :: [(Name, Int, [Int])]
scaffolds = [ (n, length ts, rightIndices ts) | (n,ts) <- consForDef ]
rightIndices :: [Either Type Type] -> [Int]
rightIndices = findIndices (has _Right)
categorize :: ([DefName], Type) -> Either Type Type
categorize (defNames, t)
| defName `elem` defNames = Right t
| otherwise = Left t
affectedFields :: [Int]
affectedFields = toListOf (folded % _3 % to length) scaffolds
lensCase :: Bool
lensCase = all (== 1) affectedFields
affineCase :: Bool
affineCase = all (<= 1) affectedFields
isoCase :: Bool
isoCase = case scaffolds of
[(_,1,[0])] -> True
_ -> False
data OpticType
= AffineFoldType
| AffineTraversalType
| FoldType
| GetterType
| IsoType
| LensType
| TraversalType
deriving Show
opticTypeName :: Bool -> OpticType -> Name
opticTypeName typeChanging AffineTraversalType = if typeChanging
then ''AffineTraversal
else ''AffineTraversal'
opticTypeName _typeChanging AffineFoldType = ''AffineFold
opticTypeName _typeChanging FoldType = ''Fold
opticTypeName _typeChanging GetterType = ''Getter
opticTypeName typeChanging IsoType = if typeChanging
then ''Iso
else ''Iso'
opticTypeName typeChanging LensType = if typeChanging
then ''Lens
else ''Lens'
opticTypeName typeChanging TraversalType = if typeChanging
then ''Traversal
else ''Traversal'
data OpticStab = OpticStab OpticType Type Type Type Type
| OpticSa Bool Cxt OpticType Type Type
stabToType :: OpticStab -> Type
stabToType (OpticStab c s t a b) =
quantifyType [] (opticTypeName True c `conAppsT` [s,t,a,b])
stabToType (OpticSa _ cx c s a ) =
quantifyType cx (opticTypeName False c `conAppsT` [s,a])
stabToContext :: OpticStab -> Cxt
stabToContext OpticStab{} = []
stabToContext (OpticSa _ cx _ _ _) = cx
stabToOpticType :: OpticStab -> OpticType
stabToOpticType (OpticStab c _ _ _ _) = c
stabToOpticType (OpticSa _ _ c _ _) = c
stabToOptic :: OpticStab -> Name
stabToOptic (OpticStab c _ _ _ _) = opticTypeName True c
stabToOptic (OpticSa _ _ c _ _) = opticTypeName False c
stabToS :: OpticStab -> Type
stabToS (OpticStab _ s _ _ _) = s
stabToS (OpticSa _ _ _ s _) = s
stabToA :: OpticStab -> Type
stabToA (OpticStab _ _ _ a _) = a
stabToA (OpticSa _ _ _ _ a) = a
buildStab :: Bool -> Type -> [Either Type Type] -> Q (Type,Type,Type,Type)
buildStab allowPhantomsChange s categorizedFields = do
sub <- T.sequenceA (M.fromSet (newName . nameBase) unfixedTypeVars)
let (t, b) = over each (substTypeVars sub) (s, a)
pure (s, t, a, b)
where
a = fromMaybe
(error "buildStab: unexpected empty list of fields")
(preview _head targetFields)
phantomTypeVars =
let allTypeVars = folded % chosen % typeVars
in setOf typeVars s S.\\ setOf allTypeVars categorizedFields
(fixedFields, targetFields) = partitionEithers categorizedFields
unfixedTypeVars =
let fixedTypeVars = setOf typeVars fixedFields
in if allowPhantomsChange
then setOf typeVars s S.\\ fixedTypeVars
else setOf typeVars s S.\\ fixedTypeVars S.\\ phantomTypeVars
makeFieldOptic ::
LensRules ->
(DefName, (OpticStab, [(Name, Int, [Int])])) ->
HasFieldClasses [Dec]
makeFieldOptic rules (defName, (defType, cons)) = do
locals <- get
addName
lift $ do cls <- mkCls locals
T.sequenceA (cls ++ sig ++ def)
where
mkCls locals = case defName of
MethodName c n | _generateClasses rules ->
do classExists <- isJust <$> lookupTypeName (show c)
return (if classExists || S.member c locals then [] else [makeFieldClass defType c n])
_ -> return []
addName = case defName of
MethodName c _ -> addFieldClassName c
_ -> return ()
sig = case defName of
_ | not (_generateSigs rules) -> []
TopName n -> [sigD n (return (stabToType defType))]
MethodName{} -> []
fun n = funD n [funDef] : inlinePragma n
def = case defName of
TopName n -> fun n
MethodName c n -> [makeFieldInstance defType c (fun n)]
funDef = makeFieldClause rules (stabToOpticType defType) cons
makeClassyDriver ::
LensRules ->
Name ->
Name ->
Type ->
[(DefName, (OpticStab, [(Name, Int, [Int])]))] ->
HasFieldClasses [Dec]
makeClassyDriver rules className methodName s defs = T.sequenceA (cls ++ inst)
where
cls | _generateClasses rules = [lift $ makeClassyClass className methodName s defs]
| otherwise = []
inst = [makeClassyInstance rules className methodName s defs]
makeClassyClass ::
Name ->
Name ->
Type ->
[(DefName, (OpticStab, [(Name, Int, [Int])]))] ->
DecQ
makeClassyClass className methodName s defs = do
c <- newName "c"
let vars = toListOf typeVars s
fd | null vars = []
| otherwise = [FunDep [c] vars]
classD (cxt[]) className (map PlainTV (c:vars)) fd
$ sigD methodName (return (''Lens' `conAppsT` [VarT c, s]))
: concat
[ [sigD defName (return ty)
,valD (varP defName) (normalB body) []
] ++
inlinePragma defName
| (TopName defName, (stab, _)) <- defs
, let body = infixApp (varE methodName) (varE '(%)) (varE defName)
, let ty = quantifyType' (S.fromList (c:vars))
(stabToContext stab)
$ stabToOptic stab `conAppsT`
[VarT c, stabToA stab]
]
makeClassyInstance ::
LensRules ->
Name ->
Name ->
Type ->
[(DefName, (OpticStab, [(Name, Int, [Int])]))] ->
HasFieldClasses Dec
makeClassyInstance rules className methodName s defs = do
methodss <- traverse (makeFieldOptic rules') defs
lift $ instanceD (cxt[]) (return instanceHead)
$ valD (varP methodName) (normalB (varE 'lensVL `appE` varE 'id)) []
: map return (concat methodss)
where
instanceHead = className `conAppsT` (s : map VarT vars)
vars = toListOf typeVars s
rules' = rules { _generateSigs = False
, _generateClasses = False
}
makeFieldClass :: OpticStab -> Name -> Name -> DecQ
makeFieldClass defType className methodName =
classD (cxt []) className [PlainTV s, PlainTV a] [FunDep [s] [a]]
[sigD methodName (return methodType)]
where
methodType = quantifyType' (S.fromList [s,a])
(stabToContext defType)
$ stabToOptic defType `conAppsT` [VarT s,VarT a]
s = mkName "s"
a = mkName "a"
makeFieldInstance :: OpticStab -> Name -> [DecQ] -> DecQ
makeFieldInstance defType className decs =
containsTypeFamilies a >>= pickInstanceDec
where
s = stabToS defType
a = stabToA defType
containsTypeFamilies = go <=< D.resolveTypeSynonyms
where
go (ConT nm) = has (_FamilyI % _1 % _TypeFamilyD) <$> reify nm
go ty = or <$> traverse go (toListOf typeSelf ty)
pickInstanceDec hasFamilies
| hasFamilies = do
placeholder <- VarT <$> newName "a"
mkInstanceDec
[return (D.equalPred placeholder a)]
[s, placeholder]
| otherwise = mkInstanceDec [] [s, a]
mkInstanceDec context headTys =
instanceD (cxt context) (return (className `conAppsT` headTys)) decs
makeFieldClause :: LensRules -> OpticType -> [(Name, Int, [Int])] -> ClauseQ
makeFieldClause rules opticType cons =
case opticType of
AffineFoldType -> makeAffineFoldClause cons
AffineTraversalType -> makeAffineTraversalClause cons irref
FoldType -> makeFoldClause cons
IsoType -> makeIsoClause cons irref
GetterType -> makeGetterClause cons
LensType -> makeLensClause cons irref
TraversalType -> makeTraversalClause cons irref
where
irref = _lazyPatterns rules && length cons == 1
makeAffineFoldClause :: [(Name, Int, [Int])] -> ClauseQ
makeAffineFoldClause cons = do
s <- newName "s"
clause
[]
(normalB $ appsE
[ varE 'afolding
, lamE [varP s] $ caseE (varE s)
[ makeAffineFoldMatch conName fieldCount fields
| (conName, fieldCount, fields) <- cons
]
])
[]
where
makeAffineFoldMatch conName fieldCount fields = do
xs <- newNames "x" $ length fields
let args = foldr (\(i, x) -> set (ix i) (varP x))
(replicate fieldCount wildP)
(zip fields xs)
body = case xs of
[] -> conE 'Nothing
[x] -> conE 'Just `appE` varE x
_ -> error "AffineFold focuses on at most one field"
match (conP conName args)
(normalB body)
[]
makeFoldClause :: [(Name, Int, [Int])] -> ClauseQ
makeFoldClause cons = do
f <- newName "f"
s <- newName "s"
clause
[]
(normalB $ appsE
[ varE 'foldVL
, lamE [varP f, varP s] $ caseE (varE s)
[ makeFoldMatch f conName fieldCount fields
| (conName, fieldCount, fields) <- cons
]
])
[]
where
makeFoldMatch f conName fieldCount fields = do
xs <- newNames "x" $ length fields
let args = foldr (\(i, x) -> set (ix i) (varP x))
(replicate fieldCount wildP)
(zip fields xs)
fxs = case xs of
[] -> [varE 'pure `appE` conE '()]
_ -> map (\x -> varE f `appE` varE x) xs
body = appsE
[ foldr1 (\fx -> infixApp fx (varE '(*>))) fxs
]
match (conP conName args)
(normalB body)
[]
makeGetterClause :: [(Name, Int, [Int])] -> ClauseQ
makeGetterClause cons = do
s <- newName "s"
clause
[]
(normalB $ appsE
[ varE 'to
, lamE [varP s] $ caseE (varE s)
[ makeGetterMatch conName fieldCount fields
| (conName, fieldCount, fields) <- cons
]
])
[]
where
makeGetterMatch conName fieldCount = \case
[field] -> do
x <- newName "x"
match (conP conName . set (ix field) (varP x) $ replicate fieldCount wildP)
(normalB $ varE x)
[]
_ -> error "Getter focuses on exactly one field"
makeIsoClause :: [(Name, Int, [Int])] -> Bool -> ClauseQ
makeIsoClause fields irref = case fields of
[(conName, 1, [0])] -> do
x <- newName "x"
clause []
(normalB $ appsE
[ varE 'iso
, lamE [irrefP $ conP conName [varP x]] (varE x)
, conE conName
])
[]
_ -> error "Iso works only for types with one constructor and one field"
where
irrefP = if irref then tildeP else id
makeLensClause :: [(Name, Int, [Int])] -> Bool -> ClauseQ
makeLensClause cons irref = do
f <- newName "f"
s <- newName "s"
clause
[]
(normalB $ appsE
[ varE 'lensVL
, lamE [varP f, varP s] $ caseE (varE s)
[ makeLensMatch irrefP f conName fieldCount fields
| (conName, fieldCount, fields) <- cons
]
])
[]
where
irrefP = if irref then tildeP else id
makeLensMatch :: (PatQ -> PatQ) -> Name -> Name -> Int -> [Int] -> Q Match
makeLensMatch irrefP f conName fieldCount = \case
[field] -> do
xs <- newNames "x" fieldCount
y <- newName "y"
let body = appsE
[ varE 'fmap
, lamE [varP y] . appsE $
conE conName : map varE (set (ix field) y xs)
, appE (varE f) . varE $ xs !! field
]
match (irrefP . conP conName $ map varP xs)
(normalB body)
[]
_ -> error "Lens focuses on exactly one field"
makeAffineTraversalClause :: [(Name, Int, [Int])] -> Bool -> ClauseQ
makeAffineTraversalClause cons irref = do
point <- newName "point"
f <- newName "f"
s <- newName "s"
clause
[]
(normalB $ appsE
[ varE 'atraversalVL
, lamE [varP point, varP f, varP s] $ caseE (varE s)
[ makeAffineTraversalMatch point f conName fieldCount fields
| (conName, fieldCount, fields) <- cons
]
])
[]
where
irrefP = if irref then tildeP else id
makeAffineTraversalMatch point f conName fieldCount = \case
[] -> do
xs <- newNames "x" fieldCount
match (irrefP . conP conName $ map varP xs)
(normalB $ varE point `appE` appsE (conE conName : map varE xs))
[]
[field] -> makeLensMatch irrefP f conName fieldCount [field]
_ -> error "Affine traversal focuses on at most one field"
makeTraversalClause :: [(Name, Int, [Int])] -> Bool -> ClauseQ
makeTraversalClause cons irref = do
f <- newName "f"
s <- newName "s"
clause
[]
(normalB $ appsE
[ varE 'traversalVL
, lamE [varP f, varP s] $ caseE (varE s)
[ makeTraversalMatch f conName fieldCount fields
| (conName, fieldCount, fields) <- cons
]
])
[]
where
irrefP = if irref then tildeP else id
makeTraversalMatch f conName fieldCount fields = do
xs <- newNames "x" fieldCount
case fields of
[] ->
match (irrefP . conP conName $ map varP xs)
(normalB $ varE 'pure `appE` appsE (conE conName : map varE xs))
[]
_ -> do
ys <- newNames "y" $ length fields
let xs' = foldr (\(i, x) -> set (ix i) x) xs (zip fields ys)
mkFx i = varE f `appE` varE (xs !! i)
body0 = appsE
[ varE 'pure
, lamE (map varP ys) $ appsE $ conE conName : map varE xs'
]
body = foldl (\acc i -> infixApp acc (varE '(<*>)) $ mkFx i)
body0
fields
match (irrefP . conP conName $ map varP xs)
(normalB body)
[]
data LensRules = LensRules
{ _simpleLenses :: Bool
, _generateSigs :: Bool
, _generateClasses :: Bool
, _allowIsos :: Bool
, _allowUpdates :: Bool
, _lazyPatterns :: Bool
, _fieldToDef :: FieldNamer
, _classyLenses :: ClassyNamer
}
type FieldNamer = Name
-> [Name]
-> Name
-> [DefName]
data DefName
= TopName Name
| MethodName Name Name
deriving (Show, Eq, Ord)
type ClassyNamer = Name
-> Maybe (Name, Name)
type HasFieldClasses = StateT (S.Set Name) Q
addFieldClassName :: Name -> HasFieldClasses ()
addFieldClassName n = modify $ S.insert n
_TypeFamilyD :: AffineFold Dec ()
_TypeFamilyD = _OpenTypeFamilyD % united `afailing` _ClosedTypeFamilyD % united
quantifyType :: Cxt -> Type -> Type
quantifyType = quantifyType' S.empty
quantifyType' :: S.Set Name -> Cxt -> Type -> Type
quantifyType' exclude c t = ForallT vs c t
where
vs = map PlainTV
$ filter (`S.notMember` exclude)
$ nub
$ toListOf typeVars t