module Aws.Lambda.Runtime where

import Control.Exception (Exception, IOException, try)
import Control.Monad.Except (ExceptT, catchError, throwError)
import Data.Aeson
import System.Exit (ExitCode (..))
import qualified Data.Text as Text
import Data.Text (Text)
import GHC.Generics
import qualified Data.ByteString.Lazy as LazyByteString
import qualified Data.ByteString as ByteString
import Control.Monad.Trans
import Text.Read (readMaybe)
import qualified Data.Text.Encoding    as Encoding
import Control.Monad
import Data.Function ((&))
import Data.Maybe (listToMaybe)
import Data.Monoid ((<>))

import qualified Data.CaseInsensitive as CI
import Lens.Micro.Platform hiding ((.=))
import qualified Network.Wreq as Wreq
import qualified System.Environment as Environment
import qualified System.Process as Process
import qualified Data.UUID as UUID
import qualified Data.UUID.V4 as UUID
import System.IO (hFlush, stdout)


type LByteString = LazyByteString.ByteString
type ByteString = ByteString.ByteString

type App a =
  ExceptT RuntimeError IO a


data RuntimeError
  = EnvironmentVariableNotSet Text
  | ApiConnectionError
  | ApiHeaderNotSet Text
  | ParseError Text Text
  | InvocationError Text
  deriving (Show)
instance Exception RuntimeError

instance ToJSON RuntimeError where
  toJSON (EnvironmentVariableNotSet msg) = object
    [ "errorType" .= ("EnvironmentVariableNotSet" :: Text)
    , "errorMessage" .= msg
    ]

  toJSON ApiConnectionError = object
    [ "errorType" .= ("ApiConnectionError" :: Text)
    , "errorMessage" .= ("Could not connect to API to retrieve AWS Lambda parameters" :: Text)
    ]

  toJSON (ApiHeaderNotSet headerName) = object
    [ "errorType" .= ("ApiHeaderNotSet" :: Text)
    , "errorMessage" .= headerName
    ]

  toJSON (ParseError objectBeingParsed value) = object
    [ "errorType" .= ("ParseError" :: Text)
    , "errorMessage" .= ("Parse error for " <> objectBeingParsed <> ", could not parse value '" <> value <> "'")
    ]

  -- We return the user error as it is
  toJSON (InvocationError err) = toJSON err



data Context = Context
  { memoryLimitInMb    :: !Int
  , functionName       :: !Text
  , functionVersion    :: !Text
  , invokedFunctionArn :: !Text
  , awsRequestId       :: !Text
  , xrayTraceId        :: !Text
  , logStreamName      :: !Text
  , logGroupName       :: !Text
  , deadline           :: !Int
  } deriving (Generic)
instance FromJSON Context
instance ToJSON Context


newtype LambdaResult =
  LambdaResult Text


awsLambdaVersion :: String
awsLambdaVersion = "2018-06-01"


nextInvocationEndpoint :: Text -> String
nextInvocationEndpoint endpoint =
  "http://" <> Text.unpack endpoint <> "/"<> awsLambdaVersion <>"/runtime/invocation/next"


responseEndpoint :: Text -> Text -> String
responseEndpoint lambdaApi requestId =
  "http://"<> Text.unpack lambdaApi <> "/" <> awsLambdaVersion <> "/runtime/invocation/"<> Text.unpack requestId <> "/response"


invocationErrorEndpoint :: Text -> Text -> String
invocationErrorEndpoint lambdaApi requestId =
  "http://"<> Text.unpack lambdaApi <> "/" <> awsLambdaVersion <> "/runtime/invocation/"<> Text.unpack requestId <> "/error"


runtimeInitErrorEndpoint :: Text -> String
runtimeInitErrorEndpoint lambdaApi =
  "http://"<> Text.unpack lambdaApi <> "/" <> awsLambdaVersion <> "/runtime/init/error"


readEnvironmentVariable :: Text -> App Text
readEnvironmentVariable envVar = do
  v <- lift (Environment.lookupEnv $ Text.unpack envVar)
  case v of
    Nothing    -> throwError (EnvironmentVariableNotSet envVar)
    Just value -> pure (Text.pack value)


readFunctionMemory :: App Int
readFunctionMemory = do
  let envVar = "AWS_LAMBDA_FUNCTION_MEMORY_SIZE"
  let parseMemory txt = readMaybe (Text.unpack txt)
  memoryValue <- readEnvironmentVariable envVar
  case parseMemory memoryValue of
    Just value -> pure value
    Nothing    -> throwError (ParseError envVar memoryValue)


getApiData :: Text -> App (Wreq.Response LByteString)
getApiData endpoint =
  keepRetrying (Wreq.get $ nextInvocationEndpoint endpoint)
 where
  keepRetrying :: IO (Wreq.Response LByteString) -> App (Wreq.Response LByteString)
  keepRetrying f = do
    result <- (liftIO $ try f) :: App (Either IOException (Wreq.Response LByteString))
    case result of
      Right x -> return x
      _ -> keepRetrying f


extractHeader :: Wreq.Response LByteString -> Text -> Text
extractHeader apiData header =
  Encoding.decodeUtf8 (apiData ^. Wreq.responseHeader (CI.mk $ Encoding.encodeUtf8 header))


extractIntHeader :: Wreq.Response LByteString -> Text -> App Int
extractIntHeader apiData headerName = do
  let header = extractHeader apiData headerName
  case readMaybe $ Text.unpack header of
    Nothing    -> throwError (ParseError "deadline" header)
    Just value -> pure value


extractBody :: Wreq.Response LByteString -> Text
extractBody apiData =
  Encoding.decodeUtf8 $ LazyByteString.toStrict (apiData ^. Wreq.responseBody)


propagateXRayTrace :: Text -> App ()
propagateXRayTrace xrayTraceId =
  liftIO $ Environment.setEnv "_X_AMZN_TRACE_ID" $ Text.unpack xrayTraceId


initializeContext :: Wreq.Response LByteString -> App Context
initializeContext apiData = do
  functionName          <- readEnvironmentVariable "AWS_LAMBDA_FUNCTION_NAME"
  version               <- readEnvironmentVariable "AWS_LAMBDA_FUNCTION_VERSION"
  logStream             <- readEnvironmentVariable "AWS_LAMBDA_LOG_STREAM_NAME"
  logGroup              <- readEnvironmentVariable "AWS_LAMBDA_LOG_GROUP_NAME"
  memoryLimitInMb       <- readFunctionMemory
  deadline              <- extractIntHeader apiData "Lambda-Runtime-Deadline-Ms"
  let xrayTraceId        = extractHeader apiData "Lambda-Runtime-Trace-Id"
  let awsRequestId       = extractHeader apiData "Lambda-Runtime-Aws-Request-Id"
  let invokedFunctionArn = extractHeader apiData "Lambda-Runtime-Invoked-Function-Arn"
  propagateXRayTrace xrayTraceId
  pure $ Context
    { functionName       = functionName
    , functionVersion    = version
    , logStreamName      = logStream
    , logGroupName       = logGroup
    , memoryLimitInMb    = memoryLimitInMb
    , invokedFunctionArn = invokedFunctionArn
    , xrayTraceId        = xrayTraceId
    , awsRequestId       = awsRequestId
    , deadline           = deadline
    }


getFunctionResult :: UUID.UUID -> Text -> App (Maybe Text)
getFunctionResult u stdOut = do
  let out = Text.lines stdOut

  out
   & takeWhile (/= uuid)
   & mapM_ ( \t -> do
    liftIO $ putStrLn $ Text.unpack t
    liftIO $ hFlush stdout)

  out
   & dropWhile (/= uuid)
   & dropWhile (== uuid)
   & listToMaybe
   & return
 where
  uuid = Text.pack $ UUID.toString u


invoke :: Text -> Context -> App LambdaResult
invoke event context = do
  handlerName <- readEnvironmentVariable "_HANDLER"
  runningDirectory <- readEnvironmentVariable "LAMBDA_TASK_ROOT"
  let contextJSON = Encoding.decodeUtf8 $ LazyByteString.toStrict $ encode context
  uuid <- liftIO UUID.nextRandom
  out <- liftIO $ Process.readProcessWithExitCode (Text.unpack runningDirectory <> "/haskell_lambda")
                [ "--eventObject", Text.unpack event
                , "--contextObject", Text.unpack contextJSON
                , "--functionHandler", Text.unpack handlerName
                , "--executionUuid", UUID.toString uuid
                ]
                ""
  case out of
    (ExitSuccess, stdOut, _) -> do
      res <- getFunctionResult uuid (Text.pack stdOut)
      case res of
        Nothing -> throwError (ParseError "parsing result" $ Text.pack stdOut)
        Just value -> pure (LambdaResult value)
    (_, stdOut, stdErr)           ->
      if stdErr /= ""
        then throwError (InvocationError $ Text.pack stdErr)
        else do
          res <- getFunctionResult uuid (Text.pack stdOut)
          case res of
            Nothing -> throwError (ParseError "parsing error" $ Text.pack stdOut)
            Just value -> throwError (InvocationError value)


publishResult :: Context -> Text -> LambdaResult -> App ()
publishResult Context {..} lambdaApi (LambdaResult result) =
  void $ liftIO $ Wreq.post (responseEndpoint lambdaApi awsRequestId) (Encoding.encodeUtf8 result)


invokeAndPublish :: Context -> Text -> Text -> App ()
invokeAndPublish ctx event lambdaApiEndpoint = do
  res <- invoke event ctx
  publishResult ctx lambdaApiEndpoint res


publishError :: Context -> Text -> RuntimeError -> App ()
publishError Context {..} lambdaApiEndpoint (InvocationError err) =
  void (liftIO $ Wreq.post (invocationErrorEndpoint lambdaApiEndpoint awsRequestId) (Encoding.encodeUtf8 err))

publishError Context {..} lambdaApiEndpoint (ParseError t t2) =
  void (liftIO $ Wreq.post (invocationErrorEndpoint lambdaApiEndpoint awsRequestId) (toJSON $ ParseError t t2))

publishError Context {..} lambdaApiEndpoint err =
  void (liftIO $ Wreq.post (runtimeInitErrorEndpoint lambdaApiEndpoint) (toJSON err))


lambdaRunner :: App ()
lambdaRunner = do
  lambdaApiEndpoint     <- readEnvironmentVariable "AWS_LAMBDA_RUNTIME_API"
  apiData               <- getApiData lambdaApiEndpoint
  let event = extractBody apiData
  ctx <- initializeContext apiData
  invokeAndPublish ctx event lambdaApiEndpoint `catchError` publishError ctx lambdaApiEndpoint