{-# LANGUAGE ScopedTypeVariables #-}
module Crypto.Data.AFIS
( split
, merge
) where
import Crypto.Hash
import Crypto.Random.Types
import Crypto.Internal.Compat
import Control.Monad (forM_, foldM)
import Data.Word
import Data.Bits
import Foreign.Storable
import Foreign.Ptr
import Crypto.Internal.ByteArray (ByteArray, Bytes, MemView(..))
import qualified Crypto.Internal.ByteArray as B
import Data.Memory.PtrMethods (memSet, memCopy)
split :: (ByteArray ba, HashAlgorithm hash, DRG rng)
=> hash
-> rng
-> Int
-> ba
-> (ba, rng)
{-# NOINLINE split #-}
split :: forall ba hash rng.
(ByteArray ba, HashAlgorithm hash, DRG rng) =>
hash -> rng -> Int -> ba -> (ba, rng)
split hash
hashAlg rng
rng Int
expandTimes ba
src
| Int
expandTimes forall a. Ord a => a -> a -> Bool
<= Int
1 = forall a. HasCallStack => [Char] -> a
error [Char]
"invalid expandTimes value"
| Bool
otherwise = forall a. IO a -> a
unsafeDoIO forall a b. (a -> b) -> a -> b
$ do
(rng
rng', ba
bs) <- forall ba p a. ByteArray ba => Int -> (Ptr p -> IO a) -> IO (a, ba)
B.allocRet Int
diffusedLen forall {a}. Ptr a -> IO rng
runOp
forall (m :: * -> *) a. Monad m => a -> m a
return (ba
bs, rng
rng')
where diffusedLen :: Int
diffusedLen = Int
blockSize forall a. Num a => a -> a -> a
* Int
expandTimes
blockSize :: Int
blockSize = forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
src
runOp :: Ptr a -> IO rng
runOp Ptr a
dstPtr = do
let lastBlock :: Ptr b
lastBlock = Ptr a
dstPtr forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
blockSize forall a. Num a => a -> a -> a
* (Int
expandTimesforall a. Num a => a -> a -> a
-Int
1))
Ptr Word8 -> Word8 -> Int -> IO ()
memSet forall {b}. Ptr b
lastBlock Word8
0 Int
blockSize
let randomBlockPtrs :: [Ptr b]
randomBlockPtrs = forall a b. (a -> b) -> [a] -> [b]
map (forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr a
dstPtr forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => a -> a -> a
(*) Int
blockSize) [Int
0..(Int
expandTimesforall a. Num a => a -> a -> a
-Int
2)]
rng
rng' <- forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM forall {b}. DRG b => b -> Ptr Word8 -> IO b
fillRandomBlock rng
rng forall {b}. [Ptr b]
randomBlockPtrs
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Ptr Word8 -> Ptr Word8 -> IO ()
addRandomBlock forall {b}. Ptr b
lastBlock) forall {b}. [Ptr b]
randomBlockPtrs
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
src forall a b. (a -> b) -> a -> b
$ \Ptr Word8
srcPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem Ptr Word8
srcPtr forall {b}. Ptr b
lastBlock Int
blockSize
forall (m :: * -> *) a. Monad m => a -> m a
return rng
rng'
addRandomBlock :: Ptr Word8 -> Ptr Word8 -> IO ()
addRandomBlock Ptr Word8
lastBlock Ptr Word8
blockPtr = do
Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem Ptr Word8
blockPtr Ptr Word8
lastBlock Int
blockSize
forall hash.
HashAlgorithm hash =>
hash -> Ptr Word8 -> Int -> IO ()
diffuse hash
hashAlg Ptr Word8
lastBlock Int
blockSize
fillRandomBlock :: b -> Ptr Word8 -> IO b
fillRandomBlock b
g Ptr Word8
blockPtr = do
let (Bytes
rand :: Bytes, b
g') = forall gen byteArray.
(DRG gen, ByteArray byteArray) =>
Int -> gen -> (byteArray, gen)
randomBytesGenerate Int
blockSize b
g
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
rand forall a b. (a -> b) -> a -> b
$ \Ptr Word8
randPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
blockPtr Ptr Word8
randPtr Int
blockSize
forall (m :: * -> *) a. Monad m => a -> m a
return b
g'
merge :: (ByteArray ba, HashAlgorithm hash)
=> hash
-> Int
-> ba
-> ba
{-# NOINLINE merge #-}
merge :: forall ba hash.
(ByteArray ba, HashAlgorithm hash) =>
hash -> Int -> ba -> ba
merge hash
hashAlg Int
expandTimes ba
bs
| Int
r forall a. Eq a => a -> a -> Bool
/= Int
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"diffused data not a multiple of expandTimes"
| Int
originalSize forall a. Ord a => a -> a -> Bool
<= Int
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"diffused data null"
| Bool
otherwise = forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
originalSize forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dstPtr ->
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
bs forall a b. (a -> b) -> a -> b
$ \Ptr Any
srcPtr -> do
Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
dstPtr Word8
0 Int
originalSize
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..(Int
expandTimesforall a. Num a => a -> a -> a
-Int
2)] forall a b. (a -> b) -> a -> b
$ \Int
i -> do
Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem (Ptr Any
srcPtr forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
i forall a. Num a => a -> a -> a
* Int
originalSize)) Ptr Word8
dstPtr Int
originalSize
forall hash.
HashAlgorithm hash =>
hash -> Ptr Word8 -> Int -> IO ()
diffuse hash
hashAlg Ptr Word8
dstPtr Int
originalSize
Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem (Ptr Any
srcPtr forall a b. Ptr a -> Int -> Ptr b
`plusPtr` ((Int
expandTimesforall a. Num a => a -> a -> a
-Int
1) forall a. Num a => a -> a -> a
* Int
originalSize)) Ptr Word8
dstPtr Int
originalSize
forall (m :: * -> *) a. Monad m => a -> m a
return ()
where (Int
originalSize,Int
r) = Int
len forall a. Integral a => a -> a -> (a, a)
`quotRem` Int
expandTimes
len :: Int
len = forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
bs
xorMem :: Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem :: Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem Ptr Word8
src Ptr Word8
dst Int
sz
| Int
sz forall a. Integral a => a -> a -> a
`mod` Int
64 forall a. Eq a => a -> a -> Bool
== Int
0 = forall {b}.
(Storable b, Bits b) =>
Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
8 (forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
src :: Ptr Word64) (forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
dst) Int
sz
| Int
sz forall a. Integral a => a -> a -> a
`mod` Int
32 forall a. Eq a => a -> a -> Bool
== Int
0 = forall {b}.
(Storable b, Bits b) =>
Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
4 (forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
src :: Ptr Word32) (forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
dst) Int
sz
| Bool
otherwise = forall {b}.
(Storable b, Bits b) =>
Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
1 (Ptr Word8
src :: Ptr Word8) Ptr Word8
dst Int
sz
where loop :: Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
_ Ptr b
_ Ptr b
_ Int
0 = forall (m :: * -> *) a. Monad m => a -> m a
return ()
loop Int
incr Ptr b
s Ptr b
d Int
n = do b
a <- forall a. Storable a => Ptr a -> IO a
peek Ptr b
s
b
b <- forall a. Storable a => Ptr a -> IO a
peek Ptr b
d
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr b
d (b
a forall a. Bits a => a -> a -> a
`xor` b
b)
Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
incr (Ptr b
s forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
incr) (Ptr b
d forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
incr) (Int
nforall a. Num a => a -> a -> a
-Int
incr)
diffuse :: HashAlgorithm hash
=> hash
-> Ptr Word8
-> Int
-> IO ()
diffuse :: forall hash.
HashAlgorithm hash =>
hash -> Ptr Word8 -> Int -> IO ()
diffuse hash
hashAlg Ptr Word8
src Int
sz = Ptr Word8 -> Int -> IO ()
loop Ptr Word8
src Int
0
where (Int
full,Int
pad) = Int
sz forall a. Integral a => a -> a -> (a, a)
`quotRem` Int
digestSize
loop :: Ptr Word8 -> Int -> IO ()
loop Ptr Word8
s Int
i
| Int
i forall a. Ord a => a -> a -> Bool
< Int
full = do Digest hash
h <- forall {m :: * -> *}.
Monad m =>
Int -> Ptr Word8 -> Int -> m (Digest hash)
hashBlock Int
i Ptr Word8
s Int
digestSize
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Digest hash
h forall a b. (a -> b) -> a -> b
$ \Ptr Word8
hPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
s Ptr Word8
hPtr Int
digestSize
Ptr Word8 -> Int -> IO ()
loop (Ptr Word8
s forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
digestSize) (Int
iforall a. Num a => a -> a -> a
+Int
1)
| Int
pad forall a. Eq a => a -> a -> Bool
/= Int
0 = do Digest hash
h <- forall {m :: * -> *}.
Monad m =>
Int -> Ptr Word8 -> Int -> m (Digest hash)
hashBlock Int
i Ptr Word8
s Int
pad
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Digest hash
h forall a b. (a -> b) -> a -> b
$ \Ptr Word8
hPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
s Ptr Word8
hPtr Int
pad
forall (m :: * -> *) a. Monad m => a -> m a
return ()
| Bool
otherwise = forall (m :: * -> *) a. Monad m => a -> m a
return ()
digestSize :: Int
digestSize = forall a. HashAlgorithm a => a -> Int
hashDigestSize hash
hashAlg
hashBlock :: Int -> Ptr Word8 -> Int -> m (Digest hash)
hashBlock Int
n Ptr Word8
p Int
hashSz = do
let ctx :: Context hash
ctx = forall alg. HashAlgorithm alg => alg -> Context alg
hashInitWith hash
hashAlg
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! forall a. HashAlgorithm a => Context a -> Digest a
hashFinalize forall a b. (a -> b) -> a -> b
$ forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
Context a -> ba -> Context a
hashUpdate (forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
Context a -> ba -> Context a
hashUpdate Context hash
ctx (Int -> Bytes
be32 Int
n)) (Ptr Word8 -> Int -> MemView
MemView Ptr Word8
p Int
hashSz)
be32 :: Int -> Bytes
be32 :: Int -> Bytes
be32 Int
n = forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
4 forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
ptr (Int -> Word8
f8 (Int
n forall a. Bits a => a -> Int -> a
`shiftR` Int
24))
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
ptr forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (Int -> Word8
f8 (Int
n forall a. Bits a => a -> Int -> a
`shiftR` Int
16))
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
ptr forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
2) (Int -> Word8
f8 (Int
n forall a. Bits a => a -> Int -> a
`shiftR` Int
8))
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
ptr forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
3) (Int -> Word8
f8 Int
n)
where
f8 :: Int -> Word8
f8 :: Int -> Word8
f8 = forall a b. (Integral a, Num b) => a -> b
fromIntegral