module Aws.Lambda.Runtime where
import Control.Exception (Exception, IOException, try)
import Control.Monad.Except (ExceptT, catchError, throwError)
import Data.Aeson
import System.Exit (ExitCode (..))
import qualified Data.Text as Text
import Data.Text (Text)
import GHC.Generics
import qualified Data.ByteString.Lazy as LazyByteString
import qualified Data.ByteString as ByteString
import Control.Monad.Trans
import Text.Read (readMaybe)
import qualified Data.Text.Encoding as Encoding
import Control.Monad
import Data.Function ((&))
import Data.Maybe (listToMaybe)
import Data.Monoid ((<>))
import qualified Data.CaseInsensitive as CI
import Lens.Micro.Platform hiding ((.=))
import qualified Network.Wreq as Wreq
import qualified Network.HTTP.Client as Http
import qualified System.Environment as Environment
import qualified System.Process as Process
import qualified Data.UUID as UUID
import qualified Data.UUID.V4 as UUID
import System.IO (hFlush, stdout)
type LByteString = LazyByteString.ByteString
type ByteString = ByteString.ByteString
type App a =
ExceptT RuntimeError IO a
data RuntimeError
= EnvironmentVariableNotSet Text
| ApiConnectionError
| ApiHeaderNotSet Text
| ParseError Text Text
| InvocationError Text
deriving (Show)
instance Exception RuntimeError
instance ToJSON RuntimeError where
toJSON (EnvironmentVariableNotSet msg) = object
[ "errorType" .= ("EnvironmentVariableNotSet" :: Text)
, "errorMessage" .= msg
]
toJSON ApiConnectionError = object
[ "errorType" .= ("ApiConnectionError" :: Text)
, "errorMessage" .= ("Could not connect to API to retrieve AWS Lambda parameters" :: Text)
]
toJSON (ApiHeaderNotSet headerName) = object
[ "errorType" .= ("ApiHeaderNotSet" :: Text)
, "errorMessage" .= headerName
]
toJSON (ParseError objectBeingParsed value) = object
[ "errorType" .= ("ParseError" :: Text)
, "errorMessage" .= ("Parse error for " <> objectBeingParsed <> ", could not parse value '" <> value <> "'")
]
toJSON (InvocationError err) = toJSON err
data Context = Context
{ memoryLimitInMb :: !Int
, functionName :: !Text
, functionVersion :: !Text
, invokedFunctionArn :: !Text
, awsRequestId :: !Text
, xrayTraceId :: !Text
, logStreamName :: !Text
, logGroupName :: !Text
, deadline :: !Int
} deriving (Generic)
instance FromJSON Context
instance ToJSON Context
newtype LambdaResult =
LambdaResult Text
awsLambdaVersion :: String
awsLambdaVersion = "2018-06-01"
nextInvocationEndpoint :: Text -> String
nextInvocationEndpoint endpoint =
"http://" <> Text.unpack endpoint <> "/"<> awsLambdaVersion <>"/runtime/invocation/next"
responseEndpoint :: Text -> Text -> String
responseEndpoint lambdaApi requestId =
"http://"<> Text.unpack lambdaApi <> "/" <> awsLambdaVersion <> "/runtime/invocation/"<> Text.unpack requestId <> "/response"
invocationErrorEndpoint :: Text -> Text -> String
invocationErrorEndpoint lambdaApi requestId =
"http://"<> Text.unpack lambdaApi <> "/" <> awsLambdaVersion <> "/runtime/invocation/"<> Text.unpack requestId <> "/error"
runtimeInitErrorEndpoint :: Text -> String
runtimeInitErrorEndpoint lambdaApi =
"http://"<> Text.unpack lambdaApi <> "/" <> awsLambdaVersion <> "/runtime/init/error"
readEnvironmentVariable :: Text -> App Text
readEnvironmentVariable envVar = do
v <- lift (Environment.lookupEnv $ Text.unpack envVar)
case v of
Nothing -> throwError (EnvironmentVariableNotSet envVar)
Just value -> pure (Text.pack value)
readFunctionMemory :: App Int
readFunctionMemory = do
let envVar = "AWS_LAMBDA_FUNCTION_MEMORY_SIZE"
let parseMemory txt = readMaybe (Text.unpack txt)
memoryValue <- readEnvironmentVariable envVar
case parseMemory memoryValue of
Just value -> pure value
Nothing -> throwError (ParseError envVar memoryValue)
getApiData :: Text -> App (Wreq.Response LByteString)
getApiData endpoint =
keepRetrying (Wreq.getWith opts $ nextInvocationEndpoint endpoint)
where
opts = Wreq.defaults
& Wreq.manager .~ Left (Http.defaultManagerSettings { Http.managerResponseTimeout = Http.responseTimeoutNone })
keepRetrying :: IO (Wreq.Response LByteString) -> App (Wreq.Response LByteString)
keepRetrying f = do
result <- (liftIO $ try f) :: App (Either IOException (Wreq.Response LByteString))
case result of
Right x -> return x
_ -> keepRetrying f
extractHeader :: Wreq.Response LByteString -> Text -> Text
extractHeader apiData header =
Encoding.decodeUtf8 (apiData ^. Wreq.responseHeader (CI.mk $ Encoding.encodeUtf8 header))
extractIntHeader :: Wreq.Response LByteString -> Text -> App Int
extractIntHeader apiData headerName = do
let header = extractHeader apiData headerName
case readMaybe $ Text.unpack header of
Nothing -> throwError (ParseError "deadline" header)
Just value -> pure value
extractBody :: Wreq.Response LByteString -> Text
extractBody apiData =
Encoding.decodeUtf8 $ LazyByteString.toStrict (apiData ^. Wreq.responseBody)
propagateXRayTrace :: Text -> App ()
propagateXRayTrace xrayTraceId =
liftIO $ Environment.setEnv "_X_AMZN_TRACE_ID" $ Text.unpack xrayTraceId
initializeContext :: Wreq.Response LByteString -> App Context
initializeContext apiData = do
functionName <- readEnvironmentVariable "AWS_LAMBDA_FUNCTION_NAME"
version <- readEnvironmentVariable "AWS_LAMBDA_FUNCTION_VERSION"
logStream <- readEnvironmentVariable "AWS_LAMBDA_LOG_STREAM_NAME"
logGroup <- readEnvironmentVariable "AWS_LAMBDA_LOG_GROUP_NAME"
memoryLimitInMb <- readFunctionMemory
deadline <- extractIntHeader apiData "Lambda-Runtime-Deadline-Ms"
let xrayTraceId = extractHeader apiData "Lambda-Runtime-Trace-Id"
let awsRequestId = extractHeader apiData "Lambda-Runtime-Aws-Request-Id"
let invokedFunctionArn = extractHeader apiData "Lambda-Runtime-Invoked-Function-Arn"
propagateXRayTrace xrayTraceId
pure $ Context
{ functionName = functionName
, functionVersion = version
, logStreamName = logStream
, logGroupName = logGroup
, memoryLimitInMb = memoryLimitInMb
, invokedFunctionArn = invokedFunctionArn
, xrayTraceId = xrayTraceId
, awsRequestId = awsRequestId
, deadline = deadline
}
getFunctionResult :: UUID.UUID -> Text -> App (Maybe Text)
getFunctionResult u stdOut = do
let out = Text.lines stdOut
out
& takeWhile (/= uuid)
& mapM_ ( \t -> do
liftIO $ putStrLn $ Text.unpack t
liftIO $ hFlush stdout)
out
& dropWhile (/= uuid)
& dropWhile (== uuid)
& listToMaybe
& return
where
uuid = Text.pack $ UUID.toString u
invoke :: Text -> Context -> App LambdaResult
invoke event context = do
handlerName <- readEnvironmentVariable "_HANDLER"
runningDirectory <- readEnvironmentVariable "LAMBDA_TASK_ROOT"
let contextJSON = Encoding.decodeUtf8 $ LazyByteString.toStrict $ encode context
uuid <- liftIO UUID.nextRandom
out <- liftIO $ Process.readProcessWithExitCode (Text.unpack runningDirectory <> "/haskell_lambda")
[ "--eventObject", Text.unpack event
, "--contextObject", Text.unpack contextJSON
, "--functionHandler", Text.unpack handlerName
, "--executionUuid", UUID.toString uuid
]
""
case out of
(ExitSuccess, stdOut, _) -> do
res <- getFunctionResult uuid (Text.pack stdOut)
case res of
Nothing -> throwError (ParseError "parsing result" $ Text.pack stdOut)
Just value -> pure (LambdaResult value)
(_, stdOut, stdErr) ->
if stdErr /= ""
then throwError (InvocationError $ Text.pack stdErr)
else do
res <- getFunctionResult uuid (Text.pack stdOut)
case res of
Nothing -> throwError (ParseError "parsing error" $ Text.pack stdOut)
Just value -> throwError (InvocationError value)
publishResult :: Context -> Text -> LambdaResult -> App ()
publishResult Context {..} lambdaApi (LambdaResult result) =
void $ liftIO $ Wreq.post (responseEndpoint lambdaApi awsRequestId) (Encoding.encodeUtf8 result)
invokeAndPublish :: Context -> Text -> Text -> App ()
invokeAndPublish ctx event lambdaApiEndpoint = do
res <- invoke event ctx
publishResult ctx lambdaApiEndpoint res
publishError :: Context -> Text -> RuntimeError -> App ()
publishError Context {..} lambdaApiEndpoint (InvocationError err) =
void (liftIO $ Wreq.post (invocationErrorEndpoint lambdaApiEndpoint awsRequestId) (Encoding.encodeUtf8 err))
publishError Context {..} lambdaApiEndpoint (ParseError t t2) =
void (liftIO $ Wreq.post (invocationErrorEndpoint lambdaApiEndpoint awsRequestId) (toJSON $ ParseError t t2))
publishError Context {..} lambdaApiEndpoint err =
void (liftIO $ Wreq.post (runtimeInitErrorEndpoint lambdaApiEndpoint) (toJSON err))
lambdaRunner :: App ()
lambdaRunner = do
lambdaApiEndpoint <- readEnvironmentVariable "AWS_LAMBDA_RUNTIME_API"
apiData <- getApiData lambdaApiEndpoint
let event = extractBody apiData
ctx <- initializeContext apiData
invokeAndPublish ctx event lambdaApiEndpoint `catchError` publishError ctx lambdaApiEndpoint