{-# LINE 1 "OpenSSL/Cipher.hsc" #-}



{-# LANGUAGE EmptyDataDecls           #-}
{-# LANGUAGE ForeignFunctionInterface #-}
{-# LANGUAGE CApiFFI #-}
-- | This module interfaces to some of the OpenSSL ciphers without using
--   EVP (see OpenSSL.EVP.Cipher). The EVP ciphers are easier to use,
--   however, in some cases you cannot do without using the OpenSSL
--   fuctions directly.
--
--   One of these cases (and the motivating example
--   for this module) is that the EVP CBC functions try to encode the
--   length of the input string in the output (thus hiding the fact that the
--   cipher is, in fact, block based and needs padding). This means that the
--   EVP CBC functions cannot, in some cases, interface with other users
--   which don't use that system (like SSH).
module OpenSSL.Cipher
    ( Mode(..)
    , AESCtx
    , newAESCtx
    , aesCBC

{-# LINE 26 "OpenSSL/Cipher.hsc" #-}
    )
    where

import           Control.Monad (when, unless)
import           Data.IORef
import           Foreign
import           Foreign.C.Types
import qualified Data.ByteString as BS
import qualified Data.ByteString.Internal as BSI
import           OpenSSL.Utils

data Mode = Encrypt | Decrypt deriving (Mode -> Mode -> Bool
(Mode -> Mode -> Bool) -> (Mode -> Mode -> Bool) -> Eq Mode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Mode -> Mode -> Bool
$c/= :: Mode -> Mode -> Bool
== :: Mode -> Mode -> Bool
$c== :: Mode -> Mode -> Bool
Eq, Int -> Mode -> ShowS
[Mode] -> ShowS
Mode -> String
(Int -> Mode -> ShowS)
-> (Mode -> String) -> ([Mode] -> ShowS) -> Show Mode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Mode] -> ShowS
$cshowList :: [Mode] -> ShowS
show :: Mode -> String
$cshow :: Mode -> String
showsPrec :: Int -> Mode -> ShowS
$cshowsPrec :: Int -> Mode -> ShowS
Show)

modeToInt :: Num a => Mode -> a
modeToInt :: forall a. Num a => Mode -> a
modeToInt Mode
Encrypt = a
1
modeToInt Mode
Decrypt = a
0

data {-# CTYPE "openssl/aes.h" "AES_KEY" #-} AES_KEY
data AESCtx = AESCtx
                (ForeignPtr AES_KEY)  -- the key schedule
                (ForeignPtr CUChar)   -- the IV / counter
                (ForeignPtr CUChar)   -- the encrypted counter (CTR mode)
                (IORef CUInt)         -- the number of bytes of the encrypted counter used
                Mode

foreign import capi unsafe "string.h memcpy"
        _memcpy :: Ptr CUChar -> Ptr CChar -> CSize -> IO (Ptr ())

foreign import capi unsafe "string.h memset"
        _memset :: Ptr CUChar -> CChar -> CSize -> IO ()

foreign import capi unsafe "openssl/aes.h AES_set_encrypt_key"
        _AES_set_encrypt_key :: Ptr CChar -> CInt -> Ptr AES_KEY -> IO CInt
foreign import capi unsafe "openssl/aes.h AES_set_decrypt_key"
        _AES_set_decrypt_key :: Ptr CChar -> CInt -> Ptr AES_KEY -> IO CInt

foreign import capi unsafe "openssl/aes.h AES_cbc_encrypt"
        _AES_cbc_encrypt :: Ptr CChar -> Ptr Word8 -> CULong -> Ptr AES_KEY -> Ptr CUChar -> CInt -> IO ()

foreign import capi unsafe "stdlib.h &free"
        _free :: FunPtr (Ptr a -> IO ())

-- | Construct a new context which holds the key schedule and IV.
newAESCtx :: Mode  -- ^ For CTR mode, this must always be Encrypt
          -> BS.ByteString  -- ^ Key: 128, 192 or 256 bits long
          -> BS.ByteString  -- ^ IV: 16 bytes long
          -> IO AESCtx
newAESCtx :: Mode -> ByteString -> ByteString -> IO AESCtx
newAESCtx Mode
mode ByteString
key ByteString
iv = do
  let keyLen :: Int
keyLen = ByteString -> Int
BS.length ByteString
key Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
8
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless ((Int -> Bool) -> [Int] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (Int
keyLen Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
==) [Int
128, Int
192, Int
256]) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> IO ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Bad AES key length"
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString -> Int
BS.length ByteString
iv Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
16) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> IO ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Bad AES128 iv length"
  ForeignPtr AES_KEY
ctx <- Int -> IO (ForeignPtr AES_KEY)
forall a. Int -> IO (ForeignPtr a)
mallocForeignPtrBytes ((Int
244))
{-# LINE 78 "OpenSSL/Cipher.hsc" #-}
  withForeignPtr ctx $ \ctxPtr ->
    BS.useAsCStringLen key (\(ptr, _) ->
      case mode of
           Encrypt -> _AES_set_encrypt_key ptr (fromIntegral keyLen) ctxPtr >>= failIf_ (/= 0)
           Decrypt -> _AES_set_decrypt_key ptr (fromIntegral keyLen) ctxPtr >>= failIf_ (/= 0))
  ForeignPtr CUChar
ivbytes <- Int -> IO (ForeignPtr CUChar)
forall a. Int -> IO (ForeignPtr a)
mallocForeignPtrBytes Int
16
  ForeignPtr CUChar
ecounter <- Int -> IO (ForeignPtr CUChar)
forall a. Int -> IO (ForeignPtr a)
mallocForeignPtrBytes Int
16
  IORef CUInt
nref <- CUInt -> IO (IORef CUInt)
forall a. a -> IO (IORef a)
newIORef CUInt
0
  ForeignPtr CUChar -> (Ptr CUChar -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CUChar
ecounter (\Ptr CUChar
ecptr -> Ptr CUChar -> CChar -> CSize -> IO ()
_memset Ptr CUChar
ecptr CChar
0 CSize
16)
  ForeignPtr CUChar -> (Ptr CUChar -> IO AESCtx) -> IO AESCtx
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CUChar
ivbytes ((Ptr CUChar -> IO AESCtx) -> IO AESCtx)
-> (Ptr CUChar -> IO AESCtx) -> IO AESCtx
forall a b. (a -> b) -> a -> b
$ \Ptr CUChar
ivPtr ->
    ByteString -> (CStringLen -> IO AESCtx) -> IO AESCtx
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BS.useAsCStringLen ByteString
iv ((CStringLen -> IO AESCtx) -> IO AESCtx)
-> (CStringLen -> IO AESCtx) -> IO AESCtx
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
ptr, Int
_) ->
    do Ptr ()
_ <- Ptr CUChar -> Ptr CChar -> CSize -> IO (Ptr ())
_memcpy Ptr CUChar
ivPtr Ptr CChar
ptr CSize
16
       AESCtx -> IO AESCtx
forall (m :: * -> *) a. Monad m => a -> m a
return (AESCtx -> IO AESCtx) -> AESCtx -> IO AESCtx
forall a b. (a -> b) -> a -> b
$ ForeignPtr AES_KEY
-> ForeignPtr CUChar
-> ForeignPtr CUChar
-> IORef CUInt
-> Mode
-> AESCtx
AESCtx ForeignPtr AES_KEY
ctx ForeignPtr CUChar
ivbytes ForeignPtr CUChar
ecounter IORef CUInt
nref Mode
mode

-- | Encrypt some number of blocks using CBC. This is an IO function because
--   the context is destructivly updated.
aesCBC :: AESCtx  -- ^ context
       -> BS.ByteString  -- ^ input, must be multiple of block size (16 bytes)
       -> IO BS.ByteString
aesCBC :: AESCtx -> ByteString -> IO ByteString
aesCBC (AESCtx ForeignPtr AES_KEY
ctx ForeignPtr CUChar
iv ForeignPtr CUChar
_ IORef CUInt
_ Mode
mode) ByteString
input = do
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ByteString -> Int
BS.length ByteString
input Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
16 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> IO ()
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"Bad input length to aesCBC"
  ForeignPtr AES_KEY
-> (Ptr AES_KEY -> IO ByteString) -> IO ByteString
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr AES_KEY
ctx ((Ptr AES_KEY -> IO ByteString) -> IO ByteString)
-> (Ptr AES_KEY -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr AES_KEY
ctxPtr ->
    ForeignPtr CUChar -> (Ptr CUChar -> IO ByteString) -> IO ByteString
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CUChar
iv ((Ptr CUChar -> IO ByteString) -> IO ByteString)
-> (Ptr CUChar -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr CUChar
ivPtr ->
    ByteString -> (CStringLen -> IO ByteString) -> IO ByteString
forall a. ByteString -> (CStringLen -> IO a) -> IO a
BS.useAsCStringLen ByteString
input ((CStringLen -> IO ByteString) -> IO ByteString)
-> (CStringLen -> IO ByteString) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \(Ptr CChar
ptr, Int
len) ->
    Int -> (Ptr Word8 -> IO ()) -> IO ByteString
BSI.create (ByteString -> Int
BS.length ByteString
input) ((Ptr Word8 -> IO ()) -> IO ByteString)
-> (Ptr Word8 -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
out ->
    Ptr CChar
-> Ptr Word8
-> CULong
-> Ptr AES_KEY
-> Ptr CUChar
-> CInt
-> IO ()
_AES_cbc_encrypt Ptr CChar
ptr Ptr Word8
out (Int -> CULong
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) Ptr AES_KEY
ctxPtr Ptr CUChar
ivPtr (CInt -> IO ()) -> CInt -> IO ()
forall a b. (a -> b) -> a -> b
$ Mode -> CInt
forall a. Num a => Mode -> a
modeToInt Mode
mode


{-# LINE 129 "OpenSSL/Cipher.hsc" #-}