{-# LANGUAGE OverloadedStrings #-}
module AWSLambda.Handler
( lambdaMain
, lambdaMainRaw
) where
import Control.Exception.Safe (MonadCatch, SomeException(..), displayException, tryAny)
import Control.Monad (forever, void)
import Control.Monad.IO.Class
import Data.Aeson ((.=))
import qualified Data.Aeson as Aeson
import Data.Typeable (typeOf)
import qualified Data.ByteString as ByteString
import qualified Data.ByteString.Char8 as Char8
import qualified Data.ByteString.Lazy as LBS
import qualified Data.Text.Encoding as Text
import qualified Data.Text.IO as Text
import GHC.IO.Handle (BufferMode(..), hSetBuffering)
import Network.HTTP.Client
import Network.HTTP.Types (HeaderName)
import System.Environment (lookupEnv)
import System.IO (stdout)
lambdaMain ::
(Aeson.FromJSON event, Aeson.ToJSON res, MonadCatch m, MonadIO m)
=> (event -> m res)
-> m ()
lambdaMain act =
lambdaMainRaw $ \input -> do
case Aeson.eitherDecode input of
Left err -> error err
Right event -> do
result <- act event
pure $ Aeson.encode result
lambdaMainRaw :: (MonadCatch m, MonadIO m) => (LBS.ByteString -> m LBS.ByteString) -> m ()
lambdaMainRaw act = do
lambdaApiAddress <- liftIO $ lookupEnv lambdaApiAddressEnv
case lambdaApiAddress of
Just address -> do
liftIO $ hSetBuffering stdout LineBuffering
manager <- liftIO $ newManager defaultManagerSettings
forever $ do
invocation <- liftIO $ httpLbs (invocationRequest address) manager
let input = responseBody invocation
let requestId = responseRequestId invocation
resultOrError <- tryAny $ act input
case resultOrError of
Right result -> liftIO $ void $ httpNoBody (resultRequest address requestId result) manager
Left exception -> liftIO $ void $ httpNoBody (errorRequest address requestId exception) manager
Nothing -> do
input <- liftIO $ LBS.fromStrict <$> ByteString.getLine
result <- act input
liftIO $ Text.putStrLn $ Text.decodeUtf8 $ LBS.toStrict result
lambdaApiAddressEnv :: String
lambdaApiAddressEnv = "AWS_LAMBDA_RUNTIME_API"
lambdaRequest :: String -> String -> Request
lambdaRequest apiAddress rqPath = parseRequest_ $ "http://" ++ apiAddress ++ "/2018-06-01" ++ rqPath
invocationRequest :: String -> Request
invocationRequest apiAddress = (lambdaRequest apiAddress "/runtime/invocation/next") { responseTimeout = responseTimeoutNone }
resultRequest :: String -> String -> LBS.ByteString -> Request
resultRequest apiAddress requestId result = (lambdaRequest apiAddress $ "/runtime/invocation/" ++ requestId ++ "/response") { method = "POST", requestBody = RequestBodyLBS result }
errorRequest :: String -> String -> SomeException -> Request
errorRequest apiAddress requestId exception = (lambdaRequest apiAddress $ "/runtime/invocation/" ++ requestId ++ "/error") { method = "POST", requestBody = RequestBodyLBS body }
where
body = Aeson.encode $ Aeson.object [ "errorMessage" .= displayException exception, "errorType" .= exceptionType exception]
exceptionType :: SomeException -> String
exceptionType (SomeException e) = show (typeOf e)
requestIdHeader :: HeaderName
requestIdHeader = "Lambda-Runtime-Aws-Request-Id"
responseRequestId :: Response a -> String
responseRequestId = Char8.unpack . snd . head . filter (uncurry $ \h _ -> h == requestIdHeader) . responseHeaders