{-# LANGUAGE ScopedTypeVariables #-}
module Crypto.Data.AFIS
( split
, merge
) where
import Crypto.Hash
import Crypto.Random.Types
import Crypto.Internal.Compat
import Control.Monad (forM_, foldM)
import Data.Word
import Data.Bits
import Foreign.Storable
import Foreign.Ptr
import Crypto.Internal.ByteArray (ByteArray, Bytes, MemView(..))
import qualified Crypto.Internal.ByteArray as B
import Data.Memory.PtrMethods (memSet, memCopy)
split :: (ByteArray ba, HashAlgorithm hash, DRG rng)
=> hash
-> rng
-> Int
-> ba
-> (ba, rng)
{-# NOINLINE split #-}
split :: hash -> rng -> Int -> ba -> (ba, rng)
split hash
hashAlg rng
rng Int
expandTimes ba
src
| Int
expandTimes Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = [Char] -> (ba, rng)
forall a. HasCallStack => [Char] -> a
error [Char]
"invalid expandTimes value"
| Bool
otherwise = IO (ba, rng) -> (ba, rng)
forall a. IO a -> a
unsafeDoIO (IO (ba, rng) -> (ba, rng)) -> IO (ba, rng) -> (ba, rng)
forall a b. (a -> b) -> a -> b
$ do
(rng
rng', ba
bs) <- Int -> (Ptr Any -> IO rng) -> IO (rng, ba)
forall ba p a. ByteArray ba => Int -> (Ptr p -> IO a) -> IO (a, ba)
B.allocRet Int
diffusedLen Ptr Any -> IO rng
forall a. Ptr a -> IO rng
runOp
(ba, rng) -> IO (ba, rng)
forall (m :: * -> *) a. Monad m => a -> m a
return (ba
bs, rng
rng')
where diffusedLen :: Int
diffusedLen = Int
blockSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
expandTimes
blockSize :: Int
blockSize = ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
src
runOp :: Ptr a -> IO rng
runOp Ptr a
dstPtr = do
let lastBlock :: Ptr b
lastBlock = Ptr a
dstPtr Ptr a -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
blockSize Int -> Int -> Int
forall a. Num a => a -> a -> a
* (Int
expandTimesInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1))
Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
forall b. Ptr b
lastBlock Word8
0 Int
blockSize
let randomBlockPtrs :: [Ptr b]
randomBlockPtrs = (Int -> Ptr b) -> [Int] -> [Ptr b]
forall a b. (a -> b) -> [a] -> [b]
map (Ptr a -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
plusPtr Ptr a
dstPtr (Int -> Ptr b) -> (Int -> Int) -> Int -> Ptr b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int -> Int
forall a. Num a => a -> a -> a
(*) Int
blockSize) [Int
0..(Int
expandTimesInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2)]
rng
rng' <- (rng -> Ptr Word8 -> IO rng) -> rng -> [Ptr Word8] -> IO rng
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM rng -> Ptr Word8 -> IO rng
forall b. DRG b => b -> Ptr Word8 -> IO b
fillRandomBlock rng
rng [Ptr Word8]
forall b. [Ptr b]
randomBlockPtrs
(Ptr Word8 -> IO ()) -> [Ptr Word8] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Ptr Word8 -> Ptr Word8 -> IO ()
addRandomBlock Ptr Word8
forall b. Ptr b
lastBlock) [Ptr Word8]
forall b. [Ptr b]
randomBlockPtrs
ba -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
src ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
srcPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem Ptr Word8
srcPtr Ptr Word8
forall b. Ptr b
lastBlock Int
blockSize
rng -> IO rng
forall (m :: * -> *) a. Monad m => a -> m a
return rng
rng'
addRandomBlock :: Ptr Word8 -> Ptr Word8 -> IO ()
addRandomBlock Ptr Word8
lastBlock Ptr Word8
blockPtr = do
Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem Ptr Word8
blockPtr Ptr Word8
lastBlock Int
blockSize
hash -> Ptr Word8 -> Int -> IO ()
forall hash.
HashAlgorithm hash =>
hash -> Ptr Word8 -> Int -> IO ()
diffuse hash
hashAlg Ptr Word8
lastBlock Int
blockSize
fillRandomBlock :: b -> Ptr Word8 -> IO b
fillRandomBlock b
g Ptr Word8
blockPtr = do
let (Bytes
rand :: Bytes, b
g') = Int -> b -> (Bytes, b)
forall gen byteArray.
(DRG gen, ByteArray byteArray) =>
Int -> gen -> (byteArray, gen)
randomBytesGenerate Int
blockSize b
g
Bytes -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Bytes
rand ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
randPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
blockPtr Ptr Word8
randPtr Int
blockSize
b -> IO b
forall (m :: * -> *) a. Monad m => a -> m a
return b
g'
merge :: (ByteArray ba, HashAlgorithm hash)
=> hash
-> Int
-> ba
-> ba
{-# NOINLINE merge #-}
merge :: hash -> Int -> ba -> ba
merge hash
hashAlg Int
expandTimes ba
bs
| Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0 = [Char] -> ba
forall a. HasCallStack => [Char] -> a
error [Char]
"diffused data not a multiple of expandTimes"
| Int
originalSize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = [Char] -> ba
forall a. HasCallStack => [Char] -> a
error [Char]
"diffused data null"
| Bool
otherwise = Int -> (Ptr Word8 -> IO ()) -> ba
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
originalSize ((Ptr Word8 -> IO ()) -> ba) -> (Ptr Word8 -> IO ()) -> ba
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
dstPtr ->
ba -> (Ptr Any -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray ba
bs ((Ptr Any -> IO ()) -> IO ()) -> (Ptr Any -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Any
srcPtr -> do
Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
dstPtr Word8
0 Int
originalSize
[Int] -> (Int -> IO ()) -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0..(Int
expandTimesInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2)] ((Int -> IO ()) -> IO ()) -> (Int -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem (Ptr Any
srcPtr Ptr Any -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
originalSize)) Ptr Word8
dstPtr Int
originalSize
hash -> Ptr Word8 -> Int -> IO ()
forall hash.
HashAlgorithm hash =>
hash -> Ptr Word8 -> Int -> IO ()
diffuse hash
hashAlg Ptr Word8
dstPtr Int
originalSize
Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem (Ptr Any
srcPtr Ptr Any -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` ((Int
expandTimesInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
originalSize)) Ptr Word8
dstPtr Int
originalSize
() -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
where (Int
originalSize,Int
r) = Int
len Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Int
expandTimes
len :: Int
len = ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ba
bs
xorMem :: Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem :: Ptr Word8 -> Ptr Word8 -> Int -> IO ()
xorMem Ptr Word8
src Ptr Word8
dst Int
sz
| Int
sz Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
64 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int -> Ptr Word64 -> Ptr Word64 -> Int -> IO ()
forall b.
(Storable b, Bits b) =>
Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
8 (Ptr Word8 -> Ptr Word64
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
src :: Ptr Word64) (Ptr Word8 -> Ptr Word64
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
dst) Int
sz
| Int
sz Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
32 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Int -> Ptr Word32 -> Ptr Word32 -> Int -> IO ()
forall b.
(Storable b, Bits b) =>
Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
4 (Ptr Word8 -> Ptr Word32
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
src :: Ptr Word32) (Ptr Word8 -> Ptr Word32
forall a b. Ptr a -> Ptr b
castPtr Ptr Word8
dst) Int
sz
| Bool
otherwise = Int -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
forall b.
(Storable b, Bits b) =>
Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
1 (Ptr Word8
src :: Ptr Word8) Ptr Word8
dst Int
sz
where loop :: Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
_ Ptr b
_ Ptr b
_ Int
0 = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
loop Int
incr Ptr b
s Ptr b
d Int
n = do b
a <- Ptr b -> IO b
forall a. Storable a => Ptr a -> IO a
peek Ptr b
s
b
b <- Ptr b -> IO b
forall a. Storable a => Ptr a -> IO a
peek Ptr b
d
Ptr b -> b -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr b
d (b
a b -> b -> b
forall a. Bits a => a -> a -> a
`xor` b
b)
Int -> Ptr b -> Ptr b -> Int -> IO ()
loop Int
incr (Ptr b
s Ptr b -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
incr) (Ptr b
d Ptr b -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
incr) (Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
incr)
diffuse :: HashAlgorithm hash
=> hash
-> Ptr Word8
-> Int
-> IO ()
diffuse :: hash -> Ptr Word8 -> Int -> IO ()
diffuse hash
hashAlg Ptr Word8
src Int
sz = Ptr Word8 -> Int -> IO ()
loop Ptr Word8
src Int
0
where (Int
full,Int
pad) = Int
sz Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Int
digestSize
loop :: Ptr Word8 -> Int -> IO ()
loop Ptr Word8
s Int
i
| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
full = do Digest hash
h <- Int -> Ptr Word8 -> Int -> IO (Digest hash)
forall (m :: * -> *).
Monad m =>
Int -> Ptr Word8 -> Int -> m (Digest hash)
hashBlock Int
i Ptr Word8
s Int
digestSize
Digest hash -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Digest hash
h ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
hPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
s Ptr Word8
hPtr Int
digestSize
Ptr Word8 -> Int -> IO ()
loop (Ptr Word8
s Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
digestSize) (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
| Int
pad Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0 = do Digest hash
h <- Int -> Ptr Word8 -> Int -> IO (Digest hash)
forall (m :: * -> *).
Monad m =>
Int -> Ptr Word8 -> Int -> m (Digest hash)
hashBlock Int
i Ptr Word8
s Int
pad
Digest hash -> (Ptr Word8 -> IO ()) -> IO ()
forall ba p a. ByteArrayAccess ba => ba -> (Ptr p -> IO a) -> IO a
B.withByteArray Digest hash
h ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
hPtr -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
s Ptr Word8
hPtr Int
pad
() -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
| Bool
otherwise = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
digestSize :: Int
digestSize = hash -> Int
forall a. HashAlgorithm a => a -> Int
hashDigestSize hash
hashAlg
hashBlock :: Int -> Ptr Word8 -> Int -> m (Digest hash)
hashBlock Int
n Ptr Word8
p Int
hashSz = do
let ctx :: Context hash
ctx = hash -> Context hash
forall alg. HashAlgorithm alg => alg -> Context alg
hashInitWith hash
hashAlg
Digest hash -> m (Digest hash)
forall (m :: * -> *) a. Monad m => a -> m a
return (Digest hash -> m (Digest hash)) -> Digest hash -> m (Digest hash)
forall a b. (a -> b) -> a -> b
$! Context hash -> Digest hash
forall a. HashAlgorithm a => Context a -> Digest a
hashFinalize (Context hash -> Digest hash) -> Context hash -> Digest hash
forall a b. (a -> b) -> a -> b
$ Context hash -> MemView -> Context hash
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
Context a -> ba -> Context a
hashUpdate (Context hash -> Bytes -> Context hash
forall ba a.
(ByteArrayAccess ba, HashAlgorithm a) =>
Context a -> ba -> Context a
hashUpdate Context hash
ctx (Int -> Bytes
be32 Int
n)) (Ptr Word8 -> Int -> MemView
MemView Ptr Word8
p Int
hashSz)
be32 :: Int -> Bytes
be32 :: Int -> Bytes
be32 Int
n = Int -> (Ptr Word8 -> IO ()) -> Bytes
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze Int
4 ((Ptr Word8 -> IO ()) -> Bytes) -> (Ptr Word8 -> IO ()) -> Bytes
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> do
Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
ptr (Int -> Word8
f8 (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
24))
Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
ptr Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (Int -> Word8
f8 (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
16))
Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
ptr Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
2) (Int -> Word8
f8 (Int
n Int -> Int -> Int
forall a. Bits a => a -> Int -> a
`shiftR` Int
8))
Ptr Word8 -> Word8 -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke (Ptr Word8
ptr Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
3) (Int -> Word8
f8 Int
n)
where
f8 :: Int -> Word8
f8 :: Int -> Word8
f8 = Int -> Word8
forall a b. (Integral a, Num b) => a -> b
fromIntegral