{-# LANGUAGE CPP #-}
{-# LANGUAGE TupleSections #-}
{-# OPTIONS_HADDOCK show-extensions #-}
module GHC.TypeLits.Extra.Solver
( plugin )
where
import Control.Monad.Trans.Maybe (MaybeT (..))
import Data.Either (lefts)
import Data.Maybe (catMaybes)
import GHC.TcPluginM.Extra (evByFiat, lookupModule, lookupName,
tracePlugin)
import FastString (fsLit)
import Module (mkModuleName)
import OccName (mkTcOcc)
import Outputable (Outputable (..), (<+>), ($$), text)
import Plugins (Plugin (..), defaultPlugin)
#if MIN_VERSION_ghc(8,6,0)
import Plugins (purePlugin)
#endif
import TcEvidence (EvTerm)
import TcPluginM (TcPluginM, tcLookupTyCon, tcPluginTrace)
import TcRnTypes (TcPlugin(..), TcPluginResult (..))
import Type (Kind, eqType)
import TyCoRep (Type (..))
import TysWiredIn (typeNatKind, promotedTrueDataCon, promotedFalseDataCon)
import TcTypeNats (typeNatLeqTyCon)
#if MIN_VERSION_ghc(8,4,0)
import GHC.TcPluginM.Extra (flattenGivens)
import TcTypeNats (typeNatTyCons)
#else
import TcPluginM (zonkCt)
import Control.Monad ((<=<))
#endif
#if MIN_VERSION_ghc(8,10,0)
import Constraint (Ct, ctEvidence, ctEvPred, isWantedCt)
import Predicate (EqRel (NomEq), Pred (EqPred), classifyPredType)
import Type (typeKind)
#else
import TcRnTypes (Ct, ctEvidence, ctEvPred, isWantedCt)
import TcType (typeKind)
import Type (EqRel (NomEq), PredTree (EqPred), classifyPredType)
#endif
import GHC.TypeLits.Extra.Solver.Operations
import GHC.TypeLits.Extra.Solver.Unify
plugin :: Plugin
plugin
= defaultPlugin
{ tcPlugin = const $ Just normalisePlugin
#if MIN_VERSION_ghc(8,6,0)
, pluginRecompile = purePlugin
#endif
}
normalisePlugin :: TcPlugin
normalisePlugin = tracePlugin "ghc-typelits-extra"
TcPlugin { tcPluginInit = lookupExtraDefs
, tcPluginSolve = decideEqualSOP
, tcPluginStop = const (return ())
}
decideEqualSOP :: ExtraDefs -> [Ct] -> [Ct] -> [Ct] -> TcPluginM TcPluginResult
decideEqualSOP _ _givens _deriveds [] = return (TcPluginOk [] [])
decideEqualSOP defs givens _deriveds wanteds = do
-- GHC 7.10.1 puts deriveds with the wanteds, so filter them out
let wanteds' = filter isWantedCt wanteds
unit_wanteds <- catMaybes <$> mapM (runMaybeT . toNatEquality defs) wanteds'
case unit_wanteds of
[] -> return (TcPluginOk [] [])
_ -> do
#if MIN_VERSION_ghc(8,4,0)
unit_givens <- catMaybes <$> mapM (runMaybeT . toNatEquality defs) (givens ++ flattenGivens givens)
#else
unit_givens <- catMaybes <$> mapM ((runMaybeT . toNatEquality defs) <=< zonkCt) givens
#endif
sr <- simplifyExtra (unit_givens ++ unit_wanteds)
tcPluginTrace "normalised" (ppr sr)
case sr of
Simplified evs -> return (TcPluginOk (filter (isWantedCt . snd) evs) [])
Impossible eq -> return (TcPluginContradiction [fromNatEquality eq])
type NatEquality = (Ct,ExtraOp,ExtraOp)
type NatInEquality = (Ct,ExtraOp,ExtraOp,Bool)
data SimplifyResult
= Simplified [(EvTerm,Ct)]
| Impossible (Either NatEquality NatInEquality)
instance Outputable SimplifyResult where
ppr (Simplified evs) = text "Simplified" $$ ppr evs
ppr (Impossible eq) = text "Impossible" <+> ppr eq
simplifyExtra :: [Either NatEquality NatInEquality] -> TcPluginM SimplifyResult
simplifyExtra eqs = tcPluginTrace "simplifyExtra" (ppr eqs) >> simples [] eqs
where
simples :: [Maybe (EvTerm, Ct)] -> [Either NatEquality NatInEquality] -> TcPluginM SimplifyResult
simples evs [] = return (Simplified (catMaybes evs))
simples evs (eq@(Left (ct,u,v)):eqs') = do
ur <- unifyExtra ct u v
tcPluginTrace "unifyExtra result" (ppr ur)
case ur of
Win -> simples (((,) <$> evMagic ct <*> pure ct):evs) eqs'
Lose -> if null evs && null eqs'
then return (Impossible eq)
else simples evs eqs'
Draw -> simples evs eqs'
simples evs (eq@(Right (ct,u,v,b)):eqs') = do
tcPluginTrace "unifyExtra leq result" (ppr (u,v,b))
case (u,v) of
(I i,I j)
| (i <= j) == b -> simples (((,) <$> evMagic ct <*> pure ct):evs) eqs'
| otherwise -> return (Impossible eq)
(p, Max x y)
| b && (p == x || p == y) -> simples (((,) <$> evMagic ct <*> pure ct):evs) eqs'
(p, q@(V _))
| b -> case findMax q eqs of
Just m -> simples evs ((Right (ct,p,m,b)):eqs')
Nothing -> simples evs eqs'
_ -> simples evs eqs'
findMax :: ExtraOp -> [Either NatEquality NatInEquality] -> Maybe ExtraOp
findMax c = go . lefts
where
go [] = Nothing
go ((ct, a,b@(Max _ _)) :_)
| c == a && not (isWantedCt ct)
= Just b
go ((ct, a@(Max _ _),b) :_)
| c == b && not (isWantedCt ct)
= Just a
go (_:rest) = go rest
toNatEquality :: ExtraDefs -> Ct -> MaybeT TcPluginM (Either NatEquality NatInEquality)
toNatEquality defs ct = case classifyPredType $ ctEvPred $ ctEvidence ct of
EqPred NomEq t1 t2
| isNatKind (typeKind t1) || isNatKind (typeKind t2)
-> Left <$> ((ct,,) <$> normaliseNat defs t1 <*> normaliseNat defs t2)
| TyConApp tc [x,y] <- t1
, tc == typeNatLeqTyCon
, TyConApp tc' [] <- t2
-> if tc' == promotedTrueDataCon
then Right <$> ((ct,,,True) <$> normaliseNat defs x <*> normaliseNat defs y)
else if tc' == promotedFalseDataCon
then Right <$> ((ct,,,False) <$> normaliseNat defs x <*> normaliseNat defs y)
else fail "Nothing"
_ -> fail "Nothing"
where
isNatKind :: Kind -> Bool
isNatKind = (`eqType` typeNatKind)
fromNatEquality :: Either NatEquality NatInEquality -> Ct
fromNatEquality (Left (ct, _, _)) = ct
fromNatEquality (Right (ct,_,_,_)) = ct
lookupExtraDefs :: TcPluginM ExtraDefs
lookupExtraDefs = do
md <- lookupModule myModule myPackage
ExtraDefs <$> look md "Max"
<*> look md "Min"
#if MIN_VERSION_ghc(8,4,0)
<*> pure (typeNatTyCons !! 5)
<*> pure (typeNatTyCons !! 6)
#else
<*> look md "Div"
<*> look md "Mod"
#endif
<*> look md "FLog"
<*> look md "CLog"
<*> look md "Log"
<*> look md "GCD"
<*> look md "LCM"
where
look md s = tcLookupTyCon =<< lookupName md (mkTcOcc s)
myModule = mkModuleName "GHC.TypeLits.Extra"
myPackage = fsLit "ghc-typelits-extra"
evMagic :: Ct -> Maybe EvTerm
evMagic ct = case classifyPredType $ ctEvPred $ ctEvidence ct of
EqPred NomEq t1 t2 -> Just (evByFiat "ghc-typelits-extra" t1 t2)
_ -> Nothing