{-# LANGUAGE CPP, DeriveDataTypeable, RankNTypes, RecordWildCards, ScopedTypeVariables #-}
module Data.Acid.Remote
(
acidServer
, acidServerSockAddr
, acidServer'
, openRemoteState
, openRemoteStateSockAddr
, skipAuthenticationCheck
, skipAuthenticationPerform
, sharedSecretCheck
, sharedSecretPerform
, AcidRemoteException(..)
, CommChannel(..)
, process
, processRemoteState
) where
import Prelude hiding ( catch )
import Control.Concurrent.STM ( atomically )
import Control.Concurrent.STM.TMVar ( newEmptyTMVar, readTMVar, takeTMVar, tryTakeTMVar, putTMVar )
import Control.Concurrent.STM.TQueue
import Control.Exception ( AsyncException(ThreadKilled)
, Exception(fromException), IOException, Handler(..)
, SomeException, catch, catches, throw, bracketOnError )
import Control.Exception ( throwIO, finally )
import Control.Monad ( forever, liftM, join, when )
import Control.Concurrent ( ThreadId, forkIO, threadDelay, killThread, myThreadId )
import Control.Concurrent.MVar ( MVar, newEmptyMVar, putMVar, takeMVar )
import Control.Concurrent.Chan ( newChan, readChan, writeChan )
import Data.Acid.Abstract
import Data.Acid.Core
import Data.Acid.Common
import Data.Monoid ((<>))
import qualified Data.ByteString as Strict
import Data.ByteString.Char8 ( pack )
import qualified Data.ByteString.Lazy as Lazy
import Data.IORef ( newIORef, readIORef, writeIORef )
import Data.Serialize
import Data.Set ( Set, member )
import Data.Typeable ( Typeable )
import GHC.IO.Exception ( IOErrorType(..) )
import Network.BSD ( PortNumber, getProtocolNumber, getHostByName, hostAddress )
import Network.Socket
import Network.Socket.ByteString as NSB ( recv, sendAll )
import System.Directory ( removeFile )
import System.IO ( Handle, hPrint, hFlush, hClose, stderr, IOMode(..) )
import System.IO.Error ( ioeGetErrorType, isFullError, isDoesNotExistError )
debugStrLn :: String -> IO ()
debugStrLn s =
do
return ()
data CommChannel = CommChannel
{ ccPut :: Strict.ByteString -> IO ()
, ccGetSome :: Int -> IO (Strict.ByteString)
, ccClose :: IO ()
}
data AcidRemoteException
= RemoteConnectionError
| AcidStateClosed
| SerializeError String
| AuthenticationError String
deriving (Eq, Show, Typeable)
instance Exception AcidRemoteException
handleToCommChannel :: Handle -> CommChannel
handleToCommChannel handle =
CommChannel { ccPut = \bs -> Strict.hPut handle bs >> hFlush handle
, ccGetSome = Strict.hGetSome handle
, ccClose = hClose handle
}
socketToCommChannel :: Socket -> CommChannel
socketToCommChannel socket =
CommChannel { ccPut = sendAll socket
, ccGetSome = NSB.recv socket
, ccClose = close socket
}
skipAuthenticationCheck :: CommChannel -> IO Bool
skipAuthenticationCheck _ = return True
skipAuthenticationPerform :: CommChannel -> IO ()
skipAuthenticationPerform _ = return ()
sharedSecretCheck :: Set Strict.ByteString
-> (CommChannel -> IO Bool)
sharedSecretCheck secrets cc =
do bs <- ccGetSome cc 1024
if member bs secrets
then do ccPut cc (pack "OK")
return True
else do ccPut cc (pack "FAIL")
return False
sharedSecretPerform :: Strict.ByteString
-> (CommChannel -> IO ())
sharedSecretPerform pw cc =
do ccPut cc pw
r <- ccGetSome cc 1024
if r == (pack "OK")
then return ()
else throwIO (AuthenticationError "shared secret authentication failed.")
acidServerSockAddr :: (CommChannel -> IO Bool)
-> SockAddr
-> AcidState st
-> IO ()
acidServerSockAddr checkAuth sockAddr acidState
= do listenSocket <- listenOn sockAddr
(acidServer' checkAuth listenSocket acidState) `finally` (cleanup listenSocket)
where
cleanup socket =
do close socket
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
case sockAddr of
(SockAddrUnix path) -> removeFile path
_ -> pure ()
#endif
acidServer :: (CommChannel -> IO Bool)
-> PortNumber
-> AcidState st
-> IO ()
acidServer checkAuth port acidState
= acidServerSockAddr checkAuth (SockAddrInet port 0) acidState
listenOn :: SockAddr -> IO Socket
listenOn sockAddr = do
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
proto <- case sockAddr of
(SockAddrUnix {}) -> pure 0
_ -> getProtocolNumber "tcp"
#else
proto <- getProtocolNumber "tcp"
#endif
bracketOnError
(socket af Stream proto)
close
(\sock -> do
setSocketOption sock ReuseAddr 1
bind sock sockAddr
listen sock maxListenQueue
return sock
)
where
af = case sockAddr of
(SockAddrInet {}) -> AF_INET
(SockAddrInet6 {}) -> AF_INET6
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
(SockAddrUnix {}) -> AF_UNIX
#endif
acidServer' :: (CommChannel -> IO Bool)
-> Socket
-> AcidState st
-> IO ()
acidServer' checkAuth listenSocket acidState
= do
let loop = forever $
do (socket, _sockAddr) <- accept listenSocket
let commChannel = socketToCommChannel socket
forkIO $ do authorized <- checkAuth commChannel
when authorized $
process commChannel acidState
ccClose commChannel
infi = loop `catchSome` logError >> infi
infi
where
logError :: (Show e) => e -> IO ()
logError e = hPrint stderr e
isResourceVanishedError :: IOException -> Bool
isResourceVanishedError = isResourceVanishedType . ioeGetErrorType
isResourceVanishedType :: IOErrorType -> Bool
isResourceVanishedType ResourceVanished = True
isResourceVanishedType _ = False
catchSome :: IO () -> (Show e => e -> IO ()) -> IO ()
catchSome op _h =
op `catches` [ Handler $ \(e :: IOException) ->
if isFullError e || isDoesNotExistError e || isResourceVanishedError e
then return ()
else throw e
]
data Command = RunQuery (Tagged Lazy.ByteString)
| RunUpdate (Tagged Lazy.ByteString)
| CreateCheckpoint
| CreateArchive
instance Serialize Command where
put cmd = case cmd of
RunQuery query -> do putWord8 0; put query
RunUpdate update -> do putWord8 1; put update
CreateCheckpoint -> putWord8 2
CreateArchive -> putWord8 3
get = do tag <- getWord8
case tag of
0 -> liftM RunQuery get
1 -> liftM RunUpdate get
2 -> return CreateCheckpoint
3 -> return CreateArchive
_ -> error $ "Data.Acid.Remote: Serialize.get for Command, invalid tag: " ++ show tag
data Response = Result Lazy.ByteString | Acknowledgement | ConnectionError
instance Serialize Response where
put resp = case resp of
Result result -> do putWord8 0; put result
Acknowledgement -> putWord8 1
ConnectionError -> putWord8 2
get = do tag <- getWord8
case tag of
0 -> liftM Result get
1 -> return Acknowledgement
2 -> return ConnectionError
_ -> error $ "Data.Acid.Remote: Serialize.get for Response, invalid tag: " ++ show tag
process :: CommChannel
-> AcidState st
-> IO ()
process CommChannel{..} acidState
= do chan <- newChan
forkIO $ forever $ do response <- join (readChan chan)
ccPut (encode response)
worker chan (runGetPartial get Strict.empty)
where worker chan inp
= case inp of
Fail msg _ -> throwIO (SerializeError msg)
Partial cont -> do bs <- ccGetSome 1024
if Strict.null bs then
return ()
else
worker chan (cont bs)
Done cmd rest -> do processCommand chan cmd; worker chan (runGetPartial get rest)
processCommand chan cmd =
case cmd of
RunQuery query -> do result <- queryCold acidState query
writeChan chan (return $ Result result)
RunUpdate update -> do result <- scheduleColdUpdate acidState update
writeChan chan (liftM Result $ takeMVar result)
CreateCheckpoint -> do createCheckpoint acidState
writeChan chan (return Acknowledgement)
CreateArchive -> do createArchive acidState
writeChan chan (return Acknowledgement)
data RemoteState st = RemoteState (Command -> IO (MVar Response)) (IO ())
deriving (Typeable)
openRemoteState :: IsAcidic st =>
(CommChannel -> IO ())
-> HostName
-> PortNumber
-> IO (AcidState st)
openRemoteState performAuthorization host port =
do he <- getHostByName host
openRemoteStateSockAddr performAuthorization (SockAddrInet port (hostAddress he))
openRemoteStateSockAddr :: IsAcidic st =>
(CommChannel -> IO ())
-> SockAddr
-> IO (AcidState st)
openRemoteStateSockAddr performAuthorization sockAddr
= withSocketsDo $
do processRemoteState reconnect
where
af :: Family
af = case sockAddr of
(SockAddrInet {}) -> AF_INET
(SockAddrInet6 {}) -> AF_INET6
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
(SockAddrUnix {}) -> AF_UNIX
#endif
reconnect :: IO CommChannel
reconnect
= (do debugStrLn "Reconnecting."
#if !defined(mingw32_HOST_OS) && !defined(cygwin32_HOST_OS) && !defined(_WIN32)
proto <- case sockAddr of
(SockAddrUnix {}) -> pure 0
_ -> getProtocolNumber "tcp"
#else
proto <- getProtocolNumber "tcp"
#endif
handle <- bracketOnError
(socket af Stream proto)
close
(\sock -> do
connect sock sockAddr
socketToHandle sock ReadWriteMode
)
let cc = handleToCommChannel handle
performAuthorization cc
debugStrLn "Reconnected."
return cc
)
`catch`
((\_ -> threadDelay 1000000 >> reconnect) :: IOError -> IO CommChannel)
processRemoteState :: IsAcidic st =>
IO CommChannel
-> IO (AcidState st)
processRemoteState reconnect
= do cmdQueue <- atomically newTQueue
ccTMV <- atomically newEmptyTMVar
isClosed <- newIORef False
let actor :: Command -> IO (MVar Response)
actor command =
do debugStrLn "actor: begin."
readIORef isClosed >>= flip when (throwIO AcidStateClosed)
ref <- newEmptyMVar
atomically $ writeTQueue cmdQueue (command, ref)
debugStrLn "actor: end."
return ref
expireQueue listenQueue =
do mCallback <- atomically $ tryReadTQueue listenQueue
case mCallback of
Nothing -> return ()
(Just callback) ->
do callback ConnectionError
expireQueue listenQueue
handleReconnect :: SomeException -> IO ()
handleReconnect e
= case fromException e of
(Just ThreadKilled) ->
do debugStrLn "handleReconnect: ThreadKilled. Not attempting to reconnect."
return ()
_ ->
do debugStrLn $ "handleReconnect begin."
tmv <- atomically $ tryTakeTMVar ccTMV
case tmv of
Nothing ->
do debugStrLn $ "handleReconnect: error handling already in progress."
debugStrLn $ "handleReconnect end."
return ()
(Just (oldCC, oldListenQueue, oldListenerTID)) ->
do thisTID <- myThreadId
when (thisTID /= oldListenerTID) (killThread oldListenerTID)
ccClose oldCC
expireQueue oldListenQueue
cc <- reconnect
listenQueue <- atomically $ newTQueue
listenerTID <- forkIO $ listener cc listenQueue
atomically $ putTMVar ccTMV (cc, listenQueue, listenerTID)
debugStrLn $ "handleReconnect end."
return ()
listener :: CommChannel -> TQueue (Response -> IO ()) -> IO ()
listener cc listenQueue
= getResponse Strict.empty `catch` handleReconnect
where
getResponse leftover =
do debugStrLn $ "listener: listening for Response."
let go inp = case inp of
Fail msg _ -> error $ "Data.Acid.Remote: " <> msg
Partial cont -> do debugStrLn $ "listener: ccGetSome"
bs <- ccGetSome cc 1024
go (cont bs)
Done resp rest -> do debugStrLn $ "listener: getting callback"
callback <- atomically $ readTQueue listenQueue
debugStrLn $ "listener: passing Response to callback"
callback (resp :: Response)
return rest
rest <- go (runGetPartial get leftover)
getResponse rest
actorThread :: IO ()
actorThread = forever $
do debugStrLn "actorThread: waiting for something to do."
(cc, cmd) <- atomically $
do (cmd, ref) <- readTQueue cmdQueue
(cc, listenQueue, _) <- readTMVar ccTMV
writeTQueue listenQueue (putMVar ref)
return (cc, cmd)
debugStrLn "actorThread: sending command."
ccPut cc (encode cmd) `catch` handleReconnect
debugStrLn "actorThread: sent."
return ()
shutdown :: ThreadId -> IO ()
shutdown actorTID =
do debugStrLn "shutdown: update isClosed IORef to True."
writeIORef isClosed True
debugStrLn "shutdown: killing actor thread."
killThread actorTID
debugStrLn "shutdown: taking ccTMV."
(cc, listenQueue, listenerTID) <- atomically $ takeTMVar ccTMV
debugStrLn "shutdown: killing listener thread."
killThread listenerTID
debugStrLn "shutdown: expiring listen queue."
expireQueue listenQueue
debugStrLn "shutdown: closing connection."
ccClose cc
return ()
cc <- reconnect
listenQueue <- atomically $ newTQueue
actorTID <- forkIO $ actorThread
listenerTID <- forkIO $ listener cc listenQueue
atomically $ putTMVar ccTMV (cc, listenQueue, listenerTID)
return (toAcidState $ RemoteState actor (shutdown actorTID))
remoteQuery :: QueryEvent event => RemoteState (EventState event) -> MethodMap (EventState event) -> event -> IO (EventResult event)
remoteQuery acidState mmap event
= do let encoded = encodeMethod ms event
resp <- remoteQueryCold acidState (methodTag event, encoded)
return (case decodeResult ms resp of
Left msg -> error $ "Data.Acid.Remote: " <> msg
Right result -> result)
where
(_, ms) = lookupHotMethodAndSerialiser mmap event
remoteQueryCold :: RemoteState st -> Tagged Lazy.ByteString -> IO Lazy.ByteString
remoteQueryCold rs@(RemoteState fn _shutdown) event
= do resp <- takeMVar =<< fn (RunQuery event)
case resp of
(Result result) -> return result
ConnectionError -> do debugStrLn "retrying query event."
remoteQueryCold rs event
Acknowledgement -> error "Data.Acid.Remote: remoteQueryCold got Acknowledgement. That should never happen."
scheduleRemoteUpdate :: UpdateEvent event => RemoteState (EventState event) -> MethodMap (EventState event) -> event -> IO (MVar (EventResult event))
scheduleRemoteUpdate (RemoteState fn _shutdown) mmap event
= do let encoded = encodeMethod ms event
parsed <- newEmptyMVar
respRef <- fn (RunUpdate (methodTag event, encoded))
forkIO $ do Result resp <- takeMVar respRef
putMVar parsed (case decodeResult ms resp of
Left msg -> error $ "Data.Acid.Remote: " <> msg
Right result -> result)
return parsed
where
(_, ms) = lookupHotMethodAndSerialiser mmap event
scheduleRemoteColdUpdate :: RemoteState st -> Tagged Lazy.ByteString -> IO (MVar Lazy.ByteString)
scheduleRemoteColdUpdate (RemoteState fn _shutdown) event
= do parsed <- newEmptyMVar
respRef <- fn (RunUpdate event)
forkIO $ do Result resp <- takeMVar respRef
putMVar parsed resp
return parsed
closeRemoteState :: RemoteState st -> IO ()
closeRemoteState (RemoteState _fn shutdown) = shutdown
createRemoteCheckpoint :: RemoteState st -> IO ()
createRemoteCheckpoint (RemoteState fn _shutdown)
= do Acknowledgement <- takeMVar =<< fn CreateCheckpoint
return ()
createRemoteArchive :: RemoteState st -> IO ()
createRemoteArchive (RemoteState fn _shutdown)
= do Acknowledgement <- takeMVar =<< fn CreateArchive
return ()
toAcidState :: forall st . IsAcidic st => RemoteState st -> AcidState st
toAcidState remote
= AcidState { _scheduleUpdate = scheduleRemoteUpdate remote mmap
, scheduleColdUpdate = scheduleRemoteColdUpdate remote
, _query = remoteQuery remote mmap
, queryCold = remoteQueryCold remote
, createCheckpoint = createRemoteCheckpoint remote
, createArchive = createRemoteArchive remote
, closeAcidState = closeRemoteState remote
, acidSubState = mkAnyState remote
}
where
mmap :: MethodMap st
mmap = mkMethodMap (eventsToMethods acidEvents)