module Thentos.CookieSession.CSRF
( CsrfSecret(..)
, CsrfToken(..)
, CsrfNonce(..)
, GetCsrfSecret(..)
, HasSessionCsrfToken(..)
, MonadHasSessionCsrfToken
, MonadViewCsrfSecret
, genCsrfSecret
, validFormatCsrfSecretField
, validFormatCsrfToken
, checkCsrfToken
, refreshCsrfToken
, clearCsrfToken
) where
import Control.Lens
import Control.Monad.Reader.Class (MonadReader)
import Control.Monad.State.Class (MonadState)
import Control.Monad (when)
import Crypto.Hash (SHA256)
import Crypto.MAC.HMAC (HMAC,hmac)
import Crypto.Random (MonadRandom(..))
import Data.Aeson (FromJSON, ToJSON)
import Data.ByteArray.Encoding (convertToBase, convertFromBase, Base(Base16))
import Data.String.Conversions (SBS, ST, cs, (<>))
import Data.String (IsString)
import Data.Typeable (Typeable)
import GHC.Generics (Generic)
import qualified Data.ByteString as SBS
import qualified Data.Text as ST
import Servant.Missing (MonadError500, throwError500)
import Thentos.CookieSession.Types (ThentosSessionToken(fromThentosSessionToken), MonadUseThentosSessionToken, getThentosSessionToken)
newtype CsrfToken = CsrfToken { fromCsrfToken :: ST }
deriving (Eq, Ord, Show, Read, FromJSON, ToJSON, Typeable, Generic, IsString)
newtype CsrfSecret = CsrfSecret SBS
deriving (Show, Eq)
newtype CsrfNonce = CsrfNonce SBS
deriving (Show, Eq)
class GetCsrfSecret a where
csrfSecret :: Getter a (Maybe CsrfSecret)
instance GetCsrfSecret ST.Text where
csrfSecret = to $ \secret -> do
let Right sbs = convertFromBase Base16 (cs secret :: SBS)
return $ CsrfSecret sbs
class HasSessionCsrfToken a where
sessionCsrfToken :: Lens' a (Maybe CsrfToken)
type MonadHasSessionCsrfToken s m = (MonadState s m, HasSessionCsrfToken s)
type MonadViewCsrfSecret e m = (MonadReader e m, GetCsrfSecret e)
validFormatCsrfSecretField :: Maybe ST -> Bool
validFormatCsrfSecretField ms
| Just t <- ms
, Right s' <- convertFromBase Base16 (cs t :: SBS)
= SBS.length s' == 32
| otherwise = False
validFormatCsrfToken :: CsrfToken -> Bool
validFormatCsrfToken (CsrfToken st)
| Right s' <- convertFromBase Base16 (cs st :: SBS) = SBS.length s' == 64
| otherwise = False
makeCsrfToken :: (MonadError500 err m, MonadViewCsrfSecret e m, MonadUseThentosSessionToken s m) =>
CsrfNonce -> m CsrfToken
makeCsrfToken (CsrfNonce rnd) = do
maySessionToken <- use getThentosSessionToken
case maySessionToken of
Nothing -> throwError500 "No session token"
Just sessionToken -> do
Just (CsrfSecret key) <- view csrfSecret
return $ CsrfToken . cs $ rnd <> convertToBase Base16 (hmac key (tok <> rnd) :: HMAC SHA256)
where
tok = cs $ fromThentosSessionToken sessionToken
csrfNonceFromCsrfToken :: CsrfToken -> CsrfNonce
csrfNonceFromCsrfToken = CsrfNonce . SBS.take 64 . cs . fromCsrfToken
checkCsrfToken :: (MonadError500 err m, MonadViewCsrfSecret e m, MonadUseThentosSessionToken s m) => CsrfToken -> m ()
checkCsrfToken csrfToken
| not (validFormatCsrfToken csrfToken) =
throwError500 $ "Ill-formatted CSRF Token " <> show csrfToken
| otherwise = do
csrfToken' <- makeCsrfToken (csrfNonceFromCsrfToken csrfToken)
when (csrfToken /= csrfToken') $
throwError500 "Invalid CSRF token"
genCsrfSecret :: MonadRandom m => m CsrfSecret
genCsrfSecret = CsrfSecret . (convertToBase Base16 :: SBS -> SBS) <$> getRandomBytes 32
genCsrfNonce :: MonadRandom m => m CsrfNonce
genCsrfNonce = CsrfNonce . (convertToBase Base16 :: SBS -> SBS) <$> getRandomBytes 32
refreshCsrfToken :: (MonadError500 err m, MonadHasSessionCsrfToken s m,
MonadRandom m, MonadViewCsrfSecret e m, MonadUseThentosSessionToken s m) => m ()
refreshCsrfToken = do
csrfToken <- makeCsrfToken =<< genCsrfNonce
sessionCsrfToken .= Just csrfToken
clearCsrfToken :: MonadHasSessionCsrfToken s m => m ()
clearCsrfToken = sessionCsrfToken .= Nothing