module Contravariant.Extras.TH
  ( opContrazipDecs,
    contrazipDecs,
    contrazipExp,
  )
where

import Contravariant.Extras.Prelude
import Language.Haskell.TH.Syntax
import qualified TemplateHaskell.Compat.V0208 as Compat

-- |
-- Generates declarations in the spirit of the following:
--
-- @
-- tuple3 :: Monoid a => Op a b1 -> Op a b2 -> Op a b3 -> Op a ( b1 , b2 , b3 )
-- tuple3 ( Op op1 ) ( Op op2 ) ( Op op3 ) =
--   Op $ \( v1 , v2 , v3 ) -> mconcat [ op1 v1 , op2 v2 , op3 v3 ]
-- @
opContrazipDecs :: String -> Int -> [Dec]
opContrazipDecs :: String -> Int -> [Dec]
opContrazipDecs String
baseName Int
arity =
  [Dec
signature, Dec
value]
  where
    name :: Name
name =
      String -> Name
mkName (String -> ShowS
showString String
baseName (forall a. Show a => a -> String
show Int
arity))
    signature :: Dec
signature =
      Name -> Type -> Dec
SigD Name
name Type
type_
      where
        type_ :: Type
type_ =
          [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT [TyVarBndr Specificity]
vars Cxt
cxt Type
type_
          where
            vars :: [TyVarBndr Specificity]
vars =
              forall a b. (a -> b) -> [a] -> [b]
map (Name -> TyVarBndr Specificity
Compat.specifiedPlainTV forall {k} (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. String -> Name
mkName) (String
"a" forall a. a -> [a] -> [a]
: [String]
bs)
              where
                bs :: [String]
bs =
                  forall a b. (a -> b) -> [a] -> [b]
map forall a. Show a => a -> String
b (forall a. Enum a => a -> a -> [a]
enumFromTo Int
1 Int
arity)
                  where
                    b :: a -> String
b a
index =
                      String -> ShowS
showString String
"b" (forall a. Show a => a -> String
show a
index)
            cxt :: Cxt
cxt =
              [Type
pred]
              where
                pred :: Type
pred =
                  Name -> Cxt -> Type
Compat.classP ''Monoid [Type
a]
                  where
                    a :: Type
a =
                      Name -> Type
VarT (String -> Name
mkName String
"a")
            type_ :: Type
type_ =
              forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Type -> Type -> Type
appArrowT Type
result Cxt
params
              where
                appArrowT :: Type -> Type -> Type
appArrowT Type
a Type
b =
                  Type -> Type -> Type
AppT (Type -> Type -> Type
AppT Type
ArrowT Type
a) Type
b
                a :: Type
a =
                  Name -> Type
VarT (String -> Name
mkName String
"a")
                result :: Type
result =
                  Type -> Type -> Type
AppT (Type -> Type -> Type
AppT (Name -> Type
ConT ''Op) Type
a) Type
tuple
                  where
                    tuple :: Type
tuple =
                      forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Int -> Type
TupleT Int
arity) Cxt
params
                      where
                        params :: Cxt
params =
                          forall a b. (a -> b) -> [a] -> [b]
map forall {a}. Show a => a -> Type
param (forall a. Enum a => a -> a -> [a]
enumFromTo Int
1 Int
arity)
                          where
                            param :: a -> Type
param a
index =
                              Name -> Type
VarT (String -> Name
mkName (String -> ShowS
showString String
"b" (forall a. Show a => a -> String
show a
index)))
                params :: Cxt
params =
                  forall a b. (a -> b) -> [a] -> [b]
map forall {a}. Show a => a -> Type
param (forall a. Enum a => a -> a -> [a]
enumFromTo Int
1 Int
arity)
                  where
                    param :: a -> Type
param a
index =
                      Type -> Type -> Type
AppT (Type -> Type -> Type
AppT (Name -> Type
ConT ''Op) Type
a) Type
b
                      where
                        b :: Type
b =
                          Name -> Type
VarT (String -> Name
mkName (String -> ShowS
showString String
"b" (forall a. Show a => a -> String
show a
index)))
    value :: Dec
value =
      Name -> [Clause] -> Dec
FunD Name
name [Clause]
clauses
      where
        clauses :: [Clause]
clauses =
          [Clause
clause]
          where
            clause :: Clause
clause =
              [Pat] -> Body -> [Dec] -> Clause
Clause [Pat]
pats Body
body []
              where
                pats :: [Pat]
pats =
                  forall a b. (a -> b) -> [a] -> [b]
map forall {a}. Show a => a -> Pat
pat (forall a. Enum a => a -> a -> [a]
enumFromTo Int
1 Int
arity)
                  where
                    pat :: a -> Pat
pat a
index =
                      Name -> [Pat] -> Pat
Compat.conP 'Op [Pat]
pats
                      where
                        pats :: [Pat]
pats =
                          [Name -> Pat
VarP Name
name]
                          where
                            name :: Name
name =
                              String -> Name
mkName (String -> ShowS
showString String
"op" (forall a. Show a => a -> String
show a
index))
                body :: Body
body =
                  Exp -> Body
NormalB (Exp -> Exp -> Exp
AppE (Name -> Exp
ConE 'Op) Exp
lambda)
                  where
                    lambda :: Exp
lambda =
                      [Pat] -> Exp -> Exp
LamE [Pat]
pats Exp
exp
                      where
                        pats :: [Pat]
pats =
                          [[Pat] -> Pat
TupP [Pat]
pats]
                          where
                            pats :: [Pat]
pats =
                              forall a b. (a -> b) -> [a] -> [b]
map forall {a}. Show a => a -> Pat
pat (forall a. Enum a => a -> a -> [a]
enumFromTo Int
1 Int
arity)
                              where
                                pat :: a -> Pat
pat a
index =
                                  Name -> Pat
VarP (String -> Name
mkName (String -> ShowS
showString String
"v" (forall a. Show a => a -> String
show a
index)))
                        exp :: Exp
exp =
                          Exp -> Exp -> Exp
AppE (Name -> Exp
VarE 'mconcat) ([Exp] -> Exp
ListE [Exp]
applications)
                          where
                            applications :: [Exp]
applications =
                              forall a b. (a -> b) -> [a] -> [b]
map forall {a}. Show a => a -> Exp
application (forall a. Enum a => a -> a -> [a]
enumFromTo Int
1 Int
arity)
                              where
                                application :: a -> Exp
application a
index =
                                  Exp -> Exp -> Exp
AppE (Name -> Exp
VarE Name
opName) (Name -> Exp
VarE Name
varName)
                                  where
                                    opName :: Name
opName =
                                      String -> Name
mkName (String -> ShowS
showString String
"op" (forall a. Show a => a -> String
show a
index))
                                    varName :: Name
varName =
                                      String -> Name
mkName (String -> ShowS
showString String
"v" (forall a. Show a => a -> String
show a
index))

-- |
-- Generates declarations in the spirit of the following:
--
-- @
-- contrazip4 :: Divisible f => f a1 -> f a2 -> f a3 -> f a4 -> f ( a1 , a2 , a3 , a4 )
-- contrazip4 f1 f2 f3 f4 =
--   divide $(TupleTH.splitTupleAt 4 1) f1 $
--   divide $(TupleTH.splitTupleAt 3 1) f2 $
--   divide $(TupleTH.splitTupleAt 2 1) f3 $
--   f4
-- @
contrazipDecs :: String -> Int -> [Dec]
contrazipDecs :: String -> Int -> [Dec]
contrazipDecs String
baseName Int
arity = [Dec
signature, Dec
value]
  where
    name :: Name
name = String -> Name
mkName (String -> ShowS
showString String
baseName (forall a. Show a => a -> String
show Int
arity))
    signature :: Dec
signature = Name -> Type -> Dec
SigD Name
name (Int -> Type
contrazipType Int
arity)
    value :: Dec
value = Name -> [Clause] -> Dec
FunD Name
name [Clause]
clauses
      where
        clauses :: [Clause]
clauses = [Clause
clause]
          where
            clause :: Clause
clause = [Pat] -> Body -> [Dec] -> Clause
Clause [] Body
body []
              where
                body :: Body
body = Exp -> Body
NormalB (Int -> Exp
contrazipExp Int
arity)

contrazipType :: Int -> Type
contrazipType :: Int -> Type
contrazipType Int
arity = [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT [TyVarBndr Specificity]
vars Cxt
cxt Type
type_
  where
    fName :: Name
fName = String -> Name
mkName String
"f"
    aNames :: [Name]
aNames = forall a b. (a -> b) -> [a] -> [b]
map forall {a}. Show a => a -> Name
aName (forall a. Enum a => a -> a -> [a]
enumFromTo Int
1 Int
arity)
      where
        aName :: a -> Name
aName a
index = String -> Name
mkName (String -> ShowS
showString String
"a" (forall a. Show a => a -> String
show a
index))
    vars :: [TyVarBndr Specificity]
vars = forall a b. (a -> b) -> [a] -> [b]
map Name -> TyVarBndr Specificity
Compat.specifiedPlainTV (Name
fName forall a. a -> [a] -> [a]
: [Name]
aNames)
    cxt :: Cxt
cxt = [Type
pred]
      where
        pred :: Type
pred = Name -> Cxt -> Type
Compat.classP ''Divisible [Name -> Type
VarT Name
fName]
    type_ :: Type
type_ = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Type -> Type -> Type
appArrowT Type
result Cxt
params
      where
        appArrowT :: Type -> Type -> Type
appArrowT Type
a Type
b = Type -> Type -> Type
AppT (Type -> Type -> Type
AppT Type
ArrowT Type
a) Type
b
        result :: Type
result = Type -> Type -> Type
AppT (Name -> Type
VarT Name
fName) Type
tuple
          where
            tuple :: Type
tuple = forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Int -> Type
TupleT Int
arity) (forall a b. (a -> b) -> [a] -> [b]
map Name -> Type
VarT [Name]
aNames)
        params :: Cxt
params = forall a b. (a -> b) -> [a] -> [b]
map Name -> Type
param [Name]
aNames
          where
            param :: Name -> Type
param Name
aName = Type -> Type -> Type
AppT (Name -> Type
VarT Name
fName) (Name -> Type
VarT Name
aName)

-- |
-- Contrazip lambda expression of specified arity.
--
-- Allows to create contrazip expressions of any arity:
--
-- >>>:t $(return (contrazipExp 2))
-- \$(return (contrazipExp 2))
--   :: Data.Functor.Contravariant.Divisible.Divisible f =>
--      f a1 -> f a2 -> f (a1, a2)
contrazipExp :: Int -> Exp
contrazipExp :: Int -> Exp
contrazipExp Int
arity = Exp -> Type -> Exp
SigE ([Pat] -> Exp -> Exp
LamE [Pat]
pats Exp
body) (Int -> Type
contrazipType Int
arity)
  where
    pats :: [Pat]
pats = forall a b. (a -> b) -> [a] -> [b]
map forall {a}. Show a => a -> Pat
pat (forall a. Enum a => a -> a -> [a]
enumFromTo Int
1 Int
arity)
      where
        pat :: a -> Pat
pat a
index = Name -> Pat
VarP Name
name
          where
            name :: Name
name = String -> Name
mkName (String -> ShowS
showString String
"f" (forall a. Show a => a -> String
show a
index))
    body :: Exp
body = Int -> Exp
exp Int
arity
      where
        exp :: Int -> Exp
exp Int
index = case Int
index of
          Int
1 -> Name -> Exp
VarE (String -> Name
mkName (String -> ShowS
showString String
"f" (forall a. Show a => a -> String
show Int
arity)))
          Int
_ ->
            forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1
              Exp -> Exp -> Exp
AppE
              [ Name -> Exp
VarE 'divide,
                Int -> Int -> Exp
splitTupleAtExp Int
index Int
1,
                Name -> Exp
VarE (String -> Name
mkName (String -> ShowS
showString String
"f" (forall a. Show a => a -> String
show (Int
arity forall a. Num a => a -> a -> a
- Int
index forall a. Num a => a -> a -> a
+ Int
1)))),
                Int -> Exp
exp (forall a. Enum a => a -> a
pred Int
index)
              ]

splitTupleAtExp :: Int -> Int -> Exp
splitTupleAtExp :: Int -> Int -> Exp
splitTupleAtExp Int
arity Int
position =
  let nameByIndex :: a -> Name
nameByIndex a
index = OccName -> NameFlavour -> Name
Name (String -> OccName
OccName (Char
'_' forall a. a -> [a] -> [a]
: forall a. Show a => a -> String
show a
index)) NameFlavour
NameS
      names :: [Name]
names = forall a. Enum a => a -> a -> [a]
enumFromTo Int
0 (forall a. Enum a => a -> a
pred Int
arity) forall a b. a -> (a -> b) -> b
& forall a b. (a -> b) -> [a] -> [b]
map forall {a}. Show a => a -> Name
nameByIndex
      pats :: [Pat]
pats = [Name]
names forall a b. a -> (a -> b) -> b
& forall a b. (a -> b) -> [a] -> [b]
map Name -> Pat
VarP
      pat :: Pat
pat = [Pat] -> Pat
TupP [Pat]
pats
      exps :: [Exp]
exps = [Name]
names forall a b. a -> (a -> b) -> b
& forall a b. (a -> b) -> [a] -> [b]
map Name -> Exp
VarE
      body :: Exp
body = forall a. Int -> [a] -> ([a], [a])
splitAt Int
position [Exp]
exps forall a b. a -> (a -> b) -> b
& \([Exp]
a, [Exp]
b) -> [Exp] -> Exp
Compat.tupE [[Exp] -> Exp
Compat.tupE [Exp]
a, [Exp] -> Exp
Compat.tupE [Exp]
b]
   in [Pat] -> Exp -> Exp
LamE [Pat
pat] Exp
body