{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE CPP #-}
module Polysemy.Plugin.Fundep (fundepPlugin) where
import Control.Monad
import Data.Bifunctor
import Data.Coerce
import Data.IORef
import qualified Data.Map as M
import Data.Maybe
import qualified Data.Set as S
import Polysemy.Plugin.Fundep.Stuff
import Polysemy.Plugin.Fundep.Unification
import Polysemy.Plugin.Fundep.Utils
import TcEvidence
import TcPluginM (TcPluginM, tcPluginIO)
import TcRnTypes
#if __GLASGOW_HASKELL__ >= 810
import Constraint
#endif
import TcSMonad hiding (tcLookupClass)
import Type
fundepPlugin :: TcPlugin
fundepPlugin = TcPlugin
{ tcPluginInit =
(,) <$> tcPluginIO (newIORef S.empty)
<*> polysemyStuff
, tcPluginSolve = solveFundep
, tcPluginStop = const $ pure ()
}
data FindConstraint = FindConstraint
{ fcLoc :: CtLoc
, fcEffectName :: Type
, fcEffect :: Type
, fcRow :: Type
}
getFindConstraints :: PolysemyStuff 'Things -> [Ct] -> [FindConstraint]
getFindConstraints (findClass -> cls) cts = do
cd@CDictCan{cc_class = cls', cc_tyargs = [_, eff, r]} <- cts
guard $ cls == cls'
pure $ FindConstraint
{ fcLoc = ctLoc cd
, fcEffectName = getEffName eff
, fcEffect = eff
, fcRow = r
}
findMatchingEffectIfSingular
:: FindConstraint
-> [FindConstraint]
-> Maybe Type
findMatchingEffectIfSingular (FindConstraint _ eff_name wanted r) ts =
singleListToJust $ do
FindConstraint _ eff_name' eff' r' <- ts
guard $ eqType eff_name eff_name'
guard $ eqType r r'
guard $ canUnifyRecursive FunctionDef wanted eff'
pure eff'
getEffName :: Type -> Type
getEffName t = fst $ splitAppTys t
mkWantedForce
:: FindConstraint
-> Type
-> TcPluginM (Unification, Ct)
mkWantedForce fc given = do
(ev, _) <- unsafeTcPluginTcM
. runTcSDeriveds
$ newWantedEq (fcLoc fc) Nominal wanted given
pure ( Unification (OrdType wanted) (OrdType given)
, CNonCanonical ev
)
where
wanted = fcEffect fc
mkWanted
:: FindConstraint
-> SolveContext
-> Type
-> TcPluginM (Maybe (Unification, Ct))
mkWanted fc solve_ctx given =
whenA (not (mustUnify solve_ctx) || canUnifyRecursive solve_ctx wanted given) $
mkWantedForce fc given
where
wanted = fcEffect fc
getBogusRs :: PolysemyStuff 'Things -> [Ct] -> [Type]
getBogusRs stuff wanteds = do
CIrredCan ct _ <- wanteds
(_, [_, _, a, b]) <- pure . splitAppTys $ ctev_pred ct
maybeToList (extractRowFromSem stuff a)
++ maybeToList (extractRowFromSem stuff b)
extractRowFromSem :: PolysemyStuff 'Things -> Type -> Maybe Type
extractRowFromSem (semTyCon -> sem) ty = do
(tycon, [r, _]) <- splitTyConApp_maybe ty
guard $ tycon == sem
pure r
solveBogusError :: PolysemyStuff 'Things -> [Ct] -> [(EvTerm, Ct)]
solveBogusError stuff wanteds = do
let splitTyConApp_list = maybeToList . splitTyConApp_maybe
let bogus = getBogusRs stuff wanteds
ct@(CIrredCan ce _) <- wanteds
(stuck, [_, _, expr, _, _]) <- splitTyConApp_list $ ctev_pred ce
guard $ stuck == ifStuckTyCon stuff
(idx, [_, _, r]) <- splitTyConApp_list expr
guard $ idx == locateEffectTyCon stuff
guard $ elem @[] (OrdType r) $ coerce bogus
pure (error "bogus proof for stuck type family", ct)
exactlyOneWantedForR
:: [FindConstraint]
-> Type
-> Bool
exactlyOneWantedForR wanteds
= fromMaybe False
. flip M.lookup singular_r
. OrdType
where
singular_r = M.fromList
. fmap (second (/= 1))
. countLength
$ OrdType . fcRow <$> wanteds
solveFundep
:: ( IORef (S.Set Unification)
, PolysemyStuff 'Things
)
-> [Ct]
-> [Ct]
-> [Ct]
-> TcPluginM TcPluginResult
solveFundep _ _ _ [] = pure $ TcPluginOk [] []
solveFundep (ref, stuff) given _ wanted = do
let wanted_finds = getFindConstraints stuff wanted
given_finds = getFindConstraints stuff given
eqs <- forM wanted_finds $ \fc -> do
let r = fcRow fc
case findMatchingEffectIfSingular fc given_finds of
Just eff' -> Just <$> mkWantedForce fc eff'
Nothing ->
case splitAppTys r of
(_, [_, eff', _]) ->
mkWanted fc
(InterpreterUse $ exactlyOneWantedForR wanted_finds r)
eff'
_ -> pure Nothing
already_emitted <- tcPluginIO $ readIORef ref
let (unifications, new_wanteds) = unzipNewWanteds already_emitted $ catMaybes eqs
tcPluginIO $ modifyIORef ref $ S.union $ S.fromList unifications
pure $ TcPluginOk (solveBogusError stuff wanted) new_wanteds