{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase       #-}

module Ide.Plugin.Tactic.Context where

import           Bag
import           Control.Arrow
import           Control.Monad.Reader
import           Data.List
import           Data.Maybe (mapMaybe)
import           Data.Set (Set)
import qualified Data.Set as S
import           Development.IDE.GHC.Compat
import           Ide.Plugin.Tactic.GHC (tacticsThetaTy)
import           Ide.Plugin.Tactic.Machinery (methodHypothesis)
import           Ide.Plugin.Tactic.Types
import           OccName
import           TcRnTypes
import           TcType (substTy, tcSplitSigmaTy)
import           Unify (tcUnifyTy)
import Ide.Plugin.Tactic.FeatureSet (FeatureSet)


mkContext :: FeatureSet -> [(OccName, CType)] -> TcGblEnv -> Context
mkContext :: FeatureSet -> [(OccName, CType)] -> TcGblEnv -> Context
mkContext FeatureSet
features [(OccName, CType)]
locals TcGblEnv
tcg = Context :: [(OccName, CType)] -> [(OccName, CType)] -> FeatureSet -> Context
Context
  { ctxDefiningFuncs :: [(OccName, CType)]
ctxDefiningFuncs = [(OccName, CType)]
locals
  , ctxModuleFuncs :: [(OccName, CType)]
ctxModuleFuncs = (Id -> (OccName, CType)) -> [Id] -> [(OccName, CType)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Id -> (OccName, CType)
splitId
                   ([Id] -> [(OccName, CType)])
-> (Bag (LHsBindLR GhcTc GhcTc) -> [Id])
-> Bag (LHsBindLR GhcTc GhcTc)
-> [(OccName, CType)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HsBindLR GhcTc GhcTc -> [Id]
getFunBindId (HsBindLR GhcTc GhcTc -> [Id]) -> [HsBindLR GhcTc GhcTc] -> [Id]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<)
                   ([HsBindLR GhcTc GhcTc] -> [Id])
-> (Bag (LHsBindLR GhcTc GhcTc) -> [HsBindLR GhcTc GhcTc])
-> Bag (LHsBindLR GhcTc GhcTc)
-> [Id]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (LHsBindLR GhcTc GhcTc -> HsBindLR GhcTc GhcTc)
-> [LHsBindLR GhcTc GhcTc] -> [HsBindLR GhcTc GhcTc]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap LHsBindLR GhcTc GhcTc -> HsBindLR GhcTc GhcTc
forall a. HasSrcSpan a => a -> SrcSpanLess a
unLoc
                   ([LHsBindLR GhcTc GhcTc] -> [HsBindLR GhcTc GhcTc])
-> (Bag (LHsBindLR GhcTc GhcTc) -> [LHsBindLR GhcTc GhcTc])
-> Bag (LHsBindLR GhcTc GhcTc)
-> [HsBindLR GhcTc GhcTc]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Bag (LHsBindLR GhcTc GhcTc) -> [LHsBindLR GhcTc GhcTc]
forall a. Bag a -> [a]
bagToList
                   (Bag (LHsBindLR GhcTc GhcTc) -> [(OccName, CType)])
-> Bag (LHsBindLR GhcTc GhcTc) -> [(OccName, CType)]
forall a b. (a -> b) -> a -> b
$ TcGblEnv -> Bag (LHsBindLR GhcTc GhcTc)
tcg_binds TcGblEnv
tcg
  , ctxFeatureSet :: FeatureSet
ctxFeatureSet = FeatureSet
features
  }


------------------------------------------------------------------------------
-- | Find all of the class methods that exist from the givens in the context.
contextMethodHypothesis :: Context -> Hypothesis CType
contextMethodHypothesis :: Context -> Hypothesis CType
contextMethodHypothesis Context
ctx
  = [HyInfo CType] -> Hypothesis CType
forall a. [HyInfo a] -> Hypothesis a
Hypothesis
  ([HyInfo CType] -> Hypothesis CType)
-> ([(OccName, CType)] -> [HyInfo CType])
-> [(OccName, CType)]
-> Hypothesis CType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [HyInfo CType] -> [HyInfo CType]
forall a. [HyInfo a] -> [HyInfo a]
excludeForbiddenMethods
  ([HyInfo CType] -> [HyInfo CType])
-> ([(OccName, CType)] -> [HyInfo CType])
-> [(OccName, CType)]
-> [HyInfo CType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[HyInfo CType]] -> [HyInfo CType]
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join
  ([[HyInfo CType]] -> [HyInfo CType])
-> ([(OccName, CType)] -> [[HyInfo CType]])
-> [(OccName, CType)]
-> [HyInfo CType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (CType -> [[HyInfo CType]]) -> [CType] -> [[HyInfo CType]]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap
      ( (PredType -> Maybe [HyInfo CType])
-> [PredType] -> [[HyInfo CType]]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe PredType -> Maybe [HyInfo CType]
methodHypothesis
      ([PredType] -> [[HyInfo CType]])
-> (CType -> [PredType]) -> CType -> [[HyInfo CType]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PredType -> [PredType]
tacticsThetaTy
      (PredType -> [PredType])
-> (CType -> PredType) -> CType -> [PredType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CType -> PredType
unCType
      )
  ([CType] -> [[HyInfo CType]])
-> ([(OccName, CType)] -> [CType])
-> [(OccName, CType)]
-> [[HyInfo CType]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (OccName -> Maybe CType) -> [OccName] -> [CType]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Context -> OccName -> Maybe CType
definedThetaType Context
ctx)
  ([OccName] -> [CType])
-> ([(OccName, CType)] -> [OccName])
-> [(OccName, CType)]
-> [CType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((OccName, CType) -> OccName) -> [(OccName, CType)] -> [OccName]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (OccName, CType) -> OccName
forall a b. (a, b) -> a
fst
  ([(OccName, CType)] -> Hypothesis CType)
-> [(OccName, CType)] -> Hypothesis CType
forall a b. (a -> b) -> a -> b
$ Context -> [(OccName, CType)]
ctxDefiningFuncs Context
ctx


------------------------------------------------------------------------------
-- | Many operations are defined in typeclasses for performance reasons, rather
-- than being a true part of the class. This function filters out those, in
-- order to keep our hypothesis space small.
excludeForbiddenMethods :: [HyInfo a] -> [HyInfo a]
excludeForbiddenMethods :: [HyInfo a] -> [HyInfo a]
excludeForbiddenMethods = (HyInfo a -> Bool) -> [HyInfo a] -> [HyInfo a]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (HyInfo a -> Bool) -> HyInfo a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (OccName -> Set OccName -> Bool) -> Set OccName -> OccName -> Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip OccName -> Set OccName -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member Set OccName
forbiddenMethods (OccName -> Bool) -> (HyInfo a -> OccName) -> HyInfo a -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HyInfo a -> OccName
forall a. HyInfo a -> OccName
hi_name)
  where
    forbiddenMethods :: Set OccName
    forbiddenMethods :: Set OccName
forbiddenMethods = (String -> OccName) -> Set String -> Set OccName
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map String -> OccName
mkVarOcc (Set String -> Set OccName) -> Set String -> Set OccName
forall a b. (a -> b) -> a -> b
$ [String] -> Set String
forall a. Ord a => [a] -> Set a
S.fromList
      [ -- monadfail
        String
"fail"
      ]


------------------------------------------------------------------------------
-- | Given the name of a function that exists in 'ctxDefiningFuncs', get its
-- theta type.
definedThetaType :: Context -> OccName -> Maybe CType
definedThetaType :: Context -> OccName -> Maybe CType
definedThetaType Context
ctx OccName
name = do
  (OccName
_, CType PredType
mono) <- ((OccName, CType) -> Bool)
-> [(OccName, CType)] -> Maybe (OccName, CType)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== OccName
name) (OccName -> Bool)
-> ((OccName, CType) -> OccName) -> (OccName, CType) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (OccName, CType) -> OccName
forall a b. (a, b) -> a
fst) ([(OccName, CType)] -> Maybe (OccName, CType))
-> [(OccName, CType)] -> Maybe (OccName, CType)
forall a b. (a -> b) -> a -> b
$ Context -> [(OccName, CType)]
ctxDefiningFuncs Context
ctx
  (OccName
_, CType PredType
poly) <- ((OccName, CType) -> Bool)
-> [(OccName, CType)] -> Maybe (OccName, CType)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((OccName -> OccName -> Bool
forall a. Eq a => a -> a -> Bool
== OccName
name) (OccName -> Bool)
-> ((OccName, CType) -> OccName) -> (OccName, CType) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (OccName, CType) -> OccName
forall a b. (a, b) -> a
fst) ([(OccName, CType)] -> Maybe (OccName, CType))
-> [(OccName, CType)] -> Maybe (OccName, CType)
forall a b. (a -> b) -> a -> b
$ Context -> [(OccName, CType)]
ctxModuleFuncs Context
ctx
  let ([Id]
_, [PredType]
_, PredType
poly') = PredType -> ([Id], [PredType], PredType)
tcSplitSigmaTy PredType
poly
  TCvSubst
subst <- PredType -> PredType -> Maybe TCvSubst
tcUnifyTy PredType
poly' PredType
mono
  CType -> Maybe CType
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CType -> Maybe CType) -> CType -> Maybe CType
forall a b. (a -> b) -> a -> b
$ PredType -> CType
CType (PredType -> CType) -> PredType -> CType
forall a b. (a -> b) -> a -> b
$ HasCallStack => TCvSubst -> PredType -> PredType
TCvSubst -> PredType -> PredType
substTy TCvSubst
subst (PredType -> PredType) -> PredType -> PredType
forall a b. (a -> b) -> a -> b
$ ([Id], PredType) -> PredType
forall a b. (a, b) -> b
snd (([Id], PredType) -> PredType) -> ([Id], PredType) -> PredType
forall a b. (a -> b) -> a -> b
$ PredType -> ([Id], PredType)
splitForAllTys PredType
poly


splitId :: Id -> (OccName, CType)
splitId :: Id -> (OccName, CType)
splitId = Id -> OccName
forall name. HasOccName name => name -> OccName
occName (Id -> OccName) -> (Id -> CType) -> Id -> (OccName, CType)
forall (a :: * -> * -> *) b c c'.
Arrow a =>
a b c -> a b c' -> a b (c, c')
&&& PredType -> CType
CType (PredType -> CType) -> (Id -> PredType) -> Id -> CType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> PredType
idType


getFunBindId :: HsBindLR GhcTc GhcTc -> [Id]
getFunBindId :: HsBindLR GhcTc GhcTc -> [Id]
getFunBindId (AbsBinds XAbsBinds GhcTc GhcTc
_ [Id]
_ [Id]
_ [ABExport GhcTc]
abes [TcEvBinds]
_ Bag (LHsBindLR GhcTc GhcTc)
_ Bool
_)
  = [ABExport GhcTc]
abes [ABExport GhcTc] -> (ABExport GhcTc -> [Id]) -> [Id]
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
      ABE XABE GhcTc
_ IdP GhcTc
poly IdP GhcTc
_ HsWrapper
_ TcSpecPrags
_ -> Id -> [Id]
forall (f :: * -> *) a. Applicative f => a -> f a
pure IdP GhcTc
Id
poly
      ABExport GhcTc
_ -> []
getFunBindId HsBindLR GhcTc GhcTc
_ = []


getCurrentDefinitions :: MonadReader Context m => m [(OccName, CType)]
getCurrentDefinitions :: m [(OccName, CType)]
getCurrentDefinitions = (Context -> [(OccName, CType)]) -> m [(OccName, CType)]
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Context -> [(OccName, CType)]
ctxDefiningFuncs

getModuleHypothesis :: MonadReader Context m => m [(OccName, CType)]
getModuleHypothesis :: m [(OccName, CType)]
getModuleHypothesis = (Context -> [(OccName, CType)]) -> m [(OccName, CType)]
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Context -> [(OccName, CType)]
ctxModuleFuncs