{-# LANGUAGE ScopedTypeVariables #-}
module Network.TLS.Util
( sub
, takelast
, partition3
, partition6
, fromJust
, (&&!)
, bytesEq
, fmapEither
, catchException
, forEitherM
, mapChunks_
, getChunks
, Saved
, saveMVar
, restoreMVar
) where
import qualified Data.ByteArray as BA
import qualified Data.ByteString as B
import Network.TLS.Imports
import Control.Exception (SomeException)
import Control.Concurrent.Async
import Control.Concurrent.MVar
sub :: ByteString -> Int -> Int -> Maybe ByteString
sub b offset len
| B.length b < offset + len = Nothing
| otherwise = Just $ B.take len $ snd $ B.splitAt offset b
takelast :: Int -> ByteString -> Maybe ByteString
takelast i b
| B.length b >= i = sub b (B.length b - i) i
| otherwise = Nothing
partition3 :: ByteString -> (Int,Int,Int) -> Maybe (ByteString, ByteString, ByteString)
partition3 bytes (d1,d2,d3)
| any (< 0) l = Nothing
| sum l /= B.length bytes = Nothing
| otherwise = Just (p1,p2,p3)
where l = [d1,d2,d3]
(p1, r1) = B.splitAt d1 bytes
(p2, r2) = B.splitAt d2 r1
(p3, _) = B.splitAt d3 r2
partition6 :: ByteString -> (Int,Int,Int,Int,Int,Int) -> Maybe (ByteString, ByteString, ByteString, ByteString, ByteString, ByteString)
partition6 bytes (d1,d2,d3,d4,d5,d6) = if B.length bytes < s then Nothing else Just (p1,p2,p3,p4,p5,p6)
where s = sum [d1,d2,d3,d4,d5,d6]
(p1, r1) = B.splitAt d1 bytes
(p2, r2) = B.splitAt d2 r1
(p3, r3) = B.splitAt d3 r2
(p4, r4) = B.splitAt d4 r3
(p5, r5) = B.splitAt d5 r4
(p6, _) = B.splitAt d6 r5
fromJust :: String -> Maybe a -> a
fromJust what Nothing = error ("fromJust " ++ what ++ ": Nothing")
fromJust _ (Just x) = x
(&&!) :: Bool -> Bool -> Bool
True &&! True = True
True &&! False = False
False &&! True = False
False &&! False = False
bytesEq :: ByteString -> ByteString -> Bool
bytesEq = BA.constEq
fmapEither :: (a -> b) -> Either l a -> Either l b
fmapEither f = fmap f
catchException :: IO a -> (SomeException -> IO a) -> IO a
catchException action handler = withAsync action waitCatch >>= either handler return
forEitherM :: Monad m => [a] -> (a -> m (Either l b)) -> m (Either l [b])
forEitherM [] _ = return (pure [])
forEitherM (x:xs) f = f x >>= doTail
where
doTail (Right b) = fmap (b :) <$> forEitherM xs f
doTail (Left e) = return (Left e)
mapChunks_ :: Monad m
=> Int -> (B.ByteString -> m a) -> B.ByteString -> m ()
mapChunks_ len f = mapM_ f . getChunks len
getChunks :: Int -> B.ByteString -> [B.ByteString]
getChunks len bs
| B.length bs > len =
let (chunk, remain) = B.splitAt len bs
in chunk : getChunks len remain
| otherwise = [bs]
newtype Saved a = Saved a
saveMVar :: MVar a -> IO (Saved a)
saveMVar ref = Saved <$> readMVar ref
restoreMVar :: MVar a -> Saved a -> IO (Saved a)
restoreMVar ref (Saved val) = Saved <$> swapMVar ref val