{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

-- | Shared functions for dependent-sum-template
module Data.GADT.TH.Internal where

import Control.Monad
import Control.Monad.Writer
import Data.List (foldl', drop)
import Data.Maybe
import Data.Map (Map)
import qualified Data.Map as Map
import qualified Data.Map.Merge.Lazy as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Language.Haskell.TH
import Language.Haskell.TH.Datatype
import Language.Haskell.TH.Datatype.TyVarBndr

classHeadToParams :: Type -> (Name, [Type])
classHeadToParams :: Type -> (Name, Cxt)
classHeadToParams Type
t = (Name
h, forall a. [a] -> [a]
reverse Cxt
reversedParams)
  where
    (Name
h, Cxt
reversedParams) = Type -> (Name, Cxt)
go Type
t
    go :: Type -> (Name, [Type])
    go :: Type -> (Name, Cxt)
go Type
t = case Type
t of
      AppT Type
f Type
x ->
        let (Name
h, Cxt
reversedParams) = Type -> (Name, Cxt)
classHeadToParams Type
f
        in (Name
h, Type
x forall a. a -> [a] -> [a]
: Cxt
reversedParams)
      Type
_ -> (Type -> Name
headOfType Type
t, [])

-- Do not export this data family, it must remain empty. It's used as a way to trick GHC into not unifying certain type variables.
data family Skolem :: k -> k

skolemize :: Set Name -> Type -> Type
skolemize :: Set Name -> Type -> Type
skolemize Set Name
rigids Type
t = case Type
t of
  ForallT [TyVarBndr Specificity]
bndrs Cxt
cxt Type
t' -> [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT [TyVarBndr Specificity]
bndrs Cxt
cxt (Set Name -> Type -> Type
skolemize (forall a. Ord a => Set a -> Set a -> Set a
Set.difference Set Name
rigids (forall a. Ord a => [a] -> Set a
Set.fromList (forall a b. (a -> b) -> [a] -> [b]
map forall flag. TyVarBndr_ flag -> Name
tvName [TyVarBndr Specificity]
bndrs))) Type
t')
  AppT Type
t1 Type
t2 -> Type -> Type -> Type
AppT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t1) (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t2)
  SigT Type
t Type
k -> Type -> Type -> Type
SigT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t) Type
k
  VarT Name
v -> if forall a. Ord a => a -> Set a -> Bool
Set.member Name
v Set Name
rigids
    then Type -> Type -> Type
AppT (Name -> Type
ConT ''Skolem) (Name -> Type
VarT Name
v)
    else Type
t
  InfixT Type
t1 Name
n Type
t2 -> Type -> Name -> Type -> Type
InfixT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t1) Name
n (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t2)
  UInfixT Type
t1 Name
n Type
t2 -> Type -> Name -> Type -> Type
UInfixT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t1) Name
n (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t2)
  ParensT Type
t -> Type -> Type
ParensT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t)
  Type
_ -> Type
t

reifyInstancesWithRigids :: Set Name -> Name -> [Type] -> Q [InstanceDec]
reifyInstancesWithRigids :: Set Name -> Name -> Cxt -> Q [InstanceDec]
reifyInstancesWithRigids Set Name
rigids Name
cls Cxt
tys = Name -> Cxt -> Q [InstanceDec]
reifyInstances Name
cls (forall a b. (a -> b) -> [a] -> [b]
map (Set Name -> Type -> Type
skolemize Set Name
rigids) Cxt
tys)

-- | Determine the type variables which occur freely in a type.
freeTypeVariables :: Type -> Set Name
freeTypeVariables :: Type -> Set Name
freeTypeVariables Type
t = case Type
t of
  ForallT [TyVarBndr Specificity]
bndrs Cxt
_ Type
t' -> forall a. Ord a => Set a -> Set a -> Set a
Set.difference (Type -> Set Name
freeTypeVariables Type
t') (forall a. Ord a => [a] -> Set a
Set.fromList (forall a b. (a -> b) -> [a] -> [b]
map forall flag. TyVarBndr_ flag -> Name
tvName [TyVarBndr Specificity]
bndrs))
  AppT Type
t1 Type
t2 -> forall a. Ord a => Set a -> Set a -> Set a
Set.union (Type -> Set Name
freeTypeVariables Type
t1) (Type -> Set Name
freeTypeVariables Type
t2)
  SigT Type
t Type
_ -> Type -> Set Name
freeTypeVariables Type
t
  VarT Name
n -> forall a. a -> Set a
Set.singleton Name
n
  Type
_ -> forall a. Set a
Set.empty

subst :: Map Name Type -> Type -> Type
subst :: Map Name Type -> Type -> Type
subst Map Name Type
s = Type -> Type
f
  where
    f :: Type -> Type
f = \case
      ForallT [TyVarBndr Specificity]
bndrs Cxt
cxt Type
t ->
        let s' :: Map Name Type
s' = forall k a b. Ord k => Map k a -> Map k b -> Map k a
Map.difference Map Name Type
s (forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Name
k,()) | Name
k <- forall a b. (a -> b) -> [a] -> [b]
map forall flag. TyVarBndr_ flag -> Name
tvName [TyVarBndr Specificity]
bndrs])
        in [TyVarBndr Specificity] -> Cxt -> Type -> Type
ForallT [TyVarBndr Specificity]
bndrs Cxt
cxt (Map Name Type -> Type -> Type
subst Map Name Type
s' Type
t)
      AppT Type
t Type
t' -> Type -> Type -> Type
AppT (Type -> Type
f Type
t) (Type -> Type
f Type
t')
      SigT Type
t Type
k -> Type -> Type -> Type
SigT (Type -> Type
f Type
t) Type
k
      VarT Name
n -> case forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Name
n Map Name Type
s of
        Just Type
t -> Type
t
        Maybe Type
Nothing -> Name -> Type
VarT Name
n
      InfixT Type
t Name
x Type
t' -> Type -> Name -> Type -> Type
InfixT (Type -> Type
f Type
t) Name
x (Type -> Type
f Type
t')
      UInfixT Type
t Name
x Type
t' -> Type -> Name -> Type -> Type
UInfixT (Type -> Type
f Type
t) Name
x (Type -> Type
f Type
t')
      Type
x -> Type
x

-- Invoke the deriver for the given class instance.  We assume that the type
-- we're deriving for is always the first typeclass parameter, if there are
-- multiple.
deriveForDec
  :: Name
  -> (DatatypeInfo -> WriterT [Type] Q Dec)
  -> Dec
  -> Q [Dec]
deriveForDec :: Name
-> (DatatypeInfo -> WriterT Cxt Q InstanceDec)
-> InstanceDec
-> Q [InstanceDec]
deriveForDec Name
className DatatypeInfo -> WriterT Cxt Q InstanceDec
f (InstanceD Maybe Overlap
overlaps Cxt
cxt Type
instanceHead [InstanceDec]
decs) = do
  let (Name
givenClassName, Type
firstParam : Cxt
_) = Type -> (Name, Cxt)
classHeadToParams Type
instanceHead
  forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name
givenClassName forall a. Eq a => a -> a -> Bool
/= Name
className) forall a b. (a -> b) -> a -> b
$
    forall (m :: * -> *) a. MonadFail m => String -> m a
fail forall a b. (a -> b) -> a -> b
$ String
"while deriving " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Name
className forall a. [a] -> [a] -> [a]
++ String
": wrong class name in prototype declaration: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Name
givenClassName
  let dataTypeName :: Name
dataTypeName = Type -> Name
headOfType Type
firstParam
  DatatypeInfo
dataTypeInfo <- Name -> Q DatatypeInfo
reifyDatatype Name
dataTypeName
  let instTypes :: Cxt
instTypes = DatatypeInfo -> Cxt
datatypeInstTypes DatatypeInfo
dataTypeInfo
      paramVars :: Set Name
paramVars = forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions [Type -> Set Name
freeTypeVariables Type
t | Type
t <- Cxt
instTypes]
      instTypes' :: Cxt
instTypes' = case forall a. [a] -> [a]
reverse Cxt
instTypes of
        [] -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGEq: Not enough type parameters"
        (Type
_:Cxt
xs) -> forall a. [a] -> [a]
reverse Cxt
xs
      generatedInstanceHead :: Type
generatedInstanceHead = Type -> Type -> Type
AppT (Name -> Type
ConT Name
className) (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT forall a b. (a -> b) -> a -> b
$ DatatypeInfo -> Name
datatypeName DatatypeInfo
dataTypeInfo) Cxt
instTypes')
  Map Name Type
unifiedTypes <- Cxt -> Q (Map Name Type)
unifyTypes [Type
generatedInstanceHead, Type
instanceHead]
  let
    newInstanceHead :: Type
newInstanceHead = forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
unifiedTypes Type
instanceHead
    newContext :: Cxt
newContext = forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
unifiedTypes Cxt
cxt
  -- We are not using the generated context that we collect from f, instead
  -- relying on a correct instance head from the user
  (InstanceDec
dec, Cxt
_) <- forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT forall a b. (a -> b) -> a -> b
$ DatatypeInfo -> WriterT Cxt Q InstanceDec
f DatatypeInfo
dataTypeInfo
  forall (m :: * -> *) a. Monad m => a -> m a
return [Maybe Overlap -> Cxt -> Type -> [InstanceDec] -> InstanceDec
InstanceD Maybe Overlap
overlaps Cxt
newContext Type
newInstanceHead [InstanceDec
dec]]
deriveForDec Name
className DatatypeInfo -> WriterT Cxt Q InstanceDec
f InstanceDec
dataDec = do
  DatatypeInfo
dataTypeInfo <- InstanceDec -> Q DatatypeInfo
normalizeDec InstanceDec
dataDec
  let instTypes :: Cxt
instTypes = DatatypeInfo -> Cxt
datatypeInstTypes DatatypeInfo
dataTypeInfo
      paramVars :: Set Name
paramVars = forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions [Type -> Set Name
freeTypeVariables Type
t | Type
t <- Cxt
instTypes]
      instTypes' :: Cxt
instTypes' = case forall a. [a] -> [a]
reverse Cxt
instTypes of
        [] -> forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGEq: Not enough type parameters"
        (Type
_:Cxt
xs) -> forall a. [a] -> [a]
reverse Cxt
xs
      instanceHead :: Type
instanceHead = Type -> Type -> Type
AppT (Name -> Type
ConT Name
className) (forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT forall a b. (a -> b) -> a -> b
$ DatatypeInfo -> Name
datatypeName DatatypeInfo
dataTypeInfo) Cxt
instTypes')
  (InstanceDec
dec, Cxt
cxt') <- forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (DatatypeInfo -> WriterT Cxt Q InstanceDec
f DatatypeInfo
dataTypeInfo)
  forall (m :: * -> *) a. Monad m => a -> m a
return [Maybe Overlap -> Cxt -> Type -> [InstanceDec] -> InstanceDec
InstanceD forall a. Maybe a
Nothing (DatatypeInfo -> Cxt
datatypeContext DatatypeInfo
dataTypeInfo forall a. [a] -> [a] -> [a]
++ Cxt
cxt') Type
instanceHead [InstanceDec
dec]]

headOfType :: Type -> Name
headOfType :: Type -> Name
headOfType = \case
  ForallT [TyVarBndr Specificity]
_ Cxt
_ Type
ty -> Type -> Name
headOfType Type
ty
  VarT Name
name -> Name
name
  ConT Name
name -> Name
name
  TupleT Int
n -> Int -> Name
tupleTypeName Int
n
  Type
ArrowT -> ''(->)
  Type
ListT -> ''[]
  AppT Type
t Type
_ -> Type -> Name
headOfType Type
t