{-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE OverloadedStrings #-} {- Methods related to checking caching in the graph. -} module Spark.Core.Internal.Caching( NodeCachingType(..), CachingFailure(..), CacheTry, CacheGraph, AutocacheGen(..), checkCaching, fillAutoCache ) where import Control.Monad.Identity import qualified Data.Set as S import qualified Data.Map.Strict as M import qualified Data.Vector as V import Control.Arrow((&&&)) import Data.Foldable import Data.Set(Set) import Data.Maybe(mapMaybe) import Debug.Trace(trace) import Data.Text(Text) import Formatting -- import Debug.Trace import Spark.Core.Internal.DAGFunctions import Spark.Core.Internal.DAGStructures import Spark.Core.Internal.DatasetStructures import Spark.Core.Internal.Utilities -- import Spark.Core.StructuresInternal(NodeId) data NodeCachingType = -- Hinted caching. Will be fullfilled by the algorithm below. -- The node id is that of the node being cached AutocacheOp VertexId -- Unconditional caching -- The node id is that of the node being cached | CacheOp VertexId -- First one is the node id -- the second is the id of the matching cache or autocache node | UncacheOp VertexId VertexId | Through | Stop deriving (Show, Eq) data CachingFailure = CachingFailure { cachingNode :: !VertexId, uncachingNode :: !VertexId, escapingNode :: !VertexId } deriving (Show, Eq) type CacheTry t = Either Text t type CacheGraph v = Graph (v, NodeCachingType) StructureEdge data AutocacheGen v = AutocacheGen { -- Generates an uncaching node to insert at the final location of uncaching deriveUncache :: Vertex v -> Vertex v, -- A function that given a node, generates an identity node (and a new node -- id) that can be inserted in place. The generated node id will be used with -- the identity node newly generated; the previous node will be moved around -- along with its identity. deriveIdentity :: Vertex v -> Vertex v } checkCaching :: (Show v) => Graph v StructureEdge -> (v -> CacheTry NodeCachingType) -> CacheTry [CachingFailure] checkCaching g fun = _cacheGraph g fun >>= _checkCaching fillAutoCache :: (Show v) => (v -> CacheTry NodeCachingType) -> AutocacheGen v -> Graph v StructureEdge -> -- The final graph being constructed. DagTry (Graph v StructureEdge) fillAutoCache cacheFun acGen g = do cg <- graphMapVertices g $ \vx _ -> (const vx &&& id) <$> cacheFun vx acg <- _fillAutoCache acGen cg let acg' = graphMapVertices' fst acg return acg' -- Some internal types to guarantee more correctness newtype AutocacheVertex v = AutocacheVertex (Vertex v) newtype StopVertex v = StopVertex { unStopVertex :: Vertex v } newtype IdentityVertex v = IdentityVertex (Vertex v) data UncacheVertex v = UncacheVertex (Vertex v) VertexId deriving (Show) newtype AnyCacheOp = AnyCacheOp { unAnyCacheOp :: VertexId } deriving (Show, Ord, Eq) -- The result of creating a vertex type CreateUncache v = Either (UncacheVertex v) (UncacheVertex v, Edge StructureEdge) -- This performs a graph transform: -- For each autocache node, it finds the transitive closure of Stop nodes. -- Then it replaces the sink nodes by a layer of identity nodes and sink nodes -- and intercalates the uncache node between the two layers, through some -- logical dependencies. -- -- If it does not find the closuure, it leaves these -- autocache nodes alone and does not attempt to remove them: they will be -- considered as being unconditional caching without checks. -- -- Note that it works on the reverse of the graph (flow instead of dependencies) _fillAutoCache :: forall v. (Show v) => AutocacheGen v -> -- The graph, already annotated with caching information Graph (v, NodeCachingType) StructureEdge -> -- The final graph being constructed. DagTry (Graph (v, NodeCachingType) StructureEdge) _fillAutoCache acGen cg = -- Find the auto nodes. -- Compute the closure for each of them -- Perform the insertion. -- TODO: this function is too big, split or build subfunctions. let vxMap = M.fromList ((vertexId &&& id) <$> toList (gVertices cg)) -- TODO: mark if the result was already in the graph findOrCreateIdentity :: StopVertex v -> IdentityVertex v findOrCreateIdentity (StopVertex vx) = let uvx = deriveIdentity acGen vx in case M.lookup (vertexId uvx) vxMap of Just vx' -> IdentityVertex $ fst <$> vx' -- Already created Nothing -> IdentityVertex uvx acNodesAndScopes = _autoCachingCandidates cg -- Add the uncaching nodes findOrCreateUncache' (acv, l) = case _findOrCreateUncache vxMap acGen acv of Left x -> trace ("findOrCreateUncache: dropping autocache node " ++ show x) Nothing Right (ucv, ed) -> Just (ed, (acv, ucv, l)) acWithUncache' = mapMaybe findOrCreateUncache' acNodesAndScopes acWithUncache = snd <$> acWithUncache' acEdges = fst <$> acWithUncache' -- Now group by stop vertex, so that each stop vertex has a list of -- associated cache and uncache nodes. -- Not sure if they may be several, but it just sounds like good practice. tups = myGroupBy [(vertexId (unStopVertex svx), (cvx, uvx, svx)) | (cvx, uvx, l) <- acWithUncache, svx <- l] -- Just in this case, it should work because of the construction above -- TODO: put a lot more documentation here, it is tricky code group ((_, uvx, svx) : t) = (svx, findOrCreateIdentity svx, uvx : [uvx' | ( _, uvx', _) <- t]) group [] = failure "_fillAutoCache:group: empty: should not happen" stopsWithCachingSteps :: [(StopVertex v, IdentityVertex v, [UncacheVertex v])] stopsWithCachingSteps = (group . snd) <$> M.toList tups tups2 = [(svx, ivx, uvx) | (svx, ivx, l) <- stopsWithCachingSteps, uvx <- l] folder eds (svx, ivx, uvx) = _performEdgeTransform svx ivx uvx eds startEdges = veEdge <$> [ve | (_, v) <- M.toList (gEdges cg), ve <- V.toList v] edges = acEdges ++ foldl' folder startEdges tups2 -- Gather all the vertices and edges, and remove duplicates startVertices = V.toList (gVertices cg) ucVertices = acWithUncache <&> \(_, UncacheVertex vx cacheVid, _) -> -- TODO: propagate the cache vertexId with UncacheVertex let op = UncacheOp (vertexId vx) cacheVid in (id &&& const op) <$> vx idVertices = tups2 <&> \(_, IdentityVertex vx, _) -> (id &&& const Stop) <$> vx allVertices = startVertices ++ ucVertices ++ idVertices -- Make a new graph in buildGraphFromList allVertices edges -- TODO: should be a try to perform extra check operations _findOrCreateUncache :: (HasCallStack, Show v) => M.Map VertexId (Vertex (v, NodeCachingType)) -> AutocacheGen v -> AutocacheVertex v -> CreateUncache v _findOrCreateUncache vxMap acGen (AutocacheVertex acv) = let uvx = deriveUncache acGen acv acVid = vertexId acv uVid = vertexId uvx look = vertexData <$> M.lookup uVid vxMap in case look of Just (x, UncacheOp _ _) -> -- That vertex already exists, we will not try to create -- an uncaching node then Left $ UncacheVertex (Vertex uVid x) uVid Just _ -> -- That vertex already exists, but it is not the proper type. -- This is a programming error in AutocacheGen: we abort here. failure $ sformat ("_findOrCreateUncache:"%sh%"->"%sh) acv look Nothing -> -- The uncache node does not exist, we are going to create one. let ed' = Edge uVid acVid ParentEdge in Right (UncacheVertex uvx uVid, ed') -- FIXME: duplicated work on the stop and identity: pass all the uncache vertexes to process them in one go _performEdgeTransform :: StopVertex v -> IdentityVertex v -> UncacheVertex v -> [Edge StructureEdge] -> [Edge StructureEdge] _performEdgeTransform (StopVertex svx) (IdentityVertex ivx) (UncacheVertex uvx _) eds = let stopVid = vertexId svx idenVid = vertexId ivx ucVid = vertexId uvx -- Rewrite the edges incoming to the stop node so that they point to the -- id node instead. f ed | edgeTo ed == stopVid = ed { edgeTo = idenVid } f ed = ed joinEd = Edge { edgeFrom = idenVid, edgeTo = stopVid, edgeData = ParentEdge } id1Ed = Edge { edgeFrom = idenVid, edgeTo = ucVid, edgeData = LogicalEdge } id2Ed = Edge { edgeFrom = ucVid, edgeTo = stopVid, edgeData = LogicalEdge } in id1Ed : id2Ed : joinEd : (f <$> eds) -- The list of nodes that do autocaching, and the fringes for each of these -- nodes. -- Returns a list of caching node -> [stop node] _autoCachingCandidates :: forall v. (Show v) => Graph (v, NodeCachingType) StructureEdge -> [(AutocacheVertex v, [StopVertex v])] _autoCachingCandidates cg = let cg' = graphMapVertices' snd cg exps = gVertices $ _expansions cg' extractAutocache vx = case snd (vertexData vx) of AutocacheOp _ -> [AutocacheVertex (fst <$> vx)] _ -> [] acVxs = concatMap extractAutocache (gVertices cg) -- All the stop nodes for each caching vertex id extractFringe vx = case vertexData vx of (Stop, set) -> (id &&& const (vertexId vx)) <$> toList set _ -> [] -- cache vid -> Stop vertex id acWithFringe = myGroupBy $ concatMap extractFringe (toList exps) vmap = vertexMap cg vmap' = vertexMap cg' -- TODO: should be a try and it should not fail findStop :: VertexId -> Maybe (StopVertex v) findStop vid = do vx <- M.lookup vid vmap _ <- M.lookup vid vmap' return $ StopVertex (Vertex vid (fst vx)) -- TODO: it should be a try, although it is a programming error here combineWithFringe :: AutocacheVertex v -> (AutocacheVertex v, [StopVertex v]) combineWithFringe acv @ (AutocacheVertex vx) = let vids = M.findWithDefault [] (AnyCacheOp (vertexId vx)) acWithFringe in (acv, mapMaybe findStop vids) -- Remove the nodes that do not have a fringe. -- In this case, they are passed through without uncaching operation. acWithFringeVx = filter (not.null.snd) $ combineWithFringe <$> acVxs in acWithFringeVx _checkCaching :: Graph NodeCachingType StructureEdge -> CacheTry [CachingFailure] _checkCaching cg = let expands = snd <$> vertexMap (_expansions cg) removals = vertexMap $ _removals cg f :: NodeCachingType -> [(VertexId, VertexId)] f (UncacheOp uncacheNid cacheNid) = [(cacheNid, uncacheNid)] f _ = [] -- cacheNID -> uncacheNID removedNodes = M.fromList $ concatMap f (vertexData <$> gVertices cg) removedNodeSet = S.fromList $ M.keys removedNodes checkErrors :: VertexId -> [CachingFailure] checkErrors nid = let rems = S.intersection removedNodeSet (M.findWithDefault S.empty nid removals) exps = S.intersection removedNodeSet (unAnyCacheOp `S.map` M.findWithDefault S.empty nid expands) bad = S.toList $ S.difference exps rems badWithUncache = flip mapMaybe bad $ \ cacheNid -> M.lookup cacheNid removedNodes <&> \uncacheNid -> CachingFailure cacheNid uncacheNid nid in badWithUncache in return $ concatMap checkErrors (vertexId <$> gVertices cg) _cacheGraph :: (Show v) => Graph v StructureEdge -> (v -> CacheTry NodeCachingType) -> CacheTry (Graph NodeCachingType StructureEdge) _cacheGraph g f = graphMapVertices g f' where f' vx _ = f vx -- The set of node caching operations at each step. -- This includes both regular cache and autocache. _expansions :: (Show e) => Graph NodeCachingType e -> Graph (NodeCachingType, Set AnyCacheOp) e _expansions g = runIdentity (graphMapVertices g f) where f x l = return (x, S.union seta parentSet) where filt ((Stop, _), _) = S.empty -- Uncaching drops the caching node from the expansions filt ((UncacheOp _ cacheVid, s), _) = S.delete (AnyCacheOp cacheVid) s filt ((_, s), _) = s parentSet :: S.Set AnyCacheOp parentSet = S.unions (filt <$> l) seta = case x of CacheOp nid -> S.singleton (AnyCacheOp nid) AutocacheOp nid -> S.singleton (AnyCacheOp nid) _ -> S.empty _removals :: (Show e) => Graph NodeCachingType e -> Graph (Set VertexId) e _removals g = runIdentity (graphMapVertices (reverseGraph g) f) where f x l = return $ S.union seta (S.unions (fst <$> l)) where seta = case x of UncacheOp _ cacheNid -> S.singleton cacheNid _ -> S.empty