module Database.EventStore.Internal.Connection
( InternalConnection
, ConnectionException(..)
, connUUID
, connClose
, connSend
, connRecv
, connIsClosed
, newConnection
) where
import Control.Concurrent
import Control.Concurrent.STM
import Control.Exception
import qualified Data.ByteString as B
import Data.Foldable (for_)
import Data.IORef
import Data.Typeable
import Text.Printf
import Data.Serialize
import Data.UUID
import Data.UUID.V4
import Network.Connection
import Database.EventStore.Internal.Discovery
import Database.EventStore.Internal.Types
import Database.EventStore.Logging
data ConnectionException
= MaxAttemptConnectionReached
| ClosedConnection
| WrongPackageFraming
| PackageParsingError String
deriving (Show, Typeable)
instance Exception ConnectionException
data In a where
Id :: In UUID
Close :: In ()
Send :: Package -> In ()
Recv :: In Package
data InternalConnection =
InternalConnection
{ _var :: TMVar State
, _last :: IORef (Maybe EndPoint)
, _disc :: Discovery
, _setts :: Settings
, _ctx :: ConnectionContext
}
data State
= Offline
| Online !UUID !Connection
| Closed
newConnection :: Settings -> Discovery -> IO InternalConnection
newConnection setts disc = do
ctx <- initConnectionContext
var <- newTMVarIO Offline
ref <- newIORef Nothing
return $ InternalConnection var ref disc setts ctx
connUUID :: InternalConnection -> IO UUID
connUUID conn = execute conn Id
connClose :: InternalConnection -> IO ()
connClose conn = execute conn Close
connSend :: InternalConnection -> Package -> IO ()
connSend conn pkg = execute conn (Send pkg)
connRecv :: InternalConnection -> IO Package
connRecv conn = execute conn Recv
connIsClosed :: InternalConnection -> STM Bool
connIsClosed InternalConnection{..} = do
r <- readTMVar _var
case r of
Closed -> return True
_ -> return False
execute :: forall a. InternalConnection -> In a -> IO a
execute InternalConnection{..} i = do
res <- atomically $ do
s <- takeTMVar _var
case s of
Offline -> return $ Right Nothing
Online u con -> return $ Right $ Just (u, con)
Closed -> return $ Left ClosedConnection
case i of
Close ->
case res of
Left _ -> atomically $ putTMVar _var Closed
Right Nothing -> atomically $ putTMVar _var Closed
Right (Just (_, con)) -> do
connectionClose con
atomically $ putTMVar _var Closed
other ->
case res of
Left e -> do
atomically $ putTMVar _var Closed
throwIO e
Right alt -> do
sres <- case alt of
Nothing -> newState _setts _ctx _last _disc
Just (u, h) -> return $ Right $ Online u h
case sres of
Left e -> do
atomically $ putTMVar _var Closed
throwIO e
Right s -> do
atomically $ putTMVar _var s
let Online u con = s
case other of
Id -> return u
Send pkg -> send con pkg
Recv -> recv con
Close -> error "impossible execute"
newState :: Settings
-> ConnectionContext
-> IORef (Maybe EndPoint)
-> Discovery
-> IO (Either ConnectionException State)
newState sett ctx ref disc =
case s_retry sett of
AtMost n ->
let loop i = do
_settingsLog sett (Info $ Connecting i)
let action = do
old <- readIORef ref
ept_opt <- runDiscovery disc old
case ept_opt of
Nothing -> do
threadDelay delay
if n <= i
then return $
Left MaxAttemptConnectionReached
else loop (i + 1)
Just ept -> do
let host = endPointIp ept
port = endPointPort ept
st <- connect sett ctx host port
writeIORef ref (Just ept)
return $ Right st
catch action $ \(_ :: SomeException) -> do
threadDelay delay
if n <= i
then return $
Left MaxAttemptConnectionReached
else loop (i + 1) in
loop 1
KeepRetrying ->
let endlessly i = do
_settingsLog sett (Info $ Connecting i)
let action = do
old <- readIORef ref
ept_opt <- runDiscovery disc old
case ept_opt of
Nothing -> threadDelay delay
>> endlessly (i + 1)
Just ept -> do
let host = endPointIp ept
port = endPointPort ept
st <- connect sett ctx host port
writeIORef ref (Just ept)
return $ Right st
catch action $ \(_ :: SomeException) ->
threadDelay delay >> endlessly (i + 1) in
endlessly (1 :: Int)
where
delay = s_reconnect_delay_secs sett * secs
secs :: Int
secs = 1000000
connect :: Settings -> ConnectionContext -> String -> Int -> IO State
connect sett ctx host port = do
let params = ConnectionParams host (fromIntegral port) (s_ssl sett) Nothing
conn <- connectTo ctx params
uuid <- nextRandom
_settingsLog sett (Info $ Connected uuid)
return $ Online uuid conn
recv :: Connection -> IO Package
recv con = do
header_bs <- connectionGet con 4
case runGet getLengthPrefix header_bs of
Left _ -> throwIO WrongPackageFraming
Right length_prefix -> do
bs <- connectionGet con length_prefix
case runGet getPackage bs of
Left e -> throwIO $ PackageParsingError e
Right pkg -> return pkg
send :: Connection -> Package -> IO ()
send con pkg = connectionPut con bs
where
bs = runPut $ putPackage pkg
putPackage :: Package -> Put
putPackage pack = do
putWord32le length_prefix
putWord8 (packageCmd pack)
putWord8 flag_word8
putLazyByteString corr_bytes
for_ cred_m $ \(Credentials login passw) -> do
putWord8 $ fromIntegral $ B.length login
putByteString login
putWord8 $ fromIntegral $ B.length passw
putByteString passw
putByteString pack_data
where
pack_data = packageData pack
cred_len = maybe 0 credSize cred_m
length_prefix = fromIntegral (B.length pack_data + mandatorySize + cred_len)
cred_m = packageCred pack
flag_word8 = maybe 0x00 (const 0x01) cred_m
corr_bytes = toByteString $ packageCorrelation pack
credSize :: Credentials -> Int
credSize (Credentials login passw) = B.length login + B.length passw + 2
mandatorySize :: Int
mandatorySize = 18
getLengthPrefix :: Get Int
getLengthPrefix = fmap fromIntegral getWord32le
getPackage :: Get Package
getPackage = do
cmd <- getWord8
flg <- getFlag
col <- getUUID
cred <- getCredentials flg
rest <- remaining
dta <- getBytes rest
let pkg = Package
{ packageCmd = cmd
, packageCorrelation = col
, packageData = dta
, packageCred = cred
}
return pkg
getFlag :: Get Flag
getFlag = do
wd <- getWord8
case wd of
0x00 -> return None
0x01 -> return Authenticated
_ -> fail $ printf "TCP: Unhandled flag value 0x%x" wd
getCredEntryLength :: Get Int
getCredEntryLength = fmap fromIntegral getWord8
getCredentials :: Flag -> Get (Maybe Credentials)
getCredentials None = return Nothing
getCredentials _ = do
loginLen <- getCredEntryLength
login <- getBytes loginLen
passwLen <- getCredEntryLength
passw <- getBytes passwLen
return $ Just $ credentials login passw
getUUID :: Get UUID
getUUID = do
bs <- getLazyByteString 16
case fromByteString bs of
Just uuid -> return uuid
_ -> fail "TCP: Wrong UUID format"