module Database.MongoDB.Internal.Protocol (
Pipe, send, call,
FullCollection,
Notice(..), UpdateOption(..), DeleteOption(..), CursorId,
Request(..), QueryOption(..),
Reply(..), ResponseFlag(..),
Username, Password, Nonce, pwHash, pwKey
) where
import Prelude as X
import Control.Applicative ((<$>))
import Control.Arrow ((***))
import Data.ByteString.Lazy as B (length, hPut)
import qualified Control.Pipeline as P
import Data.Bson (Document, UString)
import Data.Bson.Binary
import Data.Binary.Put
import Data.Binary.Get
import Data.Int
import Data.Bits
import Data.IORef
import System.IO.Unsafe (unsafePerformIO)
import qualified Crypto.Hash.MD5 as MD5 (hash)
import Data.UString as U (pack, append, toByteString)
import qualified Data.ByteString as BS (ByteString, unpack)
import Data.Word (Word8)
import System.IO.Error as E (try)
import Control.Monad.Error
import Control.Monad.Util (whenJust)
import Network.Abstract hiding (send)
import System.IO (hFlush)
import Database.MongoDB.Internal.Util (hGetN, bitOr)
import Numeric (showHex)
type Pipe = P.Pipeline Message Response
send :: Pipe -> [Notice] -> IOE ()
send pipe notices = P.send pipe (notices, Nothing)
call :: Pipe -> [Notice] -> Request -> IOE (IOE Reply)
call pipe notices request = do
requestId <- genRequestId
promise <- P.call pipe (notices, Just (request, requestId))
return $ check requestId <$> promise
where
check requestId (responseTo, reply) = if requestId == responseTo then reply else
error $ "expected response id (" ++ show responseTo ++ ") to match request id (" ++ show requestId ++ ")"
type Message = ([Notice], Maybe (Request, RequestId))
instance WriteMessage Message where
writeMessage handle (notices, mRequest) = ErrorT . E.try $ do
forM_ notices $ \n -> writeReq . (Left n,) =<< genRequestId
whenJust mRequest $ writeReq . (Right *** id)
hFlush handle
where
writeReq (e, requestId) = do
hPut handle lenBytes
hPut handle bytes
where
bytes = runPut $ (either putNotice putRequest e) requestId
lenBytes = encodeSize . toEnum . fromEnum $ B.length bytes
encodeSize = runPut . putInt32 . (+ 4)
type Response = (ResponseTo, Reply)
instance ReadMessage Response where
readMessage handle = ErrorT . E.try $ readResp where
readResp = do
len <- fromEnum . decodeSize <$> hGetN handle 4
runGet getReply <$> hGetN handle len
decodeSize = subtract 4 . runGet getInt32
type FullCollection = UString
type Opcode = Int32
type RequestId = Int32
type ResponseTo = RequestId
genRequestId :: (MonadIO m) => m RequestId
genRequestId = liftIO $ atomicModifyIORef counter $ \n -> (n + 1, n) where
counter :: IORef RequestId
counter = unsafePerformIO (newIORef 0)
putHeader :: Opcode -> RequestId -> Put
putHeader opcode requestId = do
putInt32 requestId
putInt32 0
putInt32 opcode
getHeader :: Get (Opcode, ResponseTo)
getHeader = do
_requestId <- getInt32
responseTo <- getInt32
opcode <- getInt32
return (opcode, responseTo)
data Notice =
Insert {
iFullCollection :: FullCollection,
iDocuments :: [Document]}
| Update {
uFullCollection :: FullCollection,
uOptions :: [UpdateOption],
uSelector :: Document,
uUpdater :: Document}
| Delete {
dFullCollection :: FullCollection,
dOptions :: [DeleteOption],
dSelector :: Document}
| KillCursors {
kCursorIds :: [CursorId]}
deriving (Show, Eq)
data UpdateOption =
Upsert
| MultiUpdate
deriving (Show, Eq)
data DeleteOption = SingleRemove
deriving (Show, Eq)
type CursorId = Int64
nOpcode :: Notice -> Opcode
nOpcode Update{} = 2001
nOpcode Insert{} = 2002
nOpcode Delete{} = 2006
nOpcode KillCursors{} = 2007
putNotice :: Notice -> RequestId -> Put
putNotice notice requestId = do
putHeader (nOpcode notice) requestId
putInt32 0
case notice of
Insert{..} -> do
putCString iFullCollection
mapM_ putDocument iDocuments
Update{..} -> do
putCString uFullCollection
putInt32 (uBits uOptions)
putDocument uSelector
putDocument uUpdater
Delete{..} -> do
putCString dFullCollection
putInt32 (dBits dOptions)
putDocument dSelector
KillCursors{..} -> do
putInt32 $ toEnum (X.length kCursorIds)
mapM_ putInt64 kCursorIds
uBit :: UpdateOption -> Int32
uBit Upsert = bit 0
uBit MultiUpdate = bit 1
uBits :: [UpdateOption] -> Int32
uBits = bitOr . map uBit
dBit :: DeleteOption -> Int32
dBit SingleRemove = bit 0
dBits :: [DeleteOption] -> Int32
dBits = bitOr . map dBit
data Request =
Query {
qOptions :: [QueryOption],
qFullCollection :: FullCollection,
qSkip :: Int32,
qBatchSize :: Int32,
qSelector :: Document,
qProjector :: Document
} | GetMore {
gFullCollection :: FullCollection,
gBatchSize :: Int32,
gCursorId :: CursorId}
deriving (Show, Eq)
data QueryOption =
TailableCursor
| SlaveOK
| NoCursorTimeout
| AwaitData
deriving (Show, Eq)
qOpcode :: Request -> Opcode
qOpcode Query{} = 2004
qOpcode GetMore{} = 2005
putRequest :: Request -> RequestId -> Put
putRequest request requestId = do
putHeader (qOpcode request) requestId
case request of
Query{..} -> do
putInt32 (qBits qOptions)
putCString qFullCollection
putInt32 qSkip
putInt32 qBatchSize
putDocument qSelector
unless (null qProjector) (putDocument qProjector)
GetMore{..} -> do
putInt32 0
putCString gFullCollection
putInt32 gBatchSize
putInt64 gCursorId
qBit :: QueryOption -> Int32
qBit TailableCursor = bit 1
qBit SlaveOK = bit 2
qBit NoCursorTimeout = bit 4
qBit AwaitData = bit 5
qBits :: [QueryOption] -> Int32
qBits = bitOr . map qBit
data Reply = Reply {
rResponseFlags :: [ResponseFlag],
rCursorId :: CursorId,
rStartingFrom :: Int32,
rDocuments :: [Document]
} deriving (Show, Eq)
data ResponseFlag =
CursorNotFound
| QueryError
| AwaitCapable
deriving (Show, Eq, Enum)
replyOpcode :: Opcode
replyOpcode = 1
getReply :: Get (ResponseTo, Reply)
getReply = do
(opcode, responseTo) <- getHeader
unless (opcode == replyOpcode) $ fail $ "expected reply opcode (1) but got " ++ show opcode
rResponseFlags <- rFlags <$> getInt32
rCursorId <- getInt64
rStartingFrom <- getInt32
numDocs <- fromIntegral <$> getInt32
rDocuments <- replicateM numDocs getDocument
return (responseTo, Reply{..})
rFlags :: Int32 -> [ResponseFlag]
rFlags bits = filter (testBit bits . rBit) [CursorNotFound ..]
rBit :: ResponseFlag -> Int
rBit CursorNotFound = 0
rBit QueryError = 1
rBit AwaitCapable = 3
type Username = UString
type Password = UString
type Nonce = UString
pwHash :: Username -> Password -> UString
pwHash u p = pack . byteStringHex . MD5.hash . toByteString $ u `U.append` ":mongo:" `U.append` p
pwKey :: Nonce -> Username -> Password -> UString
pwKey n u p = pack . byteStringHex . MD5.hash . toByteString . U.append n . U.append u $ pwHash u p
byteStringHex :: BS.ByteString -> String
byteStringHex = concatMap byteHex . BS.unpack
byteHex :: Word8 -> String
byteHex b = (if b < 16 then ('0' :) else id) (showHex b "")