{-# LANGUAGE PatternGuards #-}
module Idris.Erasure (performUsageAnalysis, mkFieldName) where
import Idris.AbsSyntax
import Idris.ASTUtils
import Idris.Core.CaseTree
import Idris.Core.Evaluate
import Idris.Core.TT
import Idris.Error
import Idris.Options
import Idris.Primitives
import Prelude hiding (id, (.))
import Control.Arrow
import Control.Category
import Control.Monad.State
import Data.IntMap (IntMap)
import qualified Data.IntMap as IM
import Data.IntSet (IntSet)
import qualified Data.IntSet as IS
import Data.List
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe
import Data.Set (Set)
import qualified Data.Set as S
import Data.Text (pack)
import qualified Data.Text as T
type UseMap = Map Name (IntMap (Set Reason))
data Arg = Arg Int | Result deriving (Eq, Ord)
instance Show Arg where
show (Arg i) = show i
show Result = "*"
type Node = (Name, Arg)
type Deps = Map Cond DepSet
type Reason = (Name, Int)
type DepSet = Map Node (Set Reason)
type Cond = Set Node
data VarInfo = VI
{ viDeps :: DepSet
, viFunArg :: Maybe Int
, viMethod :: Maybe Name
}
deriving Show
type Vars = Map Name VarInfo
performUsageAnalysis :: [Name] -> Idris [Name]
performUsageAnalysis startNames = do
ctx <- tt_ctxt <$> getIState
case startNames of
[] -> return []
main -> do
ci <- idris_interfaces <$> getIState
cg <- idris_callgraph <$> getIState
opt <- idris_optimisation <$> getIState
used <- idris_erasureUsed <$> getIState
externs <- idris_externs <$> getIState
let depMap = buildDepMap ci used (S.toList externs) ctx main
let (residDeps, (reachableNames, minUse)) = minimalUsage depMap
usage = M.toList minUse
logErasure 5 $ "Original deps:\n" ++ unlines (map fmtItem . M.toList $ depMap)
logErasure 3 $ "Reachable names:\n" ++ unlines (map (indent . show) . S.toList $ reachableNames)
logErasure 4 $ "Minimal usage:\n" ++ fmtUseMap usage
logErasure 5 $ "Residual deps:\n" ++ unlines (map fmtItem . M.toList $ residDeps)
checkEnabled <- (WarnReach `elem`) . opt_cmdline . idris_options <$> getIState
when checkEnabled $
mapM_ (checkAccessibility opt) usage
reachablePostulates <- S.intersection reachableNames . idris_postulates <$> getIState
when (not . S.null $ reachablePostulates)
$ ifail ("reachable postulates:\n" ++ intercalate "\n" [" " ++ show n | n <- S.toList reachablePostulates])
mapM_ storeUsage usage
return $ S.toList reachableNames
where
indent = (" " ++)
fmtItem :: (Cond, DepSet) -> String
fmtItem (cond, deps) = indent $ show (S.toList cond) ++ " -> " ++ show (M.toList deps)
fmtUseMap :: [(Name, IntMap (Set Reason))] -> String
fmtUseMap = unlines . map (\(n,is) -> indent $ show n ++ " -> " ++ fmtIxs is)
fmtIxs :: IntMap (Set Reason) -> String
fmtIxs = intercalate ", " . map fmtArg . IM.toList
where
fmtArg (i, rs)
| S.null rs = show i
| otherwise = show i ++ " from " ++ intercalate ", " (map show $ S.toList rs)
storeUsage :: (Name, IntMap (Set Reason)) -> Idris ()
storeUsage (n, args) = fputState (cg_usedpos . ist_callgraph n) flat
where
flat = [(i, S.toList rs) | (i,rs) <- IM.toList args]
checkAccessibility :: Ctxt OptInfo -> (Name, IntMap (Set Reason)) -> Idris ()
checkAccessibility opt (n, reachable)
| Just (Optimise inaccessible dt force) <- lookupCtxtExact n opt
, eargs@(_:_) <- [fmt n (S.toList rs) | (i,n) <- inaccessible, rs <- maybeToList $ IM.lookup i reachable]
= warn $ show n ++ ": inaccessible arguments reachable:\n " ++ intercalate "\n " eargs
| otherwise = return ()
where
fmt n [] = show n ++ " (no more information available)"
fmt n rs = show n ++ " from " ++ intercalate ", " [show rn ++ " arg# " ++ show ri | (rn,ri) <- rs]
warn = logErasure 0
type Constraint = (Cond, DepSet)
minimalUsage :: Deps -> (Deps, (Set Name, UseMap))
minimalUsage deps
= fromNumbered *** gather
$ forwardChain (index numbered) seedDeps seedDeps numbered
where
numbered = toNumbered deps
seedDeps :: DepSet
seedDeps = M.unionsWith S.union [ds | (cond, ds) <- IM.elems numbered, S.null cond]
toNumbered :: Deps -> IntMap Constraint
toNumbered = IM.fromList . zip [0..] . M.toList
fromNumbered :: IntMap Constraint -> Deps
fromNumbered = IM.foldr addConstraint M.empty
where
addConstraint (ns, vs) = M.insertWith (M.unionWith S.union) ns vs
index :: IntMap Constraint -> Map Node IntSet
index = IM.foldrWithKey (
\i (ns, _ds) ix -> foldr (
\n ix' -> M.insertWith IS.union n (IS.singleton i) ix'
) ix (S.toList ns)
) M.empty
gather :: DepSet -> (Set Name, UseMap)
gather = foldr ins (S.empty, M.empty) . M.toList
where
ins :: (Node, Set Reason) -> (Set Name, UseMap) -> (Set Name, UseMap)
ins ((n, Result), rs) (ns, umap) = (S.insert n ns, umap)
ins ((n, Arg i ), rs) (ns, umap) = (ns, M.insertWith (IM.unionWith S.union) n (IM.singleton i rs) umap)
forwardChain
:: Map Node IntSet
-> DepSet
-> DepSet
-> IntMap Constraint
-> (IntMap Constraint, DepSet)
forwardChain index solution previouslyNew constrs
| M.null currentlyNew
= (constrs, solution)
| otherwise
= forwardChain index
(M.unionWith S.union solution currentlyNew)
currentlyNew
constrs'
where
affectedIxs = IS.unions [
M.findWithDefault IS.empty n index
| n <- M.keys previouslyNew
]
(currentlyNew, constrs')
= IS.foldr
(reduceConstraint $ M.keysSet previouslyNew)
(M.empty, constrs)
affectedIxs
reduceConstraint
:: Set Node
-> Int
-> (DepSet, IntMap (Cond, DepSet))
-> (DepSet, IntMap (Cond, DepSet))
reduceConstraint previouslyNew i (news, constrs)
| Just (cond, deps) <- IM.lookup i constrs
= case cond S.\\ previouslyNew of
cond'
| S.null cond'
-> (M.unionWith S.union news deps, IM.delete i constrs)
| S.size cond' < S.size cond
-> (news, IM.insert i (cond', deps) constrs)
| otherwise
-> (news, constrs)
| otherwise = (news, constrs)
buildDepMap :: Ctxt InterfaceInfo -> [(Name, Int)] -> [(Name, Int)] ->
Context -> [Name] -> Deps
buildDepMap ci used externs ctx startNames
= addPostulates used $ dfs S.empty M.empty startNames
where
addPostulates :: [(Name, Int)] -> Deps -> Deps
addPostulates used deps = foldr (\(ds, rs) -> M.insertWith (M.unionWith S.union) ds rs) deps (postulates used)
where
(==>) ds rs = (S.fromList ds, M.fromList [(r, S.empty) | r <- rs])
it n is = [(sUN n, Arg i) | i <- is]
specialPrims = S.fromList [sUN "prim__believe_me"]
usedNames = allNames deps S.\\ specialPrims
usedPrims = [(p_name p, p_arity p) | p <- primitives, p_name p `S.member` usedNames]
postulates used =
[ [] ==> concat
[(map (\n -> (n, Result)) startNames)
,[(sUN "run__IO", Result), (sUN "run__IO", Arg 1)]
,[(sUN "call__IO", Result), (sUN "call__IO", Arg 2)]
, map (\(n, i) -> (n, Arg i)) used
, it "MkIO" [2]
, it "prim__IO" [1]
, [(pairCon, Arg 2),
(pairCon, Arg 3)]
, it "prim_fork" [0]
, it "unsafePerformPrimIO" [1]
, it "prim__believe_me" [2]
, [(n, Arg i) | (n,arity) <- usedPrims, i <- [0..arity-1]]
, [(n, Arg i) | (n,arity) <- externs, i <- [0..arity-1]]
]
]
dfs :: Set Name -> Deps -> [Name] -> Deps
dfs visited deps [] = deps
dfs visited deps (n : ns)
| n `S.member` visited = dfs visited deps ns
| otherwise = dfs (S.insert n visited) (M.unionWith (M.unionWith S.union) deps' deps) (next ++ ns)
where
next = [n | n <- S.toList depn, n `S.notMember` visited]
depn = S.delete n $ allNames deps'
deps' = getDeps n
allNames :: Deps -> Set Name
allNames = S.unions . map names . M.toList
where
names (cs, ns) = S.map fst cs `S.union` S.map fst (M.keysSet ns)
getDeps :: Name -> Deps
getDeps (SN (WhereN i (SN (ImplementationCtorN interfaceN)) (MN i' field)))
= M.empty
getDeps n = case lookupDefExact n ctx of
Just def -> getDepsDef n def
Nothing -> error $ "erasure checker: unknown reference: " ++ show n
getDepsDef :: Name -> Def -> Deps
getDepsDef fn (Function ty t) = error "a function encountered"
getDepsDef fn (TyDecl ty t) = M.empty
getDepsDef fn (Operator ty n' f) = M.empty
getDepsDef fn (CaseOp ci ty tys def tot cdefs)
= getDepsSC fn etaVars (etaMap `M.union` varMap) sc
where
etaIdx = [length vars .. length tys - 1]
etaVars = [eta i | i <- etaIdx]
etaMap = M.fromList [varPair (eta i) i | i <- etaIdx]
eta i = MN i (pack "eta")
varMap = M.fromList [varPair v i | (v,i) <- zip vars [0..]]
varPair n argNo = (n, VI
{ viDeps = M.singleton (fn, Arg argNo) S.empty
, viFunArg = Just argNo
, viMethod = Nothing
})
(vars, sc) = cases_runtime cdefs
etaExpand :: [Name] -> Term -> Term
etaExpand [] t = t
etaExpand (n : ns) t = etaExpand ns (App Complete t (P Ref n Erased))
getDepsSC :: Name -> [Name] -> Vars -> SC -> Deps
getDepsSC fn es vs ImpossibleCase = M.empty
getDepsSC fn es vs (UnmatchedCase msg) = M.empty
getDepsSC fn es vs (ProjCase (Proj t i) alts) = getDepsSC fn es vs (ProjCase t alts)
getDepsSC fn es vs (ProjCase (P _ n _) alts) = getDepsSC fn es vs (Case Shared n alts)
getDepsSC fn es vs (ProjCase t alts) = error $ "ProjCase not supported:\n" ++ show (ProjCase t alts)
getDepsSC fn es vs (STerm t) = getDepsTerm vs [] (S.singleton (fn, Result)) (etaExpand es t)
getDepsSC fn es vs (Case sh n alts)
= addTagDep $ unionMap (getDepsAlt fn es vs casedVar) alts
where
addTagDep = case alts of
[_] -> id
_ -> M.insertWith (M.unionWith S.union) (S.singleton (fn, Result)) (viDeps casedVar)
casedVar = fromMaybe (error $ "nonpatvar in case: " ++ show n) (M.lookup n vs)
getDepsAlt :: Name -> [Name] -> Vars -> VarInfo -> CaseAlt -> Deps
getDepsAlt fn es vs var (FnCase n ns sc) = M.empty
getDepsAlt fn es vs var (ConstCase c sc) = getDepsSC fn es vs sc
getDepsAlt fn es vs var (DefaultCase sc) = getDepsSC fn es vs sc
getDepsAlt fn es vs var (SucCase n sc)
= getDepsSC fn es (M.insert n var vs) sc
getDepsAlt fn es vs var (ConCase n cnt ns sc)
= getDepsSC fn es (vs' `M.union` vs) sc
where
vs' = M.fromList [(v, VI
{ viDeps = M.insertWith S.union (n, Arg j) (S.singleton (fn, varIdx)) (viDeps var)
, viFunArg = viFunArg var
, viMethod = meth j
})
| (v, j) <- zip ns [0..]]
varIdx = fromJust (viFunArg var)
meth :: Int -> Maybe Name
meth | SN (ImplementationCtorN interfaceName) <- n = \j -> Just (mkFieldName n j)
| otherwise = \j -> Nothing
getDepsTerm :: Vars -> [(Name, Cond -> Deps)] -> Cond -> Term -> Deps
getDepsTerm vs bs cd (P _ n _)
| Just deps <- lookup n bs
= deps cd
| Just var <- M.lookup n vs
= M.singleton cd (viDeps var)
| MN _ _ <- n
= error $ "erasure analysis: variable " ++ show n ++ " unbound in " ++ show (S.toList cd)
| otherwise = M.singleton cd (M.singleton (n, Result) S.empty)
getDepsTerm vs bs cd (V i) = snd (bs !! i) cd
getDepsTerm vs bs cd (Bind n bdr body)
| Lam _ ty <- bdr = getDepsTerm vs ((n, const M.empty) : bs) cd body
| Pi _ _ ty _ <- bdr = getDepsTerm vs ((n, const M.empty) : bs) cd body
| Let rig ty t <- bdr = var t cd `union` getDepsTerm vs ((n, const M.empty) : bs) cd body
| NLet ty t <- bdr = var t cd `union` getDepsTerm vs ((n, const M.empty) : bs) cd body
where
var t cd = getDepsTerm vs bs cd t
getDepsTerm vs bs cd app@(App _ _ _)
| (fun, args) <- unApply app = case fun of
P (DCon _ _ _) ctorName@(SN (ImplementationCtorN interfaceName)) _
-> conditionalDeps ctorName args
`union` unionMap (methodDeps ctorName) (zip [0..] args)
P (TCon _ _) n _ -> unconditionalDeps args
P (DCon _ _ _) n _ -> conditionalDeps n args
P _ (UN n) _
| n == T.pack "mkForeignPrim"
-> unconditionalDeps $ drop 4 args
P _ n _
| Just deps <- lookup n bs
-> deps cd `union` unconditionalDeps args
| Just var <- M.lookup n vs
, Just meth <- viMethod var
-> viDeps var `ins` conditionalDeps meth args
| Just var <- M.lookup n vs
-> viDeps var `ins` unconditionalDeps args
| otherwise
-> conditionalDeps n args
V i -> snd (bs !! i) cd `union` unconditionalDeps args
Bind n (Lam _ ty) t -> getDepsTerm vs bs cd (lamToLet app)
Bind n (Let _ ty t') t -> getDepsTerm vs bs cd (App Complete (Bind n (Lam RigW ty) t) t')
Bind n (NLet ty t') t -> getDepsTerm vs bs cd (App Complete (Bind n (Lam RigW ty) t) t')
Proj t i
-> error $ "cannot[0] analyse projection !" ++ show i ++ " of " ++ show t
Erased -> M.empty
_ -> error $ "cannot analyse application of " ++ show fun ++ " to " ++ show args
where
union = M.unionWith $ M.unionWith S.union
ins = M.insertWith (M.unionWith S.union) cd
unconditionalDeps :: [Term] -> Deps
unconditionalDeps = unionMap (getDepsTerm vs bs cd)
conditionalDeps :: Name -> [Term] -> Deps
conditionalDeps n
= ins (M.singleton (n, Result) S.empty) . unionMap (getDepsArgs n) . zip indices
where
indices = map Just [0 .. getArity n - 1] ++ repeat Nothing
getDepsArgs n (Just i, t) = getDepsTerm vs bs (S.insert (n, Arg i) cd) t
getDepsArgs n (Nothing, t) = getDepsTerm vs bs cd t
methodDeps :: Name -> (Int, Term) -> Deps
methodDeps ctorName (methNo, t)
= getDepsTerm (vars `M.union` vs) (bruijns ++ bs) cond body
where
vars = M.fromList [(v, VI
{ viDeps = deps i
, viFunArg = Just i
, viMethod = Nothing
}) | (v, i) <- zip args [0..]]
deps i = M.singleton (metameth, Arg i) S.empty
bruijns = reverse [(n, \cd -> M.singleton cd (deps i)) | (i, n) <- zip [0..] args]
cond = S.singleton (metameth, Result)
metameth = mkFieldName ctorName methNo
(args, body) = unfoldLams t
getDepsTerm vs bs cd (Proj t (-1)) = getDepsTerm vs bs cd t
getDepsTerm vs bs cd (Proj t i) = error $ "cannot[1] analyse projection !" ++ show i ++ " of " ++ show t
getDepsTerm vs bs cd (Inferred t) = getDepsTerm vs bs cd t
getDepsTerm vs bs cd (Constant _) = M.empty
getDepsTerm vs bs cd (TType _) = M.empty
getDepsTerm vs bs cd (UType _) = M.empty
getDepsTerm vs bs cd Erased = M.empty
getDepsTerm vs bs cd Impossible = M.empty
getDepsTerm vs bs cd t = error $ "cannot get deps of: " ++ show t
getArity :: Name -> Int
getArity (SN (WhereN i' ctorName (MN i field)))
| Just (TyDecl (DCon _ _ _) ty) <- lookupDefExact ctorName ctx
= let argTys = map snd $ getArgTys ty
in if i <= length argTys
then length $ getArgTys (argTys !! i)
else error $ "invalid field number " ++ show i ++ " for " ++ show ctorName
| otherwise = error $ "unknown implementation constructor: " ++ show ctorName
getArity n = case lookupDefExact n ctx of
Just (CaseOp ci ty tys def tot cdefs) -> length tys
Just (TyDecl (DCon tag arity _) _) -> arity
Just (TyDecl (Ref) ty) -> length $ getArgTys ty
Just (Operator ty arity op) -> arity
Just df -> error $ "Erasure/getArity: unrecognised entity '"
++ show n ++ "' with definition: " ++ show df
Nothing -> error $ "Erasure/getArity: definition not found for " ++ show n
lamToLet :: Term -> Term
lamToLet tm = lamToLet' args f
where
(f, args) = unApply tm
lamToLet' :: [Term] -> Term -> Term
lamToLet' (v:vs) (Bind n (Lam rig ty) tm) = Bind n (Let rig ty v) $ lamToLet' vs tm
lamToLet' [] tm = tm
lamToLet' vs tm = error $
"Erasure.hs:lamToLet': unexpected input: "
++ "vs = " ++ show vs ++ ", tm = " ++ show tm
unfoldLams :: Term -> ([Name], Term)
unfoldLams (Bind n (Lam _ ty) t) = let (ns,t') = unfoldLams t in (n:ns, t')
unfoldLams t = ([], t)
union :: Deps -> Deps -> Deps
union = M.unionWith (M.unionWith S.union)
unionMap :: (a -> Deps) -> [a] -> Deps
unionMap f = M.unionsWith (M.unionWith S.union) . map f
mkFieldName :: Name -> Int -> Name
mkFieldName ctorName fieldNo = SN (WhereN fieldNo ctorName $ sMN fieldNo "field")