{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE CPP #-}
module Control.Effect.Plugin.Fundep (fundepPlugin) where
import Control.Monad
import Data.Bifunctor
import Data.IORef
import qualified Data.Map as M
import Data.Maybe
import qualified Data.Set as S
import Control.Effect.Plugin.Fundep.Unification
import Control.Effect.Plugin.Fundep.Utils
import TcEvidence
import TcPluginM (tcPluginIO, tcLookupClass)
import TcRnTypes
#if __GLASGOW_HASKELL__ >= 810
import Constraint
#endif
import TcSMonad hiding (tcLookupClass)
import Type
import GHC (Class, mkModuleName)
import GHC.TcPluginM.Extra (lookupName)
import OccName (mkTcOcc)
import Packages (lookupModuleWithSuggestions, LookupResult (..))
import Outputable (pprPanic, text, (<+>), ($$))
getMemberClass :: TcPluginM Class
getMemberClass :: TcPluginM Class
getMemberClass = do
DynFlags
dflags <- TcM DynFlags -> TcPluginM DynFlags
forall a. TcM a -> TcPluginM a
unsafeTcPluginTcM TcM DynFlags
forall (m :: * -> *). HasDynFlags m => m DynFlags
getDynFlags
let error_msg :: a
error_msg = String -> SDoc -> a
forall a. HasCallStack => String -> SDoc -> a
pprPanic String
"in-other-words-plugin"
(SDoc -> a) -> SDoc -> a
forall a b. (a -> b) -> a -> b
$ String -> SDoc
text String
""
SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"--------------------------------------------------------------------------------"
SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"`in-other-words-plugin` is loaded, but"
SDoc -> SDoc -> SDoc
<+> String -> SDoc
text String
"`in-other-words` isn't available as a package."
SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"Probable fix: add `in-other-words` to your cabal `build-depends`"
SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
"--------------------------------------------------------------------------------"
SDoc -> SDoc -> SDoc
$$ String -> SDoc
text String
""
let lookupRes :: LookupResult
lookupRes = DynFlags -> ModuleName -> Maybe FastString -> LookupResult
lookupModuleWithSuggestions
DynFlags
dflags
(String -> ModuleName
mkModuleName String
"Control.Effect.Internal.Membership")
(FastString -> Maybe FastString
forall a. a -> Maybe a
Just FastString
"in-other-words")
case LookupResult
lookupRes of
LookupFound Module
md PackageConfig
_ -> do
Name
nm <- Module -> OccName -> TcPluginM Name
lookupName Module
md (String -> OccName
mkTcOcc String
"Member")
Name -> TcPluginM Class
tcLookupClass Name
nm
LookupResult
_ -> TcPluginM Class
forall a. a
error_msg
fundepPlugin :: TcPlugin
fundepPlugin :: TcPlugin
fundepPlugin = TcPlugin :: forall s.
TcPluginM s
-> (s -> TcPluginSolver) -> (s -> TcPluginM ()) -> TcPlugin
TcPlugin
{ tcPluginInit :: TcPluginM (IORef (Set Unification), Class)
tcPluginInit =
(,) (IORef (Set Unification)
-> Class -> (IORef (Set Unification), Class))
-> TcPluginM (IORef (Set Unification))
-> TcPluginM (Class -> (IORef (Set Unification), Class))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO (IORef (Set Unification)) -> TcPluginM (IORef (Set Unification))
forall a. IO a -> TcPluginM a
tcPluginIO (Set Unification -> IO (IORef (Set Unification))
forall a. a -> IO (IORef a)
newIORef Set Unification
forall a. Set a
S.empty)
TcPluginM (Class -> (IORef (Set Unification), Class))
-> TcPluginM Class -> TcPluginM (IORef (Set Unification), Class)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TcPluginM Class
getMemberClass
, tcPluginSolve :: (IORef (Set Unification), Class) -> TcPluginSolver
tcPluginSolve = (IORef (Set Unification), Class) -> TcPluginSolver
solveFundep
, tcPluginStop :: (IORef (Set Unification), Class) -> TcPluginM ()
tcPluginStop = TcPluginM () -> (IORef (Set Unification), Class) -> TcPluginM ()
forall a b. a -> b -> a
const (TcPluginM () -> (IORef (Set Unification), Class) -> TcPluginM ())
-> TcPluginM () -> (IORef (Set Unification), Class) -> TcPluginM ()
forall a b. (a -> b) -> a -> b
$ () -> TcPluginM ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
}
data MemberConstraint = MemberConstraint
{ MemberConstraint -> CtLoc
mcLoc :: CtLoc
, MemberConstraint -> Type
mcEffectName :: Type
, MemberConstraint -> Type
mcEffect :: Type
, MemberConstraint -> Type
mcRow :: Type
}
getMemberConstraints :: Class -> [Ct] -> [MemberConstraint]
getMemberConstraints :: Class -> [Ct] -> [MemberConstraint]
getMemberConstraints Class
cls [Ct]
cts = do
cd :: Ct
cd@CDictCan{cc_class :: Ct -> Class
cc_class = Class
cls', cc_tyargs :: Ct -> [Type]
cc_tyargs = [Type
_, Type
eff, Type
r]} <- [Ct]
cts
Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> [()]) -> Bool -> [()]
forall a b. (a -> b) -> a -> b
$ Class
cls Class -> Class -> Bool
forall a. Eq a => a -> a -> Bool
== Class
cls'
MemberConstraint -> [MemberConstraint]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (MemberConstraint -> [MemberConstraint])
-> MemberConstraint -> [MemberConstraint]
forall a b. (a -> b) -> a -> b
$ MemberConstraint :: CtLoc -> Type -> Type -> Type -> MemberConstraint
MemberConstraint
{ mcLoc :: CtLoc
mcLoc = Ct -> CtLoc
ctLoc Ct
cd
, mcEffectName :: Type
mcEffectName = Type -> Type
getEffName Type
eff
, mcEffect :: Type
mcEffect = Type
eff
, mcRow :: Type
mcRow = Type
r
}
findMatchingEffectIfSingular
:: MemberConstraint
-> [MemberConstraint]
-> Maybe Type
findMatchingEffectIfSingular :: MemberConstraint -> [MemberConstraint] -> Maybe Type
findMatchingEffectIfSingular (MemberConstraint CtLoc
_ Type
eff_name Type
wanted Type
r) [MemberConstraint]
ts =
[Type] -> Maybe Type
forall a. [a] -> Maybe a
singleListToJust ([Type] -> Maybe Type) -> [Type] -> Maybe Type
forall a b. (a -> b) -> a -> b
$ do
MemberConstraint CtLoc
_ Type
eff_name' Type
eff' Type
r' <- [MemberConstraint]
ts
Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> [()]) -> Bool -> [()]
forall a b. (a -> b) -> a -> b
$ Type -> Type -> Bool
eqType Type
eff_name Type
eff_name'
Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> [()]) -> Bool -> [()]
forall a b. (a -> b) -> a -> b
$ Type -> Type -> Bool
eqType Type
r Type
r'
Bool -> [()]
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> [()]) -> Bool -> [()]
forall a b. (a -> b) -> a -> b
$ SolveContext -> Type -> Type -> Bool
canUnifyRecursive SolveContext
FunctionDef Type
wanted Type
eff'
Type -> [Type]
forall (f :: * -> *) a. Applicative f => a -> f a
pure Type
eff'
getEffName :: Type -> Type
getEffName :: Type -> Type
getEffName Type
t = (Type, [Type]) -> Type
forall a b. (a, b) -> a
fst ((Type, [Type]) -> Type) -> (Type, [Type]) -> Type
forall a b. (a -> b) -> a -> b
$ Type -> (Type, [Type])
splitAppTys Type
t
mkWantedForce
:: MemberConstraint
-> Type
-> TcPluginM (Unification, Ct)
mkWantedForce :: MemberConstraint -> Type -> TcPluginM (Unification, Ct)
mkWantedForce MemberConstraint
mc Type
given = do
(CtEvidence
ev, Coercion
_) <- TcM (CtEvidence, Coercion) -> TcPluginM (CtEvidence, Coercion)
forall a. TcM a -> TcPluginM a
unsafeTcPluginTcM
(TcM (CtEvidence, Coercion) -> TcPluginM (CtEvidence, Coercion))
-> (TcS (CtEvidence, Coercion) -> TcM (CtEvidence, Coercion))
-> TcS (CtEvidence, Coercion)
-> TcPluginM (CtEvidence, Coercion)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TcS (CtEvidence, Coercion) -> TcM (CtEvidence, Coercion)
forall a. TcS a -> TcM a
runTcSDeriveds
(TcS (CtEvidence, Coercion) -> TcPluginM (CtEvidence, Coercion))
-> TcS (CtEvidence, Coercion) -> TcPluginM (CtEvidence, Coercion)
forall a b. (a -> b) -> a -> b
$ CtLoc -> Role -> Type -> Type -> TcS (CtEvidence, Coercion)
newWantedEq (MemberConstraint -> CtLoc
mcLoc MemberConstraint
mc) Role
Nominal Type
wanted Type
given
(Unification, Ct) -> TcPluginM (Unification, Ct)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ( OrdType -> OrdType -> Unification
Unification (Type -> OrdType
OrdType Type
wanted) (Type -> OrdType
OrdType Type
given)
, CtEvidence -> Ct
CNonCanonical CtEvidence
ev
)
where
wanted :: Type
wanted = MemberConstraint -> Type
mcEffect MemberConstraint
mc
mkWanted
:: MemberConstraint
-> SolveContext
-> Type
-> TcPluginM (Maybe (Unification, Ct))
mkWanted :: MemberConstraint
-> SolveContext -> Type -> TcPluginM (Maybe (Unification, Ct))
mkWanted MemberConstraint
mc SolveContext
solve_ctx Type
given =
Bool
-> TcPluginM (Unification, Ct)
-> TcPluginM (Maybe (Unification, Ct))
forall (m :: * -> *) (z :: * -> *) a.
(Monad m, Alternative z) =>
Bool -> m a -> m (z a)
whenA (Bool -> Bool
not (SolveContext -> Bool
mustUnify SolveContext
solve_ctx) Bool -> Bool -> Bool
|| SolveContext -> Type -> Type -> Bool
canUnifyRecursive SolveContext
solve_ctx Type
wanted Type
given) (TcPluginM (Unification, Ct)
-> TcPluginM (Maybe (Unification, Ct)))
-> TcPluginM (Unification, Ct)
-> TcPluginM (Maybe (Unification, Ct))
forall a b. (a -> b) -> a -> b
$
MemberConstraint -> Type -> TcPluginM (Unification, Ct)
mkWantedForce MemberConstraint
mc Type
given
where
wanted :: Type
wanted = MemberConstraint -> Type
mcEffect MemberConstraint
mc
exactlyOneWantedForR
:: [MemberConstraint]
-> Type
-> Bool
exactlyOneWantedForR :: [MemberConstraint] -> Type -> Bool
exactlyOneWantedForR [MemberConstraint]
wanteds
= Bool -> Maybe Bool -> Bool
forall a. a -> Maybe a -> a
fromMaybe Bool
False
(Maybe Bool -> Bool) -> (Type -> Maybe Bool) -> Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (OrdType -> Map OrdType Bool -> Maybe Bool)
-> Map OrdType Bool -> OrdType -> Maybe Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip OrdType -> Map OrdType Bool -> Maybe Bool
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Map OrdType Bool
singular_r
(OrdType -> Maybe Bool) -> (Type -> OrdType) -> Type -> Maybe Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> OrdType
OrdType
where
singular_r :: Map OrdType Bool
singular_r = [(OrdType, Bool)] -> Map OrdType Bool
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
([(OrdType, Bool)] -> Map OrdType Bool)
-> ([OrdType] -> [(OrdType, Bool)])
-> [OrdType]
-> Map OrdType Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((OrdType, Int) -> (OrdType, Bool))
-> [(OrdType, Int)] -> [(OrdType, Bool)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Int -> Bool) -> (OrdType, Int) -> (OrdType, Bool)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b -> c) -> p a b -> p a c
second (Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
1))
([(OrdType, Int)] -> [(OrdType, Bool)])
-> ([OrdType] -> [(OrdType, Int)])
-> [OrdType]
-> [(OrdType, Bool)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [OrdType] -> [(OrdType, Int)]
forall a. Eq a => [a] -> [(a, Int)]
countLength
([OrdType] -> Map OrdType Bool) -> [OrdType] -> Map OrdType Bool
forall a b. (a -> b) -> a -> b
$ Type -> OrdType
OrdType (Type -> OrdType)
-> (MemberConstraint -> Type) -> MemberConstraint -> OrdType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemberConstraint -> Type
mcRow (MemberConstraint -> OrdType) -> [MemberConstraint] -> [OrdType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [MemberConstraint]
wanteds
solveFundep
:: ( IORef (S.Set Unification)
, Class
)
-> [Ct]
-> [Ct]
-> [Ct]
-> TcPluginM TcPluginResult
solveFundep :: (IORef (Set Unification), Class) -> TcPluginSolver
solveFundep (IORef (Set Unification), Class)
_ [Ct]
_ [Ct]
_ [] = TcPluginResult -> TcPluginM TcPluginResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TcPluginResult -> TcPluginM TcPluginResult)
-> TcPluginResult -> TcPluginM TcPluginResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [] []
solveFundep (IORef (Set Unification)
ref, Class
cls) [Ct]
given [Ct]
_ [Ct]
wanted = do
let wanted_finds :: [MemberConstraint]
wanted_finds = Class -> [Ct] -> [MemberConstraint]
getMemberConstraints Class
cls [Ct]
wanted
given_finds :: [MemberConstraint]
given_finds = Class -> [Ct] -> [MemberConstraint]
getMemberConstraints Class
cls [Ct]
given
[Maybe (Unification, Ct)]
eqs <- [MemberConstraint]
-> (MemberConstraint -> TcPluginM (Maybe (Unification, Ct)))
-> TcPluginM [Maybe (Unification, Ct)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [MemberConstraint]
wanted_finds ((MemberConstraint -> TcPluginM (Maybe (Unification, Ct)))
-> TcPluginM [Maybe (Unification, Ct)])
-> (MemberConstraint -> TcPluginM (Maybe (Unification, Ct)))
-> TcPluginM [Maybe (Unification, Ct)]
forall a b. (a -> b) -> a -> b
$ \MemberConstraint
mc -> do
let r :: Type
r = MemberConstraint -> Type
mcRow MemberConstraint
mc
case MemberConstraint -> [MemberConstraint] -> Maybe Type
findMatchingEffectIfSingular MemberConstraint
mc [MemberConstraint]
given_finds of
Just Type
eff' -> (Unification, Ct) -> Maybe (Unification, Ct)
forall a. a -> Maybe a
Just ((Unification, Ct) -> Maybe (Unification, Ct))
-> TcPluginM (Unification, Ct)
-> TcPluginM (Maybe (Unification, Ct))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MemberConstraint -> Type -> TcPluginM (Unification, Ct)
mkWantedForce MemberConstraint
mc Type
eff'
Maybe Type
Nothing ->
case Type -> (Type, [Type])
splitAppTys Type
r of
(Type
_, [Type
_, Type
eff', Type
_]) ->
MemberConstraint
-> SolveContext -> Type -> TcPluginM (Maybe (Unification, Ct))
mkWanted MemberConstraint
mc
(Bool -> SolveContext
InterpreterUse (Bool -> SolveContext) -> Bool -> SolveContext
forall a b. (a -> b) -> a -> b
$ [MemberConstraint] -> Type -> Bool
exactlyOneWantedForR [MemberConstraint]
wanted_finds Type
r)
Type
eff'
(Type, [Type])
_ -> Maybe (Unification, Ct) -> TcPluginM (Maybe (Unification, Ct))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (Unification, Ct)
forall a. Maybe a
Nothing
Set Unification
already_emitted <- IO (Set Unification) -> TcPluginM (Set Unification)
forall a. IO a -> TcPluginM a
tcPluginIO (IO (Set Unification) -> TcPluginM (Set Unification))
-> IO (Set Unification) -> TcPluginM (Set Unification)
forall a b. (a -> b) -> a -> b
$ IORef (Set Unification) -> IO (Set Unification)
forall a. IORef a -> IO a
readIORef IORef (Set Unification)
ref
let ([Unification]
unifications, [Ct]
new_wanteds) = Set Unification -> [(Unification, Ct)] -> ([Unification], [Ct])
unzipNewWanteds Set Unification
already_emitted ([(Unification, Ct)] -> ([Unification], [Ct]))
-> [(Unification, Ct)] -> ([Unification], [Ct])
forall a b. (a -> b) -> a -> b
$ [Maybe (Unification, Ct)] -> [(Unification, Ct)]
forall a. [Maybe a] -> [a]
catMaybes [Maybe (Unification, Ct)]
eqs
IO () -> TcPluginM ()
forall a. IO a -> TcPluginM a
tcPluginIO (IO () -> TcPluginM ()) -> IO () -> TcPluginM ()
forall a b. (a -> b) -> a -> b
$ IORef (Set Unification)
-> (Set Unification -> Set Unification) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef IORef (Set Unification)
ref ((Set Unification -> Set Unification) -> IO ())
-> (Set Unification -> Set Unification) -> IO ()
forall a b. (a -> b) -> a -> b
$ Set Unification -> Set Unification -> Set Unification
forall a. Ord a => Set a -> Set a -> Set a
S.union (Set Unification -> Set Unification -> Set Unification)
-> Set Unification -> Set Unification -> Set Unification
forall a b. (a -> b) -> a -> b
$ [Unification] -> Set Unification
forall a. Ord a => [a] -> Set a
S.fromList [Unification]
unifications
TcPluginResult -> TcPluginM TcPluginResult
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TcPluginResult -> TcPluginM TcPluginResult)
-> TcPluginResult -> TcPluginM TcPluginResult
forall a b. (a -> b) -> a -> b
$ [(EvTerm, Ct)] -> [Ct] -> TcPluginResult
TcPluginOk [] [Ct]
new_wanteds