module Database.MongoDB.Internal.Protocol (
FullCollection,
Pipe, newPipe, send, call,
Notice(..), InsertOption(..), UpdateOption(..), DeleteOption(..), CursorId,
Request(..), QueryOption(..),
Reply(..), ResponseFlag(..),
Username, Password, Nonce, pwHash, pwKey
) where
import Prelude as X
import Control.Applicative ((<$>))
import Control.Arrow ((***))
import Data.ByteString.Lazy as B (length, hPut)
import System.IO.Pipeline (IOE, Pipeline, newPipeline, IOStream(..))
import qualified System.IO.Pipeline as P (send, call)
import System.IO (Handle, hClose)
import Data.Bson (Document, UString)
import Data.Bson.Binary
import Data.Binary.Put
import Data.Binary.Get
import Data.Int
import Data.Bits
import Data.IORef
import System.IO.Unsafe (unsafePerformIO)
import qualified Crypto.Hash.MD5 as MD5 (hash)
import Data.UString as U (pack, append, toByteString)
import Control.Exception as E (try)
import Control.Monad.Error
import System.IO (hFlush)
import Database.MongoDB.Internal.Util (whenJust, hGetN, bitOr, byteStringHex)
type Pipe = Pipeline Response Message
newPipe :: Handle -> IO Pipe
newPipe handle = newPipeline $ IOStream (writeMessage handle) (readMessage handle) (hClose handle)
send :: Pipe -> [Notice] -> IOE ()
send pipe notices = P.send pipe (notices, Nothing)
call :: Pipe -> [Notice] -> Request -> IOE (IOE Reply)
call pipe notices request = do
requestId <- genRequestId
promise <- P.call pipe (notices, Just (request, requestId))
return $ check requestId <$> promise
where
check requestId (responseTo, reply) = if requestId == responseTo then reply else
error $ "expected response id (" ++ show responseTo ++ ") to match request id (" ++ show requestId ++ ")"
type Message = ([Notice], Maybe (Request, RequestId))
writeMessage :: Handle -> Message -> IOE ()
writeMessage handle (notices, mRequest) = ErrorT . E.try $ do
forM_ notices $ \n -> writeReq . (Left n,) =<< genRequestId
whenJust mRequest $ writeReq . (Right *** id)
hFlush handle
where
writeReq (e, requestId) = do
hPut handle lenBytes
hPut handle bytes
where
bytes = runPut $ (either putNotice putRequest e) requestId
lenBytes = encodeSize . toEnum . fromEnum $ B.length bytes
encodeSize = runPut . putInt32 . (+ 4)
type Response = (ResponseTo, Reply)
readMessage :: Handle -> IOE Response
readMessage handle = ErrorT $ E.try readResp where
readResp = do
len <- fromEnum . decodeSize <$> hGetN handle 4
runGet getReply <$> hGetN handle len
decodeSize = subtract 4 . runGet getInt32
type FullCollection = UString
type Opcode = Int32
type RequestId = Int32
type ResponseTo = RequestId
genRequestId :: (MonadIO m) => m RequestId
genRequestId = liftIO $ atomicModifyIORef counter $ \n -> (n + 1, n) where
counter :: IORef RequestId
counter = unsafePerformIO (newIORef 0)
putHeader :: Opcode -> RequestId -> Put
putHeader opcode requestId = do
putInt32 requestId
putInt32 0
putInt32 opcode
getHeader :: Get (Opcode, ResponseTo)
getHeader = do
_requestId <- getInt32
responseTo <- getInt32
opcode <- getInt32
return (opcode, responseTo)
data Notice =
Insert {
iFullCollection :: FullCollection,
iOptions :: [InsertOption],
iDocuments :: [Document]}
| Update {
uFullCollection :: FullCollection,
uOptions :: [UpdateOption],
uSelector :: Document,
uUpdater :: Document}
| Delete {
dFullCollection :: FullCollection,
dOptions :: [DeleteOption],
dSelector :: Document}
| KillCursors {
kCursorIds :: [CursorId]}
deriving (Show, Eq)
data InsertOption = KeepGoing
deriving (Show, Eq)
data UpdateOption =
Upsert
| MultiUpdate
deriving (Show, Eq)
data DeleteOption = SingleRemove
deriving (Show, Eq)
type CursorId = Int64
nOpcode :: Notice -> Opcode
nOpcode Update{} = 2001
nOpcode Insert{} = 2002
nOpcode Delete{} = 2006
nOpcode KillCursors{} = 2007
putNotice :: Notice -> RequestId -> Put
putNotice notice requestId = do
putHeader (nOpcode notice) requestId
case notice of
Insert{..} -> do
putInt32 (iBits iOptions)
putCString iFullCollection
mapM_ putDocument iDocuments
Update{..} -> do
putInt32 0
putCString uFullCollection
putInt32 (uBits uOptions)
putDocument uSelector
putDocument uUpdater
Delete{..} -> do
putInt32 0
putCString dFullCollection
putInt32 (dBits dOptions)
putDocument dSelector
KillCursors{..} -> do
putInt32 0
putInt32 $ toEnum (X.length kCursorIds)
mapM_ putInt64 kCursorIds
iBit :: InsertOption -> Int32
iBit KeepGoing = bit 0
iBits :: [InsertOption] -> Int32
iBits = bitOr . map iBit
uBit :: UpdateOption -> Int32
uBit Upsert = bit 0
uBit MultiUpdate = bit 1
uBits :: [UpdateOption] -> Int32
uBits = bitOr . map uBit
dBit :: DeleteOption -> Int32
dBit SingleRemove = bit 0
dBits :: [DeleteOption] -> Int32
dBits = bitOr . map dBit
data Request =
Query {
qOptions :: [QueryOption],
qFullCollection :: FullCollection,
qSkip :: Int32,
qBatchSize :: Int32,
qSelector :: Document,
qProjector :: Document
} | GetMore {
gFullCollection :: FullCollection,
gBatchSize :: Int32,
gCursorId :: CursorId}
deriving (Show, Eq)
data QueryOption =
TailableCursor
| SlaveOK
| NoCursorTimeout
| AwaitData
| Partial
deriving (Show, Eq)
qOpcode :: Request -> Opcode
qOpcode Query{} = 2004
qOpcode GetMore{} = 2005
putRequest :: Request -> RequestId -> Put
putRequest request requestId = do
putHeader (qOpcode request) requestId
case request of
Query{..} -> do
putInt32 (qBits qOptions)
putCString qFullCollection
putInt32 qSkip
putInt32 qBatchSize
putDocument qSelector
unless (null qProjector) (putDocument qProjector)
GetMore{..} -> do
putInt32 0
putCString gFullCollection
putInt32 gBatchSize
putInt64 gCursorId
qBit :: QueryOption -> Int32
qBit TailableCursor = bit 1
qBit SlaveOK = bit 2
qBit NoCursorTimeout = bit 4
qBit AwaitData = bit 5
qBit Partial = bit 7
qBits :: [QueryOption] -> Int32
qBits = bitOr . map qBit
data Reply = Reply {
rResponseFlags :: [ResponseFlag],
rCursorId :: CursorId,
rStartingFrom :: Int32,
rDocuments :: [Document]
} deriving (Show, Eq)
data ResponseFlag =
CursorNotFound
| QueryError
| AwaitCapable
deriving (Show, Eq, Enum)
replyOpcode :: Opcode
replyOpcode = 1
getReply :: Get (ResponseTo, Reply)
getReply = do
(opcode, responseTo) <- getHeader
unless (opcode == replyOpcode) $ fail $ "expected reply opcode (1) but got " ++ show opcode
rResponseFlags <- rFlags <$> getInt32
rCursorId <- getInt64
rStartingFrom <- getInt32
numDocs <- fromIntegral <$> getInt32
rDocuments <- replicateM numDocs getDocument
return (responseTo, Reply{..})
rFlags :: Int32 -> [ResponseFlag]
rFlags bits = filter (testBit bits . rBit) [CursorNotFound ..]
rBit :: ResponseFlag -> Int
rBit CursorNotFound = 0
rBit QueryError = 1
rBit AwaitCapable = 3
type Username = UString
type Password = UString
type Nonce = UString
pwHash :: Username -> Password -> UString
pwHash u p = pack . byteStringHex . MD5.hash . toByteString $ u `U.append` ":mongo:" `U.append` p
pwKey :: Nonce -> Username -> Password -> UString
pwKey n u p = pack . byteStringHex . MD5.hash . toByteString . U.append n . U.append u $ pwHash u p