-- | Module to reading from and writing into buffers.

{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE TypeSynonymInstances       #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE CPP                        #-}

module Raaz.Core.Transfer
       ( -- * Transfer actions.
         -- $transfer$

         -- ** Read action
         ReadM, ReadIO, bytesToRead, unsafeRead
       , readBytes, readInto

         -- ** Write action.
       ,  WriteM, WriteIO, bytesToWrite, unsafeWrite
       , write, writeStorable, writeVector, writeStorableVector
       , writeFrom, writeBytes
       , padWrite, prependWrite, glueWrites
       , writeByteString, skipWrite

       ) where

import           Control.Monad.IO.Class
import           Data.ByteString           (ByteString)
import           Data.String
import           Data.ByteString.Internal  (unsafeCreate)
import           Data.Monoid
import qualified Data.Vector.Generic       as G
import           Data.Word                 (Word8)
import           Foreign.Ptr               (castPtr, Ptr)
import           Foreign.Storable          ( Storable, poke )

import           Raaz.Core.MonoidalAction
import           Raaz.Core.Types.Copying
import           Raaz.Core.Types.Endian
import           Raaz.Core.Types.Pointer
import           Raaz.Core.Util.ByteString as BU
import           Raaz.Core.Encode

-- $transfer$
--
-- Low level buffer operations are problematic portions of any
-- crypto-library. Buffers are usually represented by the starting
-- pointer and one needs to keep track of the buffer sizes
-- carefully. An operation that writes into a buffer, if it writes
-- beyond the actual size of the buffer, can lead to a possible remote
-- code execution. On the other hand, when reading from a buffer, if
-- we read beyond the buffer it can leak private data to the attacker
-- (as in the case of Heart bleed bug). This module is indented to
-- give a relatively high level interface to this problem. We expose
-- two types, the `ReadM` and the `WriteM` type which deals with these
-- two aspects. Both these actions keep track of the number of bytes
-- that they transfer.

-- Complex reads and writes can be constructed using the monoid
-- instance of these types.



-- | This monoid captures a transfer action.
newtype TransferM m = TransferM { TransferM m -> m ()
unTransferM :: m () }

#if MIN_VERSION_base(4,11,0)
instance Monad m => Semigroup (TransferM m) where
  <> :: TransferM m -> TransferM m -> TransferM m
(<>) TransferM m
wa TransferM m
wb = m () -> TransferM m
forall (m :: * -> *). m () -> TransferM m
TransferM (m () -> TransferM m) -> m () -> TransferM m
forall a b. (a -> b) -> a -> b
$ TransferM m -> m ()
forall (m :: * -> *). TransferM m -> m ()
unTransferM TransferM m
wa m () -> m () -> m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> TransferM m -> m ()
forall (m :: * -> *). TransferM m -> m ()
unTransferM TransferM m
wb
#endif


instance Monad m => Monoid (TransferM m) where
  mempty :: TransferM m
mempty        = m () -> TransferM m
forall (m :: * -> *). m () -> TransferM m
TransferM (m () -> TransferM m) -> m () -> TransferM m
forall a b. (a -> b) -> a -> b
$ () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
  {-# INLINE mempty #-}

  mappend :: TransferM m -> TransferM m -> TransferM m
mappend TransferM m
wa TransferM m
wb = m () -> TransferM m
forall (m :: * -> *). m () -> TransferM m
TransferM (m () -> TransferM m) -> m () -> TransferM m
forall a b. (a -> b) -> a -> b
$ TransferM m -> m ()
forall (m :: * -> *). TransferM m -> m ()
unTransferM TransferM m
wa m () -> m () -> m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> TransferM m -> m ()
forall (m :: * -> *). TransferM m -> m ()
unTransferM TransferM m
wb
  {-# INLINE mappend #-}

  mconcat :: [TransferM m] -> TransferM m
mconcat = m () -> TransferM m
forall (m :: * -> *). m () -> TransferM m
TransferM (m () -> TransferM m)
-> ([TransferM m] -> m ()) -> [TransferM m] -> TransferM m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TransferM m -> m ()) -> [TransferM m] -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ TransferM m -> m ()
forall (m :: * -> *). TransferM m -> m ()
unTransferM
  {-# INLINE mconcat #-}

-- | A action that transfers bytes from its input pointer. Transfer
-- could either be writing or reading.
type TransferAction m = Pointer -> TransferM m

instance LAction (BYTES Int) (TransferAction m) where
  BYTES Int
offset <.> :: BYTES Int -> TransferAction m -> TransferAction m
<.> TransferAction m
action = TransferAction m
action TransferAction m -> (Pointer -> Pointer) -> TransferAction m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (BYTES Int
offsetBYTES Int -> Pointer -> Pointer
forall m space. LAction m space => m -> space -> space
<.>)
  {-# INLINE (<.>) #-}

instance Monad m => Distributive (BYTES Int) (TransferAction m)

-- | Byte transfers that keep track of the number of bytes that were
-- transferred (from/into) its input buffer.
type Transfer m = SemiR (TransferAction m) (BYTES Int)

-- | Make an explicit transfer action given.
makeTransfer :: LengthUnit u => u -> (Pointer -> m ()) -> Transfer m
{-# INLINE makeTransfer #-}
makeTransfer :: u -> (Pointer -> m ()) -> Transfer m
makeTransfer u
sz Pointer -> m ()
action = (Pointer -> TransferM m) -> BYTES Int -> Transfer m
forall space m. space -> m -> SemiR space m
SemiR (m () -> TransferM m
forall (m :: * -> *). m () -> TransferM m
TransferM (m () -> TransferM m)
-> (Pointer -> m ()) -> Pointer -> TransferM m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pointer -> m ()
action) (BYTES Int -> Transfer m) -> BYTES Int -> Transfer m
forall a b. (a -> b) -> a -> b
$ u -> BYTES Int
forall u. LengthUnit u => u -> BYTES Int
inBytes u
sz


-------------------------- Monoid for writing stuff --------------------------------------

-- | An element of type `WriteM m` is an action which when executed transfers bytes
-- /into/ its input buffer.  The type @`WriteM` m@ forms a monoid and
-- hence can be concatnated using the `<>` operator.
newtype WriteM m = WriteM { WriteM m -> Transfer m
unWriteM :: Transfer m }
#if MIN_VERSION_base(4,11,0)
                 deriving (b -> WriteM m -> WriteM m
NonEmpty (WriteM m) -> WriteM m
WriteM m -> WriteM m -> WriteM m
(WriteM m -> WriteM m -> WriteM m)
-> (NonEmpty (WriteM m) -> WriteM m)
-> (forall b. Integral b => b -> WriteM m -> WriteM m)
-> Semigroup (WriteM m)
forall b. Integral b => b -> WriteM m -> WriteM m
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
forall (m :: * -> *). Monad m => NonEmpty (WriteM m) -> WriteM m
forall (m :: * -> *). Monad m => WriteM m -> WriteM m -> WriteM m
forall (m :: * -> *) b.
(Monad m, Integral b) =>
b -> WriteM m -> WriteM m
stimes :: b -> WriteM m -> WriteM m
$cstimes :: forall (m :: * -> *) b.
(Monad m, Integral b) =>
b -> WriteM m -> WriteM m
sconcat :: NonEmpty (WriteM m) -> WriteM m
$csconcat :: forall (m :: * -> *). Monad m => NonEmpty (WriteM m) -> WriteM m
<> :: WriteM m -> WriteM m -> WriteM m
$c<> :: forall (m :: * -> *). Monad m => WriteM m -> WriteM m -> WriteM m
Semigroup, Semigroup (WriteM m)
WriteM m
Semigroup (WriteM m)
-> WriteM m
-> (WriteM m -> WriteM m -> WriteM m)
-> ([WriteM m] -> WriteM m)
-> Monoid (WriteM m)
[WriteM m] -> WriteM m
WriteM m -> WriteM m -> WriteM m
forall a.
Semigroup a -> a -> (a -> a -> a) -> ([a] -> a) -> Monoid a
forall (m :: * -> *). Monad m => Semigroup (WriteM m)
forall (m :: * -> *). Monad m => WriteM m
forall (m :: * -> *). Monad m => [WriteM m] -> WriteM m
forall (m :: * -> *). Monad m => WriteM m -> WriteM m -> WriteM m
mconcat :: [WriteM m] -> WriteM m
$cmconcat :: forall (m :: * -> *). Monad m => [WriteM m] -> WriteM m
mappend :: WriteM m -> WriteM m -> WriteM m
$cmappend :: forall (m :: * -> *). Monad m => WriteM m -> WriteM m -> WriteM m
mempty :: WriteM m
$cmempty :: forall (m :: * -> *). Monad m => WriteM m
$cp1Monoid :: forall (m :: * -> *). Monad m => Semigroup (WriteM m)
Monoid)
#else
                 deriving Monoid
#endif


-- | A write io-action.
type WriteIO = WriteM IO

-- | Returns the bytes that will be written when the write action is performed.
bytesToWrite :: WriteM m -> BYTES Int
bytesToWrite :: WriteM m -> BYTES Int
bytesToWrite = SemiR (TransferAction m) (BYTES Int) -> BYTES Int
forall space m. SemiR space m -> m
semiRMonoid (SemiR (TransferAction m) (BYTES Int) -> BYTES Int)
-> (WriteM m -> SemiR (TransferAction m) (BYTES Int))
-> WriteM m
-> BYTES Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WriteM m -> SemiR (TransferAction m) (BYTES Int)
forall (m :: * -> *). WriteM m -> Transfer m
unWriteM

-- | Perform the write action without any checks of the buffer
unsafeWrite :: WriteM m
            -> Pointer   -- ^ The pointer for the buffer to be written into.
            -> m ()
unsafeWrite :: WriteM m -> Pointer -> m ()
unsafeWrite WriteM m
wr =  TransferM m -> m ()
forall (m :: * -> *). TransferM m -> m ()
unTransferM (TransferM m -> m ())
-> (Pointer -> TransferM m) -> Pointer -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SemiR (Pointer -> TransferM m) (BYTES Int)
-> Pointer -> TransferM m
forall space m. SemiR space m -> space
semiRSpace (WriteM m -> SemiR (Pointer -> TransferM m) (BYTES Int)
forall (m :: * -> *). WriteM m -> Transfer m
unWriteM WriteM m
wr)

-- | Function that explicitly constructs a write action.
makeWrite     :: LengthUnit u => u -> (Pointer -> m ()) -> WriteM m
makeWrite :: u -> (Pointer -> m ()) -> WriteM m
makeWrite u
sz  = Transfer m -> WriteM m
forall (m :: * -> *). Transfer m -> WriteM m
WriteM (Transfer m -> WriteM m)
-> ((Pointer -> m ()) -> Transfer m)
-> (Pointer -> m ())
-> WriteM m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. u -> (Pointer -> m ()) -> Transfer m
forall u (m :: * -> *).
LengthUnit u =>
u -> (Pointer -> m ()) -> Transfer m
makeTransfer u
sz


-- | The expression @`writeStorable` a@ gives a write action that
-- stores a value @a@ in machine endian. The type of the value @a@ has
-- to be an instance of `Storable`. This should be used when we want
-- to talk with C functions and not when talking to the outside world
-- (otherwise this could lead to endian confusion). To take care of
-- endianness use the `write` combinator.
writeStorable :: (MonadIO m, Storable a) => a -> WriteM m
writeStorable :: a -> WriteM m
writeStorable a
a = Transfer m -> WriteM m
forall (m :: * -> *). Transfer m -> WriteM m
WriteM (Transfer m -> WriteM m) -> Transfer m -> WriteM m
forall a b. (a -> b) -> a -> b
$ BYTES Int -> (Pointer -> m ()) -> Transfer m
forall u (m :: * -> *).
LengthUnit u =>
u -> (Pointer -> m ()) -> Transfer m
makeTransfer (a -> BYTES Int
forall a. Storable a => a -> BYTES Int
sizeOf a
a) Pointer -> m ()
forall a. Ptr a -> m ()
pokeIt
  where pokeIt :: Ptr a -> m ()
pokeIt = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (Ptr a -> IO ()) -> Ptr a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Ptr a -> a -> IO ()) -> a -> Ptr a -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Ptr a -> a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke a
a (Ptr a -> IO ()) -> (Ptr a -> Ptr a) -> Ptr a -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr a -> Ptr a
forall a b. Ptr a -> Ptr b
castPtr
-- | The expression @`write` a@ gives a write action that stores a
-- value @a@. One needs the type of the value @a@ to be an instance of
-- `EndianStore`. Proper endian conversion is done irrespective of
-- what the machine endianness is. The man use of this write is to
-- serialize data for the consumption of the outside world.
write :: (MonadIO m, EndianStore a) => a -> WriteM m
write :: a -> WriteM m
write a
a = BYTES Int -> (Pointer -> m ()) -> WriteM m
forall u (m :: * -> *).
LengthUnit u =>
u -> (Pointer -> m ()) -> WriteM m
makeWrite (a -> BYTES Int
forall a. Storable a => a -> BYTES Int
sizeOf a
a) ((Pointer -> m ()) -> WriteM m) -> (Pointer -> m ()) -> WriteM m
forall a b. (a -> b) -> a -> b
$ IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> (Pointer -> IO ()) -> Pointer -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Pointer -> a -> IO ()) -> a -> Pointer -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Ptr a -> a -> IO ()
forall w. EndianStore w => Ptr w -> w -> IO ()
store (Ptr a -> a -> IO ())
-> (Pointer -> Ptr a) -> Pointer -> a -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pointer -> Ptr a
forall a b. Ptr a -> Ptr b
castPtr) a
a

-- | Write many elements from the given buffer
writeFrom :: (MonadIO m, EndianStore a) => Int -> Src (Ptr a) -> WriteM m
writeFrom :: Int -> Src (Ptr a) -> WriteM m
writeFrom Int
n Src (Ptr a)
src = BYTES Int -> (Pointer -> m ()) -> WriteM m
forall u (m :: * -> *).
LengthUnit u =>
u -> (Pointer -> m ()) -> WriteM m
makeWrite (a -> Src (Ptr a) -> BYTES Int
forall a. Storable a => a -> Src (Ptr a) -> BYTES Int
sz a
forall a. HasCallStack => a
undefined Src (Ptr a)
src)
                  ((Pointer -> m ()) -> WriteM m) -> (Pointer -> m ()) -> WriteM m
forall a b. (a -> b) -> a -> b
$ \ Pointer
ptr -> IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO  (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Dest Pointer -> Src (Ptr a) -> Int -> IO ()
forall w.
EndianStore w =>
Dest Pointer -> Src (Ptr w) -> Int -> IO ()
copyToBytes (Pointer -> Dest Pointer
forall a. a -> Dest a
destination Pointer
ptr) Src (Ptr a)
src Int
n
  where sz :: Storable a => a -> Src (Ptr a) -> BYTES Int
        sz :: a -> Src (Ptr a) -> BYTES Int
sz a
a Src (Ptr a)
_ = Int -> BYTES Int
forall a. Enum a => Int -> a
toEnum Int
n BYTES Int -> BYTES Int -> BYTES Int
forall a. Num a => a -> a -> a
* a -> BYTES Int
forall a. Storable a => a -> BYTES Int
sizeOf a
a

-- | The vector version of `writeStorable`.
writeStorableVector :: (Storable a, G.Vector v a, MonadIO m) => v a -> WriteM m
{-# INLINE writeStorableVector #-}
writeStorableVector :: v a -> WriteM m
writeStorableVector = (WriteM m -> a -> WriteM m) -> WriteM m -> v a -> WriteM m
forall (v :: * -> *) b a.
Vector v b =>
(a -> b -> a) -> a -> v b -> a
G.foldl' WriteM m -> a -> WriteM m
forall (m :: * -> *) a.
(MonadIO m, Storable a) =>
WriteM m -> a -> WriteM m
foldFunc WriteM m
forall a. Monoid a => a
mempty
  where foldFunc :: WriteM m -> a -> WriteM m
foldFunc WriteM m
w a
a =  WriteM m
w WriteM m -> WriteM m -> WriteM m
forall a. Semigroup a => a -> a -> a
<> a -> WriteM m
forall (m :: * -> *) a. (MonadIO m, Storable a) => a -> WriteM m
writeStorable a
a

{-

TODO: This function can be slow due to the fact that each time we use
the semi-direct product, we incur a cost due to the lambda being not
lifted.

-}

-- | The vector version of `write`.
writeVector :: (EndianStore a, G.Vector v a, MonadIO m) => v a -> WriteM m
{-# INLINE writeVector #-}
{- TODO: improve this using the fact that the size is known -}

writeVector :: v a -> WriteM m
writeVector = (WriteM m -> a -> WriteM m) -> WriteM m -> v a -> WriteM m
forall (v :: * -> *) b a.
Vector v b =>
(a -> b -> a) -> a -> v b -> a
G.foldl' WriteM m -> a -> WriteM m
forall (m :: * -> *) a.
(MonadIO m, EndianStore a) =>
WriteM m -> a -> WriteM m
foldFunc WriteM m
forall a. Monoid a => a
mempty
  where foldFunc :: WriteM m -> a -> WriteM m
foldFunc WriteM m
w a
a =  WriteM m
w WriteM m -> WriteM m -> WriteM m
forall a. Semigroup a => a -> a -> a
<> a -> WriteM m
forall (m :: * -> *) a. (MonadIO m, EndianStore a) => a -> WriteM m
write a
a
{- TODO: Same as in writeStorableVector -}


-- | The combinator @writeBytes n b@ writes @b@ as the next @n@
-- consecutive bytes.
writeBytes :: (LengthUnit n, MonadIO m) => Word8 -> n -> WriteM m
writeBytes :: Word8 -> n -> WriteM m
writeBytes Word8
w8 n
n = n -> (Pointer -> m ()) -> WriteM m
forall u (m :: * -> *).
LengthUnit u =>
u -> (Pointer -> m ()) -> WriteM m
makeWrite n
n Pointer -> m ()
forall (m :: * -> *). MonadIO m => Pointer -> m ()
memsetIt
  where memsetIt :: Pointer -> m ()
memsetIt Pointer
cptr = IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Pointer -> Word8 -> n -> IO ()
forall (m :: * -> *) l.
(MonadIO m, LengthUnit l) =>
Pointer -> Word8 -> l -> m ()
memset Pointer
cptr Word8
w8 n
n

{-
-- | The write action @padWriteTo w n wr@ is wr padded with the byte @w@ so that the total length
-- is n. If the total bytes written by @wr@ is greater than @n@ then this throws an error.
padWriteTo :: ( LengthUnit n, MonadIO m)
              => Word8     -- ^ the padding byte to use
              -> n         -- ^ the total length to pad to
              -> WriteM m  -- ^ the write that needs padding
              -> WriteM m
padWriteTo w8 n wrm | pl < 0    = error "padToLength: padding length smaller than total length"
                    | otherwise = wrm <> writeBytes w8 n
  where pl = inBytes n - bytesToWrite wrm

-}

-- | The combinator @glueWrites w n hdr ftr@ is equivalent to
-- @hdr <> glue <> ftr@ where the write @glue@ writes as many bytes
-- @w@ so that the total length is aligned to the boundary @n@.
glueWrites :: ( LengthUnit n, MonadIO m)
           =>  Word8    -- ^ The bytes to use in the glue
           -> n        -- ^ The length boundary to align to.
           -> WriteM m -- ^ The header write
           -> WriteM m -- ^ The footer write
           -> WriteM m
glueWrites :: Word8 -> n -> WriteM m -> WriteM m -> WriteM m
glueWrites Word8
w8 n
n WriteM m
hdr WriteM m
ftr = WriteM m
hdr WriteM m -> WriteM m -> WriteM m
forall a. Semigroup a => a -> a -> a
<> Word8 -> BYTES Int -> WriteM m
forall n (m :: * -> *).
(LengthUnit n, MonadIO m) =>
Word8 -> n -> WriteM m
writeBytes Word8
w8 BYTES Int
lglue WriteM m -> WriteM m -> WriteM m
forall a. Semigroup a => a -> a -> a
<> WriteM m
ftr
  where lhead :: BYTES Int
lhead   = WriteM m -> BYTES Int
forall (m :: * -> *). WriteM m -> BYTES Int
bytesToWrite WriteM m
hdr
        lfoot :: BYTES Int
lfoot   = WriteM m -> BYTES Int
forall (m :: * -> *). WriteM m -> BYTES Int
bytesToWrite WriteM m
ftr
        lexceed :: BYTES Int
lexceed = (BYTES Int
lhead BYTES Int -> BYTES Int -> BYTES Int
forall a. Num a => a -> a -> a
+ BYTES Int
lfoot) BYTES Int -> BYTES Int -> BYTES Int
forall a. Integral a => a -> a -> a
`rem` BYTES Int
nBytes  -- bytes exceeding the boundary.
        lglue :: BYTES Int
lglue   = BYTES Int
nBytes BYTES Int -> BYTES Int -> BYTES Int
forall a. Num a => a -> a -> a
- BYTES Int
lexceed
        nBytes :: BYTES Int
nBytes  = n -> BYTES Int
forall u. LengthUnit u => u -> BYTES Int
inBytes n
n



-- | The write action @prependWrite w n wr@ is wr pre-pended with the byte @w@ so that the total length
-- ends at a multiple of @n@.
prependWrite  :: ( LengthUnit n, MonadIO m)
              => Word8     -- ^ the byte to pre-pend with.
              -> n         -- ^ the length to align the message to
              -> WriteM m  -- ^ the message that needs pre-pending
              -> WriteM m
prependWrite :: Word8 -> n -> WriteM m -> WriteM m
prependWrite Word8
w8 n
n = Word8 -> n -> WriteM m -> WriteM m -> WriteM m
forall n (m :: * -> *).
(LengthUnit n, MonadIO m) =>
Word8 -> n -> WriteM m -> WriteM m -> WriteM m
glueWrites Word8
w8 n
n WriteM m
forall a. Monoid a => a
mempty

-- | The write action @padWrite w n wr@ is wr padded with the byte @w@ so that the total length
-- ends at a multiple of @n@.
padWrite :: ( LengthUnit n, MonadIO m)
         => Word8     -- ^ the padding byte to use
         -> n         -- ^ the length to align message to
         -> WriteM m  -- ^ the message that needs padding
         -> WriteM m
padWrite :: Word8 -> n -> WriteM m -> WriteM m
padWrite Word8
w8 n
n = (WriteM m -> WriteM m -> WriteM m)
-> WriteM m -> WriteM m -> WriteM m
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Word8 -> n -> WriteM m -> WriteM m -> WriteM m
forall n (m :: * -> *).
(LengthUnit n, MonadIO m) =>
Word8 -> n -> WriteM m -> WriteM m -> WriteM m
glueWrites Word8
w8 n
n) WriteM m
forall a. Monoid a => a
mempty

-- | Writes a strict bytestring.
writeByteString :: MonadIO m => ByteString -> WriteM m
writeByteString :: ByteString -> WriteM m
writeByteString ByteString
bs = BYTES Int -> (Pointer -> m ()) -> WriteM m
forall u (m :: * -> *).
LengthUnit u =>
u -> (Pointer -> m ()) -> WriteM m
makeWrite (ByteString -> BYTES Int
BU.length ByteString
bs) ((Pointer -> m ()) -> WriteM m) -> (Pointer -> m ()) -> WriteM m
forall a b. (a -> b) -> a -> b
$ IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO  (IO () -> m ()) -> (Pointer -> IO ()) -> Pointer -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Pointer -> IO ()
BU.unsafeCopyToPointer ByteString
bs

-- | A write action that just skips over the given bytes.
skipWrite :: (LengthUnit u, Monad m) => u -> WriteM m
skipWrite :: u -> WriteM m
skipWrite = (u -> (Pointer -> m ()) -> WriteM m)
-> (Pointer -> m ()) -> u -> WriteM m
forall a b c. (a -> b -> c) -> b -> a -> c
flip u -> (Pointer -> m ()) -> WriteM m
forall u (m :: * -> *).
LengthUnit u =>
u -> (Pointer -> m ()) -> WriteM m
makeWrite ((Pointer -> m ()) -> u -> WriteM m)
-> (Pointer -> m ()) -> u -> WriteM m
forall a b. (a -> b) -> a -> b
$ m () -> Pointer -> m ()
forall a b. a -> b -> a
const (m () -> Pointer -> m ()) -> m () -> Pointer -> m ()
forall a b. (a -> b) -> a -> b
$ () -> m ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

instance MonadIO m => IsString (WriteM m)  where
  fromString :: String -> WriteM m
fromString = ByteString -> WriteM m
forall (m :: * -> *). MonadIO m => ByteString -> WriteM m
writeByteString (ByteString -> WriteM m)
-> (String -> ByteString) -> String -> WriteM m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> ByteString
forall a. IsString a => String -> a
fromString

instance Encodable (WriteM IO) where
  {-# INLINE toByteString #-}
  toByteString :: WriteM IO -> ByteString
toByteString WriteM IO
w  = Int -> (Ptr Word8 -> IO ()) -> ByteString
unsafeCreate Int
n ((Ptr Word8 -> IO ()) -> ByteString)
-> (Ptr Word8 -> IO ()) -> ByteString
forall a b. (a -> b) -> a -> b
$ WriteM IO -> Pointer -> IO ()
forall (m :: * -> *). WriteM m -> Pointer -> m ()
unsafeWrite WriteM IO
w (Pointer -> IO ()) -> (Ptr Word8 -> Pointer) -> Ptr Word8 -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr Word8 -> Pointer
forall a b. Ptr a -> Ptr b
castPtr
    where BYTES Int
n = WriteM IO -> BYTES Int
forall (m :: * -> *). WriteM m -> BYTES Int
bytesToWrite WriteM IO
w

  {-# INLINE unsafeFromByteString #-}
  unsafeFromByteString :: ByteString -> WriteM IO
unsafeFromByteString = ByteString -> WriteM IO
forall (m :: * -> *). MonadIO m => ByteString -> WriteM m
writeByteString

  {-# INLINE fromByteString #-}
  fromByteString :: ByteString -> Maybe (WriteM IO)
fromByteString       = WriteM IO -> Maybe (WriteM IO)
forall a. a -> Maybe a
Just (WriteM IO -> Maybe (WriteM IO))
-> (ByteString -> WriteM IO) -> ByteString -> Maybe (WriteM IO)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> WriteM IO
forall (m :: * -> *). MonadIO m => ByteString -> WriteM m
writeByteString

------------------------  Read action ----------------------------

-- | The `ReadM` is the type that captures the act of reading from a buffer
-- and possibly doing some action on the bytes read. Although
-- inaccurate, it is helpful to think of elements of `ReadM` as action
-- that on an input buffer transfers data from it to some unspecified
-- source.
--
-- Read actions form a monoid with the following semantics: if @r1@
-- and @r2@ are two read actions then @r1 `<>` r2@ first reads the
-- data associated from @r1@ and then the read associated with the
-- data @r2@.

newtype ReadM m = ReadM { ReadM m -> Transfer m
unReadM :: Transfer m}
#if MIN_VERSION_base(4,11,0)
                 deriving (b -> ReadM m -> ReadM m
NonEmpty (ReadM m) -> ReadM m
ReadM m -> ReadM m -> ReadM m
(ReadM m -> ReadM m -> ReadM m)
-> (NonEmpty (ReadM m) -> ReadM m)
-> (forall b. Integral b => b -> ReadM m -> ReadM m)
-> Semigroup (ReadM m)
forall b. Integral b => b -> ReadM m -> ReadM m
forall a.
(a -> a -> a)
-> (NonEmpty a -> a)
-> (forall b. Integral b => b -> a -> a)
-> Semigroup a
forall (m :: * -> *). Monad m => NonEmpty (ReadM m) -> ReadM m
forall (m :: * -> *). Monad m => ReadM m -> ReadM m -> ReadM m
forall (m :: * -> *) b.
(Monad m, Integral b) =>
b -> ReadM m -> ReadM m
stimes :: b -> ReadM m -> ReadM m
$cstimes :: forall (m :: * -> *) b.
(Monad m, Integral b) =>
b -> ReadM m -> ReadM m
sconcat :: NonEmpty (ReadM m) -> ReadM m
$csconcat :: forall (m :: * -> *). Monad m => NonEmpty (ReadM m) -> ReadM m
<> :: ReadM m -> ReadM m -> ReadM m
$c<> :: forall (m :: * -> *). Monad m => ReadM m -> ReadM m -> ReadM m
Semigroup, Semigroup (ReadM m)
ReadM m
Semigroup (ReadM m)
-> ReadM m
-> (ReadM m -> ReadM m -> ReadM m)
-> ([ReadM m] -> ReadM m)
-> Monoid (ReadM m)
[ReadM m] -> ReadM m
ReadM m -> ReadM m -> ReadM m
forall a.
Semigroup a -> a -> (a -> a -> a) -> ([a] -> a) -> Monoid a
forall (m :: * -> *). Monad m => Semigroup (ReadM m)
forall (m :: * -> *). Monad m => ReadM m
forall (m :: * -> *). Monad m => [ReadM m] -> ReadM m
forall (m :: * -> *). Monad m => ReadM m -> ReadM m -> ReadM m
mconcat :: [ReadM m] -> ReadM m
$cmconcat :: forall (m :: * -> *). Monad m => [ReadM m] -> ReadM m
mappend :: ReadM m -> ReadM m -> ReadM m
$cmappend :: forall (m :: * -> *). Monad m => ReadM m -> ReadM m -> ReadM m
mempty :: ReadM m
$cmempty :: forall (m :: * -> *). Monad m => ReadM m
$cp1Monoid :: forall (m :: * -> *). Monad m => Semigroup (ReadM m)
Monoid)
#else
                 deriving Monoid
#endif

-- | A read io-action.
type ReadIO = ReadM IO

-- | Function that explicitly constructs a write action.
makeRead     :: LengthUnit u => u -> (Pointer -> m ()) -> ReadM m
makeRead :: u -> (Pointer -> m ()) -> ReadM m
makeRead u
sz  = Transfer m -> ReadM m
forall (m :: * -> *). Transfer m -> ReadM m
ReadM (Transfer m -> ReadM m)
-> ((Pointer -> m ()) -> Transfer m)
-> (Pointer -> m ())
-> ReadM m
forall b c a. (b -> c) -> (a -> b) -> a -> c
. u -> (Pointer -> m ()) -> Transfer m
forall u (m :: * -> *).
LengthUnit u =>
u -> (Pointer -> m ()) -> Transfer m
makeTransfer u
sz


-- | The expression @bytesToRead r@ gives the total number of bytes that
-- would be read from the input buffer if the action @r@ is performed.
bytesToRead :: ReadM m -> BYTES Int
bytesToRead :: ReadM m -> BYTES Int
bytesToRead = SemiR (TransferAction m) (BYTES Int) -> BYTES Int
forall space m. SemiR space m -> m
semiRMonoid (SemiR (TransferAction m) (BYTES Int) -> BYTES Int)
-> (ReadM m -> SemiR (TransferAction m) (BYTES Int))
-> ReadM m
-> BYTES Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ReadM m -> SemiR (TransferAction m) (BYTES Int)
forall (m :: * -> *). ReadM m -> Transfer m
unReadM

-- | The action @unsafeRead r ptr@ results in reading @bytesToRead r@
-- bytes from the buffer pointed by @ptr@. This action is unsafe as it
-- will not (and cannot) check if the action reads beyond what is
-- legally stored at @ptr@.
unsafeRead :: ReadM m
           -> Pointer   -- ^ The pointer for the buffer to be written into.
           -> m ()
unsafeRead :: ReadM m -> Pointer -> m ()
unsafeRead ReadM m
rd =  TransferM m -> m ()
forall (m :: * -> *). TransferM m -> m ()
unTransferM (TransferM m -> m ())
-> (Pointer -> TransferM m) -> Pointer -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SemiR (Pointer -> TransferM m) (BYTES Int)
-> Pointer -> TransferM m
forall space m. SemiR space m -> space
semiRSpace (ReadM m -> SemiR (Pointer -> TransferM m) (BYTES Int)
forall (m :: * -> *). ReadM m -> Transfer m
unReadM ReadM m
rd)

-- | The action @readBytes sz dptr@ gives a read action, which if run on
-- an input buffer, will transfers @sz@ to the destination buffer
-- pointed by @dptr@. Note that it is the responsibility of the user
-- to make sure that @dptr@ has enough space to receive @sz@ units of
-- data if and when the read action is executed.
readBytes :: ( LengthUnit sz, MonadIO m)
          => sz             -- ^ how much to read.
          -> Dest Pointer   -- ^ buffer to read the bytes into
          -> ReadM m
readBytes :: sz -> Dest Pointer -> ReadM m
readBytes sz
sz Dest Pointer
dest = sz -> (Pointer -> m ()) -> ReadM m
forall u (m :: * -> *).
LengthUnit u =>
u -> (Pointer -> m ()) -> ReadM m
makeRead sz
sz
                    ((Pointer -> m ()) -> ReadM m) -> (Pointer -> m ()) -> ReadM m
forall a b. (a -> b) -> a -> b
$ \ Pointer
ptr -> IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO  (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Dest Pointer -> Src Pointer -> sz -> IO ()
forall (m :: * -> *) l.
(MonadIO m, LengthUnit l) =>
Dest Pointer -> Src Pointer -> l -> m ()
memcpy Dest Pointer
dest (Pointer -> Src Pointer
forall a. a -> Src a
source Pointer
ptr) sz
sz

-- | The action @readInto n dptr@ gives a read action which if run on an
-- input buffer, will transfers @n@ elements of type @a@ into the
-- buffer pointed by @dptr@. In particular, the read action @readInto n
-- dptr@ is the same as @readBytes (fromIntegral n :: BYTES Int) dptr@
-- when the type @a@ is `Word8`.
readInto :: (EndianStore a, MonadIO m)
         => Int             -- ^ how many elements to read.
         -> Dest (Ptr a)    -- ^ buffer to read the elements into
         -> ReadM m
readInto :: Int -> Dest (Ptr a) -> ReadM m
readInto Int
n Dest (Ptr a)
dest = BYTES Int -> (Pointer -> m ()) -> ReadM m
forall u (m :: * -> *).
LengthUnit u =>
u -> (Pointer -> m ()) -> ReadM m
makeRead (a -> Dest (Ptr a) -> BYTES Int
forall a. Storable a => a -> Dest (Ptr a) -> BYTES Int
sz a
forall a. HasCallStack => a
undefined Dest (Ptr a)
dest)
                  ((Pointer -> m ()) -> ReadM m) -> (Pointer -> m ()) -> ReadM m
forall a b. (a -> b) -> a -> b
$ \ Pointer
ptr -> IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Dest (Ptr a) -> Src Pointer -> Int -> IO ()
forall w.
EndianStore w =>
Dest (Ptr w) -> Src Pointer -> Int -> IO ()
copyFromBytes Dest (Ptr a)
dest (Pointer -> Src Pointer
forall a. a -> Src a
source Pointer
ptr) Int
n
  where sz :: Storable a => a -> Dest (Ptr a) -> BYTES Int
        sz :: a -> Dest (Ptr a) -> BYTES Int
sz a
a Dest (Ptr a)
_ = Int -> BYTES Int
forall a. Enum a => Int -> a
toEnum Int
n BYTES Int -> BYTES Int -> BYTES Int
forall a. Num a => a -> a -> a
* a -> BYTES Int
forall a. Storable a => a -> BYTES Int
sizeOf a
a