-- | Internal functions
module Network.Transport.Internal
  ( -- * Encoders/decoders
    encodeWord32
  , decodeWord32
  , encodeEnum32
  , decodeNum32
  , encodeWord16
  , decodeWord16
  , encodeEnum16
  , decodeNum16
  , prependLength
    -- * Miscellaneous abstractions
  , mapIOException
  , tryIO
  , tryToEnum
  , timeoutMaybe
  , asyncWhenCancelled
  -- * Replicated functionality from "base"
  , void
  , forkIOWithUnmask
    -- * Debugging
  , tlog
  ) where

#if ! MIN_VERSION_base(4,6,0)
import Prelude hiding (catch)
#endif

import Foreign.Storable (pokeByteOff, peekByteOff)
import Foreign.ForeignPtr (withForeignPtr)
import Data.ByteString (ByteString)
import Data.List (foldl')
import qualified Data.ByteString as BS (length)
import qualified Data.ByteString.Internal as BSI
  ( unsafeCreate
  , toForeignPtr
  )
import Data.Word (Word32, Word16)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Exception
  ( IOException
  , SomeException
  , AsyncException
  , Exception
  , catch
  , try
  , throw
  , throwIO
  , mask_
  )
import Control.Concurrent (ThreadId, forkIO)
import Control.Concurrent.MVar (MVar, newEmptyMVar, takeMVar, putMVar)
import GHC.IO (unsafeUnmask)
import System.IO.Unsafe (unsafeDupablePerformIO)
import System.Timeout (timeout)
--import Control.Concurrent (myThreadId)

#ifdef mingw32_HOST_OS

foreign import stdcall unsafe "htonl" htonl :: Word32 -> Word32
foreign import stdcall unsafe "ntohl" ntohl :: Word32 -> Word32
foreign import stdcall unsafe "htons" htons :: Word16 -> Word16
foreign import stdcall unsafe "ntohs" ntohs :: Word16 -> Word16

#else

foreign import ccall unsafe "htonl" htonl :: Word32 -> Word32
foreign import ccall unsafe "ntohl" ntohl :: Word32 -> Word32
foreign import ccall unsafe "htons" htons :: Word16 -> Word16
foreign import ccall unsafe "ntohs" ntohs :: Word16 -> Word16

#endif

-- | Serialize 32-bit to network byte order
encodeWord32 :: Word32 -> ByteString
encodeWord32 :: Word32 -> ByteString
encodeWord32 Word32
w32 =
  Int -> (Ptr Word8 -> IO ()) -> ByteString
BSI.unsafeCreate Int
4 ((Ptr Word8 -> IO ()) -> ByteString)
-> (Ptr Word8 -> IO ()) -> ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p ->
    Ptr Word8 -> Int -> Word32 -> IO ()
forall b. Ptr b -> Int -> Word32 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr Word8
p Int
0 (Word32 -> Word32
htonl Word32
w32)

-- | Deserialize 32-bit from network byte order
-- Throws an IO exception if this is not exactly 32 bits.
decodeWord32 :: ByteString -> Word32
decodeWord32 :: ByteString -> Word32
decodeWord32 ByteString
bs
  | ByteString -> Int
BS.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
4 = IOError -> Word32
forall a e. Exception e => e -> a
throw (IOError -> Word32) -> IOError -> Word32
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"decodeWord32: not 4 bytes"
  | Bool
otherwise         = IO Word32 -> Word32
forall a. IO a -> a
unsafeDupablePerformIO (IO Word32 -> Word32) -> IO Word32 -> Word32
forall a b. (a -> b) -> a -> b
$ do
      let (ForeignPtr Word8
fp, Int
offset, Int
_) = ByteString -> (ForeignPtr Word8, Int, Int)
BSI.toForeignPtr ByteString
bs
      ForeignPtr Word8 -> (Ptr Word8 -> IO Word32) -> IO Word32
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO Word32) -> IO Word32)
-> (Ptr Word8 -> IO Word32) -> IO Word32
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p -> Word32 -> Word32
ntohl (Word32 -> Word32) -> IO Word32 -> IO Word32
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr Word8 -> Int -> IO Word32
forall b. Ptr b -> Int -> IO Word32
forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Word8
p Int
offset

-- | Serialize 16-bit to network byte order
encodeWord16 :: Word16 -> ByteString
encodeWord16 :: Word16 -> ByteString
encodeWord16 Word16
w16 =
  Int -> (Ptr Word8 -> IO ()) -> ByteString
BSI.unsafeCreate Int
2 ((Ptr Word8 -> IO ()) -> ByteString)
-> (Ptr Word8 -> IO ()) -> ByteString
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p ->
    Ptr Word8 -> Int -> Word16 -> IO ()
forall b. Ptr b -> Int -> Word16 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr Word8
p Int
0 (Word16 -> Word16
htons Word16
w16)

-- | Deserialize 16-bit from network byte order
-- Throws an IO exception if this is not exactly 16 bits.
decodeWord16 :: ByteString -> Word16
decodeWord16 :: ByteString -> Word16
decodeWord16 ByteString
bs
  | ByteString -> Int
BS.length ByteString
bs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
2 = IOError -> Word16
forall a e. Exception e => e -> a
throw (IOError -> Word16) -> IOError -> Word16
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"decodeWord16: not 2 bytes"
  | Bool
otherwise         = IO Word16 -> Word16
forall a. IO a -> a
unsafeDupablePerformIO (IO Word16 -> Word16) -> IO Word16 -> Word16
forall a b. (a -> b) -> a -> b
$ do
      let (ForeignPtr Word8
fp, Int
offset, Int
_) = ByteString -> (ForeignPtr Word8, Int, Int)
BSI.toForeignPtr ByteString
bs
      ForeignPtr Word8 -> (Ptr Word8 -> IO Word16) -> IO Word16
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
fp ((Ptr Word8 -> IO Word16) -> IO Word16)
-> (Ptr Word8 -> IO Word16) -> IO Word16
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p -> Word16 -> Word16
ntohs (Word16 -> Word16) -> IO Word16 -> IO Word16
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr Word8 -> Int -> IO Word16
forall b. Ptr b -> Int -> IO Word16
forall a b. Storable a => Ptr b -> Int -> IO a
peekByteOff Ptr Word8
p Int
offset

-- | Encode an Enum in 32 bits by encoding its signed Int equivalent (beware
-- of truncation, an Enum may contain more than 2^32 points).
encodeEnum32 :: Enum a => a -> ByteString
encodeEnum32 :: forall a. Enum a => a -> ByteString
encodeEnum32 = Word32 -> ByteString
encodeWord32 (Word32 -> ByteString) -> (a -> Word32) -> a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word32) -> (a -> Int) -> a -> Word32
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Int
forall a. Enum a => a -> Int
fromEnum

-- | Decode any Num type from 32 bits by using fromIntegral to convert from
--   a Word32.
decodeNum32 :: Num a => ByteString -> a
decodeNum32 :: forall a. Num a => ByteString -> a
decodeNum32 = Word32 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32 -> a) -> (ByteString -> Word32) -> ByteString -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Word32
decodeWord32

-- | Encode an Enum in 16 bits by encoding its signed Int equivalent (beware
-- of truncation, an Enum may contain more than 2^16 points).
encodeEnum16 :: Enum a => a -> ByteString
encodeEnum16 :: forall a. Enum a => a -> ByteString
encodeEnum16 = Word16 -> ByteString
encodeWord16 (Word16 -> ByteString) -> (a -> Word16) -> a -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Word16) -> (a -> Int) -> a -> Word16
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Int
forall a. Enum a => a -> Int
fromEnum

-- | Decode any Num type from 16 bits by using fromIntegral to convert from
-- a Word16.
decodeNum16 :: Num a => ByteString -> a
decodeNum16 :: forall a. Num a => ByteString -> a
decodeNum16 = Word16 -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16 -> a) -> (ByteString -> Word16) -> ByteString -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Word16
decodeWord16

-- | Prepend a list of bytestrings with their total length
--   Will be an exception in case of overflow: the sum of the lengths of
--   the ByteStrings overflows Int, or that sum overflows Word32.
prependLength :: [ByteString] -> [ByteString]
prependLength :: [ByteString] -> [ByteString]
prependLength [ByteString]
bss = case Maybe Word32
word32Length of
    Maybe Word32
Nothing -> [ByteString]
forall {a}. a
overflow
    Just Word32
w32 -> Word32 -> ByteString
encodeWord32 Word32
w32 ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: [ByteString]
bss
  where
    intLength :: Int
    intLength :: Int
intLength = (Int -> Int -> Int) -> Int -> [Int] -> Int
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Int -> Int -> Int
safeAdd Int
0 ([Int] -> Int) -> ([ByteString] -> [Int]) -> [ByteString] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> Int) -> [ByteString] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> Int
BS.length ([ByteString] -> Int) -> [ByteString] -> Int
forall a b. (a -> b) -> a -> b
$ [ByteString]
bss
    word32Length :: Maybe Word32
    word32Length :: Maybe Word32
word32Length = Int -> Maybe Word32
forall a. (Enum a, Bounded a) => Int -> Maybe a
tryToEnum Int
intLength
    -- Non-negative integer addition with overflow check.
    safeAdd :: Int -> Int -> Int
    safeAdd :: Int -> Int -> Int
safeAdd Int
i Int
j
      | Int
r Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0    = Int
r
      | Bool
otherwise = Int
forall {a}. a
overflow
      where
      r :: Int
r = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
j
    overflow :: a
overflow = IOError -> a
forall a e. Exception e => e -> a
throw (IOError -> a) -> IOError -> a
forall a b. (a -> b) -> a -> b
$ String -> IOError
userError String
"prependLength: input is too long (overflow)"

-- | Translate exceptions that arise in IO computations
mapIOException :: Exception e => (IOException -> e) -> IO a -> IO a
mapIOException :: forall e a. Exception e => (IOError -> e) -> IO a -> IO a
mapIOException IOError -> e
f IO a
p = IO a -> (IOError -> IO a) -> IO a
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch IO a
p (e -> IO a
forall e a. Exception e => e -> IO a
throwIO (e -> IO a) -> (IOError -> e) -> IOError -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOError -> e
f)

-- | Like 'try', but lifted and specialized to IOExceptions
tryIO :: MonadIO m => IO a -> m (Either IOException a)
tryIO :: forall (m :: * -> *) a. MonadIO m => IO a -> m (Either IOError a)
tryIO = IO (Either IOError a) -> m (Either IOError a)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Either IOError a) -> m (Either IOError a))
-> (IO a -> IO (Either IOError a)) -> IO a -> m (Either IOError a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IO a -> IO (Either IOError a)
forall e a. Exception e => IO a -> IO (Either e a)
try

-- | Logging (for debugging)
tlog :: MonadIO m => String -> m ()
tlog :: forall (m :: * -> *). MonadIO m => String -> m ()
tlog String
_ = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
{-
tlog msg = liftIO $ do
  tid <- myThreadId
  putStrLn $ show tid ++ ": "  ++ msg
-}

-- | Not all versions of "base" export 'void'
void :: Monad m => m a -> m ()
void :: forall (m :: * -> *) a. Monad m => m a -> m ()
void m a
p = m a
p m a -> m () -> m ()
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | This was introduced in "base" some time after 7.0.4
forkIOWithUnmask :: ((forall a . IO a -> IO a) -> IO ()) -> IO ThreadId
forkIOWithUnmask :: ((forall a. IO a -> IO a) -> IO ()) -> IO ThreadId
forkIOWithUnmask (forall a. IO a -> IO a) -> IO ()
io = IO () -> IO ThreadId
forkIO ((forall a. IO a -> IO a) -> IO ()
io IO a -> IO a
forall a. IO a -> IO a
unsafeUnmask)

-- | Safe version of 'toEnum'
tryToEnum :: (Enum a, Bounded a) => Int -> Maybe a
tryToEnum :: forall a. (Enum a, Bounded a) => Int -> Maybe a
tryToEnum = a -> a -> Int -> Maybe a
forall b. Enum b => b -> b -> Int -> Maybe b
go a
forall a. Bounded a => a
minBound a
forall a. Bounded a => a
maxBound
  where
    go :: Enum b => b -> b -> Int -> Maybe b
    go :: forall b. Enum b => b -> b -> Int -> Maybe b
go b
lo b
hi Int
n = if b -> Int
forall a. Enum a => a -> Int
fromEnum b
lo Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
n Bool -> Bool -> Bool
&& Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= b -> Int
forall a. Enum a => a -> Int
fromEnum b
hi then b -> Maybe b
forall a. a -> Maybe a
Just (Int -> b
forall a. Enum a => Int -> a
toEnum Int
n) else Maybe b
forall a. Maybe a
Nothing

-- | If the timeout value is not Nothing, wrap the given computation with a
-- timeout and it if times out throw the specified exception. Identity
-- otherwise.
timeoutMaybe :: Exception e => Maybe Int -> e -> IO a -> IO a
timeoutMaybe :: forall e a. Exception e => Maybe Int -> e -> IO a -> IO a
timeoutMaybe Maybe Int
Nothing  e
_ IO a
f = IO a
f
timeoutMaybe (Just Int
n) e
e IO a
f = do
  Maybe a
ma <- Int -> IO a -> IO (Maybe a)
forall a. Int -> IO a -> IO (Maybe a)
timeout Int
n IO a
f
  case Maybe a
ma of
    Maybe a
Nothing -> e -> IO a
forall e a. Exception e => e -> IO a
throwIO e
e
    Just a
a  -> a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

-- | @asyncWhenCancelled g f@ runs f in a separate thread and waits for it
-- to complete. If f throws an exception we catch it and rethrow it in the
-- current thread. If the current thread is interrupted before f completes,
-- we run the specified clean up handler (if f throws an exception we assume
-- that no cleanup is necessary).
asyncWhenCancelled :: forall a. (a -> IO ()) -> IO a -> IO a
asyncWhenCancelled :: forall a. (a -> IO ()) -> IO a -> IO a
asyncWhenCancelled a -> IO ()
g IO a
f = IO a -> IO a
forall a. IO a -> IO a
mask_ (IO a -> IO a) -> IO a -> IO a
forall a b. (a -> b) -> a -> b
$ do
    MVar (Either SomeException a)
mvar <- IO (MVar (Either SomeException a))
forall a. IO (MVar a)
newEmptyMVar
    IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ IO a -> IO (Either SomeException a)
forall e a. Exception e => IO a -> IO (Either e a)
try IO a
f IO (Either SomeException a)
-> (Either SomeException a -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= MVar (Either SomeException a) -> Either SomeException a -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (Either SomeException a)
mvar
    -- takeMVar is interruptible (even inside a mask_)
    IO (Either SomeException a)
-> (AsyncException -> IO (Either SomeException a))
-> IO (Either SomeException a)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
catch (MVar (Either SomeException a) -> IO (Either SomeException a)
forall a. MVar a -> IO a
takeMVar MVar (Either SomeException a)
mvar) (MVar (Either SomeException a)
-> AsyncException -> IO (Either SomeException a)
exceptionHandler MVar (Either SomeException a)
mvar) IO (Either SomeException a)
-> (Either SomeException a -> IO a) -> IO a
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (SomeException -> IO a)
-> (a -> IO a) -> Either SomeException a -> IO a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either SomeException -> IO a
forall e a. Exception e => e -> IO a
throwIO a -> IO a
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return
  where
    exceptionHandler :: MVar (Either SomeException a)
                     -> AsyncException
                     -> IO (Either SomeException a)
    exceptionHandler :: MVar (Either SomeException a)
-> AsyncException -> IO (Either SomeException a)
exceptionHandler MVar (Either SomeException a)
mvar AsyncException
ex = do
      IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ MVar (Either SomeException a) -> IO (Either SomeException a)
forall a. MVar a -> IO a
takeMVar MVar (Either SomeException a)
mvar IO (Either SomeException a)
-> (Either SomeException a -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (SomeException -> IO ())
-> (a -> IO ()) -> Either SomeException a -> IO ()
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (IO () -> SomeException -> IO ()
forall a b. a -> b -> a
const (IO () -> SomeException -> IO ())
-> IO () -> SomeException -> IO ()
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()) a -> IO ()
g
      AsyncException -> IO (Either SomeException a)
forall e a. Exception e => e -> IO a
throwIO AsyncException
ex