module Data.LargeHashable.TH (
deriveLargeHashable, deriveLargeHashableCtx, deriveLargeHashableNoCtx
, deriveLargeHashableCustomCtx
) where
import Data.LargeHashable.Class
import Language.Haskell.TH
import Foreign.C.Types (CULong (..))
import Control.Monad (forM)
deriveLargeHashable :: Name -> Q [Dec]
deriveLargeHashable n = reify n >>= \info ->
case info of
TyConI dec ->
case dec of
#if MIN_VERSION_template_haskell(2,11,0)
DataD context name tyvars _ cons _ ->
#else
DataD context name tyvars cons _ ->
#endif
buildInstance (ConT name) context tyvars cons
#if MIN_VERSION_template_haskell(2,11,0)
NewtypeD context name tyvars _ con _ ->
#else
NewtypeD context name tyvars con _ ->
#endif
buildInstance (ConT name) context tyvars [con]
_ -> fail $ notDeriveAbleErrorMsg n info
FamilyI _ instDecs -> fmap concat $ forM instDecs $ \instDec ->
case instDec of
#if MIN_VERSION_template_haskell(2,11,0)
DataInstD context name types _ cons _ ->
#else
DataInstD context name types cons _ ->
#endif
buildInstance (foldl AppT (ConT name) types) context [] cons
#if MIN_VERSION_template_haskell(2,11,0)
NewtypeInstD context name types _ con _ ->
#else
NewtypeInstD context name types con _ ->
#endif
buildInstance (foldl AppT (ConT name) types) context [] [con]
_ -> fail $ notDeriveAbleErrorMsg n info
_ -> fail $ notDeriveAbleErrorMsg n info
deriveLargeHashableCtx ::
Name
-> ([TypeQ] -> [PredQ])
-> Q [Dec]
deriveLargeHashableCtx tyName extraPreds =
deriveLargeHashableCustomCtx tyName mkCtx
where
mkCtx args oldCtx =
oldCtx ++ extraPreds args
deriveLargeHashableNoCtx ::
Name
-> (Q [Dec])
deriveLargeHashableNoCtx tyName =
deriveLargeHashableCustomCtx tyName (\_ _ -> [])
deriveLargeHashableCustomCtx ::
Name
-> ([TypeQ] -> [PredQ] -> [PredQ])
-> (Q [Dec])
deriveLargeHashableCustomCtx tyName extraPreds =
do decs <- deriveLargeHashable tyName
case decs of
#if MIN_VERSION_template_haskell(2,11,0)
(InstanceD overlap ctx ty body : _) ->
#else
(InstanceD ctx ty body : _) ->
#endif
do let args = reverse (collectArgs ty)
newCtx <- sequence (extraPreds (map return args) (map return ctx))
#if MIN_VERSION_template_haskell(2,11,0)
return [InstanceD overlap newCtx ty body]
#else
return [InstanceD newCtx ty body]
#endif
_ ->
error $
"Unexpected declarations returned by deriveLargeHashable: " ++ show (ppr decs)
where
collectArgs :: Type -> [Type]
collectArgs outerTy =
let loop ty =
case ty of
(AppT l r) ->
case l of
AppT _ _ -> r : loop l
_ -> [r]
_ -> []
in case outerTy of
AppT _ r -> loop r
_ -> []
notDeriveAbleErrorMsg :: Name -> Info -> String
notDeriveAbleErrorMsg name info = "Could not derive LargeHashable instance for "
++ (show name) ++ "(" ++ (show info) ++ "). If you think this should be possible, file an issue."
buildInstance :: Type -> Cxt -> [TyVarBndr] -> [Con] -> Q [Dec]
buildInstance basicType context vars cons =
let consWithIds = zip [0..] cons
constraints = makeConstraints context vars
typeWithVars = foldl appT (return basicType) $ map (varT . varName) vars
in (:[]) <$> instanceD constraints (conT ''LargeHashable `appT` typeWithVars)
[updateHashDeclaration consWithIds]
updateHashDeclaration :: [(Integer, Con)] -> Q Dec
updateHashDeclaration consWIds = funD 'updateHash (map (uncurry updateHashClause) consWIds)
updateHashClause :: Integer -> Con -> Q Clause
updateHashClause i con =
clause [return patOfClause]
(normalB $
foldl sequenceExps
[| updateHash ($(litE . IntegerL $ i) :: CULong) |]
hashUpdatesOfConFields)
[]
where hashUpdatesOfConFields = map (\name -> [| updateHash $(varE name) |]) patVarNames
patVarNames = case patOfClause of
ConP _ vars -> map (\(VarP v) -> v) vars
InfixP (VarP v1) _ (VarP v2) -> [v1, v2]
_ -> error "Pattern in patVarNames not matched!"
patOfClause = patternForCon con
patternForCon :: Con -> Pat
patternForCon con = case con of
NormalC n types -> ConP n $ uniqueVarPats (length types)
RecC n varTypes -> ConP n $ uniqueVarPats (length varTypes)
InfixC _ n _ -> InfixP (VarP . mkName $ "x") n (VarP . mkName $ "y")
ForallC _ _ c -> patternForCon c
#if MIN_VERSION_template_haskell(2,11,0)
GadtC [n] types _ -> ConP n $ uniqueVarPats (length types)
RecGadtC [n] varTypes _ -> ConP n $ uniqueVarPats (length varTypes)
#endif
where uniqueVarPats n = take n . map (VarP . mkName) $ names
sequenceExps :: Q Exp -> Q Exp -> Q Exp
sequenceExps first second = infixE (Just first) (varE '(>>)) (Just second)
makeConstraints :: Cxt -> [TyVarBndr] -> Q Cxt
makeConstraints context vars = return $ context ++
map (\v -> (ConT (toLargeHashableClass v)) `AppT` (VarT . varName $ v)) vars
where
toLargeHashableClass :: TyVarBndr -> Name
toLargeHashableClass var =
case var of
(PlainTV _) -> ''LargeHashable
(KindedTV _ (AppT (AppT ArrowT StarT) StarT)) -> ''LargeHashable'
(KindedTV _ _) -> ''LargeHashable
varName :: TyVarBndr -> Name
varName (PlainTV n) = n
varName (KindedTV n _) = n
names :: [String]
names = concat $ map (gen (map (:[]) ['a'..'z'])) [0..]
where gen :: [String] -> Integer -> [String]
gen acc 0 = acc
gen acc n = gen (concat $ map (\q -> map (\c -> c : q) ['a'..'z']) acc) (n 1)