{-# LANGUAGE RecordWildCards, StandaloneDeriving, OverloadedStrings #-}
{-# LANGUAGE CPP, FlexibleContexts, TupleSections, TypeSynonymInstances #-}
{-# LANGUAGE MultiParamTypeClasses, FlexibleInstances, UndecidableInstances #-}
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE NamedFieldPuns, ScopedTypeVariables #-}
#if (__GLASGOW_HASKELL__ >= 706)
{-# LANGUAGE RecursiveDo #-}
#else
{-# LANGUAGE DoRec #-}
#endif
module Database.MongoDB.Internal.Protocol (
FullCollection,
Pipe, newPipe, newPipeWith, send, call,
Notice(..), InsertOption(..), UpdateOption(..), DeleteOption(..), CursorId,
Request(..), QueryOption(..),
Reply(..), ResponseFlag(..),
Username, Password, Nonce, pwHash, pwKey,
isClosed, close, ServerData(..), Pipeline(..)
) where
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative ((<$>))
#endif
import Control.Monad (forM, replicateM, unless)
import Data.Binary.Get (Get, runGet)
import Data.Binary.Put (Put, runPut)
import Data.Bits (bit, testBit)
import Data.Int (Int32, Int64)
import Data.IORef (IORef, newIORef, atomicModifyIORef)
import System.IO (Handle)
import System.IO.Error (doesNotExistErrorType, mkIOError)
import System.IO.Unsafe (unsafePerformIO)
import Data.Maybe (maybeToList)
import GHC.Conc (ThreadStatus(..), threadStatus)
import Control.Monad (forever)
import Control.Monad.STM (atomically)
import Control.Concurrent (ThreadId, killThread, forkIOWithUnmask)
import Control.Concurrent.STM.TChan (TChan, newTChan, readTChan, writeTChan, isEmptyTChan)
import Control.Exception.Lifted (SomeException, mask_, onException, throwIO, try)
import qualified Data.ByteString.Lazy as L
import Control.Monad.Trans (MonadIO, liftIO)
import Data.Bson (Document)
import Data.Bson.Binary (getDocument, putDocument, getInt32, putInt32, getInt64,
putInt64, putCString)
import Data.Text (Text)
import qualified Crypto.Hash.MD5 as MD5
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import Database.MongoDB.Internal.Util (bitOr, byteStringHex)
import Database.MongoDB.Transport (Transport)
import qualified Database.MongoDB.Transport as Tr
#if MIN_VERSION_base(4,6,0)
import Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar, withMVar,
putMVar, readMVar, mkWeakMVar, isEmptyMVar)
#else
import Control.Concurrent.MVar.Lifted (MVar, newEmptyMVar, newMVar, withMVar,
putMVar, readMVar, addMVarFinalizer)
#endif
#if !MIN_VERSION_base(4,6,0)
mkWeakMVar :: MVar a -> IO () -> IO ()
mkWeakMVar = addMVarFinalizer
#endif
data Pipeline = Pipeline
{ vStream :: MVar Transport
, responseQueue :: TChan (MVar (Either IOError Response))
, listenThread :: ThreadId
, finished :: MVar ()
, serverData :: ServerData
}
data ServerData = ServerData
{ isMaster :: Bool
, minWireVersion :: Int
, maxWireVersion :: Int
, maxMessageSizeBytes :: Int
, maxBsonObjectSize :: Int
, maxWriteBatchSize :: Int
}
forkUnmaskedFinally :: IO a -> (Either SomeException a -> IO ()) -> IO ThreadId
forkUnmaskedFinally action and_then =
mask_ $ forkIOWithUnmask $ \unmask ->
try (unmask action) >>= and_then
newPipeline :: ServerData -> Transport -> IO Pipeline
newPipeline serverData stream = do
vStream <- newMVar stream
responseQueue <- atomically newTChan
finished <- newEmptyMVar
let drainReplies = do
chanEmpty <- atomically $ isEmptyTChan responseQueue
if chanEmpty
then return ()
else do
var <- atomically $ readTChan responseQueue
putMVar var $ Left $ mkIOError
doesNotExistErrorType
"Handle has been closed"
Nothing
Nothing
drainReplies
rec
let pipe = Pipeline{..}
listenThread <- forkUnmaskedFinally (listen pipe) $ \_ -> do
putMVar finished ()
drainReplies
_ <- mkWeakMVar vStream $ do
killThread listenThread
Tr.close stream
return pipe
isFinished :: Pipeline -> IO Bool
isFinished Pipeline {finished} = do
empty <- isEmptyMVar finished
return $ not empty
close :: Pipeline -> IO ()
close Pipeline{..} = do
killThread listenThread
Tr.close =<< readMVar vStream
isClosed :: Pipeline -> IO Bool
isClosed Pipeline{listenThread} = do
status <- threadStatus listenThread
return $ case status of
ThreadRunning -> False
ThreadFinished -> True
ThreadBlocked _ -> False
ThreadDied -> True
listen :: Pipeline -> IO ()
listen Pipeline{..} = do
stream <- readMVar vStream
forever $ do
e <- try $ readMessage stream
var <- atomically $ readTChan responseQueue
putMVar var e
case e of
Left err -> Tr.close stream >> ioError err
Right _ -> return ()
psend :: Pipeline -> Message -> IO ()
psend p@Pipeline{..} !message = withMVar vStream (flip writeMessage message) `onException` close p
pcall :: Pipeline -> Message -> IO (IO Response)
pcall p@Pipeline{..} message = do
listenerStopped <- isFinished p
if listenerStopped
then ioError $ mkIOError doesNotExistErrorType "Handle has been closed" Nothing Nothing
else withMVar vStream doCall `onException` close p
where
doCall stream = do
writeMessage stream message
var <- newEmptyMVar
liftIO $ atomically $ writeTChan responseQueue var
return $ readMVar var >>= either throwIO return
type Pipe = Pipeline
newPipe :: ServerData -> Handle -> IO Pipe
newPipe sd handle = Tr.fromHandle handle >>= (newPipeWith sd)
newPipeWith :: ServerData -> Transport -> IO Pipe
newPipeWith sd conn = newPipeline sd conn
send :: Pipe -> [Notice] -> IO ()
send pipe notices = psend pipe (notices, Nothing)
call :: Pipe -> [Notice] -> Request -> IO (IO Reply)
call pipe notices request = do
requestId <- genRequestId
promise <- pcall pipe (notices, Just (request, requestId))
return $ check requestId <$> promise
where
check requestId (responseTo, reply) = if requestId == responseTo then reply else
error $ "expected response id (" ++ show responseTo ++ ") to match request id (" ++ show requestId ++ ")"
type Message = ([Notice], Maybe (Request, RequestId))
writeMessage :: Transport -> Message -> IO ()
writeMessage conn (notices, mRequest) = do
noticeStrings <- forM notices $ \n -> do
requestId <- genRequestId
let s = runPut $ putNotice n requestId
return $ (lenBytes s) `L.append` s
let requestString = do
(request, requestId) <- mRequest
let s = runPut $ putRequest request requestId
return $ (lenBytes s) `L.append` s
Tr.write conn $ L.toStrict $ L.concat $ noticeStrings ++ (maybeToList requestString)
Tr.flush conn
where
lenBytes bytes = encodeSize . toEnum . fromEnum $ L.length bytes
encodeSize = runPut . putInt32 . (+ 4)
type Response = (ResponseTo, Reply)
readMessage :: Transport -> IO Response
readMessage conn = readResp where
readResp = do
len <- fromEnum . decodeSize . L.fromStrict <$> Tr.read conn 4
runGet getReply . L.fromStrict <$> Tr.read conn len
decodeSize = subtract 4 . runGet getInt32
type FullCollection = Text
type Opcode = Int32
type RequestId = Int32
type ResponseTo = RequestId
genRequestId :: (MonadIO m) => m RequestId
genRequestId = liftIO $ atomicModifyIORef counter $ \n -> (n + 1, n) where
counter :: IORef RequestId
counter = unsafePerformIO (newIORef 0)
{-# NOINLINE counter #-}
putHeader :: Opcode -> RequestId -> Put
putHeader opcode requestId = do
putInt32 requestId
putInt32 0
putInt32 opcode
getHeader :: Get (Opcode, ResponseTo)
getHeader = do
_requestId <- getInt32
responseTo <- getInt32
opcode <- getInt32
return (opcode, responseTo)
data Notice =
Insert {
iFullCollection :: FullCollection,
iOptions :: [InsertOption],
iDocuments :: [Document]}
| Update {
uFullCollection :: FullCollection,
uOptions :: [UpdateOption],
uSelector :: Document,
uUpdater :: Document}
| Delete {
dFullCollection :: FullCollection,
dOptions :: [DeleteOption],
dSelector :: Document}
| KillCursors {
kCursorIds :: [CursorId]}
deriving (Show, Eq)
data InsertOption = KeepGoing
deriving (Show, Eq)
data UpdateOption =
Upsert
| MultiUpdate
deriving (Show, Eq)
data DeleteOption = SingleRemove
deriving (Show, Eq)
type CursorId = Int64
nOpcode :: Notice -> Opcode
nOpcode Update{} = 2001
nOpcode Insert{} = 2002
nOpcode Delete{} = 2006
nOpcode KillCursors{} = 2007
putNotice :: Notice -> RequestId -> Put
putNotice notice requestId = do
putHeader (nOpcode notice) requestId
case notice of
Insert{..} -> do
putInt32 (iBits iOptions)
putCString iFullCollection
mapM_ putDocument iDocuments
Update{..} -> do
putInt32 0
putCString uFullCollection
putInt32 (uBits uOptions)
putDocument uSelector
putDocument uUpdater
Delete{..} -> do
putInt32 0
putCString dFullCollection
putInt32 (dBits dOptions)
putDocument dSelector
KillCursors{..} -> do
putInt32 0
putInt32 $ toEnum (length kCursorIds)
mapM_ putInt64 kCursorIds
iBit :: InsertOption -> Int32
iBit KeepGoing = bit 0
iBits :: [InsertOption] -> Int32
iBits = bitOr . map iBit
uBit :: UpdateOption -> Int32
uBit Upsert = bit 0
uBit MultiUpdate = bit 1
uBits :: [UpdateOption] -> Int32
uBits = bitOr . map uBit
dBit :: DeleteOption -> Int32
dBit SingleRemove = bit 0
dBits :: [DeleteOption] -> Int32
dBits = bitOr . map dBit
data Request =
Query {
qOptions :: [QueryOption],
qFullCollection :: FullCollection,
qSkip :: Int32,
qBatchSize :: Int32,
qSelector :: Document,
qProjector :: Document
} | GetMore {
gFullCollection :: FullCollection,
gBatchSize :: Int32,
gCursorId :: CursorId}
deriving (Show, Eq)
data QueryOption =
TailableCursor
| SlaveOK
| NoCursorTimeout
| AwaitData
| Partial
deriving (Show, Eq)
qOpcode :: Request -> Opcode
qOpcode Query{} = 2004
qOpcode GetMore{} = 2005
putRequest :: Request -> RequestId -> Put
putRequest request requestId = do
putHeader (qOpcode request) requestId
case request of
Query{..} -> do
putInt32 (qBits qOptions)
putCString qFullCollection
putInt32 qSkip
putInt32 qBatchSize
putDocument qSelector
unless (null qProjector) (putDocument qProjector)
GetMore{..} -> do
putInt32 0
putCString gFullCollection
putInt32 gBatchSize
putInt64 gCursorId
qBit :: QueryOption -> Int32
qBit TailableCursor = bit 1
qBit SlaveOK = bit 2
qBit NoCursorTimeout = bit 4
qBit AwaitData = bit 5
qBit Partial = bit 7
qBits :: [QueryOption] -> Int32
qBits = bitOr . map qBit
data Reply = Reply {
rResponseFlags :: [ResponseFlag],
rCursorId :: CursorId,
rStartingFrom :: Int32,
rDocuments :: [Document]
} deriving (Show, Eq)
data ResponseFlag =
CursorNotFound
| QueryError
| AwaitCapable
deriving (Show, Eq, Enum)
replyOpcode :: Opcode
replyOpcode = 1
getReply :: Get (ResponseTo, Reply)
getReply = do
(opcode, responseTo) <- getHeader
unless (opcode == replyOpcode) $ fail $ "expected reply opcode (1) but got " ++ show opcode
rResponseFlags <- rFlags <$> getInt32
rCursorId <- getInt64
rStartingFrom <- getInt32
numDocs <- fromIntegral <$> getInt32
rDocuments <- replicateM numDocs getDocument
return (responseTo, Reply{..})
rFlags :: Int32 -> [ResponseFlag]
rFlags bits = filter (testBit bits . rBit) [CursorNotFound ..]
rBit :: ResponseFlag -> Int
rBit CursorNotFound = 0
rBit QueryError = 1
rBit AwaitCapable = 3
type Username = Text
type Password = Text
type Nonce = Text
pwHash :: Username -> Password -> Text
pwHash u p = T.pack . byteStringHex . MD5.hash . TE.encodeUtf8 $ u `T.append` ":mongo:" `T.append` p
pwKey :: Nonce -> Username -> Password -> Text
pwKey n u p = T.pack . byteStringHex . MD5.hash . TE.encodeUtf8 . T.append n . T.append u $ pwHash u p