module Cookie.Secure (encryptAndSign
                    , verifyAndDecrypt
                    , encryptAndSignIO
                    , verifyAndDecryptIO) where

import Data.ByteString (ByteString)
import qualified Data.ByteString.Char8 as BS
import Crypto.Error (CryptoFailable, maybeCryptoError, throwCryptoErrorIO)
import System.Random (getStdRandom, randomR)
import Data.Char (chr)
import Control.Monad (replicateM)
import System.Environment (getEnv)

import Crypto.Encryption (encrypt, decrypt)
import Crypto.Verification (sign
                          , verify
                          , serialize
                          , deserialize
                          , getSignable)

encryptAndSign
  :: String
  -> String
  -> String
  -> ByteString
  -> CryptoFailable ByteString
encryptAndSign :: String
-> String -> String -> ByteString -> CryptoFailable ByteString
encryptAndSign String
iv String
encryptKey String
authKey ByteString
message = forall signable. Signable signable => Signed signable -> ByteString
serialize forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CryptoFailable (Signed Encrypted)
signed
  where
    signed :: CryptoFailable (Signed Encrypted)
signed = forall signable.
Signable signable =>
String -> signable -> Signed signable
sign String
authKey forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CryptoFailable Encrypted
encrypted
    encrypted :: CryptoFailable Encrypted
encrypted = String -> String -> ByteString -> CryptoFailable Encrypted
encrypt String
iv String
encryptKey ByteString
message

-- OPTIMIZE: wrap result in Either errorType, instead of Maybe.
-- Ideally, wrap it in a CryptoFailable, but that does not take
-- any error type except CryptoError, which has no constructors
-- for any signing/verification failures (/deserialization).
verifyAndDecrypt :: String -> String -> ByteString -> Maybe ByteString
verifyAndDecrypt :: String -> String -> ByteString -> Maybe ByteString
verifyAndDecrypt String
authKey String
encryptKey ByteString
message =
  forall signable.
Signable signable =>
ByteString -> Maybe (Signed signable)
deserialize ByteString
message forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Signed Encrypted -> Maybe ByteString
verifyAndDecryptDeserialized
    where
      verifyAndDecryptDeserialized :: Signed Encrypted -> Maybe ByteString
verifyAndDecryptDeserialized Signed Encrypted
signed =
        if forall signable.
Signable signable =>
String -> Signed signable -> Bool
verify String
authKey Signed Encrypted
signed
        then forall a. CryptoFailable a -> Maybe a
maybeCryptoError forall a b. (a -> b) -> a -> b
$ String -> Encrypted -> CryptoFailable ByteString
decrypt String
encryptKey (forall signable. Signable signable => Signed signable -> signable
getSignable Signed Encrypted
signed)
        else forall a. Maybe a
Nothing

encryptAndSignIO :: ByteString -> IO ByteString
encryptAndSignIO :: ByteString -> IO ByteString
encryptAndSignIO ByteString
message = do
  (String
iv, String
validationKey, String
encryptionKey) <- IO (String, String, String)
getIVAuthKeyEncryptKey

  forall a. CryptoFailable a -> IO a
throwCryptoErrorIO
    forall a b. (a -> b) -> a -> b
$ String
-> String -> String -> ByteString -> CryptoFailable ByteString
encryptAndSign String
iv String
encryptionKey String
validationKey ByteString
message

verifyAndDecryptIO :: ByteString -> IO (Maybe ByteString)
verifyAndDecryptIO :: ByteString -> IO (Maybe ByteString)
verifyAndDecryptIO ByteString
message = do
  (String
_, String
validationKey, String
encryptionKey) <- IO (String, String, String)
getIVAuthKeyEncryptKey

  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ String -> String -> ByteString -> Maybe ByteString
verifyAndDecrypt String
validationKey String
encryptionKey ByteString
message

getIVAuthKeyEncryptKey :: IO (String, String, String)
getIVAuthKeyEncryptKey :: IO (String, String, String)
getIVAuthKeyEncryptKey = (,,)
  -- The function takes a string for the IV, but the AES-256/CTR algorithm
  -- is just looking for bytes. Printability in ASCII, UTF-8, or any other
  -- encoding doesn't matter.
  forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO String
get16RandomBytes
  forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> IO String
getEnv String
"WAI_COOKIE_VALIDATION_KEY"
  forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> IO String
getEnv String
"WAI_COOKIE_ENCRYPTION_KEY"
    where
      get16RandomBytes :: IO String
get16RandomBytes =
        forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
16 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. MonadIO m => (StdGen -> (a, StdGen)) -> m a
getStdRandom forall a b. (a -> b) -> a -> b
$ forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (Int -> Char
chr Int
0, Int -> Char
chr Int
255)