{-# LANGUAGE CPP #-}

module Database.PostgreSQL.Pure.Internal.SocketIO
  ( SocketIO
  , runSocketIO
  , send
  , buildAndSend
  , receive
  ) where

import           Database.PostgreSQL.Pure.Internal.Data      (Buffer (Buffer), Carry, Config, Error (Error))
import qualified Database.PostgreSQL.Pure.Internal.Exception as Exception
import qualified Database.PostgreSQL.Pure.Internal.Parser    as Parser

import           Control.Concurrent                          (yield)
import           Control.Monad.IO.Class                      (liftIO)
import           Control.Monad.Reader                        (ReaderT, ask, runReaderT)
import           Control.Monad.State.Strict                  (StateT, get, put, runStateT)
import qualified Data.Attoparsec.ByteString                  as AP
import qualified Data.ByteString                             as BS
import qualified Data.ByteString.Builder                     as BSB
import qualified Data.ByteString.Builder.Extra               as BSB
import qualified Data.ByteString.Internal                    as BSI
import           Data.List                                   (intercalate)
import           Data.Word                                   (Word8)
import           Foreign                                     (ForeignPtr, Ptr, withForeignPtr)
import qualified Network.Socket                              as NS
import qualified Network.Socket.ByteString                   as NSB

#if MIN_VERSION_base(4,13,0)
import           Control.Exception.Safe                      (throw, try)
import           Control.Monad                               (unless)
#else
import           Control.Exception.Safe                      (throw, try, tryJust)
import           Control.Monad                               (guard, unless)
import           System.IO.Error                             (isEOFError)
#endif

type SocketIO = StateT Carry (ReaderT (NS.Socket, Buffer, Buffer, Config) IO)

runSocketIO :: NS.Socket -> Buffer -> Buffer -> Config -> SocketIO a -> IO a
runSocketIO :: Socket -> Buffer -> Buffer -> Config -> SocketIO a -> IO a
runSocketIO Socket
s Buffer
sb Buffer
rb Config
c SocketIO a
m =
  (ReaderT (Socket, Buffer, Buffer, Config) IO a
 -> (Socket, Buffer, Buffer, Config) -> IO a)
-> (Socket, Buffer, Buffer, Config)
-> ReaderT (Socket, Buffer, Buffer, Config) IO a
-> IO a
forall a b c. (a -> b -> c) -> b -> a -> c
flip ReaderT (Socket, Buffer, Buffer, Config) IO a
-> (Socket, Buffer, Buffer, Config) -> IO a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (Socket
s, Buffer
sb, Buffer
rb, Config
c) (ReaderT (Socket, Buffer, Buffer, Config) IO a -> IO a)
-> ReaderT (Socket, Buffer, Buffer, Config) IO a -> IO a
forall a b. (a -> b) -> a -> b
$ do
    (a
a, Carry
carry) <- SocketIO a
-> Carry -> ReaderT (Socket, Buffer, Buffer, Config) IO (a, Carry)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT SocketIO a
m Carry
BS.empty
    Bool
-> ReaderT (Socket, Buffer, Buffer, Config) IO ()
-> ReaderT (Socket, Buffer, Buffer, Config) IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Carry -> Bool
BS.null Carry
carry) (ReaderT (Socket, Buffer, Buffer, Config) IO ()
 -> ReaderT (Socket, Buffer, Buffer, Config) IO ())
-> ReaderT (Socket, Buffer, Buffer, Config) IO ()
-> ReaderT (Socket, Buffer, Buffer, Config) IO ()
forall a b. (a -> b) -> a -> b
$ InternalException -> ReaderT (Socket, Buffer, Buffer, Config) IO ()
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throw (InternalException
 -> ReaderT (Socket, Buffer, Buffer, Config) IO ())
-> InternalException
-> ReaderT (Socket, Buffer, Buffer, Config) IO ()
forall a b. (a -> b) -> a -> b
$ Carry -> InternalException
Exception.InternalExtraData Carry
carry
    a -> ReaderT (Socket, Buffer, Buffer, Config) IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
a

send :: BS.ByteString -> SocketIO ()
send :: Carry -> SocketIO ()
send Carry
message = do
  (Socket
sock, Buffer
_, Buffer
_, Config
_) <- StateT
  Carry
  (ReaderT (Socket, Buffer, Buffer, Config) IO)
  (Socket, Buffer, Buffer, Config)
forall r (m :: * -> *). MonadReader r m => m r
ask
  IO () -> SocketIO ()
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> SocketIO ()) -> IO () -> SocketIO ()
forall a b. (a -> b) -> a -> b
$ do
    Socket -> Carry -> IO ()
NSB.sendAll Socket
sock Carry
message
    IO ()
yield

buildAndSend :: BSB.Builder -> SocketIO ()
buildAndSend :: Builder -> SocketIO ()
buildAndSend Builder
builder = do
  (Socket
_, Buffer ForeignPtr Word8
fp Int
len, Buffer
_, Config
_) <- StateT
  Carry
  (ReaderT (Socket, Buffer, Buffer, Config) IO)
  (Socket, Buffer, Buffer, Config)
forall r (m :: * -> *). MonadReader r m => m r
ask
  ForeignPtr Word8 -> Int -> BufferWriter -> SocketIO ()
go ForeignPtr Word8
fp Int
len (BufferWriter -> SocketIO ()) -> BufferWriter -> SocketIO ()
forall a b. (a -> b) -> a -> b
$ Builder -> BufferWriter
BSB.runBuilder Builder
builder
  where
    go :: ForeignPtr Word8 -> Int -> BSB.BufferWriter -> SocketIO ()
    go :: ForeignPtr Word8 -> Int -> BufferWriter -> SocketIO ()
go ForeignPtr Word8
bfp Int
blen BufferWriter
writer = do
      (Int
wc, Next
next) <- IO (Int, Next)
-> StateT
     Carry (ReaderT (Socket, Buffer, Buffer, Config) IO) (Int, Next)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Int, Next)
 -> StateT
      Carry (ReaderT (Socket, Buffer, Buffer, Config) IO) (Int, Next))
-> IO (Int, Next)
-> StateT
     Carry (ReaderT (Socket, Buffer, Buffer, Config) IO) (Int, Next)
forall a b. (a -> b) -> a -> b
$ ForeignPtr Word8 -> (Ptr Word8 -> IO (Int, Next)) -> IO (Int, Next)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
bfp ((Ptr Word8 -> IO (Int, Next)) -> IO (Int, Next))
-> (Ptr Word8 -> IO (Int, Next)) -> IO (Int, Next)
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
ptr -> BufferWriter
writer Ptr Word8
ptr Int
blen
      Carry -> SocketIO ()
send (Carry -> SocketIO ()) -> Carry -> SocketIO ()
forall a b. (a -> b) -> a -> b
$ ForeignPtr Word8 -> Int -> Int -> Carry
BSI.PS ForeignPtr Word8
bfp Int
0 Int
wc
      case Next
next of
        Next
BSB.Done -> () -> SocketIO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        BSB.More Int
newLen BufferWriter
w
          | Int
newLen Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
blen -> ForeignPtr Word8 -> Int -> BufferWriter -> SocketIO ()
go ForeignPtr Word8
bfp Int
blen BufferWriter
w
          | Bool
otherwise -> do
            ForeignPtr Word8
newFPtr <- IO (ForeignPtr Word8)
-> StateT
     Carry
     (ReaderT (Socket, Buffer, Buffer, Config) IO)
     (ForeignPtr Word8)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (ForeignPtr Word8)
 -> StateT
      Carry
      (ReaderT (Socket, Buffer, Buffer, Config) IO)
      (ForeignPtr Word8))
-> IO (ForeignPtr Word8)
-> StateT
     Carry
     (ReaderT (Socket, Buffer, Buffer, Config) IO)
     (ForeignPtr Word8)
forall a b. (a -> b) -> a -> b
$ Int -> IO (ForeignPtr Word8)
forall a. Int -> IO (ForeignPtr a)
BSI.mallocByteString Int
newLen
            ForeignPtr Word8 -> Int -> BufferWriter -> SocketIO ()
go ForeignPtr Word8
newFPtr Int
newLen BufferWriter
w
        BSB.Chunk Carry
bs BufferWriter
w -> do
          Carry -> SocketIO ()
send Carry
bs
          ForeignPtr Word8 -> Int -> BufferWriter -> SocketIO ()
go ForeignPtr Word8
bfp Int
blen BufferWriter
w

recvAndParse :: NS.Socket -> Buffer -> Carry -> AP.Parser response -> IO (response, Carry)
recvAndParse :: Socket
-> Buffer -> Carry -> Parser response -> IO (response, Carry)
recvAndParse Socket
sock (Buffer ForeignPtr Word8
bfptr Int
blen) Carry
carry Parser response
parser =
  ForeignPtr Word8
-> (Ptr Word8 -> IO (response, Carry)) -> IO (response, Carry)
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr Word8
bfptr ((Ptr Word8 -> IO (response, Carry)) -> IO (response, Carry))
-> (Ptr Word8 -> IO (response, Carry)) -> IO (response, Carry)
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
bptr -> do
    let
      recv :: IO BS.ByteString
      recv :: IO Carry
recv = do
        Int
len <- Socket -> Ptr Word8 -> Int -> IO Int
recvBuf Socket
sock Ptr Word8
bptr Int
blen
        case Int
len of
          Int
0 -> Carry -> IO Carry
forall (f :: * -> *) a. Applicative f => a -> f a
pure Carry
BS.empty -- EOF
          Int
_ -> Carry -> IO Carry
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Carry -> IO Carry) -> Carry -> IO Carry
forall a b. (a -> b) -> a -> b
$ Carry -> Carry
BS.copy (Carry -> Carry) -> Carry -> Carry
forall a b. (a -> b) -> a -> b
$ ForeignPtr Word8 -> Int -> Int -> Carry
BSI.PS ForeignPtr Word8
bfptr Int
0 Int
len
    Result response
result <- IO Carry -> Parser response -> Carry -> IO (Result response)
forall (m :: * -> *) a.
Monad m =>
m Carry -> Parser a -> Carry -> m (Result a)
AP.parseWith IO Carry
recv Parser response
parser Carry
carry
    case Result response
result of
      AP.Done Carry
rest response
response -> (response, Carry) -> IO (response, Carry)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (response
response, Carry
rest)
      AP.Fail Carry
rest [] String
msg   -> InternalException -> IO (response, Carry)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throw (InternalException -> IO (response, Carry))
-> InternalException -> IO (response, Carry)
forall a b. (a -> b) -> a -> b
$ String -> Carry -> InternalException
Exception.InternalResponseParsingFailed String
msg Carry
rest
      AP.Fail Carry
rest [String]
ctxs String
msg -> InternalException -> IO (response, Carry)
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throw (InternalException -> IO (response, Carry))
-> InternalException -> IO (response, Carry)
forall a b. (a -> b) -> a -> b
$ String -> Carry -> InternalException
Exception.InternalResponseParsingFailed (String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
" > " [String]
ctxs String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
": " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
msg) Carry
rest
      AP.Partial Carry -> Result response
_          -> IO (response, Carry)
forall a. HasCallStack => a
Exception.cantReachHere

receiveJust :: AP.Parser response -> SocketIO response
receiveJust :: Parser response -> SocketIO response
receiveJust Parser response
parser = do
  Carry
carry <- StateT Carry (ReaderT (Socket, Buffer, Buffer, Config) IO) Carry
forall s (m :: * -> *). MonadState s m => m s
get
  (Socket
sock, Buffer
_, Buffer
buff, Config
_) <- StateT
  Carry
  (ReaderT (Socket, Buffer, Buffer, Config) IO)
  (Socket, Buffer, Buffer, Config)
forall r (m :: * -> *). MonadReader r m => m r
ask
  (response
response, Carry
carry') <- IO (response, Carry)
-> StateT
     Carry
     (ReaderT (Socket, Buffer, Buffer, Config) IO)
     (response, Carry)
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (response, Carry)
 -> StateT
      Carry
      (ReaderT (Socket, Buffer, Buffer, Config) IO)
      (response, Carry))
-> IO (response, Carry)
-> StateT
     Carry
     (ReaderT (Socket, Buffer, Buffer, Config) IO)
     (response, Carry)
forall a b. (a -> b) -> a -> b
$ Socket
-> Buffer -> Carry -> Parser response -> IO (response, Carry)
forall response.
Socket
-> Buffer -> Carry -> Parser response -> IO (response, Carry)
recvAndParse Socket
sock Buffer
buff Carry
carry Parser response
parser
  Carry -> SocketIO ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put Carry
carry'
  response -> SocketIO response
forall (f :: * -> *) a. Applicative f => a -> f a
pure response
response

receive :: AP.Parser response -> SocketIO response
receive :: Parser response -> SocketIO response
receive Parser response
parser = do
  Either InternalException response
r <- SocketIO response
-> StateT
     Carry
     (ReaderT (Socket, Buffer, Buffer, Config) IO)
     (Either InternalException response)
forall (m :: * -> *) e a.
(MonadCatch m, Exception e) =>
m a -> m (Either e a)
try (SocketIO response
 -> StateT
      Carry
      (ReaderT (Socket, Buffer, Buffer, Config) IO)
      (Either InternalException response))
-> SocketIO response
-> StateT
     Carry
     (ReaderT (Socket, Buffer, Buffer, Config) IO)
     (Either InternalException response)
forall a b. (a -> b) -> a -> b
$ Parser response -> SocketIO response
forall response. Parser response -> SocketIO response
receiveJust Parser response
parser
  case Either InternalException response
r of
    Right response
r -> response -> SocketIO response
forall (f :: * -> *) a. Applicative f => a -> f a
pure response
r
    Left e :: InternalException
e@(Exception.InternalResponseParsingFailed String
_ Carry
raw) ->
      case Parser Error -> Carry -> Result Error
forall a. Parser a -> Carry -> Result a
AP.parse Parser Error
Parser.skipUntilError Carry
raw of
        AP.Done Carry
rest (Error ErrorFields
fields) -> InternalException -> SocketIO response
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throw (InternalException -> SocketIO response)
-> InternalException -> SocketIO response
forall a b. (a -> b) -> a -> b
$ ErrorFields -> Maybe TransactionState -> Carry -> InternalException
Exception.InternalErrorResponse ErrorFields
fields Maybe TransactionState
forall a. Maybe a
Nothing Carry
rest
        AP.Fail {}                  -> InternalException -> SocketIO response
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throw InternalException
e
        AP.Partial Carry -> Result Error
_                -> InternalException -> SocketIO response
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throw InternalException
e
    Left InternalException
e -> InternalException -> SocketIO response
forall (m :: * -> *) e a. (MonadThrow m, Exception e) => e -> m a
throw InternalException
e

-- Before network 3.0.0.0, recvBuf raises error on EOF. Otherwise it returns 0 on EOF.
recvBuf :: NS.Socket -> Ptr Word8 -> Int -> IO Int
#if MIN_VERSION_network(3, 0, 0)
recvBuf :: Socket -> Ptr Word8 -> Int -> IO Int
recvBuf Socket
s Ptr Word8
ptr Int
nbytes = Socket -> Ptr Word8 -> Int -> IO Int
NS.recvBuf Socket
s Ptr Word8
ptr Int
nbytes
#else
recvBuf s ptr nbytes = do
  r <- tryJust (guard . isEOFError) $ NS.recvBuf s ptr nbytes
  case r of
    Left _  -> pure 0
    Right l -> pure l
#endif