module Database.MySQL.Protocol.Packet where
import Control.Applicative
import Control.Exception (Exception (..), throwIO)
import Data.Binary.Parser
import Data.Binary.Put
import Data.Binary (Binary(..), encode)
import Data.Bits
import qualified Data.ByteString as B
import Data.ByteString.Char8 as BC
import qualified Data.ByteString.Lazy as L
import qualified Data.ByteString.Lazy.Internal as L
import Data.Int.Int24
import Data.Int
import Data.Word
import Data.Typeable
import Data.Word.Word24
data Packet = Packet
{ pLen :: !Int64
, pSeqN :: !Word8
, pBody :: !L.ByteString
} deriving (Show, Eq)
putPacket :: Packet -> Put
putPacket (Packet len seqN body) = do
putWord24le (fromIntegral len)
putWord8 seqN
putLazyByteString body
getPacket :: Get Packet
getPacket = do
len <- fromIntegral <$> getWord24le
seqN <- getWord8
body <- getLazyByteString (fromIntegral len)
return (Packet len seqN body)
instance Binary Packet where
put = putPacket
get = getPacket
isERR :: Packet -> Bool
isERR p = L.index (pBody p) 0 == 0xFF
isOK :: Packet -> Bool
isOK p = L.index (pBody p) 0 == 0x00
isEOF :: Packet -> Bool
isEOF p = L.index (pBody p) 0 == 0xFE
decodeFromPacket :: Binary a => Packet -> IO a
decodeFromPacket = getFromPacket get
getFromPacket :: Get a -> Packet -> IO a
getFromPacket g (Packet _ _ body) = case parseDetailLazy g body of
Left (buf, offset, errmsg) -> throwIO (DecodePacketFailed buf offset errmsg)
Right (_, _, r ) -> return r
data DecodePacketException = DecodePacketFailed ByteString ByteOffset String
deriving (Typeable, Show)
instance Exception DecodePacketException
encodeToPacket :: Binary a => Word8 -> a -> Packet
encodeToPacket seqN payload =
let s = encode payload
l = L.length s
in Packet (fromIntegral l) seqN s
putToPacket :: Word8 -> Put -> Packet
putToPacket seqN payload =
let s = runPut payload
l = L.length s
in Packet (fromIntegral l) seqN s
data OK = OK
{ okAffectedRows :: !Int
, okLastInsertID :: !Int
, okStatus :: !Word16
, okWarningCnt :: !Word16
} deriving (Show, Eq)
getOK :: Get OK
getOK = OK <$ skipN 1
<*> getLenEncInt
<*> getLenEncInt
<*> getWord16le
<*> getWord16le
putOK :: OK -> Put
putOK (OK row lid stat wcnt) = do
putWord8 0x00
putLenEncInt row
putLenEncInt lid
putWord16le stat
putWord16le wcnt
instance Binary OK where
get = getOK
put = putOK
data ERR = ERR
{ errCode :: !Word16
, errState :: !ByteString
, errMsg :: !ByteString
} deriving (Show, Eq)
getERR :: Get ERR
getERR = ERR <$ skipN 1
<*> getWord16le
<* skipN 1
<*> getByteString 5
<*> getRemainingByteString
putERR :: ERR -> Put
putERR (ERR code stat msg) = do
putWord8 0xFF
putWord16le code
putWord8 35
putByteString stat
putByteString msg
instance Binary ERR where
get = getERR
put = putERR
data EOF = EOF
{ eofWarningCnt :: !Word16
, eofStatus :: !Word16
} deriving (Show, Eq)
getEOF :: Get EOF
getEOF = EOF <$ skipN 1
<*> getWord16le
<*> getWord16le
putEOF :: EOF -> Put
putEOF (EOF wcnt stat) = do
putWord8 0xFE
putWord16le wcnt
putWord16le stat
instance Binary EOF where
get = getEOF
put = putEOF
getByteStringNul :: Get ByteString
getByteStringNul = L.toStrict <$> getLazyByteStringNul
getRemainingByteString :: Get ByteString
getRemainingByteString = L.toStrict <$> getRemainingLazyByteString
putLenEncBytes :: ByteString -> Put
putLenEncBytes c = do
putLenEncInt (B.length c)
putByteString c
getLenEncBytes :: Get ByteString
getLenEncBytes = getLenEncInt >>= getByteString
getLenEncInt:: Get Int
getLenEncInt = getWord8 >>= word2Len
where
word2Len l
| l < 0xFB = pure (fromIntegral l)
| l == 0xFC = fromIntegral <$> getWord16le
| l == 0xFD = fromIntegral <$> getWord24le
| l == 0xFE = fromIntegral <$> getWord64le
| otherwise = fail $ "invalid length val " ++ show l
putLenEncInt:: Int -> Put
putLenEncInt x
| x < 251 = putWord8 (fromIntegral x)
| x < 65536 = putWord8 0xFC >> putWord16le (fromIntegral x)
| x < 16777216 = putWord8 0xFD >> putWord24le (fromIntegral x)
| otherwise = putWord8 0xFE >> putWord64le (fromIntegral x)
putWord24le :: Word32 -> Put
putWord24le v = do
putWord16le $ fromIntegral v
putWord8 $ fromIntegral (v `shiftR` 16)
getWord24le :: Get Word32
getWord24le = do
a <- fromIntegral <$> getWord16le
b <- fromIntegral <$> getWord8
return $! a .|. (b `shiftL` 16)
putWord48le :: Word64 -> Put
putWord48le v = do
putWord32le $ fromIntegral v
putWord16le $ fromIntegral (v `shiftR` 32)
getWord48le :: Get Word64
getWord48le = do
a <- fromIntegral <$> getWord32le
b <- fromIntegral <$> getWord16le
return $! a .|. (b `shiftL` 32)
getWord24be :: Get Word24
getWord24be = do
a <- fromIntegral <$> getWord16be
b <- fromIntegral <$> getWord8
return $! b .|. (a `shiftL` 8)
getInt24be :: Get Int24
getInt24be = do
a <- fromIntegral <$> getWord16be
b <- fromIntegral <$> getWord8
return $! fromIntegral $ (b .|. (a `shiftL` 8) :: Word24)
getWord40be, getWord48be, getWord56be :: Get Word64
getWord40be = do
a <- fromIntegral <$> getWord32be
b <- fromIntegral <$> getWord8
return $! (a `shiftL` 8) .|. b
getWord48be = do
a <- fromIntegral <$> getWord32be
b <- fromIntegral <$> getWord16be
return $! (a `shiftL` 16) .|. b
getWord56be = do
a <- fromIntegral <$> getWord32be
b <- fromIntegral <$> getWord24be
return $! (a `shiftL` 24) .|. b