{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
module Database.Bolt.Connection.Pipe where

import           Database.Bolt.Connection.Instances
import           Database.Bolt.Connection.Type
import           Database.Bolt.Value.Instances
import           Database.Bolt.Value.Type
import qualified Database.Bolt.Connection.Connection as C (close, connect, recv,
                                                           send, sendMany)

import           Control.Exception                   (throwIO)
import           Control.Monad                       (forM_, unless, void, when)
import           Control.Monad.Except                (MonadError (..), ExceptT, runExceptT)
import           Control.Monad.Trans                 (MonadIO (..))
import           Data.ByteString                     (ByteString)
import qualified Data.ByteString                     as B (concat, length, null,
                                                           splitAt)
import           Data.Word                           (Word16)
import           GHC.Stack                           (HasCallStack)

type MonadPipe m = (MonadIO m, MonadError BoltError m)

-- |Creates new 'Pipe' instance to use all requests through
connect :: MonadIO m => HasCallStack => BoltCfg -> m Pipe
connect :: BoltCfg -> m Pipe
connect = (BoltCfg -> ExceptT BoltError m Pipe) -> BoltCfg -> m Pipe
forall (m :: * -> *) a b.
(MonadIO m, HasCallStack) =>
(a -> ExceptT BoltError m b) -> a -> m b
makeIO BoltCfg -> ExceptT BoltError m Pipe
forall (m :: * -> *). MonadPipe m => BoltCfg -> m Pipe
connect'
  where
    connect' :: MonadPipe m => BoltCfg -> m Pipe
    connect' :: BoltCfg -> m Pipe
connect' BoltCfg
bcfg = do ConnectionWithTimeout
conn <- Bool -> String -> PortNumber -> Int -> m ConnectionWithTimeout
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
Bool -> String -> PortNumber -> Int -> m ConnectionWithTimeout
C.connect (BoltCfg -> Bool
secure BoltCfg
bcfg) (BoltCfg -> String
host BoltCfg
bcfg) (Int -> PortNumber
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> PortNumber) -> Int -> PortNumber
forall a b. (a -> b) -> a -> b
$ BoltCfg -> Int
port BoltCfg
bcfg) (BoltCfg -> Int
socketTimeout BoltCfg
bcfg)
                       let pipe :: Pipe
pipe = ConnectionWithTimeout -> Word16 -> Pipe
Pipe ConnectionWithTimeout
conn (BoltCfg -> Word16
maxChunkSize BoltCfg
bcfg)
                       Pipe -> BoltCfg -> m ()
forall (m :: * -> *).
(MonadPipe m, HasCallStack) =>
Pipe -> BoltCfg -> m ()
handshake Pipe
pipe BoltCfg
bcfg
                       Pipe -> m Pipe
forall (f :: * -> *) a. Applicative f => a -> f a
pure Pipe
pipe

-- |Closes 'Pipe'
close :: MonadIO m => HasCallStack => Pipe -> m ()
close :: Pipe -> m ()
close = ConnectionWithTimeout -> m ()
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
ConnectionWithTimeout -> m ()
C.close (ConnectionWithTimeout -> m ())
-> (Pipe -> ConnectionWithTimeout) -> Pipe -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Pipe -> ConnectionWithTimeout
connection

-- |Resets current sessions
reset :: MonadIO m => HasCallStack => Pipe -> m ()
reset :: Pipe -> m ()
reset = (Pipe -> ExceptT BoltError m ()) -> Pipe -> m ()
forall (m :: * -> *) a b.
(MonadIO m, HasCallStack) =>
(a -> ExceptT BoltError m b) -> a -> m b
makeIO Pipe -> ExceptT BoltError m ()
forall (m :: * -> *). MonadPipe m => Pipe -> m ()
reset'
  where
    reset' :: MonadPipe m => Pipe -> m ()
    reset' :: Pipe -> m ()
reset' Pipe
pipe = do Pipe -> Request -> m ()
forall (m :: * -> *).
(MonadPipe m, HasCallStack) =>
Pipe -> Request -> m ()
flush Pipe
pipe Request
RequestReset
                     Response
response <- Pipe -> m Response
forall (m :: * -> *).
(MonadPipe m, HasCallStack) =>
Pipe -> m Response
fetch Pipe
pipe
                     Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Response -> Bool
isFailure Response
response) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
                       BoltError -> m ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError BoltError
ResetFailed

-- Helper to make pipe operations in IO
makeIO :: MonadIO m => HasCallStack => (a -> ExceptT BoltError m b) -> a -> m b
makeIO :: (a -> ExceptT BoltError m b) -> a -> m b
makeIO a -> ExceptT BoltError m b
action a
arg = do Either BoltError b
actionIO <- ExceptT BoltError m b -> m (Either BoltError b)
forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT (a -> ExceptT BoltError m b
action a
arg)
                       case Either BoltError b
actionIO of
                         Right b
x -> b -> m b
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
x
                         Left  BoltError
e -> IO b -> m b
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO b -> m b) -> IO b -> m b
forall a b. (a -> b) -> a -> b
$ BoltError -> IO b
forall e a. Exception e => e -> IO a
throwIO BoltError
e

-- = Internal interfaces

ackFailure :: MonadPipe m => HasCallStack => Pipe -> m ()
ackFailure :: Pipe -> m ()
ackFailure Pipe
pipe = Pipe -> Request -> m ()
forall (m :: * -> *).
(MonadPipe m, HasCallStack) =>
Pipe -> Request -> m ()
flush Pipe
pipe Request
RequestAckFailure m () -> m () -> m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> m Response -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Pipe -> m Response
forall (m :: * -> *).
(MonadPipe m, HasCallStack) =>
Pipe -> m Response
fetch Pipe
pipe)

discardAll :: MonadPipe m => HasCallStack => Pipe -> m ()
discardAll :: Pipe -> m ()
discardAll Pipe
pipe = Pipe -> Request -> m ()
forall (m :: * -> *).
(MonadPipe m, HasCallStack) =>
Pipe -> Request -> m ()
flush Pipe
pipe Request
RequestDiscardAll m () -> m () -> m ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> m Response -> m ()
forall (f :: * -> *) a. Functor f => f a -> f ()
void (Pipe -> m Response
forall (m :: * -> *).
(MonadPipe m, HasCallStack) =>
Pipe -> m Response
fetch Pipe
pipe)

flush :: MonadPipe m => HasCallStack => Pipe -> Request -> m ()
flush :: Pipe -> Request -> m ()
flush Pipe
pipe Request
request = do [ByteString] -> (ByteString -> m ()) -> m ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [ByteString]
chunks ((ByteString -> m ()) -> m ()) -> (ByteString -> m ()) -> m ()
forall a b. (a -> b) -> a -> b
$ ConnectionWithTimeout -> [ByteString] -> m ()
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
ConnectionWithTimeout -> [ByteString] -> m ()
C.sendMany ConnectionWithTimeout
conn ([ByteString] -> m ())
-> (ByteString -> [ByteString]) -> ByteString -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
mkChunk
                        ConnectionWithTimeout -> ByteString -> m ()
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
ConnectionWithTimeout -> ByteString -> m ()
C.send ConnectionWithTimeout
conn ByteString
terminal
  where bs :: ByteString
bs        = Structure -> ByteString
forall a. BoltValue a => a -> ByteString
pack (Structure -> ByteString) -> Structure -> ByteString
forall a b. (a -> b) -> a -> b
$ Request -> Structure
forall a. ToStructure a => a -> Structure
toStructure Request
request
        chunkSize :: Int
chunkSize = Word16 -> ByteString -> Int
chunkSizeFor (Pipe -> Word16
mcs Pipe
pipe) ByteString
bs
        chunks :: [ByteString]
chunks    = Int -> ByteString -> [ByteString]
split Int
chunkSize ByteString
bs
        terminal :: ByteString
terminal  = Word16 -> ByteString
forall a. Binary a => a -> ByteString
encodeStrict (Word16
0 :: Word16)
        conn :: ConnectionWithTimeout
conn      = Pipe -> ConnectionWithTimeout
connection Pipe
pipe

        mkChunk :: ByteString -> [ByteString]
        mkChunk :: ByteString -> [ByteString]
mkChunk ByteString
chunk = let size :: Word16
size = Int -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (ByteString -> Int
B.length ByteString
chunk) :: Word16
                        in  [Word16 -> ByteString
forall a. Binary a => a -> ByteString
encodeStrict Word16
size, ByteString
chunk]

fetch :: MonadPipe m => HasCallStack => Pipe -> m Response
fetch :: Pipe -> m Response
fetch Pipe
pipe = do ByteString
bs <- [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString) -> m [ByteString] -> m ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m [ByteString]
forall (m :: * -> *). MonadPipe m => m [ByteString]
chunks
                Either UnpackError Response
response <- (UnpackT m Response
 -> ByteString -> m (Either UnpackError Response))
-> ByteString
-> UnpackT m Response
-> m (Either UnpackError Response)
forall a b c. (a -> b -> c) -> b -> a -> c
flip UnpackT m Response -> ByteString -> m (Either UnpackError Response)
forall (m :: * -> *) a.
Monad m =>
UnpackT m a -> ByteString -> m (Either UnpackError a)
unpackAction ByteString
bs (UnpackT m Response -> m (Either UnpackError Response))
-> UnpackT m Response -> m (Either UnpackError Response)
forall a b. (a -> b) -> a -> b
$ UnpackT m Structure
forall a (m :: * -> *). (BoltValue a, Monad m) => UnpackT m a
unpackT UnpackT m Structure
-> (Structure -> UnpackT m Response) -> UnpackT m Response
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Structure -> UnpackT m Response
forall a (m :: * -> *).
(FromStructure a, MonadError UnpackError m) =>
Structure -> m a
fromStructure
                (UnpackError -> m Response)
-> (Response -> m Response)
-> Either UnpackError Response
-> m Response
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (BoltError -> m Response
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (BoltError -> m Response)
-> (UnpackError -> BoltError) -> UnpackError -> m Response
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UnpackError -> BoltError
WrongMessageFormat) Response -> m Response
forall (f :: * -> *) a. Applicative f => a -> f a
pure Either UnpackError Response
response
  where conn :: ConnectionWithTimeout
conn = Pipe -> ConnectionWithTimeout
connection Pipe
pipe

        chunks :: MonadPipe m => m [ByteString]
        chunks :: m [ByteString]
chunks = do Word16
size <- ByteString -> Word16
forall a. Binary a => ByteString -> a
decodeStrict (ByteString -> Word16) -> m ByteString -> m Word16
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ConnectionWithTimeout -> Word16 -> m ByteString
forall (m :: * -> *).
(MonadPipe m, HasCallStack) =>
ConnectionWithTimeout -> Word16 -> m ByteString
recvChunk ConnectionWithTimeout
conn Word16
2
                    ByteString
chunk <- ConnectionWithTimeout -> Word16 -> m ByteString
forall (m :: * -> *).
(MonadPipe m, HasCallStack) =>
ConnectionWithTimeout -> Word16 -> m ByteString
recvChunk ConnectionWithTimeout
conn Word16
size
                    if ByteString -> Bool
B.null ByteString
chunk
                      then [ByteString] -> m [ByteString]
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
                      else do [ByteString]
rest <- m [ByteString]
forall (m :: * -> *). MonadPipe m => m [ByteString]
chunks
                              [ByteString] -> m [ByteString]
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString
chunkByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:[ByteString]
rest)

-- Helper functions

handshake :: MonadPipe m => HasCallStack => Pipe -> BoltCfg -> m ()
handshake :: Pipe -> BoltCfg -> m ()
handshake Pipe
pipe BoltCfg
bcfg = do let conn :: ConnectionWithTimeout
conn = Pipe -> ConnectionWithTimeout
connection Pipe
pipe
                         ConnectionWithTimeout -> ByteString -> m ()
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
ConnectionWithTimeout -> ByteString -> m ()
C.send ConnectionWithTimeout
conn (Word32 -> ByteString
forall a. Binary a => a -> ByteString
encodeStrict (Word32 -> ByteString) -> Word32 -> ByteString
forall a b. (a -> b) -> a -> b
$ BoltCfg -> Word32
magic BoltCfg
bcfg)
                         ConnectionWithTimeout -> ByteString -> m ()
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
ConnectionWithTimeout -> ByteString -> m ()
C.send ConnectionWithTimeout
conn (BoltCfg -> ByteString
boltVersionProposal BoltCfg
bcfg)
                         Word32
serverVersion <- ByteString -> Word32
forall a. Binary a => ByteString -> a
decodeStrict (ByteString -> Word32) -> m ByteString -> m Word32
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ConnectionWithTimeout -> Word16 -> m ByteString
forall (m :: * -> *).
(MonadPipe m, HasCallStack) =>
ConnectionWithTimeout -> Word16 -> m ByteString
recvChunk ConnectionWithTimeout
conn Word16
4
                         Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Word32
serverVersion Word32 -> Word32 -> Bool
forall a. Eq a => a -> a -> Bool
/= BoltCfg -> Word32
version BoltCfg
bcfg) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
                           BoltError -> m ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError BoltError
UnsupportedServerVersion
                         Pipe -> Request -> m ()
forall (m :: * -> *).
(MonadPipe m, HasCallStack) =>
Pipe -> Request -> m ()
flush Pipe
pipe (BoltCfg -> Request
createInit BoltCfg
bcfg)
                         Response
response <- Pipe -> m Response
forall (m :: * -> *).
(MonadPipe m, HasCallStack) =>
Pipe -> m Response
fetch Pipe
pipe
                         Bool -> m () -> m ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless (Response -> Bool
isSuccess Response
response) (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$
                           BoltError -> m ()
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError BoltError
AuthentificationFailed

boltVersionProposal :: BoltCfg -> ByteString
boltVersionProposal :: BoltCfg -> ByteString
boltVersionProposal BoltCfg
bcfg = [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$ Word32 -> ByteString
forall a. Binary a => a -> ByteString
encodeStrict (Word32 -> ByteString) -> [Word32] -> [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [BoltCfg -> Word32
version BoltCfg
bcfg, Word32
0, Word32
0, Word32
0]

recvChunk :: MonadPipe m => HasCallStack => ConnectionWithTimeout -> Word16 -> m ByteString
recvChunk :: ConnectionWithTimeout -> Word16 -> m ByteString
recvChunk ConnectionWithTimeout
conn Word16
size = [ByteString] -> ByteString
B.concat ([ByteString] -> ByteString) -> m [ByteString] -> m ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m [ByteString]
forall (m :: * -> *). MonadPipe m => Int -> m [ByteString]
helper (Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
size)
  where helper :: MonadPipe m => Int -> m [ByteString]
        helper :: Int -> m [ByteString]
helper Int
0  = [ByteString] -> m [ByteString]
forall (f :: * -> *) a. Applicative f => a -> f a
pure []
        helper Int
sz = do Maybe ByteString
mbChunk <- ConnectionWithTimeout -> Int -> m (Maybe ByteString)
forall (m :: * -> *).
(MonadIO m, HasCallStack) =>
ConnectionWithTimeout -> Int -> m (Maybe ByteString)
C.recv ConnectionWithTimeout
conn Int
sz
                       case Maybe ByteString
mbChunk of
                         Just ByteString
chunk -> (ByteString
chunkByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
:) ([ByteString] -> [ByteString]) -> m [ByteString] -> m [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m [ByteString]
forall (m :: * -> *). MonadPipe m => Int -> m [ByteString]
helper (Int
sz Int -> Int -> Int
forall a. Num a => a -> a -> a
- ByteString -> Int
B.length ByteString
chunk)
                         Maybe ByteString
Nothing    -> BoltError -> m [ByteString]
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError BoltError
CannotReadChunk

chunkSizeFor :: Word16 -> ByteString -> Int
chunkSizeFor :: Word16 -> ByteString -> Int
chunkSizeFor Word16
maxSize ByteString
bs = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int -> Int
forall a. Integral a => a -> a -> a
div Int
len Int
noc
  where len :: Int
len = ByteString -> Int
B.length ByteString
bs
        noc :: Int
noc = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int -> Int -> Int
forall a. Integral a => a -> a -> a
div Int
len (Word16 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
maxSize)

split :: Int -> ByteString -> [ByteString]
split :: Int -> ByteString -> [ByteString]
split Int
size ByteString
bs | ByteString -> Bool
B.null ByteString
bs = []
              | Bool
otherwise = let (ByteString
chunk, ByteString
rest) = Int -> ByteString -> (ByteString, ByteString)
B.splitAt Int
size ByteString
bs
                            in ByteString
chunk ByteString -> [ByteString] -> [ByteString]
forall a. a -> [a] -> [a]
: Int -> ByteString -> [ByteString]
split Int
size ByteString
rest