module Idris.TypeSearch (
searchByType
, searchPred
, defaultScoreFunction
) where
import Idris.AbsSyntax (addUsingConstraints, getIState, implicit, putIState)
import Idris.AbsSyntaxTree (IState(idris_docstrings, idris_interfaces, idris_outputmode, tt_ctxt),
Idris, InterfaceInfo, OutputMode(..), PTerm,
defaultSyntax, eqTy, implicitAllowed,
interface_implementations, toplevel)
import Idris.Core.Evaluate (Context(definitions), Def(CaseOp, Function, TyDecl),
normaliseC)
import Idris.Core.TT hiding (score)
import Idris.Core.Unify (match_unify)
import Idris.Delaborate (delabTy)
import Idris.Docstrings (noDocs, overview)
import Idris.Elab.Type (elabType)
import Idris.IBC
import Idris.Imports (PkgName)
import Idris.Output (iPrintResult, iRenderError, iRenderOutput, iRenderResult,
iputStrLn, prettyDocumentedIst)
import Util.Pretty (Doc, annotate, char, text, vsep, (<>))
#if (MIN_VERSION_base(4,11,0))
import Prelude hiding (Semigroup(..), pred)
import qualified Prelude as S (Semigroup(..))
#else
import Prelude hiding (pred)
#endif
import Control.Applicative (Applicative(..), (<$>), (<*>), (<|>))
import Control.Arrow (first, second, (&&&), (***))
import Control.Monad (guard, when)
import Data.List (find, partition, (\\))
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe (catMaybes, fromMaybe, isJust, mapMaybe, maybeToList)
import Data.Monoid (Monoid(mappend, mempty))
import Data.Ord (comparing)
import qualified Data.PriorityQueue.FingerTree as Q
import Data.Set (Set)
import qualified Data.Set as S
import qualified Data.Text as T (isPrefixOf, pack)
import Data.Traversable (traverse)
searchByType :: [PkgName] -> PTerm -> Idris ()
searchByType pkgs pterm = do
i <- getIState
when (not (null pkgs)) $
iputStrLn $ "Searching packages: " ++ showSep ", " (map show pkgs)
mapM_ loadPkgIndex pkgs
pterm' <- addUsingConstraints syn emptyFC pterm
pterm'' <- implicit toplevel syn name pterm'
ty <- elabType toplevel syn (fst noDocs) (snd noDocs) emptyFC [] name NoFC pterm'
let names = searchUsing searchPred i ty
let names' = take numLimit names
let docs =
[ let docInfo = (n, delabTy i n, fmap (overview . fst) (lookupCtxtExact n (idris_docstrings i))) in
displayScore theScore <> char ' ' <> prettyDocumentedIst i docInfo
| (n, theScore) <- names']
if (not (null docs))
then case idris_outputmode i of
RawOutput _ -> do mapM_ iRenderOutput docs
iPrintResult ""
IdeMode _ _ -> iRenderResult (vsep docs)
else iRenderError $ text "No results found"
putIState i
where
numLimit = 50
syn = defaultSyntax { implicitAllowed = True }
name = sMN 0 "searchType"
searchUsing :: (IState -> Type -> [(Name, Type)] -> [(Name, a)])
-> IState -> Type -> [(Name, a)]
searchUsing pred istate ty = pred istate nty . concat . M.elems $
M.mapWithKey (\key -> M.toAscList . M.mapMaybe (f key)) (definitions ctxt)
where
nty = normaliseC ctxt [] ty
ctxt = tt_ctxt istate
f k x = do
guard $ not (special k)
type2 <- typeFromDef x
return $ normaliseC ctxt [] type2
special :: Name -> Bool
special (NS n _) = special n
special (SN _) = True
special (UN n) = T.pack "default#" `T.isPrefixOf` n
|| n `elem` map T.pack ["believe_me", "really_believe_me"]
special _ = False
searchPred :: IState -> Type -> [(Name, Type)] -> [(Name, Score)]
searchPred istate ty1 = matcher where
maxScore = 100
matcher = matchTypesBulk istate maxScore ty1
typeFromDef :: (Def, r, i, b, c, d) -> Maybe Type
typeFromDef (def, _, _, _, _, _) = get def where
get :: Def -> Maybe Type
get (Function ty _) = Just ty
get (TyDecl _ ty) = Just ty
get (CaseOp _ ty _ _ _ _) = Just ty
get _ = Nothing
unLazy :: Type -> Type
unLazy typ = case typ of
App _ (App _ (P _ lazy _) _) ty | lazy == sUN "Delayed" -> unLazy ty
Bind name binder ty -> Bind name (fmap unLazy binder) (unLazy ty)
App s t1 t2 -> App s (unLazy t1) (unLazy t2)
Proj ty i -> Proj (unLazy ty) i
_ -> typ
reverseDag :: Ord k => [((k, a), Set k)] -> [((k, a), Set k)]
reverseDag xs = map f xs where
f ((k, v), _) = ((k, v), S.fromList . map (fst . fst) $ filter (S.member k . snd) xs)
computeDagP :: Ord n
=> (TT n -> Bool)
-> TT n
-> ([((n, TT n), Set n)], [(n, TT n)], TT n)
computeDagP removePred t = (reverse (map f arguments), reverse theRemovedArgs , retTy) where
f (n, ty) = ((n, ty), M.keysSet (usedVars ty))
(arguments, theRemovedArgs, retTy) = go [] [] t
go args removedArgs (Bind n (Pi _ _ ty _) sc) = let arg = (n, ty) in
if removePred ty
then go args (arg : removedArgs) sc
else go (arg : args) removedArgs sc
go args removedArgs sc = (args, removedArgs, sc)
usedVars :: Ord n => TT n -> Map n (TT n, Bool)
usedVars = f True where
f b (P Bound n t) = M.singleton n (t, b) `M.union` f b t
f b (Bind n binder t2) = (M.delete n (f b t2) `M.union`) $ case binder of
Let rig t v -> f b t `M.union` f b v
Guess t v -> f b t `M.union` f b v
bind -> f b (binderTy bind)
f b (App _ t1 t2) = f b t1 `M.union` f (b && isInjective t1) t2
f b (Proj t _) = f b t
f _ (V _) = error "unexpected! run vToP first"
f _ _ = M.empty
deleteFromDag :: Ord n => n -> [((n, TT n), (a, Set n))] -> [((n, TT n), (a, Set n))]
deleteFromDag name [] = []
deleteFromDag name (((name2, ty), (ix, set)) : xs) = (if name == name2
then id
else (((name2, ty) , (ix, S.delete name set)) :) ) (deleteFromDag name xs)
deleteFromArgList :: Ord n => n -> [(n, TT n)] -> [(n, TT n)]
deleteFromArgList n = filter ((/= n) . fst)
data AsymMods = Mods
{ argApp :: !Int
, interfaceApp :: !Int
, interfaceIntro :: !Int
} deriving (Eq, Show)
data Sided a = Sided
{ left :: !a
, right :: !a
} deriving (Eq, Show)
sided :: (a -> a -> b) -> Sided a -> b
sided f (Sided l r) = f l r
both :: (a -> b) -> Sided a -> Sided b
both f (Sided l r) = Sided (f l) (f r)
data Score = Score
{ transposition :: !Int
, equalityFlips :: !Int
, asymMods :: !(Sided AsymMods)
} deriving (Eq, Show)
displayScore :: Score -> Doc OutputAnnotation
displayScore score = case both noMods (asymMods score) of
Sided True True -> annotated EQ "="
Sided True False -> annotated LT "<"
Sided False True -> annotated GT ">"
Sided False False -> text "_"
where
annotated ordr = annotate (AnnSearchResult ordr) . text
noMods (Mods app tcApp tcIntro) = app + tcApp + tcIntro == 0
scoreCriterion :: Score -> Bool
scoreCriterion (Score _ _ amods) = not
( sided (&&) (both ((> 0) . argApp) amods)
|| sided (+) (both argApp amods) > 4
|| sided (||) (both (\(Mods _ tcApp tcIntro) -> tcApp > 3 || tcIntro > 3) amods)
)
defaultScoreFunction :: Score -> Int
defaultScoreFunction (Score trans eqFlip amods) =
trans + eqFlip + linearPenalty + upAndDowncastPenalty
where
linearPenalty = (\(Sided l r) -> 3 * l + r)
(both (\(Mods app tcApp tcIntro) -> 3 * app + 4 * tcApp + 2 * tcIntro) amods)
upAndDowncastPenalty = 100 *
sided (*) (both (\(Mods app tcApp tcIntro) -> 2 * app + tcApp + tcIntro) amods)
instance Ord Score where
compare = comparing defaultScoreFunction
#if (MIN_VERSION_base(4,11,0))
instance S.Semigroup a => S.Semigroup (Sided a) where
(Sided l1 r1) <> (Sided l2 r2) = Sided (l1 S.<> l2) (r1 S.<> r2)
instance S.Semigroup AsymMods where
(<>) = mappend
instance S.Semigroup Score where
(<>) = mappend
#endif
instance Monoid a => Monoid (Sided a) where
mempty = Sided mempty mempty
(Sided l1 r1) `mappend` (Sided l2 r2) = Sided (l1 `mappend` l2) (r1 `mappend` r2)
instance Monoid AsymMods where
mempty = Mods 0 0 0
(Mods a b c) `mappend` (Mods a' b' c') = Mods (a + a') (b + b') (c + c')
instance Monoid Score where
mempty = Score 0 0 mempty
(Score t e mods) `mappend` (Score t' e' mods') = Score (t + t') (e + e') (mods `mappend` mods')
type ArgsDAG = [((Name, Type), (Int, Set Name))]
type Interfaces = [(Name, Type)]
data State = State
{ holes :: ![(Name, Type)]
, argsAndInterfaces :: !(Sided (ArgsDAG, Interfaces))
, score :: !Score
, usedNames :: ![Name]
} deriving Show
modifyTypes :: (Type -> Type) -> (ArgsDAG, Interfaces) -> (ArgsDAG, Interfaces)
modifyTypes f = modifyDag *** modifyList
where
modifyDag = map (first (second f))
modifyList = map (second f)
findNameInArgsDAG :: Name -> ArgsDAG -> Maybe (Type, Maybe Int)
findNameInArgsDAG name = fmap ((snd . fst) &&& (Just . fst . snd)) . find ((name ==) . fst . fst)
findName :: Name -> (ArgsDAG, Interfaces) -> Maybe (Type, Maybe Int)
findName n (args, interfaces) = findNameInArgsDAG n args <|> ((,) <$> lookup n interfaces <*> Nothing)
deleteName :: Name -> (ArgsDAG, Interfaces) -> (ArgsDAG, Interfaces)
deleteName n (args, interfaces) = (deleteFromDag n args, filter ((/= n) . fst) interfaces)
tcToMaybe :: TC a -> Maybe a
tcToMaybe (OK x) = Just x
tcToMaybe (Error _) = Nothing
inArgTys :: (Type -> Type) -> ArgsDAG -> ArgsDAG
inArgTys = map . first . second
interfaceUnify :: Ctxt InterfaceInfo -> Context -> Type -> Type -> Maybe [(Name, Type)]
interfaceUnify interfaceInfo ctxt ty tyTry = do
res <- tcToMaybe $ match_unify ctxt [] (ty, Nothing) (retTy, Nothing) [] theHoles []
guard $ null (theHoles \\ map fst res)
let argTys' = map (second $ foldr (.) id [ subst n t | (n, t) <- res ]) tcArgs
return argTys'
where
tyTry' = vToP tyTry
theHoles = map fst nonTcArgs
retTy = getRetTy tyTry'
(tcArgs, nonTcArgs) = partition (isInterfaceArg interfaceInfo . snd) $ getArgTys tyTry'
isInterfaceArg :: Ctxt InterfaceInfo -> Type -> Bool
isInterfaceArg interfaceInfo ty = not (null (getInterfaceName clss >>= flip lookupCtxt interfaceInfo)) where
(clss, _) = unApply ty
getInterfaceName (P (TCon _ _) interfaceName _) = [interfaceName]
getInterfaceName _ = []
subsets :: [a] -> [[a]]
subsets [] = [[]]
subsets (x : xs) = let ss = subsets xs in map (x :) ss ++ ss
flipEqualities :: Type -> [(Int, Type)]
flipEqualities t = case t of
eq1@(App _ (App _ (App _ (App _ eq@(P _ eqty _) tyL) tyR) valL) valR) | eqty == eqTy ->
[(0, eq1), (1, app (app (app (app eq tyR) tyL) valR) valL)]
Bind n binder sc -> (\bind' (j, sc') -> (fst (binderTy bind') + j, Bind n (fmap snd bind') sc'))
<$> traverse flipEqualities binder <*> flipEqualities sc
App _ f x -> (\(i, f') (j, x') -> (i + j, app f' x'))
<$> flipEqualities f <*> flipEqualities x
t' -> [(0, t')]
where app = App Complete
matchTypesBulk :: forall info. IState -> Int -> Type -> [(info, Type)] -> [(info, Score)]
matchTypesBulk istate maxScore type1 types = getAllResults startQueueOfQueues where
getStartQueues :: (info, Type) -> Maybe (Score, (info, Q.PQueue Score State))
getStartQueues nty@(info, type2) = case mapMaybe startStates ty2s of
[] -> Nothing
xs -> Just (minimum (map fst xs), (info, Q.fromList xs))
where
ty2s = (\(i, dag) (j, retTy) -> (i + j, dag, retTy))
<$> flipEqualitiesDag dag2 <*> flipEqualities retTy2
flipEqualitiesDag dag = case dag of
[] -> [(0, [])]
((n, ty), (pos, deps)) : xs ->
(\(i, ty') (j, xs') -> (i + j , ((n, ty'), (pos, deps)) : xs'))
<$> flipEqualities ty <*> flipEqualitiesDag xs
startStates (numEqFlips, sndDag, sndRetTy) = do
state <- unifyQueue (State startingHoles
(Sided (dag1, interfaceArgs1) (sndDag, interfaceArgs2))
(mempty { equalityFlips = numEqFlips }) usedns) [(retTy1, sndRetTy)]
return (score state, state)
(dag2, interfaceArgs2, retTy2) = makeDag (uniqueBinders (map fst argNames1) type2)
argNames2 = map fst dag2
usedns = map fst startingHoles
startingHoles = argNames1 ++ argNames2
startQueueOfQueues :: Q.PQueue Score (info, Q.PQueue Score State)
startQueueOfQueues = Q.fromList $ mapMaybe getStartQueues types
getAllResults :: Q.PQueue Score (info, Q.PQueue Score State) -> [(info, Score)]
getAllResults q = case Q.minViewWithKey q of
Nothing -> []
Just ((nextScore, (info, stateQ)), q') ->
if defaultScoreFunction nextScore <= maxScore
then case nextStepsQueue stateQ of
Nothing -> getAllResults q'
Just (Left stateQ') -> case Q.minViewWithKey stateQ' of
Nothing -> getAllResults q'
Just ((newQscore,_), _) -> getAllResults (Q.add newQscore (info, stateQ') q')
Just (Right theScore) -> (info, theScore) : getAllResults q'
else []
ctxt = tt_ctxt istate
interfaceInfo = idris_interfaces istate
(dag1, interfaceArgs1, retTy1) = makeDag type1
argNames1 = map fst dag1
makeDag :: Type -> (ArgsDAG, Interfaces, Type)
makeDag = first3 (zipWith (\i (ty, deps) -> (ty, (i, deps))) [0..] . reverseDag) .
computeDagP (isInterfaceArg interfaceInfo) . vToP . unLazy
first3 f (a,b,c) = (f a, b, c)
resolveUnis :: [(Name, Type)] -> State -> Maybe (State, [(Type, Type)])
resolveUnis [] state = Just (state, [])
resolveUnis ((name, term@(P Bound name2 _)) : xs)
state | isJust findArgs = do
((ty1, ix1), (ty2, ix2)) <- findArgs
(state'', queue) <- resolveUnis xs state'
let transScore = fromMaybe 0 (abs <$> (() <$> ix1 <*> ix2))
return (inScore (\s -> s { transposition = transposition s + transScore }) state'', (ty1, ty2) : queue)
where
unresolved = argsAndInterfaces state
inScore f stat = stat { score = f (score stat) }
findArgs = ((,) <$> findName name (left unresolved) <*> findName name2 (right unresolved)) <|>
((,) <$> findName name2 (left unresolved) <*> findName name (right unresolved))
matchnames = [name, name2]
deleteArgs = deleteName name . deleteName name2
state' = state { holes = filter (not . (`elem` matchnames) . fst) (holes state)
, argsAndInterfaces = both (modifyTypes (subst name term) . deleteArgs) unresolved}
resolveUnis ((name, term) : xs)
state@(State hs unresolved _ _) = case both (findName name) unresolved of
Sided Nothing Nothing -> Nothing
Sided (Just _) (Just _) -> error "Idris internal error: TypeSearch.resolveUnis"
oneOfEach -> first (addScore (both scoreFor oneOfEach)) <$> nextStep
where
scoreFor (Just _) = mempty { argApp = 1 }
scoreFor Nothing = mempty { argApp = otherApplied }
matchedVarMap = usedVars term
bothT f (x, y) = (f x, f y)
(injUsedVars, notInjUsedVars) = bothT M.keys . M.partition id . M.filterWithKey (\k _-> k `elem` map fst hs) $ M.map snd matchedVarMap
varsInTy = injUsedVars ++ notInjUsedVars
toDelete = name : varsInTy
deleteMany = foldr (.) id (map deleteName toDelete)
otherApplied = length notInjUsedVars
addScore additions theState = theState { score = let s = score theState in
s { asymMods = asymMods s `mappend` additions } }
state' = state { holes = filter (not . (`elem` toDelete) . fst) hs
, argsAndInterfaces = both (modifyTypes (subst name term) . deleteMany) (argsAndInterfaces state) }
nextStep = resolveUnis xs state'
unifyQueue :: State -> [(Type, Type)] -> Maybe State
unifyQueue state [] = return state
unifyQueue state ((ty1, ty2) : queue) = do
res <- tcToMaybe $ match_unify ctxt [ (n, RigW, Pi RigW Nothing ty (TType (UVar [] 0))) | (n, ty) <- holes state]
(ty1, Nothing)
(ty2, Nothing) [] (map fst $ holes state) []
(state', queueAdditions) <- resolveUnis res state
guard $ scoreCriterion (score state')
unifyQueue state' (queue ++ queueAdditions)
possInterfaceImplementations :: [Name] -> Type -> [Type]
possInterfaceImplementations usedns ty = do
interfaceName <- getInterfaceName clss
interfaceDef <- lookupCtxt interfaceName interfaceInfo
n <- interface_implementations interfaceDef
def <- lookupCtxt (fst n) (definitions ctxt)
nty <- normaliseC ctxt [] <$> (case typeFromDef def of Just x -> [x]; Nothing -> [])
let ty' = vToP (uniqueBinders usedns nty)
return ty'
where
(clss, _) = unApply ty
getInterfaceName (P (TCon _ _) interfaceName _) = [interfaceName]
getInterfaceName _ = []
nextStepsQueue :: Q.PQueue Score State -> Maybe (Either (Q.PQueue Score State) Score)
nextStepsQueue queue = do
((nextScore, next), rest) <- Q.minViewWithKey queue
Just $ if isFinal next
then Right nextScore
else let additions = if scoreCriterion nextScore
then Q.fromList [ (score state, state) | state <- nextSteps next ]
else Q.empty in
Left (Q.union rest additions)
where
isFinal (State [] (Sided ([], []) ([], [])) _ _) = True
isFinal _ = False
nextSteps :: State -> [State]
nextSteps (State [] unresolved@(Sided ([], c1) ([], c2)) scoreAcc usedns) =
if null results3 then results4 else results3
where
results3 =
catMaybes [ unifyQueue (State []
(Sided ([], deleteFromArgList n1 c1)
([], map (second subst2for1) (deleteFromArgList n2 c2)))
scoreAcc usedns) [(ty1, ty2)]
| (n1, ty1) <- c1, (n2, ty2) <- c2, let subst2for1 = psubst n2 (P Bound n1 ty1)]
results4 = [ State [] (both (\(cs, _, _) -> ([], cs)) sds)
(scoreAcc `mappend` Score 0 0 (both (\(_, amods, _) -> amods) sds))
(usedns ++ sided (++) (both (\(_, _, hs) -> hs) sds))
| sds <- allMods ]
where
allMods = parallel defMod mods
mods :: Sided [( Interfaces, AsymMods, [Name] )]
mods = both (implementationMods . snd) unresolved
defMod :: Sided (Interfaces, AsymMods, [Name])
defMod = both (\(_, cs) -> (cs, mempty , [])) unresolved
parallel :: Sided a -> Sided [a] -> [Sided a]
parallel (Sided l r) (Sided ls rs) = map (flip Sided r) ls ++ map (Sided l) rs
implementationMods :: Interfaces -> [( Interfaces , AsymMods, [Name] )]
implementationMods interfaces = [ ( newInterfaceArgs, mempty { interfaceApp = 1 }, newHoles )
| (_, ty) <- interfaces
, impl <- possInterfaceImplementations usedns ty
, newInterfaceArgs <- maybeToList $ interfaceUnify interfaceInfo ctxt ty impl
, let newHoles = map fst newInterfaceArgs ]
nextSteps (State hs (Sided (dagL, c1) (dagR, c2)) scoreAcc usedns) = results where
results = concatMap takeSomeInterfaces results1
canBeFirst :: ArgsDAG -> [(Name, Type)]
canBeFirst = map fst . filter (S.null . snd . snd)
results1 = catMaybes [ unifyQueue (State (filter (not . (`elem` [n1,n2]) . fst) hs)
(Sided (deleteFromDag n1 dagL, c1)
(inArgTys subst2for1 $ deleteFromDag n2 dagR, map (second subst2for1) c2))
scoreAcc usedns) [(ty1, ty2)]
| (n1, ty1) <- canBeFirst dagL, (n2, ty2) <- canBeFirst dagR
, let subst2for1 = psubst n2 (P Bound n1 ty1)]
takeSomeInterfaces (State [] unresolved@(Sided ([], _) ([], _)) scoreAcc usedns) =
map statesFromMods . prod $ both (interfaceMods . snd) unresolved
where
swap (Sided l r) = Sided r l
statesFromMods :: Sided (Interfaces, AsymMods) -> State
statesFromMods sides = let interfaces = both (\(c, _) -> ([], c)) sides
mods = swap (both snd sides) in
State [] interfaces (scoreAcc `mappend` (mempty { asymMods = mods })) usedns
interfaceMods :: Interfaces -> [(Interfaces, AsymMods)]
interfaceMods cs = let lcs = length cs in
[ (cs', mempty { interfaceIntro = lcs length cs' }) | cs' <- subsets cs ]
prod :: Sided [a] -> [Sided a]
prod (Sided ls rs) = [Sided l r | l <- ls, r <- rs]
takeSomeInterfaces s = [s]