-- |
-- Module      : Crypto.Data.AFIS
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- Haskell implementation of the Anti-forensic information splitter
-- available in LUKS. <http://clemens.endorphin.org/AFsplitter>
--
-- The algorithm bloats an arbitrary secret with many bits that are necessary for
-- the recovery of the key (merge), and allow greater way to permanently
-- destroy a key stored on disk.
--
{-# LANGUAGE ScopedTypeVariables #-}
module Crypto.Data.AFIS
    ( split
    , merge
    ) where

import           Crypto.Hash
import           Crypto.Random.Types
import           Crypto.Internal.Compat
import           Control.Monad (forM_, foldM)
import           Data.Word
import           Data.Bits
import           Foreign.Storable
import           Foreign.Ptr

import           Crypto.Internal.ByteArray (ByteArray, Bytes, MemView(..))
import qualified Crypto.Internal.ByteArray as B

import           Data.Memory.PtrMethods (memSet, memCopy)

-- | Split data to diffused data, using a random generator and
-- an hash algorithm.
--
-- the diffused data will consist of random data for (expandTimes-1)
-- then the last block will be xor of the accumulated random data diffused by
-- the hash algorithm.
--
-- ----------
-- -  orig  -
-- ----------
--
-- ---------- ---------- --------------
-- - rand1  - - rand2  - - orig ^ acc -
-- ---------- ---------- --------------
--
-- where acc is :
--   acc(n+1) = hash (n ++ rand(n)) ^ acc(n)
--
split :: (ByteArray ba, HashAlgorithm hash, DRG rng)
      => hash  -- ^ Hash algorithm to use as diffuser
      -> rng   -- ^ Random generator to use
      -> Int   -- ^ Number of times to diffuse the data.
      -> ba    -- ^ original data to diffuse.
      -> (ba, rng)         -- ^ The diffused data
{-# NOINLINE split #-}
split :: forall ba hash rng.
(ByteArray ba, HashAlgorithm hash, DRG rng) =>
hash -> rng -> Int -> ba -> (ba, rng)
split hash
hashAlg rng
rng Int
expandTimes ba
src
    | Int
expandTimes forall a. Ord a => a -> a -> Bool
<= Int
1 = forall a. HasCallStack => [Char] -> a
error [Char]
"invalid expandTimes value"
    | Bool
otherwise        = forall a. IO a -> a
unsafeDoIO forall a b. (a -> b) -> a -> b
$ do
        (rng
rng', ba
bs) <- forall ba p a. ByteArray ba => Int -> (Ptr p -> IO a) -> IO (a, ba)
B.allocRet Int
diffusedLen forall {a}. Ptr a -> IO rng
runOp
        forall (m :: * -> *) a. Monad m => a -> m a
return (ba
bs, rng
rng')
  where diffusedLen :: Int
diffusedLen = Int
blockSize forall a. Num a => a -> a -> a
* Int
expandTimes
        blockSize :: Int
blockSize   = forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
src
        runOp :: Ptr a -> IO rng
runOp Ptr a
dstPtr = do
            let lastBlock :: Ptr b
lastBlock = Ptr a
dstPtr forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
blockSize forall a. Num a => a -> a -> a
* (Int
expandTimesforall a. Num a => a -> a -> a
-Int
1))
            Ptr Word8 -> Word8 -> Int -> IO ()
memSet forall {b}. Ptr b
lastBlock Word8
0 Int
blockSize
            let randomBlockPtrs :: [Ptr b]
randomBlockPtrs = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr a
dstPtr forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => a -> a -> a
(*) Int
blockSize) [Int
0..(Int
expandTimesforall a. Num a => a -> a -> a
-Int
2)]
            rng
rng' <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM forall {b}. DRG b => b -> Ptr Word8 -> IO b
fillRandomBlock rng
rng forall {b}. [Ptr b]
randomBlockPtrs
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Ptr Word8 -> Ptr Word8 -> IO ()
addRandomBlock forall {b}. Ptr b
lastBlock) forall {b}. [Ptr b]
randomBlockPtrs
            forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
src forall a b. (a -> b) -> a -> b
$ \Ptr Word8
srcPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem Ptr Word8
srcPtr forall {b}. Ptr b
lastBlock Int
blockSize
            forall (m :: * -> *) a. Monad m => a -> m a
return rng
rng'
        addRandomBlock :: Ptr Word8 -> Ptr Word8 -> IO ()
addRandomBlock Ptr Word8
lastBlock Ptr Word8
blockPtr = do
            Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem Ptr Word8
blockPtr Ptr Word8
lastBlock Int
blockSize
            forall hash.
HashAlgorithm hash =>
hash -> Ptr Word8 -> Int -> IO ()
diffuse hash
hashAlg Ptr Word8
lastBlock Int
blockSize
        fillRandomBlock :: b -> Ptr Word8 -> IO b
fillRandomBlock b
g Ptr Word8
blockPtr = do
            let (Bytes
rand :: Bytes, b
g') = forall gen byteArray.
(DRG gen, ByteArray byteArray) =>
Int -> gen -> (byteArray, gen)
randomBytesGenerate Int
blockSize b
g
            forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
rand forall a b. (a -> b) -> a -> b
$ \Ptr Word8
randPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
blockPtr Ptr Word8
randPtr Int
blockSize
            forall (m :: * -> *) a. Monad m => a -> m a
return b
g'

-- | Merge previously diffused data back to the original data.
merge :: (ByteArray ba, HashAlgorithm hash)
      => hash  -- ^ Hash algorithm used as diffuser
      -> Int   -- ^ Number of times to un-diffuse the data
      -> ba    -- ^ Diffused data
      -> ba    -- ^ Original data
{-# NOINLINE merge #-}
merge :: forall ba hash.
(ByteArray ba, HashAlgorithm hash) =>
hash -> Int -> ba -> ba
merge hash
hashAlg Int
expandTimes ba
bs
    | Int
r forall a. Eq a => a -> a -> Bool
/= Int
0            = forall a. HasCallStack => [Char] -> a
error [Char]
"diffused data not a multiple of expandTimes"
    | Int
originalSize forall a. Ord a => a -> a -> Bool
<= Int
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"diffused data null"
    | Bool
otherwise         = forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
originalSize forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dstPtr ->
        forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
bs forall a b. (a -> b) -> a -> b
$ \Ptr Any
srcPtr -> do
            Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
dstPtr Word8
0 Int
originalSize
            forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..(Int
expandTimesforall a. Num a => a -> a -> a
-Int
2)] forall a b. (a -> b) -> a -> b
$ \Int
i -> do
                Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem (Ptr Any
srcPtr forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
i forall a. Num a => a -> a -> a
* Int
originalSize)) Ptr Word8
dstPtr Int
originalSize
                forall hash.
HashAlgorithm hash =>
hash -> Ptr Word8 -> Int -> IO ()
diffuse hash
hashAlg Ptr Word8
dstPtr Int
originalSize
            Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem (Ptr Any
srcPtr forall a b. Ptr a -> Int -> Ptr b
`plusPtr` ((Int
expandTimesforall a. Num a => a -> a -> a
-Int
1) forall a. Num a => a -> a -> a
* Int
originalSize)) Ptr Word8
dstPtr Int
originalSize
            forall (m :: * -> *) a. Monad m => a -> m a
return ()
  where (Int
originalSize,Int
r) = Int
len forall a. Integral a => a -> a -> (a, a)
`quotRem` Int
expandTimes
        len :: Int
len              = forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
bs

-- | inplace Xor with an input
-- dst = src `xor` dst
xorMem :: Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem :: Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem Ptr Word8
src Ptr Word8
dst Int
sz
    | Int
sz forall a. Integral a => a -> a -> a
`mod` Int
64 forall a. Eq a => a -> a -> Bool
== Int
0 = forall {b}.
(Storable b, Bits b) =>
Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
8 (forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
src :: Ptr Word64) (forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
dst) Int
sz
    | Int
sz forall a. Integral a => a -> a -> a
`mod` Int
32 forall a. Eq a => a -> a -> Bool
== Int
0 = forall {b}.
(Storable b, Bits b) =>
Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
4 (forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
src :: Ptr Word32) (forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
dst) Int
sz
    | Bool
otherwise        = forall {b}.
(Storable b, Bits b) =>
Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
1 (Ptr Word8
src :: Ptr Word8) Ptr Word8
dst Int
sz
  where loop :: Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
_    Ptr b
_ Ptr b
_ Int
0 = forall (m :: * -> *) a. Monad m => a -> m a
return ()
        loop Int
incr Ptr b
s Ptr b
d Int
n = do b
a <- forall a. Storable a => Ptr a -> IO a
peek Ptr b
s
                             b
b <- forall a. Storable a => Ptr a -> IO a
peek Ptr b
d
                             forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr b
d (b
a forall a. Bits a => a -> a -> a
`xor` b
b)
                             Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
incr (Ptr b
s forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
incr) (Ptr b
d forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
incr) (Int
nforall a. Num a => a -> a -> a
-Int
incr)

diffuse :: HashAlgorithm hash
        => hash      -- ^ Hash function to use as diffuser
        -> Ptr Word8 -- ^ buffer to diffuse, modify in place
        -> Int       -- ^ length of buffer to diffuse
        -> IO ()
diffuse :: forall hash.
HashAlgorithm hash =>
hash -> Ptr Word8 -> Int -> IO ()
diffuse hash
hashAlg Ptr Word8
src Int
sz = Ptr Word8 -> Int -> IO ()
loop Ptr Word8
src Int
0
  where (Int
full,Int
pad) = Int
sz forall a. Integral a => a -> a -> (a, a)
`quotRem` Int
digestSize 
        loop :: Ptr Word8 -> Int -> IO ()
loop Ptr Word8
s Int
i
            | Int
i forall a. Ord a => a -> a -> Bool
< Int
full = do Digest hash
h <- forall {m :: * -> *}.
Monad m =>
Int -> Ptr Word8 -> Int -> m (Digest hash)
hashBlock Int
i Ptr Word8
s Int
digestSize
                            forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Digest hash
h forall a b. (a -> b) -> a -> b
$ \Ptr Word8
hPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
s Ptr Word8
hPtr Int
digestSize
                            Ptr Word8 -> Int -> IO ()
loop (Ptr Word8
s forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
digestSize) (Int
iforall a. Num a => a -> a -> a
+Int
1)
            | Int
pad forall a. Eq a => a -> a -> Bool
/= Int
0 = do Digest hash
h <- forall {m :: * -> *}.
Monad m =>
Int -> Ptr Word8 -> Int -> m (Digest hash)
hashBlock Int
i Ptr Word8
s Int
pad
                            forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Digest hash
h forall a b. (a -> b) -> a -> b
$ \Ptr Word8
hPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
s Ptr Word8
hPtr Int
pad
                            forall (m :: * -> *) a. Monad m => a -> m a
return ()
            | Bool
otherwise = forall (m :: * -> *) a. Monad m => a -> m a
return ()

        digestSize :: Int
digestSize = forall a. HashAlgorithm a => a -> Int
hashDigestSize hash
hashAlg

        -- Hash [ BE32(n), (p .. p+hashSz) ]
        hashBlock :: Int -> Ptr Word8 -> Int -> m (Digest hash)
hashBlock Int
n Ptr Word8
p Int
hashSz = do
            let ctx :: Context hash
ctx = forall alg. HashAlgorithm alg => alg -> Context alg
hashInitWith hash
hashAlg
            forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! forall a. HashAlgorithm a => Context a -> Digest a
hashFinalize forall a b. (a -> b) -> a -> b
$ forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
Context a -> ba -> Context a
hashUpdate (forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
Context a -> ba -> Context a
hashUpdate Context hash
ctx (Int -> Bytes
be32 Int
n)) (Ptr Word8 -> Int -> MemView
MemView Ptr Word8
p Int
hashSz)

        be32 :: Int -> Bytes
        be32 :: Int -> Bytes
be32 Int
n = forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
4 forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
            forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
ptr               (Int -> Word8
f8 (Int
n forall a. Bits a => a -> Int -> a
`shiftR` Int
24))
            forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
ptr forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (Int -> Word8
f8 (Int
n forall a. Bits a => a -> Int -> a
`shiftR` Int
16))
            forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
ptr forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
2) (Int -> Word8
f8 (Int
n forall a. Bits a => a -> Int -> a
`shiftR` Int
8))
            forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
ptr forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
3) (Int -> Word8
f8 Int
n)
          where
                f8 :: Int -> Word8
                f8 :: Int -> Word8
f8 = forall a b. (Integral a, Num b) => a -> b
fromIntegral