{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE PolyKinds #-}
-- | Shared functions for dependent-sum-template
module Data.Dependent.Sum.TH.Internal where

import Control.Monad
import Language.Haskell.TH
import Language.Haskell.TH.Extras

classHeadToParams :: Type -> (Name, [Type])
classHeadToParams :: Type -> (Name, [Type])
classHeadToParams t :: Type
t = (Name
h, [Type] -> [Type]
forall a. [a] -> [a]
reverse [Type]
reversedParams)
  where (h :: Name
h, reversedParams :: [Type]
reversedParams) = Type -> (Name, [Type])
go Type
t
        go :: Type -> (Name, [Type])
        go :: Type -> (Name, [Type])
go t :: Type
t = case Type
t of
          AppT f :: Type
f x :: Type
x ->
            let (h :: Name
h, reversedParams :: [Type]
reversedParams) = Type -> (Name, [Type])
classHeadToParams Type
f
            in (Name
h, Type
x Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
reversedParams)
          _ -> (Type -> Name
headOfType Type
t, [])

-- Invoke the deriver for the given class instance.  We assume that the type
-- we're deriving for is always the first typeclass parameter, if there are
-- multiple.
deriveForDec :: Name -> (Q Type -> Q Type) -> ([TyVarBndr] -> [Con] -> Q Dec) -> Dec -> Q [Dec]
deriveForDec :: Name
-> (Q Type -> Q Type)
-> ([TyVarBndr] -> [Con] -> Q Dec)
-> Dec
-> Q [Dec]
deriveForDec className :: Name
className _ f :: [TyVarBndr] -> [Con] -> Q Dec
f (InstanceD overlaps :: Maybe Overlap
overlaps cxt :: [Type]
cxt classHead :: Type
classHead decs :: [Dec]
decs) = do
    let (givenClassName :: Name
givenClassName, firstParam :: Type
firstParam : _) = Type -> (Name, [Type])
classHeadToParams Type
classHead
    Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name
givenClassName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
/= Name
className) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
      String -> Q ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$ "while deriving " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
className String -> String -> String
forall a. [a] -> [a] -> [a]
++ ": wrong class name in prototype declaration: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
givenClassName
    let dataTypeName :: Name
dataTypeName = Type -> Name
headOfType Type
firstParam
    Info
dataTypeInfo <- Name -> Q Info
reify Name
dataTypeName
    case Info
dataTypeInfo of
        TyConI (DataD dataCxt :: [Type]
dataCxt name :: Name
name bndrs :: [TyVarBndr]
bndrs _ cons :: [Con]
cons _) -> do
            Dec
dec <- [TyVarBndr] -> [Con] -> Q Dec
f [TyVarBndr]
bndrs [Con]
cons
            [Dec] -> Q [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return [Maybe Overlap -> [Type] -> Type -> [Dec] -> Dec
InstanceD Maybe Overlap
overlaps [Type]
cxt Type
classHead [Dec
dec]]
        _ -> String -> Q [Dec]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q [Dec]) -> String -> Q [Dec]
forall a b. (a -> b) -> a -> b
$ "while deriving " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
className String -> String -> String
forall a. [a] -> [a] -> [a]
++ ": the name of an algebraic data type constructor is required"
deriveForDec className :: Name
className makeClassHead :: Q Type -> Q Type
makeClassHead f :: [TyVarBndr] -> [Con] -> Q Dec
f (DataD dataCxt :: [Type]
dataCxt name :: Name
name bndrs :: [TyVarBndr]
bndrs _ cons :: [Con]
cons _) = Dec -> [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return (Dec -> [Dec]) -> Q Dec -> Q [Dec]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Q Dec
inst
    where
        inst :: Q Dec
inst = CxtQ -> Q Type -> [Q Dec] -> Q Dec
instanceD ([Q Type] -> CxtQ
cxt ((Type -> Q Type) -> [Type] -> [Q Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Q Type
forall (m :: * -> *) a. Monad m => a -> m a
return [Type]
dataCxt)) (Q Type -> Q Type
makeClassHead (Q Type -> Q Type) -> Q Type -> Q Type
forall a b. (a -> b) -> a -> b
$ Name -> Q Type
conT Name
name) [Q Dec
dec]
        dec :: Q Dec
dec = [TyVarBndr] -> [Con] -> Q Dec
f [TyVarBndr]
bndrs [Con]
cons
#if __GLASGOW_HASKELL__ >= 808
deriveForDec className :: Name
className makeClassHead :: Q Type -> Q Type
makeClassHead f :: [TyVarBndr] -> [Con] -> Q Dec
f (DataInstD dataCxt :: [Type]
dataCxt tvBndrs :: Maybe [TyVarBndr]
tvBndrs ty :: Type
ty _ cons :: [Con]
cons _) = Dec -> [Dec]
forall (m :: * -> *) a. Monad m => a -> m a
return (Dec -> [Dec]) -> Q Dec -> Q [Dec]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Q Dec
inst
#else
deriveForDec className makeClassHead f (DataInstD dataCxt name tyArgs _ cons _) = return <$> inst
#endif
    where
        inst :: Q Dec
inst = CxtQ -> Q Type -> [Q Dec] -> Q Dec
instanceD ([Q Type] -> CxtQ
cxt ((Type -> Q Type) -> [Type] -> [Q Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Q Type
forall (m :: * -> *) a. Monad m => a -> m a
return [Type]
dataCxt)) Q Type
clhead [Q Dec
dec]
#if __GLASGOW_HASKELL__ >= 808
        clhead :: Q Type
clhead = Q Type -> Q Type
makeClassHead (Q Type -> Q Type) -> Q Type -> Q Type
forall a b. (a -> b) -> a -> b
$ Type -> Q Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Q Type) -> Type -> Q Type
forall a b. (a -> b) -> a -> b
$ Type -> Type
initTy Type
ty
        bndrs :: [TyVarBndr]
bndrs = [Name -> TyVarBndr
PlainTV Name
v | PlainTV v :: Name
v <- [TyVarBndr]
-> ([TyVarBndr] -> [TyVarBndr]) -> Maybe [TyVarBndr] -> [TyVarBndr]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [] [TyVarBndr] -> [TyVarBndr]
forall a. a -> a
id Maybe [TyVarBndr]
tvBndrs]
        initTy :: Type -> Type
initTy (AppT ty :: Type
ty _) = Type
ty
#else
        clhead = makeClassHead $ foldl1 appT (map return $ (ConT name : init tyArgs))
        -- TODO: figure out proper number of family parameters vs instance parameters
        bndrs = [PlainTV v | VarT v <- tail tyArgs ]
#endif
        dec :: Q Dec
dec = [TyVarBndr] -> [Con] -> Q Dec
f [TyVarBndr]
bndrs [Con]
cons