{-# LANGUAGE CPP                       #-}
{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE TypeFamilies              #-}
{-# LANGUAGE FlexibleContexts          #-}
{-# LANGUAGE FlexibleInstances         #-}
{-# LANGUAGE MultiParamTypeClasses     #-}
{-# LANGUAGE ScopedTypeVariables       #-}
{-# LANGUAGE ConstraintKinds           #-}

-- | This module exposes the low-level internal details of ciphers. Do
-- not import this module unless you want to implement a new cipher or
-- give a new implementation of an existing cipher.
module Raaz.Cipher.Internal
       (
         -- * Internals of a cipher.
         -- $cipherdoc$
         Cipher, CipherMode(..)
         -- ** Cipher implementation
       , CipherI(..), SomeCipherI(..)

         -- ** Stream ciphers.
         -- $streamcipher$
       , StreamCipher, makeCipherI
       , transform, transform'
       -- ** Unsafe encryption and decryption.
       -- $unsafecipher$
       --
       , unsafeEncrypt, unsafeDecrypt, unsafeEncrypt', unsafeDecrypt'

       ) where

import Control.Monad.IO.Class          (liftIO)
import Data.ByteString.Internal as IB
import Foreign.Ptr                     (castPtr)
import System.IO.Unsafe                (unsafePerformIO)

import Raaz.Core
import Raaz.Core.Util.ByteString as B

-- $cipherdoc$
--
-- Ciphers provide symmetric encryption in the raaz library and are
-- captured by the type class `Cipher`.  They are instances of the
-- class `Symmetric` and the associated type `Key` captures the all
-- that is required to determine the encryption and decryption
-- process. In most ciphers, this includes what is know as the
-- _encryption key_ as well as the _initialisation vector_.
--
-- Instances of `Cipher` is only required to provide full block
-- encryption/decryption algorithms.  Implementations are captured by
-- two types.
--
-- [`CipherI`:] Values of this type that captures implementations of a
-- cipher.  This type is parameterised over the memory element that is
-- used internally by the implementation.
--
-- [`SomeCipherI`:] The existentially quantified version of `CipherI`
-- over its memory element. By wrapping the memory element inside the
-- existential quantifier, values of this type exposes only the
-- interface and not the internals of the implementation. The
-- `Implementation` associated type of a cipher is the type
-- `SomeCipherI`
--
-- To support a new cipher, a developer needs to:
--
-- 1. Define a new type which captures the cipher. This type should be
--    an instance of the class `Cipher`.
--
-- 2. Define an implementation, i.e. a value of the type `SomeCipherI`.
--
-- 3. Define a recommended implementation, i.e. an instance of the
--    type class `Raaz.Core.Primitives.Recommendation`
--

-- $streamcipher$
--
-- Stream ciphers are special class of ciphers which can encrypt
-- messages of any length (not necessarily multiples of block length).
-- Typically, stream ciphers are obtained by xoring the data with a
-- stream of prg values that the stream ciphers generate. As a
-- consequence, the encryption and decryption is the same algorithm.
-- one can also use the stream cipher as a pseudo-random generator.
--
-- We have the class `StreamCipher` that captures valid stream ciphers.
--


-- | Block cipher modes.
data CipherMode = CBC -- ^ Cipher-block chaining
                | CTR -- ^ Counter
                deriving (Int -> CipherMode -> ShowS
[CipherMode] -> ShowS
CipherMode -> String
(Int -> CipherMode -> ShowS)
-> (CipherMode -> String)
-> ([CipherMode] -> ShowS)
-> Show CipherMode
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [CipherMode] -> ShowS
$cshowList :: [CipherMode] -> ShowS
show :: CipherMode -> String
$cshow :: CipherMode -> String
showsPrec :: Int -> CipherMode -> ShowS
$cshowsPrec :: Int -> CipherMode -> ShowS
Show, CipherMode -> CipherMode -> Bool
(CipherMode -> CipherMode -> Bool)
-> (CipherMode -> CipherMode -> Bool) -> Eq CipherMode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: CipherMode -> CipherMode -> Bool
$c/= :: CipherMode -> CipherMode -> Bool
== :: CipherMode -> CipherMode -> Bool
$c== :: CipherMode -> CipherMode -> Bool
Eq)

-- | The implementation of a block cipher.
data CipherI cipher encMem decMem = CipherI
     { CipherI cipher encMem decMem -> String
cipherIName         :: String
     , CipherI cipher encMem decMem -> String
cipherIDescription  :: String
       -- | The underlying block encryption function.
     , CipherI cipher encMem decMem
-> Pointer -> BLOCKS cipher -> MT encMem ()
encryptBlocks :: Pointer -> BLOCKS cipher -> MT encMem ()
       -- | The underlying block decryption function.
     , CipherI cipher encMem decMem
-> Pointer -> BLOCKS cipher -> MT decMem ()
decryptBlocks :: Pointer -> BLOCKS cipher -> MT decMem ()
     , CipherI cipher encMem decMem -> Alignment
cipherStartAlignment :: Alignment
     }

-- | Type constraints on the memory of a block cipher implementation.
type CipherM cipher encMem decMem = ( Initialisable encMem (Key cipher)
                                    , Initialisable decMem (Key cipher)
                                    , Primitive cipher
                                    ) -- TODO: More need initialisable from buffer.

-- | Some implementation of a block cipher. This type is existentially
-- quantifies over the memory used in the implementation.
data SomeCipherI cipher =
  forall encMem decMem . CipherM cipher encMem decMem
  => SomeCipherI (CipherI cipher encMem decMem)


instance BlockAlgorithm (CipherI cipher encMem decMem) where
  bufferStartAlignment :: CipherI cipher encMem decMem -> Alignment
bufferStartAlignment = CipherI cipher encMem decMem -> Alignment
forall cipher encMem decMem.
CipherI cipher encMem decMem -> Alignment
cipherStartAlignment

instance Describable (CipherI cipher encMem decMem) where
  name :: CipherI cipher encMem decMem -> String
name        = CipherI cipher encMem decMem -> String
forall cipher encMem decMem. CipherI cipher encMem decMem -> String
cipherIName
  description :: CipherI cipher encMem decMem -> String
description = CipherI cipher encMem decMem -> String
forall cipher encMem decMem. CipherI cipher encMem decMem -> String
cipherIDescription


instance Describable (SomeCipherI cipher) where
  name :: SomeCipherI cipher -> String
name         (SomeCipherI CipherI cipher encMem decMem
cI) = CipherI cipher encMem decMem -> String
forall d. Describable d => d -> String
name CipherI cipher encMem decMem
cI
  description :: SomeCipherI cipher -> String
description  (SomeCipherI CipherI cipher encMem decMem
cI) = CipherI cipher encMem decMem -> String
forall d. Describable d => d -> String
description CipherI cipher encMem decMem
cI

instance BlockAlgorithm (SomeCipherI cipher) where
  bufferStartAlignment :: SomeCipherI cipher -> Alignment
bufferStartAlignment (SomeCipherI CipherI cipher encMem decMem
imp) = CipherI cipher encMem decMem -> Alignment
forall a. BlockAlgorithm a => a -> Alignment
bufferStartAlignment CipherI cipher encMem decMem
imp


-- | Class capturing ciphers. The implementation of this class should
-- give an encryption and decryption algorithm for messages of length
-- which is a multiple of the block size.  Needless to say, the
-- encryption and decryption should be inverses of each other for such
-- messages.
class (Primitive cipher, Implementation cipher ~ SomeCipherI cipher, Describable cipher)
      => Cipher cipher

-- | Class that captures stream ciphers. An instance of `StreamCipher`
-- should be an instance of `Cipher`, with the following additional
-- constraints.
--
-- 1. The encryption and decryption should be the same algorithm.
--
-- 2. Encryption/decryption can be applied to a messages of length @l@
--    even if @l@ is not a multiple of block length.
--
-- 3. The encryption of a prefix of a length @l@ of a message @m@
--    should be the same as the @l@ length prefix of the encryption of
--    @m@.
--
-- It is the duty of the implementer of the cipher to ensure that the
-- above conditions are true before declaring an instance of a stream
-- cipher.
class Cipher cipher => StreamCipher cipher


-- | Constructs a `CipherI`  value out of a stream transformation function. Useful in
--   building a Cipher instance of a stream cipher.
makeCipherI :: String                                -- ^ name
            -> String                                -- ^ description
            -> (Pointer -> BLOCKS prim -> MT mem ()) -- ^ stream transformer
            -> Alignment                             -- ^ buffer starting alignment
            -> CipherI prim mem mem
makeCipherI :: String
-> String
-> (Pointer -> BLOCKS prim -> MT mem ())
-> Alignment
-> CipherI prim mem mem
makeCipherI String
nm String
des Pointer -> BLOCKS prim -> MT mem ()
trans = String
-> String
-> (Pointer -> BLOCKS prim -> MT mem ())
-> (Pointer -> BLOCKS prim -> MT mem ())
-> Alignment
-> CipherI prim mem mem
forall cipher encMem decMem.
String
-> String
-> (Pointer -> BLOCKS cipher -> MT encMem ())
-> (Pointer -> BLOCKS cipher -> MT decMem ())
-> Alignment
-> CipherI cipher encMem decMem
CipherI String
nm String
des Pointer -> BLOCKS prim -> MT mem ()
trans Pointer -> BLOCKS prim -> MT mem ()
trans

------------------ Unsafe cipher operations ------------------------

-- $unsafecipher$
--
-- We expose some unsafe functions to encrypt and decrypt bytestrings.
-- These function works correctly only if the input byte string has a
-- length which is a multiple of the block size of the cipher and
-- hence are unsafe to use as general methods of encryption and
-- decryption of data.  Use these functions for testing and
-- benchmarking and nothing else.
--
-- There are multiple ways to handle arbitrary sized strings like
-- padding, cipher block stealing etc. They are not exposed here
-- though.

-- | Encrypt the given `ByteString`. This function is unsafe because
-- it only works correctly when the input `ByteString` is of length
-- which is a multiple of the block length of the cipher.
unsafeEncrypt' :: Cipher c
               => c                -- ^ The cipher to use
               -> Implementation c -- ^ The implementation to use
               -> Key c            -- ^ The key to use
               -> ByteString       -- ^ The string to encrypt.
               -> ByteString
unsafeEncrypt' :: c -> Implementation c -> Key c -> ByteString -> ByteString
unsafeEncrypt' c
c simp :: Implementation c
simp@(SomeCipherI imp) Key c
key ByteString
bs = Int -> (Ptr Word8 -> IO ()) -> ByteString
IB.unsafeCreate Int
sbytes Ptr Word8 -> IO ()
go
  where sz :: BLOCKS c
sz           = BYTES Int -> BLOCKS c
forall src dest. (LengthUnit src, LengthUnit dest) => src -> dest
atMost (ByteString -> BYTES Int
B.length ByteString
bs) BLOCKS c -> BLOCKS c -> BLOCKS c
forall a. a -> a -> a
`asTypeOf` Int -> c -> BLOCKS c
forall p. Int -> p -> BLOCKS p
blocksOf Int
1 c
c
        BYTES Int
sbytes = BLOCKS c -> BYTES Int
forall u. LengthUnit u => u -> BYTES Int
inBytes BLOCKS c
sz
        go :: Ptr Word8 -> IO ()
go    Ptr Word8
ptr    = Implementation c -> BLOCKS c -> (Pointer -> IO ()) -> IO ()
forall prim b.
Primitive prim =>
Implementation prim -> BLOCKS prim -> (Pointer -> IO b) -> IO b
allocBufferFor Implementation c
simp BLOCKS c
sz ((Pointer -> IO ()) -> IO ()) -> (Pointer -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ Pointer
buf -> MT encMem () -> IO ()
forall (mT :: * -> * -> *) mem a.
(MemoryThread mT, Memory mem) =>
mT mem a -> IO a
insecurely (MT encMem () -> IO ()) -> MT encMem () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
          Key c -> MT encMem ()
forall m v. Initialisable m v => v -> MT m ()
initialise Key c
key
          IO () -> MT encMem ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> MT encMem ()) -> IO () -> MT encMem ()
forall a b. (a -> b) -> a -> b
$ BLOCKS c -> ByteString -> Pointer -> IO ()
forall n. LengthUnit n => n -> ByteString -> Pointer -> IO ()
unsafeNCopyToPointer BLOCKS c
sz ByteString
bs Pointer
buf -- Copy the input to buffer.
          CipherI c encMem decMem -> Pointer -> BLOCKS c -> MT encMem ()
forall cipher encMem decMem.
CipherI cipher encMem decMem
-> Pointer -> BLOCKS cipher -> MT encMem ()
encryptBlocks CipherI c encMem decMem
imp Pointer
buf BLOCKS c
sz
          IO () -> MT encMem ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> MT encMem ()) -> IO () -> MT encMem ()
forall a b. (a -> b) -> a -> b
$ Dest Pointer -> Src Pointer -> BLOCKS c -> IO ()
forall (m :: * -> *) l.
(MonadIO m, LengthUnit l) =>
Dest Pointer -> Src Pointer -> l -> m ()
Raaz.Core.memcpy (Pointer -> Dest Pointer
forall a. a -> Dest a
destination (Ptr Word8 -> Pointer
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
ptr)) (Pointer -> Src Pointer
forall a. a -> Src a
source Pointer
buf) BLOCKS c
sz

-- | Transforms a given bytestring using a stream cipher. We use the
-- transform instead of encrypt/decrypt because for stream ciphers
-- these operations are same.

transform' :: StreamCipher c
           => c
           -> Implementation c
           -> Key c
           -> ByteString
           -> ByteString
transform' :: c -> Implementation c -> Key c -> ByteString -> ByteString
transform' c
c simp :: Implementation c
simp@(SomeCipherI imp) Key c
key ByteString
bs = IO ByteString -> ByteString
forall a. IO a -> a
unsafePerformIO (IO ByteString -> ByteString) -> IO ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> (Ptr Word8 -> IO Int) -> IO ByteString
IB.createAndTrim (BYTES Int -> Int
forall a. Enum a => a -> Int
fromEnum (BYTES Int -> Int) -> BYTES Int -> Int
forall a b. (a -> b) -> a -> b
$ BLOCKS c -> BYTES Int
forall u. LengthUnit u => u -> BYTES Int
inBytes BLOCKS c
blks) Ptr Word8 -> IO Int
action
   where blks :: BLOCKS c
blks          = BYTES Int -> BLOCKS c
forall src dest. (LengthUnit src, LengthUnit dest) => src -> dest
atLeast BYTES Int
len BLOCKS c -> BLOCKS c -> BLOCKS c
forall a. a -> a -> a
`asTypeOf` Int -> c -> BLOCKS c
forall p. Int -> p -> BLOCKS p
blocksOf Int
1 c
c
         len :: BYTES Int
len           = ByteString -> BYTES Int
B.length ByteString
bs
         action :: Ptr Word8 -> IO Int
action Ptr Word8
ptr    = Implementation c -> BLOCKS c -> (Pointer -> IO Int) -> IO Int
forall prim b.
Primitive prim =>
Implementation prim -> BLOCKS prim -> (Pointer -> IO b) -> IO b
allocBufferFor Implementation c
simp BLOCKS c
blks ((Pointer -> IO Int) -> IO Int) -> (Pointer -> IO Int) -> IO Int
forall a b. (a -> b) -> a -> b
$ \ Pointer
buf -> MT encMem Int -> IO Int
forall (mT :: * -> * -> *) mem a.
(MemoryThread mT, Memory mem) =>
mT mem a -> IO a
insecurely (MT encMem Int -> IO Int) -> MT encMem Int -> IO Int
forall a b. (a -> b) -> a -> b
$ do
           Key c -> MT encMem ()
forall m v. Initialisable m v => v -> MT m ()
initialise Key c
key
           IO () -> MT encMem ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> MT encMem ()) -> IO () -> MT encMem ()
forall a b. (a -> b) -> a -> b
$ ByteString -> Pointer -> IO ()
unsafeCopyToPointer ByteString
bs Pointer
buf -- copy data into the buffer
           CipherI c encMem decMem -> Pointer -> BLOCKS c -> MT encMem ()
forall cipher encMem decMem.
CipherI cipher encMem decMem
-> Pointer -> BLOCKS cipher -> MT encMem ()
encryptBlocks CipherI c encMem decMem
imp Pointer
buf BLOCKS c
blks          -- encrypt it
           IO () -> MT encMem ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> MT encMem ()) -> IO () -> MT encMem ()
forall a b. (a -> b) -> a -> b
$ Dest Pointer -> Src Pointer -> BYTES Int -> IO ()
forall (m :: * -> *) l.
(MonadIO m, LengthUnit l) =>
Dest Pointer -> Src Pointer -> l -> m ()
Raaz.Core.memcpy (Pointer -> Dest Pointer
forall a. a -> Dest a
destination (Ptr Word8 -> Pointer
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
ptr)) (Pointer -> Src Pointer
forall a. a -> Src a
source Pointer
buf) BYTES Int
len
                                               -- copy it back to the actual pointer.
           Int -> MT encMem Int
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> MT encMem Int) -> Int -> MT encMem Int
forall a b. (a -> b) -> a -> b
$ BYTES Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral BYTES Int
len

-- | Transform a given bytestring using the recommended implementation
-- of a stream cipher.
transform :: (StreamCipher c, Recommendation c)
           => c
           -> Key c
           -> ByteString
           -> ByteString
transform :: c -> Key c -> ByteString -> ByteString
transform c
c = c -> Implementation c -> Key c -> ByteString -> ByteString
forall c.
StreamCipher c =>
c -> Implementation c -> Key c -> ByteString -> ByteString
transform' c
c (Implementation c -> Key c -> ByteString -> ByteString)
-> Implementation c -> Key c -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ c -> Implementation c
forall p. Recommendation p => p -> Implementation p
recommended c
c



-- | Encrypt using the recommended implementation. This function is
-- unsafe because it only works correctly when the input `ByteString`
-- is of length which is a multiple of the block length of the cipher.
unsafeEncrypt :: (Cipher c, Recommendation c)
              => c            -- ^ The cipher
              -> Key c        -- ^ The key to use
              -> ByteString   -- ^ The string to encrypt
              -> ByteString
unsafeEncrypt :: c -> Key c -> ByteString -> ByteString
unsafeEncrypt c
c = c -> Implementation c -> Key c -> ByteString -> ByteString
forall c.
Cipher c =>
c -> Implementation c -> Key c -> ByteString -> ByteString
unsafeEncrypt' c
c (Implementation c -> Key c -> ByteString -> ByteString)
-> Implementation c -> Key c -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ c -> Implementation c
forall p. Recommendation p => p -> Implementation p
recommended c
c

-- | Decrypts the given `ByteString`. This function is unsafe because
-- it only works correctly when the input `ByteString` is of length
-- which is a multiple of the block length of the cipher.
unsafeDecrypt' :: Cipher c
               => c                -- ^ The cipher to use
               -> Implementation c -- ^ The implementation to use
               -> Key c            -- ^ The key to use
               -> ByteString       -- ^ The string to encrypt.
               -> ByteString
unsafeDecrypt' :: c -> Implementation c -> Key c -> ByteString -> ByteString
unsafeDecrypt' c
c simp :: Implementation c
simp@(SomeCipherI imp) Key c
key ByteString
bs = Int -> (Ptr Word8 -> IO ()) -> ByteString
IB.unsafeCreate Int
sbytes Ptr Word8 -> IO ()
go
  where sz :: BLOCKS c
sz           = BYTES Int -> BLOCKS c
forall src dest. (LengthUnit src, LengthUnit dest) => src -> dest
atMost (ByteString -> BYTES Int
B.length ByteString
bs) BLOCKS c -> BLOCKS c -> BLOCKS c
forall a. a -> a -> a
`asTypeOf` Int -> c -> BLOCKS c
forall p. Int -> p -> BLOCKS p
blocksOf Int
1 c
c
        BYTES Int
sbytes = BLOCKS c -> BYTES Int
forall u. LengthUnit u => u -> BYTES Int
inBytes BLOCKS c
sz
        go :: Ptr Word8 -> IO ()
go    Ptr Word8
ptr    = Implementation c -> BLOCKS c -> (Pointer -> IO ()) -> IO ()
forall prim b.
Primitive prim =>
Implementation prim -> BLOCKS prim -> (Pointer -> IO b) -> IO b
allocBufferFor Implementation c
simp BLOCKS c
sz ((Pointer -> IO ()) -> IO ()) -> (Pointer -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \ Pointer
buf -> MT decMem () -> IO ()
forall (mT :: * -> * -> *) mem a.
(MemoryThread mT, Memory mem) =>
mT mem a -> IO a
insecurely (MT decMem () -> IO ()) -> MT decMem () -> IO ()
forall a b. (a -> b) -> a -> b
$ do
          Key c -> MT decMem ()
forall m v. Initialisable m v => v -> MT m ()
initialise Key c
key
          IO () -> MT decMem ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> MT decMem ()) -> IO () -> MT decMem ()
forall a b. (a -> b) -> a -> b
$ BLOCKS c -> ByteString -> Pointer -> IO ()
forall n. LengthUnit n => n -> ByteString -> Pointer -> IO ()
unsafeNCopyToPointer BLOCKS c
sz ByteString
bs Pointer
buf -- Copy the input to buffer.
          CipherI c encMem decMem -> Pointer -> BLOCKS c -> MT decMem ()
forall cipher encMem decMem.
CipherI cipher encMem decMem
-> Pointer -> BLOCKS cipher -> MT decMem ()
decryptBlocks CipherI c encMem decMem
imp Pointer
buf BLOCKS c
sz
          IO () -> MT decMem ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> MT decMem ()) -> IO () -> MT decMem ()
forall a b. (a -> b) -> a -> b
$ Dest Pointer -> Src Pointer -> BLOCKS c -> IO ()
forall (m :: * -> *) l.
(MonadIO m, LengthUnit l) =>
Dest Pointer -> Src Pointer -> l -> m ()
Raaz.Core.memcpy (Pointer -> Dest Pointer
forall a. a -> Dest a
destination (Ptr Word8 -> Pointer
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
ptr)) (Pointer -> Src Pointer
forall a. a -> Src a
source Pointer
buf) BLOCKS c
sz

-- | Decrypt using the recommended implementation. This function is
-- unsafe because it only works correctly when the input `ByteString`
-- is of length which is a multiple of the block length of the cipher.
unsafeDecrypt :: (Cipher c, Recommendation c)
              => c            -- ^ The cipher
              -> Key c        -- ^ The key to use
              -> ByteString   -- ^ The string to encrypt
              -> ByteString
unsafeDecrypt :: c -> Key c -> ByteString -> ByteString
unsafeDecrypt c
c = c -> Implementation c -> Key c -> ByteString -> ByteString
forall c.
Cipher c =>
c -> Implementation c -> Key c -> ByteString -> ByteString
unsafeDecrypt' c
c (Implementation c -> Key c -> ByteString -> ByteString)
-> Implementation c -> Key c -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ c -> Implementation c
forall p. Recommendation p => p -> Implementation p
recommended c
c