{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings  #-}
module Data.Constraint.Deriving.ClassDict
  ( ClassDict (..)
  , classDictPass
  , CorePluginEnvRef, initCorePluginEnv
  ) where

import Control.Monad (join, unless, when)
import Data.Data     (Data)
import Data.Maybe    (fromMaybe, isJust)


import Data.Constraint.Deriving.CorePluginM
import Data.Constraint.Deriving.Import


{- | A marker to tell the core plugin to replace the implementation of a
     top-level function by a corresponding class data constructor
       (wrapped into `Data.Constraint.Dict`).

     Example:

@

class BarClass a => FooClass a where
  fooFun1 :: a -> a -> Int
  fooFun2 :: a -> Bool

{\-\# ANN deriveFooClass ClassDict \#-\}
deriveFooClass :: forall a . BarClass a
               => (a -> a -> Int)
               -> (a -> Bool)
               -> Dict (FooClass a)
deriveFooClass = deriveFooClass
@

     That is, the plugin replaces the RHS of @deriveFooClass@ function with
     `DataCon.classDataCon` wrapped by `bareToDict`.

     Note:

     * The plugin requires you to create a dummy function `deriveFooClass` and
       annotate it with `ClassDict` instead of automatically creating this function
       for you; this way, the function is visible during type checking:
       you can use it in the same module (avoiding orphans) and you see its type signature.
     * You have to provide a correct signature for `deriveFooClass` function;
       the plugin compares this signature against visible classes and their constructors.
       An incorrect signature will result in a compile-time error.
     * The dummy implementation @deriveFooClass = deriveFooClass@ is used here to
       prevent GHC from inlining the function before the plugin can replace it.
       But you can implement it in any way you like at your own risk.
 -}
data ClassDict = ClassDict
  deriving (ClassDict -> ClassDict -> Bool
(ClassDict -> ClassDict -> Bool)
-> (ClassDict -> ClassDict -> Bool) -> Eq ClassDict
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ClassDict -> ClassDict -> Bool
$c/= :: ClassDict -> ClassDict -> Bool
== :: ClassDict -> ClassDict -> Bool
$c== :: ClassDict -> ClassDict -> Bool
Eq, Int -> ClassDict -> ShowS
[ClassDict] -> ShowS
ClassDict -> String
(Int -> ClassDict -> ShowS)
-> (ClassDict -> String)
-> ([ClassDict] -> ShowS)
-> Show ClassDict
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ClassDict] -> ShowS
$cshowList :: [ClassDict] -> ShowS
show :: ClassDict -> String
$cshow :: ClassDict -> String
showsPrec :: Int -> ClassDict -> ShowS
$cshowsPrec :: Int -> ClassDict -> ShowS
Show, ReadPrec [ClassDict]
ReadPrec ClassDict
Int -> ReadS ClassDict
ReadS [ClassDict]
(Int -> ReadS ClassDict)
-> ReadS [ClassDict]
-> ReadPrec ClassDict
-> ReadPrec [ClassDict]
-> Read ClassDict
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [ClassDict]
$creadListPrec :: ReadPrec [ClassDict]
readPrec :: ReadPrec ClassDict
$creadPrec :: ReadPrec ClassDict
readList :: ReadS [ClassDict]
$creadList :: ReadS [ClassDict]
readsPrec :: Int -> ReadS ClassDict
$creadsPrec :: Int -> ReadS ClassDict
Read, Typeable ClassDict
DataType
Constr
Typeable ClassDict
-> (forall (c :: * -> *).
    (forall d b. Data d => c (d -> b) -> d -> c b)
    -> (forall g. g -> c g) -> ClassDict -> c ClassDict)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c ClassDict)
-> (ClassDict -> Constr)
-> (ClassDict -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c ClassDict))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c ClassDict))
-> ((forall b. Data b => b -> b) -> ClassDict -> ClassDict)
-> (forall r r'.
    (r -> r' -> r)
    -> r -> (forall d. Data d => d -> r') -> ClassDict -> r)
-> (forall r r'.
    (r' -> r -> r)
    -> r -> (forall d. Data d => d -> r') -> ClassDict -> r)
-> (forall u. (forall d. Data d => d -> u) -> ClassDict -> [u])
-> (forall u.
    Int -> (forall d. Data d => d -> u) -> ClassDict -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> ClassDict -> m ClassDict)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> ClassDict -> m ClassDict)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> ClassDict -> m ClassDict)
-> Data ClassDict
ClassDict -> DataType
ClassDict -> Constr
(forall b. Data b => b -> b) -> ClassDict -> ClassDict
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> ClassDict -> c ClassDict
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c ClassDict
forall a.
Typeable a
-> (forall (c :: * -> *).
    (forall d b. Data d => c (d -> b) -> d -> c b)
    -> (forall g. g -> c g) -> a -> c a)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c a)
-> (a -> Constr)
-> (a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c a))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a))
-> ((forall b. Data b => b -> b) -> a -> a)
-> (forall r r'.
    (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall r r'.
    (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall u. (forall d. Data d => d -> u) -> a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> Data a
forall u. Int -> (forall d. Data d => d -> u) -> ClassDict -> u
forall u. (forall d. Data d => d -> u) -> ClassDict -> [u]
forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> ClassDict -> r
forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> ClassDict -> r
forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> ClassDict -> m ClassDict
forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> ClassDict -> m ClassDict
forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c ClassDict
forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> ClassDict -> c ClassDict
forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c ClassDict)
forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c ClassDict)
$cClassDict :: Constr
$tClassDict :: DataType
gmapMo :: (forall d. Data d => d -> m d) -> ClassDict -> m ClassDict
$cgmapMo :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> ClassDict -> m ClassDict
gmapMp :: (forall d. Data d => d -> m d) -> ClassDict -> m ClassDict
$cgmapMp :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> ClassDict -> m ClassDict
gmapM :: (forall d. Data d => d -> m d) -> ClassDict -> m ClassDict
$cgmapM :: forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> ClassDict -> m ClassDict
gmapQi :: Int -> (forall d. Data d => d -> u) -> ClassDict -> u
$cgmapQi :: forall u. Int -> (forall d. Data d => d -> u) -> ClassDict -> u
gmapQ :: (forall d. Data d => d -> u) -> ClassDict -> [u]
$cgmapQ :: forall u. (forall d. Data d => d -> u) -> ClassDict -> [u]
gmapQr :: (r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> ClassDict -> r
$cgmapQr :: forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> ClassDict -> r
gmapQl :: (r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> ClassDict -> r
$cgmapQl :: forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> ClassDict -> r
gmapT :: (forall b. Data b => b -> b) -> ClassDict -> ClassDict
$cgmapT :: (forall b. Data b => b -> b) -> ClassDict -> ClassDict
dataCast2 :: (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c ClassDict)
$cdataCast2 :: forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c ClassDict)
dataCast1 :: (forall d. Data d => c (t d)) -> Maybe (c ClassDict)
$cdataCast1 :: forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c ClassDict)
dataTypeOf :: ClassDict -> DataType
$cdataTypeOf :: ClassDict -> DataType
toConstr :: ClassDict -> Constr
$ctoConstr :: ClassDict -> Constr
gunfold :: (forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c ClassDict
$cgunfold :: forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c ClassDict
gfoldl :: (forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> ClassDict -> c ClassDict
$cgfoldl :: forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> ClassDict -> c ClassDict
$cp1Data :: Typeable ClassDict
Data)

-- | Run `ClassDict` plugin pass
classDictPass :: CorePluginEnvRef -> CoreToDo
classDictPass :: CorePluginEnvRef -> CoreToDo
classDictPass CorePluginEnvRef
eref = String -> CorePluginPass -> CoreToDo
CoreDoPluginPass String
"Data.Constraint.Deriving.ClassDict"
  -- if a plugin pass totally  fails to do anything useful,
  -- copy original ModGuts as its output, so that next passes can do their jobs.
  (\ModGuts
x -> ModGuts -> Maybe ModGuts -> ModGuts
forall a. a -> Maybe a -> a
fromMaybe ModGuts
x (Maybe ModGuts -> ModGuts)
-> CoreM (Maybe ModGuts) -> CoreM ModGuts
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CorePluginM ModGuts -> CorePluginEnvRef -> CoreM (Maybe ModGuts)
forall a. CorePluginM a -> CorePluginEnvRef -> CoreM (Maybe a)
runCorePluginM (ModGuts -> CorePluginM ModGuts
classDictPass' ModGuts
x) CorePluginEnvRef
eref)

classDictPass' :: ModGuts -> CorePluginM ModGuts
classDictPass' :: ModGuts -> CorePluginM ModGuts
classDictPass' ModGuts
guts = do
    (UniqMap [Name]
remAnns, [CoreBind]
processedBinds) <- WithAnns [CoreBind]
-> UniqMap [Name] -> CorePluginM (UniqMap [Name], [CoreBind])
forall a.
WithAnns a -> UniqMap [Name] -> CorePluginM (UniqMap [Name], a)
runWithAnns ((CoreBind -> WithAnns CoreBind)
-> [CoreBind] -> WithAnns [CoreBind]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse CoreBind -> WithAnns CoreBind
go (ModGuts -> [CoreBind]
mg_binds ModGuts
guts)) UniqMap [Name]
annotateds
    Bool -> CorePluginM () -> CorePluginM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (UniqMap [Name] -> Bool
forall elt. UniqFM elt -> Bool
isNullUFM UniqMap [Name]
remAnns) (CorePluginM () -> CorePluginM ())
-> CorePluginM () -> CorePluginM ()
forall a b. (a -> b) -> a -> b
$
      SDoc -> CorePluginM ()
pluginWarning (SDoc -> CorePluginM ()) -> SDoc -> CorePluginM ()
forall a b. (a -> b) -> a -> b
$ SDoc
"One or more ClassDict annotations are ignored:"
        SDoc -> SDoc -> SDoc
$+$ [SDoc] -> SDoc
vcat
          ((Name -> SDoc) -> [Name] -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map Name -> SDoc
pprBulletNameLoc ([Name] -> [SDoc]) -> ([[Name]] -> [Name]) -> [[Name]] -> [SDoc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[Name]] -> [Name]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join ([[Name]] -> [SDoc]) -> [[Name]] -> [SDoc]
forall a b. (a -> b) -> a -> b
$ UniqMap [Name] -> [[Name]]
forall elt. UniqFM elt -> [elt]
eltsUFM UniqMap [Name]
remAnns)
        SDoc -> SDoc -> SDoc
$$ SDoc
"Note possible issues:"
        SDoc -> SDoc -> SDoc
$$ [SDoc] -> SDoc
pprNotes
         [ SDoc
"ClassDict is meant to be used only on bindings of type Ctx => Dict (Class t1 .. tn)."
         , SDoc
"GHC may remove the annotated definition if it is not reachable from module exports."
         ]
    ModGuts -> CorePluginM ModGuts
forall (m :: * -> *) a. Monad m => a -> m a
return ModGuts
guts { mg_binds :: [CoreBind]
mg_binds = [CoreBind]
processedBinds}
  where
    annotateds :: UniqMap [Name]
    annotateds :: UniqMap [Name]
annotateds = ((Name, ClassDict) -> Name) -> [(Name, ClassDict)] -> [Name]
forall a b. (a -> b) -> [a] -> [b]
map (Name, ClassDict) -> Name
forall a b. (a, b) -> a
fst ([(Name, ClassDict)] -> [Name])
-> UniqFM [(Name, ClassDict)] -> UniqMap [Name]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ModGuts -> UniqFM [(Name, ClassDict)]
forall a. Data a => ModGuts -> UniqMap [(Name, a)]
getModuleAnns ModGuts
guts :: UniqMap [(Name, ClassDict)])

    go :: CoreBind -> WithAnns CoreBind
    go :: CoreBind -> WithAnns CoreBind
go (NonRec CoreBndr
b Expr CoreBndr
e) = CoreBndr -> Expr CoreBndr -> CoreBind
forall b. b -> Expr b -> Bind b
NonRec CoreBndr
b (Expr CoreBndr -> CoreBind)
-> WithAnns (Expr CoreBndr) -> WithAnns CoreBind
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CoreBndr -> Expr CoreBndr -> WithAnns (Expr CoreBndr)
classDict' CoreBndr
b Expr CoreBndr
e
    go (Rec [(CoreBndr, Expr CoreBndr)]
bes)    = [(CoreBndr, Expr CoreBndr)] -> CoreBind
forall b. [(b, Expr b)] -> Bind b
Rec ([(CoreBndr, Expr CoreBndr)] -> CoreBind)
-> WithAnns [(CoreBndr, Expr CoreBndr)] -> WithAnns CoreBind
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((CoreBndr, Expr CoreBndr) -> WithAnns (CoreBndr, Expr CoreBndr))
-> [(CoreBndr, Expr CoreBndr)]
-> WithAnns [(CoreBndr, Expr CoreBndr)]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (\(CoreBndr
b, Expr CoreBndr
e) -> (,) CoreBndr
b (Expr CoreBndr -> (CoreBndr, Expr CoreBndr))
-> WithAnns (Expr CoreBndr) -> WithAnns (CoreBndr, Expr CoreBndr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CoreBndr -> Expr CoreBndr -> WithAnns (Expr CoreBndr)
classDict' CoreBndr
b Expr CoreBndr
e) [(CoreBndr, Expr CoreBndr)]
bes

    pprBulletNameLoc :: Name -> SDoc
pprBulletNameLoc Name
n = [SDoc] -> SDoc
hsep
      [SDoc
" " , SDoc
bullet, OccName -> SDoc
forall a. Outputable a => a -> SDoc
ppr (OccName -> SDoc) -> OccName -> SDoc
forall a b. (a -> b) -> a -> b
$ Name -> OccName
forall name. HasOccName name => name -> OccName
occName Name
n, SrcSpan -> SDoc
forall a. Outputable a => a -> SDoc
ppr (SrcSpan -> SDoc) -> SrcSpan -> SDoc
forall a b. (a -> b) -> a -> b
$ Name -> SrcSpan
nameSrcSpan Name
n]
    pprNotes :: [SDoc] -> SDoc
pprNotes = [SDoc] -> SDoc
vcat ([SDoc] -> SDoc) -> ([SDoc] -> [SDoc]) -> [SDoc] -> SDoc
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SDoc -> SDoc) -> [SDoc] -> [SDoc]
forall a b. (a -> b) -> [a] -> [b]
map (\SDoc
x -> [SDoc] -> SDoc
hsep [SDoc
" ", SDoc
bullet, SDoc
x])

    classDict' :: CoreBndr -> Expr CoreBndr -> WithAnns (Expr CoreBndr)
classDict' CoreBndr
x Expr CoreBndr
origBind = (UniqMap [Name] -> CorePluginM (UniqMap [Name], Expr CoreBndr))
-> WithAnns (Expr CoreBndr)
forall a.
(UniqMap [Name] -> CorePluginM (UniqMap [Name], a)) -> WithAnns a
WithAnns ((UniqMap [Name] -> CorePluginM (UniqMap [Name], Expr CoreBndr))
 -> WithAnns (Expr CoreBndr))
-> (UniqMap [Name] -> CorePluginM (UniqMap [Name], Expr CoreBndr))
-> WithAnns (Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ \UniqMap [Name]
anns -> case UniqMap [Name] -> Unique -> Maybe [Name]
forall key elt. Uniquable key => UniqFM elt -> key -> Maybe elt
lookupUFM UniqMap [Name]
anns (CoreBndr -> Unique
forall a. Uniquable a => a -> Unique
getUnique CoreBndr
x) of
      Just (Name
xn:[Name]
xns) -> do
        Bool -> CorePluginM () -> CorePluginM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ([Name] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Name]
xns) (CorePluginM () -> CorePluginM ())
-> CorePluginM () -> CorePluginM ()
forall a b. (a -> b) -> a -> b
$
          SrcSpan -> SDoc -> CorePluginM ()
pluginLocatedWarning (Name -> SrcSpan
nameSrcSpan Name
xn) (SDoc -> CorePluginM ()) -> SDoc -> CorePluginM ()
forall a b. (a -> b) -> a -> b
$
            SDoc
"Ignoring redundant ClassDict annotations" SDoc -> SDoc -> SDoc
$$
            [SDoc] -> SDoc
hcat
            [ SDoc
"(the plugin needs only one annotation per binding, but got "
            , Int -> SDoc
speakN ([Name] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Name]
xns Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
            , SDoc
")"
            ]
        -- add new definitions and continue
        (,) (UniqMap [Name] -> Unique -> UniqMap [Name]
forall key elt. Uniquable key => UniqFM elt -> key -> UniqFM elt
delFromUFM UniqMap [Name]
anns (CoreBndr -> Unique
forall a. Uniquable a => a -> Unique
getUnique CoreBndr
x))  (Expr CoreBndr -> (UniqMap [Name], Expr CoreBndr))
-> (Maybe (Expr CoreBndr) -> Expr CoreBndr)
-> Maybe (Expr CoreBndr)
-> (UniqMap [Name], Expr CoreBndr)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr CoreBndr -> Maybe (Expr CoreBndr) -> Expr CoreBndr
forall a. a -> Maybe a -> a
fromMaybe Expr CoreBndr
origBind (Maybe (Expr CoreBndr) -> (UniqMap [Name], Expr CoreBndr))
-> CorePluginM (Maybe (Expr CoreBndr))
-> CorePluginM (UniqMap [Name], Expr CoreBndr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CorePluginM (Expr CoreBndr) -> CorePluginM (Maybe (Expr CoreBndr))
forall a. CorePluginM a -> CorePluginM (Maybe a)
try (CoreBndr -> CorePluginM (Expr CoreBndr)
classDict CoreBndr
x)
      Maybe [Name]
_ -> (UniqMap [Name], Expr CoreBndr)
-> CorePluginM (UniqMap [Name], Expr CoreBndr)
forall (m :: * -> *) a. Monad m => a -> m a
return (UniqMap [Name]
anns, Expr CoreBndr
origBind)

-- a small state transformer for tracking remaining annotations
newtype WithAnns a = WithAnns
  { WithAnns a -> UniqMap [Name] -> CorePluginM (UniqMap [Name], a)
runWithAnns :: UniqMap [Name] -> CorePluginM (UniqMap [Name], a) }

instance Functor WithAnns where
  fmap :: (a -> b) -> WithAnns a -> WithAnns b
fmap a -> b
f WithAnns a
m = (UniqMap [Name] -> CorePluginM (UniqMap [Name], b)) -> WithAnns b
forall a.
(UniqMap [Name] -> CorePluginM (UniqMap [Name], a)) -> WithAnns a
WithAnns ((UniqMap [Name] -> CorePluginM (UniqMap [Name], b)) -> WithAnns b)
-> (UniqMap [Name] -> CorePluginM (UniqMap [Name], b))
-> WithAnns b
forall a b. (a -> b) -> a -> b
$ ((UniqMap [Name], a) -> (UniqMap [Name], b))
-> CorePluginM (UniqMap [Name], a)
-> CorePluginM (UniqMap [Name], b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a -> b) -> (UniqMap [Name], a) -> (UniqMap [Name], b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f) (CorePluginM (UniqMap [Name], a)
 -> CorePluginM (UniqMap [Name], b))
-> (UniqMap [Name] -> CorePluginM (UniqMap [Name], a))
-> UniqMap [Name]
-> CorePluginM (UniqMap [Name], b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WithAnns a -> UniqMap [Name] -> CorePluginM (UniqMap [Name], a)
forall a.
WithAnns a -> UniqMap [Name] -> CorePluginM (UniqMap [Name], a)
runWithAnns WithAnns a
m

instance Applicative WithAnns where
  pure :: a -> WithAnns a
pure a
x = (UniqMap [Name] -> CorePluginM (UniqMap [Name], a)) -> WithAnns a
forall a.
(UniqMap [Name] -> CorePluginM (UniqMap [Name], a)) -> WithAnns a
WithAnns ((UniqMap [Name] -> CorePluginM (UniqMap [Name], a)) -> WithAnns a)
-> (UniqMap [Name] -> CorePluginM (UniqMap [Name], a))
-> WithAnns a
forall a b. (a -> b) -> a -> b
$ \UniqMap [Name]
anns -> (UniqMap [Name], a) -> CorePluginM (UniqMap [Name], a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (UniqMap [Name]
anns, a
x)
  WithAnns (a -> b)
mf <*> :: WithAnns (a -> b) -> WithAnns a -> WithAnns b
<*> WithAnns a
mx = (UniqMap [Name] -> CorePluginM (UniqMap [Name], b)) -> WithAnns b
forall a.
(UniqMap [Name] -> CorePluginM (UniqMap [Name], a)) -> WithAnns a
WithAnns ((UniqMap [Name] -> CorePluginM (UniqMap [Name], b)) -> WithAnns b)
-> (UniqMap [Name] -> CorePluginM (UniqMap [Name], b))
-> WithAnns b
forall a b. (a -> b) -> a -> b
$ \UniqMap [Name]
anns0 -> do
    (UniqMap [Name]
anns1, a -> b
f) <- WithAnns (a -> b)
-> UniqMap [Name] -> CorePluginM (UniqMap [Name], a -> b)
forall a.
WithAnns a -> UniqMap [Name] -> CorePluginM (UniqMap [Name], a)
runWithAnns WithAnns (a -> b)
mf UniqMap [Name]
anns0
    (UniqMap [Name]
anns2, a
x) <- WithAnns a -> UniqMap [Name] -> CorePluginM (UniqMap [Name], a)
forall a.
WithAnns a -> UniqMap [Name] -> CorePluginM (UniqMap [Name], a)
runWithAnns WithAnns a
mx UniqMap [Name]
anns1
    (UniqMap [Name], b) -> CorePluginM (UniqMap [Name], b)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (UniqMap [Name]
anns2, a -> b
f a
x)


-- | Replace a given CoreBind with a corresponding class DataCon fun implementation.
--
--   The core bind must have type `Ctx => Dict (Class t1 .. tn)`;
--   it does not change.
classDict :: CoreBndr -> CorePluginM CoreExpr

classDict :: CoreBndr -> CorePluginM (Expr CoreBndr)
classDict CoreBndr
bindVar = do

    -- get necessary definitions
    TyCon
tcDict <- (CorePluginEnv -> CorePluginM TyCon) -> CorePluginM TyCon
forall a. (CorePluginEnv -> CorePluginM a) -> CorePluginM a
ask CorePluginEnv -> CorePluginM TyCon
tyConDict
    let conDict :: DataCon
conDict = TyCon -> DataCon
tyConSingleDataCon TyCon
tcDict

    -- check that the outermost constructor of the result type is Dict
    -- and unwrap it.
    Type
dictContentTy <- case HasDebugCallStack => Type -> Maybe (TyCon, [Type])
Type -> Maybe (TyCon, [Type])
splitTyConApp_maybe Type
dictTy of
      Just (TyCon
tcDict', [Type
resTy])
        | TyCon
tcDict' TyCon -> TyCon -> Bool
forall a. Eq a => a -> a -> Bool
== TyCon
tcDict -> Type -> CorePluginM Type
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
resTy
      Maybe (TyCon, [Type])
err -> SrcSpan -> SDoc -> CorePluginM Type
forall a. SrcSpan -> SDoc -> CorePluginM a
pluginLocatedError SrcSpan
loc (SDoc -> CorePluginM Type) -> SDoc -> CorePluginM Type
forall a b. (a -> b) -> a -> b
$ [SDoc] -> SDoc
vcat
        [ [SDoc] -> SDoc
hsep
          [ SDoc
"Expected `Dict (Cls t1..tn)', but found", Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr Type
dictTy]
        , if Maybe (TyCon, [Type]) -> Bool
forall a. Maybe a -> Bool
isJust Maybe (TyCon, [Type])
err
          then SDoc
"(constructor or number of arguments do not match)."
          else SDoc
"(I could not split apart a constructor application)."
        , SDoc
notGoodMsg
        ]

    -- check if the content of the result Dict is indeed a class constraint
    -- and get the class and its arguments.
    (Class
klass, [Type]
instanceArgs) <- case HasDebugCallStack => Type -> Maybe (TyCon, [Type])
Type -> Maybe (TyCon, [Type])
splitTyConApp_maybe Type
dictContentTy of
      Just (TyCon
klassTyCon, [Type]
iArgs)
        | Just Class
klas <- TyCon -> Maybe Class
tyConClass_maybe TyCon
klassTyCon
          -> (Class, [Type]) -> CorePluginM (Class, [Type])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Class
klas, [Type]
iArgs)
        | Bool
otherwise
          -> SrcSpan -> SDoc -> CorePluginM (Class, [Type])
forall a. SrcSpan -> SDoc -> CorePluginM a
pluginLocatedError SrcSpan
loc (SDoc -> CorePluginM (Class, [Type]))
-> SDoc -> CorePluginM (Class, [Type])
forall a b. (a -> b) -> a -> b
$ [SDoc] -> SDoc
vcat
            [ [SDoc] -> SDoc
hsep
              [ SDoc
"Expected a class constructor, but found", TyCon -> SDoc
forall a. Outputable a => a -> SDoc
ppr TyCon
klassTyCon]
            ,   SDoc
"(not a class data constructor)."
            , SDoc
notGoodMsg
            ]
      Maybe (TyCon, [Type])
Nothing -> SrcSpan -> SDoc -> CorePluginM (Class, [Type])
forall a. SrcSpan -> SDoc -> CorePluginM a
pluginLocatedError SrcSpan
loc (SDoc -> CorePluginM (Class, [Type]))
-> SDoc -> CorePluginM (Class, [Type])
forall a b. (a -> b) -> a -> b
$ [SDoc] -> SDoc
vcat
            [ [SDoc] -> SDoc
hsep
              [ SDoc
"Expected a class constructor, but found", Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr Type
dictContentTy]
            ,   SDoc
"(I could not split apart a constructor application)."
            , SDoc
notGoodMsg
            ]

    -- the core of the plugin: use a class data constructor
    let klassDataCon :: DataCon
klassDataCon = Class -> DataCon
classDataCon Class
klass

    -- check if the types agree
    let expectedType :: Type
expectedType = (Type -> Type) -> Type -> Type
mapResultType (TyCon -> [Type] -> Type
mkTyConApp TyCon
tcDict ([Type] -> Type) -> (Type -> [Type]) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
:[]))
                       (Type -> Type) -> (CoreBndr -> Type) -> CoreBndr -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CoreBndr -> Type
idType (CoreBndr -> Type) -> CoreBndr -> Type
forall a b. (a -> b) -> a -> b
$ DataCon -> CoreBndr
dataConWorkId DataCon
klassDataCon

    Bool -> CorePluginM () -> CorePluginM ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ([(Type, Type)] -> Bool
typesCantMatch [(Type
origBindTy, Type
expectedType)]) (CorePluginM () -> CorePluginM ())
-> CorePluginM () -> CorePluginM ()
forall a b. (a -> b) -> a -> b
$
      SrcSpan -> SDoc -> CorePluginM ()
forall a. SrcSpan -> SDoc -> CorePluginM a
pluginLocatedError SrcSpan
loc (SDoc -> CorePluginM ()) -> SDoc -> CorePluginM ()
forall a b. (a -> b) -> a -> b
$ [SDoc] -> SDoc
vcat
            [ [SDoc] -> SDoc
hsep
              [ SDoc
"Cannot match the expected type (the type of the data constructor of the given class)"
              , SDoc
"and the found type (the user-supplied binding)."]
            , [SDoc] -> SDoc
hsep [SDoc
"Expected type:", Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr Type
expectedType]
            , [SDoc] -> SDoc
hsep [SDoc
"Found type:   ", Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr Type
origBindTy]
            ]

    [CoreBndr]
argVars <- (Type -> CorePluginM CoreBndr) -> [Type] -> CorePluginM [CoreBndr]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (Type -> String -> CorePluginM CoreBndr
`newLocalVar` String
"t") [Type]
argTys
    Expr CoreBndr -> CorePluginM (Expr CoreBndr)
forall (m :: * -> *) a. Monad m => a -> m a
return
      (Expr CoreBndr -> CorePluginM (Expr CoreBndr))
-> (Expr CoreBndr -> Expr CoreBndr)
-> Expr CoreBndr
-> CorePluginM (Expr CoreBndr)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [CoreBndr] -> Expr CoreBndr -> Expr CoreBndr
mkCoreLams ([CoreBndr]
bndrs [CoreBndr] -> [CoreBndr] -> [CoreBndr]
forall a. [a] -> [a] -> [a]
++ [CoreBndr]
argVars)
      (Expr CoreBndr -> CorePluginM (Expr CoreBndr))
-> Expr CoreBndr -> CorePluginM (Expr CoreBndr)
forall a b. (a -> b) -> a -> b
$ DataCon -> [Expr CoreBndr] -> Expr CoreBndr
mkCoreConApps DataCon
conDict
        [ Type -> Expr CoreBndr
forall b. Type -> Expr b
mkTyArg Type
dictContentTy
        , DataCon
klassDataCon DataCon -> [Expr CoreBndr] -> Expr CoreBndr
`mkCoreConApps`
          ((Type -> Expr CoreBndr) -> [Type] -> [Expr CoreBndr]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Expr CoreBndr
forall b. Type -> Expr b
mkTyArg [Type]
instanceArgs [Expr CoreBndr] -> [Expr CoreBndr] -> [Expr CoreBndr]
forall a. [a] -> [a] -> [a]
++ [CoreBndr] -> [Expr CoreBndr]
forall b. [CoreBndr] -> [Expr b]
varsToCoreExprs [CoreBndr]
argVars)
        ]

  where
    origBindTy :: Type
origBindTy = CoreBndr -> Type
idType CoreBndr
bindVar
    ([CoreBndr]
bndrs, Type
bindTy) = Type -> ([CoreBndr], Type)
splitForAllTys Type
origBindTy
    ([Type]
argTys, Type
dictTy) = Type -> ([Type], Type)
splitFunTysCompat Type
bindTy
    loc :: SrcSpan
loc = Name -> SrcSpan
nameSrcSpan (Name -> SrcSpan) -> Name -> SrcSpan
forall a b. (a -> b) -> a -> b
$ CoreBndr -> Name
forall a. NamedThing a => a -> Name
getName CoreBndr
bindVar
    notGoodMsg :: SDoc
notGoodMsg =
         SDoc
"ClassDict plugin pass failed to process a Dict declaraion."
      SDoc -> SDoc -> SDoc
$$ SDoc
"The declaration must have form `forall a1..an . Ctx => Dict (Cls t1..tn)'"
      SDoc -> SDoc -> SDoc
$$ SDoc
"Declaration:"
      SDoc -> SDoc -> SDoc
$$ [SDoc] -> SDoc
hcat
         [ SDoc
"  "
         , CoreBndr -> SDoc
forall a. Outputable a => a -> SDoc
ppr CoreBndr
bindVar, SDoc
" :: "
         , Type -> SDoc
forall a. Outputable a => a -> SDoc
ppr Type
origBindTy
         ]

-- | Transform the result type in a more complex fun type.
mapResultType :: (Type -> Type) -> Type -> Type
mapResultType :: (Type -> Type) -> Type -> Type
mapResultType Type -> Type
f Type
t
  | (bndrs :: [CoreBndr]
bndrs@(CoreBndr
_:[CoreBndr]
_), Type
t') <- Type -> ([CoreBndr], Type)
splitForAllTys Type
t
    = [CoreBndr] -> Type -> Type
mkSpecForAllTys [CoreBndr]
bndrs (Type -> Type) -> Type -> Type
forall a b. (a -> b) -> a -> b
$ (Type -> Type) -> Type -> Type
mapResultType Type -> Type
f Type
t'
  | Just (AnonArgFlag
vis, Mult
m, Type
at, Type
rt) <- Type -> Maybe (AnonArgFlag, Mult, Type, Type)
splitFunTyCompat Type
t
  -- Looks like `idType (dataConWorkId klassDataCon)` has constraints as visible arguments.
  -- I guess usually that does not change anything for the user, because they don't ever observe
  -- type signatures of class data constructors.
  -- This only pops up since 8.10 with the introduction of visibility arguments.
  -- The check below workarounds this.
    = AnonArgFlag -> Mult -> Type -> Type -> Type
mkFunTyCompat (Type -> AnonArgFlag -> AnonArgFlag
mkConstraintInvis Type
at AnonArgFlag
vis) Mult
m Type
at ((Type -> Type) -> Type -> Type
mapResultType Type -> Type
f Type
rt)
  | Bool
otherwise
    = Type -> Type
f Type
t