{-# LANGUAGE ScopedTypeVariables #-}
module Network.TLS.Util
        ( sub
        , takelast
        , partition3
        , partition6
        , fromJust
        , (&&!)
        , bytesEq
        , fmapEither
        , catchException
        , forEitherM
        , mapChunks_
        , getChunks
        , Saved
        , saveMVar
        , restoreMVar
        ) where

import qualified Data.ByteArray as BA
import qualified Data.ByteString as B
import Network.TLS.Imports

import Control.Exception (SomeException)
import Control.Concurrent.Async
import Control.Concurrent.MVar

sub :: ByteString -> Int -> Int -> Maybe ByteString
sub :: ByteString -> Int -> Int -> Maybe ByteString
sub ByteString
b Int
offset Int
len
    | ByteString -> Int
B.length ByteString
b forall a. Ord a => a -> a -> Bool
< Int
offset forall a. Num a => a -> a -> a
+ Int
len = forall a. Maybe a
Nothing
    | Bool
otherwise                 = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> ByteString
B.take Int
len forall a b. (a -> b) -> a -> b
$ forall a b. (a, b) -> b
snd forall a b. (a -> b) -> a -> b
$ Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
offset ByteString
b

takelast :: Int -> ByteString -> Maybe ByteString
takelast :: Int -> ByteString -> Maybe ByteString
takelast Int
i ByteString
b
    | ByteString -> Int
B.length ByteString
b forall a. Ord a => a -> a -> Bool
>= Int
i = ByteString -> Int -> Int -> Maybe ByteString
sub ByteString
b (ByteString -> Int
B.length ByteString
b forall a. Num a => a -> a -> a
- Int
i) Int
i
    | Bool
otherwise       = forall a. Maybe a
Nothing

partition3 :: ByteString -> (Int,Int,Int) -> Maybe (ByteString, ByteString, ByteString)
partition3 :: ByteString
-> (Int, Int, Int) -> Maybe (ByteString, ByteString, ByteString)
partition3 ByteString
bytes (Int
d1,Int
d2,Int
d3)
    | forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall a. Ord a => a -> a -> Bool
< Int
0) [Int]
l             = forall a. Maybe a
Nothing
    | forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int]
l forall a. Eq a => a -> a -> Bool
/= ByteString -> Int
B.length ByteString
bytes = forall a. Maybe a
Nothing
    | Bool
otherwise               = forall a. a -> Maybe a
Just (ByteString
p1,ByteString
p2,ByteString
p3)
        where l :: [Int]
l        = [Int
d1,Int
d2,Int
d3]
              (ByteString
p1, ByteString
r1) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
d1 ByteString
bytes
              (ByteString
p2, ByteString
r2) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
d2 ByteString
r1
              (ByteString
p3, ByteString
_)  = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
d3 ByteString
r2

partition6 :: ByteString -> (Int,Int,Int,Int,Int,Int) -> Maybe (ByteString, ByteString, ByteString, ByteString, ByteString, ByteString)
partition6 :: ByteString
-> (Int, Int, Int, Int, Int, Int)
-> Maybe
     (ByteString, ByteString, ByteString, ByteString, ByteString,
      ByteString)
partition6 ByteString
bytes (Int
d1,Int
d2,Int
d3,Int
d4,Int
d5,Int
d6) = if ByteString -> Int
B.length ByteString
bytes forall a. Ord a => a -> a -> Bool
< Int
s then forall a. Maybe a
Nothing else forall a. a -> Maybe a
Just (ByteString
p1,ByteString
p2,ByteString
p3,ByteString
p4,ByteString
p5,ByteString
p6)
  where s :: Int
s        = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum [Int
d1,Int
d2,Int
d3,Int
d4,Int
d5,Int
d6]
        (ByteString
p1, ByteString
r1) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
d1 ByteString
bytes
        (ByteString
p2, ByteString
r2) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
d2 ByteString
r1
        (ByteString
p3, ByteString
r3) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
d3 ByteString
r2
        (ByteString
p4, ByteString
r4) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
d4 ByteString
r3
        (ByteString
p5, ByteString
r5) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
d5 ByteString
r4
        (ByteString
p6, ByteString
_)  = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
d6 ByteString
r5

fromJust :: String -> Maybe a -> a
fromJust :: forall a. String -> Maybe a -> a
fromJust String
what Maybe a
Nothing  = forall a. HasCallStack => String -> a
error (String
"fromJust " forall a. [a] -> [a] -> [a]
++ String
what forall a. [a] -> [a] -> [a]
++ String
": Nothing") -- yuck
fromJust String
_    (Just a
x) = a
x

-- | This is a strict version of &&.
(&&!) :: Bool -> Bool -> Bool
Bool
True  &&! :: Bool -> Bool -> Bool
&&! Bool
True  = Bool
True
Bool
True  &&! Bool
False = Bool
False
Bool
False &&! Bool
True  = Bool
False
Bool
False &&! Bool
False = Bool
False

-- | verify that 2 bytestrings are equals.
-- it's a non lazy version, that will compare every bytes.
-- arguments with different length will bail out early
bytesEq :: ByteString -> ByteString -> Bool
bytesEq :: ByteString -> ByteString -> Bool
bytesEq = forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> Bool
BA.constEq

fmapEither :: (a -> b) -> Either l a -> Either l b
fmapEither :: forall a b l. (a -> b) -> Either l a -> Either l b
fmapEither a -> b
f = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap a -> b
f

catchException :: IO a -> (SomeException -> IO a) -> IO a
catchException :: forall a. IO a -> (SomeException -> IO a) -> IO a
catchException IO a
action SomeException -> IO a
handler = forall a b. IO a -> (Async a -> IO b) -> IO b
withAsync IO a
action forall a. Async a -> IO (Either SomeException a)
waitCatch forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either SomeException -> IO a
handler forall (m :: * -> *) a. Monad m => a -> m a
return

forEitherM :: Monad m => [a] -> (a -> m (Either l b)) -> m (Either l [b])
forEitherM :: forall (m :: * -> *) a l b.
Monad m =>
[a] -> (a -> m (Either l b)) -> m (Either l [b])
forEitherM []     a -> m (Either l b)
_ = forall (m :: * -> *) a. Monad m => a -> m a
return (forall (f :: * -> *) a. Applicative f => a -> f a
pure [])
forEitherM (a
x:[a]
xs) a -> m (Either l b)
f = a -> m (Either l b)
f a
x forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Either l b -> m (Either l [b])
doTail
  where
    doTail :: Either l b -> m (Either l [b])
doTail (Right b
b) = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (b
b forall a. a -> [a] -> [a]
:) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a l b.
Monad m =>
[a] -> (a -> m (Either l b)) -> m (Either l [b])
forEitherM [a]
xs a -> m (Either l b)
f
    doTail (Left l
e)  = forall (m :: * -> *) a. Monad m => a -> m a
return (forall a b. a -> Either a b
Left l
e)

mapChunks_ :: Monad m
           => Maybe Int -> (B.ByteString -> m a) -> B.ByteString -> m ()
mapChunks_ :: forall (m :: * -> *) a.
Monad m =>
Maybe Int -> (ByteString -> m a) -> ByteString -> m ()
mapChunks_ Maybe Int
len ByteString -> m a
f = forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ ByteString -> m a
f forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe Int -> ByteString -> [ByteString]
getChunks Maybe Int
len

getChunks :: Maybe Int -> B.ByteString -> [B.ByteString]
getChunks :: Maybe Int -> ByteString -> [ByteString]
getChunks Maybe Int
Nothing    = (forall a. a -> [a] -> [a]
: [])
getChunks (Just Int
len) = ByteString -> [ByteString]
go
  where
    go :: ByteString -> [ByteString]
go ByteString
bs | ByteString -> Int
B.length ByteString
bs forall a. Ord a => a -> a -> Bool
> Int
len =
              let (ByteString
chunk, ByteString
remain) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
len ByteString
bs
               in ByteString
chunk forall a. a -> [a] -> [a]
: ByteString -> [ByteString]
go ByteString
remain
          | Bool
otherwise = [ByteString
bs]

-- | An opaque newtype wrapper to prevent from poking inside content that has
-- been saved.
newtype Saved a = Saved a

-- | Save the content of an 'MVar' to restore it later.
saveMVar :: MVar a -> IO (Saved a)
saveMVar :: forall a. MVar a -> IO (Saved a)
saveMVar MVar a
ref = forall a. a -> Saved a
Saved forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. MVar a -> IO a
readMVar MVar a
ref

-- | Restore the content of an 'MVar' to a previous saved value and return the
-- content that has just been replaced.
restoreMVar :: MVar a -> Saved a -> IO (Saved a)
restoreMVar :: forall a. MVar a -> Saved a -> IO (Saved a)
restoreMVar MVar a
ref (Saved a
val) = forall a. a -> Saved a
Saved forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. MVar a -> a -> IO a
swapMVar MVar a
ref a
val