{-# LANGUAGE FlexibleContexts  #-}
-- | Module define byte sources.
module Raaz.Core.ByteSource
       ( -- * Byte sources.
         -- $bytesource$

         ByteSource(..), PureByteSource
       --    InfiniteSource(..)
       , FillResult(..)
       , fill, processChunks
       , withFillResult
       ) where

import           Control.Applicative
import           Control.Monad.IO.Class
import qualified Data.ByteString      as B
import qualified Data.ByteString.Lazy as L
import           Prelude hiding(length)
import           System.IO            (Handle, hIsEOF)

import           Raaz.Core.MonoidalAction
import           Raaz.Core.Types      (BYTES, Pointer, LengthUnit (..))
import           Raaz.Core.Util.ByteString( unsafeCopyToPointer
                                          , unsafeNCopyToPointer
                                          , length
                                          )
import           Raaz.Core.Types.Pointer  (hFillBuf)

-- $bytesource$
--
-- Cryptographic input come from various sources; they can come from
-- network sockets or might be just a string in the Haskell. To give a
-- uniform interfaces for all such inputs, we define the abstract
-- concept of a /byte source/. Essentially a byte source is one from
-- which we can fill a buffer with bytes.
--
-- Among instances of `ByteSource`, some like for example
-- `B.ByteString` are /pure/ in the sense filling a buffer with bytes
-- from such a source has no other side-effects. This is in contrast
-- to a source like a sockets. The type class `PureByteSource`
-- captures such byte sources.
--

-- | This type captures the result of a fill operation.
data FillResult a = Remaining a           -- ^ There is still bytes left.
                  | Exhausted (BYTES Int) -- ^ source exhausted with so much
                                          -- bytes read.
                    deriving (Int -> FillResult a -> ShowS
[FillResult a] -> ShowS
FillResult a -> String
(Int -> FillResult a -> ShowS)
-> (FillResult a -> String)
-> ([FillResult a] -> ShowS)
-> Show (FillResult a)
forall a. Show a => Int -> FillResult a -> ShowS
forall a. Show a => [FillResult a] -> ShowS
forall a. Show a => FillResult a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [FillResult a] -> ShowS
$cshowList :: forall a. Show a => [FillResult a] -> ShowS
show :: FillResult a -> String
$cshow :: forall a. Show a => FillResult a -> String
showsPrec :: Int -> FillResult a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> FillResult a -> ShowS
Show, FillResult a -> FillResult a -> Bool
(FillResult a -> FillResult a -> Bool)
-> (FillResult a -> FillResult a -> Bool) -> Eq (FillResult a)
forall a. Eq a => FillResult a -> FillResult a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: FillResult a -> FillResult a -> Bool
$c/= :: forall a. Eq a => FillResult a -> FillResult a -> Bool
== :: FillResult a -> FillResult a -> Bool
$c== :: forall a. Eq a => FillResult a -> FillResult a -> Bool
Eq)

instance Functor FillResult where
  fmap :: (a -> b) -> FillResult a -> FillResult b
fmap a -> b
f (Remaining a
a ) = b -> FillResult b
forall a. a -> FillResult a
Remaining (b -> FillResult b) -> b -> FillResult b
forall a b. (a -> b) -> a -> b
$ a -> b
f a
a
  fmap a -> b
_ (Exhausted BYTES Int
sz) = BYTES Int -> FillResult b
forall a. BYTES Int -> FillResult a
Exhausted BYTES Int
sz

-- | Combinator to handle a fill result.
withFillResult :: (a -> b)          -- ^ stuff to do when filled
               -> (BYTES Int -> b)  -- ^ stuff to do when exhausted
               -> FillResult a      -- ^ the fill result to process
               -> b
withFillResult :: (a -> b) -> (BYTES Int -> b) -> FillResult a -> b
withFillResult a -> b
continueWith BYTES Int -> b
_     (Remaining a
a)  = a -> b
continueWith a
a
withFillResult a -> b
_            BYTES Int -> b
endBy (Exhausted BYTES Int
sz) = BYTES Int -> b
endBy BYTES Int
sz

------------------------ Byte sources ----------------------------------

-- | Abstract byte sources. A bytesource is something that you can use
-- to fill a buffer.
--
--  __WARNING:__ The source is required to return `Exhausted` in the
-- boundary case where it has exactly the number of bytes
-- requested. In other words, if the source returns @Remaining@ on any
-- particular request, there should be at least 1 additional byte left
-- on the source for the next request. Cryptographic block primitives
-- have do certain special processing for the last block and it is
-- required to know whether the last block has been read or not.
class ByteSource src where
  -- | Fills a buffer from the source.
  fillBytes :: BYTES Int  -- ^ Buffer size
            -> src        -- ^ The source to fill.
            -> Pointer  -- ^ Buffer pointer
            -> IO (FillResult src)

-- | A version of fillBytes that takes type safe lengths as input.
fill :: ( LengthUnit len
        , ByteSource src
        )
     => len
     -> src
     -> Pointer
     -> IO (FillResult src)
fill :: len -> src -> Pointer -> IO (FillResult src)
fill = BYTES Int -> src -> Pointer -> IO (FillResult src)
forall src.
ByteSource src =>
BYTES Int -> src -> Pointer -> IO (FillResult src)
fillBytes (BYTES Int -> src -> Pointer -> IO (FillResult src))
-> (len -> BYTES Int)
-> len
-> src
-> Pointer
-> IO (FillResult src)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. len -> BYTES Int
forall u. LengthUnit u => u -> BYTES Int
inBytes
{-# INLINE fill #-}

-- | Process data from a source in chunks of a particular size.
processChunks :: ( MonadIO m, LengthUnit chunkSize, ByteSource src)
              => m a                 -- action on a complete chunk,
              -> (BYTES Int -> m b)  -- action on the last partial chunk,
              -> src                 -- the source
              -> chunkSize           -- size of the chunksize
              -> Pointer             -- buffer to fill the chunk in
              -> m b
processChunks :: m a -> (BYTES Int -> m b) -> src -> chunkSize -> Pointer -> m b
processChunks m a
mid BYTES Int -> m b
end src
source chunkSize
csz Pointer
ptr = src -> m b
forall a. ByteSource a => a -> m b
go src
source
  where fillChunk :: src -> m (FillResult src)
fillChunk src
src = IO (FillResult src) -> m (FillResult src)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (FillResult src) -> m (FillResult src))
-> IO (FillResult src) -> m (FillResult src)
forall a b. (a -> b) -> a -> b
$ chunkSize -> src -> Pointer -> IO (FillResult src)
forall len src.
(LengthUnit len, ByteSource src) =>
len -> src -> Pointer -> IO (FillResult src)
fill chunkSize
csz src
src Pointer
ptr
        step :: a -> m b
step a
src      = m a
mid m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> a -> m b
go a
src
        go :: a -> m b
go a
src        = a -> m (FillResult a)
forall (m :: * -> *) src.
(MonadIO m, ByteSource src) =>
src -> m (FillResult src)
fillChunk a
src m (FillResult a) -> (FillResult a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (a -> m b) -> (BYTES Int -> m b) -> FillResult a -> m b
forall a b. (a -> b) -> (BYTES Int -> b) -> FillResult a -> b
withFillResult a -> m b
step BYTES Int -> m b
end


-- | A byte source src is pure if filling from it does not have any
-- other side effect on the state of the byte source. Formally, two
-- different fills form the same source should fill the buffer with
-- the same bytes.  This additional constraint on the source helps to
-- /purify/ certain crypto computations like computing the hash or mac
-- of the source. Usualy sources like `B.ByteString` etc are pure byte
-- sources. A file handle is a byte source that is /not/ a pure
-- source.
class ByteSource src => PureByteSource src where

----------------------- Instances of byte source -----------------------

-- | __WARNING:_ The `fillBytes` may block.
instance ByteSource Handle where
  {-# INLINE fillBytes #-}
  fillBytes :: BYTES Int -> Handle -> Pointer -> IO (FillResult Handle)
fillBytes BYTES Int
sz Handle
hand Pointer
cptr = do
    BYTES Int
count <- Handle -> Pointer -> BYTES Int -> IO (BYTES Int)
forall bufSize.
LengthUnit bufSize =>
Handle -> Pointer -> bufSize -> IO (BYTES Int)
hFillBuf Handle
hand Pointer
cptr BYTES Int
sz
    Bool
eof   <- Handle -> IO Bool
hIsEOF Handle
hand
    if Bool
eof then FillResult Handle -> IO (FillResult Handle)
forall (m :: * -> *) a. Monad m => a -> m a
return (FillResult Handle -> IO (FillResult Handle))
-> FillResult Handle -> IO (FillResult Handle)
forall a b. (a -> b) -> a -> b
$ BYTES Int -> FillResult Handle
forall a. BYTES Int -> FillResult a
Exhausted BYTES Int
count
      else FillResult Handle -> IO (FillResult Handle)
forall (m :: * -> *) a. Monad m => a -> m a
return (FillResult Handle -> IO (FillResult Handle))
-> FillResult Handle -> IO (FillResult Handle)
forall a b. (a -> b) -> a -> b
$ Handle -> FillResult Handle
forall a. a -> FillResult a
Remaining Handle
hand

instance ByteSource B.ByteString where
  {-# INLINE fillBytes #-}
  fillBytes :: BYTES Int -> ByteString -> Pointer -> IO (FillResult ByteString)
fillBytes BYTES Int
sz ByteString
bs Pointer
cptr | BYTES Int
l BYTES Int -> BYTES Int -> Bool
forall a. Ord a => a -> a -> Bool
<= BYTES Int
sz    = do ByteString -> Pointer -> IO ()
unsafeCopyToPointer ByteString
bs Pointer
cptr
                                         FillResult ByteString -> IO (FillResult ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (FillResult ByteString -> IO (FillResult ByteString))
-> FillResult ByteString -> IO (FillResult ByteString)
forall a b. (a -> b) -> a -> b
$ BYTES Int -> FillResult ByteString
forall a. BYTES Int -> FillResult a
Exhausted BYTES Int
l
                       | Bool
otherwise = do BYTES Int -> ByteString -> Pointer -> IO ()
forall n. LengthUnit n => n -> ByteString -> Pointer -> IO ()
unsafeNCopyToPointer BYTES Int
sz ByteString
bs Pointer
cptr
                                        FillResult ByteString -> IO (FillResult ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return (FillResult ByteString -> IO (FillResult ByteString))
-> FillResult ByteString -> IO (FillResult ByteString)
forall a b. (a -> b) -> a -> b
$ ByteString -> FillResult ByteString
forall a. a -> FillResult a
Remaining ByteString
rest
       where l :: BYTES Int
l    = ByteString -> BYTES Int
length ByteString
bs
             rest :: ByteString
rest = Int -> ByteString -> ByteString
B.drop (BYTES Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral BYTES Int
sz) ByteString
bs

instance ByteSource L.ByteString where
  {-# INLINE fillBytes #-}
  fillBytes :: BYTES Int -> ByteString -> Pointer -> IO (FillResult ByteString)
fillBytes BYTES Int
sz ByteString
bs = (FillResult [ByteString] -> FillResult ByteString)
-> IO (FillResult [ByteString]) -> IO (FillResult ByteString)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (([ByteString] -> ByteString)
-> FillResult [ByteString] -> FillResult ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [ByteString] -> ByteString
L.fromChunks) (IO (FillResult [ByteString]) -> IO (FillResult ByteString))
-> (Pointer -> IO (FillResult [ByteString]))
-> Pointer
-> IO (FillResult ByteString)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BYTES Int
-> [ByteString] -> Pointer -> IO (FillResult [ByteString])
forall src.
ByteSource src =>
BYTES Int -> src -> Pointer -> IO (FillResult src)
fillBytes BYTES Int
sz (ByteString -> [ByteString]
L.toChunks ByteString
bs)

instance ByteSource src => ByteSource (Maybe src) where
  {-# INLINE fillBytes #-}
  fillBytes :: BYTES Int -> Maybe src -> Pointer -> IO (FillResult (Maybe src))
fillBytes BYTES Int
sz Maybe src
ma Pointer
cptr = IO (FillResult (Maybe src))
-> (src -> IO (FillResult (Maybe src)))
-> Maybe src
-> IO (FillResult (Maybe src))
forall b a. b -> (a -> b) -> Maybe a -> b
maybe IO (FillResult (Maybe src))
forall a. IO (FillResult a)
exhausted src -> IO (FillResult (Maybe src))
forall a. ByteSource a => a -> IO (FillResult (Maybe a))
fillIt Maybe src
ma
          where exhausted :: IO (FillResult a)
exhausted = FillResult a -> IO (FillResult a)
forall (m :: * -> *) a. Monad m => a -> m a
return (FillResult a -> IO (FillResult a))
-> FillResult a -> IO (FillResult a)
forall a b. (a -> b) -> a -> b
$ BYTES Int -> FillResult a
forall a. BYTES Int -> FillResult a
Exhausted BYTES Int
0
                fillIt :: a -> IO (FillResult (Maybe a))
fillIt a
a  = (a -> Maybe a) -> FillResult a -> FillResult (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> Maybe a
forall a. a -> Maybe a
Just (FillResult a -> FillResult (Maybe a))
-> IO (FillResult a) -> IO (FillResult (Maybe a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> BYTES Int -> a -> Pointer -> IO (FillResult a)
forall src.
ByteSource src =>
BYTES Int -> src -> Pointer -> IO (FillResult src)
fillBytes BYTES Int
sz a
a Pointer
cptr

instance ByteSource src => ByteSource [src] where
  fillBytes :: BYTES Int -> [src] -> Pointer -> IO (FillResult [src])
fillBytes BYTES Int
_  []     Pointer
_    = FillResult [src] -> IO (FillResult [src])
forall (m :: * -> *) a. Monad m => a -> m a
return (FillResult [src] -> IO (FillResult [src]))
-> FillResult [src] -> IO (FillResult [src])
forall a b. (a -> b) -> a -> b
$ BYTES Int -> FillResult [src]
forall a. BYTES Int -> FillResult a
Exhausted BYTES Int
0
  fillBytes BYTES Int
sz (src
x:[src]
xs) Pointer
cptr = do
    FillResult src
result <- BYTES Int -> src -> Pointer -> IO (FillResult src)
forall src.
ByteSource src =>
BYTES Int -> src -> Pointer -> IO (FillResult src)
fillBytes BYTES Int
sz src
x Pointer
cptr
    case FillResult src
result of
      Remaining src
nx     -> FillResult [src] -> IO (FillResult [src])
forall (m :: * -> *) a. Monad m => a -> m a
return (FillResult [src] -> IO (FillResult [src]))
-> FillResult [src] -> IO (FillResult [src])
forall a b. (a -> b) -> a -> b
$ [src] -> FillResult [src]
forall a. a -> FillResult a
Remaining ([src] -> FillResult [src]) -> [src] -> FillResult [src]
forall a b. (a -> b) -> a -> b
$ src
nxsrc -> [src] -> [src]
forall a. a -> [a] -> [a]
:[src]
xs
      Exhausted BYTES Int
bytesX -> let nptr :: Pointer
nptr              = BYTES Int
bytesX BYTES Int -> Pointer -> Pointer
forall m space. LAction m space => m -> space -> space
<.> Pointer
cptr
                              whenXSExhausted :: BYTES Int -> m (FillResult a)
whenXSExhausted BYTES Int
bytesXS = FillResult a -> m (FillResult a)
forall (m :: * -> *) a. Monad m => a -> m a
return (FillResult a -> m (FillResult a))
-> FillResult a -> m (FillResult a)
forall a b. (a -> b) -> a -> b
$ BYTES Int -> FillResult a
forall a. BYTES Int -> FillResult a
Exhausted (BYTES Int -> FillResult a) -> BYTES Int -> FillResult a
forall a b. (a -> b) -> a -> b
$ BYTES Int
bytesX BYTES Int -> BYTES Int -> BYTES Int
forall a. Num a => a -> a -> a
+ BYTES Int
bytesXS
                              whenXSRemains :: a -> IO (FillResult a)
whenXSRemains           = FillResult a -> IO (FillResult a)
forall (m :: * -> *) a. Monad m => a -> m a
return (FillResult a -> IO (FillResult a))
-> (a -> FillResult a) -> a -> IO (FillResult a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> FillResult a
forall a. a -> FillResult a
Remaining
                           in BYTES Int -> [src] -> Pointer -> IO (FillResult [src])
forall src.
ByteSource src =>
BYTES Int -> src -> Pointer -> IO (FillResult src)
fillBytes (BYTES Int
sz BYTES Int -> BYTES Int -> BYTES Int
forall a. Num a => a -> a -> a
- BYTES Int
bytesX) [src]
xs Pointer
nptr
                              IO (FillResult [src])
-> (FillResult [src] -> IO (FillResult [src]))
-> IO (FillResult [src])
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= ([src] -> IO (FillResult [src]))
-> (BYTES Int -> IO (FillResult [src]))
-> FillResult [src]
-> IO (FillResult [src])
forall a b. (a -> b) -> (BYTES Int -> b) -> FillResult a -> b
withFillResult [src] -> IO (FillResult [src])
forall a. a -> IO (FillResult a)
whenXSRemains BYTES Int -> IO (FillResult [src])
forall (m :: * -> *) a. Monad m => BYTES Int -> m (FillResult a)
whenXSExhausted


--------------------- Instances of pure byte source --------------------

instance PureByteSource B.ByteString where
instance PureByteSource L.ByteString where
instance PureByteSource src => PureByteSource [src]
instance PureByteSource src => PureByteSource (Maybe src)