module Servant.Server.Experimental.Auth.Cookie
( CipherAlgorithm
, AuthCookieData
, Cookie (..)
, AuthCookieException (..)
, WithMetadata (..)
#if MIN_VERSION_servant(0,9,1)
, Cookied
, cookied
#endif
, RandomSource
, mkRandomSource
, getRandomBytes
, generateRandomBytes
, ServerKey
, ServerKeySet (..)
, PersistentServerKey
, mkPersistentServerKey
, RenewableKeySet
, RenewableKeySetHooks (..)
, mkRenewableKeySet
, AuthCookieSettings (..)
, EncryptedSession (..)
, emptyEncryptedSession
, encryptCookie
, decryptCookie
, encryptSession
, decryptSession
, addSession
, removeSession
, addSessionToErr
, removeSessionFromErr
, getSession
, defaultAuthHandler
, renderSession
, parseSessionRequest
, parseSessionResponse
) where
import Blaze.ByteString.Builder (toByteString)
import Control.Arrow ((&&&), first)
import Control.Monad
import Control.Monad.Catch (MonadThrow (..), Exception)
import Control.Monad.Except
import Crypto.Cipher.AES (AES256)
import Crypto.Cipher.Types
import Crypto.Error
import Crypto.Hash (HashAlgorithm(..))
import Crypto.Hash.Algorithms (SHA256)
import Crypto.MAC.HMAC (HMAC)
import Crypto.Random (DRG(..), drgNew)
import Data.ByteString (ByteString)
import Data.Default
import Data.IORef
import Data.List (partition)
import Data.Maybe (listToMaybe)
import Data.Monoid ((<>))
import Data.Proxy
import Data.Serialize
import Data.Time
import Data.Tagged (Tagged (..), retag)
import Data.Typeable
import GHC.TypeLits (Symbol)
import Network.HTTP.Types (hCookie, HeaderName, RequestHeaders, ResponseHeaders)
import Network.Wai (Request, requestHeaders)
import Servant (addHeader, ServantErr (..))
import Servant.API.Experimental.Auth (AuthProtect)
import Servant.API.ResponseHeaders (AddHeader)
import Servant.Server (err403)
import Servant.Server.Experimental.Auth
import Web.Cookie
import qualified Crypto.MAC.HMAC as H
import qualified Data.ByteArray as BA
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base64 as Base64
import qualified Data.ByteString.Char8 as BSC8
import qualified Network.HTTP.Types as N(Header)
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative
#endif
#if MIN_VERSION_servant(0,9,0)
import Servant (ToHttpApiData (..))
#else
import Data.ByteString.Conversion (ToByteString (..))
#endif
#if MIN_VERSION_servant(0,9,1)
import Servant (noHeader, Handler)
import Servant.API.ResponseHeaders (Headers)
import qualified Servant.API.Header as S(Header)
#endif
#if MIN_VERSION_http_types(0,9,2)
import Network.HTTP.Types (hSetCookie)
#endif
#if MIN_VERSION_http_types(0,9,2)
#else
hSetCookie :: HeaderName
hSetCookie = "Set-Cookie"
#endif
type CipherAlgorithm c = c -> IV c -> ByteString -> ByteString
type family AuthCookieData
data WithMetadata a = WithMetadata
{ wmData :: a
, wmRenew :: Bool
}
type instance AuthServerData (AuthProtect "cookie-auth") = WithMetadata AuthCookieData
data Cookie = Cookie
{ cookieIV :: ByteString
, cookieExpirationTime :: UTCTime
, cookiePayload :: ByteString
} deriving (Eq, Show)
newtype EncryptedSession = EncryptedSession ByteString
deriving (Eq, Show, Typeable)
emptyEncryptedSession :: EncryptedSession
emptyEncryptedSession = EncryptedSession ""
#if MIN_VERSION_servant(0,9,0)
instance ToHttpApiData EncryptedSession where
toHeader (EncryptedSession s) = s
toUrlPiece = error "toUrlPiece @EncryptedSession: not implemented"
#else
instance ToByteString EncryptedSession where
builder (EncryptedSession s) = builder s
#endif
#if MIN_VERSION_servant(0,9,1)
type Cookied a = Headers '[S.Header "Set-Cookie" EncryptedSession] a
#endif
data AuthCookieException
= CannotMakeIV ByteString
| BadProperKey CryptoError
| TooShortProperKey Int Int
| IncorrectMAC ByteString
| CannotParseExpirationTime ByteString
| CookieExpired UTCTime UTCTime
| SessionDeserializationFailed String
deriving (Eq, Show, Typeable)
instance Exception AuthCookieException
data EncryptedCookie
data SerializedEncryptedCookie
base64Encode :: Tagged EncryptedCookie ByteString -> Tagged SerializedEncryptedCookie ByteString
base64Encode = retag . fmap Base64.encode
base64Decode
:: Tagged SerializedEncryptedCookie ByteString
-> Either String (Tagged EncryptedCookie ByteString)
base64Decode = fmap Tagged . Base64.decode . unTagged
data RandomSource where
RandomSource :: DRG d => IO d -> Int -> IORef (d, Int) -> RandomSource
mkRandomSource :: (MonadIO m, DRG d)
=> IO d
-> Int
-> m RandomSource
mkRandomSource mkDRG threshold =
RandomSource mkDRG threshold `liftM` liftIO ((,0) <$> mkDRG >>= newIORef)
getRandomBytes :: MonadIO m
=> RandomSource
-> Int
-> m ByteString
getRandomBytes (RandomSource mkDRG threshold ref) n = do
freshDRG <- liftIO mkDRG
liftIO . atomicModifyIORef' ref $ \(drg, bytes) ->
let (result, drg') = randomBytesGenerate n drg
bytes' = bytes + n
in if bytes' >= threshold
then ((freshDRG, 0), result)
else ((drg', bytes'), result)
type ServerKey = ByteString
class ServerKeySet k where
getKeys :: (MonadThrow m, MonadIO m) => k -> m (ServerKey, [ServerKey])
removeKey :: (MonadThrow m, MonadIO m) => k -> ServerKey -> m ()
data PersistentServerKey = PersistentServerKey
{ pskBytes :: ServerKey }
instance ServerKeySet PersistentServerKey where
getKeys = return . (,[]) . pskBytes
removeKey _ = error "removeKey @PersistentServerKey: not implemented"
mkPersistentServerKey :: ByteString -> PersistentServerKey
mkPersistentServerKey bytes = PersistentServerKey { pskBytes = bytes }
data RenewableKeySetHooks s p = RenewableKeySetHooks
{ rkshNewState :: forall m. (MonadIO m, MonadThrow m)
=> p
-> ([ServerKey], s)
-> m ([ServerKey], s)
, rkshNeedUpdate :: forall m. (MonadIO m, MonadThrow m)
=> p
-> ([ServerKey], s)
-> m Bool
, rkshRemoveKey :: forall m. (MonadIO m, MonadThrow m)
=> p
-> ServerKey
-> m ()
}
data RenewableKeySet s p = RenewableKeySet
{ rksState :: IORef ([ServerKey], s)
, rksParameters :: p
, rksHooks :: RenewableKeySetHooks s p
}
instance (Eq s) => ServerKeySet (RenewableKeySet s p) where
getKeys RenewableKeySet {..} = getKeys' rksHooks where
getKeys' RenewableKeySetHooks {..} = do
state <- liftIO $ readIORef rksState
rkshNeedUpdate rksParameters state
>>= \needUpdate -> if not needUpdate
then return $ toResult state
else do
state' <- rkshNewState rksParameters state
liftIO $ atomicModifyIORef' rksState $ \state'' -> id &&& toResult $
if (userState state /= userState state'')
then state''
else state'
toResult = (head &&& tail) . fst
userState = snd
removeKey RenewableKeySet {..} key = do
found <- liftIO $ atomicModifyIORef' rksState $ \(keys, s) -> let
(found, keys') = first (not . null) . partition (== key) $ keys
in ((keys', s), found)
when found $ (rkshRemoveKey rksHooks) rksParameters key
mkRenewableKeySet :: (MonadIO m)
=> RenewableKeySetHooks s p
-> p
-> s
-> m (RenewableKeySet s p)
mkRenewableKeySet rksHooks rksParameters userState = liftIO $ do
rksState <- newIORef ([], userState)
return RenewableKeySet {..}
data AuthCookieSettings where
AuthCookieSettings :: (HashAlgorithm h, BlockCipher c) =>
{ acsSessionField :: ByteString
, acsCookieFlags :: [ByteString]
, acsMaxAge :: NominalDiffTime
, acsExpirationFormat :: String
, acsPath :: ByteString
, acsHashAlgorithm :: Proxy h
, acsCipher :: Proxy c
, acsEncryptAlgorithm :: CipherAlgorithm c
, acsDecryptAlgorithm :: CipherAlgorithm c
} -> AuthCookieSettings
instance Default AuthCookieSettings where
def = AuthCookieSettings
{ acsSessionField = "Session"
, acsCookieFlags = ["HttpOnly", "Secure"]
, acsMaxAge = fromIntegral (12 * 3600 :: Integer)
, acsExpirationFormat = "%0Y%m%d%H%M%S"
, acsPath = "/"
, acsHashAlgorithm = Proxy :: Proxy SHA256
, acsCipher = Proxy :: Proxy AES256
, acsEncryptAlgorithm = ctrCombine
, acsDecryptAlgorithm = ctrCombine }
encryptCookie :: (MonadIO m, MonadThrow m, ServerKeySet k)
=> AuthCookieSettings
-> k
-> Cookie
-> m (Tagged EncryptedCookie ByteString)
encryptCookie AuthCookieSettings {..} sks cookie = do
let iv = cookieIV cookie
expiration = BSC8.pack $ formatTime
defaultTimeLocale
acsExpirationFormat
(cookieExpirationTime cookie)
(serverKey, _) <- getKeys sks
key <- mkProperKey
(cipherKeySize $ unProxy acsCipher)
(sign acsHashAlgorithm serverKey $ iv <> expiration)
payload <- applyCipherAlgorithm acsEncryptAlgorithm
iv key (cookiePayload cookie)
let mac = sign acsHashAlgorithm serverKey
(BS.concat [iv, expiration, payload])
return . Tagged . runPut $ do
putByteString iv
putByteString expiration
putByteString payload
putByteString mac
decryptCookie :: (MonadIO m, MonadThrow m, ServerKeySet k)
=> AuthCookieSettings
-> k
-> Tagged EncryptedCookie ByteString
-> m (WithMetadata Cookie)
decryptCookie AuthCookieSettings {..} sks (Tagged s) = do
currentTime <- liftIO getCurrentTime
let ivSize = blockSize (unProxy acsCipher)
expSize =
length (formatTime defaultTimeLocale acsExpirationFormat currentTime)
payloadSize = BS.length s ivSize expSize
hashDigestSize (unProxy acsHashAlgorithm)
butMacSize = ivSize + expSize + payloadSize
(iv, s0) = BS.splitAt ivSize s
(expirationRaw, s1) = BS.splitAt expSize s0
(payloadRaw, mac) = BS.splitAt payloadSize s1
checkMac sk = mac == sign acsHashAlgorithm sk (BS.take butMacSize s)
(currentKey, rotatedKeys) <- getKeys sks
(serverKey, renew) <- if checkMac currentKey
then return (currentKey, False)
else liftM (,True) $ maybe
(throwM $ IncorrectMAC mac)
(return)
(listToMaybe . map fst . filter snd . map (id &&& checkMac) $ rotatedKeys)
expirationTime <-
maybe (throwM $ CannotParseExpirationTime expirationRaw) return $
parseTimeM False defaultTimeLocale acsExpirationFormat
(BSC8.unpack expirationRaw)
when (currentTime >= expirationTime) $
throwM (CookieExpired expirationTime currentTime)
key <- mkProperKey
(cipherKeySize (unProxy acsCipher))
(sign acsHashAlgorithm serverKey $ BS.take (ivSize + expSize) s)
payload <- applyCipherAlgorithm acsDecryptAlgorithm iv key payloadRaw
let cookie = Cookie
{ cookieIV = iv
, cookieExpirationTime = expirationTime
, cookiePayload = payload }
return WithMetadata
{ wmData = cookie
, wmRenew = renew
}
encryptSession :: (MonadIO m, MonadThrow m, Serialize a, ServerKeySet k)
=> AuthCookieSettings
-> RandomSource
-> k
-> a
-> m (Tagged SerializedEncryptedCookie ByteString)
encryptSession acs@AuthCookieSettings {..} randomSource sk session = do
iv <- getRandomBytes randomSource (blockSize $ unProxy acsCipher)
expirationTime <- liftM (addUTCTime acsMaxAge) (liftIO getCurrentTime)
let payload = runPut (put session)
padding <-
let bs = blockSize (unProxy acsCipher)
n = BS.length payload
l = (bs (n `rem` bs)) `rem` bs
in getRandomBytes randomSource l
base64Encode `liftM` encryptCookie acs sk (Cookie
{ cookieIV = iv
, cookieExpirationTime = expirationTime
, cookiePayload = BS.concat [payload, padding] })
decryptSession :: (MonadIO m, MonadThrow m, Serialize a, ServerKeySet k)
=> AuthCookieSettings
-> k
-> Tagged SerializedEncryptedCookie ByteString
-> m (WithMetadata a)
decryptSession acs@AuthCookieSettings {..} sks s =
let fromRight = either (throwM . SessionDeserializationFailed) return
in fromRight (base64Decode s) >>=
decryptCookie acs sks >>=
\w -> do
session <- fromRight . runGet get . cookiePayload $ wmData w
return w { wmData = session }
addSession
:: ( MonadIO m
, MonadThrow m
, Serialize a
, AddHeader (e :: Symbol) EncryptedSession s r
, ServerKeySet k )
=> AuthCookieSettings
-> RandomSource
-> k
-> a
-> s
-> m r
addSession acs rs sk sessionData response = do
header <- renderSession acs rs sk sessionData
return (addHeader (EncryptedSession header) response)
removeSession :: ( Monad m,
AddHeader (e :: Symbol) EncryptedSession s r )
=> AuthCookieSettings
-> s
-> m r
removeSession acs response =
return (addHeader (EncryptedSession $ expiredCookie acs) response)
addSessionToErr
:: ( MonadIO m
, MonadThrow m
, Serialize a
, ServerKeySet k )
=> AuthCookieSettings
-> RandomSource
-> k
-> a
-> ServantErr
-> m ServantErr
addSessionToErr acs rs sk sessionData err = do
header <- renderSession acs rs sk sessionData
return err { errHeaders = (hSetCookie, header) : errHeaders err }
removeSessionFromErr :: ( Monad m )
=> AuthCookieSettings
-> ServantErr
-> m ServantErr
removeSessionFromErr acs err =
return $ err { errHeaders = (hSetCookie, expiredCookie acs) : errHeaders err }
expiredCookie :: AuthCookieSettings -> ByteString
expiredCookie AuthCookieSettings{..} = (toByteString . renderCookies) cookies
where
cookies =
(acsSessionField, "") :
("Path", acsPath) :
("Expires", invalidDate) :
((,"") <$> acsCookieFlags)
invalidDate = BSC8.pack $ formatTime
defaultTimeLocale
acsExpirationFormat
timeOrigin
timeOrigin = UTCTime (toEnum 0) 0
getSession :: (MonadIO m, MonadThrow m, Serialize a, ServerKeySet k)
=> AuthCookieSettings
-> k
-> Request
-> m (Maybe (WithMetadata a))
getSession acs@AuthCookieSettings {..} sk request = maybe
(return Nothing)
(liftM Just . decryptSession acs sk)
(parseSessionRequest acs $ requestHeaders request)
parseSession
:: AuthCookieSettings
-> HeaderName
-> [N.Header]
-> Maybe (Tagged SerializedEncryptedCookie ByteString)
parseSession AuthCookieSettings {..} hdr hdrs = sessionBinary where
cookies = parseCookies <$> lookup hdr hdrs
sessionBinary = Tagged <$> (cookies >>= lookup acsSessionField)
parseSessionRequest
:: AuthCookieSettings
-> RequestHeaders
-> Maybe (Tagged SerializedEncryptedCookie ByteString)
parseSessionRequest acs hdrs = parseSession acs hCookie hdrs
parseSessionResponse
:: AuthCookieSettings
-> ResponseHeaders
-> Maybe (Tagged SerializedEncryptedCookie ByteString)
parseSessionResponse acs hdrs = parseSession acs hSetCookie hdrs
renderSession
:: ( MonadIO m
, MonadThrow m
, Serialize a
, ServerKeySet k )
=> AuthCookieSettings
-> RandomSource
-> k
-> a
-> m ByteString
renderSession acs@AuthCookieSettings {..} rs sk sessionData = do
Tagged sessionBinary <- encryptSession acs rs sk sessionData
let cookies =
(acsSessionField, sessionBinary) :
("Path", acsPath) :
("Max-Age", (BSC8.pack . show . n) acsMaxAge) :
((,"") <$> acsCookieFlags)
n = floor :: NominalDiffTime -> Int
(return . toByteString . renderCookies) cookies
#if MIN_VERSION_servant(0,9,1)
cookied :: (Serialize a, ServerKeySet k)
=> AuthCookieSettings
-> RandomSource
-> k
-> (a -> r)
-> ((WithMetadata a) -> Handler (Cookied r))
cookied acs rs k f = \(WithMetadata {..}) ->
(if wmRenew then addSession acs rs k wmData else (return . noHeader)) $ f wmData
#endif
defaultAuthHandler :: (Serialize a, ServerKeySet k)
=> AuthCookieSettings
-> k
-> AuthHandler Request (WithMetadata a)
defaultAuthHandler acs sk = mkAuthHandler $ \request -> do
msession <- liftIO (getSession acs sk request)
maybe (throwError err403) return msession
sign :: forall h. HashAlgorithm h
=> Proxy h
-> ByteString
-> ByteString
-> ByteString
sign Proxy key msg = BA.convert (H.hmac key msg :: HMAC h)
mkProperKey :: MonadThrow m
=> KeySizeSpecifier
-> ByteString
-> m ByteString
mkProperKey kss s = do
let klen = BS.length s
giveUp l = throwM (TooShortProperKey l klen)
plen <- case kss of
KeySizeRange l r ->
if klen < l
then giveUp l
else return (min klen r)
KeySizeEnum ls ->
case filter (<= klen) ls of
[] -> giveUp (minimum ls)
xs -> return (maximum xs)
KeySizeFixed l ->
if klen < l
then giveUp l
else return l
return (BS.take plen s)
applyCipherAlgorithm :: forall c m. (BlockCipher c, MonadThrow m)
=> CipherAlgorithm c
-> ByteString
-> ByteString
-> ByteString
-> m ByteString
applyCipherAlgorithm f ivRaw keyRaw msg = do
iv <- case makeIV ivRaw :: Maybe (IV c) of
Nothing -> throwM (CannotMakeIV ivRaw)
Just x -> return x
key <- case cipherInit keyRaw :: CryptoFailable c of
CryptoFailed err -> throwM (BadProperKey err)
CryptoPassed x -> return x
(return . BA.convert) (f key iv msg)
unProxy :: Proxy a -> a
unProxy Proxy = undefined
generateRandomBytes :: Int -> IO ByteString
generateRandomBytes size = (fst . randomBytesGenerate size <$> drgNew)