-- |
--
-- Module      : Raaz.Core.Memory
-- Description : Explicit, typesafe, low-level memory management in raaz
-- Copyright   : (c) Piyush P Kurur, 2019
-- License     : Apache-2.0 OR BSD-3-Clause
-- Maintainer  : Piyush P Kurur <ppk@iitpkd.ac.in>
-- Stability   : experimental
--

{-# LANGUAGE TypeFamilies               #-}
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE RankNTypes                 #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE DataKinds                  #-}
module Raaz.Core.Memory
       (

         -- BANNED combinators
         --
         -- 1. copyMemory


         -- * Low level memory management in raaz.
         -- $memorysubsystem$

         -- ** The memory class
         Memory(..)
       , VoidMemory, withMemoryPtr
       , withMemory, withSecureMemory
         -- ** The allocator
       , Alloc
       , pointerAlloc

       -- * Initialisation and Extraction.
       -- $init-extract$

       , Initialisable(..), Extractable(..), modifyMem

       -- * Accessing the bytes directly
       -- $access$
       --
       , Access(..)
       , ReadAccessible(..), WriteAccessible(..), memTransfer
       -- * A basic memory cell.
       , MemoryCell, copyCell, withCellPointer, unsafeGetCellPointer

       ) where

import           Foreign.Ptr                 ( castPtr )
import           Foreign.Storable            ( Storable )


import           Raaz.Core.Prelude
import           Raaz.Core.MonoidalAction
import           Raaz.Core.Types    hiding   ( zipWith       )
import           Raaz.Core.Types.Internal

-------------- BANNED FEATURES ---------------------------------------
--
-- This module has a lot of low level pointer gymnastics and hence
-- should be dealt with care. The following features are BANNED
-- and hence should never be exposed. Often they are subtle and can
-- be easily missed. Hence it is documented here.
--
-- * COPY BUG
--
-- ** Combinator:
--
-- >
-- > `copyMemory :: Memory mem => Dest mem -> Src mem -> IO ()
-- >
--
-- ** THE BUG. At first it looks like a useful, general function to
-- have which is just a memcpy on the underlying pointers. For a
-- memory element we can easily get its pointer and size. However this
-- has a very subtle bug. The actual data in certain memory elements
-- like MemoryCell's have a runtime dependent offset from its raw
-- pointer and can defer from one element to another. As an example
-- consider two MemoryCells A and B of type `MemoryCell Word64` and
-- let us assume that the alignment restriction for both these is
-- 8-byte boundary. The Allocation strategy for MemoryCell is the following.
--
-- (1) The size is 16 (using the atleastAligned function)
-- (2) The starting pointer is the next 8-byte aligned pointer from the
--     given pointer.
--
-- It is very well possible that on allocation A gets an 8-byte
-- aligned memory pointer internally and the nextAligned pointer would
-- be itself. However, B might not be aligned and hence the actual
-- pointer for B might have a non-zero offset from its raw
-- pointer. Clearly a memcpy from the associated raw pointers will
-- mean that the initial segment of A is lost to B.




-- $memorysubsystem$
--
-- __Warning:__ This module is pretty low level and should not be
-- needed in typical use cases. Only developers of protocols and
-- primitives might have a reason to look into this module.
--
-- The memory subsytem of raaz gives a relatively abstract and type
-- safe interface for performing low level size calculations and
-- pointer arithmetic. The two main components of this subsystem
-- is the class `Memory` whose instances are essentially memory buffers that
-- are distinguished at the type level, and the type `Alloc` that captures
-- the allocation strategies for these types.
--

------------------------ A memory allocator -----------------------

type AllocField = Field (Ptr Word8)

-- | A memory allocator for the memory type @mem@. The `Applicative`
-- instance of @Alloc@ can be used to build allocations for
-- complicated memory elements from simpler ones and takes care of
-- handling the size/offset calculations involved.
type Alloc mem = TwistRF AllocField (BYTES Int) mem

-- | Make an allocator for a given memory type.
makeAlloc :: LengthUnit l => l -> (Ptr Word8 -> mem) -> Alloc mem
makeAlloc :: forall l mem. LengthUnit l => l -> (Ptr Word8 -> mem) -> Alloc mem
makeAlloc l
l Ptr Word8 -> mem
memCreate = AllocField mem -> BYTES Int -> TwistRF AllocField (BYTES Int) mem
forall (f :: * -> *) m a. f a -> m -> TwistRF f m a
TwistRF ((Ptr Word8 -> mem) -> AllocField mem
forall (a :: * -> * -> *) b c. a b c -> WrappedArrow a b c
WrapArrow Ptr Word8 -> mem
memCreate) (BYTES Int -> TwistRF AllocField (BYTES Int) mem)
-> BYTES Int -> TwistRF AllocField (BYTES Int) mem
forall a b. (a -> b) -> a -> b
$ l -> BYTES Int
forall src dest. (LengthUnit src, LengthUnit dest) => src -> dest
atLeast l
l

-- | Allocates a buffer of size @l@ and returns the pointer to it pointer.
pointerAlloc :: LengthUnit l => l -> Alloc (Ptr Word8)
pointerAlloc :: forall l. LengthUnit l => l -> Alloc (Ptr Word8)
pointerAlloc l
l = l -> (Ptr Word8 -> Ptr Word8) -> Alloc (Ptr Word8)
forall l mem. LengthUnit l => l -> (Ptr Word8 -> mem) -> Alloc mem
makeAlloc l
l Ptr Word8 -> Ptr Word8
forall a. a -> a
id

---------------------------------------------------------------------

-- | Any cryptographic primitives use memory to store stuff. This
-- class abstracts all types that hold some memory. Cryptographic
-- application often requires securing the memory from being swapped
-- out (think of memory used to store private keys or passwords). This
-- abstraction supports memory securing. If your platform supports
-- memory locking, then securing a memory will prevent the memory from
-- being swapped to the disk. Once secured the memory location is
-- overwritten by nonsense before being freed.
--
-- While some basic memory elements like `MemoryCell` are exposed from
-- the library, often we require compound memory objects built out of
-- simpler ones. The `Applicative` instance of the `Alloc` can be made
-- use of in such situation to simplify such instance declaration as
-- illustrated in the instance declaration for a pair of memory
-- elements.
--
-- > instance (Memory ma, Memory mb) => Memory (ma, mb) where
-- >
-- >    memoryAlloc             = (,) <$> memoryAlloc <*> memoryAlloc
-- >
-- >    unsafeToPointer (ma, _) =  unsafeToPointer ma
--
class Memory m where

  -- | Returns an allocator for this memory.
  memoryAlloc     :: Alloc m

  -- | Returns the pointer to the underlying buffer.
  unsafeToPointer :: m -> Ptr Word8


-- | A memory element that holds nothing.
newtype VoidMemory = VoidMemory { VoidMemory -> Ptr Word8
unVoidMemory :: Ptr Word8  }

--
-- DEVELOPER NOTE:
--
-- It might be tempting to define VoidMemory as follows.
--
-- >
-- > newtype VoidMemory = VoidMemory
-- >
--
-- However, this will lead to failure of memory instances of product
-- memories where the first component is VoidMemory. Imagine what
-- would the member function unsafeToPointer of (VoidMemory,
-- SomeOtherMemory) look like.
--
instance Memory VoidMemory where
  memoryAlloc :: Alloc VoidMemory
memoryAlloc      = BYTES Int -> (Ptr Word8 -> VoidMemory) -> Alloc VoidMemory
forall l mem. LengthUnit l => l -> (Ptr Word8 -> mem) -> Alloc mem
makeAlloc (BYTES Int
0 :: BYTES Int) Ptr Word8 -> VoidMemory
VoidMemory
  unsafeToPointer :: VoidMemory -> Ptr Word8
unsafeToPointer  = VoidMemory -> Ptr Word8
unVoidMemory


instance ( Memory ma, Memory mb ) => Memory (ma, mb) where
    memoryAlloc :: Alloc (ma, mb)
memoryAlloc             = (,) (ma -> mb -> (ma, mb))
-> TwistRF AllocField (BYTES Int) ma
-> TwistRF AllocField (BYTES Int) (mb -> (ma, mb))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TwistRF AllocField (BYTES Int) ma
forall m. Memory m => Alloc m
memoryAlloc TwistRF AllocField (BYTES Int) (mb -> (ma, mb))
-> TwistRF AllocField (BYTES Int) mb -> Alloc (ma, mb)
forall a b.
TwistRF AllocField (BYTES Int) (a -> b)
-> TwistRF AllocField (BYTES Int) a
-> TwistRF AllocField (BYTES Int) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) mb
forall m. Memory m => Alloc m
memoryAlloc
    unsafeToPointer :: (ma, mb) -> Ptr Word8
unsafeToPointer (ma
ma, mb
_) = ma -> Ptr Word8
forall m. Memory m => m -> Ptr Word8
unsafeToPointer ma
ma

instance ( Memory ma
         , Memory mb
         , Memory mc
         )
         => Memory (ma, mb, mc) where
  memoryAlloc :: Alloc (ma, mb, mc)
memoryAlloc              = (,,)
                             (ma -> mb -> mc -> (ma, mb, mc))
-> TwistRF AllocField (BYTES Int) ma
-> TwistRF AllocField (BYTES Int) (mb -> mc -> (ma, mb, mc))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TwistRF AllocField (BYTES Int) ma
forall m. Memory m => Alloc m
memoryAlloc
                             TwistRF AllocField (BYTES Int) (mb -> mc -> (ma, mb, mc))
-> TwistRF AllocField (BYTES Int) mb
-> TwistRF AllocField (BYTES Int) (mc -> (ma, mb, mc))
forall a b.
TwistRF AllocField (BYTES Int) (a -> b)
-> TwistRF AllocField (BYTES Int) a
-> TwistRF AllocField (BYTES Int) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) mb
forall m. Memory m => Alloc m
memoryAlloc
                             TwistRF AllocField (BYTES Int) (mc -> (ma, mb, mc))
-> TwistRF AllocField (BYTES Int) mc -> Alloc (ma, mb, mc)
forall a b.
TwistRF AllocField (BYTES Int) (a -> b)
-> TwistRF AllocField (BYTES Int) a
-> TwistRF AllocField (BYTES Int) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) mc
forall m. Memory m => Alloc m
memoryAlloc
  unsafeToPointer :: (ma, mb, mc) -> Ptr Word8
unsafeToPointer (ma
ma,mb
_,mc
_) =  ma -> Ptr Word8
forall m. Memory m => m -> Ptr Word8
unsafeToPointer ma
ma

instance ( Memory ma
         , Memory mb
         , Memory mc
         , Memory md
         )
         => Memory (ma, mb, mc, md) where
  memoryAlloc :: Alloc (ma, mb, mc, md)
memoryAlloc                = (,,,)
                               (ma -> mb -> mc -> md -> (ma, mb, mc, md))
-> TwistRF AllocField (BYTES Int) ma
-> TwistRF
     AllocField (BYTES Int) (mb -> mc -> md -> (ma, mb, mc, md))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TwistRF AllocField (BYTES Int) ma
forall m. Memory m => Alloc m
memoryAlloc
                               TwistRF AllocField (BYTES Int) (mb -> mc -> md -> (ma, mb, mc, md))
-> TwistRF AllocField (BYTES Int) mb
-> TwistRF AllocField (BYTES Int) (mc -> md -> (ma, mb, mc, md))
forall a b.
TwistRF AllocField (BYTES Int) (a -> b)
-> TwistRF AllocField (BYTES Int) a
-> TwistRF AllocField (BYTES Int) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) mb
forall m. Memory m => Alloc m
memoryAlloc
                               TwistRF AllocField (BYTES Int) (mc -> md -> (ma, mb, mc, md))
-> TwistRF AllocField (BYTES Int) mc
-> TwistRF AllocField (BYTES Int) (md -> (ma, mb, mc, md))
forall a b.
TwistRF AllocField (BYTES Int) (a -> b)
-> TwistRF AllocField (BYTES Int) a
-> TwistRF AllocField (BYTES Int) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) mc
forall m. Memory m => Alloc m
memoryAlloc
                               TwistRF AllocField (BYTES Int) (md -> (ma, mb, mc, md))
-> TwistRF AllocField (BYTES Int) md -> Alloc (ma, mb, mc, md)
forall a b.
TwistRF AllocField (BYTES Int) (a -> b)
-> TwistRF AllocField (BYTES Int) a
-> TwistRF AllocField (BYTES Int) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> TwistRF AllocField (BYTES Int) md
forall m. Memory m => Alloc m
memoryAlloc

  unsafeToPointer :: (ma, mb, mc, md) -> Ptr Word8
unsafeToPointer (ma
ma,mb
_,mc
_,md
_) =  ma -> Ptr Word8
forall m. Memory m => m -> Ptr Word8
unsafeToPointer ma
ma


-- | Apply some low level action on the underlying buffer of the
-- memory.
withMemoryPtr :: Memory m
              => (BYTES Int -> Ptr Word8 -> IO a)
              -> m -> IO a
withMemoryPtr :: forall m a.
Memory m =>
(BYTES Int -> Ptr Word8 -> IO a) -> m -> IO a
withMemoryPtr BYTES Int -> Ptr Word8 -> IO a
action m
mem = BYTES Int -> Ptr Word8 -> IO a
action BYTES Int
sz (Ptr Word8 -> IO a) -> Ptr Word8 -> IO a
forall a b. (a -> b) -> a -> b
$ m -> Ptr Word8
forall m. Memory m => m -> Ptr Word8
unsafeToPointer m
mem
  where sz :: BYTES Int
sz = TwistRF AllocField (BYTES Int) m -> BYTES Int
forall (f :: * -> *) m a. TwistRF f m a -> m
twistMonoidValue (TwistRF AllocField (BYTES Int) m -> BYTES Int)
-> TwistRF AllocField (BYTES Int) m -> BYTES Int
forall a b. (a -> b) -> a -> b
$ m -> TwistRF AllocField (BYTES Int) m
forall m. Memory m => m -> Alloc m
getAlloc m
mem
        getAlloc :: Memory m => m -> Alloc m
        getAlloc :: forall m. Memory m => m -> Alloc m
getAlloc m
_ = Alloc m
forall m. Memory m => Alloc m
memoryAlloc

-- | Perform an action which makes use of this memory. The memory
-- allocated will automatically be freed when the action finishes
-- either gracefully or with some exception. Besides being safer,
-- this method might be more efficient as the memory might be
-- allocated from the stack directly and will have very little GC
-- overhead.
withMemory   :: Memory mem => (mem -> IO a) -> IO a
withMemory :: forall mem a. Memory mem => (mem -> IO a) -> IO a
withMemory   = Alloc mem -> (mem -> IO a) -> IO a
forall mem a. Alloc mem -> (mem -> IO a) -> IO a
withM Alloc mem
forall m. Memory m => Alloc m
memoryAlloc
  where withM :: Alloc mem -> (mem -> IO a) -> IO a
        withM :: forall mem a. Alloc mem -> (mem -> IO a) -> IO a
withM Alloc mem
alctr mem -> IO a
action = BYTES Int -> (Ptr Word8 -> IO a) -> IO a
forall l (ptr :: * -> *) something b.
(LengthUnit l, Pointer ptr) =>
l -> (ptr something -> IO b) -> IO b
allocaBuffer BYTES Int
sz Ptr Word8 -> IO a
actualAction
          where sz :: BYTES Int
sz                 = Alloc mem -> BYTES Int
forall (f :: * -> *) m a. TwistRF f m a -> m
twistMonoidValue Alloc mem
alctr
                getM :: Ptr Word8 -> mem
getM               = Field (Ptr Word8) mem -> Ptr Word8 -> mem
forall space b. Field space b -> space -> b
computeField (Field (Ptr Word8) mem -> Ptr Word8 -> mem)
-> Field (Ptr Word8) mem -> Ptr Word8 -> mem
forall a b. (a -> b) -> a -> b
$ Alloc mem -> Field (Ptr Word8) mem
forall (f :: * -> *) m a. TwistRF f m a -> f a
twistFunctorValue Alloc mem
alctr
                wipeIt :: Ptr Word8 -> IO ()
wipeIt Ptr Word8
cptr        = Ptr Word8 -> BYTES Int -> IO ()
forall l (ptr :: * -> *) a.
(LengthUnit l, Pointer ptr) =>
ptr a -> l -> IO ()
wipeMemory Ptr Word8
cptr BYTES Int
sz
                actualAction :: Ptr Word8 -> IO a
actualAction  Ptr Word8
cptr = mem -> IO a
action (Ptr Word8 -> mem
getM Ptr Word8
cptr) IO a -> IO () -> IO a
forall a b. IO a -> IO b -> IO a
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f a
<* Ptr Word8 -> IO ()
wipeIt Ptr Word8
cptr


-- | Similar to `withMemory` but allocates a secure memory for the
-- action. Secure memories are never swapped on to disk and will be
-- wiped clean of sensitive data after use. However, be careful when
-- using this function in a child thread. Due to the daemonic nature
-- of Haskell threads, if the main thread exists before the child
-- thread is done with its job, sensitive data can leak. This is
-- essentially a limitation of the bracket which is used internally.
withSecureMemory :: Memory mem => (mem -> IO a) -> IO a
withSecureMemory :: forall mem a. Memory mem => (mem -> IO a) -> IO a
withSecureMemory = TwistRF AllocField (BYTES Int) mem -> (mem -> IO a) -> IO a
forall {l} {ptr :: * -> *} {a} {b} {b}.
(LengthUnit l, Pointer ptr) =>
TwistRF (WrappedArrow (->) (ptr a)) l b -> (b -> IO b) -> IO b
withSM TwistRF AllocField (BYTES Int) mem
forall m. Memory m => Alloc m
memoryAlloc
  where -- withSM :: Memory m => Alloc m -> (m -> IO a) -> IO a
        withSM :: TwistRF (WrappedArrow (->) (ptr a)) l b -> (b -> IO b) -> IO b
withSM TwistRF (WrappedArrow (->) (ptr a)) l b
alctr b -> IO b
action = l -> (ptr a -> IO b) -> IO b
forall l (ptr :: * -> *) something b.
(LengthUnit l, Pointer ptr) =>
l -> (ptr something -> IO b) -> IO b
allocaSecure l
sz ((ptr a -> IO b) -> IO b) -> (ptr a -> IO b) -> IO b
forall a b. (a -> b) -> a -> b
$ b -> IO b
action (b -> IO b) -> (ptr a -> b) -> ptr a -> IO b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ptr a -> b
getM
          where sz :: l
sz     = TwistRF (WrappedArrow (->) (ptr a)) l b -> l
forall (f :: * -> *) m a. TwistRF f m a -> m
twistMonoidValue TwistRF (WrappedArrow (->) (ptr a)) l b
alctr
                getM :: ptr a -> b
getM   = Field (ptr a) b -> ptr a -> b
forall space b. Field space b -> space -> b
computeField (Field (ptr a) b -> ptr a -> b) -> Field (ptr a) b -> ptr a -> b
forall a b. (a -> b) -> a -> b
$ TwistRF (WrappedArrow (->) (ptr a)) l b -> Field (ptr a) b
forall (f :: * -> *) m a. TwistRF f m a -> f a
twistFunctorValue TwistRF (WrappedArrow (->) (ptr a)) l b
alctr


----------------------- Initialising and Extracting stuff ----------------------

-- $init-extract$
--
-- Memories often allow initialisation with and extraction of values
-- in the Haskell world. The `Initialisable` and `Extractable` class
-- captures this interface.
--
-- == Explicit Pointer
--
-- Using the `Initialisable` and `Extractable` for sensitive data
-- interface defeats one important purpose of the memory subsystem
-- namely providing memory locking. Using these interfaces means
-- keeping the sensitive information as pure values in the Haskell
-- heap which impossible to lock. Worse still, the GC often move the
-- data around spreading it all around the memory. One should use
-- direct byte transfer via `memcpy` for effecting these
-- initialisation. An interface to facilitate these is the type
-- classes `ReadAccessible` and `WriteAccessble` where direct access
-- is given (via the `Access` buffer) to the portions of the internal
-- memory where sensitive data is kept.

-- | Memories that can be initialised with a pure value. The pure
-- value resides in the Haskell heap and hence can potentially be
-- swapped. Therefore, this class should be avoided if compromising
-- the initialisation value can be dangerous. Look into the type class
-- `WriteAccessible` instead.
class Memory m => Initialisable m v where
  initialise :: v -> m -> IO ()

-- | Memories from which pure values can be extracted. Much like the
-- case of the `Initialisable` class, avoid using this interface if
-- you do not want the data extracted to be swapped. Use the
-- `ReadAccessible` class instead.
class Memory m => Extractable m v where
  extract  :: m -> IO v


-- | Apply the given function to the value in the cell. For a function
-- @f :: b -> a@, the action @modify f@ first extracts a value of type
-- @b@ from the memory element, applies @f@ to it and puts the result
-- back into the memory.
--
-- > modifyMem f mem = do b <- extract mem
-- >                      initialise (f b) mem
--
modifyMem :: (Initialisable mem a, Extractable mem b) =>  (b -> a) -> mem -> IO ()
modifyMem :: forall mem a b.
(Initialisable mem a, Extractable mem b) =>
(b -> a) -> mem -> IO ()
modifyMem b -> a
f mem
mem = mem -> IO b
forall m v. Extractable m v => m -> IO v
extract mem
mem IO b -> (b -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (a -> mem -> IO ()) -> mem -> a -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip a -> mem -> IO ()
forall m v. Initialisable m v => v -> m -> IO ()
initialise mem
mem (a -> IO ()) -> (b -> a) -> b -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. b -> a
f

-- $access$
--
-- To avoid the problems associated with the `Initialisable` and
-- `Extractable` interface, certain memory types give access to the
-- associated buffers directly via the `Access` buffer. Data then
-- needs to be transferred between these memories directly via
-- `memcpy` making use of the `Access` buffers thereby avoiding a copy
-- in the Haskell heap where it is prone to leak.
--
-- [`ReadAccessible`:] Instances of these class are memories that are
-- on the source side of the transfer. Examples include the memory
-- element that is used to implement a Diffie-Hellman key
-- exchange. The exchanged key is in the memory which can then be used
-- to initialise a cipher for the actual transfer of encrypted data .
--
-- [`WriteAccessible`:] Instances of these classes are memories that
-- are on the destination side of the transfer. The memory element
-- that stores the key for a cipher is an example of such a element.

-- | Data type that gives an access buffer to portion of the memory.
data Access = Access
  { Access -> Ptr Word8
accessPtr         :: Ptr Word8
    -- ^ The buffer pointer associated with this access.
  , Access -> BYTES Int
accessSize        :: BYTES Int
    -- ^ Its size
  }

-- | Transfer the bytes from the source memory to the destination
-- memory. The total bytes transferred is the minimum of the bytes
-- available at the source and the space available at the destination.
memTransfer :: (ReadAccessible src, WriteAccessible dest)
            => Dest dest
            -> Src src
            -> IO ()
memTransfer :: forall src dest.
(ReadAccessible src, WriteAccessible dest) =>
Dest dest -> Src src -> IO ()
memTransfer Dest dest
dest Src src
src = do
  let dmem :: dest
dmem = Dest dest -> dest
forall a. Dest a -> a
unDest Dest dest
dest
      smem :: src
smem = Src src -> src
forall a. Src a -> a
unSrc Src src
src
      in do src -> IO ()
forall mem. ReadAccessible mem => mem -> IO ()
beforeReadAdjustment src
smem
            [Access] -> [Access] -> IO ()
copyAccessList (dest -> [Access]
forall mem. WriteAccessible mem => mem -> [Access]
writeAccess dest
dmem) (src -> [Access]
forall mem. ReadAccessible mem => mem -> [Access]
readAccess src
smem)
            dest -> IO ()
forall mem. WriteAccessible mem => mem -> IO ()
afterWriteAdjustment dest
dmem


-- | Copy access list, Internal function.
copyAccessList :: [Access] -> [Access] -> IO ()
copyAccessList :: [Access] -> [Access] -> IO ()
copyAccessList (Access
da:[Access]
ds) (Access
sa:[Access]
ss)
  | BYTES Int
dsize BYTES Int -> BYTES Int -> Bool
forall a. Ord a => a -> a -> Bool
> BYTES Int
ssize = IO ()
tAct IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [Access] -> [Access] -> IO ()
copyAccessList (Access
da' Access -> [Access] -> [Access]
forall a. a -> [a] -> [a]
: [Access]
ds) [Access]
ss
  | BYTES Int
ssize BYTES Int -> BYTES Int -> Bool
forall a. Ord a => a -> a -> Bool
> BYTES Int
dsize = IO ()
tAct IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [Access] -> [Access] -> IO ()
copyAccessList [Access]
ds         (Access
sa' Access -> [Access] -> [Access]
forall a. a -> [a] -> [a]
: [Access]
ss)
  | Bool
otherwise     = IO ()
tAct IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [Access] -> [Access] -> IO ()
copyAccessList [Access]
ds [Access]
ss
    where dsize :: BYTES Int
dsize = Access -> BYTES Int
accessSize Access
da
          ssize :: BYTES Int
ssize = Access -> BYTES Int
accessSize Access
sa
          trans :: BYTES Int
trans = BYTES Int -> BYTES Int -> BYTES Int
forall a. Ord a => a -> a -> a
min BYTES Int
dsize BYTES Int
ssize
          dptr :: Ptr Word8
dptr  = Access -> Ptr Word8
accessPtr Access
da
          sptr :: Ptr Word8
sptr  = Access -> Ptr Word8
accessPtr Access
sa
          da' :: Access
da'   = Ptr Word8 -> BYTES Int -> Access
Access (Access -> Ptr Word8
accessPtr Access
da Ptr Word8 -> BYTES Int -> Ptr Word8
forall l a. LengthUnit l => Ptr a -> l -> Ptr a
`movePtr` BYTES Int
trans) (BYTES Int
dsize BYTES Int -> BYTES Int -> BYTES Int
forall a. Num a => a -> a -> a
- BYTES Int
trans)
          sa' :: Access
sa'   = Ptr Word8 -> BYTES Int -> Access
Access (Access -> Ptr Word8
accessPtr Access
sa Ptr Word8 -> BYTES Int -> Ptr Word8
forall l a. LengthUnit l => Ptr a -> l -> Ptr a
`movePtr` BYTES Int
trans) (BYTES Int
ssize BYTES Int -> BYTES Int -> BYTES Int
forall a. Num a => a -> a -> a
- BYTES Int
trans)
          tAct :: IO ()
tAct  = Dest (Ptr Word8) -> Src (Ptr Word8) -> BYTES Int -> IO ()
forall l (ptrS :: * -> *) (ptrD :: * -> *) dest src.
(LengthUnit l, Pointer ptrS, Pointer ptrD) =>
Dest (ptrD dest) -> Src (ptrS src) -> l -> IO ()
memcpy (Ptr Word8 -> Dest (Ptr Word8)
forall a. a -> Dest a
destination Ptr Word8
dptr) (Ptr Word8 -> Src (Ptr Word8)
forall a. a -> Src a
source Ptr Word8
sptr) BYTES Int
trans
copyAccessList [Access]
_ [Access]
_ = () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | This class captures memories from which bytes can be extracted
-- directly from (portions of) its buffer.
class Memory mem => ReadAccessible mem where
  -- | Internal organisation of the data might need adjustment due to
  -- host machine having a different endian than the standard byte
  -- order of the associated type. This action perform the necessary
  -- adjustment before the bytes can be read-off from the associated
  -- `readAccess` adjustments.
  beforeReadAdjustment :: mem -> IO ()

  -- | The ordered access buffers for the memory through which bytes
  -- may be read off (after running `beforeReadAdjustment` of course)
  readAccess :: mem -> [Access]

-- | This class captures memories that can be initialised by writing
-- bytes to (portions of) its buffer.
class Memory mem => WriteAccessible mem where

  -- | The ordered access to buffers through which bytes may be
  -- written into the memory.
  writeAccess :: mem -> [Access]

  -- | After writing data into the buffer, the memory might need
  -- further adjustments before it is considered "initialised" with
  -- the sensitive data.
  --
  afterWriteAdjustment :: mem -> IO ()

--------------------- Some instances of Memory --------------------

-- | A memory location to store a value of type having `Storable`
-- instance.
newtype MemoryCell a = MemoryCell { forall a. MemoryCell a -> Ptr a
unMemoryCell :: Ptr a }


instance Storable a => Memory (MemoryCell a) where

  memoryAlloc :: Alloc (MemoryCell a)
memoryAlloc = a -> Alloc (MemoryCell a)
forall b. Storable b => b -> Alloc (MemoryCell b)
allocator a
forall a. HasCallStack => a
undefined
    where allocator :: Storable b => b -> Alloc (MemoryCell b)
          allocator :: forall b. Storable b => b -> Alloc (MemoryCell b)
allocator b
b = BYTES Int -> (Ptr Word8 -> MemoryCell b) -> Alloc (MemoryCell b)
forall l mem. LengthUnit l => l -> (Ptr Word8 -> mem) -> Alloc mem
makeAlloc (Proxy b -> BYTES Int
forall a. Storable a => Proxy a -> BYTES Int
alignedSizeOf (Proxy b -> BYTES Int) -> Proxy b -> BYTES Int
forall a b. (a -> b) -> a -> b
$ b -> Proxy b
forall a. a -> Proxy a
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
b) ((Ptr Word8 -> MemoryCell b) -> Alloc (MemoryCell b))
-> (Ptr Word8 -> MemoryCell b) -> Alloc (MemoryCell b)
forall a b. (a -> b) -> a -> b
$ Ptr b -> MemoryCell b
forall a. Ptr a -> MemoryCell a
MemoryCell (Ptr b -> MemoryCell b)
-> (Ptr Word8 -> Ptr b) -> Ptr Word8 -> MemoryCell b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr Word8 -> Ptr b
forall a b. Ptr a -> Ptr b
castPtr

  unsafeToPointer :: MemoryCell a -> Ptr Word8
unsafeToPointer  = Ptr a -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr (Ptr a -> Ptr Word8)
-> (MemoryCell a -> Ptr a) -> MemoryCell a -> Ptr Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemoryCell a -> Ptr a
forall a. MemoryCell a -> Ptr a
unMemoryCell

-- | The location where the actual storing of element happens. This
-- pointer is guaranteed to be aligned to the alignment restriction of @a@
unsafeGetCellPointer :: Storable a => MemoryCell a -> Ptr a
unsafeGetCellPointer :: forall a. Storable a => MemoryCell a -> Ptr a
unsafeGetCellPointer = Ptr a -> Ptr a
forall a. Storable a => Ptr a -> Ptr a
nextLocation (Ptr a -> Ptr a)
-> (MemoryCell a -> Ptr a) -> MemoryCell a -> Ptr a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemoryCell a -> Ptr a
forall a. MemoryCell a -> Ptr a
unMemoryCell

-- | Work with the underlying pointer of the memory cell. Useful while
-- working with ffi functions.
withCellPointer :: Storable a => (Ptr a -> IO b) -> MemoryCell a -> IO b
{-# INLINE withCellPointer #-}
withCellPointer :: forall a b. Storable a => (Ptr a -> IO b) -> MemoryCell a -> IO b
withCellPointer Ptr a -> IO b
action = Ptr a -> IO b
action (Ptr a -> IO b) -> (MemoryCell a -> Ptr a) -> MemoryCell a -> IO b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemoryCell a -> Ptr a
forall a. Storable a => MemoryCell a -> Ptr a
unsafeGetCellPointer

-- | Copy the contents of one memory cell to another.
copyCell :: Storable a => Dest (MemoryCell a) -> Src (MemoryCell a) -> IO ()
copyCell :: forall a.
Storable a =>
Dest (MemoryCell a) -> Src (MemoryCell a) -> IO ()
copyCell Dest (MemoryCell a)
dest Src (MemoryCell a)
src = Dest (Ptr a) -> Src (Ptr a) -> BYTES Int -> IO ()
forall l (ptrS :: * -> *) (ptrD :: * -> *) dest src.
(LengthUnit l, Pointer ptrS, Pointer ptrD) =>
Dest (ptrD dest) -> Src (ptrS src) -> l -> IO ()
memcpy (MemoryCell a -> Ptr a
forall a. Storable a => MemoryCell a -> Ptr a
unsafeGetCellPointer (MemoryCell a -> Ptr a) -> Dest (MemoryCell a) -> Dest (Ptr a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Dest (MemoryCell a)
dest) (MemoryCell a -> Ptr a
forall a. Storable a => MemoryCell a -> Ptr a
unsafeGetCellPointer (MemoryCell a -> Ptr a) -> Src (MemoryCell a) -> Src (Ptr a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Src (MemoryCell a)
src) BYTES Int
sz
  where getProxy :: Dest (MemoryCell a) -> Proxy a
        getProxy :: forall a. Dest (MemoryCell a) -> Proxy a
getProxy Dest (MemoryCell a)
_ = Proxy a
forall {k} (t :: k). Proxy t
Proxy
        sz :: BYTES Int
sz = Proxy a -> BYTES Int
forall a. Storable a => Proxy a -> BYTES Int
sizeOf (Dest (MemoryCell a) -> Proxy a
forall a. Dest (MemoryCell a) -> Proxy a
getProxy Dest (MemoryCell a)
dest)

instance Storable a => Initialisable (MemoryCell a) a where
  initialise :: a -> MemoryCell a -> IO ()
initialise a
a = (Ptr a -> a -> IO ()) -> a -> Ptr a -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Ptr a -> a -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
pokeAligned a
a (Ptr a -> IO ())
-> (MemoryCell a -> Ptr a) -> MemoryCell a -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemoryCell a -> Ptr a
forall a. MemoryCell a -> Ptr a
unMemoryCell
  {-# INLINE initialise #-}

instance Storable a => Extractable (MemoryCell a) a where
  extract :: MemoryCell a -> IO a
extract = Ptr a -> IO a
forall a. Storable a => Ptr a -> IO a
peekAligned (Ptr a -> IO a) -> (MemoryCell a -> Ptr a) -> MemoryCell a -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MemoryCell a -> Ptr a
forall a. MemoryCell a -> Ptr a
unMemoryCell
  {-# INLINE extract #-}


memCellToAccess :: EndianStore a => MemoryCell a -> [Access]
memCellToAccess :: forall a. EndianStore a => MemoryCell a -> [Access]
memCellToAccess MemoryCell a
mem = [ Access { accessPtr :: Ptr Word8
accessPtr  = Ptr a -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr (Ptr a -> Ptr Word8) -> Ptr a -> Ptr Word8
forall a b. (a -> b) -> a -> b
$ MemoryCell a -> Ptr a
forall a. Storable a => MemoryCell a -> Ptr a
unsafeGetCellPointer MemoryCell a
mem
                               , accessSize :: BYTES Int
accessSize = Proxy a -> BYTES Int
forall a. Storable a => Proxy a -> BYTES Int
sizeOf  (Proxy a -> BYTES Int) -> Proxy a -> BYTES Int
forall a b. (a -> b) -> a -> b
$ MemoryCell a -> Proxy a
forall a. MemoryCell a -> Proxy a
getProxy MemoryCell a
mem
                               }
                      ]
  where getProxy   :: MemoryCell a -> Proxy a
        getProxy :: forall a. MemoryCell a -> Proxy a
getProxy MemoryCell a
_ =  Proxy a
forall {k} (t :: k). Proxy t
Proxy

instance EndianStore a => ReadAccessible (MemoryCell a) where
  beforeReadAdjustment :: MemoryCell a -> IO ()
beforeReadAdjustment MemoryCell a
mem = Ptr a -> Int -> IO ()
forall w. EndianStore w => Ptr w -> Int -> IO ()
adjustEndian (MemoryCell a -> Ptr a
forall a. Storable a => MemoryCell a -> Ptr a
unsafeGetCellPointer MemoryCell a
mem) Int
1
  readAccess :: MemoryCell a -> [Access]
readAccess               = MemoryCell a -> [Access]
forall a. EndianStore a => MemoryCell a -> [Access]
memCellToAccess

instance EndianStore a => WriteAccessible (MemoryCell a) where
  writeAccess :: MemoryCell a -> [Access]
writeAccess              = MemoryCell a -> [Access]
forall a. EndianStore a => MemoryCell a -> [Access]
memCellToAccess
  afterWriteAdjustment :: MemoryCell a -> IO ()
afterWriteAdjustment MemoryCell a
mem = Ptr a -> Int -> IO ()
forall w. EndianStore w => Ptr w -> Int -> IO ()
adjustEndian (MemoryCell a -> Ptr a
forall a. Storable a => MemoryCell a -> Ptr a
unsafeGetCellPointer MemoryCell a
mem) Int
1