{-# LANGUAGE FlexibleContexts     #-}

-- | This module exports a single function that computes the dependency
-- information needed to eliminate non-cut KVars, and then transitively
-- collapse the resulting constraint dependencies.
-- See the type of `SolverInfo` for details.

module Language.Fixpoint.Solver.Eliminate ( solverInfo ) where

import qualified Data.HashSet        as S
import qualified Data.HashMap.Strict as M

import           Language.Fixpoint.Types.Config    (Config)
import qualified Language.Fixpoint.Types.Solutions as Sol
import           Language.Fixpoint.Types
import           Language.Fixpoint.Types.Visitor   (kvarsExpr, isConcC)
import           Language.Fixpoint.Graph
import           Language.Fixpoint.Misc            (safeLookup, group, errorstar)
import           Language.Fixpoint.Solver.Sanitize

--------------------------------------------------------------------------------
-- | `solverInfo` constructs a `SolverInfo` comprising the Solution and various
--   indices needed by the worklist-based refinement loop
--------------------------------------------------------------------------------
{-# SCC solverInfo #-}
solverInfo :: Config -> SInfo a -> SolverInfo a b
--------------------------------------------------------------------------------
solverInfo :: Config -> SInfo a -> SolverInfo a b
solverInfo Config
cfg SInfo a
sI = Sol b QBind -> SInfo a -> CDeps -> HashSet KVar -> SolverInfo a b
forall a b.
Sol b QBind -> SInfo a -> CDeps -> HashSet KVar -> SolverInfo a b
SI Sol b QBind
forall a b. Sol a b
sHyp SInfo a
sI' CDeps
cD HashSet KVar
cKs
  where
    cD :: CDeps
cD             = SInfo a -> [CEdge] -> HashSet KVar -> HashSet Symbol -> CDeps
forall (c :: * -> *) a.
TaggedC c a =>
GInfo c a -> [CEdge] -> HashSet KVar -> HashSet Symbol -> CDeps
elimDeps     SInfo a
sI [CEdge]
es HashSet KVar
nKs HashSet Symbol
ebs
    sI' :: SInfo a
sI'            = SInfo a -> KIndex -> HashSet KVar -> SInfo a
forall a. SInfo a -> KIndex -> HashSet KVar -> SInfo a
cutSInfo     SInfo a
sI KIndex
kI HashSet KVar
cKs
    sHyp :: Sol a b
sHyp           = SymEnv
-> [(KVar, a)]
-> [(KVar, b)]
-> [(KVar, Hyp)]
-> HashMap KVar IBindEnv
-> [(BindId, EbindSol)]
-> SEnv (BindId, Sort)
-> Sol a b
forall a b.
SymEnv
-> [(KVar, a)]
-> [(KVar, b)]
-> [(KVar, Hyp)]
-> HashMap KVar IBindEnv
-> [(BindId, EbindSol)]
-> SEnv (BindId, Sort)
-> Sol a b
Sol.fromList SymEnv
sE [(KVar, a)]
forall a. Monoid a => a
mempty [(KVar, b)]
forall a. Monoid a => a
mempty [(KVar, Hyp)]
kHyps HashMap KVar IBindEnv
kS [] (SEnv (BindId, Sort) -> Sol a b) -> SEnv (BindId, Sort) -> Sol a b
forall a b. (a -> b) -> a -> b
$ [(Symbol, (BindId, Sort))] -> SEnv (BindId, Sort)
forall a. [(Symbol, a)] -> SEnv a
fromListSEnv [ (Symbol
x, (BindId
i, SortedReft -> Sort
sr_sort SortedReft
sr)) | (BindId
i,Symbol
x,SortedReft
sr) <- BindEnv -> [(BindId, Symbol, SortedReft)]
bindEnvToList (SInfo a -> BindEnv
forall (c :: * -> *) a. GInfo c a -> BindEnv
bs SInfo a
sI)]
    kHyps :: [(KVar, Hyp)]
kHyps          = SInfo a -> KIndex -> HashSet KVar -> [(KVar, Hyp)]
forall a. SInfo a -> KIndex -> HashSet KVar -> [(KVar, Hyp)]
nonCutHyps   SInfo a
sI KIndex
kI HashSet KVar
nKs
    kI :: KIndex
kI             = SInfo a -> KIndex
forall a. SInfo a -> KIndex
kIndex       SInfo a
sI
    ([CEdge]
es, HashSet KVar
cKs, HashSet KVar
nKs) = Config -> SInfo a -> ([CEdge], HashSet KVar, HashSet KVar)
forall a.
Config -> SInfo a -> ([CEdge], HashSet KVar, HashSet KVar)
kutVars Config
cfg  SInfo a
sI
    kS :: HashMap KVar IBindEnv
kS             = SInfo a -> [CEdge] -> HashMap KVar IBindEnv
forall a. SInfo a -> [CEdge] -> HashMap KVar IBindEnv
kvScopes     SInfo a
sI [CEdge]
es
    sE :: SymEnv
sE             = Config -> SInfo a -> SymEnv
forall a. Config -> SInfo a -> SymEnv
symbolEnv   Config
cfg SInfo a
sI
    ebs :: HashSet Symbol
ebs            = [Symbol] -> HashSet Symbol
forall a. (Eq a, Hashable a) => [a] -> HashSet a
S.fromList ([Symbol] -> HashSet Symbol) -> [Symbol] -> HashSet Symbol
forall a b. (a -> b) -> a -> b
$ (Symbol, SortedReft) -> Symbol
forall a b. (a, b) -> a
fst ((Symbol, SortedReft) -> Symbol)
-> (BindId -> (Symbol, SortedReft)) -> BindId -> Symbol
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (BindId -> BindEnv -> (Symbol, SortedReft))
-> BindEnv -> BindId -> (Symbol, SortedReft)
forall a b c. (a -> b -> c) -> b -> a -> c
flip BindId -> BindEnv -> (Symbol, SortedReft)
lookupBindEnv (SInfo a -> BindEnv
forall (c :: * -> *) a. GInfo c a -> BindEnv
bs SInfo a
sI) (BindId -> Symbol) -> [BindId] -> [Symbol]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SInfo a -> [BindId]
forall (c :: * -> *) a. GInfo c a -> [BindId]
ebinds SInfo a
sI)


--------------------------------------------------------------------------------
kvScopes :: SInfo a -> [CEdge] -> M.HashMap KVar IBindEnv
kvScopes :: SInfo a -> [CEdge] -> HashMap KVar IBindEnv
kvScopes SInfo a
sI [CEdge]
es = [Integer] -> IBindEnv
is2env ([Integer] -> IBindEnv) -> KIndex -> HashMap KVar IBindEnv
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KIndex
kiM
  where
    is2env :: [Integer] -> IBindEnv
is2env = (IBindEnv -> IBindEnv -> IBindEnv) -> [IBindEnv] -> IBindEnv
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldr1 IBindEnv -> IBindEnv -> IBindEnv
intersectionIBindEnv ([IBindEnv] -> IBindEnv)
-> ([Integer] -> [IBindEnv]) -> [Integer] -> IBindEnv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer -> IBindEnv) -> [Integer] -> [IBindEnv]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SimpC a -> IBindEnv
forall (c :: * -> *) a. TaggedC c a => c a -> IBindEnv
senv (SimpC a -> IBindEnv)
-> (Integer -> SimpC a) -> Integer -> IBindEnv
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SInfo a -> Integer -> SimpC a
forall a. SInfo a -> Integer -> SimpC a
getSubC SInfo a
sI)
    kiM :: KIndex
kiM    = [(KVar, Integer)] -> KIndex
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k [v]
group ([(KVar, Integer)] -> KIndex) -> [(KVar, Integer)] -> KIndex
forall a b. (a -> b) -> a -> b
$ [(KVar
k, Integer
i) | (Cstr Integer
i, KVar KVar
k) <- [CEdge]
es ] [(KVar, Integer)] -> [(KVar, Integer)] -> [(KVar, Integer)]
forall a. [a] -> [a] -> [a]
++
                     [(KVar
k, Integer
i) | (KVar KVar
k, Cstr Integer
i) <- [CEdge]
es ]

--------------------------------------------------------------------------------

cutSInfo :: SInfo a -> KIndex -> S.HashSet KVar -> SInfo a
cutSInfo :: SInfo a -> KIndex -> HashSet KVar -> SInfo a
cutSInfo SInfo a
si KIndex
kI HashSet KVar
cKs = SInfo a
si { ws :: HashMap KVar (WfC a)
ws = HashMap KVar (WfC a)
ws', cm :: HashMap Integer (SimpC a)
cm = HashMap Integer (SimpC a)
cm' }
  where
    ws' :: HashMap KVar (WfC a)
ws'   = (KVar -> WfC a -> Bool)
-> HashMap KVar (WfC a) -> HashMap KVar (WfC a)
forall k v. (k -> v -> Bool) -> HashMap k v -> HashMap k v
M.filterWithKey (\KVar
k WfC a
_ -> KVar -> HashSet KVar -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
S.member KVar
k HashSet KVar
cKs) (SInfo a -> HashMap KVar (WfC a)
forall (c :: * -> *) a. GInfo c a -> HashMap KVar (WfC a)
ws SInfo a
si)
    cm' :: HashMap Integer (SimpC a)
cm'   = (Integer -> SimpC a -> Bool)
-> HashMap Integer (SimpC a) -> HashMap Integer (SimpC a)
forall k v. (k -> v -> Bool) -> HashMap k v -> HashMap k v
M.filterWithKey (\Integer
i SimpC a
c -> Integer -> HashSet Integer -> Bool
forall a. (Eq a, Hashable a) => a -> HashSet a -> Bool
S.member Integer
i HashSet Integer
cs Bool -> Bool -> Bool
|| SimpC a -> Bool
forall (c :: * -> *) a. TaggedC c a => c a -> Bool
isConcC SimpC a
c) (SInfo a -> HashMap Integer (SimpC a)
forall (c :: * -> *) a. GInfo c a -> HashMap Integer (c a)
cm SInfo a
si)
    cs :: HashSet Integer
cs    = [Integer] -> HashSet Integer
forall a. (Eq a, Hashable a) => [a] -> HashSet a
S.fromList      ((KVar -> [Integer]) -> HashSet KVar -> [Integer]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap KVar -> [Integer]
kCs HashSet KVar
cKs)
    kCs :: KVar -> [Integer]
kCs KVar
k = [Integer] -> KVar -> KIndex -> [Integer]
forall k v. (Eq k, Hashable k) => v -> k -> HashMap k v -> v
M.lookupDefault [] KVar
k KIndex
kI

kutVars :: Config -> SInfo a -> ([CEdge], S.HashSet KVar, S.HashSet KVar)
kutVars :: Config -> SInfo a -> ([CEdge], HashSet KVar, HashSet KVar)
kutVars Config
cfg SInfo a
si   = ([CEdge]
es, Elims KVar -> HashSet KVar
forall a. Elims a -> HashSet a
depCuts Elims KVar
ds, Elims KVar -> HashSet KVar
forall a. Elims a -> HashSet a
depNonCuts Elims KVar
ds)
  where
    ([CEdge]
es, Elims KVar
ds)     = Config -> SInfo a -> ([CEdge], Elims KVar)
forall (c :: * -> *) a.
TaggedC c a =>
Config -> GInfo c a -> ([CEdge], Elims KVar)
elimVars Config
cfg SInfo a
si

--------------------------------------------------------------------------------
-- | Map each `KVar` to the list of constraints on which it appears on RHS
--------------------------------------------------------------------------------
type KIndex = M.HashMap KVar [Integer]

--------------------------------------------------------------------------------
kIndex     :: SInfo a -> KIndex
--------------------------------------------------------------------------------
kIndex :: SInfo a -> KIndex
kIndex SInfo a
si  = [(KVar, Integer)] -> KIndex
forall k v. (Eq k, Hashable k) => [(k, v)] -> HashMap k [v]
group [(KVar
k, Integer
i) | (Integer
i, SimpC a
c) <- [(Integer, SimpC a)]
iCs, KVar
k <- SimpC a -> [KVar]
rkvars SimpC a
c]
  where
    iCs :: [(Integer, SimpC a)]
iCs    = HashMap Integer (SimpC a) -> [(Integer, SimpC a)]
forall k v. HashMap k v -> [(k, v)]
M.toList (SInfo a -> HashMap Integer (SimpC a)
forall (c :: * -> *) a. GInfo c a -> HashMap Integer (c a)
cm SInfo a
si)
    rkvars :: SimpC a -> [KVar]
rkvars = Expr -> [KVar]
kvarsExpr (Expr -> [KVar]) -> (SimpC a -> Expr) -> SimpC a -> [KVar]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SimpC a -> Expr
forall (c :: * -> *) a. TaggedC c a => c a -> Expr
crhs

nonCutHyps :: SInfo a -> KIndex -> S.HashSet KVar -> [(KVar, Sol.Hyp)]
nonCutHyps :: SInfo a -> KIndex -> HashSet KVar -> [(KVar, Hyp)]
nonCutHyps SInfo a
si KIndex
kI HashSet KVar
nKs = [ (KVar
k, KIndex -> SInfo a -> KVar -> Hyp
forall a. KIndex -> SInfo a -> KVar -> Hyp
nonCutHyp KIndex
kI SInfo a
si KVar
k) | KVar
k <- HashSet KVar -> [KVar]
forall a. HashSet a -> [a]
S.toList HashSet KVar
nKs ]


nonCutHyp  :: KIndex -> SInfo a -> KVar -> Sol.Hyp
nonCutHyp :: KIndex -> SInfo a -> KVar -> Hyp
nonCutHyp KIndex
kI SInfo a
si KVar
k = SimpC a -> Cube
forall a. SimpC a -> Cube
nonCutCube (SimpC a -> Cube) -> [SimpC a] -> Hyp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [SimpC a]
cs
  where
    cs :: [SimpC a]
cs            = SInfo a -> Integer -> SimpC a
forall a. SInfo a -> Integer -> SimpC a
getSubC   SInfo a
si (Integer -> SimpC a) -> [Integer] -> [SimpC a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Integer] -> KVar -> KIndex -> [Integer]
forall k v. (Eq k, Hashable k) => v -> k -> HashMap k v -> v
M.lookupDefault [] KVar
k KIndex
kI

nonCutCube :: SimpC a -> Sol.Cube
nonCutCube :: SimpC a -> Cube
nonCutCube SimpC a
c = IBindEnv -> Subst -> Integer -> [BindId] -> Cube
Sol.Cube (SimpC a -> IBindEnv
forall (c :: * -> *) a. TaggedC c a => c a -> IBindEnv
senv SimpC a
c) (SimpC a -> Subst
forall a. SimpC a -> Subst
rhsSubst SimpC a
c) (SimpC a -> Integer
forall (c :: * -> *) a. TaggedC c a => c a -> Integer
subcId SimpC a
c) (SimpC a -> [BindId]
forall (c :: * -> *) a. TaggedC c a => c a -> [BindId]
stag SimpC a
c)

rhsSubst :: SimpC a -> Subst
rhsSubst :: SimpC a -> Subst
rhsSubst             = Expr -> Subst
rsu (Expr -> Subst) -> (SimpC a -> Expr) -> SimpC a -> Subst
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SimpC a -> Expr
forall (c :: * -> *) a. TaggedC c a => c a -> Expr
crhs
  where
    rsu :: Expr -> Subst
rsu (PKVar KVar
_ Subst
su) = Subst
su
    rsu Expr
_            = String -> Subst
forall a. (?callStack::CallStack) => String -> a
errorstar String
"Eliminate.rhsSubst called on bad input"

getSubC :: SInfo a -> Integer -> SimpC a
getSubC :: SInfo a -> Integer -> SimpC a
getSubC SInfo a
si Integer
i = String -> Integer -> HashMap Integer (SimpC a) -> SimpC a
forall k v.
(?callStack::CallStack, Eq k, Hashable k) =>
String -> k -> HashMap k v -> v
safeLookup String
msg Integer
i (SInfo a -> HashMap Integer (SimpC a)
forall (c :: * -> *) a. GInfo c a -> HashMap Integer (c a)
cm SInfo a
si)
  where
    msg :: String
msg = String
"getSubC: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Integer -> String
forall a. Show a => a -> String
show Integer
i