{-# LANGUAGE OverloadedStrings #-}

module Network.Wai.Predicate.Request
    ( Req
    , HasMethod  (..)
    , HasHeaders (..)
    , HasCookies (..)
    , HasQuery   (..)
    , HasPath    (..)
    , HasRequest (..)
    , HasVault   (..)

    , fromRequest
    , lookupHeader
    , lookupQuery
    , lookupCookie
    , lookupSegment
    ) where

import Data.ByteString (ByteString)
import Data.Maybe (mapMaybe)
import Data.Vector (Vector, (!?))
import Data.Vault.Lazy (Vault)
import Data.Word
import Network.HTTP.Types
import Network.Wai (Request)
import Web.Cookie
import Prelude

import qualified Data.ByteString as B
import qualified Network.Wai     as Wai
import qualified Data.Vector     as Vec

class HasRequest a where
    getRequest :: a -> Wai.Request

class HasMethod a where
    method :: a -> Method

class HasHeaders a where
    headers :: a -> RequestHeaders

class HasCookies a where
    cookies :: a -> Cookies

class HasQuery a where
    queryItems :: a -> Query

class HasPath a where
    segments :: a -> Vector ByteString

class HasVault a where
    requestVault :: a -> Vault

data Req = Req
    { Req -> Request
_request  :: Request
    , Req -> Cookies
_cookies  :: Cookies
    , Req -> Vector ByteString
_segments :: Vector ByteString
    }

instance HasRequest Req where
    getRequest :: Req -> Request
getRequest = Req -> Request
_request

instance HasMethod Req where
    method :: Req -> ByteString
method = Request -> ByteString
Wai.requestMethod (Request -> ByteString) -> (Req -> Request) -> Req -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Req -> Request
forall a. HasRequest a => a -> Request
getRequest

instance HasMethod Wai.Request where
    method :: Request -> ByteString
method = Request -> ByteString
Wai.requestMethod

instance HasHeaders Req where
    headers :: Req -> RequestHeaders
headers = Request -> RequestHeaders
Wai.requestHeaders (Request -> RequestHeaders)
-> (Req -> Request) -> Req -> RequestHeaders
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Req -> Request
forall a. HasRequest a => a -> Request
getRequest

instance HasHeaders Wai.Request where
    headers :: Request -> RequestHeaders
headers = Request -> RequestHeaders
Wai.requestHeaders

instance HasQuery Req where
    queryItems :: Req -> Query
queryItems = Request -> Query
Wai.queryString (Request -> Query) -> (Req -> Request) -> Req -> Query
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Req -> Request
forall a. HasRequest a => a -> Request
getRequest

instance HasQuery Wai.Request where
    queryItems :: Request -> Query
queryItems = Request -> Query
Wai.queryString

instance HasCookies Req where
    cookies :: Req -> Cookies
cookies = Req -> Cookies
_cookies

instance HasPath Req where
    segments :: Req -> Vector ByteString
segments = Req -> Vector ByteString
_segments

instance HasVault Req where
    requestVault :: Req -> Vault
requestVault = Request -> Vault
Wai.vault (Request -> Vault) -> (Req -> Request) -> Req -> Vault
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Req -> Request
_request

fromRequest :: Request -> Req
fromRequest :: Request -> Req
fromRequest Request
rq =
    Request -> Cookies -> Vector ByteString -> Req
Req Request
rq ((ByteString -> Cookies) -> [ByteString] -> Cookies
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ByteString -> Cookies
parseCookies (HeaderName -> Request -> [ByteString]
forall r. HasHeaders r => HeaderName -> r -> [ByteString]
getHeaders HeaderName
"Cookie" Request
rq))
           ([ByteString] -> Vector ByteString
forall a. [a] -> Vector a
Vec.fromList ([ByteString] -> Vector ByteString)
-> (Request -> [ByteString]) -> Request -> Vector ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
splitSegments (ByteString -> [ByteString])
-> (Request -> ByteString) -> Request -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Request -> ByteString
Wai.rawPathInfo (Request -> Vector ByteString) -> Request -> Vector ByteString
forall a b. (a -> b) -> a -> b
$ Request
rq)

lookupHeader :: HasHeaders r => HeaderName -> r -> [ByteString]
lookupHeader :: HeaderName -> r -> [ByteString]
lookupHeader HeaderName
name = HeaderName -> r -> [ByteString]
forall r. HasHeaders r => HeaderName -> r -> [ByteString]
getHeaders HeaderName
name

lookupSegment :: HasPath r => Word -> r -> Maybe ByteString
lookupSegment :: Word -> r -> Maybe ByteString
lookupSegment Word
i r
r = r -> Vector ByteString
forall a. HasPath a => a -> Vector ByteString
segments r
r Vector ByteString -> Int -> Maybe ByteString
forall a. Vector a -> Int -> Maybe a
!? Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
i

lookupCookie :: HasCookies r => ByteString -> r -> [ByteString]
lookupCookie :: ByteString -> r -> [ByteString]
lookupCookie ByteString
name = ((ByteString, ByteString) -> ByteString) -> Cookies -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map (ByteString, ByteString) -> ByteString
forall a b. (a, b) -> b
snd (Cookies -> [ByteString]) -> (r -> Cookies) -> r -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((ByteString, ByteString) -> Bool) -> Cookies -> Cookies
forall a. (a -> Bool) -> [a] -> [a]
filter ((ByteString
name ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
==) (ByteString -> Bool)
-> ((ByteString, ByteString) -> ByteString)
-> (ByteString, ByteString)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString, ByteString) -> ByteString
forall a b. (a, b) -> a
fst) (Cookies -> Cookies) -> (r -> Cookies) -> r -> Cookies
forall b c a. (b -> c) -> (a -> b) -> a -> c
. r -> Cookies
forall a. HasCookies a => a -> Cookies
cookies

lookupQuery :: HasQuery r => ByteString -> r -> [ByteString]
lookupQuery :: ByteString -> r -> [ByteString]
lookupQuery ByteString
name = ((ByteString, Maybe ByteString) -> Maybe ByteString)
-> Query -> [ByteString]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (ByteString, Maybe ByteString) -> Maybe ByteString
forall a b. (a, b) -> b
snd (Query -> [ByteString]) -> (r -> Query) -> r -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((ByteString, Maybe ByteString) -> Bool) -> Query -> Query
forall a. (a -> Bool) -> [a] -> [a]
filter ((ByteString
name ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
==) (ByteString -> Bool)
-> ((ByteString, Maybe ByteString) -> ByteString)
-> (ByteString, Maybe ByteString)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString, Maybe ByteString) -> ByteString
forall a b. (a, b) -> a
fst) (Query -> Query) -> (r -> Query) -> r -> Query
forall b c a. (b -> c) -> (a -> b) -> a -> c
. r -> Query
forall a. HasQuery a => a -> Query
queryItems

getHeaders :: HasHeaders r => HeaderName -> r -> [ByteString]
getHeaders :: HeaderName -> r -> [ByteString]
getHeaders HeaderName
name = ((HeaderName, ByteString) -> ByteString)
-> RequestHeaders -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map (HeaderName, ByteString) -> ByteString
forall a b. (a, b) -> b
snd (RequestHeaders -> [ByteString])
-> (r -> RequestHeaders) -> r -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((HeaderName, ByteString) -> Bool)
-> RequestHeaders -> RequestHeaders
forall a. (a -> Bool) -> [a] -> [a]
filter ((HeaderName
name HeaderName -> HeaderName -> Bool
forall a. Eq a => a -> a -> Bool
==) (HeaderName -> Bool)
-> ((HeaderName, ByteString) -> HeaderName)
-> (HeaderName, ByteString)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (HeaderName, ByteString) -> HeaderName
forall a b. (a, b) -> a
fst) (RequestHeaders -> RequestHeaders)
-> (r -> RequestHeaders) -> r -> RequestHeaders
forall b c a. (b -> c) -> (a -> b) -> a -> c
. r -> RequestHeaders
forall a. HasHeaders a => a -> RequestHeaders
headers

-----------------------------------------------------------------------------
-- Internal

splitSegments :: ByteString -> [ByteString]
splitSegments :: ByteString -> [ByteString]
splitSegments ByteString
a
    | ByteString -> Bool
B.null ByteString
a  = []
    | ByteString
"/" ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
a  = []
    | Bool
otherwise = if ByteString -> Word8
B.head ByteString
a Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
slash then ByteString -> [ByteString]
go (ByteString -> ByteString
B.tail ByteString
a) else ByteString -> [ByteString]
go ByteString
a
  where
    go :: ByteString -> [ByteString]
go ByteString
b =
        let (ByteString
x, ByteString
y) = (Word8 -> Bool) -> ByteString -> (ByteString, ByteString)
B.break (Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
slash) ByteString
b
        in Bool -> ByteString -> ByteString
urlDecode Bool
False ByteString
x ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: if ByteString -> Bool
B.null ByteString
y then [] else ByteString -> [ByteString]
go (ByteString -> ByteString
B.tail ByteString
y)
    slash :: Word8
slash = Word8
47