{-# LINE 1 "src/Network/Socket/Eager.hsc" #-}
-- This Source Code Form is subject to the terms of the Mozilla Public
{-# LINE 2 "src/Network/Socket/Eager.hsc" #-}
-- License, v. 2.0. If a copy of the MPL was not distributed with this
-- file, You can obtain one at http://mozilla.org/MPL/2.0/.

{-# LANGUAGE BangPatterns       #-}
{-# LANGUAGE CPP                #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE MultiWayIf         #-}
{-# LANGUAGE OverloadedStrings  #-}

-- | This module exports some basic functions to wait for some given
-- milliseconds on 'Socket's to become available for reading or writing.
--
-- On top some higher-level functions to 'connect', 'send' and 'recv'
-- with timeouts are provided.
--
-- The implementation is Unix-only and uses GHC-specific functionality
-- from "GHC.IO.Device" and "GHC.IO.FD". In particular, the basic wait
-- functions are just exports of 'GHC.IO.Device.ready'.
module Network.Socket.Eager
    ( Error        (..)
    , Milliseconds (..)
    , Descriptor
    , descriptor
    , connect
    , connect'
    , send
    , send'
    , recv
    , recv'

    -- * basic wait functions
    , waitRead
    , waitRead'
    , waitWrite
    , waitWrite'
    ) where

import Control.Applicative
import Control.Exception
import Control.Monad
import Data.ByteString.Builder
import Data.ByteString.Lazy (ByteString)
import Data.Monoid
import Data.Typeable
import Foreign
import Foreign.C
import GHC.IO.FD
import Network
import Network.Socket hiding (connect, recv, send)
import Network.Socket.Internal (withSockAddr)

import qualified Data.ByteString.Lazy           as B
import qualified Data.ByteString                as BS
import qualified Network.Socket.ByteString      as NS
import qualified Network.Socket.ByteString.Lazy as N
import qualified GHC.IO.Device                  as IODevice

data Error
    = ConnectTimeout
    | RecvTimeout
    | SendTimeout
    | ConnectError
    | ConnectionClosed
    deriving (Eq, Show, Typeable)

instance Exception Error

newtype Milliseconds = Milliseconds { ms :: Int }

-- | A @Descriptor@ encapsulates an @GHC.IO.FD.FD@
-- Most of the wait functions operate on this descriptor which
-- is implicitly created in functions which only accept a
-- 'Socket' as parameter.
newtype Descriptor = Descriptor { fd :: FD  }

descriptor :: Socket -> Descriptor
descriptor s = Descriptor $ FD (fdSocket s) 1

waitRead :: Milliseconds -> Socket -> IO Bool
waitRead t s = waitRead' t (descriptor s)
{-# INLINEABLE waitRead #-}

waitRead' :: Milliseconds -> Descriptor -> IO Bool
waitRead' t s = IODevice.ready (fd s) False (ms t)
{-# INLINEABLE waitRead' #-}

waitWrite :: Milliseconds -> Socket -> IO Bool
waitWrite t s = waitWrite' t (descriptor s)
{-# INLINEABLE waitWrite #-}

waitWrite' :: Milliseconds -> Descriptor -> IO Bool
waitWrite' t s = IODevice.ready (fd s) True (ms t)
{-# INLINEABLE waitWrite' #-}

connect :: Milliseconds -> SockAddr -> Socket -> IO ()
connect t a s = connect' t a (descriptor s)

connect' :: Milliseconds -> SockAddr -> Descriptor -> IO ()
connect' t a d = withSockAddr a loop
  where
    loop ptr sz = do
        r <- c_connect (fdFD (fd d)) ptr (fromIntegral sz)
        if r == -1
            then do
                e <- getErrno
                if | e == eINTR       -> loop ptr sz
                   | e == eINPROGRESS -> wait
                   | otherwise        -> throwIO ConnectError
            else return ()

    wait = do
        isReady <- waitWrite' t d
        unless isReady $
            throwIO ConnectTimeout
        e <- socketError d
        when (e /= 0) $
            throwIO ConnectError

send :: Milliseconds -> ByteString -> Socket -> IO ()
send t bs s = send' t bs (descriptor s) s

send' :: Milliseconds -> ByteString -> Descriptor -> Socket -> IO ()
send' _ "" _ _ = return ()
send' t !b d s = do
    isReady <- waitWrite' t d
    unless isReady $
        throwIO SendTimeout
    n <- N.send s b
    send' t (B.drop n b) d s

recv :: Milliseconds -> Int -> Socket -> IO ByteString
recv t n s = recv' t n (descriptor s) s

recv' :: Milliseconds -> Int -> Descriptor -> Socket -> IO ByteString
recv' _ 0 _ _ = return B.empty
recv' t n d s = toLazyByteString <$> go 0 mempty
  where
    go k bytes = do
        isReady <- waitRead' t d
        unless isReady $
            throwIO RecvTimeout
        a <- NS.recv s (n - k)
        if BS.null a
            then throwIO ConnectionClosed
            else let b = bytes <> byteString a
                     m = BS.length a + k
                 in if m < n then go m b else return b

-----------------------------------------------------------------------------
-- Internal:


{-# LINE 154 "src/Network/Socket/Eager.hsc" #-}

socketError :: Descriptor -> IO CInt
socketError d =
    with (fromIntegral $ sizeOf (undefined :: CInt)) $ \plen ->
        with 0 $ \pval -> do
            throwErrnoIfMinus1Retry_ "getsockopt" $
                c_getsockopt (fdFD (fd d)) sockLevel sockError pval plen
            peek pval

sockLevel :: CInt
sockLevel = 1
{-# LINE 165 "src/Network/Socket/Eager.hsc" #-}

sockError :: CInt
sockError = 4
{-# LINE 168 "src/Network/Socket/Eager.hsc" #-}

foreign import ccall unsafe "connect"
    c_connect :: CInt -> Ptr SockAddr -> CInt -> IO CInt

foreign import ccall unsafe "getsockopt"
    c_getsockopt :: CInt      -- socket
                 -> CInt      -- protocol level
                 -> CInt      -- name
                 -> Ptr CInt  -- value
                 -> Ptr CInt  -- length
                 -> IO CInt