{-|

Some utility function for byte strings.

-}

{-# LANGUAGE FlexibleContexts #-}
module Raaz.Core.Util.ByteString
       ( length, replicate
       , fromByteStringStorable
       , create, createFrom
       , withByteString
       , unsafeCopyToPointer
       , unsafeNCopyToPointer
       ) where

import           Prelude            hiding (length, replicate)
import qualified Data.ByteString    as B
import           Data.ByteString    (ByteString)
import qualified Data.ByteString.Internal as BI
import           Data.Word
import           Foreign.ForeignPtr (withForeignPtr)
import           Foreign.Ptr        (castPtr, plusPtr)
import           Foreign.Storable   (peek, Storable)

import           System.IO.Unsafe   (unsafePerformIO)

import           Raaz.Core.Types.Pointer
import           Raaz.Core.Types.Copying

-- | A typesafe length for Bytestring
length :: ByteString -> BYTES Int
length :: ByteString -> BYTES Int
length = Int -> BYTES Int
forall a. a -> BYTES a
BYTES (Int -> BYTES Int)
-> (ByteString -> Int) -> ByteString -> BYTES Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Int
B.length

-- | A type safe version of replicate
replicate :: LengthUnit l => l -> Word8 -> ByteString
replicate :: l -> Word8 -> ByteString
replicate l
l = Int -> Word8 -> ByteString
B.replicate Int
sz
  where BYTES Int
sz = l -> BYTES Int
forall u. LengthUnit u => u -> BYTES Int
inBytes l
l

-- | Copy the bytestring to the crypto buffer. This operation leads to
-- undefined behaviour if the crypto pointer points to an area smaller
-- than the size of the byte string.
unsafeCopyToPointer :: ByteString   -- ^ The source.
                    -> Pointer      -- ^ The destination.
                    -> IO ()
unsafeCopyToPointer :: ByteString -> Pointer -> IO ()
unsafeCopyToPointer ByteString
bs Pointer
cptr =  ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fptr ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
           \ Ptr Word8
p -> Dest Pointer -> Src Pointer -> BYTES Int -> IO ()
forall (m :: * -> *) l.
(MonadIO m, LengthUnit l) =>
Dest Pointer -> Src Pointer -> l -> m ()
memcpy Dest Pointer
forall b. Dest (Ptr b)
dptr (Pointer -> Src Pointer
forall a. a -> Src a
source (Pointer -> Src Pointer) -> Pointer -> Src Pointer
forall a b. (a -> b) -> a -> b
$ Ptr Word8
p Ptr Word8 -> Int -> Pointer
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
offset) (Int -> BYTES Int
forall a. a -> BYTES a
BYTES Int
n)
    where (ForeignPtr Word8
fptr, Int
offset,Int
n) = ByteString -> (ForeignPtr Word8, Int, Int)
BI.toForeignPtr ByteString
bs
          dptr :: Dest (Ptr b)
dptr = Ptr b -> Dest (Ptr b)
forall a. a -> Dest a
destination (Ptr b -> Dest (Ptr b)) -> Ptr b -> Dest (Ptr b)
forall a b. (a -> b) -> a -> b
$ Pointer -> Ptr b
forall a b. Ptr a -> Ptr b
castPtr Pointer
cptr


-- | Similar to `unsafeCopyToPointer` but takes an additional input
-- @n@ which is the number of bytes (expressed in type safe length
-- units) to transfer. This operation leads to undefined behaviour if
-- either the bytestring is shorter than @n@ or the crypto pointer
-- points to an area smaller than @n@.
unsafeNCopyToPointer :: LengthUnit n
                       => n              -- ^ length of data to be copied
                       -> ByteString     -- ^ The source byte string
                       -> Pointer        -- ^ The buffer
                       -> IO ()
unsafeNCopyToPointer :: n -> ByteString -> Pointer -> IO ()
unsafeNCopyToPointer n
n ByteString
bs Pointer
cptr = ForeignPtr Word8 -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fptr ((Ptr Word8 -> IO ()) -> IO ()) -> (Ptr Word8 -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$
           \ Ptr Word8
p -> Dest Pointer -> Src Pointer -> n -> IO ()
forall (m :: * -> *) l.
(MonadIO m, LengthUnit l) =>
Dest Pointer -> Src Pointer -> l -> m ()
memcpy Dest Pointer
forall b. Dest (Ptr b)
dptr (Pointer -> Src Pointer
forall a. a -> Src a
source (Pointer -> Src Pointer) -> Pointer -> Src Pointer
forall a b. (a -> b) -> a -> b
$ Ptr Word8
p Ptr Word8 -> Int -> Pointer
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
offset) n
n
    where (ForeignPtr Word8
fptr, Int
offset,Int
_) = ByteString -> (ForeignPtr Word8, Int, Int)
BI.toForeignPtr ByteString
bs
          dptr :: Dest (Ptr b)
dptr             = Ptr b -> Dest (Ptr b)
forall a. a -> Dest a
destination (Ptr b -> Dest (Ptr b)) -> Ptr b -> Dest (Ptr b)
forall a b. (a -> b) -> a -> b
$ Pointer -> Ptr b
forall a b. Ptr a -> Ptr b
castPtr Pointer
cptr

-- | Works directly on the pointer associated with the
-- `ByteString`. This function should only read and not modify the
-- contents of the pointer.
withByteString :: ByteString -> (Pointer -> IO a) -> IO a
withByteString :: ByteString -> (Pointer -> IO a) -> IO a
withByteString ByteString
bs Pointer -> IO a
f = ForeignPtr Word8 -> (Ptr Word8 -> IO a) -> IO a
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fptr (Pointer -> IO a
f (Pointer -> IO a) -> (Ptr Word8 -> Pointer) -> Ptr Word8 -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ptr Any -> Int -> Pointer) -> Int -> Ptr Any -> Pointer
forall a b c. (a -> b -> c) -> b -> a -> c
flip Ptr Any -> Int -> Pointer
forall a b. Ptr a -> Int -> Ptr b
plusPtr Int
off (Ptr Any -> Pointer)
-> (Ptr Word8 -> Ptr Any) -> Ptr Word8 -> Pointer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr Word8 -> Ptr Any
forall a b. Ptr a -> Ptr b
castPtr)
  where (ForeignPtr Word8
fptr, Int
off, Int
_) = ByteString -> (ForeignPtr Word8, Int, Int)
BI.toForeignPtr ByteString
bs

-- | Get the value from the bytestring using `peek`.
fromByteStringStorable :: Storable k => ByteString -> k
fromByteStringStorable :: ByteString -> k
fromByteStringStorable ByteString
str = IO k -> k
forall a. IO a -> a
unsafePerformIO (IO k -> k) -> IO k -> k
forall a b. (a -> b) -> a -> b
$ ByteString -> (Pointer -> IO k) -> IO k
forall a. ByteString -> (Pointer -> IO a) -> IO a
withByteString ByteString
str (Ptr k -> IO k
forall a. Storable a => Ptr a -> IO a
peek (Ptr k -> IO k) -> (Pointer -> Ptr k) -> Pointer -> IO k
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pointer -> Ptr k
forall a b. Ptr a -> Ptr b
castPtr)


-- | The action @create l act@ creates a length @l@ bytestring where
-- the contents are filled using the the @act@ to fill the buffer.
create :: LengthUnit l => l -> (Pointer -> IO ()) -> IO ByteString
create :: l -> (Pointer -> IO ()) -> IO ByteString
create l
l Pointer -> IO ()
act = (Ptr Word8 -> IO ()) -> IO ByteString
myCreate (Pointer -> IO ()
act (Pointer -> IO ()) -> (Ptr Word8 -> Pointer) -> Ptr Word8 -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr Word8 -> Pointer
forall a b. Ptr a -> Ptr b
castPtr)
  where myCreate :: (Ptr Word8 -> IO ()) -> IO ByteString
myCreate =  Int -> (Ptr Word8 -> IO ()) -> IO ByteString
BI.create (Int -> (Ptr Word8 -> IO ()) -> IO ByteString)
-> Int -> (Ptr Word8 -> IO ()) -> IO ByteString
forall a b. (a -> b) -> a -> b
$ BYTES Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (BYTES Int -> Int) -> BYTES Int -> Int
forall a b. (a -> b) -> a -> b
$ l -> BYTES Int
forall u. LengthUnit u => u -> BYTES Int
inBytes l
l

-- | The IO action @createFrom n cptr@ creates a bytestring by copying
-- @n@ bytes from the pointer @cptr@.
createFrom :: LengthUnit l => l -> Pointer -> IO ByteString
createFrom :: l -> Pointer -> IO ByteString
createFrom l
l Pointer
cptr = l -> (Pointer -> IO ()) -> IO ByteString
forall l. LengthUnit l => l -> (Pointer -> IO ()) -> IO ByteString
create l
l Pointer -> IO ()
forall (m :: * -> *) a. MonadIO m => Ptr a -> m ()
filler
  where filler :: Ptr a -> m ()
filler Ptr a
dptr = Dest Pointer -> Src Pointer -> l -> m ()
forall (m :: * -> *) l.
(MonadIO m, LengthUnit l) =>
Dest Pointer -> Src Pointer -> l -> m ()
memcpy (Pointer -> Dest Pointer
forall a. a -> Dest a
destination (Pointer -> Dest Pointer) -> Pointer -> Dest Pointer
forall a b. (a -> b) -> a -> b
$ Ptr a -> Pointer
forall a b. Ptr a -> Ptr b
castPtr Ptr a
dptr) (Pointer -> Src Pointer
forall a. a -> Src a
source Pointer
cptr) l
l

----------------------  Hexadecimal encoding. -----------------------------------