{-# LANGUAGE OverloadedStrings #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE FlexibleContexts #-} module Spark.Core.Internal.ContextIOInternal( returnPure, createSparkSession, createSparkSession', executeCommand1 ) where import Control.Concurrent(threadDelay) import Control.Lens((^.)) import Control.Monad.State(mapStateT, get) import Control.Monad(forM) import Data.Aeson(toJSON) import Data.Functor.Identity(runIdentity) import Data.Text(Text, pack) import qualified Data.Text as T import qualified Network.Wreq as W import Network.Wreq(responseBody) import Control.Monad.Trans(lift) import Control.Monad.Logger(runStdoutLoggingT, LoggingT, logDebugN, logInfoN, MonadLoggerIO) import System.Random(randomIO) import Data.Word(Word8) import Control.Monad.IO.Class -- import Formatting import Network.Wreq.Types(Postable) import Data.ByteString.Lazy(ByteString) import Spark.Core.Dataset import Spark.Core.Internal.Client import Spark.Core.Internal.ContextInternal import Spark.Core.Internal.ContextStructures import Spark.Core.Row import Spark.Core.StructuresInternal import Spark.Core.Try import Spark.Core.Internal.Utilities returnPure :: forall a. SparkStatePure a -> SparkState a returnPure p = lift $ mapStateT (return . runIdentity) p {- | Creates a new Spark session. This session is unique, and it will not try to reconnect to an existing session. -} createSparkSession :: (MonadLoggerIO m) => SparkSessionConf -> m SparkSession createSparkSession conf = do sessionName <- case confRequestedSessionName conf of "" -> liftIO _randomSessionName x -> pure x let session = _createSparkSession conf sessionName 0 let url = _sessionEndPoint session logDebugN $ "Creating spark session at url: " <> url -- TODO get the current counter from remote _ <- _ensureSession session return session {-| Convenience function for simple cases that do not require monad stacks. -} createSparkSession' :: SparkSessionConf -> IO SparkSession createSparkSession' = _runLogger . createSparkSession {- | Executes a command: - performs the transforms and the optimizations in the pure state - sends the computation to the backend - waits for the terminal nodes to reach a final state - commits the final results to the state If any failure is detected that is internal to Krapsh, it returns an error. If the error comes from an underlying library (http stack, programming failure), an exception may be thrown instead. -} executeCommand1 :: forall a. (FromSQL a, HasCallStack) => LocalData a -> SparkState (Try a) executeCommand1 ld = do session <- get tcomp <- returnPure $ prepareExecution1 ld case tcomp of Left err -> return (Left err) Right comp -> let obss = getTargetNodes comp fun3 ld2 = do result <- _waitSingleComputation session comp (nodeName ld2) return (ld2, result) nodeResults :: SparkState [(LocalData Cell, FinalResult)] nodeResults = sequence (fun3 <$> obss) in do _ <- _sendComputation session comp nrs <- nodeResults tcell <- returnPure $ storeResults comp nrs return $ tcell >>= (tryEither . cellToValue) _randomSessionName :: IO Text _randomSessionName = do ws <- forM [1..10] (\(_::Int) -> randomIO :: IO Word8) let ints = (`mod` 10) <$> ws return . T.pack $ "session" ++ concat (show <$> ints) type DefLogger a = LoggingT IO a _runLogger :: DefLogger a -> IO a _runLogger = runStdoutLoggingT _post :: (MonadIO m, Postable a) => Text -> a -> m (W.Response ByteString) _post url = liftIO . W.post (T.unpack url) _get :: (MonadIO m) => Text -> m (W.Response ByteString) _get url = liftIO $ W.get (T.unpack url) -- TODO move to more general utilities -- Performs repeated polling until the result can be converted -- to a certain other type. -- Int controls the delay in milliseconds between each poll. _pollMonad :: (MonadIO m) => m a -> Int -> (a -> Maybe b) -> m b _pollMonad rec delayMillis check = do curr <- rec case check curr of Just res -> return res Nothing -> do _ <- liftIO $ threadDelay (delayMillis * 1000) _pollMonad rec delayMillis check -- Creates a new session from a string containing a session ID. _createSparkSession :: SparkSessionConf -> Text -> Integer -> SparkSession _createSparkSession conf sessionId = SparkSession conf sid where sid = LocalSessionId sessionId -- The URL of the end point _sessionEndPoint :: SparkSession -> Text _sessionEndPoint sess = let port = (pack . show . confPort . ssConf) sess sid = (unLocalSession . ssId) sess in T.concat [ (confEndPoint . ssConf) sess, ":", port, "/session/", sid] _sessionPortText :: SparkSession -> Text _sessionPortText = pack . show . confPort . ssConf -- The URL of the computation end point _compEndPoint :: SparkSession -> ComputationID -> Text _compEndPoint sess compId = let port = _sessionPortText sess sid = (unLocalSession . ssId) sess cid = unComputationID compId in T.concat [ (confEndPoint . ssConf) sess, ":", port, "/computation/", sid, "/", cid] -- The URL of the status of a computation _compEndPointStatus :: SparkSession -> ComputationID -> Text _compEndPointStatus sess compId = let port = _sessionPortText sess sid = (unLocalSession . ssId) sess cid = unComputationID compId in T.concat [ (confEndPoint . ssConf) sess, ":", port, "/status/", sid, "/", cid] -- Ensures that the server has instantiated a session with the given ID. _ensureSession :: (MonadLoggerIO m) => SparkSession -> m () _ensureSession session = do let url = _sessionEndPoint session <> "/create" -- logDebugN $ "url:" <> url _ <- _post url (toJSON 'a') return () _sendComputation :: (MonadLoggerIO m) => SparkSession -> Computation -> m () _sendComputation session comp = do let base' = _compEndPoint session (cId comp) let url = base' <> "/create" logInfoN $ "Sending computations at url: " <> url _ <- _post url (toJSON (cNodes comp)) return () _computationStatus :: (MonadLoggerIO m) => SparkSession -> ComputationID -> NodeName -> m PossibleNodeStatus _computationStatus session compId nname = do let base' = _compEndPointStatus session compId let rest = unNodeName nname let url = base' <> "/" <> rest logDebugN $ "Sending computations status request at url: " <> url _ <- _get url -- raw <- _get url --logDebugN $ sformat ("Got raw status: "%sh) raw status <- liftIO (W.asJSON =<< W.get (T.unpack url) :: IO (W.Response PossibleNodeStatus)) --logDebugN $ sformat ("Got status: "%sh) status let s = status ^. responseBody case s of NodeFinishedSuccess _ -> logInfoN $ rest <> " finished: success" NodeFinishedFailure _ -> logInfoN $ rest <> " finished: failure" _ -> return () return s _waitSingleComputation :: (MonadLoggerIO m) => SparkSession -> Computation -> NodeName -> m FinalResult _waitSingleComputation session comp nname = let extract :: PossibleNodeStatus -> Maybe FinalResult extract (NodeFinishedSuccess s) = Just $ Right s extract (NodeFinishedFailure f) = Just $ Left f extract _ = Nothing -- getStatus :: m PossibleNodeStatus getStatus = _computationStatus session (cId comp) nname i = confPollingIntervalMillis $ ssConf session in _pollMonad getStatus i extract