{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE ViewPatterns #-}
module Aws.Lambda.Wai (waiHandler, waiHandler', WaiHandler) where
import Aws.Lambda
import Control.Concurrent.MVar
import Data.Aeson
import qualified Data.Aeson as Aeson
import qualified Data.Aeson.Types as Aeson
import qualified Data.Binary.Builder as Binary
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Lazy as BL
import qualified Data.CaseInsensitive as CI
import qualified Data.HashMap.Strict as HMap
import Data.IORef
import qualified Data.IP as IP
import Data.Text (Text)
import qualified Data.Text as T
import Data.Text.Encoding (decodeUtf8', encodeUtf8)
import qualified Data.Text.Encoding as T
import qualified Data.Vault.Lazy as Vault
import GHC.IO.Unsafe (unsafePerformIO)
import qualified Network.HTTP.Types as H
import qualified Network.Socket as Socket
import Network.Wai (Application)
import qualified Network.Wai as Wai
import qualified Network.Wai.Internal as Wai
import Text.Read (readMaybe)
type WaiHandler context = ApiGatewayRequest Text -> Context context -> IO (Either (ApiGatewayResponse Text) (ApiGatewayResponse Text))
waiHandler :: forall context. IO Wai.Application -> WaiHandler context
waiHandler initApp gatewayRequest context = initApp >>=
\app -> waiHandler'' app gatewayRequest context
waiHandler' :: forall context. (context -> Wai.Application) -> WaiHandler context
waiHandler' getApp request context = do
app <- getApp <$> readIORef (customContext context)
waiHandler'' app request context
waiHandler'' :: forall context. Wai.Application -> WaiHandler context
waiHandler'' waiApplication gatewayRequest _ = do
waiRequest <- mkWaiRequest gatewayRequest
(status, headers, body) <- processRequest waiApplication waiRequest >>= readResponse
if BS.null body
then return . pure . wrapInResponse (H.statusCode status) headers $ mempty
else case decodeUtf8' body of
Right responseBodyText ->
return . pure . wrapInResponse (H.statusCode status) headers $ responseBodyText
Left err -> error "Expected a response body that is valid UTF-8."
mkWaiRequest :: ApiGatewayRequest Text -> IO Wai.Request
mkWaiRequest ApiGatewayRequest{..} = do
let ApiGatewayRequestContext{..} = apiGatewayRequestRequestContext
ApiGatewayRequestContextIdentity{..} = apiGatewayRequestContextIdentity
ip <- parseIp apiGatewayRequestContextIdentitySourceIp
let pathInfo = H.decodePathSegments (encodeUtf8 apiGatewayRequestPath)
let requestBodyRaw = maybe mempty T.encodeUtf8 apiGatewayRequestBody
let requestBodyLength = Wai.KnownLength $ fromIntegral $ BS.length requestBodyRaw
requestBodyMVar <- newMVar requestBodyRaw
let requestBody = takeRequestBodyChunk requestBodyMVar
let requestHeaderHost = encodeUtf8 <$> HMap.lookup "host" apiGatewayRequestHeaders
let requestHeaderRange = encodeUtf8 <$> HMap.lookup "range" apiGatewayRequestHeaders
let requestHeaderReferer = encodeUtf8 <$> HMap.lookup "referer" apiGatewayRequestHeaders
let requestHeaderUserAgent = encodeUtf8 <$> HMap.lookup "User-Agent" apiGatewayRequestHeaders
let queryParameters = toQueryStringParameters apiGatewayRequestQueryStringParameters
rawQueryString = H.renderQuery True queryParameters
httpVersion = getHttpVersion apiGatewayRequestContextProtocol
let result = Wai.Request
(encodeUtf8 apiGatewayRequestHttpMethod)
httpVersion
(encodeUtf8 apiGatewayRequestPath)
rawQueryString
(map toHeader $ HMap.toList apiGatewayRequestHeaders)
True
ip
pathInfo
queryParameters
requestBody
Vault.empty
requestBodyLength
requestHeaderHost
requestHeaderRange
requestHeaderReferer
requestHeaderUserAgent
return result
getHttpVersion :: Text -> H.HttpVersion
getHttpVersion protocol
| "0.9" `T.isSuffixOf` protocol = H.http09
| "1.0" `T.isSuffixOf` protocol = H.http10
| "1.1" `T.isSuffixOf` protocol = H.http11
| "2.0" `T.isSuffixOf` protocol = H.http20
| otherwise = H.http11
takeRequestBodyChunk :: MVar ByteString -> IO ByteString
takeRequestBodyChunk requestBodyMVar = do
result <- tryTakeMVar requestBodyMVar
case result of
Just bs -> pure bs
Nothing -> pure BS.empty
toQueryStringParameters :: Maybe (HMap.HashMap Text Text) -> [H.QueryItem]
toQueryStringParameters (Just params) =
let toQueryItem (key, value) = (encodeUtf8 key, Just $ encodeUtf8 value)
in map toQueryItem $ HMap.toList params
toQueryStringParameters _ = []
parseIp :: Maybe Text -> IO Socket.SockAddr
parseIp sourceIpText =
case sourceIpText of
Just sourceIp ->
case readMaybe (T.unpack sourceIp) of
Just ip ->
pure $ case ip of
IP.IPv4 ip4 ->
Socket.SockAddrInet
0
(IP.toHostAddress ip4)
IP.IPv6 ip6 ->
Socket.SockAddrInet6
0
0
(IP.toHostAddress6 ip6)
0
Nothing -> error "Could not parse source ip."
Nothing -> error "Missing source ip."
processRequest :: Application -> Wai.Request -> IO Wai.Response
processRequest app req = do
mvar <- newEmptyMVar
Wai.ResponseReceived <- app req $ \resp -> do
putMVar mvar resp
pure Wai.ResponseReceived
takeMVar mvar
readResponse :: Wai.Response -> IO (H.Status, H.ResponseHeaders, ByteString)
readResponse (Wai.responseToStream -> (st, hdrs, mkBody)) = do
body <- mkBody drainBody
pure (st, hdrs, body)
where
drainBody :: Wai.StreamingBody -> IO ByteString
drainBody body = do
ioRef <- newIORef Binary.empty
body
(\b -> atomicModifyIORef ioRef (\b' -> (b <> b', ())))
(pure ())
BL.toStrict . Binary.toLazyByteString <$> readIORef ioRef
wrapInResponse
:: Int
-> H.ResponseHeaders
-> res
-> ApiGatewayResponse res
wrapInResponse code responseHeaders response =
ApiGatewayResponse code responseHeaders response False
toHeader :: (Text, Text) -> H.Header
toHeader (name, val) = (CI.mk . encodeUtf8 $ name, encodeUtf8 val)
tshow :: Show a => a -> Text
tshow = T.pack . show