{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE DataKinds #-}

-- | Shared functions for dependent-sum-template
module Data.GADT.TH.Internal where

import Control.Monad
import Control.Monad.Writer
import qualified Data.Kind
import Data.List (foldl', drop)
import Data.Maybe
import Data.Map (Map)
import qualified Data.Map as Map
import qualified Data.Map.Merge.Lazy as Map
import Data.Set (Set)
import qualified Data.Set as Set
import Language.Haskell.TH
import Language.Haskell.TH.Datatype
import Language.Haskell.TH.Datatype.TyVarBndr

classHeadToParams :: Type -> (Name, [Type])
classHeadToParams :: Type -> (Name, [Type])
classHeadToParams Type
t = (Name
h, [Type] -> [Type]
forall a. [a] -> [a]
reverse [Type]
reversedParams)
  where
    (Name
h, [Type]
reversedParams) = Type -> (Name, [Type])
go Type
t
    go :: Type -> (Name, [Type])
    go :: Type -> (Name, [Type])
go Type
t = case Type
t of
      AppT Type
f Type
x ->
        let (Name
h, [Type]
reversedParams) = Type -> (Name, [Type])
classHeadToParams Type
f
        in (Name
h, Type
x Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
reversedParams)
      Type
_ -> (Type -> Name
headOfType Type
t, [])

-- Do not export this type family, it must remain empty. It's used as a way to trick GHC into not unifying certain type variables.
type family Skolem :: k -> k

skolemize :: Set Name -> Type -> Type
skolemize :: Set Name -> Type -> Type
skolemize Set Name
rigids Type
t = case Type
t of
  ForallT [TyVarBndr]
bndrs [Type]
cxt Type
t' -> [TyVarBndr] -> [Type] -> Type -> Type
ForallT [TyVarBndr]
bndrs [Type]
cxt (Set Name -> Type -> Type
skolemize (Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
Set.difference Set Name
rigids ([Name] -> Set Name
forall a. Ord a => [a] -> Set a
Set.fromList ((TyVarBndr -> Name) -> [TyVarBndr] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndr -> Name
forall flag. TyVarBndr -> Name
tvName [TyVarBndr]
bndrs))) Type
t')
  AppT Type
t1 Type
t2 -> Type -> Type -> Type
AppT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t1) (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t2)
  SigT Type
t Type
k -> Type -> Type -> Type
SigT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t) Type
k
  VarT Name
v -> if Name -> Set Name -> Bool
forall a. Ord a => a -> Set a -> Bool
Set.member Name
v Set Name
rigids
    then Type -> Type -> Type
AppT (Name -> Type
ConT ''Skolem) (Name -> Type
VarT Name
v)
    else Type
t
  InfixT Type
t1 Name
n Type
t2 -> Type -> Name -> Type -> Type
InfixT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t1) Name
n (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t2)
  UInfixT Type
t1 Name
n Type
t2 -> Type -> Name -> Type -> Type
UInfixT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t1) Name
n (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t2)
  ParensT Type
t -> Type -> Type
ParensT (Set Name -> Type -> Type
skolemize Set Name
rigids Type
t)
  Type
_ -> Type
t

reifyInstancesBroken :: Q Bool
reifyInstancesBroken :: Q Bool
reifyInstancesBroken = do
  Name
a <- String -> Q Name
newName String
"a"
  [InstanceDec]
ins <- Set Name -> Name -> [Type] -> Q [InstanceDec]
reifyInstancesWithRigids' (Name -> Set Name
forall a. a -> Set a
Set.singleton Name
a) ''Show [Name -> Type
VarT Name
a]
  Bool -> Q Bool
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> Q Bool) -> Bool -> Q Bool
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [InstanceDec] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [InstanceDec]
ins

reifyInstancesWithRigids' :: Set Name -> Name -> [Type] -> Q [InstanceDec]
reifyInstancesWithRigids' :: Set Name -> Name -> [Type] -> Q [InstanceDec]
reifyInstancesWithRigids' Set Name
rigids Name
cls [Type]
tys = Name -> [Type] -> Q [InstanceDec]
reifyInstances Name
cls ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Set Name -> Type -> Type
skolemize Set Name
rigids) [Type]
tys)

reifyInstancesWithRigids :: Set Name -> Name -> [Type] -> Q [InstanceDec]
reifyInstancesWithRigids :: Set Name -> Name -> [Type] -> Q [InstanceDec]
reifyInstancesWithRigids Set Name
rigids Name
cls [Type]
tys = do
  Bool
isBroken <- Q Bool
reifyInstancesBroken
  if Bool
isBroken
    then String -> Q [InstanceDec]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Unsupported GHC version: 'reifyInstances' in this version of GHC returns instances when we expect an empty list. See https://gitlab.haskell.org/ghc/ghc/-/issues/23743"
    else Set Name -> Name -> [Type] -> Q [InstanceDec]
reifyInstancesWithRigids' Set Name
rigids Name
cls [Type]
tys

-- | Determine the type variables which occur freely in a type.
freeTypeVariables :: Type -> Set Name
freeTypeVariables :: Type -> Set Name
freeTypeVariables Type
t = case Type
t of
  ForallT [TyVarBndr]
bndrs [Type]
_ Type
t' -> Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
Set.difference (Type -> Set Name
freeTypeVariables Type
t') ([Name] -> Set Name
forall a. Ord a => [a] -> Set a
Set.fromList ((TyVarBndr -> Name) -> [TyVarBndr] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndr -> Name
forall flag. TyVarBndr -> Name
tvName [TyVarBndr]
bndrs))
  AppT Type
t1 Type
t2 -> Set Name -> Set Name -> Set Name
forall a. Ord a => Set a -> Set a -> Set a
Set.union (Type -> Set Name
freeTypeVariables Type
t1) (Type -> Set Name
freeTypeVariables Type
t2)
  SigT Type
t Type
_ -> Type -> Set Name
freeTypeVariables Type
t
  VarT Name
n -> Name -> Set Name
forall a. a -> Set a
Set.singleton Name
n
  Type
_ -> Set Name
forall a. Set a
Set.empty

subst :: Map Name Type -> Type -> Type
subst :: Map Name Type -> Type -> Type
subst Map Name Type
s = Type -> Type
f
  where
    f :: Type -> Type
f = \case
      ForallT [TyVarBndr]
bndrs [Type]
cxt Type
t ->
        let s' :: Map Name Type
s' = Map Name Type -> Map Name () -> Map Name Type
forall k a b. Ord k => Map k a -> Map k b -> Map k a
Map.difference Map Name Type
s ([(Name, ())] -> Map Name ()
forall k a. Ord k => [(k, a)] -> Map k a
Map.fromList [(Name
k,()) | Name
k <- (TyVarBndr -> Name) -> [TyVarBndr] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map TyVarBndr -> Name
forall flag. TyVarBndr -> Name
tvName [TyVarBndr]
bndrs])
        in [TyVarBndr] -> [Type] -> Type -> Type
ForallT [TyVarBndr]
bndrs [Type]
cxt (Map Name Type -> Type -> Type
subst Map Name Type
s' Type
t)
      AppT Type
t Type
t' -> Type -> Type -> Type
AppT (Type -> Type
f Type
t) (Type -> Type
f Type
t')
      SigT Type
t Type
k -> Type -> Type -> Type
SigT (Type -> Type
f Type
t) Type
k
      VarT Name
n -> case Name -> Map Name Type -> Maybe Type
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup Name
n Map Name Type
s of
        Just Type
t -> Type
t
        Maybe Type
Nothing -> Name -> Type
VarT Name
n
      InfixT Type
t Name
x Type
t' -> Type -> Name -> Type -> Type
InfixT (Type -> Type
f Type
t) Name
x (Type -> Type
f Type
t')
      UInfixT Type
t Name
x Type
t' -> Type -> Name -> Type -> Type
UInfixT (Type -> Type
f Type
t) Name
x (Type -> Type
f Type
t')
      Type
x -> Type
x

-- Invoke the deriver for the given class instance.  We assume that the type
-- we're deriving for is always the first typeclass parameter, if there are
-- multiple.
deriveForDec
  :: Name
  -> (DatatypeInfo -> WriterT [Type] Q Dec)
  -> Dec
  -> Q [Dec]
deriveForDec :: Name
-> (DatatypeInfo -> WriterT [Type] Q InstanceDec)
-> InstanceDec
-> Q [InstanceDec]
deriveForDec Name
className DatatypeInfo -> WriterT [Type] Q InstanceDec
f (InstanceD Maybe Overlap
overlaps [Type]
cxt Type
instanceHead [InstanceDec]
decs) = do
  let (Name
givenClassName, Type
firstParam : [Type]
_) = Type -> (Name, [Type])
classHeadToParams Type
instanceHead
  Bool -> Q () -> Q ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Name
givenClassName Name -> Name -> Bool
forall a. Eq a => a -> a -> Bool
/= Name
className) (Q () -> Q ()) -> Q () -> Q ()
forall a b. (a -> b) -> a -> b
$
    String -> Q ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Q ()) -> String -> Q ()
forall a b. (a -> b) -> a -> b
$ String
"while deriving " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
className String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": wrong class name in prototype declaration: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
forall a. Show a => a -> String
show Name
givenClassName
  let dataTypeName :: Name
dataTypeName = Type -> Name
headOfType Type
firstParam
  DatatypeInfo
dataTypeInfo <- Name -> Q DatatypeInfo
reifyDatatype Name
dataTypeName
  let instTypes :: [Type]
instTypes = DatatypeInfo -> [Type]
datatypeInstTypes DatatypeInfo
dataTypeInfo
      paramVars :: Set Name
paramVars = [Set Name] -> Set Name
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions [Type -> Set Name
freeTypeVariables Type
t | Type
t <- [Type]
instTypes]
      instTypes' :: [Type]
instTypes' = case [Type] -> [Type]
forall a. [a] -> [a]
reverse [Type]
instTypes of
        [] -> String -> [Type]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGEq: Not enough type parameters"
        (Type
_:[Type]
xs) -> [Type] -> [Type]
forall a. [a] -> [a]
reverse [Type]
xs
      generatedInstanceHead :: Type
generatedInstanceHead = Type -> Type -> Type
AppT (Name -> Type
ConT Name
className) ((Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ DatatypeInfo -> Name
datatypeName DatatypeInfo
dataTypeInfo) [Type]
instTypes')
  Map Name Type
unifiedTypes <- [Type] -> Q (Map Name Type)
unifyTypes [Type
generatedInstanceHead, Type
instanceHead]
  let
    newInstanceHead :: Type
newInstanceHead = Map Name Type -> Type -> Type
forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
unifiedTypes Type
instanceHead
    newContext :: [Type]
newContext = Map Name Type -> [Type] -> [Type]
forall a. TypeSubstitution a => Map Name Type -> a -> a
applySubstitution Map Name Type
unifiedTypes [Type]
cxt
  -- We are not using the generated context that we collect from f, instead
  -- relying on a correct instance head from the user
  (InstanceDec
dec, [Type]
_) <- WriterT [Type] Q InstanceDec -> Q (InstanceDec, [Type])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (WriterT [Type] Q InstanceDec -> Q (InstanceDec, [Type]))
-> WriterT [Type] Q InstanceDec -> Q (InstanceDec, [Type])
forall a b. (a -> b) -> a -> b
$ DatatypeInfo -> WriterT [Type] Q InstanceDec
f DatatypeInfo
dataTypeInfo
  [InstanceDec] -> Q [InstanceDec]
forall (m :: * -> *) a. Monad m => a -> m a
return [Maybe Overlap -> [Type] -> Type -> [InstanceDec] -> InstanceDec
InstanceD Maybe Overlap
overlaps [Type]
newContext Type
newInstanceHead [InstanceDec
dec]]
deriveForDec Name
className DatatypeInfo -> WriterT [Type] Q InstanceDec
f InstanceDec
dataDec = do
  DatatypeInfo
dataTypeInfo <- InstanceDec -> Q DatatypeInfo
normalizeDec InstanceDec
dataDec
  let instTypes :: [Type]
instTypes = DatatypeInfo -> [Type]
datatypeInstTypes DatatypeInfo
dataTypeInfo
      paramVars :: Set Name
paramVars = [Set Name] -> Set Name
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
Set.unions [Type -> Set Name
freeTypeVariables Type
t | Type
t <- [Type]
instTypes]
      instTypes' :: [Type]
instTypes' = case [Type] -> [Type]
forall a. [a] -> [a]
reverse [Type]
instTypes of
        [] -> String -> [Type]
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"deriveGEq: Not enough type parameters"
        (Type
_:[Type]
xs) -> [Type] -> [Type]
forall a. [a] -> [a]
reverse [Type]
xs
      instanceHead :: Type
instanceHead = Type -> Type -> Type
AppT (Name -> Type
ConT Name
className) ((Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Type -> Type -> Type
AppT (Name -> Type
ConT (Name -> Type) -> Name -> Type
forall a b. (a -> b) -> a -> b
$ DatatypeInfo -> Name
datatypeName DatatypeInfo
dataTypeInfo) [Type]
instTypes')
  (InstanceDec
dec, [Type]
cxt') <- WriterT [Type] Q InstanceDec -> Q (InstanceDec, [Type])
forall w (m :: * -> *) a. WriterT w m a -> m (a, w)
runWriterT (DatatypeInfo -> WriterT [Type] Q InstanceDec
f DatatypeInfo
dataTypeInfo)
  [InstanceDec] -> Q [InstanceDec]
forall (m :: * -> *) a. Monad m => a -> m a
return [Maybe Overlap -> [Type] -> Type -> [InstanceDec] -> InstanceDec
InstanceD Maybe Overlap
forall a. Maybe a
Nothing (DatatypeInfo -> [Type]
datatypeContext DatatypeInfo
dataTypeInfo [Type] -> [Type] -> [Type]
forall a. [a] -> [a] -> [a]
++ [Type]
cxt') Type
instanceHead [InstanceDec
dec]]

headOfType :: Type -> Name
headOfType :: Type -> Name
headOfType = \case
  ForallT [TyVarBndr]
_ [Type]
_ Type
ty -> Type -> Name
headOfType Type
ty
  VarT Name
name -> Name
name
  ConT Name
name -> Name
name
  TupleT Int
n -> Int -> Name
tupleTypeName Int
n
  Type
ArrowT -> ''(->)
  Type
ListT -> ''[]
  AppT Type
t Type
_ -> Type -> Name
headOfType Type
t