module Database.MongoDB.Internal.Protocol (
FullCollection,
Pipe, newPipe, newPipeWith, send, call,
Notice(..), InsertOption(..), UpdateOption(..), DeleteOption(..), CursorId,
Request(..), QueryOption(..),
Reply(..), ResponseFlag(..),
Username, Password, Nonce, pwHash, pwKey
) where
#if !MIN_VERSION_base(4,8,0)
import Control.Applicative ((<$>))
#endif
import Control.Arrow ((***))
import Control.Monad (forM_, replicateM, unless)
import Data.Binary.Get (Get, runGet)
import Data.Binary.Put (Put, runPut)
import Data.Bits (bit, testBit)
import Data.Int (Int32, Int64)
import Data.IORef (IORef, newIORef, atomicModifyIORef)
import System.IO (Handle)
import System.IO.Unsafe (unsafePerformIO)
import qualified Data.ByteString.Lazy as L
import Control.Monad.Trans (MonadIO, liftIO)
import Data.Bson (Document)
import Data.Bson.Binary (getDocument, putDocument, getInt32, putInt32, getInt64,
putInt64, putCString)
import Data.Text (Text)
import qualified Crypto.Hash.MD5 as MD5
import qualified Data.Text as T
import qualified Data.Text.Encoding as TE
import Database.MongoDB.Internal.Util (whenJust, bitOr, byteStringHex)
import System.IO.Pipeline (Pipeline, newPipeline, IOStream(..))
import qualified System.IO.Pipeline as P
import Database.MongoDB.Internal.Connection (Connection)
import qualified Database.MongoDB.Internal.Connection as Connection
type Pipe = Pipeline Response Message
newPipe :: Handle -> IO Pipe
newPipe handle = Connection.fromHandle handle >>= newPipeWith
newPipeWith :: Connection -> IO Pipe
newPipeWith conn = newPipeline $ IOStream (writeMessage conn)
(readMessage conn)
(Connection.close conn)
send :: Pipe -> [Notice] -> IO ()
send pipe notices = P.send pipe (notices, Nothing)
call :: Pipe -> [Notice] -> Request -> IO (IO Reply)
call pipe notices request = do
requestId <- genRequestId
promise <- P.call pipe (notices, Just (request, requestId))
return $ check requestId <$> promise
where
check requestId (responseTo, reply) = if requestId == responseTo then reply else
error $ "expected response id (" ++ show responseTo ++ ") to match request id (" ++ show requestId ++ ")"
type Message = ([Notice], Maybe (Request, RequestId))
writeMessage :: Connection -> Message -> IO ()
writeMessage conn (notices, mRequest) = do
forM_ notices $ \n -> writeReq . (Left n,) =<< genRequestId
whenJust mRequest $ writeReq . (Right *** id)
Connection.flush conn
where
writeReq (e, requestId) = do
Connection.writeLazy conn lenBytes
Connection.writeLazy conn bytes
where
bytes = runPut $ (either putNotice putRequest e) requestId
lenBytes = encodeSize . toEnum . fromEnum $ L.length bytes
encodeSize = runPut . putInt32 . (+ 4)
type Response = (ResponseTo, Reply)
readMessage :: Connection -> IO Response
readMessage conn = readResp where
readResp = do
len <- fromEnum . decodeSize <$> Connection.readExactly conn 4
runGet getReply <$> Connection.readExactly conn len
decodeSize = subtract 4 . runGet getInt32
type FullCollection = Text
type Opcode = Int32
type RequestId = Int32
type ResponseTo = RequestId
genRequestId :: (MonadIO m) => m RequestId
genRequestId = liftIO $ atomicModifyIORef counter $ \n -> (n + 1, n) where
counter :: IORef RequestId
counter = unsafePerformIO (newIORef 0)
putHeader :: Opcode -> RequestId -> Put
putHeader opcode requestId = do
putInt32 requestId
putInt32 0
putInt32 opcode
getHeader :: Get (Opcode, ResponseTo)
getHeader = do
_requestId <- getInt32
responseTo <- getInt32
opcode <- getInt32
return (opcode, responseTo)
data Notice =
Insert {
iFullCollection :: FullCollection,
iOptions :: [InsertOption],
iDocuments :: [Document]}
| Update {
uFullCollection :: FullCollection,
uOptions :: [UpdateOption],
uSelector :: Document,
uUpdater :: Document}
| Delete {
dFullCollection :: FullCollection,
dOptions :: [DeleteOption],
dSelector :: Document}
| KillCursors {
kCursorIds :: [CursorId]}
deriving (Show, Eq)
data InsertOption = KeepGoing
deriving (Show, Eq)
data UpdateOption =
Upsert
| MultiUpdate
deriving (Show, Eq)
data DeleteOption = SingleRemove
deriving (Show, Eq)
type CursorId = Int64
nOpcode :: Notice -> Opcode
nOpcode Update{} = 2001
nOpcode Insert{} = 2002
nOpcode Delete{} = 2006
nOpcode KillCursors{} = 2007
putNotice :: Notice -> RequestId -> Put
putNotice notice requestId = do
putHeader (nOpcode notice) requestId
case notice of
Insert{..} -> do
putInt32 (iBits iOptions)
putCString iFullCollection
mapM_ putDocument iDocuments
Update{..} -> do
putInt32 0
putCString uFullCollection
putInt32 (uBits uOptions)
putDocument uSelector
putDocument uUpdater
Delete{..} -> do
putInt32 0
putCString dFullCollection
putInt32 (dBits dOptions)
putDocument dSelector
KillCursors{..} -> do
putInt32 0
putInt32 $ toEnum (length kCursorIds)
mapM_ putInt64 kCursorIds
iBit :: InsertOption -> Int32
iBit KeepGoing = bit 0
iBits :: [InsertOption] -> Int32
iBits = bitOr . map iBit
uBit :: UpdateOption -> Int32
uBit Upsert = bit 0
uBit MultiUpdate = bit 1
uBits :: [UpdateOption] -> Int32
uBits = bitOr . map uBit
dBit :: DeleteOption -> Int32
dBit SingleRemove = bit 0
dBits :: [DeleteOption] -> Int32
dBits = bitOr . map dBit
data Request =
Query {
qOptions :: [QueryOption],
qFullCollection :: FullCollection,
qSkip :: Int32,
qBatchSize :: Int32,
qSelector :: Document,
qProjector :: Document
} | GetMore {
gFullCollection :: FullCollection,
gBatchSize :: Int32,
gCursorId :: CursorId}
deriving (Show, Eq)
data QueryOption =
TailableCursor
| SlaveOK
| NoCursorTimeout
| AwaitData
| Partial
deriving (Show, Eq)
qOpcode :: Request -> Opcode
qOpcode Query{} = 2004
qOpcode GetMore{} = 2005
putRequest :: Request -> RequestId -> Put
putRequest request requestId = do
putHeader (qOpcode request) requestId
case request of
Query{..} -> do
putInt32 (qBits qOptions)
putCString qFullCollection
putInt32 qSkip
putInt32 qBatchSize
putDocument qSelector
unless (null qProjector) (putDocument qProjector)
GetMore{..} -> do
putInt32 0
putCString gFullCollection
putInt32 gBatchSize
putInt64 gCursorId
qBit :: QueryOption -> Int32
qBit TailableCursor = bit 1
qBit SlaveOK = bit 2
qBit NoCursorTimeout = bit 4
qBit AwaitData = bit 5
qBit Partial = bit 7
qBits :: [QueryOption] -> Int32
qBits = bitOr . map qBit
data Reply = Reply {
rResponseFlags :: [ResponseFlag],
rCursorId :: CursorId,
rStartingFrom :: Int32,
rDocuments :: [Document]
} deriving (Show, Eq)
data ResponseFlag =
CursorNotFound
| QueryError
| AwaitCapable
deriving (Show, Eq, Enum)
replyOpcode :: Opcode
replyOpcode = 1
getReply :: Get (ResponseTo, Reply)
getReply = do
(opcode, responseTo) <- getHeader
unless (opcode == replyOpcode) $ fail $ "expected reply opcode (1) but got " ++ show opcode
rResponseFlags <- rFlags <$> getInt32
rCursorId <- getInt64
rStartingFrom <- getInt32
numDocs <- fromIntegral <$> getInt32
rDocuments <- replicateM numDocs getDocument
return (responseTo, Reply{..})
rFlags :: Int32 -> [ResponseFlag]
rFlags bits = filter (testBit bits . rBit) [CursorNotFound ..]
rBit :: ResponseFlag -> Int
rBit CursorNotFound = 0
rBit QueryError = 1
rBit AwaitCapable = 3
type Username = Text
type Password = Text
type Nonce = Text
pwHash :: Username -> Password -> Text
pwHash u p = T.pack . byteStringHex . MD5.hash . TE.encodeUtf8 $ u `T.append` ":mongo:" `T.append` p
pwKey :: Nonce -> Username -> Password -> Text
pwKey n u p = T.pack . byteStringHex . MD5.hash . TE.encodeUtf8 . T.append n . T.append u $ pwHash u p