{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE NamedFieldPuns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeFamilies #-}
module Auth.Biscuit.Servant
(
RequireBiscuit
, CheckedBiscuit (..)
, authHandler
, genBiscuitCtx
, checkBiscuit
, checkBiscuitM
, WithVerifier (..)
, handleBiscuit
, withVerifier
, withVerifier_
, withVerifierM
, withVerifierM_
, noVerifier
, noVerifier_
, withFallbackVerifier
, withPriorityVerifier
, withFallbackVerifierM
, withPriorityVerifierM
) where
import Auth.Biscuit (Biscuit, PublicKey, Verifier,
checkBiscuitSignature,
parseB64, verifyBiscuit)
import Control.Applicative (liftA2)
import Control.Monad.Except (MonadError, throwError)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Reader (ReaderT, lift, runReaderT)
import Data.Bifunctor (first)
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as C8
import qualified Data.ByteString.Lazy as LBS
import Network.Wai
import Servant (AuthProtect)
import Servant.Server
import Servant.Server.Experimental.Auth
type RequireBiscuit = AuthProtect "biscuit"
type instance AuthServerData RequireBiscuit = CheckedBiscuit
data CheckedBiscuit = CheckedBiscuit PublicKey Biscuit
data WithVerifier m a
= WithVerifier
{ WithVerifier m a -> ReaderT Biscuit m a
handler_ :: ReaderT Biscuit m a
, WithVerifier m a -> m Verifier
verifier_ :: m Verifier
}
withFallbackVerifier :: Functor m
=> Verifier
-> WithVerifier m a
-> WithVerifier m a
withFallbackVerifier :: Verifier -> WithVerifier m a -> WithVerifier m a
withFallbackVerifier Verifier
newV h :: WithVerifier m a
h@WithVerifier{m Verifier
verifier_ :: m Verifier
verifier_ :: forall (m :: * -> *) a. WithVerifier m a -> m Verifier
verifier_} =
WithVerifier m a
h { verifier_ :: m Verifier
verifier_ = (Verifier -> Verifier -> Verifier
forall a. Semigroup a => a -> a -> a
<> Verifier
newV) (Verifier -> Verifier) -> m Verifier -> m Verifier
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Verifier
verifier_ }
withFallbackVerifierM :: Applicative m
=> m Verifier
-> WithVerifier m a
-> WithVerifier m a
withFallbackVerifierM :: m Verifier -> WithVerifier m a -> WithVerifier m a
withFallbackVerifierM m Verifier
newV h :: WithVerifier m a
h@WithVerifier{m Verifier
verifier_ :: m Verifier
verifier_ :: forall (m :: * -> *) a. WithVerifier m a -> m Verifier
verifier_} =
WithVerifier m a
h { verifier_ :: m Verifier
verifier_ = (Verifier -> Verifier -> Verifier)
-> m Verifier -> m Verifier -> m Verifier
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Verifier -> Verifier -> Verifier
forall a. Semigroup a => a -> a -> a
(<>) m Verifier
verifier_ m Verifier
newV }
withPriorityVerifier :: Functor m
=> Verifier
-> WithVerifier m a
-> WithVerifier m a
withPriorityVerifier :: Verifier -> WithVerifier m a -> WithVerifier m a
withPriorityVerifier Verifier
newV h :: WithVerifier m a
h@WithVerifier{m Verifier
verifier_ :: m Verifier
verifier_ :: forall (m :: * -> *) a. WithVerifier m a -> m Verifier
verifier_} =
WithVerifier m a
h { verifier_ :: m Verifier
verifier_ = (Verifier
newV Verifier -> Verifier -> Verifier
forall a. Semigroup a => a -> a -> a
<>) (Verifier -> Verifier) -> m Verifier -> m Verifier
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Verifier
verifier_ }
withPriorityVerifierM :: Applicative m
=> m Verifier
-> WithVerifier m a
-> WithVerifier m a
withPriorityVerifierM :: m Verifier -> WithVerifier m a -> WithVerifier m a
withPriorityVerifierM m Verifier
newV h :: WithVerifier m a
h@WithVerifier{m Verifier
verifier_ :: m Verifier
verifier_ :: forall (m :: * -> *) a. WithVerifier m a -> m Verifier
verifier_} =
WithVerifier m a
h { verifier_ :: m Verifier
verifier_ = (Verifier -> Verifier -> Verifier)
-> m Verifier -> m Verifier -> m Verifier
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 Verifier -> Verifier -> Verifier
forall a. Semigroup a => a -> a -> a
(<>) m Verifier
newV m Verifier
verifier_ }
withVerifier :: Applicative m => Verifier -> ReaderT Biscuit m a -> WithVerifier m a
withVerifier :: Verifier -> ReaderT Biscuit m a -> WithVerifier m a
withVerifier Verifier
v ReaderT Biscuit m a
handler_ =
WithVerifier :: forall (m :: * -> *) a.
ReaderT Biscuit m a -> m Verifier -> WithVerifier m a
WithVerifier
{ ReaderT Biscuit m a
handler_ :: ReaderT Biscuit m a
handler_ :: ReaderT Biscuit m a
handler_
, verifier_ :: m Verifier
verifier_ = Verifier -> m Verifier
forall (f :: * -> *) a. Applicative f => a -> f a
pure Verifier
v
}
withVerifierM :: m Verifier -> ReaderT Biscuit m a -> WithVerifier m a
withVerifierM :: m Verifier -> ReaderT Biscuit m a -> WithVerifier m a
withVerifierM m Verifier
verifier_ ReaderT Biscuit m a
handler_ =
WithVerifier :: forall (m :: * -> *) a.
ReaderT Biscuit m a -> m Verifier -> WithVerifier m a
WithVerifier
{ ReaderT Biscuit m a
handler_ :: ReaderT Biscuit m a
handler_ :: ReaderT Biscuit m a
handler_
, m Verifier
verifier_ :: m Verifier
verifier_ :: m Verifier
verifier_
}
withVerifier_ :: Monad m => Verifier -> m a -> WithVerifier m a
withVerifier_ :: Verifier -> m a -> WithVerifier m a
withVerifier_ Verifier
v = Verifier -> ReaderT Biscuit m a -> WithVerifier m a
forall (m :: * -> *) a.
Applicative m =>
Verifier -> ReaderT Biscuit m a -> WithVerifier m a
withVerifier Verifier
v (ReaderT Biscuit m a -> WithVerifier m a)
-> (m a -> ReaderT Biscuit m a) -> m a -> WithVerifier m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> ReaderT Biscuit m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
withVerifierM_ :: Monad m => m Verifier -> m a -> WithVerifier m a
withVerifierM_ :: m Verifier -> m a -> WithVerifier m a
withVerifierM_ m Verifier
v = m Verifier -> ReaderT Biscuit m a -> WithVerifier m a
forall (m :: * -> *) a.
m Verifier -> ReaderT Biscuit m a -> WithVerifier m a
withVerifierM m Verifier
v (ReaderT Biscuit m a -> WithVerifier m a)
-> (m a -> ReaderT Biscuit m a) -> m a -> WithVerifier m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> ReaderT Biscuit m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
noVerifier :: Applicative m => ReaderT Biscuit m a -> WithVerifier m a
noVerifier :: ReaderT Biscuit m a -> WithVerifier m a
noVerifier = Verifier -> ReaderT Biscuit m a -> WithVerifier m a
forall (m :: * -> *) a.
Applicative m =>
Verifier -> ReaderT Biscuit m a -> WithVerifier m a
withVerifier Verifier
forall a. Monoid a => a
mempty
noVerifier_ :: Monad m => m a -> WithVerifier m a
noVerifier_ :: m a -> WithVerifier m a
noVerifier_ = ReaderT Biscuit m a -> WithVerifier m a
forall (m :: * -> *) a.
Applicative m =>
ReaderT Biscuit m a -> WithVerifier m a
noVerifier (ReaderT Biscuit m a -> WithVerifier m a)
-> (m a -> ReaderT Biscuit m a) -> m a -> WithVerifier m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> ReaderT Biscuit m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift
extractBiscuit :: Request -> Either String Biscuit
Request
req = do
let note :: a -> Maybe b -> Either a b
note a
e = Either a b -> (b -> Either a b) -> Maybe b -> Either a b
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (a -> Either a b
forall a b. a -> Either a b
Left a
e) b -> Either a b
forall a b. b -> Either a b
Right
ByteString
authHeader <- String -> Maybe ByteString -> Either String ByteString
forall a b. a -> Maybe b -> Either a b
note String
"Missing Authorization header" (Maybe ByteString -> Either String ByteString)
-> ([(HeaderName, ByteString)] -> Maybe ByteString)
-> [(HeaderName, ByteString)]
-> Either String ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HeaderName -> [(HeaderName, ByteString)] -> Maybe ByteString
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup HeaderName
"Authorization" ([(HeaderName, ByteString)] -> Either String ByteString)
-> [(HeaderName, ByteString)] -> Either String ByteString
forall a b. (a -> b) -> a -> b
$ Request -> [(HeaderName, ByteString)]
requestHeaders Request
req
ByteString
b64Token <- String -> Maybe ByteString -> Either String ByteString
forall a b. a -> Maybe b -> Either a b
note String
"Not a Bearer token" (Maybe ByteString -> Either String ByteString)
-> Maybe ByteString -> Either String ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> ByteString -> Maybe ByteString
BS.stripPrefix ByteString
"Bearer " ByteString
authHeader
(ParseError -> String)
-> Either ParseError Biscuit -> Either String Biscuit
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first (String -> ParseError -> String
forall a b. a -> b -> a
const String
"Not a B64-encoded biscuit") (Either ParseError Biscuit -> Either String Biscuit)
-> Either ParseError Biscuit -> Either String Biscuit
forall a b. (a -> b) -> a -> b
$ ByteString -> Either ParseError Biscuit
parseB64 ByteString
b64Token
authHandler :: PublicKey -> AuthHandler Request CheckedBiscuit
authHandler :: PublicKey -> AuthHandler Request CheckedBiscuit
authHandler PublicKey
publicKey = (Request -> Handler CheckedBiscuit)
-> AuthHandler Request CheckedBiscuit
forall r usr. (r -> Handler usr) -> AuthHandler r usr
mkAuthHandler Request -> Handler CheckedBiscuit
handler
where
authError :: String -> ServerError
authError String
s = ServerError
err401 { errBody :: ByteString
errBody = ByteString -> ByteString
LBS.fromStrict (String -> ByteString
C8.pack String
s) }
orError :: Either String a -> Handler a
orError = (String -> Handler a)
-> (a -> Handler a) -> Either String a -> Handler a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (ServerError -> Handler a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ServerError -> Handler a)
-> (String -> ServerError) -> String -> Handler a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ServerError
authError) a -> Handler a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
handler :: Request -> Handler CheckedBiscuit
handler Request
req = do
Biscuit
biscuit <- Either String Biscuit -> Handler Biscuit
forall a. Either String a -> Handler a
orError (Either String Biscuit -> Handler Biscuit)
-> Either String Biscuit -> Handler Biscuit
forall a b. (a -> b) -> a -> b
$ Request -> Either String Biscuit
extractBiscuit Request
req
Bool
result <- IO Bool -> Handler Bool
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO Bool -> Handler Bool) -> IO Bool -> Handler Bool
forall a b. (a -> b) -> a -> b
$ Biscuit -> PublicKey -> IO Bool
checkBiscuitSignature Biscuit
biscuit PublicKey
publicKey
case Bool
result of
Bool
False -> ServerError -> Handler CheckedBiscuit
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ServerError -> Handler CheckedBiscuit)
-> ServerError -> Handler CheckedBiscuit
forall a b. (a -> b) -> a -> b
$ String -> ServerError
authError String
"Invalid signature"
Bool
True -> CheckedBiscuit -> Handler CheckedBiscuit
forall (f :: * -> *) a. Applicative f => a -> f a
pure (CheckedBiscuit -> Handler CheckedBiscuit)
-> CheckedBiscuit -> Handler CheckedBiscuit
forall a b. (a -> b) -> a -> b
$ PublicKey -> Biscuit -> CheckedBiscuit
CheckedBiscuit PublicKey
publicKey Biscuit
biscuit
genBiscuitCtx :: PublicKey -> Context '[AuthHandler Request CheckedBiscuit]
genBiscuitCtx :: PublicKey -> Context '[AuthHandler Request CheckedBiscuit]
genBiscuitCtx PublicKey
pk = PublicKey -> AuthHandler Request CheckedBiscuit
authHandler PublicKey
pk AuthHandler Request CheckedBiscuit
-> Context '[] -> Context '[AuthHandler Request CheckedBiscuit]
forall x (xs :: [*]). x -> Context xs -> Context (x : xs)
:. Context '[]
EmptyContext
checkBiscuit :: (MonadIO m, MonadError ServerError m)
=> CheckedBiscuit
-> Verifier
-> m a
-> m a
checkBiscuit :: CheckedBiscuit -> Verifier -> m a -> m a
checkBiscuit (CheckedBiscuit PublicKey
pk Biscuit
b) Verifier
v m a
h = do
Either VerificationError Query
res <- IO (Either VerificationError Query)
-> m (Either VerificationError Query)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either VerificationError Query)
-> m (Either VerificationError Query))
-> IO (Either VerificationError Query)
-> m (Either VerificationError Query)
forall a b. (a -> b) -> a -> b
$ Biscuit
-> Verifier -> PublicKey -> IO (Either VerificationError Query)
verifyBiscuit Biscuit
b Verifier
v PublicKey
pk
case Either VerificationError Query
res of
Left VerificationError
e -> do IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ VerificationError -> IO ()
forall a. Show a => a -> IO ()
print VerificationError
e
ServerError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ServerError -> m a) -> ServerError -> m a
forall a b. (a -> b) -> a -> b
$ ServerError
err401 { errBody :: ByteString
errBody = ByteString
"Biscuit failed checks" }
Right Query
_ -> m a
h
checkBiscuitM :: (MonadIO m, MonadError ServerError m)
=> CheckedBiscuit
-> m Verifier
-> m a
-> m a
checkBiscuitM :: CheckedBiscuit -> m Verifier -> m a -> m a
checkBiscuitM (CheckedBiscuit PublicKey
pk Biscuit
b) m Verifier
mv m a
h = do
Verifier
v <- m Verifier
mv
Either VerificationError Query
res <- IO (Either VerificationError Query)
-> m (Either VerificationError Query)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either VerificationError Query)
-> m (Either VerificationError Query))
-> IO (Either VerificationError Query)
-> m (Either VerificationError Query)
forall a b. (a -> b) -> a -> b
$ Biscuit
-> Verifier -> PublicKey -> IO (Either VerificationError Query)
verifyBiscuit Biscuit
b Verifier
v PublicKey
pk
case Either VerificationError Query
res of
Left VerificationError
e -> do IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ VerificationError -> IO ()
forall a. Show a => a -> IO ()
print VerificationError
e
ServerError -> m a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (ServerError -> m a) -> ServerError -> m a
forall a b. (a -> b) -> a -> b
$ ServerError
err401 { errBody :: ByteString
errBody = ByteString
"Biscuit failed checks" }
Right Query
_ -> m a
h
handleBiscuit :: (MonadIO m, MonadError ServerError m)
=> CheckedBiscuit
-> WithVerifier m a
-> m a
handleBiscuit :: CheckedBiscuit -> WithVerifier m a -> m a
handleBiscuit cb :: CheckedBiscuit
cb@(CheckedBiscuit PublicKey
_ Biscuit
b) WithVerifier{m Verifier
verifier_ :: m Verifier
verifier_ :: forall (m :: * -> *) a. WithVerifier m a -> m Verifier
verifier_, ReaderT Biscuit m a
handler_ :: ReaderT Biscuit m a
handler_ :: forall (m :: * -> *) a. WithVerifier m a -> ReaderT Biscuit m a
handler_} =
let h :: m a
h = ReaderT Biscuit m a -> Biscuit -> m a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT ReaderT Biscuit m a
handler_ Biscuit
b
in CheckedBiscuit -> m Verifier -> m a -> m a
forall (m :: * -> *) a.
(MonadIO m, MonadError ServerError m) =>
CheckedBiscuit -> m Verifier -> m a -> m a
checkBiscuitM CheckedBiscuit
cb m Verifier
verifier_ m a
h