{-# LANGUAGE CPP, DeriveDataTypeable, RankNTypes, RecordWildCards, ScopedTypeVariables #-}
module Data.Acid.Remote
(
acidServer
, acidServerSockAddr
, acidServer'
, openRemoteState
, openRemoteStateSockAddr
, skipAuthenticationCheck
, skipAuthenticationPerform
, sharedSecretCheck
, sharedSecretPerform
, AcidRemoteException(..)
, CommChannel(..)
, process
, processRemoteState
) where
import Prelude hiding ( catch )
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative
#endif
import Control.Concurrent.STM ( atomically )
import Control.Concurrent.STM.TMVar ( newEmptyTMVar, readTMVar, takeTMVar, tryTakeTMVar, putTMVar )
import Control.Concurrent.STM.TQueue
import Control.Exception ( AsyncException(ThreadKilled)
, Exception(fromException), IOException, Handler(..)
, SomeException, catch, catches, throw, bracketOnError )
import Control.Exception ( throwIO, finally )
import Control.Monad ( forever, liftM, join, when )
import Control.Concurrent ( ThreadId, forkIO, threadDelay, killThread, myThreadId )
import Control.Concurrent.MVar ( MVar, newEmptyMVar, putMVar, takeMVar )
import Control.Concurrent.Chan ( newChan, readChan, writeChan )
import Data.Acid.Abstract
import Data.Acid.Core
import Data.Acid.Common
#if !MIN_VERSION_base(4,11,0)
import Data.Monoid ((<>))
#endif
import qualified Data.ByteString as Strict
import Data.ByteString.Char8 ( pack )
import qualified Data.ByteString.Lazy as Lazy
import Data.IORef ( newIORef, readIORef, writeIORef )
import Data.Serialize
import Data.Set ( Set, member )
import Data.Typeable ( Typeable )
import GHC.IO.Exception ( IOErrorType(..) )
import Network.BSD ( PortNumber, getProtocolNumber, getHostByName, hostAddress )
import Network.Socket
import Network.Socket.ByteString as NSB ( recv, sendAll )
import System.Directory ( removeFile )
import System.IO ( Handle, hPrint, hFlush, hClose, stderr, IOMode(..) )
import System.IO.Error ( ioeGetErrorType, isFullError, isDoesNotExistError )
debugStrLn :: String -> IO ()
debugStrLn :: String -> IO ()
debugStrLn String
s =
do
forall (m :: * -> *) a. Monad m => a -> m a
return ()
data CommChannel = CommChannel
{ CommChannel -> ByteString -> IO ()
ccPut :: Strict.ByteString -> IO ()
, CommChannel -> Int -> IO ByteString
ccGetSome :: Int -> IO (Strict.ByteString)
, CommChannel -> IO ()
ccClose :: IO ()
}
data AcidRemoteException
= RemoteConnectionError
| AcidStateClosed
| SerializeError String
| AuthenticationError String
deriving (AcidRemoteException -> AcidRemoteException -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AcidRemoteException -> AcidRemoteException -> Bool
$c/= :: AcidRemoteException -> AcidRemoteException -> Bool
== :: AcidRemoteException -> AcidRemoteException -> Bool
$c== :: AcidRemoteException -> AcidRemoteException -> Bool
Eq, Int -> AcidRemoteException -> ShowS
[AcidRemoteException] -> ShowS
AcidRemoteException -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AcidRemoteException] -> ShowS
$cshowList :: [AcidRemoteException] -> ShowS
show :: AcidRemoteException -> String
$cshow :: AcidRemoteException -> String
showsPrec :: Int -> AcidRemoteException -> ShowS
$cshowsPrec :: Int -> AcidRemoteException -> ShowS
Show, Typeable)
instance Exception AcidRemoteException
handleToCommChannel :: Handle -> CommChannel
handleToCommChannel :: Handle -> CommChannel
handleToCommChannel Handle
handle =
CommChannel { ccPut :: ByteString -> IO ()
ccPut = \ByteString
bs -> Handle -> ByteString -> IO ()
Strict.hPut Handle
handle ByteString
bs forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Handle -> IO ()
hFlush Handle
handle
, ccGetSome :: Int -> IO ByteString
ccGetSome = Handle -> Int -> IO ByteString
Strict.hGetSome Handle
handle
, ccClose :: IO ()
ccClose = Handle -> IO ()
hClose Handle
handle
}
socketToCommChannel :: Socket -> CommChannel
socketToCommChannel :: Socket -> CommChannel
socketToCommChannel Socket
socket =
CommChannel { ccPut :: ByteString -> IO ()
ccPut = Socket -> ByteString -> IO ()
sendAll Socket
socket
, ccGetSome :: Int -> IO ByteString
ccGetSome = Socket -> Int -> IO ByteString
NSB.recv Socket
socket
, ccClose :: IO ()
ccClose = Socket -> IO ()
close Socket
socket
}
skipAuthenticationCheck :: CommChannel -> IO Bool
skipAuthenticationCheck :: CommChannel -> IO Bool
skipAuthenticationCheck CommChannel
_ = forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
skipAuthenticationPerform :: CommChannel -> IO ()
skipAuthenticationPerform :: CommChannel -> IO ()
skipAuthenticationPerform CommChannel
_ = forall (m :: * -> *) a. Monad m => a -> m a
return ()
sharedSecretCheck :: Set Strict.ByteString
-> (CommChannel -> IO Bool)
sharedSecretCheck :: Set ByteString -> CommChannel -> IO Bool
sharedSecretCheck Set ByteString
secrets CommChannel
cc =
do ByteString
bs <- CommChannel -> Int -> IO ByteString
ccGetSome CommChannel
cc Int
1024
if forall a. Ord a => a -> Set a -> Bool
member ByteString
bs Set ByteString
secrets
then do CommChannel -> ByteString -> IO ()
ccPut CommChannel
cc (String -> ByteString
pack String
"OK")
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
else do CommChannel -> ByteString -> IO ()
ccPut CommChannel
cc (String -> ByteString
pack String
"FAIL")
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
sharedSecretPerform :: Strict.ByteString
-> (CommChannel -> IO ())
sharedSecretPerform :: ByteString -> CommChannel -> IO ()
sharedSecretPerform ByteString
pw CommChannel
cc =
do CommChannel -> ByteString -> IO ()
ccPut CommChannel
cc ByteString
pw
ByteString
r <- CommChannel -> Int -> IO ByteString
ccGetSome CommChannel
cc Int
1024
if ByteString
r forall a. Eq a => a -> a -> Bool
== (String -> ByteString
pack String
"OK")
then forall (m :: * -> *) a. Monad m => a -> m a
return ()
else forall e a. Exception e => e -> IO a
throwIO (String -> AcidRemoteException
AuthenticationError String
"shared secret authentication failed.")
acidServerSockAddr :: (CommChannel -> IO Bool)
-> SockAddr
-> AcidState st
-> IO ()
acidServerSockAddr :: forall st.
(CommChannel -> IO Bool) -> SockAddr -> AcidState st -> IO ()
acidServerSockAddr CommChannel -> IO Bool
checkAuth SockAddr
sockAddr AcidState st
acidState
= do Socket
listenSocket <- SockAddr -> IO Socket
listenOn SockAddr
sockAddr
(forall st.
(CommChannel -> IO Bool) -> Socket -> AcidState st -> IO ()
acidServer' CommChannel -> IO Bool
checkAuth Socket
listenSocket AcidState st
acidState) forall a b. IO a -> IO b -> IO a
`finally` (Socket -> IO ()
cleanup Socket
listenSocket)
where
cleanup :: Socket -> IO ()
cleanup Socket
socket =
do Socket -> IO ()
close Socket
socket
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
case SockAddr
sockAddr of
(SockAddrUnix String
path) -> String -> IO ()
removeFile String
path
SockAddr
_ -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
#endif
acidServer :: (CommChannel -> IO Bool)
-> PortNumber
-> AcidState st
-> IO ()
acidServer :: forall st.
(CommChannel -> IO Bool) -> PortNumber -> AcidState st -> IO ()
acidServer CommChannel -> IO Bool
checkAuth PortNumber
port AcidState st
acidState
= forall st.
(CommChannel -> IO Bool) -> SockAddr -> AcidState st -> IO ()
acidServerSockAddr CommChannel -> IO Bool
checkAuth (PortNumber -> HostAddress -> SockAddr
SockAddrInet PortNumber
port HostAddress
0) AcidState st
acidState
listenOn :: SockAddr -> IO Socket
listenOn :: SockAddr -> IO Socket
listenOn SockAddr
sockAddr = do
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
ProtocolNumber
proto <- case SockAddr
sockAddr of
(SockAddrUnix {}) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ProtocolNumber
0
SockAddr
_ -> String -> IO ProtocolNumber
getProtocolNumber String
"tcp"
#else
proto <- getProtocolNumber "tcp"
#endif
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
(Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
af SocketType
Stream ProtocolNumber
proto)
Socket -> IO ()
close
(\Socket
sock -> do
Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock SocketOption
ReuseAddr Int
1
Socket -> SockAddr -> IO ()
bind Socket
sock SockAddr
sockAddr
Socket -> Int -> IO ()
listen Socket
sock Int
maxListenQueue
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock
)
where
af :: Family
af = case SockAddr
sockAddr of
(SockAddrInet {}) -> Family
AF_INET
(SockAddrInet6 {}) -> Family
AF_INET6
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
(SockAddrUnix {}) -> Family
AF_UNIX
#endif
acidServer' :: (CommChannel -> IO Bool)
-> Socket
-> AcidState st
-> IO ()
acidServer' :: forall st.
(CommChannel -> IO Bool) -> Socket -> AcidState st -> IO ()
acidServer' CommChannel -> IO Bool
checkAuth Socket
listenSocket AcidState st
acidState
= do
let loop :: IO b
loop = forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$
do (Socket
socket, SockAddr
_sockAddr) <- Socket -> IO (Socket, SockAddr)
accept Socket
listenSocket
let commChannel :: CommChannel
commChannel = Socket -> CommChannel
socketToCommChannel Socket
socket
IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ do Bool
authorized <- CommChannel -> IO Bool
checkAuth CommChannel
commChannel
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
authorized forall a b. (a -> b) -> a -> b
$
forall st. CommChannel -> AcidState st -> IO ()
process CommChannel
commChannel AcidState st
acidState
CommChannel -> IO ()
ccClose CommChannel
commChannel
infi :: IO b
infi = forall {b}. IO b
loop forall e. IO () -> (Show e => e -> IO ()) -> IO ()
`catchSome` forall e. Show e => e -> IO ()
logError forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO b
infi
forall {b}. IO b
infi
where
logError :: (Show e) => e -> IO ()
logError :: forall e. Show e => e -> IO ()
logError e
e = forall a. Show a => Handle -> a -> IO ()
hPrint Handle
stderr e
e
isResourceVanishedError :: IOException -> Bool
isResourceVanishedError :: IOException -> Bool
isResourceVanishedError = IOErrorType -> Bool
isResourceVanishedType forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOException -> IOErrorType
ioeGetErrorType
isResourceVanishedType :: IOErrorType -> Bool
isResourceVanishedType :: IOErrorType -> Bool
isResourceVanishedType IOErrorType
ResourceVanished = Bool
True
isResourceVanishedType IOErrorType
_ = Bool
False
catchSome :: IO () -> (Show e => e -> IO ()) -> IO ()
catchSome :: forall e. IO () -> (Show e => e -> IO ()) -> IO ()
catchSome IO ()
op Show e => e -> IO ()
_h =
IO ()
op forall a. IO a -> [Handler a] -> IO a
`catches` [ forall a e. Exception e => (e -> IO a) -> Handler a
Handler forall a b. (a -> b) -> a -> b
$ \(IOException
e :: IOException) ->
if IOException -> Bool
isFullError IOException
e Bool -> Bool -> Bool
|| IOException -> Bool
isDoesNotExistError IOException
e Bool -> Bool -> Bool
|| IOException -> Bool
isResourceVanishedError IOException
e
then forall (m :: * -> *) a. Monad m => a -> m a
return ()
else forall a e. Exception e => e -> a
throw IOException
e
]
data Command = RunQuery (Tagged Lazy.ByteString)
| RunUpdate (Tagged Lazy.ByteString)
| CreateCheckpoint
| CreateArchive
instance Serialize Command where
put :: Putter Command
put Command
cmd = case Command
cmd of
RunQuery Tagged ByteString
query -> do Putter Word8
putWord8 Word8
0; forall t. Serialize t => Putter t
put Tagged ByteString
query
RunUpdate Tagged ByteString
update -> do Putter Word8
putWord8 Word8
1; forall t. Serialize t => Putter t
put Tagged ByteString
update
Command
CreateCheckpoint -> Putter Word8
putWord8 Word8
2
Command
CreateArchive -> Putter Word8
putWord8 Word8
3
get :: Get Command
get = do Word8
tag <- Get Word8
getWord8
case Word8
tag of
Word8
0 -> forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Tagged ByteString -> Command
RunQuery forall t. Serialize t => Get t
get
Word8
1 -> forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Tagged ByteString -> Command
RunUpdate forall t. Serialize t => Get t
get
Word8
2 -> forall (m :: * -> *) a. Monad m => a -> m a
return Command
CreateCheckpoint
Word8
3 -> forall (m :: * -> *) a. Monad m => a -> m a
return Command
CreateArchive
Word8
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Data.Acid.Remote: Serialize.get for Command, invalid tag: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Word8
tag
data Response = Result Lazy.ByteString | Acknowledgement | ConnectionError
instance Serialize Response where
put :: Putter Response
put Response
resp = case Response
resp of
Result ByteString
result -> do Putter Word8
putWord8 Word8
0; forall t. Serialize t => Putter t
put ByteString
result
Response
Acknowledgement -> Putter Word8
putWord8 Word8
1
Response
ConnectionError -> Putter Word8
putWord8 Word8
2
get :: Get Response
get = do Word8
tag <- Get Word8
getWord8
case Word8
tag of
Word8
0 -> forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM ByteString -> Response
Result forall t. Serialize t => Get t
get
Word8
1 -> forall (m :: * -> *) a. Monad m => a -> m a
return Response
Acknowledgement
Word8
2 -> forall (m :: * -> *) a. Monad m => a -> m a
return Response
ConnectionError
Word8
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Data.Acid.Remote: Serialize.get for Response, invalid tag: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Word8
tag
process :: CommChannel
-> AcidState st
-> IO ()
process :: forall st. CommChannel -> AcidState st -> IO ()
process CommChannel{IO ()
Int -> IO ByteString
ByteString -> IO ()
ccClose :: IO ()
ccGetSome :: Int -> IO ByteString
ccPut :: ByteString -> IO ()
ccClose :: CommChannel -> IO ()
ccGetSome :: CommChannel -> Int -> IO ByteString
ccPut :: CommChannel -> ByteString -> IO ()
..} AcidState st
acidState
= do Chan (IO Response)
chan <- forall a. IO (Chan a)
newChan
IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$ do Response
response <- forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (forall a. Chan a -> IO a
readChan Chan (IO Response)
chan)
ByteString -> IO ()
ccPut (forall a. Serialize a => a -> ByteString
encode Response
response)
Chan (IO Response) -> Result Command -> IO ()
worker Chan (IO Response)
chan (forall a. Get a -> ByteString -> Result a
runGetPartial forall t. Serialize t => Get t
get ByteString
Strict.empty)
where worker :: Chan (IO Response) -> Result Command -> IO ()
worker Chan (IO Response)
chan Result Command
inp
= case Result Command
inp of
Fail String
msg ByteString
_ -> forall e a. Exception e => e -> IO a
throwIO (String -> AcidRemoteException
SerializeError String
msg)
Partial ByteString -> Result Command
cont -> do ByteString
bs <- Int -> IO ByteString
ccGetSome Int
1024
if ByteString -> Bool
Strict.null ByteString
bs then
forall (m :: * -> *) a. Monad m => a -> m a
return ()
else
Chan (IO Response) -> Result Command -> IO ()
worker Chan (IO Response)
chan (ByteString -> Result Command
cont ByteString
bs)
Done Command
cmd ByteString
rest -> do Chan (IO Response) -> Command -> IO ()
processCommand Chan (IO Response)
chan Command
cmd; Chan (IO Response) -> Result Command -> IO ()
worker Chan (IO Response)
chan (forall a. Get a -> ByteString -> Result a
runGetPartial forall t. Serialize t => Get t
get ByteString
rest)
processCommand :: Chan (IO Response) -> Command -> IO ()
processCommand Chan (IO Response)
chan Command
cmd =
case Command
cmd of
RunQuery Tagged ByteString
query -> do ByteString
result <- forall st. AcidState st -> Tagged ByteString -> IO ByteString
queryCold AcidState st
acidState Tagged ByteString
query
forall a. Chan a -> a -> IO ()
writeChan Chan (IO Response)
chan (forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ ByteString -> Response
Result ByteString
result)
RunUpdate Tagged ByteString
update -> do MVar ByteString
result <- forall st.
AcidState st -> Tagged ByteString -> IO (MVar ByteString)
scheduleColdUpdate AcidState st
acidState Tagged ByteString
update
forall a. Chan a -> a -> IO ()
writeChan Chan (IO Response)
chan (forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM ByteString -> Response
Result forall a b. (a -> b) -> a -> b
$ forall a. MVar a -> IO a
takeMVar MVar ByteString
result)
Command
CreateCheckpoint -> do forall st. AcidState st -> IO ()
createCheckpoint AcidState st
acidState
forall a. Chan a -> a -> IO ()
writeChan Chan (IO Response)
chan (forall (m :: * -> *) a. Monad m => a -> m a
return Response
Acknowledgement)
Command
CreateArchive -> do forall st. AcidState st -> IO ()
createArchive AcidState st
acidState
forall a. Chan a -> a -> IO ()
writeChan Chan (IO Response)
chan (forall (m :: * -> *) a. Monad m => a -> m a
return Response
Acknowledgement)
data RemoteState st = RemoteState (Command -> IO (MVar Response)) (IO ())
deriving (Typeable)
openRemoteState :: IsAcidic st =>
(CommChannel -> IO ())
-> HostName
-> PortNumber
-> IO (AcidState st)
openRemoteState :: forall st.
IsAcidic st =>
(CommChannel -> IO ()) -> String -> PortNumber -> IO (AcidState st)
openRemoteState CommChannel -> IO ()
performAuthorization String
host PortNumber
port =
do HostEntry
he <- String -> IO HostEntry
getHostByName String
host
forall st.
IsAcidic st =>
(CommChannel -> IO ()) -> SockAddr -> IO (AcidState st)
openRemoteStateSockAddr CommChannel -> IO ()
performAuthorization (PortNumber -> HostAddress -> SockAddr
SockAddrInet PortNumber
port (HostEntry -> HostAddress
hostAddress HostEntry
he))
openRemoteStateSockAddr :: IsAcidic st =>
(CommChannel -> IO ())
-> SockAddr
-> IO (AcidState st)
openRemoteStateSockAddr :: forall st.
IsAcidic st =>
(CommChannel -> IO ()) -> SockAddr -> IO (AcidState st)
openRemoteStateSockAddr CommChannel -> IO ()
performAuthorization SockAddr
sockAddr
= forall a. IO a -> IO a
withSocketsDo forall a b. (a -> b) -> a -> b
$
do forall st. IsAcidic st => IO CommChannel -> IO (AcidState st)
processRemoteState IO CommChannel
reconnect
where
af :: Family
af :: Family
af = case SockAddr
sockAddr of
(SockAddrInet {}) -> Family
AF_INET
(SockAddrInet6 {}) -> Family
AF_INET6
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
(SockAddrUnix {}) -> Family
AF_UNIX
#endif
reconnect :: IO CommChannel
reconnect :: IO CommChannel
reconnect
= (do String -> IO ()
debugStrLn String
"Reconnecting."
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
ProtocolNumber
proto <- case SockAddr
sockAddr of
(SockAddrUnix {}) -> forall (f :: * -> *) a. Applicative f => a -> f a
pure ProtocolNumber
0
SockAddr
_ -> String -> IO ProtocolNumber
getProtocolNumber String
"tcp"
#else
proto <- getProtocolNumber "tcp"
#endif
Handle
handle <- forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
(Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
af SocketType
Stream ProtocolNumber
proto)
Socket -> IO ()
close
(\Socket
sock -> do
Socket -> SockAddr -> IO ()
connect Socket
sock SockAddr
sockAddr
Socket -> IOMode -> IO Handle
socketToHandle Socket
sock IOMode
ReadWriteMode
)
let cc :: CommChannel
cc = Handle -> CommChannel
handleToCommChannel Handle
handle
CommChannel -> IO ()
performAuthorization CommChannel
cc
String -> IO ()
debugStrLn String
"Reconnected."
forall (m :: * -> *) a. Monad m => a -> m a
return CommChannel
cc
)
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch`
((\IOException
_ -> Int -> IO ()
threadDelay Int
1000000 forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO CommChannel
reconnect) :: IOError -> IO CommChannel)
processRemoteState :: IsAcidic st =>
IO CommChannel
-> IO (AcidState st)
processRemoteState :: forall st. IsAcidic st => IO CommChannel -> IO (AcidState st)
processRemoteState IO CommChannel
reconnect
= do TQueue (Command, MVar Response)
cmdQueue <- forall a. STM a -> IO a
atomically forall a. STM (TQueue a)
newTQueue
TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
ccTMV <- forall a. STM a -> IO a
atomically forall a. STM (TMVar a)
newEmptyTMVar
IORef Bool
isClosed <- forall a. a -> IO (IORef a)
newIORef Bool
False
let actor :: Command -> IO (MVar Response)
actor :: Command -> IO (MVar Response)
actor Command
command =
do String -> IO ()
debugStrLn String
"actor: begin."
forall a. IORef a -> IO a
readIORef IORef Bool
isClosed forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= forall a b c. (a -> b -> c) -> b -> a -> c
flip forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall e a. Exception e => e -> IO a
throwIO AcidRemoteException
AcidStateClosed)
MVar Response
ref <- forall a. IO (MVar a)
newEmptyMVar
forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TQueue a -> a -> STM ()
writeTQueue TQueue (Command, MVar Response)
cmdQueue (Command
command, MVar Response
ref)
String -> IO ()
debugStrLn String
"actor: end."
forall (m :: * -> *) a. Monad m => a -> m a
return MVar Response
ref
expireQueue :: TQueue (Response -> IO a) -> IO ()
expireQueue TQueue (Response -> IO a)
listenQueue =
do Maybe (Response -> IO a)
mCallback <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TQueue a -> STM (Maybe a)
tryReadTQueue TQueue (Response -> IO a)
listenQueue
case Maybe (Response -> IO a)
mCallback of
Maybe (Response -> IO a)
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
(Just Response -> IO a
callback) ->
do Response -> IO a
callback Response
ConnectionError
TQueue (Response -> IO a) -> IO ()
expireQueue TQueue (Response -> IO a)
listenQueue
handleReconnect :: SomeException -> IO ()
handleReconnect :: SomeException -> IO ()
handleReconnect SomeException
e
= case forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e of
(Just AsyncException
ThreadKilled) ->
do String -> IO ()
debugStrLn String
"handleReconnect: ThreadKilled. Not attempting to reconnect."
forall (m :: * -> *) a. Monad m => a -> m a
return ()
Maybe AsyncException
_ ->
do String -> IO ()
debugStrLn forall a b. (a -> b) -> a -> b
$ String
"handleReconnect begin."
Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId)
tmv <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TMVar a -> STM (Maybe a)
tryTakeTMVar TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
ccTMV
case Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId)
tmv of
Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId)
Nothing ->
do String -> IO ()
debugStrLn forall a b. (a -> b) -> a -> b
$ String
"handleReconnect: error handling already in progress."
String -> IO ()
debugStrLn forall a b. (a -> b) -> a -> b
$ String
"handleReconnect end."
forall (m :: * -> *) a. Monad m => a -> m a
return ()
(Just (CommChannel
oldCC, TQueue (Response -> IO ())
oldListenQueue, ThreadId
oldListenerTID)) ->
do ThreadId
thisTID <- IO ThreadId
myThreadId
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ThreadId
thisTID forall a. Eq a => a -> a -> Bool
/= ThreadId
oldListenerTID) (ThreadId -> IO ()
killThread ThreadId
oldListenerTID)
CommChannel -> IO ()
ccClose CommChannel
oldCC
forall {a}. TQueue (Response -> IO a) -> IO ()
expireQueue TQueue (Response -> IO ())
oldListenQueue
CommChannel
cc <- IO CommChannel
reconnect
TQueue (Response -> IO ())
listenQueue <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. STM (TQueue a)
newTQueue
ThreadId
listenerTID <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ CommChannel -> TQueue (Response -> IO ()) -> IO ()
listener CommChannel
cc TQueue (Response -> IO ())
listenQueue
forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TMVar a -> a -> STM ()
putTMVar TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
ccTMV (CommChannel
cc, TQueue (Response -> IO ())
listenQueue, ThreadId
listenerTID)
String -> IO ()
debugStrLn forall a b. (a -> b) -> a -> b
$ String
"handleReconnect end."
forall (m :: * -> *) a. Monad m => a -> m a
return ()
listener :: CommChannel -> TQueue (Response -> IO ()) -> IO ()
listener :: CommChannel -> TQueue (Response -> IO ()) -> IO ()
listener CommChannel
cc TQueue (Response -> IO ())
listenQueue
= forall {b}. ByteString -> IO b
getResponse ByteString
Strict.empty forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` SomeException -> IO ()
handleReconnect
where
getResponse :: ByteString -> IO b
getResponse ByteString
leftover =
do String -> IO ()
debugStrLn forall a b. (a -> b) -> a -> b
$ String
"listener: listening for Response."
let go :: Result Response -> IO ByteString
go Result Response
inp = case Result Response
inp of
Fail String
msg ByteString
_ -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Data.Acid.Remote: " forall a. Semigroup a => a -> a -> a
<> String
msg
Partial ByteString -> Result Response
cont -> do String -> IO ()
debugStrLn forall a b. (a -> b) -> a -> b
$ String
"listener: ccGetSome"
ByteString
bs <- CommChannel -> Int -> IO ByteString
ccGetSome CommChannel
cc Int
1024
Result Response -> IO ByteString
go (ByteString -> Result Response
cont ByteString
bs)
Done Response
resp ByteString
rest -> do String -> IO ()
debugStrLn forall a b. (a -> b) -> a -> b
$ String
"listener: getting callback"
Response -> IO ()
callback <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TQueue a -> STM a
readTQueue TQueue (Response -> IO ())
listenQueue
String -> IO ()
debugStrLn forall a b. (a -> b) -> a -> b
$ String
"listener: passing Response to callback"
Response -> IO ()
callback (Response
resp :: Response)
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
rest
ByteString
rest <- Result Response -> IO ByteString
go (forall a. Get a -> ByteString -> Result a
runGetPartial forall t. Serialize t => Get t
get ByteString
leftover)
ByteString -> IO b
getResponse ByteString
rest
actorThread :: IO ()
actorThread :: IO ()
actorThread = forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$
do String -> IO ()
debugStrLn String
"actorThread: waiting for something to do."
(CommChannel
cc, Command
cmd) <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$
do (Command
cmd, MVar Response
ref) <- forall a. TQueue a -> STM a
readTQueue TQueue (Command, MVar Response)
cmdQueue
(CommChannel
cc, TQueue (Response -> IO ())
listenQueue, ThreadId
_) <- forall a. TMVar a -> STM a
readTMVar TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
ccTMV
forall a. TQueue a -> a -> STM ()
writeTQueue TQueue (Response -> IO ())
listenQueue (forall a. MVar a -> a -> IO ()
putMVar MVar Response
ref)
forall (m :: * -> *) a. Monad m => a -> m a
return (CommChannel
cc, Command
cmd)
String -> IO ()
debugStrLn String
"actorThread: sending command."
CommChannel -> ByteString -> IO ()
ccPut CommChannel
cc (forall a. Serialize a => a -> ByteString
encode Command
cmd) forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` SomeException -> IO ()
handleReconnect
String -> IO ()
debugStrLn String
"actorThread: sent."
forall (m :: * -> *) a. Monad m => a -> m a
return ()
shutdown :: ThreadId -> IO ()
shutdown :: ThreadId -> IO ()
shutdown ThreadId
actorTID =
do String -> IO ()
debugStrLn String
"shutdown: update isClosed IORef to True."
forall a. IORef a -> a -> IO ()
writeIORef IORef Bool
isClosed Bool
True
String -> IO ()
debugStrLn String
"shutdown: killing actor thread."
ThreadId -> IO ()
killThread ThreadId
actorTID
String -> IO ()
debugStrLn String
"shutdown: taking ccTMV."
(CommChannel
cc, TQueue (Response -> IO ())
listenQueue, ThreadId
listenerTID) <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TMVar a -> STM a
takeTMVar TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
ccTMV
String -> IO ()
debugStrLn String
"shutdown: killing listener thread."
ThreadId -> IO ()
killThread ThreadId
listenerTID
String -> IO ()
debugStrLn String
"shutdown: expiring listen queue."
forall {a}. TQueue (Response -> IO a) -> IO ()
expireQueue TQueue (Response -> IO ())
listenQueue
String -> IO ()
debugStrLn String
"shutdown: closing connection."
CommChannel -> IO ()
ccClose CommChannel
cc
forall (m :: * -> *) a. Monad m => a -> m a
return ()
CommChannel
cc <- IO CommChannel
reconnect
TQueue (Response -> IO ())
listenQueue <- forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. STM (TQueue a)
newTQueue
ThreadId
actorTID <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ IO ()
actorThread
ThreadId
listenerTID <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ CommChannel -> TQueue (Response -> IO ()) -> IO ()
listener CommChannel
cc TQueue (Response -> IO ())
listenQueue
forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TMVar a -> a -> STM ()
putTMVar TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
ccTMV (CommChannel
cc, TQueue (Response -> IO ())
listenQueue, ThreadId
listenerTID)
forall (m :: * -> *) a. Monad m => a -> m a
return (forall st. IsAcidic st => RemoteState st -> AcidState st
toAcidState forall a b. (a -> b) -> a -> b
$ forall st.
(Command -> IO (MVar Response)) -> IO () -> RemoteState st
RemoteState Command -> IO (MVar Response)
actor (ThreadId -> IO ()
shutdown ThreadId
actorTID))
remoteQuery :: QueryEvent event => RemoteState (EventState event) -> MethodMap (EventState event) -> event -> IO (EventResult event)
remoteQuery :: forall event.
QueryEvent event =>
RemoteState (EventState event)
-> MethodMap (EventState event) -> event -> IO (EventResult event)
remoteQuery RemoteState (EventState event)
acidState MethodMap (EventState event)
mmap event
event
= do let encoded :: ByteString
encoded = forall method. MethodSerialiser method -> method -> ByteString
encodeMethod MethodSerialiser event
ms event
event
ByteString
resp <- forall st. RemoteState st -> Tagged ByteString -> IO ByteString
remoteQueryCold RemoteState (EventState event)
acidState (forall ev. Method ev => ev -> ByteString
methodTag event
event, ByteString
encoded)
forall (m :: * -> *) a. Monad m => a -> m a
return (case forall method.
MethodSerialiser method
-> ByteString -> Either String (MethodResult method)
decodeResult MethodSerialiser event
ms ByteString
resp of
Left String
msg -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Data.Acid.Remote: " forall a. Semigroup a => a -> a -> a
<> String
msg
Right EventResult event
result -> EventResult event
result)
where
(State (EventState event) (EventResult event)
_, MethodSerialiser event
ms) = forall method.
Method method =>
MethodMap (MethodState method)
-> method
-> (State (MethodState method) (MethodResult method),
MethodSerialiser method)
lookupHotMethodAndSerialiser MethodMap (EventState event)
mmap event
event
remoteQueryCold :: RemoteState st -> Tagged Lazy.ByteString -> IO Lazy.ByteString
remoteQueryCold :: forall st. RemoteState st -> Tagged ByteString -> IO ByteString
remoteQueryCold rs :: RemoteState st
rs@(RemoteState Command -> IO (MVar Response)
fn IO ()
_shutdown) Tagged ByteString
event
= do Response
resp <- forall a. MVar a -> IO a
takeMVar forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Command -> IO (MVar Response)
fn (Tagged ByteString -> Command
RunQuery Tagged ByteString
event)
case Response
resp of
(Result ByteString
result) -> forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
result
Response
ConnectionError -> do String -> IO ()
debugStrLn String
"retrying query event."
forall st. RemoteState st -> Tagged ByteString -> IO ByteString
remoteQueryCold RemoteState st
rs Tagged ByteString
event
Response
Acknowledgement -> forall a. HasCallStack => String -> a
error String
"Data.Acid.Remote: remoteQueryCold got Acknowledgement. That should never happen."
scheduleRemoteUpdate :: UpdateEvent event => RemoteState (EventState event) -> MethodMap (EventState event) -> event -> IO (MVar (EventResult event))
scheduleRemoteUpdate :: forall event.
UpdateEvent event =>
RemoteState (EventState event)
-> MethodMap (EventState event)
-> event
-> IO (MVar (EventResult event))
scheduleRemoteUpdate (RemoteState Command -> IO (MVar Response)
fn IO ()
_shutdown) MethodMap (EventState event)
mmap event
event
= do let encoded :: ByteString
encoded = forall method. MethodSerialiser method -> method -> ByteString
encodeMethod MethodSerialiser event
ms event
event
MVar (EventResult event)
parsed <- forall a. IO (MVar a)
newEmptyMVar
MVar Response
respRef <- Command -> IO (MVar Response)
fn (Tagged ByteString -> Command
RunUpdate (forall ev. Method ev => ev -> ByteString
methodTag event
event, ByteString
encoded))
IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ do Result ByteString
resp <- forall a. MVar a -> IO a
takeMVar MVar Response
respRef
forall a. MVar a -> a -> IO ()
putMVar MVar (EventResult event)
parsed (case forall method.
MethodSerialiser method
-> ByteString -> Either String (MethodResult method)
decodeResult MethodSerialiser event
ms ByteString
resp of
Left String
msg -> forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"Data.Acid.Remote: " forall a. Semigroup a => a -> a -> a
<> String
msg
Right EventResult event
result -> EventResult event
result)
forall (m :: * -> *) a. Monad m => a -> m a
return MVar (EventResult event)
parsed
where
(State (EventState event) (EventResult event)
_, MethodSerialiser event
ms) = forall method.
Method method =>
MethodMap (MethodState method)
-> method
-> (State (MethodState method) (MethodResult method),
MethodSerialiser method)
lookupHotMethodAndSerialiser MethodMap (EventState event)
mmap event
event
scheduleRemoteColdUpdate :: RemoteState st -> Tagged Lazy.ByteString -> IO (MVar Lazy.ByteString)
scheduleRemoteColdUpdate :: forall st.
RemoteState st -> Tagged ByteString -> IO (MVar ByteString)
scheduleRemoteColdUpdate (RemoteState Command -> IO (MVar Response)
fn IO ()
_shutdown) Tagged ByteString
event
= do MVar ByteString
parsed <- forall a. IO (MVar a)
newEmptyMVar
MVar Response
respRef <- Command -> IO (MVar Response)
fn (Tagged ByteString -> Command
RunUpdate Tagged ByteString
event)
IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ do Result ByteString
resp <- forall a. MVar a -> IO a
takeMVar MVar Response
respRef
forall a. MVar a -> a -> IO ()
putMVar MVar ByteString
parsed ByteString
resp
forall (m :: * -> *) a. Monad m => a -> m a
return MVar ByteString
parsed
closeRemoteState :: RemoteState st -> IO ()
closeRemoteState :: forall st. RemoteState st -> IO ()
closeRemoteState (RemoteState Command -> IO (MVar Response)
_fn IO ()
shutdown) = IO ()
shutdown
createRemoteCheckpoint :: RemoteState st -> IO ()
createRemoteCheckpoint :: forall st. RemoteState st -> IO ()
createRemoteCheckpoint (RemoteState Command -> IO (MVar Response)
fn IO ()
_shutdown)
= do Response
Acknowledgement <- forall a. MVar a -> IO a
takeMVar forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Command -> IO (MVar Response)
fn Command
CreateCheckpoint
forall (m :: * -> *) a. Monad m => a -> m a
return ()
createRemoteArchive :: RemoteState st -> IO ()
createRemoteArchive :: forall st. RemoteState st -> IO ()
createRemoteArchive (RemoteState Command -> IO (MVar Response)
fn IO ()
_shutdown)
= do Response
Acknowledgement <- forall a. MVar a -> IO a
takeMVar forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Command -> IO (MVar Response)
fn Command
CreateArchive
forall (m :: * -> *) a. Monad m => a -> m a
return ()
toAcidState :: forall st . IsAcidic st => RemoteState st -> AcidState st
toAcidState :: forall st. IsAcidic st => RemoteState st -> AcidState st
toAcidState RemoteState st
remote
= AcidState { _scheduleUpdate :: forall event.
(UpdateEvent event, EventState event ~ st) =>
event -> IO (MVar (EventResult event))
_scheduleUpdate = forall event.
UpdateEvent event =>
RemoteState (EventState event)
-> MethodMap (EventState event)
-> event
-> IO (MVar (EventResult event))
scheduleRemoteUpdate RemoteState st
remote MethodMap st
mmap
, scheduleColdUpdate :: Tagged ByteString -> IO (MVar ByteString)
scheduleColdUpdate = forall st.
RemoteState st -> Tagged ByteString -> IO (MVar ByteString)
scheduleRemoteColdUpdate RemoteState st
remote
, _query :: forall event.
(QueryEvent event, EventState event ~ st) =>
event -> IO (EventResult event)
_query = forall event.
QueryEvent event =>
RemoteState (EventState event)
-> MethodMap (EventState event) -> event -> IO (EventResult event)
remoteQuery RemoteState st
remote MethodMap st
mmap
, queryCold :: Tagged ByteString -> IO ByteString
queryCold = forall st. RemoteState st -> Tagged ByteString -> IO ByteString
remoteQueryCold RemoteState st
remote
, createCheckpoint :: IO ()
createCheckpoint = forall st. RemoteState st -> IO ()
createRemoteCheckpoint RemoteState st
remote
, createArchive :: IO ()
createArchive = forall st. RemoteState st -> IO ()
createRemoteArchive RemoteState st
remote
, closeAcidState :: IO ()
closeAcidState = forall st. RemoteState st -> IO ()
closeRemoteState RemoteState st
remote
, acidSubState :: AnyState st
acidSubState = forall (sub_st :: * -> *) st.
Typeable sub_st =>
sub_st st -> AnyState st
mkAnyState RemoteState st
remote
}
where
mmap :: MethodMap st
mmap :: MethodMap st
mmap = forall st. [MethodContainer st] -> MethodMap st
mkMethodMap (forall st. [Event st] -> [MethodContainer st]
eventsToMethods forall st. IsAcidic st => [Event st]
acidEvents)