{-# LANGUAGE CPP, DeriveDataTypeable, RankNTypes, RecordWildCards, ScopedTypeVariables #-}
-----------------------------------------------------------------------------
{- |
 Module      :  Data.Acid.Remote
 Copyright   :  PublicDomain

 Maintainer  :  lemmih@gmail.com
 Portability :  non-portable (uses GHC extensions)

 This module provides the ability perform 'update' and 'query' calls
from a remote process.

On the server-side you:

 1. open your 'AcidState' normally

 2. then use 'acidServer' to share the state

On the client-side you:

 1. use 'openRemoteState' to connect to the remote state

 2. use the returned 'AcidState' like any other 'AcidState' handle

'openRemoteState' and 'acidServer' communicate over an unencrypted
socket. If you need an encrypted connection, see @acid-state-tls@.

On Unix®-like systems you can use 'SockAddrUnix' to create a socket file for
local communication between the client and server. Access can be
controlled by setting the permissions of the parent directory
containing the socket file.

It is also possible to perform some simple authentication using
'sharedSecretCheck' and 'sharedSecretPerform'. Keep in mind that
secrets will be sent in plain-text if you do not use
@acid-state-tls@. If you are using a 'SockAddrUnix' additional
authentication may not be required, so you can use
'skipAuthenticationCheck' and 'skipAuthenticationPerform'.

Working with a remote 'AcidState' is nearly identical to working with
a local 'AcidState' with a few important differences.

The connection to the remote 'AcidState' can be lost. The client will
automatically attempt to reconnect every second. Because 'query'
events do not affect the state, an aborted 'query' will be retried
automatically after the server is reconnected.

If the connection was lost during an 'update' event, the event will
not be retried. Instead 'RemoteConnectionError' will be raised. This
is because it is impossible for the client to know if the aborted
update completed on the server-side or not.

When using a local 'AcidState', an update event in one thread does not
block query events taking place in other threads. With a remote
connection, all queries and requests are channeled over a single
connection. As a result, updates and queries are performed in the
order they are executed and do block each other. In the rare case
where this is an issue, you could create one remote connection per
thread.

When working with local state, a query or update which returns the
whole state is not usually a problem due to memory sharing. The
update/query event basically just needs to return a pointer to the
data already in memory. But, when working remotely, the entire result
will be serialized and sent to the remote client. Hence, it is good
practice to create queries and updates that will only return the
required data.

This module is designed to be extenible. You can easily add your own
authentication methods by creating a suitable pair of functions and
passing them to 'acidServer' and 'openRemoteState'.

It is also possible to create alternative communication layers using
'CommChannel', 'process', and 'processRemoteState'.

-}
module Data.Acid.Remote
    (
    -- * Server/Client
      acidServer
    , acidServerSockAddr
    , acidServer'
    , openRemoteState
    , openRemoteStateSockAddr
    -- * Authentication
    , skipAuthenticationCheck
    , skipAuthenticationPerform
    , sharedSecretCheck
    , sharedSecretPerform
    -- * Exception type
    , AcidRemoteException(..)
    -- * Low-Level functions needed to implement additional communication channels
    , CommChannel(..)
    , process
    , processRemoteState
    ) where

import Prelude                                hiding ( catch )
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative
#endif
import Control.Concurrent.STM                        ( atomically )
import Control.Concurrent.STM.TMVar                  ( newEmptyTMVar, readTMVar, takeTMVar, tryTakeTMVar, putTMVar )
import Control.Concurrent.STM.TQueue
import Control.Exception                             ( AsyncException(ThreadKilled)
                                                     , Exception(fromException), IOException, Handler(..)
                                                     , SomeException, catch, catches, throw, bracketOnError )
import Control.Exception                             ( throwIO, finally )
import Control.Monad                                 ( forever, liftM, join, when )
import Control.Concurrent                            ( ThreadId, forkIO, threadDelay, killThread, myThreadId )
import Control.Concurrent.MVar                       ( MVar, newEmptyMVar, putMVar, takeMVar )
import Control.Concurrent.Chan                       ( newChan, readChan, writeChan )
import Data.Acid.Abstract
import Data.Acid.Core
import Data.Acid.Common
#if !MIN_VERSION_base(4,11,0)
import Data.Monoid                                   ((<>))
#endif
import qualified Data.ByteString                     as Strict
import Data.ByteString.Char8                         ( pack )
import qualified Data.ByteString.Lazy                as Lazy
import Data.IORef                                    ( newIORef, readIORef, writeIORef )
import Data.Serialize
import Data.Set                                      ( Set, member )
import Data.Typeable                                 ( Typeable )
import GHC.IO.Exception                              ( IOErrorType(..) )
import Network.BSD                                   ( PortNumber, getProtocolNumber, getHostByName, hostAddress )
import Network.Socket
import Network.Socket.ByteString                     as NSB ( recv, sendAll )
import System.Directory                              ( removeFile )
import System.IO                                     ( Handle, hPrint, hFlush, hClose, stderr, IOMode(..) )
import System.IO.Error                               ( ioeGetErrorType, isFullError, isDoesNotExistError )

debugStrLn :: String -> IO ()
debugStrLn :: String -> IO ()
debugStrLn String
s =
    do -- putStrLn s -- uncomment to enable debugging
       () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | 'CommChannel' is a record containing the IO functions we need for communication between the server and client.
--
-- We abstract this out of the core processing function so that we can easily add support for SSL/TLS and Unit testing.
data CommChannel = CommChannel
    { CommChannel -> ByteString -> IO ()
ccPut     :: Strict.ByteString -> IO ()
    , CommChannel -> Int -> IO ByteString
ccGetSome :: Int -> IO (Strict.ByteString)
    , CommChannel -> IO ()
ccClose   :: IO ()
    }

data AcidRemoteException
    = RemoteConnectionError
    | AcidStateClosed
    | SerializeError String
    | AuthenticationError String
      deriving (AcidRemoteException -> AcidRemoteException -> Bool
(AcidRemoteException -> AcidRemoteException -> Bool)
-> (AcidRemoteException -> AcidRemoteException -> Bool)
-> Eq AcidRemoteException
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: AcidRemoteException -> AcidRemoteException -> Bool
$c/= :: AcidRemoteException -> AcidRemoteException -> Bool
== :: AcidRemoteException -> AcidRemoteException -> Bool
$c== :: AcidRemoteException -> AcidRemoteException -> Bool
Eq, Int -> AcidRemoteException -> ShowS
[AcidRemoteException] -> ShowS
AcidRemoteException -> String
(Int -> AcidRemoteException -> ShowS)
-> (AcidRemoteException -> String)
-> ([AcidRemoteException] -> ShowS)
-> Show AcidRemoteException
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [AcidRemoteException] -> ShowS
$cshowList :: [AcidRemoteException] -> ShowS
show :: AcidRemoteException -> String
$cshow :: AcidRemoteException -> String
showsPrec :: Int -> AcidRemoteException -> ShowS
$cshowsPrec :: Int -> AcidRemoteException -> ShowS
Show, Typeable)
instance Exception AcidRemoteException

-- | create a 'CommChannel' from a 'Handle'. The 'Handle' should be
-- some two-way communication channel, such as a socket
-- connection. Passing in a 'Handle' to a normal is file is unlikely
-- to do anything useful.
handleToCommChannel :: Handle -> CommChannel
handleToCommChannel :: Handle -> CommChannel
handleToCommChannel Handle
handle =
    CommChannel { ccPut :: ByteString -> IO ()
ccPut     = \ByteString
bs -> Handle -> ByteString -> IO ()
Strict.hPut Handle
handle ByteString
bs IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Handle -> IO ()
hFlush Handle
handle
                , ccGetSome :: Int -> IO ByteString
ccGetSome = Handle -> Int -> IO ByteString
Strict.hGetSome Handle
handle
                , ccClose :: IO ()
ccClose   = Handle -> IO ()
hClose Handle
handle
                }

{- | create a 'CommChannel' from a 'Socket'. The 'Socket' should be
     an accepted socket, not a listen socket.
-}
socketToCommChannel :: Socket -> CommChannel
socketToCommChannel :: Socket -> CommChannel
socketToCommChannel Socket
socket =
    CommChannel { ccPut :: ByteString -> IO ()
ccPut     = Socket -> ByteString -> IO ()
sendAll Socket
socket
                , ccGetSome :: Int -> IO ByteString
ccGetSome = Socket -> Int -> IO ByteString
NSB.recv Socket
socket
                , ccClose :: IO ()
ccClose   = Socket -> IO ()
close  Socket
socket
                }

{- | skip server-side authentication checking entirely. -}
skipAuthenticationCheck :: CommChannel -> IO Bool
skipAuthenticationCheck :: CommChannel -> IO Bool
skipAuthenticationCheck CommChannel
_ = Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True

{- | skip client-side authentication entirely. -}
skipAuthenticationPerform :: CommChannel -> IO ()
skipAuthenticationPerform :: CommChannel -> IO ()
skipAuthenticationPerform CommChannel
_ = () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

{- | check that the client knows a shared secret.

The function takes a 'Set' of shared secrets. If a client knows any
of them, it is considered to be trusted.

The shared secret is any 'ByteString' of your choice.

If you give each client a different shared secret then you can
revoke access individually.

see also: 'sharedSecretPerform'
-}
sharedSecretCheck :: Set Strict.ByteString -- ^ set of shared secrets
                  -> (CommChannel -> IO Bool)
sharedSecretCheck :: Set ByteString -> CommChannel -> IO Bool
sharedSecretCheck Set ByteString
secrets CommChannel
cc =
    do ByteString
bs <- CommChannel -> Int -> IO ByteString
ccGetSome CommChannel
cc Int
1024
       if ByteString -> Set ByteString -> Bool
forall a. Ord a => a -> Set a -> Bool
member ByteString
bs Set ByteString
secrets
          then do CommChannel -> ByteString -> IO ()
ccPut CommChannel
cc (String -> ByteString
pack String
"OK")
                  Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
          else do CommChannel -> ByteString -> IO ()
ccPut CommChannel
cc (String -> ByteString
pack String
"FAIL")
                  Bool -> IO Bool
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

-- | attempt to authenticate with the server using a shared secret.
sharedSecretPerform :: Strict.ByteString -- ^ shared secret
                    -> (CommChannel -> IO ())
sharedSecretPerform :: ByteString -> CommChannel -> IO ()
sharedSecretPerform ByteString
pw CommChannel
cc =
    do CommChannel -> ByteString -> IO ()
ccPut CommChannel
cc ByteString
pw
       ByteString
r <- CommChannel -> Int -> IO ByteString
ccGetSome CommChannel
cc Int
1024
       if ByteString
r ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== (String -> ByteString
pack String
"OK")
          then () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
          else AcidRemoteException -> IO ()
forall e a. Exception e => e -> IO a
throwIO (String -> AcidRemoteException
AuthenticationError String
"shared secret authentication failed.")

{- | Accept connections on @sockAddr@ and handle requests using the given 'AcidState'.
     This call doesn't return.

     see also: 'acidServer', 'openRemoteState' and 'sharedSecretCheck'.
 -}
acidServerSockAddr :: (CommChannel -> IO Bool) -- ^ check authentication, see 'sharedSecretPerform'
           -> SockAddr                 -- ^ SockAddr to listen on
           -> AcidState st             -- ^ state to serve
           -> IO ()
acidServerSockAddr :: forall st.
(CommChannel -> IO Bool) -> SockAddr -> AcidState st -> IO ()
acidServerSockAddr CommChannel -> IO Bool
checkAuth SockAddr
sockAddr AcidState st
acidState
  = do Socket
listenSocket <- SockAddr -> IO Socket
listenOn SockAddr
sockAddr
       ((CommChannel -> IO Bool) -> Socket -> AcidState st -> IO ()
forall st.
(CommChannel -> IO Bool) -> Socket -> AcidState st -> IO ()
acidServer' CommChannel -> IO Bool
checkAuth Socket
listenSocket AcidState st
acidState) IO () -> IO () -> IO ()
forall a b. IO a -> IO b -> IO a
`finally` (Socket -> IO ()
cleanup Socket
listenSocket)
    where
      cleanup :: Socket -> IO ()
cleanup Socket
socket =
          do Socket -> IO ()
close Socket
socket
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
             case SockAddr
sockAddr of
               (SockAddrUnix String
path) -> String -> IO ()
removeFile String
path
               SockAddr
_ -> () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
#endif


{- | Accept connections on @port@ and handle requests using the given 'AcidState'.
     This call doesn't return.

     see also: 'acidServerSockAddr', 'openRemoteState' and 'sharedSecretCheck'.
 -}
acidServer :: (CommChannel -> IO Bool) -- ^ check authentication, see 'sharedSecretPerform'
           -> PortNumber               -- ^ Port to listen on
           -> AcidState st             -- ^ state to serve
           -> IO ()
acidServer :: forall st.
(CommChannel -> IO Bool) -> PortNumber -> AcidState st -> IO ()
acidServer CommChannel -> IO Bool
checkAuth PortNumber
port AcidState st
acidState
  = (CommChannel -> IO Bool) -> SockAddr -> AcidState st -> IO ()
forall st.
(CommChannel -> IO Bool) -> SockAddr -> AcidState st -> IO ()
acidServerSockAddr CommChannel -> IO Bool
checkAuth (PortNumber -> HostAddress -> SockAddr
SockAddrInet PortNumber
port HostAddress
0) AcidState st
acidState

listenOn :: SockAddr -> IO Socket
listenOn :: SockAddr -> IO Socket
listenOn SockAddr
sockAddr = do
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
    ProtocolNumber
proto <- case SockAddr
sockAddr of
              (SockAddrUnix {}) -> ProtocolNumber -> IO ProtocolNumber
forall (f :: * -> *) a. Applicative f => a -> f a
pure ProtocolNumber
0
              SockAddr
_                 -> String -> IO ProtocolNumber
getProtocolNumber String
"tcp"
#else
    proto <- getProtocolNumber "tcp"
#endif
    IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Socket) -> IO Socket
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
        (Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
af SocketType
Stream ProtocolNumber
proto)
        Socket -> IO ()
close
        (\Socket
sock -> do
            Socket -> SocketOption -> Int -> IO ()
setSocketOption Socket
sock SocketOption
ReuseAddr Int
1
            Socket -> SockAddr -> IO ()
bind Socket
sock SockAddr
sockAddr
            Socket -> Int -> IO ()
listen Socket
sock Int
maxListenQueue
            Socket -> IO Socket
forall (m :: * -> *) a. Monad m => a -> m a
return Socket
sock
        )

      where
        af :: Family
af = case SockAddr
sockAddr of
          (SockAddrInet {})  -> Family
AF_INET
          (SockAddrInet6 {}) -> Family
AF_INET6
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
          (SockAddrUnix {})  -> Family
AF_UNIX
#endif

{- | Works the same way as 'acidServer', but uses pre-binded socket @listenSocket@.

     Can be useful when fine-tuning of socket binding parameters is needed
     (for example, listening on a particular network interface, IPv4/IPv6 options).
 -}
acidServer' :: (CommChannel -> IO Bool) -- ^ check authentication, see 'sharedSecretPerform'
           -> Socket                   -- ^ binded socket to accept connections from
           -> AcidState st             -- ^ state to serve
           -> IO ()
acidServer' :: forall st.
(CommChannel -> IO Bool) -> Socket -> AcidState st -> IO ()
acidServer' CommChannel -> IO Bool
checkAuth Socket
listenSocket AcidState st
acidState
  = do
       let loop :: IO b
loop = IO ThreadId -> IO b
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO ThreadId -> IO b) -> IO ThreadId -> IO b
forall a b. (a -> b) -> a -> b
$
             do (Socket
socket, SockAddr
_sockAddr) <- Socket -> IO (Socket, SockAddr)
accept Socket
listenSocket
                let commChannel :: CommChannel
commChannel = Socket -> CommChannel
socketToCommChannel Socket
socket
                IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do Bool
authorized <- CommChannel -> IO Bool
checkAuth CommChannel
commChannel
                            Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when Bool
authorized (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
                                 CommChannel -> AcidState st -> IO ()
forall st. CommChannel -> AcidState st -> IO ()
process CommChannel
commChannel AcidState st
acidState
                            CommChannel -> IO ()
ccClose CommChannel
commChannel -- FIXME: `finally` ?
           infi :: IO b
infi = IO ()
forall {b}. IO b
loop IO () -> (Show Any => Any -> IO ()) -> IO ()
forall e. IO () -> (Show e => e -> IO ()) -> IO ()
`catchSome` Show Any => Any -> IO ()
forall e. Show e => e -> IO ()
logError IO () -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO b
infi
       IO ()
forall {b}. IO b
infi
    where
      logError :: (Show e) => e -> IO ()
      logError :: forall e. Show e => e -> IO ()
logError e
e = Handle -> e -> IO ()
forall a. Show a => Handle -> a -> IO ()
hPrint Handle
stderr e
e

      isResourceVanishedError :: IOException -> Bool
      isResourceVanishedError :: IOException -> Bool
isResourceVanishedError = IOErrorType -> Bool
isResourceVanishedType (IOErrorType -> Bool)
-> (IOException -> IOErrorType) -> IOException -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOException -> IOErrorType
ioeGetErrorType

      isResourceVanishedType :: IOErrorType -> Bool
      isResourceVanishedType :: IOErrorType -> Bool
isResourceVanishedType IOErrorType
ResourceVanished = Bool
True
      isResourceVanishedType IOErrorType
_                = Bool
False

      catchSome :: IO () -> (Show e => e -> IO ()) -> IO ()
      catchSome :: forall e. IO () -> (Show e => e -> IO ()) -> IO ()
catchSome IO ()
op Show e => e -> IO ()
_h =
          IO ()
op IO () -> [Handler ()] -> IO ()
forall a. IO a -> [Handler a] -> IO a
`catches` [ (IOException -> IO ()) -> Handler ()
forall a e. Exception e => (e -> IO a) -> Handler a
Handler ((IOException -> IO ()) -> Handler ())
-> (IOException -> IO ()) -> Handler ()
forall a b. (a -> b) -> a -> b
$ \(IOException
e :: IOException)    ->
                           if IOException -> Bool
isFullError IOException
e Bool -> Bool -> Bool
|| IOException -> Bool
isDoesNotExistError IOException
e Bool -> Bool -> Bool
|| IOException -> Bool
isResourceVanishedError IOException
e
                            then () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return () -- h (toException e) -- we could log the exception, but there could be thousands of them
                            else IOException -> IO ()
forall a e. Exception e => e -> a
throw IOException
e
                       ]

data Command = RunQuery (Tagged Lazy.ByteString)
             | RunUpdate (Tagged Lazy.ByteString)
             | CreateCheckpoint
             | CreateArchive

instance Serialize Command where
  put :: Putter Command
put Command
cmd = case Command
cmd of
              RunQuery Tagged ByteString
query   -> do Putter Word8
putWord8 Word8
0; Putter (Tagged ByteString)
forall t. Serialize t => Putter t
put Tagged ByteString
query
              RunUpdate Tagged ByteString
update -> do Putter Word8
putWord8 Word8
1; Putter (Tagged ByteString)
forall t. Serialize t => Putter t
put Tagged ByteString
update
              Command
CreateCheckpoint ->    Putter Word8
putWord8 Word8
2
              Command
CreateArchive    ->    Putter Word8
putWord8 Word8
3
  get :: Get Command
get = do Word8
tag <- Get Word8
getWord8
           case Word8
tag of
             Word8
0 -> (Tagged ByteString -> Command)
-> Get (Tagged ByteString) -> Get Command
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Tagged ByteString -> Command
RunQuery Get (Tagged ByteString)
forall t. Serialize t => Get t
get
             Word8
1 -> (Tagged ByteString -> Command)
-> Get (Tagged ByteString) -> Get Command
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM Tagged ByteString -> Command
RunUpdate Get (Tagged ByteString)
forall t. Serialize t => Get t
get
             Word8
2 -> Command -> Get Command
forall (m :: * -> *) a. Monad m => a -> m a
return Command
CreateCheckpoint
             Word8
3 -> Command -> Get Command
forall (m :: * -> *) a. Monad m => a -> m a
return Command
CreateArchive
             Word8
_ -> String -> Get Command
forall a. HasCallStack => String -> a
error (String -> Get Command) -> String -> Get Command
forall a b. (a -> b) -> a -> b
$ String
"Data.Acid.Remote: Serialize.get for Command, invalid tag: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Word8 -> String
forall a. Show a => a -> String
show Word8
tag

data Response = Result Lazy.ByteString | Acknowledgement | ConnectionError

instance Serialize Response where
  put :: Putter Response
put Response
resp = case Response
resp of
               Result ByteString
result -> do Putter Word8
putWord8 Word8
0; Putter ByteString
forall t. Serialize t => Putter t
put ByteString
result
               Response
Acknowledgement -> Putter Word8
putWord8 Word8
1
               Response
ConnectionError -> Putter Word8
putWord8 Word8
2
  get :: Get Response
get = do Word8
tag <- Get Word8
getWord8
           case Word8
tag of
             Word8
0 -> (ByteString -> Response) -> Get ByteString -> Get Response
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM ByteString -> Response
Result Get ByteString
forall t. Serialize t => Get t
get
             Word8
1 -> Response -> Get Response
forall (m :: * -> *) a. Monad m => a -> m a
return Response
Acknowledgement
             Word8
2 -> Response -> Get Response
forall (m :: * -> *) a. Monad m => a -> m a
return Response
ConnectionError
             Word8
_ -> String -> Get Response
forall a. HasCallStack => String -> a
error (String -> Get Response) -> String -> Get Response
forall a b. (a -> b) -> a -> b
$ String
"Data.Acid.Remote: Serialize.get for Response, invalid tag: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Word8 -> String
forall a. Show a => a -> String
show Word8
tag

{- | Server inner-loop

     This function is generally only needed if you are adding a new communication channel.
-}
process :: CommChannel  -- ^ a connected, authenticated communication channel
        -> AcidState st -- ^ state to share
        -> IO ()
process :: forall st. CommChannel -> AcidState st -> IO ()
process CommChannel{IO ()
Int -> IO ByteString
ByteString -> IO ()
ccClose :: IO ()
ccGetSome :: Int -> IO ByteString
ccPut :: ByteString -> IO ()
ccClose :: CommChannel -> IO ()
ccGetSome :: CommChannel -> Int -> IO ByteString
ccPut :: CommChannel -> ByteString -> IO ()
..} AcidState st
acidState
  = do Chan (IO Response)
chan <- IO (Chan (IO Response))
forall a. IO (Chan a)
newChan
       IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ do Response
response <- IO (IO Response) -> IO Response
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (Chan (IO Response) -> IO (IO Response)
forall a. Chan a -> IO a
readChan Chan (IO Response)
chan)
                             ByteString -> IO ()
ccPut (Response -> ByteString
forall a. Serialize a => a -> ByteString
encode Response
response)
       Chan (IO Response) -> Result Command -> IO ()
worker Chan (IO Response)
chan (Get Command -> ByteString -> Result Command
forall a. Get a -> ByteString -> Result a
runGetPartial Get Command
forall t. Serialize t => Get t
get ByteString
Strict.empty)
  where worker :: Chan (IO Response) -> Result Command -> IO ()
worker Chan (IO Response)
chan Result Command
inp
          = case Result Command
inp of
              Fail String
msg ByteString
_    -> AcidRemoteException -> IO ()
forall e a. Exception e => e -> IO a
throwIO (String -> AcidRemoteException
SerializeError String
msg)
              Partial ByteString -> Result Command
cont  -> do ByteString
bs <- Int -> IO ByteString
ccGetSome Int
1024
                                  if ByteString -> Bool
Strict.null ByteString
bs then
                                     () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                                  else
                                     Chan (IO Response) -> Result Command -> IO ()
worker Chan (IO Response)
chan (ByteString -> Result Command
cont ByteString
bs)
              Done Command
cmd ByteString
rest -> do Chan (IO Response) -> Command -> IO ()
processCommand Chan (IO Response)
chan Command
cmd; Chan (IO Response) -> Result Command -> IO ()
worker Chan (IO Response)
chan (Get Command -> ByteString -> Result Command
forall a. Get a -> ByteString -> Result a
runGetPartial Get Command
forall t. Serialize t => Get t
get ByteString
rest)
        processCommand :: Chan (IO Response) -> Command -> IO ()
processCommand Chan (IO Response)
chan Command
cmd =
          case Command
cmd of
            RunQuery Tagged ByteString
query -> do ByteString
result <- AcidState st -> Tagged ByteString -> IO ByteString
forall st. AcidState st -> Tagged ByteString -> IO ByteString
queryCold AcidState st
acidState Tagged ByteString
query
                                 Chan (IO Response) -> IO Response -> IO ()
forall a. Chan a -> a -> IO ()
writeChan Chan (IO Response)
chan (Response -> IO Response
forall (m :: * -> *) a. Monad m => a -> m a
return (Response -> IO Response) -> Response -> IO Response
forall a b. (a -> b) -> a -> b
$ ByteString -> Response
Result ByteString
result)
            RunUpdate Tagged ByteString
update -> do MVar ByteString
result <- AcidState st -> Tagged ByteString -> IO (MVar ByteString)
forall st.
AcidState st -> Tagged ByteString -> IO (MVar ByteString)
scheduleColdUpdate AcidState st
acidState Tagged ByteString
update
                                   Chan (IO Response) -> IO Response -> IO ()
forall a. Chan a -> a -> IO ()
writeChan Chan (IO Response)
chan ((ByteString -> Response) -> IO ByteString -> IO Response
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM ByteString -> Response
Result (IO ByteString -> IO Response) -> IO ByteString -> IO Response
forall a b. (a -> b) -> a -> b
$ MVar ByteString -> IO ByteString
forall a. MVar a -> IO a
takeMVar MVar ByteString
result)
            Command
CreateCheckpoint -> do AcidState st -> IO ()
forall st. AcidState st -> IO ()
createCheckpoint AcidState st
acidState
                                   Chan (IO Response) -> IO Response -> IO ()
forall a. Chan a -> a -> IO ()
writeChan Chan (IO Response)
chan (Response -> IO Response
forall (m :: * -> *) a. Monad m => a -> m a
return Response
Acknowledgement)
            Command
CreateArchive -> do AcidState st -> IO ()
forall st. AcidState st -> IO ()
createArchive AcidState st
acidState
                                Chan (IO Response) -> IO Response -> IO ()
forall a. Chan a -> a -> IO ()
writeChan Chan (IO Response)
chan (Response -> IO Response
forall (m :: * -> *) a. Monad m => a -> m a
return Response
Acknowledgement)

data RemoteState st = RemoteState (Command -> IO (MVar Response)) (IO ())
                    deriving (Typeable)

{- | Connect to an acid-state server which is sharing an 'AcidState'. -}
openRemoteState :: IsAcidic st =>
                   (CommChannel -> IO ()) -- ^ authentication function, see 'sharedSecretPerform'
                -> HostName               -- ^ remote host to connect to
                -> PortNumber             -- ^ remote port to connect to
                -> IO (AcidState st)
openRemoteState :: forall st.
IsAcidic st =>
(CommChannel -> IO ()) -> String -> PortNumber -> IO (AcidState st)
openRemoteState CommChannel -> IO ()
performAuthorization String
host PortNumber
port =
   do HostEntry
he    <- String -> IO HostEntry
getHostByName String
host
      (CommChannel -> IO ()) -> SockAddr -> IO (AcidState st)
forall st.
IsAcidic st =>
(CommChannel -> IO ()) -> SockAddr -> IO (AcidState st)
openRemoteStateSockAddr CommChannel -> IO ()
performAuthorization (PortNumber -> HostAddress -> SockAddr
SockAddrInet PortNumber
port (HostEntry -> HostAddress
hostAddress HostEntry
he))

{- | Connect to an acid-state server which is sharing an 'AcidState'. -}
openRemoteStateSockAddr :: IsAcidic st =>
                   (CommChannel -> IO ()) -- ^ authentication function, see 'sharedSecretPerform'
                -> SockAddr               -- ^ remote SockAddr to connect to
                -> IO (AcidState st)
openRemoteStateSockAddr :: forall st.
IsAcidic st =>
(CommChannel -> IO ()) -> SockAddr -> IO (AcidState st)
openRemoteStateSockAddr CommChannel -> IO ()
performAuthorization SockAddr
sockAddr
  = IO (AcidState st) -> IO (AcidState st)
forall a. IO a -> IO a
withSocketsDo (IO (AcidState st) -> IO (AcidState st))
-> IO (AcidState st) -> IO (AcidState st)
forall a b. (a -> b) -> a -> b
$
    do IO CommChannel -> IO (AcidState st)
forall st. IsAcidic st => IO CommChannel -> IO (AcidState st)
processRemoteState IO CommChannel
reconnect
    where
      af :: Family
      af :: Family
af = case SockAddr
sockAddr of
          (SockAddrInet {})  -> Family
AF_INET
          (SockAddrInet6 {}) -> Family
AF_INET6
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
          (SockAddrUnix {})  -> Family
AF_UNIX
#endif

      -- | reconnect
      reconnect :: IO CommChannel
      reconnect :: IO CommChannel
reconnect
          = (do String -> IO ()
debugStrLn String
"Reconnecting."
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
                ProtocolNumber
proto <- case SockAddr
sockAddr of
                           (SockAddrUnix {}) -> ProtocolNumber -> IO ProtocolNumber
forall (f :: * -> *) a. Applicative f => a -> f a
pure ProtocolNumber
0
                           SockAddr
_                 -> String -> IO ProtocolNumber
getProtocolNumber String
"tcp"
#else
                proto <- getProtocolNumber "tcp"
#endif
                Handle
handle <- IO Socket
-> (Socket -> IO ()) -> (Socket -> IO Handle) -> IO Handle
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracketOnError
                    (Family -> SocketType -> ProtocolNumber -> IO Socket
socket Family
af SocketType
Stream ProtocolNumber
proto)
                    Socket -> IO ()
close  -- only done if there's an error
                    (\Socket
sock -> do
                      Socket -> SockAddr -> IO ()
connect Socket
sock SockAddr
sockAddr
                      Socket -> IOMode -> IO Handle
socketToHandle Socket
sock IOMode
ReadWriteMode
                    )

                let cc :: CommChannel
cc = Handle -> CommChannel
handleToCommChannel Handle
handle
                CommChannel -> IO ()
performAuthorization CommChannel
cc
                String -> IO ()
debugStrLn String
"Reconnected."
                CommChannel -> IO CommChannel
forall (m :: * -> *) a. Monad m => a -> m a
return CommChannel
cc
            )
            IO CommChannel -> (IOException -> IO CommChannel) -> IO CommChannel
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch`
            ((\IOException
_ -> Int -> IO ()
threadDelay Int
1000000 IO () -> IO CommChannel -> IO CommChannel
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO CommChannel
reconnect) :: IOError -> IO CommChannel)


{- | Client inner-loop

     This function is generally only needed if you are adding a new communication channel.
-}
processRemoteState :: IsAcidic st =>
                      IO CommChannel -- ^ (re-)connect function
                   -> IO (AcidState st)
processRemoteState :: forall st. IsAcidic st => IO CommChannel -> IO (AcidState st)
processRemoteState IO CommChannel
reconnect
  = do TQueue (Command, MVar Response)
cmdQueue    <- STM (TQueue (Command, MVar Response))
-> IO (TQueue (Command, MVar Response))
forall a. STM a -> IO a
atomically STM (TQueue (Command, MVar Response))
forall a. STM (TQueue a)
newTQueue
       TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
ccTMV       <- STM (TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId))
-> IO (TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId))
forall a. STM a -> IO a
atomically STM (TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId))
forall a. STM (TMVar a)
newEmptyTMVar
       IORef Bool
isClosed    <- Bool -> IO (IORef Bool)
forall a. a -> IO (IORef a)
newIORef Bool
False

       let actor :: Command -> IO (MVar Response)
           actor :: Command -> IO (MVar Response)
actor Command
command =
               do String -> IO ()
debugStrLn String
"actor: begin."
                  IORef Bool -> IO Bool
forall a. IORef a -> IO a
readIORef IORef Bool
isClosed IO Bool -> (Bool -> IO ()) -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (Bool -> IO () -> IO ()) -> IO () -> Bool -> IO ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (AcidRemoteException -> IO ()
forall e a. Exception e => e -> IO a
throwIO AcidRemoteException
AcidStateClosed)
                  MVar Response
ref <- IO (MVar Response)
forall a. IO (MVar a)
newEmptyMVar
                  STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TQueue (Command, MVar Response)
-> (Command, MVar Response) -> STM ()
forall a. TQueue a -> a -> STM ()
writeTQueue TQueue (Command, MVar Response)
cmdQueue (Command
command, MVar Response
ref)
                  String -> IO ()
debugStrLn String
"actor: end."
                  MVar Response -> IO (MVar Response)
forall (m :: * -> *) a. Monad m => a -> m a
return MVar Response
ref

           expireQueue :: TQueue (Response -> IO a) -> IO ()
expireQueue TQueue (Response -> IO a)
listenQueue =
               do Maybe (Response -> IO a)
mCallback <- STM (Maybe (Response -> IO a)) -> IO (Maybe (Response -> IO a))
forall a. STM a -> IO a
atomically (STM (Maybe (Response -> IO a)) -> IO (Maybe (Response -> IO a)))
-> STM (Maybe (Response -> IO a)) -> IO (Maybe (Response -> IO a))
forall a b. (a -> b) -> a -> b
$ TQueue (Response -> IO a) -> STM (Maybe (Response -> IO a))
forall a. TQueue a -> STM (Maybe a)
tryReadTQueue TQueue (Response -> IO a)
listenQueue
                  case Maybe (Response -> IO a)
mCallback of
                    Maybe (Response -> IO a)
Nothing         -> () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                    (Just Response -> IO a
callback) ->
                        do Response -> IO a
callback Response
ConnectionError
                           TQueue (Response -> IO a) -> IO ()
expireQueue TQueue (Response -> IO a)
listenQueue

           handleReconnect :: SomeException -> IO ()
           handleReconnect :: SomeException -> IO ()
handleReconnect SomeException
e
             = case SomeException -> Maybe AsyncException
forall e. Exception e => SomeException -> Maybe e
fromException SomeException
e of
                 (Just AsyncException
ThreadKilled) ->
                     do String -> IO ()
debugStrLn String
"handleReconnect: ThreadKilled. Not attempting to reconnect."
                        () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                 Maybe AsyncException
_ ->
                   do String -> IO ()
debugStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"handleReconnect begin."
                      Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId)
tmv <- STM (Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId))
-> IO (Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId))
forall a. STM a -> IO a
atomically (STM (Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId))
 -> IO (Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId)))
-> STM (Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId))
-> IO (Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId))
forall a b. (a -> b) -> a -> b
$ TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
-> STM (Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId))
forall a. TMVar a -> STM (Maybe a)
tryTakeTMVar TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
ccTMV
                      case Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId)
tmv of
                        Maybe (CommChannel, TQueue (Response -> IO ()), ThreadId)
Nothing ->
                            do String -> IO ()
debugStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"handleReconnect: error handling already in progress."
                               String -> IO ()
debugStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"handleReconnect end."
                               () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
                        (Just (CommChannel
oldCC, TQueue (Response -> IO ())
oldListenQueue, ThreadId
oldListenerTID)) ->
                            do ThreadId
thisTID <- IO ThreadId
myThreadId
                               Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (ThreadId
thisTID ThreadId -> ThreadId -> Bool
forall a. Eq a => a -> a -> Bool
/= ThreadId
oldListenerTID) (ThreadId -> IO ()
killThread ThreadId
oldListenerTID)
                               CommChannel -> IO ()
ccClose CommChannel
oldCC
                               TQueue (Response -> IO ()) -> IO ()
forall {a}. TQueue (Response -> IO a) -> IO ()
expireQueue TQueue (Response -> IO ())
oldListenQueue
                               CommChannel
cc <- IO CommChannel
reconnect
                               TQueue (Response -> IO ())
listenQueue <- STM (TQueue (Response -> IO ())) -> IO (TQueue (Response -> IO ()))
forall a. STM a -> IO a
atomically (STM (TQueue (Response -> IO ()))
 -> IO (TQueue (Response -> IO ())))
-> STM (TQueue (Response -> IO ()))
-> IO (TQueue (Response -> IO ()))
forall a b. (a -> b) -> a -> b
$ STM (TQueue (Response -> IO ()))
forall a. STM (TQueue a)
newTQueue
                               ThreadId
listenerTID <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ CommChannel -> TQueue (Response -> IO ()) -> IO ()
listener CommChannel
cc TQueue (Response -> IO ())
listenQueue
                               STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
-> (CommChannel, TQueue (Response -> IO ()), ThreadId) -> STM ()
forall a. TMVar a -> a -> STM ()
putTMVar TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
ccTMV (CommChannel
cc, TQueue (Response -> IO ())
listenQueue, ThreadId
listenerTID)
                               String -> IO ()
debugStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"handleReconnect end."
                               () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

           listener :: CommChannel -> TQueue (Response -> IO ()) -> IO ()
           listener :: CommChannel -> TQueue (Response -> IO ()) -> IO ()
listener CommChannel
cc TQueue (Response -> IO ())
listenQueue
             = ByteString -> IO ()
forall {b}. ByteString -> IO b
getResponse ByteString
Strict.empty IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` SomeException -> IO ()
handleReconnect
               where
                 getResponse :: ByteString -> IO b
getResponse ByteString
leftover =
                     do String -> IO ()
debugStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"listener: listening for Response."
                        let go :: Result Response -> IO ByteString
go Result Response
inp = case Result Response
inp of
                                   Fail String
msg ByteString
_     -> String -> IO ByteString
forall a. HasCallStack => String -> a
error (String -> IO ByteString) -> String -> IO ByteString
forall a b. (a -> b) -> a -> b
$ String
"Data.Acid.Remote: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
msg
                                   Partial ByteString -> Result Response
cont   -> do String -> IO ()
debugStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"listener: ccGetSome"
                                                        ByteString
bs <- CommChannel -> Int -> IO ByteString
ccGetSome CommChannel
cc Int
1024
                                                        Result Response -> IO ByteString
go (ByteString -> Result Response
cont ByteString
bs)
                                   Done Response
resp ByteString
rest -> do String -> IO ()
debugStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"listener: getting callback"
                                                        Response -> IO ()
callback <- STM (Response -> IO ()) -> IO (Response -> IO ())
forall a. STM a -> IO a
atomically (STM (Response -> IO ()) -> IO (Response -> IO ()))
-> STM (Response -> IO ()) -> IO (Response -> IO ())
forall a b. (a -> b) -> a -> b
$ TQueue (Response -> IO ()) -> STM (Response -> IO ())
forall a. TQueue a -> STM a
readTQueue TQueue (Response -> IO ())
listenQueue
                                                        String -> IO ()
debugStrLn (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"listener: passing Response to callback"
                                                        Response -> IO ()
callback (Response
resp :: Response)
                                                        ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
rest
                        ByteString
rest <- Result Response -> IO ByteString
go (Get Response -> ByteString -> Result Response
forall a. Get a -> ByteString -> Result a
runGetPartial Get Response
forall t. Serialize t => Get t
get ByteString
leftover) -- `catch` (\e -> do handleReconnect e
                                                                --                   throwIO e
                                                                 --        )
                        ByteString -> IO b
getResponse ByteString
rest

           actorThread :: IO ()
           actorThread :: IO ()
actorThread = IO () -> IO ()
forall (f :: * -> *) a b. Applicative f => f a -> f b
forever (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$
             do String -> IO ()
debugStrLn String
"actorThread: waiting for something to do."
                (CommChannel
cc, Command
cmd) <- STM (CommChannel, Command) -> IO (CommChannel, Command)
forall a. STM a -> IO a
atomically (STM (CommChannel, Command) -> IO (CommChannel, Command))
-> STM (CommChannel, Command) -> IO (CommChannel, Command)
forall a b. (a -> b) -> a -> b
$
                  do (Command
cmd, MVar Response
ref)        <- TQueue (Command, MVar Response) -> STM (Command, MVar Response)
forall a. TQueue a -> STM a
readTQueue TQueue (Command, MVar Response)
cmdQueue
                     (CommChannel
cc, TQueue (Response -> IO ())
listenQueue, ThreadId
_) <- TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
-> STM (CommChannel, TQueue (Response -> IO ()), ThreadId)
forall a. TMVar a -> STM a
readTMVar TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
ccTMV
                     TQueue (Response -> IO ()) -> (Response -> IO ()) -> STM ()
forall a. TQueue a -> a -> STM ()
writeTQueue TQueue (Response -> IO ())
listenQueue (MVar Response -> Response -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar Response
ref)
                     (CommChannel, Command) -> STM (CommChannel, Command)
forall (m :: * -> *) a. Monad m => a -> m a
return (CommChannel
cc, Command
cmd)
                String -> IO ()
debugStrLn String
"actorThread: sending command."
                CommChannel -> ByteString -> IO ()
ccPut CommChannel
cc (Command -> ByteString
forall a. Serialize a => a -> ByteString
encode Command
cmd) IO () -> (SomeException -> IO ()) -> IO ()
forall e a. Exception e => IO a -> (e -> IO a) -> IO a
`catch` SomeException -> IO ()
handleReconnect
                String -> IO ()
debugStrLn String
"actorThread: sent."
                () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

           shutdown :: ThreadId -> IO ()
           shutdown :: ThreadId -> IO ()
shutdown ThreadId
actorTID =
               do String -> IO ()
debugStrLn String
"shutdown: update isClosed IORef to True."
                  IORef Bool -> Bool -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef Bool
isClosed Bool
True
                  String -> IO ()
debugStrLn String
"shutdown: killing actor thread."
                  ThreadId -> IO ()
killThread ThreadId
actorTID
                  String -> IO ()
debugStrLn String
"shutdown: taking ccTMV."
                  (CommChannel
cc, TQueue (Response -> IO ())
listenQueue, ThreadId
listenerTID) <- STM (CommChannel, TQueue (Response -> IO ()), ThreadId)
-> IO (CommChannel, TQueue (Response -> IO ()), ThreadId)
forall a. STM a -> IO a
atomically (STM (CommChannel, TQueue (Response -> IO ()), ThreadId)
 -> IO (CommChannel, TQueue (Response -> IO ()), ThreadId))
-> STM (CommChannel, TQueue (Response -> IO ()), ThreadId)
-> IO (CommChannel, TQueue (Response -> IO ()), ThreadId)
forall a b. (a -> b) -> a -> b
$ TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
-> STM (CommChannel, TQueue (Response -> IO ()), ThreadId)
forall a. TMVar a -> STM a
takeTMVar TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
ccTMV -- FIXME: or should this by tryTakeTMVar
                  String -> IO ()
debugStrLn String
"shutdown: killing listener thread."
                  ThreadId -> IO ()
killThread ThreadId
listenerTID
                  String -> IO ()
debugStrLn String
"shutdown: expiring listen queue."
                  TQueue (Response -> IO ()) -> IO ()
forall {a}. TQueue (Response -> IO a) -> IO ()
expireQueue  TQueue (Response -> IO ())
listenQueue
                  String -> IO ()
debugStrLn String
"shutdown: closing connection."
                  CommChannel -> IO ()
ccClose CommChannel
cc
                  () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

       CommChannel
cc <- IO CommChannel
reconnect
       TQueue (Response -> IO ())
listenQueue <- STM (TQueue (Response -> IO ())) -> IO (TQueue (Response -> IO ()))
forall a. STM a -> IO a
atomically (STM (TQueue (Response -> IO ()))
 -> IO (TQueue (Response -> IO ())))
-> STM (TQueue (Response -> IO ()))
-> IO (TQueue (Response -> IO ()))
forall a b. (a -> b) -> a -> b
$ STM (TQueue (Response -> IO ()))
forall a. STM (TQueue a)
newTQueue

       ThreadId
actorTID    <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ IO ()
actorThread
       ThreadId
listenerTID <- IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ CommChannel -> TQueue (Response -> IO ()) -> IO ()
listener CommChannel
cc TQueue (Response -> IO ())
listenQueue

       STM () -> IO ()
forall a. STM a -> IO a
atomically (STM () -> IO ()) -> STM () -> IO ()
forall a b. (a -> b) -> a -> b
$ TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
-> (CommChannel, TQueue (Response -> IO ()), ThreadId) -> STM ()
forall a. TMVar a -> a -> STM ()
putTMVar TMVar (CommChannel, TQueue (Response -> IO ()), ThreadId)
ccTMV (CommChannel
cc, TQueue (Response -> IO ())
listenQueue, ThreadId
listenerTID)

       AcidState st -> IO (AcidState st)
forall (m :: * -> *) a. Monad m => a -> m a
return (RemoteState st -> AcidState st
forall st. IsAcidic st => RemoteState st -> AcidState st
toAcidState (RemoteState st -> AcidState st) -> RemoteState st -> AcidState st
forall a b. (a -> b) -> a -> b
$ (Command -> IO (MVar Response)) -> IO () -> RemoteState st
forall st.
(Command -> IO (MVar Response)) -> IO () -> RemoteState st
RemoteState Command -> IO (MVar Response)
actor (ThreadId -> IO ()
shutdown ThreadId
actorTID))

remoteQuery :: QueryEvent event => RemoteState (EventState event) -> MethodMap (EventState event) -> event -> IO (EventResult event)
remoteQuery :: forall event.
QueryEvent event =>
RemoteState (EventState event)
-> MethodMap (EventState event) -> event -> IO (EventResult event)
remoteQuery RemoteState (EventState event)
acidState MethodMap (EventState event)
mmap event
event
  = do let encoded :: ByteString
encoded = MethodSerialiser event -> event -> ByteString
forall method. MethodSerialiser method -> method -> ByteString
encodeMethod MethodSerialiser event
ms event
event
       ByteString
resp <- RemoteState (EventState event)
-> Tagged ByteString -> IO ByteString
forall st. RemoteState st -> Tagged ByteString -> IO ByteString
remoteQueryCold RemoteState (EventState event)
acidState (event -> ByteString
forall ev. Method ev => ev -> ByteString
methodTag event
event, ByteString
encoded)
       EventResult event -> IO (EventResult event)
forall (m :: * -> *) a. Monad m => a -> m a
return (case MethodSerialiser event
-> ByteString -> Either String (EventResult event)
forall method.
MethodSerialiser method
-> ByteString -> Either String (MethodResult method)
decodeResult MethodSerialiser event
ms ByteString
resp of
                 Left String
msg -> String -> EventResult event
forall a. HasCallStack => String -> a
error (String -> EventResult event) -> String -> EventResult event
forall a b. (a -> b) -> a -> b
$ String
"Data.Acid.Remote: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
msg
                 Right EventResult event
result -> EventResult event
result)
  where
    (State (EventState event) (EventResult event)
_, MethodSerialiser event
ms) = MethodMap (EventState event)
-> event
-> (State (EventState event) (EventResult event),
    MethodSerialiser event)
forall method.
Method method =>
MethodMap (MethodState method)
-> method
-> (State (MethodState method) (MethodResult method),
    MethodSerialiser method)
lookupHotMethodAndSerialiser MethodMap (EventState event)
mmap event
event

remoteQueryCold :: RemoteState st -> Tagged Lazy.ByteString -> IO Lazy.ByteString
remoteQueryCold :: forall st. RemoteState st -> Tagged ByteString -> IO ByteString
remoteQueryCold rs :: RemoteState st
rs@(RemoteState Command -> IO (MVar Response)
fn IO ()
_shutdown) Tagged ByteString
event
  = do Response
resp <- MVar Response -> IO Response
forall a. MVar a -> IO a
takeMVar (MVar Response -> IO Response) -> IO (MVar Response) -> IO Response
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Command -> IO (MVar Response)
fn (Tagged ByteString -> Command
RunQuery Tagged ByteString
event)
       case Response
resp of
         (Result ByteString
result) -> ByteString -> IO ByteString
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
result
         Response
ConnectionError -> do String -> IO ()
debugStrLn String
"retrying query event."
                               RemoteState st -> Tagged ByteString -> IO ByteString
forall st. RemoteState st -> Tagged ByteString -> IO ByteString
remoteQueryCold RemoteState st
rs Tagged ByteString
event
         Response
Acknowledgement    -> String -> IO ByteString
forall a. HasCallStack => String -> a
error String
"Data.Acid.Remote: remoteQueryCold got Acknowledgement. That should never happen."

scheduleRemoteUpdate :: UpdateEvent event => RemoteState (EventState event) -> MethodMap (EventState event) -> event -> IO (MVar (EventResult event))
scheduleRemoteUpdate :: forall event.
UpdateEvent event =>
RemoteState (EventState event)
-> MethodMap (EventState event)
-> event
-> IO (MVar (EventResult event))
scheduleRemoteUpdate (RemoteState Command -> IO (MVar Response)
fn IO ()
_shutdown) MethodMap (EventState event)
mmap event
event
  = do let encoded :: ByteString
encoded = MethodSerialiser event -> event -> ByteString
forall method. MethodSerialiser method -> method -> ByteString
encodeMethod MethodSerialiser event
ms event
event
       MVar (EventResult event)
parsed <- IO (MVar (EventResult event))
forall a. IO (MVar a)
newEmptyMVar
       MVar Response
respRef <- Command -> IO (MVar Response)
fn (Tagged ByteString -> Command
RunUpdate (event -> ByteString
forall ev. Method ev => ev -> ByteString
methodTag event
event, ByteString
encoded))
       IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do Result ByteString
resp <- MVar Response -> IO Response
forall a. MVar a -> IO a
takeMVar MVar Response
respRef
                   MVar (EventResult event) -> EventResult event -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar (EventResult event)
parsed (case MethodSerialiser event
-> ByteString -> Either String (EventResult event)
forall method.
MethodSerialiser method
-> ByteString -> Either String (MethodResult method)
decodeResult MethodSerialiser event
ms ByteString
resp of
                                      Left String
msg -> String -> EventResult event
forall a. HasCallStack => String -> a
error (String -> EventResult event) -> String -> EventResult event
forall a b. (a -> b) -> a -> b
$ String
"Data.Acid.Remote: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
msg
                                      Right EventResult event
result -> EventResult event
result)
       MVar (EventResult event) -> IO (MVar (EventResult event))
forall (m :: * -> *) a. Monad m => a -> m a
return MVar (EventResult event)
parsed
  where
    (State (EventState event) (EventResult event)
_, MethodSerialiser event
ms) = MethodMap (EventState event)
-> event
-> (State (EventState event) (EventResult event),
    MethodSerialiser event)
forall method.
Method method =>
MethodMap (MethodState method)
-> method
-> (State (MethodState method) (MethodResult method),
    MethodSerialiser method)
lookupHotMethodAndSerialiser MethodMap (EventState event)
mmap event
event

scheduleRemoteColdUpdate :: RemoteState st -> Tagged Lazy.ByteString -> IO (MVar Lazy.ByteString)
scheduleRemoteColdUpdate :: forall st.
RemoteState st -> Tagged ByteString -> IO (MVar ByteString)
scheduleRemoteColdUpdate (RemoteState Command -> IO (MVar Response)
fn IO ()
_shutdown) Tagged ByteString
event
  = do MVar ByteString
parsed <- IO (MVar ByteString)
forall a. IO (MVar a)
newEmptyMVar
       MVar Response
respRef <- Command -> IO (MVar Response)
fn (Tagged ByteString -> Command
RunUpdate Tagged ByteString
event)
       IO () -> IO ThreadId
forkIO (IO () -> IO ThreadId) -> IO () -> IO ThreadId
forall a b. (a -> b) -> a -> b
$ do Result ByteString
resp <- MVar Response -> IO Response
forall a. MVar a -> IO a
takeMVar MVar Response
respRef
                   MVar ByteString -> ByteString -> IO ()
forall a. MVar a -> a -> IO ()
putMVar MVar ByteString
parsed ByteString
resp
       MVar ByteString -> IO (MVar ByteString)
forall (m :: * -> *) a. Monad m => a -> m a
return MVar ByteString
parsed

closeRemoteState :: RemoteState st -> IO ()
closeRemoteState :: forall st. RemoteState st -> IO ()
closeRemoteState (RemoteState Command -> IO (MVar Response)
_fn IO ()
shutdown) = IO ()
shutdown

createRemoteCheckpoint :: RemoteState st -> IO ()
createRemoteCheckpoint :: forall st. RemoteState st -> IO ()
createRemoteCheckpoint (RemoteState Command -> IO (MVar Response)
fn IO ()
_shutdown)
  = do Response
Acknowledgement <- MVar Response -> IO Response
forall a. MVar a -> IO a
takeMVar (MVar Response -> IO Response) -> IO (MVar Response) -> IO Response
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Command -> IO (MVar Response)
fn Command
CreateCheckpoint
       () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

createRemoteArchive :: RemoteState st -> IO ()
createRemoteArchive :: forall st. RemoteState st -> IO ()
createRemoteArchive (RemoteState Command -> IO (MVar Response)
fn IO ()
_shutdown)
  = do Response
Acknowledgement <- MVar Response -> IO Response
forall a. MVar a -> IO a
takeMVar (MVar Response -> IO Response) -> IO (MVar Response) -> IO Response
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Command -> IO (MVar Response)
fn Command
CreateArchive
       () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

toAcidState :: forall st . IsAcidic st => RemoteState st -> AcidState st
toAcidState :: forall st. IsAcidic st => RemoteState st -> AcidState st
toAcidState RemoteState st
remote
  = AcidState { _scheduleUpdate :: forall event.
(UpdateEvent event, EventState event ~ st) =>
event -> IO (MVar (EventResult event))
_scheduleUpdate    = RemoteState (EventState event)
-> MethodMap (EventState event)
-> event
-> IO (MVar (EventResult event))
forall event.
UpdateEvent event =>
RemoteState (EventState event)
-> MethodMap (EventState event)
-> event
-> IO (MVar (EventResult event))
scheduleRemoteUpdate RemoteState st
RemoteState (EventState event)
remote MethodMap st
MethodMap (EventState event)
mmap
              , scheduleColdUpdate :: Tagged ByteString -> IO (MVar ByteString)
scheduleColdUpdate = RemoteState st -> Tagged ByteString -> IO (MVar ByteString)
forall st.
RemoteState st -> Tagged ByteString -> IO (MVar ByteString)
scheduleRemoteColdUpdate RemoteState st
remote
              , _query :: forall event.
(QueryEvent event, EventState event ~ st) =>
event -> IO (EventResult event)
_query             = RemoteState (EventState event)
-> MethodMap (EventState event) -> event -> IO (EventResult event)
forall event.
QueryEvent event =>
RemoteState (EventState event)
-> MethodMap (EventState event) -> event -> IO (EventResult event)
remoteQuery RemoteState st
RemoteState (EventState event)
remote MethodMap st
MethodMap (EventState event)
mmap
              , queryCold :: Tagged ByteString -> IO ByteString
queryCold          = RemoteState st -> Tagged ByteString -> IO ByteString
forall st. RemoteState st -> Tagged ByteString -> IO ByteString
remoteQueryCold RemoteState st
remote
              , createCheckpoint :: IO ()
createCheckpoint   = RemoteState st -> IO ()
forall st. RemoteState st -> IO ()
createRemoteCheckpoint RemoteState st
remote
              , createArchive :: IO ()
createArchive      = RemoteState st -> IO ()
forall st. RemoteState st -> IO ()
createRemoteArchive RemoteState st
remote
              , closeAcidState :: IO ()
closeAcidState     = RemoteState st -> IO ()
forall st. RemoteState st -> IO ()
closeRemoteState RemoteState st
remote
              , acidSubState :: AnyState st
acidSubState       = RemoteState st -> AnyState st
forall (sub_st :: * -> *) st.
Typeable sub_st =>
sub_st st -> AnyState st
mkAnyState RemoteState st
remote
              }
  where
    mmap :: MethodMap st
    mmap :: MethodMap st
mmap = [MethodContainer st] -> MethodMap st
forall st. [MethodContainer st] -> MethodMap st
mkMethodMap ([Event st] -> [MethodContainer st]
forall st. [Event st] -> [MethodContainer st]
eventsToMethods [Event st]
forall st. IsAcidic st => [Event st]
acidEvents)