{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE FlexibleContexts #-}

-- Functions to build the graph of computations.
-- The following steps are performed:
--  - typing checking
--  - caching checks
--  - building the final json
--
-- All the functions in this module are pure and use SparkStatePure for transforms.

module Spark.Core.Internal.ContextInternal(
  FinalResult,
  prepareExecution1,
  buildComputationGraph,
  performGraphTransforms,
  getTargetNodes,
  storeResults,
) where

import Control.Monad.State(get, put)
import Control.Monad(forM)
import Data.Text(pack)
import Debug.Trace(trace)
import Data.Foldable(toList)
import Control.Arrow((&&&))
import Formatting
import qualified Data.Map.Strict as M
import qualified Data.Vector as V

import Spark.Core.Dataset
import Spark.Core.Try
import Spark.Core.Row
import Spark.Core.Types
import Spark.Core.Internal.Caching
import Spark.Core.Internal.CachingUntyped
import Spark.Core.Internal.ContextStructures
import Spark.Core.Internal.Client
import Spark.Core.Internal.ComputeDag
import Spark.Core.Internal.PathsUntyped
-- Required to import the instances.
import Spark.Core.Internal.Paths()
import Spark.Core.Internal.TypesStructures
import Spark.Core.Internal.TypesFunctions(arrayType)
import Spark.Core.Internal.DAGFunctions(buildVertexList)
import Spark.Core.Internal.DAGStructures
import Spark.Core.Internal.DatasetFunctions
import Spark.Core.Internal.DatasetStructures
import Spark.Core.Internal.Utilities

-- The result from querying the status of a computation
type FinalResult = Either NodeComputationFailure NodeComputationSuccess


-- The main function that takes a single output point and
-- tries to transform it as a valid computation.
prepareExecution1 :: LocalData a -> SparkStatePure (Try Computation)
prepareExecution1 ld = get >>= \session ->
  let cg = buildComputationGraph ld
      cg' = performGraphTransforms =<< cg
      comp = _buildComputation session =<< cg'
  in case comp of
      Left _ -> return comp
      Right _ -> do
        _increaseCompCounter
        return comp

-- Here are the steps being run
--  - node collection + cycle detection
--  - naming:
--    -> everything after that can be done with names, and on server
--    -> for convenience, the vertex ids will be still the hash ids
--  - verification of cache/uncache
--  - deconstruction of unions and aggregations
--  - caching swap
--
-- There is a lot more that could be done (merging the aggregations, etc.)
-- but it is outside the scope of this MVP.

{-| Builds the computation graph by expanding a single node until a transitive
closure is reached.

It performs the naming, node deduplication and cycle detection.

TODO(kps) use the caching information to have a correct fringe
-}
buildComputationGraph :: ComputeNode loc a -> Try ComputeGraph
buildComputationGraph ld = do
  cg <- tryEither $ buildCGraph (untyped ld)
  assignPathsUntyped cg

{-| Performs all the operations that are done on the compute graph:

- fullfilling autocache requests
- checking the cache/uncache pairs
- deconstructions of the unions (in the future)

This could all be done on the server side at this point.
-}
performGraphTransforms :: ComputeGraph -> Try ComputeGraph
performGraphTransforms cg = do
  let g = traceHint "_performGraphTransforms g=" $ computeGraphToGraph cg
  let acg = traceHint "_performGraphTransforms: After autocaching:" $ fillAutoCache cachingType autocacheGen g
  g' <- tryEither acg
  failures <- tryEither $ checkCaching g' cachingType
  case failures of
    [] -> return (graphToComputeGraph g')
    _ -> tryError $ sformat ("Found some caching errors: "%sh) failures

_buildComputation :: SparkSession -> ComputeGraph -> Try Computation
_buildComputation session cg =
  let sid = ssId session
      cid = (ComputationID . pack . show . ssCommandCounter) session
      tiedCg = tieNodes cg
      allNodes = vertexData <$> toList (cdVertices tiedCg)
      terminalNodeNames = nodeName . vertexData <$> toList (cdOutputs tiedCg)
  -- TODO it is missing the first node here, hoping it is the first one.
  in case terminalNodeNames of
    [name] ->
      return $ Computation sid cid allNodes [name] name
    _ -> tryError $ sformat ("Programming error in _build1: cg="%sh) cg

_increaseCompCounter :: SparkStatePure ()
_increaseCompCounter = get >>= \session ->
  let
    curr = ssCommandCounter session
    session2 = session { ssCommandCounter =  curr + 1 }
  in put session2

-- Given an end point, gathers all the nodes reachable from there.
_gatherNodes :: LocalData a -> Try [UntypedNode]
_gatherNodes = tryEither . buildVertexList . untyped

_extractionType :: SQLType a -> SQLType [a]
_extractionType = arrayType . SQLType . unSQLType

-- Like the type, remove the row wrapper in the case of basic elements
-- TODO(kps) figure out what the exact semantics are.
-- It seems collect is behaving differently than the other nodes.
_postprocessBasic :: (HasCallStack) => Cell -> Cell
_postprocessBasic (RowArray rows) =
  RowArray (process <$> rows)  where
    process (RowArray arr) = case V.toList arr of
       [IntElement x] -> IntElement x
       [StringElement x] -> StringElement x
       _ -> RowArray arr
    process x = x
_postprocessBasic x = x
  --failure $ "Could not interpret this cell: " ++ show x

-- Given a result, tries to build the corresponding object out of it
_extract1 :: FinalResult -> SQLType Cell -> Try Cell
_extract1 (Left nf) _ = tryError $ sformat ("got an error "%shown) nf
_extract1 (Right ncs) sqlt = res0 where
  -- Because of the Row semantics, all results are wrappend in a row.
  -- We are using the equivalence between arrays and rows during decoding here.
  wrappingType = _extractionType sqlt
  trow = tryEither $ jsonToCell (unSQLType wrappingType) (ncsData ncs)
  res = trow >>= \l -> case l of
      RowArray arr | V.length arr == 1 -> Right $ _postprocessBasic (V.head arr)
      x -> tryError $ sformat ("ContextInternal:_extract1: Expected on element, got "%shown) x
  res0 = trace ("_extract1: wrappingType = " ++ show wrappingType ++ " ncs = " ++ show ncs ++ " res = " ++ show res) res


-- Gets the relevant nodes for this computation from this spark session.
-- The computation is assumed to be correct and to contain all the nodes
-- already.
-- TODO: make it a total function
getTargetNodes :: (HasCallStack) => Computation -> [UntypedLocalData]
getTargetNodes comp =
  let
    fun2 :: (HasCallStack) => UntypedNode -> UntypedLocalData
    fun2 n = case asLocalObservable <$> castLocality n of
      Right (Right x) -> x
      err -> failure $ sformat ("_getNodes:fun2: err="%shown%" n="%shown) err n
    finalNodeNames = traceHint "_getTargetNodes: finalNodeNames=" $cTerminalNodes comp
    dct = traceHint "_getTargetNodes: dct=" $ M.fromList $ (nodeName &&& id) <$> cNodes comp
    untyped' = finalNodeNames <&> \n ->
      let err = failure $ sformat ("Could not find "%sh%" in "%sh) n dct
      in M.findWithDefault err n dct
  in fun2 <$> untyped'


-- Stores the results of the computation in the state (so that we can accelerate the
-- next sessions) and returns the expected final results (as a Cell to be converted)
storeResults :: Computation -> [(LocalData Cell, FinalResult)] -> SparkStatePure (Try Cell)
storeResults comp [] = return e where
  e = tryError $ sformat ("No result returned for computation "%shown) comp
storeResults _ res =
  let
    fun4 :: (LocalData Cell, FinalResult) -> Try Cell
    fun4 (node, fresult) =
      trace ("_storeResults node=" ++ show node ++ "final = " ++ show fresult) $
        _extract1 fresult (nodeType node)
    allResults = sequence $ forM res fun4
    expResult = head allResults -- Just accessing the final result for now
  in
    -- TODO store the results:
    return expResult