{-# LANGUAGE ConstraintKinds, RankNTypes #-}
module Data.SensitiveBytes.Internal
( withSecureMemory
, WithSecureMemory
, SodiumInitialised
, SecureMemoryInitException
, SensitiveBytes (..)
, allocate
, free
, unsafePtr
, resized
, withSensitiveBytes
, SensitiveBytesAllocException
) where
import Control.Exception.Safe (Exception, MonadMask, bracket, throwIO)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Data.ByteArray (ByteArrayAccess (length, withByteArray))
import Data.Reflection (Given, give, given)
import Foreign.Ptr (Ptr, castPtr, nullPtr)
import Libsodium (sodium_free, sodium_init, sodium_malloc, sodium_memzero)
data SodiumInitialised = SodiumInitialised
type WithSecureMemory = Given SodiumInitialised
withSecureMemory
:: forall m r. MonadIO m
=> (WithSecureMemory => m r)
-> m r
withSecureMemory :: (WithSecureMemory => m r) -> m r
withSecureMemory WithSecureMemory => m r
act = do
IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IO CInt
sodium_init IO CInt -> (CInt -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \case
CInt
0 ->
() -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
CInt
1 ->
() -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
CInt
_ ->
SecureMemoryInitException -> IO ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO SecureMemoryInitException
SodiumInitFailed
SodiumInitialised -> (WithSecureMemory => m r) -> m r
forall a r. a -> (Given a => r) -> r
give SodiumInitialised
SodiumInitialised WithSecureMemory => m r
act
data SecureMemoryInitException
= SodiumInitFailed
instance Show SecureMemoryInitException where
show :: SecureMemoryInitException -> String
show SecureMemoryInitException
SodiumInitFailed =
String
"Failed to initialise a secure memory region"
instance Exception SecureMemoryInitException
data SensitiveBytes s = SensitiveBytes
{ SensitiveBytes s -> Int
allocSize :: Int
, SensitiveBytes s -> Int
dataSize :: Int
, SensitiveBytes s -> Ptr ()
bufPtr :: Ptr ()
}
instance ByteArrayAccess (SensitiveBytes s) where
length :: SensitiveBytes s -> Int
length SensitiveBytes{ Int
dataSize :: Int
dataSize :: forall k (s :: k). SensitiveBytes s -> Int
dataSize } = Int
dataSize
withByteArray :: SensitiveBytes s -> (Ptr p -> IO a) -> IO a
withByteArray SensitiveBytes{ Ptr ()
bufPtr :: Ptr ()
bufPtr :: forall k (s :: k). SensitiveBytes s -> Ptr ()
bufPtr } Ptr p -> IO a
act = Ptr p -> IO a
act (Ptr () -> Ptr p
forall a b. Ptr a -> Ptr b
castPtr Ptr ()
bufPtr)
unsafePtr :: SensitiveBytes s -> Ptr ()
unsafePtr :: SensitiveBytes s -> Ptr ()
unsafePtr = SensitiveBytes s -> Ptr ()
forall k (s :: k). SensitiveBytes s -> Ptr ()
bufPtr
allocate
:: forall s m. (MonadIO m, WithSecureMemory)
=> Int
-> m (SensitiveBytes s)
allocate :: Int -> m (SensitiveBytes s)
allocate Int
size = m (SensitiveBytes s) -> m (SensitiveBytes s)
forall r. r -> WithSecureMemory => r
requiringSecureMemory (m (SensitiveBytes s) -> m (SensitiveBytes s))
-> (IO (SensitiveBytes s) -> m (SensitiveBytes s))
-> IO (SensitiveBytes s)
-> m (SensitiveBytes s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO (SensitiveBytes s) -> m (SensitiveBytes s)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (SensitiveBytes s) -> m (SensitiveBytes s))
-> IO (SensitiveBytes s) -> m (SensitiveBytes s)
forall a b. (a -> b) -> a -> b
$ do
Ptr ()
res <- (Any ::: CSize) -> IO (Ptr ())
forall k (size :: k) a. (Any ::: CSize) -> IO (Ptr a)
sodium_malloc (Int -> Any ::: CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
size)
if Ptr ()
res Ptr () -> Ptr () -> Bool
forall a. Eq a => a -> a -> Bool
== Ptr ()
forall a. Ptr a
nullPtr then
SensitiveBytesAllocException -> IO (SensitiveBytes s)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throwIO SensitiveBytesAllocException
SodiumMallocFailed
else
SensitiveBytes s -> IO (SensitiveBytes s)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SensitiveBytes s -> IO (SensitiveBytes s))
-> SensitiveBytes s -> IO (SensitiveBytes s)
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Ptr () -> SensitiveBytes s
forall k (s :: k). Int -> Int -> Ptr () -> SensitiveBytes s
SensitiveBytes Int
size Int
size Ptr ()
res
free
:: forall s m. (MonadIO m, WithSecureMemory)
=> SensitiveBytes s
-> m ()
free :: SensitiveBytes s -> m ()
free SensitiveBytes{ Ptr ()
bufPtr :: Ptr ()
bufPtr :: forall k (s :: k). SensitiveBytes s -> Ptr ()
bufPtr } = m () -> WithSecureMemory => m ()
forall r. r -> WithSecureMemory => r
requiringSecureMemory (m () -> WithSecureMemory => m ())
-> m () -> WithSecureMemory => m ()
forall a b. (a -> b) -> a -> b
$
IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Ptr () -> IO ()
forall k (addr :: k) x. (addr ::: Ptr x) -> IO ()
sodium_free Ptr ()
bufPtr
memzero
:: forall s m. (MonadIO m)
=> SensitiveBytes s
-> m ()
memzero :: SensitiveBytes s -> m ()
memzero SensitiveBytes{ Int
allocSize :: Int
allocSize :: forall k (s :: k). SensitiveBytes s -> Int
allocSize, Ptr ()
bufPtr :: Ptr ()
bufPtr :: forall k (s :: k). SensitiveBytes s -> Ptr ()
bufPtr } =
IO () -> m ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ Ptr () -> (Any ::: CSize) -> IO ()
forall k1 k2 (pnt :: k1) x (len :: k2).
(pnt ::: Ptr x) -> (Any ::: CSize) -> IO ()
sodium_memzero Ptr ()
bufPtr (Int -> Any ::: CSize
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
allocSize)
resized
:: forall s. ()
=> Int
-> SensitiveBytes s
-> SensitiveBytes s
resized :: Int -> SensitiveBytes s -> SensitiveBytes s
resized Int
newSize sb :: SensitiveBytes s
sb@SensitiveBytes{ Int
allocSize :: Int
allocSize :: forall k (s :: k). SensitiveBytes s -> Int
allocSize }
| Int
newSize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
allocSize = SensitiveBytes s
sb{ dataSize :: Int
dataSize = Int
newSize }
| Bool
otherwise = String -> SensitiveBytes s
forall a. HasCallStack => String -> a
error String
"SensitiveBytes.Internal.resized: the new size is too large"
withSensitiveBytes
:: forall s m r. (MonadIO m, MonadMask m, WithSecureMemory)
=> Int
-> (SensitiveBytes s -> m r)
-> m r
withSensitiveBytes :: Int -> (SensitiveBytes s -> m r) -> m r
withSensitiveBytes Int
size = m (SensitiveBytes s)
-> (SensitiveBytes s -> m ()) -> (SensitiveBytes s -> m r) -> m r
forall (m :: * -> *) a b c.
MonadMask m =>
m a -> (a -> m b) -> (a -> m c) -> m c
bracket (Int -> m (SensitiveBytes s)
forall k (s :: k) (m :: * -> *).
(MonadIO m, WithSecureMemory) =>
Int -> m (SensitiveBytes s)
allocate Int
size) SensitiveBytes s -> m ()
forall k (f :: * -> *) (s :: k).
MonadIO f =>
SensitiveBytes s -> f ()
finalise
where
finalise :: SensitiveBytes s -> f ()
finalise SensitiveBytes s
sb = SensitiveBytes s -> f ()
forall k (s :: k) (m :: * -> *).
MonadIO m =>
SensitiveBytes s -> m ()
memzero SensitiveBytes s
sb f () -> f () -> f ()
forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> SensitiveBytes s -> f ()
forall k (s :: k) (m :: * -> *).
(MonadIO m, WithSecureMemory) =>
SensitiveBytes s -> m ()
free SensitiveBytes s
sb
data SensitiveBytesAllocException
= SodiumMallocFailed
instance Show SensitiveBytesAllocException where
show :: SensitiveBytesAllocException -> String
show SensitiveBytesAllocException
SodiumMallocFailed =
String
"Failed to allocate secure memory"
instance Exception SensitiveBytesAllocException
requiringSecureMemory :: r -> (WithSecureMemory => r)
requiringSecureMemory :: r -> WithSecureMemory => r
requiringSecureMemory r
act = (\SodiumInitialised
_ -> r
act) (SodiumInitialised
forall a. Given a => a
given :: SodiumInitialised)