{-# LANGUAGE CPP #-}
module Data.Deriving.Via.Internal where
#if MIN_VERSION_template_haskell(2,12,0)
import Control.Monad ((<=<), unless)
import Data.Deriving.Internal
import qualified Data.Map as M
import Data.Map (Map)
import Data.Maybe (catMaybes)
import Language.Haskell.TH
import Language.Haskell.TH.Datatype
deriveGND :: Q Type -> Q [Dec]
deriveGND qty = do
ty <- qty
let (instanceTvbs, instanceCxt, instanceTy) = decomposeType ty
instanceTy' <- (resolveTypeSynonyms <=< resolveInfixT) instanceTy
decs <- deriveViaDecs instanceTy' Nothing
let instanceHeader = ForallT instanceTvbs instanceCxt instanceTy
(:[]) `fmap` instanceD (return [])
(return instanceHeader)
(map return decs)
deriveVia :: Q Type -> Q [Dec]
deriveVia qty = do
ty <- qty
let (instanceTvbs, instanceCxt, viaApp) = decomposeType ty
viaApp' <- (resolveTypeSynonyms <=< resolveInfixT) viaApp
(instanceTy, viaTy)
<- case unapplyTy viaApp' of
(via, [instanceTy,viaTy])
| via == ConT viaTypeName
-> return (instanceTy, viaTy)
_ -> fail $ unlines
[ "Failure to meet ‘deriveVia‘ specification"
, "\tThe ‘Via‘ type must be used, e.g."
, "\t[t| forall a. C (T a) `Via` V a |]"
]
decs <- deriveViaDecs instanceTy (Just viaTy)
let instanceHeader = ForallT instanceTvbs instanceCxt instanceTy
(:[]) `fmap` instanceD (return [])
(return instanceHeader)
(map return decs)
deriveViaDecs :: Type
-> Maybe Type
-> Q [Dec]
deriveViaDecs instanceTy mbViaTy = do
let (clsTy, clsArgs) = unapplyTy instanceTy
case clsTy of
ConT clsName -> do
clsInfo <- reify clsName
case clsInfo of
ClassI (ClassD _ _ clsTvbs _ clsDecs) _ ->
case (unsnoc clsArgs, unsnoc clsTvbs) of
(Just (_, dataApp), Just (_, clsLastTvb)) -> do
let (dataTy, dataArgs) = unapplyTy dataApp
clsLastTvbKind = tvbKind clsLastTvb
(_, kindList) = uncurryTy clsLastTvbKind
numArgsToEtaReduce = length kindList - 1
repTy <-
case mbViaTy of
Just viaTy -> return viaTy
Nothing ->
case dataTy of
ConT dataName -> do
DatatypeInfo {
datatypeInstTypes = dataInstTypes
, datatypeVariant = dv
, datatypeCons = cons
} <- reifyDatatype dataName
case newtypeRepType dv cons of
Just newtypeRepTy ->
case etaReduce numArgsToEtaReduce newtypeRepTy of
Just etaRepTy ->
let repTySubst =
M.fromList $
zipWith (\var arg -> (varTToName var, arg))
dataInstTypes dataArgs
in return $ applySubstitution repTySubst etaRepTy
Nothing -> etaReductionError instanceTy
Nothing -> fail $ "Not a newtype: " ++ nameBase dataName
_ -> fail $ "Not a data type: " ++ pprint dataTy
concat . catMaybes <$> traverse (deriveViaDecs' clsName clsTvbs clsArgs repTy) clsDecs
(_, _) -> fail $ "Cannot derive instance for nullary class " ++ pprint clsTy
_ -> fail $ "Not a type class: " ++ pprint clsTy
_ -> fail $ "Malformed instance: " ++ pprint instanceTy
deriveViaDecs' :: Name -> [TyVarBndr] -> [Type] -> Type -> Dec -> Q (Maybe [Dec])
deriveViaDecs' clsName clsTvbs clsArgs repTy dec = do
let numExpectedArgs = length clsTvbs
numActualArgs = length clsArgs
unless (numExpectedArgs == numActualArgs) $
fail $ "Mismatched number of class arguments"
++ "\n\tThe class " ++ nameBase clsName ++ " expects " ++ show numExpectedArgs ++ " argument(s),"
++ "\n\tbut was provided " ++ show numActualArgs ++ " argument(s)."
go dec
where
go :: Dec -> Q (Maybe [Dec])
go (OpenTypeFamilyD (TypeFamilyHead tfName tfTvbs _ _)) = do
let lhsSubst = zipTvbSubst clsTvbs clsArgs
rhsSubst = zipTvbSubst clsTvbs $ changeLast clsArgs repTy
tfTvbTys = map tvbToType tfTvbs
tfLHSTys = map (applySubstitution lhsSubst) tfTvbTys
tfRHSTys = map (applySubstitution rhsSubst) tfTvbTys
tfRHSTy = applyTy (ConT tfName) tfRHSTys
tfInst <- tySynInstDCompat tfName Nothing
(map pure tfLHSTys) (pure tfRHSTy)
pure (Just [tfInst])
go (SigD methName methTy) =
let (fromTy, toTy) = mkCoerceClassMethEqn clsTvbs clsArgs repTy $
stripOuterForallT methTy
fromTau = stripOuterForallT fromTy
toTau = stripOuterForallT toTy
rhsExpr = VarE coerceValName `AppTypeE` fromTau
`AppTypeE` toTau
`AppE` VarE methName
sig = SigD methName toTy
meth = ValD (VarP methName)
(NormalB rhsExpr)
[]
in return (Just [sig, meth])
go _ = return Nothing
mkCoerceClassMethEqn :: [TyVarBndr] -> [Type] -> Type -> Type -> (Type, Type)
mkCoerceClassMethEqn clsTvbs clsArgs repTy methTy
= ( applySubstitution rhsSubst methTy
, applySubstitution lhsSubst methTy
)
where
lhsSubst = zipTvbSubst clsTvbs clsArgs
rhsSubst = zipTvbSubst clsTvbs $ changeLast clsArgs repTy
zipTvbSubst :: [TyVarBndr] -> [Type] -> Map Name Type
zipTvbSubst tvbs = M.fromList . zipWith (\tvb ty -> (tvName tvb, ty)) tvbs
changeLast :: [a] -> a -> [a]
changeLast [] _ = error "changeLast"
changeLast [_] x = [x]
changeLast (x:xs) x' = x : changeLast xs x'
stripOuterForallT :: Type -> Type
#if __GLASGOW_HASKELL__ < 807
stripOuterForallT (ForallT _ _ ty) = ty
#endif
stripOuterForallT ty = ty
decomposeType :: Type -> ([TyVarBndr], Cxt, Type)
decomposeType (ForallT tvbs ctxt ty) = (tvbs, ctxt, ty)
decomposeType ty = ([], [], ty)
newtypeRepType :: DatatypeVariant -> [ConstructorInfo] -> Maybe Type
newtypeRepType dv cons = do
checkIfNewtype
case cons of
[ConstructorInfo { constructorVars = []
, constructorContext = []
, constructorFields = [repTy]
}] -> Just repTy
_ -> Nothing
where
checkIfNewtype :: Maybe ()
checkIfNewtype
| Newtype <- dv = Just ()
| NewtypeInstance <- dv = Just ()
| otherwise = Nothing
etaReduce :: Int -> Type -> Maybe Type
etaReduce num ty =
let (tyHead, tyArgs) = unapplyTy ty
(tyArgsRemaining, tyArgsDropped) = splitAt (length tyArgs - num) tyArgs
in if canEtaReduce tyArgsRemaining tyArgsDropped
then Just $ applyTy tyHead tyArgsRemaining
else Nothing
#endif