{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE FlexibleContexts #-}

{-| A set of utility functions to build and transform DAGs.

Because I could not find a public library for such transforms.

Most krapsh manipulations are converted into graph manipulations.
-}
module Spark.Core.Internal.DAGFunctions(
  DagTry,
  -- Building
  buildGraph,
  buildVertexList,
  buildGraphFromList,
  -- Queries
  graphSinks,
  graphSources,
  -- Transforms
  graphMapVertices,
  graphMapVertices',
  vertexMap,
  graphFlatMapEdges,
  graphMapEdges,
  reverseGraph,
  verticesAndEdges,
) where

import qualified Data.Set as S
import qualified Data.Map.Strict as M
import qualified Data.Vector as V
import Data.List(sortBy)
import Data.Maybe
import Data.Foldable(toList)
import Data.Text(Text)
import Control.Arrow((&&&))
import Control.Monad.Except
import Formatting
import Control.Monad.Identity

import Spark.Core.Internal.DAGStructures
import Spark.Core.Internal.Utilities

-- | Separate type of error to make it more general and easier
-- to test.
type DagTry a = Either Text a

{-| Starts from a vertex and expands the vertex to reach all the transitive
closure of the vertex.

Returns a list in lexicographic order of dependencies: the graph corresponding
to this list of elements has one sink (the start element) and potentially
multiple sources. The nodes are ordered so that all the parents are visited
before the node itself.
-}
buildVertexList :: (GraphVertexOperations v, Show v) => v -> DagTry [v]
buildVertexList x = buildVertexListBounded x []

{-| Builds the list of vertices, up to a boundary.
-}
buildVertexListBounded :: (GraphVertexOperations v, Show v) =>
  v -> [v] -> DagTry [v]
buildVertexListBounded x boundary =
  let
    boundaryIds = S.fromList $ vertexToId <$> boundary
    traversals = toList $ _buildList boundaryIds [x] M.empty
    lexico = _lexicographic vertexToId traversals in lexico

-- | Builds a graph by expanding a start vertex.
buildGraph :: forall v e. (GraphOperations v e, Show v, Show e) =>
  v -> DagTry (Graph v e)
buildGraph start = buildVertexList start <&> \vxData ->
  let vertices = [Vertex (vertexToId vx) vx | vx <- vxData]
      -- The edges and vertices are already in the right order, no need
      -- to do further checks
      f :: v -> (VertexId, V.Vector (VertexEdge e v))
      f x =
        let vid = vertexToId x
            g :: (e, v) -> VertexEdge e v
            g (ed, x') =
              let toId = vertexToId x'
                  v' = Vertex toId x'
                  e = Edge vid toId ed
              in VertexEdge v' e
            vedges = g <$> expandVertex x
        in (vid, V.fromList vedges)
      vxs = traceHint "buildGraph: vertices=" $ V.fromList vertices
      edges = traceHint "buildGraph: edges=" $ f <$> vxData
      adj = M.fromList edges
  in Graph adj vxs

{-| Attempts to build a graph from a collection of vertices and edges.

This collection may be invalid (cycles, etc.) and the vertices need not
be in topological order.

All the vertices referred by edges must be present in the list of vertices.
-}
buildGraphFromList :: forall v e. (Show v, Show e) =>
  [Vertex v] -> [Edge e] -> DagTry (Graph v e)
buildGraphFromList vxs eds = do
  -- 1. Group the edges by start point
  -- 2. Find the lexicgraphic order (if possible)
  vxById <- _vertexById vxs
  -- The topological information
  let edTopo = myGroupBy $ (edgeFrom &&& edgeTo) <$> eds
  let vertexById :: VertexId -> DagTry (Vertex v)
      vertexById vid = case M.lookup vid vxById of
        Nothing -> throwError $ sformat ("buildGraphFromList: vertex id found in edge but not in vertices: "%sh) vid
        Just vx -> pure vx
  let f :: Vertex v -> DagTry (Vertex v, [Vertex v])
      f vx =
        let links = M.findWithDefault [] (vertexId vx) edTopo
        in sequence (vertexById <$> links) <&> \l -> (vx, l)
  verticesWithEnds <- sequence $ f <$> vxs
  let indexedVertices = zip [1..] verticesWithEnds <&> \(idx, (vx, l)) -> (idx, vx, l)
  -- The nodes in lexicographic order.
  lexico <- _lexicographic vertexId indexedVertices
  -- Build the edge map:
  -- vertexFromId -> vertexEdge
  let vertexEdge :: Edge e -> DagTry (VertexId, VertexEdge e v)
      vertexEdge e = do
        vxTo <- vertexById (edgeTo e)
        -- Used to confirm that the start vertex is here
        _ <- vertexById (edgeFrom e)
        return (edgeFrom e, VertexEdge vxTo e)
  vEdges <- sequence $ vertexEdge <$> eds
  let edgeMap = M.map V.fromList (myGroupBy vEdges)
  return $ Graph edgeMap (V.fromList lexico)

_vertexById :: (Show v) => [Vertex v] -> DagTry (M.Map VertexId (Vertex v))
_vertexById vxs =
  -- This is probably not the most pretty, but it works.
  let vxById = myGroupBy $ (vertexId &&& id) <$> vxs
      f (vid, [vx]) = pure (vid, vx)
      f (vid, l) = throwError $ sformat ("_VertexById: Multiple vertices with the same id: "%sh%" in "%sh) vid l
  in M.fromList <$> sequence (f <$> M.toList vxById)

-- This implementation is not very efficient and is probably a performance
-- bottleneck.
-- Int is the traversal order. It is just used to break the ties.
-- VertexId is the node id of the vertex.
_lexicographic :: (v -> VertexId) -> [(Int, v, [v])] -> DagTry [v]
_lexicographic _ [] = return []
_lexicographic f m =
  -- We use the traversal ordering to separate the ties.
  -- The first nodes traversed get priority.
  let fcmp (idx, _, []) (idx', _, []) = compare idx idx'
      fcmp (_, _, []) (_, _, _) = LT
      fcmp (_, _, _) (_, _, []) = GT
      fcmp (_, _, _) (_, _, _) = EQ -- This one does not matter
  in case sortBy fcmp m of
    [] -> throwError "_lexicographic: there is a cycle"
    ((_, v, _) : t) ->
      let currentId = f v
          removeCurrentId l = [v' | v' <- l, f v' /= currentId]
          m' = t <&> \(idx, v', l) -> (idx, v', removeCurrentId l)
          tl = _lexicographic f m'
      in (v :) <$> tl


_buildList :: (Show v, GraphVertexOperations v) =>
  S.Set VertexId -> -- boundary ids, they will not be traversed
  [v] -> -- fringe ids
  M.Map VertexId (Int, v, [v]) -> -- all seen ids so far (the intermediate result)
  M.Map VertexId (Int, v, [v])
_buildList boundary fringe =
  _buildListGeneral boundary fringe expandVertexAsVertices

-- (internal)
-- Gathers the list of all the nodes connected through this graph
--
-- The expansion function in that case can be controlled.
--
-- The expansion is done in a DFS manner (the order of the node is unique).
_buildListGeneral :: (Show v, GraphVertexOperations v) =>
  S.Set VertexId -> -- boundary ids, they will not be traversed
  [v] -> -- fringe ids: the nodes that have been touched but not expanded.
  (v -> [v]) -> -- The expansion function. They will be the next nodes to expand.
  -- all seen ids so far (the intermediate result)
  -- along with the index of the node during the traversal, and the
  -- node itself.
  M.Map VertexId (Int, v, [v]) ->
  M.Map VertexId (Int, v, [v])
_buildListGeneral _ [] _ allSeen = allSeen
_buildListGeneral boundaryIds (x : t) expand allSeen =
  let vid = vertexToId x in
  if M.member vid allSeen || S.member vid boundaryIds then
    _buildListGeneral boundaryIds t expand allSeen
  else
    let nextVertices = expand x
        currIdx = M.size allSeen
        allSeen2 = M.insert vid (currIdx, x, nextVertices) allSeen
        filterFun y = not $ M.member (vertexToId y) allSeen2
        nextVertices2 = filter filterFun nextVertices
    in _buildListGeneral boundaryIds (nextVertices2 ++ t) expand allSeen2

{-| The sources of a DAG (nodes with no parent).
-}
graphSources :: Graph v e -> [Vertex v]
graphSources g =
  let hasParent = do
        vedges <- toList (gEdges g)
        edge <- toList vedges
        return . vertexId . veEndVertex $ edge
      hasPSet = S.fromList hasParent
      -- false iff the vertex has an incoming edge
      filt vx = not (S.member (vertexId vx) hasPSet)
  in filter filt (toList (gVertices g))

{-| The sinks of a graph (nodes with no descendant).
-}
graphSinks :: Graph v e -> [Vertex v]
graphSinks g =
  let f vx = V.null (M.findWithDefault V.empty (vertexId vx) (gEdges g))
  in filter f (toList (gVertices g))

-- | Flips the edges of this graph (it is also a DAG)
reverseGraph :: forall v e. (HasCallStack, Show v, Show e) => Graph v e -> Graph v e
reverseGraph g =
  let
    vxMap = M.fromList ((vertexId &&& id) <$> toList (gVertices g))
    flipVEdge :: (VertexId, V.Vector (VertexEdge e v)) -> [(VertexId, VertexEdge e v)]
    flipVEdge (fromNid, vec) = case M.lookup fromNid vxMap of
      Nothing -> [] -- Should be a programming error
      Just endVx ->
        toList vec <&> \ve ->
          let ed = veEdge ve
              oldEndVx = veEndVertex ve
              oldEndVid = vertexId oldEndVx
              ed' = Edge {
                edgeFrom = oldEndVid,
                edgeTo = fromNid,
                edgeData = edgeData ed }
          in (oldEndVid, VertexEdge { veEdge = ed', veEndVertex = endVx })
    edges = myGroupBy $ concat $ flipVEdge <$> M.toList (gEdges g)
  in Graph (V.fromList <$> edges) (V.reverse (gVertices g))

-- | A generic transform over the graph that may account for potential failures
-- in the process.
graphMapVertices :: forall m v e v2. (HasCallStack, Show v2, Show v, Show e, Monad m) =>
  Graph v e -> -- The start graph
  (v -> [(v2,e)] -> m v2) -> -- The transform
  m (Graph v2 e)
graphMapVertices g f =
  let
    fun :: M.Map VertexId v2 -> [Vertex v] -> m [Vertex v2]
    fun _ [] = return []
    fun done (vx : t) =
      let
        vid = vertexId vx
        parents = V.toList $ fromMaybe V.empty $ M.lookup vid (gEdges g)
        parentEdges = veEdge <$> parents
        getPairs :: Edge e -> (v2, e)
        getPairs ed =
          let vidTo = edgeTo ed
              msg = sformat ("graphMapVertices: Could not locate "%shown%" in "%shown)vidTo done
              -- The edges are flowing from child -> parent so
              -- to == parent
              vert = fromMaybe (failure msg) (M.lookup vidTo done)
            in (vert, edgeData ed)
        parents2 = [getPairs ed | ed <- parentEdges]
        -- parents2 = [fromJust $ M.lookup vidFrom done | vidFrom <- parentVids]
        merge0 :: v2 -> m [Vertex v2]
        merge0 vx2Data =
          let done2 = M.insert vid vx2Data done
              vx2 = vx { vertexData = vx2Data }
              rest = fun done2 t in
            (vx2 : ) <$> rest
      in
        f (vertexData vx) parents2 >>= merge0
  in do
    verts2 <- fun M.empty (toList (gVertices g))
    let
      idxs2 = M.fromList [(vertexId vx2, vx2) | vx2 <- verts2]
      trans :: Vertex v -> Vertex v2
      trans vx = fromJust $ M.lookup (vertexId vx) idxs2
      conv :: VertexEdge e v -> VertexEdge e v2
      conv (VertexEdge vx1 e1) = VertexEdge (trans vx1) e1
      adj2 = M.map (conv <$>) (gEdges g)
    return Graph { gEdges = adj2, gVertices = V.fromList verts2 }

-- | (internal) Maps the edges
graphMapEdges :: Graph v e -> (e -> e') -> Graph v e'
graphMapEdges g f = graphFlatMapEdges g ((:[]) . f)

-- | (internal) Maps and the edges, and may create more or less.
graphFlatMapEdges :: Graph v e -> (e -> [e']) -> Graph v e'
graphFlatMapEdges g f = g { gEdges = edges } where
  fun (VertexEdge vx ed) =
    f (edgeData ed) <&> \ed' -> VertexEdge vx (ed { edgeData = ed' })
  edges = (V.fromList . concatMap fun) <$> gEdges g

-- | (internal) Maps the vertices.
graphMapVertices' :: (Show v, Show e, Show v') => (v -> v') -> Graph v e -> Graph v' e
graphMapVertices' f g =
  runIdentity (graphMapVertices g f') where
    f' v _ = return $ f v

-- | The map of vertices, by vertex id.
vertexMap :: Graph v e -> M.Map VertexId v
vertexMap g =
  M.fromList . toList $ gVertices g <&> (vertexId &&& vertexData)

-- (internal)
-- The vertices in lexicographic order, and the originating edges for these
-- vertices.
verticesAndEdges :: Graph v e -> [([(v, e)],v)]
verticesAndEdges g =
  toList (gVertices g) <&> \vx ->
    let n = vertexData vx
        l = V.toList $ M.findWithDefault V.empty (vertexId vx) (gEdges g)
        lres = [(vertexData vx', edgeData e') | (VertexEdge vx' e') <- l]
    in (lres, n)