{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedStrings #-}
module Data.Constraint.Deriving.ToInstance
( ToInstance (..)
, OverlapMode (..)
, toInstancePass
, CorePluginEnvRef, initCorePluginEnv
) where
import Class (Class, classTyCon)
import Control.Applicative (Alternative (..))
import Control.Monad (join, unless)
import Data.Data (Data)
import Data.Maybe (fromMaybe, isJust)
import Data.Monoid (First (..))
import GhcPlugins hiding (OverlapMode (..), overlapMode)
import qualified InstEnv
import qualified OccName
import Panic (panicDoc)
import qualified Unify
import Data.Constraint.Deriving.CorePluginM
newtype ToInstance = ToInstance { overlapMode :: OverlapMode }
deriving (Eq, Show, Read, Data)
toInstancePass :: CorePluginEnvRef -> CoreToDo
toInstancePass eref = CoreDoPluginPass "Data.Constraint.Deriving.ToInstance"
(\x -> fromMaybe x <$> runCorePluginM (toInstancePass' x) eref)
toInstancePass' :: ModGuts -> CorePluginM ModGuts
toInstancePass' gs = go (reverse $ mg_binds gs) annotateds gs { mg_binds = []}
where
annotateds :: UniqFM [(Name, ToInstance)]
annotateds = getModuleAnns gs
go :: [CoreBind] -> UniqFM [(Name, ToInstance)] -> ModGuts -> CorePluginM ModGuts
go [] anns guts = do
unless (isNullUFM anns) $
pluginWarning $ "One or more ToInstance annotations are ignored:"
$+$ vcat
(map (pprBulletNameLoc . fst) . join $ eltsUFM anns)
$$ "Note possible issues:"
$$ pprNotes
[ "ToInstance is meant to be used only on bindings of type Ctx => Dict (Class t1 .. tn)."
, "Currently, I process non-recursive bindings only."
, sep
[ "Non-exported bindings may vanish before the plugin pass:"
, "make sure you export annotated definitions!"
]
]
return guts
go (cbx@(NonRec x _):xs) anns guts
| Just ((xn, ti):ds) <- lookupUFM anns x = do
unless (null ds) $
pluginLocatedWarning (nameSrcSpan xn) $
"Ignoring redundant ToInstance annotions" $$
hcat
[ "(the plugin needs only one annotation per binding, but got "
, speakN (length ds + 1)
, ")"
]
try (toInstance ti cbx) >>= \case
Nothing
-> go xs (delFromUFM anns x) guts { mg_binds = cbx : mg_binds guts}
Just (newInstance, newBind)
-> go xs (delFromUFM anns x) guts
{ mg_insts = newInstance : mg_insts guts
, mg_inst_env = InstEnv.extendInstEnv (mg_inst_env guts) newInstance
, mg_binds = cbx : newBind : mg_binds guts
, mg_exports = filterAvails (xn /=) $ mg_exports guts
}
go (x:xs) anns guts = go xs anns guts { mg_binds = x : mg_binds guts}
pprBulletNameLoc n = hsep
[" " , bullet, ppr $ occName n, ppr $ nameSrcSpan n]
pprNotes = vcat . map (\x -> hsep [" ", bullet, x])
toInstance :: ToInstance -> CoreBind -> CorePluginM (InstEnv.ClsInst, CoreBind)
toInstance _ (Rec xs) = do
loc <- liftCoreM getSrcSpanM
pluginLocatedError
(fromMaybe loc $ getFirst $ foldMap (First . Just . nameSrcSpan . getName . fst) xs)
$ "ToInstance plugin pass does not support recursive bindings"
$$ hsep ["(group:", pprQuotedList (map (getName . fst) xs), ")"]
toInstance (ToInstance omode) (NonRec bindVar bindExpr) = do
unless (all (isConstraintKind . typeKind) theta) $
pluginLocatedError loc notGoodMsg
tcBareConstraint <- ask tyConBareConstraint
tcDict <- ask tyConDict
fDictToBare <- ask funDictToBare
varCls <- newTyVar constraintKind
let tyMatcher = mkTyConApp tcDict [mkTyVarTy varCls]
match <- case Unify.tcMatchTy tyMatcher dictTy of
Nothing -> pluginLocatedError loc notGoodMsg
Just ma -> pure ma
let matchedTy = substTyVar match varCls
instSig = mkSpecForAllTys bndrs $ mkFunTys theta matchedTy
bindBareTy = mkSpecForAllTys bndrs $ mkFunTys theta $ mkTyConApp tcBareConstraint [matchedTy]
matchedClass <- case tyConAppTyCon_maybe matchedTy >>= tyConClass_maybe of
Nothing -> pluginLocatedError loc notGoodMsg
Just cl -> pure cl
mnewExpr <- try $ unwrapDictExpr dictTy fDictToBare bindExpr
newExpr <- case mnewExpr of
Nothing -> pluginLocatedError loc notGoodMsg
Just ex -> pure $ mkCast ex
$ mkUnsafeCo Representational bindBareTy instSig
mkNewInstance omode matchedClass bindVar newExpr
where
origBindTy = idType bindVar
(bndrs, bindTy) = splitForAllTys origBindTy
(theta, dictTy) = splitFunTys bindTy
loc = nameSrcSpan $ getName bindVar
notGoodMsg =
"ToInstance plugin pass failed to process a Dict declaraion."
$$ "The declaration must have form `forall a1..an . Ctx => Dict (Cls t1..tn)'"
$$ "Declaration:"
$$ hcat
[ " "
, ppr bindVar, " :: "
, ppr origBindTy
]
$$ ""
$$ "Please check:"
$$ vcat
( map (\s -> hsep [" ", bullet, s])
[ "It must not have arguments (i.e. is it not a fuction, but a value);"
, "It must have type Dict;"
, "The argument of Dict must be a single class (e.g. no constraint tuples or equalities);"
, "It must not have implicit arguments or any other complicated things."
]
)
mkNewInstance :: OverlapMode
-> Class
-> Id
-> CoreExpr
-> CorePluginM (InstEnv.ClsInst, CoreBind)
mkNewInstance omode cls bindVar bindExpr = do
n <- newName OccName.varName
$ getOccString bindVar ++ "_ToInstance"
let iDFunId = mkExportedLocalId
(DFunId $ isNewTyCon (classTyCon cls))
n itype
return
( InstEnv.mkLocalInstance iDFunId ioflag tvs cls tys
, NonRec iDFunId bindExpr
)
where
ioflag = toOverlapFlag omode
itype = exprType bindExpr
(tvs, itype') = splitForAllTys itype
(_, typeBody) = splitFunTys itype'
tys = fromMaybe aAaaOmg $ tyConAppArgs_maybe typeBody
aAaaOmg = panicDoc "ToInstance" $ hsep
[ "Impossible happened:"
, "expected a class constructor in mkNewInstance, but got"
, ppr typeBody
, "at", ppr $ nameSrcSpan $ getName bindVar
]
unwrapDictExpr :: Type
-> Id
-> CoreExpr
-> CorePluginM CoreExpr
unwrapDictExpr dictT unwrapFun ex = case ex of
Var _ -> testNWrap unwrapFail <|> (mkLamApp >>= proceed)
Lit _ -> testNWrap unwrapFail
App e a -> testNWrap $ (App e <$> proceed a) <|> (flip App a <$> proceed e)
Lam b e -> testNWrap $ Lam b <$> proceed e
Let b e -> testNWrap $ Let b <$> proceed e
Case{} -> testNWrap unwrapFail
Cast{} -> testNWrap unwrapFail
Tick t e -> testNWrap $ Tick t <$> proceed e
Type{} -> unwrapFail
Coercion{} -> unwrapFail
where
unwrapFail = pluginError
$ "Failed to match a definition signature."
$$ hang "Looking for a dictionary:" 2 (ppr dictT)
$$ hang "Inspecting an expression:" 2
(hsep [ppr ex, "::", ppr $ exprType ex])
proceed = unwrapDictExpr dictT unwrapFun
testNWrap go = if testType ex then wrap ex else go
wrap e = flip fmap (getClsT e) $ \t -> Var unwrapFun `App` t `App` e
testType = isJust . Unify.tcMatchTy dictT . exprType
getClsT e = case tyConAppArgs_maybe $ exprType e of
Just [t] -> pure $ Type t
_ -> unwrapFail
mkThetaVar (i, ty) = newLocalVar ty ("theta" ++ show (i :: Int))
mkLamApp =
let et0 = exprType ex
(bndrs, et1) = splitForAllTys et0
(theta, _ ) = splitFunTys et1
in if null bndrs && null theta
then unwrapFail
else do
thetaVars <- traverse mkThetaVar $ zip [1 ..] theta
let allVars = bndrs ++ thetaVars
allApps = map (Type . mkTyVarTy) bndrs ++ map Var thetaVars
fullyApplied = foldl App ex allApps
return $ foldr Lam fullyApplied allVars