module Data.DeepPrisms where

import Control.Lens (Prism', makeClassyPrisms)
import qualified Control.Lens as Lens (preview, review)
import Language.Haskell.TH
import Language.Haskell.TH.Datatype (
  ConstructorInfo(constructorName, constructorFields),
  DatatypeInfo(datatypeCons),
  reifyDatatype,
  )
import Language.Haskell.TH.Syntax (
  ModName(..),
  Name(Name),
  NameFlavour(NameQ, NameS, NameG),
  NameSpace(VarName),
  OccName(..),
  )

class DeepPrisms e e' where
  prism :: Prism' e e'

hoist :: DeepPrisms e e' => e' -> e
hoist :: e' -> e
hoist =
  AReview e e' -> e' -> e
forall b (m :: * -> *) t. MonadReader b m => AReview t b -> m t
Lens.review AReview e e'
forall e e'. DeepPrisms e e' => Prism' e e'
prism

retrieve :: DeepPrisms e e' => e -> Maybe e'
retrieve :: e -> Maybe e'
retrieve =
  Getting (First e') e e' -> e -> Maybe e'
forall s (m :: * -> *) a.
MonadReader s m =>
Getting (First a) s a -> m (Maybe a)
Lens.preview Getting (First e') e e'
forall e e'. DeepPrisms e e' => Prism' e e'
prism

data Ctor =
  Ctor {
    Ctor -> Name
ctorName :: Name,
    Ctor -> Name
ctorType :: Name
  }
  deriving (Ctor -> Ctor -> Bool
(Ctor -> Ctor -> Bool) -> (Ctor -> Ctor -> Bool) -> Eq Ctor
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Ctor -> Ctor -> Bool
$c/= :: Ctor -> Ctor -> Bool
== :: Ctor -> Ctor -> Bool
$c== :: Ctor -> Ctor -> Bool
Eq, Int -> Ctor -> ShowS
[Ctor] -> ShowS
Ctor -> String
(Int -> Ctor -> ShowS)
-> (Ctor -> String) -> ([Ctor] -> ShowS) -> Show Ctor
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Ctor] -> ShowS
$cshowList :: [Ctor] -> ShowS
show :: Ctor -> String
$cshow :: Ctor -> String
showsPrec :: Int -> Ctor -> ShowS
$cshowsPrec :: Int -> Ctor -> ShowS
Show)

data SubError =
  SubError {
    SubError -> Name
seCtor :: Name,
    SubError -> Name
seWrapped :: Name
  }
  deriving (SubError -> SubError -> Bool
(SubError -> SubError -> Bool)
-> (SubError -> SubError -> Bool) -> Eq SubError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SubError -> SubError -> Bool
$c/= :: SubError -> SubError -> Bool
== :: SubError -> SubError -> Bool
$c== :: SubError -> SubError -> Bool
Eq, Int -> SubError -> ShowS
[SubError] -> ShowS
SubError -> String
(Int -> SubError -> ShowS)
-> (SubError -> String) -> ([SubError] -> ShowS) -> Show SubError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SubError] -> ShowS
$cshowList :: [SubError] -> ShowS
show :: SubError -> String
$cshow :: SubError -> String
showsPrec :: Int -> SubError -> ShowS
$cshowsPrec :: Int -> SubError -> ShowS
Show)

data PrismsInstance =
  PrismsInstance {
    PrismsInstance -> Name
prismInstanceName :: Name,
    PrismsInstance -> Dec
prismInstanceDec :: Dec
  }
  deriving (PrismsInstance -> PrismsInstance -> Bool
(PrismsInstance -> PrismsInstance -> Bool)
-> (PrismsInstance -> PrismsInstance -> Bool) -> Eq PrismsInstance
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: PrismsInstance -> PrismsInstance -> Bool
$c/= :: PrismsInstance -> PrismsInstance -> Bool
== :: PrismsInstance -> PrismsInstance -> Bool
$c== :: PrismsInstance -> PrismsInstance -> Bool
Eq, Int -> PrismsInstance -> ShowS
[PrismsInstance] -> ShowS
PrismsInstance -> String
(Int -> PrismsInstance -> ShowS)
-> (PrismsInstance -> String)
-> ([PrismsInstance] -> ShowS)
-> Show PrismsInstance
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [PrismsInstance] -> ShowS
$cshowList :: [PrismsInstance] -> ShowS
show :: PrismsInstance -> String
$cshow :: PrismsInstance -> String
showsPrec :: Int -> PrismsInstance -> ShowS
$cshowsPrec :: Int -> PrismsInstance -> ShowS
Show)

ctor :: ConstructorInfo -> Maybe Ctor
ctor :: ConstructorInfo -> Maybe Ctor
ctor ConstructorInfo
info =
  [Type] -> Maybe Ctor
cons (ConstructorInfo -> [Type]
constructorFields ConstructorInfo
info)
  where
    cons :: [Type] -> Maybe Ctor
cons [ConT tpe] =
      Ctor -> Maybe Ctor
forall a. a -> Maybe a
Just (Ctor -> Maybe Ctor) -> Ctor -> Maybe Ctor
forall a b. (a -> b) -> a -> b
$ Name -> Name -> Ctor
Ctor (ConstructorInfo -> Name
constructorName ConstructorInfo
info) Name
tpe
    cons [Type]
_ =
      Maybe Ctor
forall a. Maybe a
Nothing

dataType :: Name -> Q [Ctor]
dataType :: Name -> Q [Ctor]
dataType =
  (DatatypeInfo -> [Ctor]) -> Q DatatypeInfo -> Q [Ctor]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((ConstructorInfo -> Maybe Ctor) -> [ConstructorInfo] -> [Ctor]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ConstructorInfo -> Maybe Ctor
ctor ([ConstructorInfo] -> [Ctor])
-> (DatatypeInfo -> [ConstructorInfo]) -> DatatypeInfo -> [Ctor]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DatatypeInfo -> [ConstructorInfo]
datatypeCons) (Q DatatypeInfo -> Q [Ctor])
-> (Name -> Q DatatypeInfo) -> Name -> Q [Ctor]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Q DatatypeInfo
reifyDatatype

mkHoist :: TypeQ -> TypeQ -> BodyQ -> DecQ
mkHoist :: TypeQ -> TypeQ -> BodyQ -> DecQ
mkHoist TypeQ
_ TypeQ
_ BodyQ
body = do
  (VarE Name
name) <- [|prism|]
  Name -> [ClauseQ] -> DecQ
funD Name
name [[PatQ] -> BodyQ -> [DecQ] -> ClauseQ
clause [] BodyQ
body []]

deepPrismsInstance :: TypeQ -> TypeQ -> BodyQ -> DecQ
deepPrismsInstance :: TypeQ -> TypeQ -> BodyQ -> DecQ
deepPrismsInstance TypeQ
top TypeQ
local' BodyQ
body =
  CxtQ -> TypeQ -> [DecQ] -> DecQ
instanceD ([TypeQ] -> CxtQ
cxt []) (TypeQ -> TypeQ -> TypeQ
appT (TypeQ -> TypeQ -> TypeQ
appT [t|DeepPrisms|] TypeQ
top) TypeQ
local') [TypeQ -> TypeQ -> BodyQ -> DecQ
mkHoist TypeQ
top TypeQ
local' BodyQ
body]

idInstance :: Name -> DecQ
idInstance :: Name -> DecQ
idInstance Name
name =
  TypeQ -> TypeQ -> BodyQ -> DecQ
deepPrismsInstance TypeQ
nt TypeQ
nt BodyQ
body
  where
    nt :: TypeQ
nt = Name -> TypeQ
conT Name
name
    body :: BodyQ
body = ExpQ -> BodyQ
normalB [|id|]

typeHasDeepPrisms :: Ctor -> Q Bool
typeHasDeepPrisms :: Ctor -> Q Bool
typeHasDeepPrisms (Ctor Name
_ Name
tpe) = do
  (ConT Name
name) <- [t|DeepPrisms|]
  Name -> [Type] -> Q Bool
isInstance Name
name [Name -> Type
ConT Name
tpe, Name -> Type
ConT Name
tpe]

modName :: NameFlavour -> Maybe ModName
modName :: NameFlavour -> Maybe ModName
modName (NameQ ModName
mod') =
  ModName -> Maybe ModName
forall a. a -> Maybe a
Just ModName
mod'
modName (NameG NameSpace
_ PkgName
_ ModName
mod') =
  ModName -> Maybe ModName
forall a. a -> Maybe a
Just ModName
mod'
modName NameFlavour
_ =
  Maybe ModName
forall a. Maybe a
Nothing

sameModule :: NameFlavour -> NameFlavour -> Bool
sameModule :: NameFlavour -> NameFlavour -> Bool
sameModule NameFlavour
f1 NameFlavour
f2 =
  case (NameFlavour -> Maybe ModName
modName NameFlavour
f1, NameFlavour -> Maybe ModName
modName NameFlavour
f2) of
    (Just ModName
a, Just ModName
b) | ModName
a ModName -> ModName -> Bool
forall a. Eq a => a -> a -> Bool
== ModName
b -> Bool
True
    (Maybe ModName, Maybe ModName)
_ -> Bool
False

-- |Convert a constructor's NameFlavour to one for a prism
-- The NameSpace field is DataName for the constructor and must be VarName
-- Curiously, this only surfaces as a bug when having a certain nesting level across modules
prismFlavour :: NameFlavour -> NameFlavour
prismFlavour :: NameFlavour -> NameFlavour
prismFlavour (NameG NameSpace
_ PkgName
pkg ModName
mod') =
  NameSpace -> PkgName -> ModName -> NameFlavour
NameG NameSpace
VarName PkgName
pkg ModName
mod'
prismFlavour NameFlavour
n =
  NameFlavour
n

prismName :: Name -> Name -> ExpQ
prismName :: Name -> Name -> ExpQ
prismName (Name OccName
_ NameFlavour
topFlavour) (Name (OccName String
n) NameFlavour
localFlavour) =
  Name -> ExpQ
varE (OccName -> NameFlavour -> Name
Name (String -> OccName
OccName (Char
'_' Char -> ShowS
forall a. a -> [a] -> [a]
: String
n)) NameFlavour
flavour)
  where
    flavour :: NameFlavour
flavour
      | NameFlavour -> NameFlavour -> Bool
sameModule NameFlavour
topFlavour NameFlavour
localFlavour = NameFlavour
NameS
      | Bool
otherwise = NameFlavour -> NameFlavour
prismFlavour NameFlavour
localFlavour

constructorPrism :: Name -> [Name] -> Ctor -> Q PrismsInstance
constructorPrism :: Name -> [Name] -> Ctor -> Q PrismsInstance
constructorPrism Name
top [Name]
intermediate (Ctor Name
name Name
tpe) = do
  Dec
inst <- TypeQ -> TypeQ -> BodyQ -> DecQ
deepPrismsInstance (Name -> TypeQ
conT Name
top) (Name -> TypeQ
conT Name
tpe) (ExpQ -> BodyQ
normalB ExpQ
body)
  return (Name -> Dec -> PrismsInstance
PrismsInstance Name
tpe Dec
inst)
  where
    compose :: Name -> ExpQ -> ExpQ
compose = ExpQ -> ExpQ -> ExpQ
appE (ExpQ -> ExpQ -> ExpQ) -> (Name -> ExpQ) -> Name -> ExpQ -> ExpQ
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExpQ -> ExpQ -> ExpQ
appE [|(.)|] (ExpQ -> ExpQ) -> (Name -> ExpQ) -> Name -> ExpQ
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Name -> Name -> ExpQ
prismName Name
top
    body :: ExpQ
body = (Name -> ExpQ -> ExpQ) -> ExpQ -> [Name] -> ExpQ
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Name -> ExpQ -> ExpQ
compose (Name -> Name -> ExpQ
prismName Name
top Name
name) ([Name] -> [Name]
forall a. [a] -> [a]
reverse [Name]
intermediate)

filterDuplicates :: [Ctor] -> [PrismsInstance] -> [PrismsInstance]
filterDuplicates :: [Ctor] -> [PrismsInstance] -> [PrismsInstance]
filterDuplicates [Ctor]
created =
  (PrismsInstance -> Bool) -> [PrismsInstance] -> [PrismsInstance]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool)
-> (PrismsInstance -> Bool) -> PrismsInstance -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name -> [Name] -> Bool
forall (f :: * -> *) a.
(Foldable f, DisallowElem f, Eq a) =>
a -> f a -> Bool
`elem` (Ctor -> Name
ctorType (Ctor -> Name) -> [Ctor] -> [Name]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Ctor]
created)) (Name -> Bool)
-> (PrismsInstance -> Name) -> PrismsInstance -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrismsInstance -> Name
prismInstanceName)

deepPrismCtors :: Name -> Q [Ctor]
deepPrismCtors :: Name -> Q [Ctor]
deepPrismCtors =
  (Ctor -> Q Bool) -> [Ctor] -> Q [Ctor]
forall (m :: * -> *) a.
Applicative m =>
(a -> m Bool) -> [a] -> m [a]
filterM Ctor -> Q Bool
typeHasDeepPrisms ([Ctor] -> Q [Ctor]) -> (Name -> Q [Ctor]) -> Name -> Q [Ctor]
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< Name -> Q [Ctor]
dataType

basicPrisms :: Name -> DecsQ
basicPrisms :: Name -> DecsQ
basicPrisms Name
name = do
  [Ctor]
ctors <- Name -> Q [Ctor]
dataType Name
name
  if [Ctor] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Ctor]
ctors Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1 then Name -> DecsQ
makeClassyPrisms Name
name else [Dec] -> DecsQ
forall (m :: * -> *) a. Monad m => a -> m a
return []

prismsForData :: Name -> [Name] -> Name -> Q [PrismsInstance]
prismsForData :: Name -> [Name] -> Name -> Q [PrismsInstance]
prismsForData Name
top [Name]
intermediate Name
local' = do
  [Ctor]
cons <- Name -> Q [Ctor]
deepPrismCtors Name
local'
  [PrismsInstance]
localInstances <- (Ctor -> Q PrismsInstance) -> [Ctor] -> Q [PrismsInstance]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (Name -> [Name] -> Ctor -> Q PrismsInstance
constructorPrism Name
top [Name]
intermediate) [Ctor]
cons
  [[PrismsInstance]]
deepInstances <- (Ctor -> Q [PrismsInstance]) -> [Ctor] -> Q [[PrismsInstance]]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse Ctor -> Q [PrismsInstance]
recurse [Ctor]
cons
  return ([PrismsInstance]
localInstances [PrismsInstance] -> [PrismsInstance] -> [PrismsInstance]
forall a. [a] -> [a] -> [a]
++ ([[PrismsInstance]]
deepInstances [[PrismsInstance]]
-> ([PrismsInstance] -> [PrismsInstance]) -> [PrismsInstance]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= [Ctor] -> [PrismsInstance] -> [PrismsInstance]
filterDuplicates [Ctor]
cons))
  where
    recurse :: Ctor -> Q [PrismsInstance]
recurse (Ctor Name
name Name
tpe) = Name -> [Name] -> Name -> Q [PrismsInstance]
prismsForData Name
top (Name
name Name -> [Name] -> [Name]
forall a. a -> [a] -> [a]
: [Name]
intermediate) Name
tpe

prismsForMainData :: Name -> DecsQ
prismsForMainData :: Name -> DecsQ
prismsForMainData Name
name = do
  Dec
idInst <- Name -> DecQ
idInstance Name
name
  [PrismsInstance]
insts <- Name -> [Name] -> Name -> Q [PrismsInstance]
prismsForData Name
name [] Name
name
  return (Dec
idInst Dec -> [Dec] -> [Dec]
forall a. a -> [a] -> [a]
: (PrismsInstance -> Dec
prismInstanceDec (PrismsInstance -> Dec) -> [PrismsInstance] -> [Dec]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [PrismsInstance]
insts))

deepPrisms :: Name -> DecsQ
deepPrisms :: Name -> DecsQ
deepPrisms Name
name = do
  [Dec]
basic <- Name -> DecsQ
basicPrisms Name
name
  [Dec]
deep <- Name -> DecsQ
prismsForMainData Name
name
  return $ [Dec]
basic [Dec] -> [Dec] -> [Dec]
forall a. [a] -> [a] -> [a]
++ [Dec]
deep