{-# LANGUAGE DataKinds         #-}
{-# LANGUAGE KindSignatures    #-}
{-# LANGUAGE OverloadedStrings #-}

module Network.Wai.Predicate.Accept
    ( accept
    , module Network.Wai.Predicate.MediaType
    ) where

import Control.Monad
import Data.ByteString (ByteString)
import Data.Monoid hiding (All)
import Data.Maybe
import Data.Predicate
import Data.Singletons.TypeLits (Symbol)
import Network.Wai.Predicate.Error
import Network.Wai.Predicate.Request
import Network.Wai.Predicate.MediaType
import Network.Wai.Predicate.Utility

import qualified Network.Wai.Predicate.Parser.MediaType as M

accept :: HasHeaders r
       => ByteString
       -> ByteString
       -> Predicate r Error (Media (t :: Symbol) (s :: Symbol))
accept :: ByteString -> ByteString -> Predicate r Error (Media t s)
accept ByteString
t ByteString
s r
r =
    let mtypes :: [MediaType]
mtypes = HeaderName -> r -> [MediaType]
forall r. HasHeaders r => HeaderName -> r -> [MediaType]
M.readMediaTypes HeaderName
"accept" r
r in
    if [MediaType] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [MediaType]
mtypes
        then Media t s -> Result Error (Media t s)
forall (m :: * -> *) a. Monad m => a -> m a
return (ByteString
-> ByteString -> Double -> [(ByteString, ByteString)] -> Media t s
forall (t :: Symbol) (s :: Symbol).
ByteString
-> ByteString -> Double -> [(ByteString, ByteString)] -> Media t s
Media ByteString
t ByteString
s Double
1.0 [])
        else case ByteString -> ByteString -> [MediaType] -> [Media t s]
forall (t :: Symbol) (s :: Symbol).
ByteString -> ByteString -> [MediaType] -> [Media t s]
findMediaType ByteString
t ByteString
s [MediaType]
mtypes of
            Media t s
m:[Media t s]
_ -> Double -> Media t s -> Result Error (Media t s)
forall f t. Double -> t -> Result f t
Okay (Double
1.0 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Media t s -> Double
forall (t :: Symbol) (s :: Symbol). Media t s -> Double
mediaQuality Media t s
m) Media t s
m
            []  -> Error -> Result Error (Media t s)
forall f t. f -> Result f t
Fail (Error
e406 Error -> (Error -> Error) -> Error
forall a b. a -> (a -> b) -> b
& ByteString -> Error -> Error
setMessage ByteString
msg)
      where
        msg :: ByteString
msg = ByteString
"Expected 'Accept: " ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
t ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"/" ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
s ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
"'."

findMediaType :: ByteString -> ByteString -> [M.MediaType] -> [Media t s]
findMediaType :: ByteString -> ByteString -> [MediaType] -> [Media t s]
findMediaType ByteString
t ByteString
s = (MediaType -> Maybe (Media t s)) -> [MediaType] -> [Media t s]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (\MediaType
m -> do
    let mt :: ByteString
mt = MediaType -> ByteString
M.medType MediaType
m
        ms :: ByteString
ms = MediaType -> ByteString
M.medSubtype MediaType
m
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard ((ByteString
mt ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"*" Bool -> Bool -> Bool
|| ByteString
t ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
mt) Bool -> Bool -> Bool
&& (ByteString
ms ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
"*" Bool -> Bool -> Bool
|| ByteString
s ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== ByteString
ms))
    Media t s -> Maybe (Media t s)
forall (m :: * -> *) a. Monad m => a -> m a
return (Media t s -> Maybe (Media t s)) -> Media t s -> Maybe (Media t s)
forall a b. (a -> b) -> a -> b
$ ByteString
-> ByteString -> Double -> [(ByteString, ByteString)] -> Media t s
forall (t :: Symbol) (s :: Symbol).
ByteString
-> ByteString -> Double -> [(ByteString, ByteString)] -> Media t s
Media ByteString
t ByteString
s (MediaType -> Double
M.medQuality MediaType
m) (MediaType -> [(ByteString, ByteString)]
M.medParams MediaType
m))