-- |
-- Module      : Data.Memory.PtrMethods
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- methods to manipulate raw memory representation
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE UnboxedTuples #-}
{-# LANGUAGE ForeignFunctionInterface #-}
module Data.Memory.PtrMethods
    ( memCreateTemporary
    , memXor
    , memXorWith
    , memCopy
    , memSet
    , memReverse
    , memEqual
    , memConstEqual
    , memCompare
    ) where

import           Data.Memory.Internal.Imports
import           Foreign.Ptr              (Ptr, plusPtr)
import           Foreign.Storable         (peek, poke, peekByteOff)
import           Foreign.C.Types
import           Foreign.Marshal.Alloc    (allocaBytesAligned)
import           Data.Bits                ((.|.), xor)

-- | Create a new temporary buffer
memCreateTemporary :: Int -> (Ptr Word8 -> IO a) -> IO a
memCreateTemporary :: forall a. Int -> (Ptr Word8 -> IO a) -> IO a
memCreateTemporary Int
size Ptr Word8 -> IO a
f = forall a b. Int -> Int -> (Ptr a -> IO b) -> IO b
allocaBytesAligned Int
size Int
8 Ptr Word8 -> IO a
f

-- | xor bytes from source1 and source2 to destination
-- 
-- d = s1 xor s2
--
-- s1, nor s2 are modified unless d point to s1 or s2
memXor :: Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memXor :: Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memXor Ptr Word8
_ Ptr Word8
_  Ptr Word8
_  Int
0 = forall (m :: * -> *) a. Monad m => a -> m a
return ()
memXor Ptr Word8
d Ptr Word8
s1 Ptr Word8
s2 Int
n = do
    (forall a. Bits a => a -> a -> a
xor forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
s1 forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
s2) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
d
    Ptr Word8 -> Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memXor (Ptr Word8
d forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (Ptr Word8
s1 forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (Ptr Word8
s2 forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (Int
nforall a. Num a => a -> a -> a
-Int
1)

-- | xor bytes from source with a specific value to destination
--
-- d = replicate (sizeof s) v `xor` s
memXorWith :: Ptr Word8 -> Word8 -> Ptr Word8 -> Int -> IO ()
memXorWith :: Ptr Word8 -> Word8 -> Ptr Word8 -> Int -> IO ()
memXorWith Ptr Word8
destination !Word8
v Ptr Word8
source Int
bytes
    | Ptr Word8
destination forall a. Eq a => a -> a -> Bool
== Ptr Word8
source = forall {t}. (Ord t, Num t) => Ptr Word8 -> t -> IO ()
loopInplace Ptr Word8
source Int
bytes
    | Bool
otherwise             = forall {t}. (Ord t, Num t) => Ptr Word8 -> Ptr Word8 -> t -> IO ()
loop Ptr Word8
destination Ptr Word8
source Int
bytes
  where
    loop :: Ptr Word8 -> Ptr Word8 -> t -> IO ()
loop !Ptr Word8
d !Ptr Word8
s t
n = forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (t
n forall a. Ord a => a -> a -> Bool
> t
0) forall a b. (a -> b) -> a -> b
$ do
        forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
s forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
d forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Bits a => a -> a -> a
xor Word8
v
        Ptr Word8 -> Ptr Word8 -> t -> IO ()
loop (Ptr Word8
d forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (Ptr Word8
s forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (t
nforall a. Num a => a -> a -> a
-t
1)

    loopInplace :: Ptr Word8 -> t -> IO ()
loopInplace !Ptr Word8
s t
n = forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (t
n forall a. Ord a => a -> a -> Bool
> t
0) forall a b. (a -> b) -> a -> b
$ do
        forall a. Storable a => Ptr a -> IO a
peek Ptr Word8
s forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
s forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Bits a => a -> a -> a
xor Word8
v
        Ptr Word8 -> t -> IO ()
loopInplace (Ptr Word8
s forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) (t
nforall a. Num a => a -> a -> a
-t
1)

-- | Copy a set number of bytes from @src to @dst
memCopy :: Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy :: Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memCopy Ptr Word8
dst Ptr Word8
src Int
n = Ptr Word8 -> Ptr Word8 -> CSize -> IO ()
c_memcpy Ptr Word8
dst Ptr Word8
src (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n)
{-# INLINE memCopy #-}

-- | Set @n number of bytes to the same value @v
memSet :: Ptr Word8 -> Word8 -> Int -> IO ()
memSet :: Ptr Word8 -> Word8 -> Int -> IO ()
memSet Ptr Word8
start Word8
v Int
n = Ptr Word8 -> Word8 -> CSize -> IO ()
c_memset Ptr Word8
start Word8
v (forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \()
_ -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
{-# INLINE memSet #-}

-- | Reverse a set number of bytes from @src@ to @dst@.  Memory
-- locations should not overlap.
memReverse :: Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memReverse :: Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memReverse Ptr Word8
d Ptr Word8
s Int
n
    | Int
n forall a. Ord a => a -> a -> Bool
> Int
0 = do forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Word8
s (Int
n forall a. Num a => a -> a -> a
- Int
1) forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Word8
d
                 Ptr Word8 -> Ptr Word8 -> Int -> IO ()
memReverse (Ptr Word8
d forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
1) Ptr Word8
s (Int
n forall a. Num a => a -> a -> a
- Int
1)
    | Bool
otherwise = forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Check if two piece of memory are equals
memEqual :: Ptr Word8 -> Ptr Word8 -> Int -> IO Bool
memEqual :: Ptr Word8 -> Ptr Word8 -> Int -> IO Bool
memEqual Ptr Word8
p1 Ptr Word8
p2 Int
n = Int -> IO Bool
loop Int
0
  where
    loop :: Int -> IO Bool
loop Int
i
        | Int
i forall a. Eq a => a -> a -> Bool
== Int
n    = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
        | Bool
otherwise = do
            Bool
e <- forall a. Eq a => a -> a -> Bool
(==) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Word8
p1 Int
i forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Word8
p2 Int
i :: IO Word8)
            if Bool
e then Int -> IO Bool
loop (Int
iforall a. Num a => a -> a -> a
+Int
1) else forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

-- | Compare two piece of memory and returns how they compare
memCompare :: Ptr Word8 -> Ptr Word8 -> Int -> IO Ordering
memCompare :: Ptr Word8 -> Ptr Word8 -> Int -> IO Ordering
memCompare Ptr Word8
p1 Ptr Word8
p2 Int
n = Int -> IO Ordering
loop Int
0
  where
    loop :: Int -> IO Ordering
loop Int
i
        | Int
i forall a. Eq a => a -> a -> Bool
== Int
n    = forall (m :: * -> *) a. Monad m => a -> m a
return Ordering
EQ
        | Bool
otherwise = do
            Ordering
e <- forall a. Ord a => a -> a -> Ordering
compare forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Word8
p1 Int
i forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Word8
p2 Int
i :: IO Word8)
            if Ordering
e forall a. Eq a => a -> a -> Bool
== Ordering
EQ then Int -> IO Ordering
loop (Int
iforall a. Num a => a -> a -> a
+Int
1) else forall (m :: * -> *) a. Monad m => a -> m a
return Ordering
e

-- | A constant time equality test for 2 Memory buffers
--
-- compared to normal equality function, this function will go
-- over all the bytes present before yielding a result even when
-- knowing the overall result early in the processing.
memConstEqual :: Ptr Word8 -> Ptr Word8 -> Int -> IO Bool
memConstEqual :: Ptr Word8 -> Ptr Word8 -> Int -> IO Bool
memConstEqual Ptr Word8
p1 Ptr Word8
p2 Int
n = Int -> Word8 -> IO Bool
loop Int
0 Word8
0
  where
    loop :: Int -> Word8 -> IO Bool
loop Int
i !Word8
acc
        | Int
i forall a. Eq a => a -> a -> Bool
== Int
n    = forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$! Word8
acc forall a. Eq a => a -> a -> Bool
== Word8
0
        | Bool
otherwise = do
            Word8
e <- forall a. Bits a => a -> a -> a
xor forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Word8
p1 Int
i forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Word8
p2 Int
i :: IO Word8)
            Int -> Word8 -> IO Bool
loop (Int
iforall a. Num a => a -> a -> a
+Int
1) (Word8
acc forall a. Bits a => a -> a -> a
.|. Word8
e)

foreign import ccall unsafe "memset"
    c_memset :: Ptr Word8 -> Word8 -> CSize -> IO ()

foreign import ccall unsafe "memcpy"
    c_memcpy :: Ptr Word8 -> Ptr Word8 -> CSize -> IO ()