module Servant.Server.Experimental.Auth.Cookie
( CipherAlgorithm
, AuthCookieData
, Cookie (..)
, AuthCookieException (..)
, RandomSource
, mkRandomSource
, getRandomBytes
, ServerKey
, mkServerKey
, mkServerKeyFromBytes
, getServerKey
, AuthCookieSettings (..)
, encryptCookie
, decryptCookie
, encryptSession
, decryptSession
, addSession
, addSessionToErr
, getSession
, renderSession
, defaultAuthHandler
) where
import Blaze.ByteString.Builder (toByteString)
import Control.Monad
import Control.Monad.Catch (MonadThrow (..), Exception)
import Control.Monad.Except
import Crypto.Cipher.AES (AES256)
import Crypto.Cipher.Types
import Crypto.Error
import Crypto.Hash (HashAlgorithm(..))
import Crypto.Hash.Algorithms (SHA256)
import Crypto.MAC.HMAC (HMAC)
import Crypto.Random (drgNew, DRG(..))
import Data.ByteString (ByteString)
import Data.Default
import Data.IORef
import Data.Maybe (fromMaybe, isNothing)
import Data.Monoid ((<>))
import Data.Proxy
import Data.Serialize
import Data.Time
import Data.Typeable
import GHC.TypeLits (Symbol)
import Network.HTTP.Types (hCookie)
import Network.Wai (Request, requestHeaders)
import Servant (addHeader, ServantErr (..))
import Servant.API.Experimental.Auth (AuthProtect)
import Servant.API.ResponseHeaders (AddHeader)
import Servant.Server (err403)
import Servant.Server.Experimental.Auth
import Web.Cookie
import qualified Crypto.MAC.HMAC as H
import qualified Data.ByteArray as BA
import qualified Data.ByteString as BS
import qualified Data.ByteString.Base64 as Base64
import qualified Data.ByteString.Char8 as BSC8
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative
#endif
type CipherAlgorithm c = c -> IV c -> ByteString -> ByteString
type family AuthCookieData
type instance AuthServerData (AuthProtect "cookie-auth") = AuthCookieData
data Cookie = Cookie
{ cookieIV :: ByteString
, cookieExpirationTime :: UTCTime
, cookiePayload :: ByteString
} deriving (Eq, Show)
data AuthCookieException
= CannotMakeIV ByteString
| BadProperKey CryptoError
| TooShortProperKey Int Int
| IncorrectMAC ByteString
| CannotParseExpirationTime ByteString
| CookieExpired UTCTime UTCTime
| SessionDeserializationFailed String
deriving (Eq, Show, Typeable)
instance Exception AuthCookieException
data RandomSource where
RandomSource :: DRG d => IO d -> Int -> IORef (d, Int) -> RandomSource
mkRandomSource :: (MonadIO m, DRG d)
=> IO d
-> Int
-> m RandomSource
mkRandomSource mkDRG threshold =
RandomSource mkDRG threshold `liftM` liftIO ((,0) <$> mkDRG >>= newIORef)
getRandomBytes :: MonadIO m
=> RandomSource
-> Int
-> m ByteString
getRandomBytes (RandomSource mkDRG threshold ref) n = do
freshDRG <- liftIO mkDRG
liftIO . atomicModifyIORef' ref $ \(drg, bytes) ->
let (result, drg') = randomBytesGenerate n drg
bytes' = bytes + n
in if bytes' >= threshold
then ((freshDRG, 0), result)
else ((drg', bytes'), result)
data ServerKey =
ServerKey Int (Maybe NominalDiffTime) (IORef (ByteString, UTCTime))
mkServerKey :: MonadIO m
=> Int
-> Maybe NominalDiffTime
-> m ServerKey
mkServerKey size maxAge =
ServerKey size maxAge `liftM` liftIO (mkServerKeyState size maxAge >>= newIORef)
mkServerKeyFromBytes :: MonadIO m
=> ByteString
-> m ServerKey
mkServerKeyFromBytes bytes =
ServerKey (BS.length bytes) Nothing `liftM` liftIO (newIORef (bytes, timeOrigin))
where
timeOrigin = UTCTime (toEnum 0) 0
getServerKey :: MonadIO m
=> ServerKey
-> m ByteString
getServerKey (ServerKey size maxAge ref) = do
currentTime <- liftIO getCurrentTime
(key', expirationTime') <- mkServerKeyState size maxAge
liftIO . atomicModifyIORef' ref $ \(key, expirationTime) ->
let expired =
if isNothing maxAge
then False
else currentTime > expirationTime
in if expired
then ((key', expirationTime'), key')
else ((key, expirationTime), key)
mkServerKeyState :: MonadIO m
=> Int
-> Maybe NominalDiffTime
-> m (ByteString, UTCTime)
mkServerKeyState size maxAge = liftIO $ do
key <- fst . randomBytesGenerate size <$> drgNew
time <- addUTCTime (fromMaybe 0 maxAge) <$> getCurrentTime
return (key, time)
data AuthCookieSettings where
AuthCookieSettings :: (HashAlgorithm h, BlockCipher c) =>
{ acsSessionField :: ByteString
, acsCookieFlags :: [ByteString]
, acsMaxAge :: NominalDiffTime
, acsExpirationFormat :: String
, acsPath :: ByteString
, acsHashAlgorithm :: Proxy h
, acsCipher :: Proxy c
, acsEncryptAlgorithm :: CipherAlgorithm c
, acsDecryptAlgorithm :: CipherAlgorithm c
} -> AuthCookieSettings
instance Default AuthCookieSettings where
def = AuthCookieSettings
{ acsSessionField = "Session"
, acsCookieFlags = ["HttpOnly", "Secure"]
, acsMaxAge = fromIntegral (12 * 3600 :: Integer)
, acsExpirationFormat = "%0Y%m%d%H%M%S"
, acsPath = "/"
, acsHashAlgorithm = Proxy :: Proxy SHA256
, acsCipher = Proxy :: Proxy AES256
, acsEncryptAlgorithm = ctrCombine
, acsDecryptAlgorithm = ctrCombine }
encryptCookie :: (MonadIO m, MonadThrow m)
=> AuthCookieSettings
-> ServerKey
-> Cookie
-> m ByteString
encryptCookie AuthCookieSettings {..} sk cookie = do
let iv = cookieIV cookie
expiration = BSC8.pack $ formatTime
defaultTimeLocale
acsExpirationFormat
(cookieExpirationTime cookie)
serverKey <- getServerKey sk
key <- mkProperKey
(cipherKeySize $ unProxy acsCipher)
(sign acsHashAlgorithm serverKey $ iv <> expiration)
payload <- applyCipherAlgorithm acsEncryptAlgorithm
iv key (cookiePayload cookie)
let mac = sign acsHashAlgorithm serverKey
(BS.concat [iv, expiration, payload])
return . runPut $ do
putByteString iv
putByteString expiration
putByteString payload
putByteString mac
decryptCookie :: (MonadIO m, MonadThrow m)
=> AuthCookieSettings
-> ServerKey
-> ByteString
-> m Cookie
decryptCookie AuthCookieSettings {..} sk s = do
currentTime <- liftIO getCurrentTime
let ivSize = blockSize (unProxy acsCipher)
expSize =
length (formatTime defaultTimeLocale acsExpirationFormat currentTime)
payloadSize = BS.length s ivSize expSize
hashDigestSize (unProxy acsHashAlgorithm)
butMacSize = ivSize + expSize + payloadSize
(iv, s0) = BS.splitAt ivSize s
(expirationRaw, s1) = BS.splitAt expSize s0
(payloadRaw, mac) = BS.splitAt payloadSize s1
serverKey <- getServerKey sk
when (mac /= sign acsHashAlgorithm serverKey (BS.take butMacSize s)) $
throwM (IncorrectMAC mac)
expirationTime <-
maybe (throwM $ CannotParseExpirationTime expirationRaw) return $
parseTimeM False defaultTimeLocale acsExpirationFormat
(BSC8.unpack expirationRaw)
when (currentTime >= expirationTime) $
throwM (CookieExpired expirationTime currentTime)
key <- mkProperKey
(cipherKeySize (unProxy acsCipher))
(sign acsHashAlgorithm serverKey $ BS.take (ivSize + expSize) s)
payload <- applyCipherAlgorithm acsDecryptAlgorithm iv key payloadRaw
return Cookie
{ cookieIV = iv
, cookieExpirationTime = expirationTime
, cookiePayload = payload }
encryptSession :: (MonadIO m, MonadThrow m, Serialize a)
=> AuthCookieSettings
-> RandomSource
-> ServerKey
-> a
-> m ByteString
encryptSession acs@AuthCookieSettings {..} randomSource sk session = do
iv <- getRandomBytes randomSource (blockSize $ unProxy acsCipher)
expirationTime <- liftM (addUTCTime acsMaxAge) (liftIO getCurrentTime)
let payload = runPut (put session)
padding <-
let bs = blockSize (unProxy acsCipher)
n = BS.length payload
l = (bs (n `rem` bs)) `rem` bs
in getRandomBytes randomSource l
Base64.encode `liftM` encryptCookie acs sk (Cookie
{ cookieIV = iv
, cookieExpirationTime = expirationTime
, cookiePayload = BS.concat [payload, padding] })
decryptSession :: (MonadIO m, MonadThrow m, Serialize a)
=> AuthCookieSettings
-> ServerKey
-> ByteString
-> m a
decryptSession acs@AuthCookieSettings {..} sk s =
let fromRight = either (throwM . SessionDeserializationFailed) return
in fromRight (Base64.decode s) >>=
decryptCookie acs sk >>=
fromRight . runGet get . cookiePayload
addSession
:: ( MonadIO m
, MonadThrow m
, Serialize a
, AddHeader (e :: Symbol) ByteString s r )
=> AuthCookieSettings
-> RandomSource
-> ServerKey
-> a
-> s
-> m r
addSession acs rs sk sessionData response = do
header <- renderSession acs rs sk sessionData
return (addHeader header response)
addSessionToErr
:: ( MonadIO m
, MonadThrow m
, Serialize a )
=> AuthCookieSettings
-> RandomSource
-> ServerKey
-> a
-> ServantErr
-> m ServantErr
addSessionToErr acs rs sk sessionData err = do
header <- renderSession acs rs sk sessionData
return err { errHeaders = ("set-cookie", header) : errHeaders err }
getSession :: (MonadIO m, MonadThrow m, Serialize a)
=> AuthCookieSettings
-> ServerKey
-> Request
-> m (Maybe a)
getSession acs@AuthCookieSettings {..} sk request = do
let cookies = parseCookies <$> lookup hCookie (requestHeaders request)
sessionBinary = cookies >>= lookup acsSessionField
maybe (return Nothing) (liftM Just . decryptSession acs sk) sessionBinary
renderSession
:: ( MonadIO m
, MonadThrow m
, Serialize a )
=> AuthCookieSettings
-> RandomSource
-> ServerKey
-> a
-> m ByteString
renderSession acs@AuthCookieSettings {..} rs sk sessionData = do
sessionBinary <- encryptSession acs rs sk sessionData
let cookies =
(acsSessionField, sessionBinary) :
("Path", acsPath) :
("Max-Age", (BSC8.pack . show . n) acsMaxAge) :
((,"") <$> acsCookieFlags)
n = floor :: NominalDiffTime -> Int
(return . toByteString . renderCookies) cookies
defaultAuthHandler :: Serialize a
=> AuthCookieSettings
-> ServerKey
-> AuthHandler Request a
defaultAuthHandler acs sk = mkAuthHandler $ \request -> do
msession <- liftIO (getSession acs sk request)
maybe (throwError err403) return msession
sign :: forall h. HashAlgorithm h
=> Proxy h
-> ByteString
-> ByteString
-> ByteString
sign Proxy key msg = BA.convert (H.hmac key msg :: HMAC h)
mkProperKey :: MonadThrow m
=> KeySizeSpecifier
-> ByteString
-> m ByteString
mkProperKey kss s = do
let klen = BS.length s
giveUp l = throwM (TooShortProperKey l klen)
plen <- case kss of
KeySizeRange l r ->
if klen < l
then giveUp l
else return (min klen r)
KeySizeEnum ls ->
case filter (<= klen) ls of
[] -> giveUp (minimum ls)
xs -> return (maximum xs)
KeySizeFixed l ->
if klen < l
then giveUp l
else return l
return (BS.take plen s)
applyCipherAlgorithm :: forall c m. (BlockCipher c, MonadThrow m)
=> CipherAlgorithm c
-> ByteString
-> ByteString
-> ByteString
-> m ByteString
applyCipherAlgorithm f ivRaw keyRaw msg = do
iv <- case makeIV ivRaw :: Maybe (IV c) of
Nothing -> throwM (CannotMakeIV ivRaw)
Just x -> return x
key <- case cipherInit keyRaw :: CryptoFailable c of
CryptoFailed err -> throwM (BadProperKey err)
CryptoPassed x -> return x
(return . BA.convert) (f key iv msg)
unProxy :: Proxy a -> a
unProxy Proxy = undefined