{-# LANGUAGE BangPatterns      #-}
{-# LANGUAGE CPP               #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TupleSections     #-}

module Database.CQL.Protocol.Codec
    ( encodeByte
    , decodeByte

    , encodeSignedByte
    , decodeSignedByte

    , encodeShort
    , decodeShort

    , encodeSignedShort
    , decodeSignedShort

    , encodeInt
    , decodeInt

    , encodeString
    , decodeString

    , encodeLongString
    , decodeLongString

    , encodeBytes
    , decodeBytes

    , encodeShortBytes
    , decodeShortBytes

    , encodeUUID
    , decodeUUID

    , encodeList
    , decodeList

    , encodeMap
    , decodeMap

    , encodeMultiMap
    , decodeMultiMap

    , encodeSockAddr
    , decodeSockAddr

    , encodeConsistency
    , decodeConsistency

    , encodeOpCode
    , decodeOpCode

    , encodePagingState
    , decodePagingState

    , decodeKeyspace
    , decodeTable
    , decodeColumnType
    , decodeQueryId

    , putValue
    , getValue
    ) where

import Control.Applicative
import Control.Monad
import Data.Bits
import Data.ByteString (ByteString)
import Data.Decimal
import Data.Int
import Data.IP
#ifdef INCOMPATIBLE_VARINT
import Data.List (unfoldr)
#else
import Data.List (foldl')
#endif
import Data.Text (Text)
import Data.UUID (UUID)
import Data.Word
import Data.Serialize hiding (decode, encode)
import Database.CQL.Protocol.Types
import Network.Socket (SockAddr (..), PortNumber (..))
import Prelude

import qualified Data.ByteString         as B
import qualified Data.ByteString.Lazy    as LB
import qualified Data.Text.Encoding      as T
import qualified Data.Text.Lazy          as LT
import qualified Data.Text.Lazy.Encoding as LT
import qualified Data.UUID               as UUID

------------------------------------------------------------------------------
-- Byte

encodeByte :: Putter Word8
encodeByte = put

decodeByte :: Get Word8
decodeByte = get

------------------------------------------------------------------------------
-- Signed Byte

encodeSignedByte :: Putter Int8
encodeSignedByte = put

decodeSignedByte :: Get Int8
decodeSignedByte = get

------------------------------------------------------------------------------
-- Short

encodeShort :: Putter Word16
encodeShort = put

decodeShort :: Get Word16
decodeShort = get

------------------------------------------------------------------------------
-- Signed Short

encodeSignedShort :: Putter Int16
encodeSignedShort = put

decodeSignedShort :: Get Int16
decodeSignedShort = get

------------------------------------------------------------------------------
-- Int

encodeInt :: Putter Int32
encodeInt = put

decodeInt :: Get Int32
decodeInt = get

------------------------------------------------------------------------------
-- String

encodeString :: Putter Text
encodeString = encodeShortBytes . T.encodeUtf8

decodeString :: Get Text
decodeString = T.decodeUtf8 <$> decodeShortBytes

------------------------------------------------------------------------------
-- Long String

encodeLongString :: Putter LT.Text
encodeLongString = encodeBytes . LT.encodeUtf8

decodeLongString :: Get LT.Text
decodeLongString = do
    n <- get :: Get Int32
    LT.decodeUtf8 <$> getLazyByteString (fromIntegral n)

------------------------------------------------------------------------------
-- Bytes

encodeBytes :: Putter LB.ByteString
encodeBytes bs = do
    put (fromIntegral (LB.length bs) :: Int32)
    putLazyByteString bs

decodeBytes :: Get (Maybe LB.ByteString)
decodeBytes = do
    n <- get :: Get Int32
    if n < 0
        then return Nothing
        else Just <$> getLazyByteString (fromIntegral n)

------------------------------------------------------------------------------
-- Short Bytes

encodeShortBytes :: Putter ByteString
encodeShortBytes bs = do
    put (fromIntegral (B.length bs) :: Word16)
    putByteString bs

decodeShortBytes :: Get ByteString
decodeShortBytes = do
    n <- get :: Get Word16
    getByteString (fromIntegral n)

------------------------------------------------------------------------------
-- UUID

encodeUUID :: Putter UUID
encodeUUID = putLazyByteString . UUID.toByteString

decodeUUID :: Get UUID
decodeUUID = do
    uuid <- UUID.fromByteString <$> getLazyByteString 16
    maybe (fail "decode-uuid: invalid") return uuid

------------------------------------------------------------------------------
-- String List

encodeList :: Putter [Text]
encodeList sl = do
    put (fromIntegral (length sl) :: Word16)
    mapM_ encodeString sl

decodeList :: Get [Text]
decodeList = do
    n <- get :: Get Word16
    replicateM (fromIntegral n) decodeString

------------------------------------------------------------------------------
-- String Map

encodeMap :: Putter [(Text, Text)]
encodeMap m = do
    put (fromIntegral (length m) :: Word16)
    forM_ m $ \(k, v) -> encodeString k >> encodeString v

decodeMap :: Get [(Text, Text)]
decodeMap = do
    n <- get :: Get Word16
    replicateM (fromIntegral n) ((,) <$> decodeString <*> decodeString)

------------------------------------------------------------------------------
-- String Multi-Map

encodeMultiMap :: Putter [(Text, [Text])]
encodeMultiMap mm = do
    put (fromIntegral (length mm) :: Word16)
    forM_ mm $ \(k, v) -> encodeString k >> encodeList v

decodeMultiMap :: Get [(Text, [Text])]
decodeMultiMap = do
    n <- get :: Get Word16
    replicateM (fromIntegral n) ((,) <$> decodeString <*> decodeList)

------------------------------------------------------------------------------
-- Inet Address

encodeSockAddr :: Putter SockAddr
encodeSockAddr (SockAddrInet p a) = do
    putWord8 4
    putWord32le a
    putWord32be (fromIntegral p)
encodeSockAddr (SockAddrInet6 p _ (a, b, c, d) _) = do
    putWord8 16
    putWord32host a
    putWord32host b
    putWord32host c
    putWord32host d
    putWord32be (fromIntegral p)
encodeSockAddr (SockAddrUnix _) = fail "encode-socket: unix address not allowed"
#if MIN_VERSION_network(2,6,1)
encodeSockAddr (SockAddrCan _) = fail "encode-socket: can address not allowed"
#endif

decodeSockAddr :: Get SockAddr
decodeSockAddr = do
    n <- getWord8
    case n of
        4  -> do
            i <- getIPv4
            p <- getPort
            return $ SockAddrInet p i
        16 -> do
            i <- getIPv6
            p <- getPort
            return $ SockAddrInet6 p 0 i 0
        _  -> fail $ "decode-socket: unknown: " ++ show n
  where
    getPort :: Get PortNumber
    getPort = fromIntegral <$> getWord32be

    getIPv4 :: Get Word32
    getIPv4 = getWord32le

    getIPv6 :: Get (Word32, Word32, Word32, Word32)
    getIPv6 = (,,,) <$> getWord32host <*> getWord32host <*> getWord32host <*> getWord32host

------------------------------------------------------------------------------
-- Consistency

encodeConsistency :: Putter Consistency
encodeConsistency Any         = encodeShort 0x00
encodeConsistency One         = encodeShort 0x01
encodeConsistency Two         = encodeShort 0x02
encodeConsistency Three       = encodeShort 0x03
encodeConsistency Quorum      = encodeShort 0x04
encodeConsistency All         = encodeShort 0x05
encodeConsistency LocalQuorum = encodeShort 0x06
encodeConsistency EachQuorum  = encodeShort 0x07
encodeConsistency Serial      = encodeShort 0x08
encodeConsistency LocalSerial = encodeShort 0x09
encodeConsistency LocalOne    = encodeShort 0x0A

decodeConsistency :: Get Consistency
decodeConsistency = decodeShort >>= mapCode
      where
        mapCode 0x00 = return Any
        mapCode 0x01 = return One
        mapCode 0x02 = return Two
        mapCode 0x03 = return Three
        mapCode 0x04 = return Quorum
        mapCode 0x05 = return All
        mapCode 0x06 = return LocalQuorum
        mapCode 0x07 = return EachQuorum
        mapCode 0x08 = return Serial
        mapCode 0x09 = return LocalSerial
        mapCode 0x10 = return LocalOne
        mapCode code = fail $ "decode-consistency: unknown: " ++ show code

------------------------------------------------------------------------------
-- OpCode

encodeOpCode :: Putter OpCode
encodeOpCode OcError         = encodeByte 0x00
encodeOpCode OcStartup       = encodeByte 0x01
encodeOpCode OcReady         = encodeByte 0x02
encodeOpCode OcAuthenticate  = encodeByte 0x03
encodeOpCode OcOptions       = encodeByte 0x05
encodeOpCode OcSupported     = encodeByte 0x06
encodeOpCode OcQuery         = encodeByte 0x07
encodeOpCode OcResult        = encodeByte 0x08
encodeOpCode OcPrepare       = encodeByte 0x09
encodeOpCode OcExecute       = encodeByte 0x0A
encodeOpCode OcRegister      = encodeByte 0x0B
encodeOpCode OcEvent         = encodeByte 0x0C
encodeOpCode OcBatch         = encodeByte 0x0D
encodeOpCode OcAuthChallenge = encodeByte 0x0E
encodeOpCode OcAuthResponse  = encodeByte 0x0F
encodeOpCode OcAuthSuccess   = encodeByte 0x10

decodeOpCode :: Get OpCode
decodeOpCode = decodeByte >>= mapCode
  where
    mapCode 0x00 = return OcError
    mapCode 0x01 = return OcStartup
    mapCode 0x02 = return OcReady
    mapCode 0x03 = return OcAuthenticate
    mapCode 0x05 = return OcOptions
    mapCode 0x06 = return OcSupported
    mapCode 0x07 = return OcQuery
    mapCode 0x08 = return OcResult
    mapCode 0x09 = return OcPrepare
    mapCode 0x0A = return OcExecute
    mapCode 0x0B = return OcRegister
    mapCode 0x0C = return OcEvent
    mapCode 0x0D = return OcBatch
    mapCode 0x0E = return OcAuthChallenge
    mapCode 0x0F = return OcAuthResponse
    mapCode 0x10 = return OcAuthSuccess
    mapCode word = fail $ "decode-opcode: unknown: " ++ show word

------------------------------------------------------------------------------
-- ColumnType

decodeColumnType :: Get ColumnType
decodeColumnType = decodeShort >>= toType
  where
    toType 0x0000 = CustomColumn <$> decodeString
    toType 0x0001 = return AsciiColumn
    toType 0x0002 = return BigIntColumn
    toType 0x0003 = return BlobColumn
    toType 0x0004 = return BooleanColumn
    toType 0x0005 = return CounterColumn
    toType 0x0006 = return DecimalColumn
    toType 0x0007 = return DoubleColumn
    toType 0x0008 = return FloatColumn
    toType 0x0009 = return IntColumn
    toType 0x000A = return TextColumn
    toType 0x000B = return TimestampColumn
    toType 0x000C = return UuidColumn
    toType 0x000D = return VarCharColumn
    toType 0x000E = return VarIntColumn
    toType 0x000F = return TimeUuidColumn
    toType 0x0010 = return InetColumn
    toType 0x0020 = ListColumn  <$> (decodeShort >>= toType)
    toType 0x0021 = MapColumn   <$> (decodeShort >>= toType) <*> (decodeShort >>= toType)
    toType 0x0022 = SetColumn   <$> (decodeShort >>= toType)
    toType 0x0030 = do
        _ <- decodeString -- Keyspace (not used by this library)
        t <- decodeString -- Type name
        UdtColumn t <$> do
            n <- fromIntegral <$> decodeShort
            replicateM n ((,) <$> decodeString <*> (decodeShort >>= toType))
    toType 0x0031 = TupleColumn <$> do
        n <- fromIntegral <$> decodeShort
        replicateM n (decodeShort >>= toType)
    toType other  = fail $ "decode-type: unknown: " ++ show other

------------------------------------------------------------------------------
-- Paging State

encodePagingState :: Putter PagingState
encodePagingState (PagingState s) = encodeBytes s

decodePagingState :: Get (Maybe PagingState)
decodePagingState = fmap PagingState <$> decodeBytes

------------------------------------------------------------------------------
-- Value

putValue :: Version -> Putter Value
putValue V3 (CqlList x)        = toBytes 4 $ do
    encodeInt (fromIntegral (length x))
    mapM_ (toBytes 4 . putNative V3) x
putValue V2 (CqlList x)        = toBytes 4 $ do
    encodeShort (fromIntegral (length x))
    mapM_ (toBytes 2 . putNative V2) x
putValue V3 (CqlSet x)         = toBytes 4 $ do
    encodeInt (fromIntegral (length x))
    mapM_ (toBytes 4 . putNative V3) x
putValue V2 (CqlSet x)         = toBytes 4 $ do
    encodeShort (fromIntegral (length x))
    mapM_ (toBytes 2 . putNative V2) x
putValue V3 (CqlMap x)         = toBytes 4 $ do
    encodeInt (fromIntegral (length x))
    forM_ x $ \(k, v) -> toBytes 4 (putNative V3 k) >> toBytes 4 (putNative V3 v)
putValue V2 (CqlMap x)         = toBytes 4 $ do
    encodeShort (fromIntegral (length x))
    forM_ x $ \(k, v) -> toBytes 2 (putNative V2 k) >> toBytes 2 (putNative V2 v)
putValue V3 (CqlTuple x)       =
    toBytes 4 $ putByteString $ runPut (mapM_ (putValue V3) x)
putValue _ (CqlMaybe Nothing)  = put (-1 :: Int32)
putValue v (CqlMaybe (Just x)) = putValue v x
putValue v value               = toBytes 4 $ putNative v value

putNative :: Version -> Putter Value
putNative _ (CqlCustom x)    = putLazyByteString x
putNative _ (CqlBoolean x)   = putWord8 $ if x then 1 else 0
putNative _ (CqlInt x)       = put x
putNative _ (CqlBigInt x)    = put x
putNative _ (CqlFloat x)     = putFloat32be x
putNative _ (CqlDouble x)    = putFloat64be x
putNative _ (CqlText x)      = putByteString (T.encodeUtf8 x)
putNative _ (CqlUuid x)      = encodeUUID x
putNative _ (CqlTimeUuid x)  = encodeUUID x
putNative _ (CqlTimestamp x) = put x
putNative _ (CqlAscii x)     = putByteString (T.encodeUtf8 x)
putNative _ (CqlBlob x)      = putLazyByteString x
putNative _ (CqlCounter x)   = put x
putNative _ (CqlInet x)      = case x of
    IPv4 i -> putWord32le (toHostAddress i)
    IPv6 i -> do
        let (a, b, c, d) = toHostAddress6 i
        putWord32host a
        putWord32host b
        putWord32host c
        putWord32host d
putNative _ (CqlVarInt x)   = integer2bytes x
putNative _ (CqlDecimal x)  = do
    put (fromIntegral (decimalPlaces x) :: Int32)
    integer2bytes (decimalMantissa x)
putNative V3   (CqlUdt   x) = putByteString $ runPut (mapM_ (putValue V3 . snd) x)
putNative V2 v@(CqlUdt   _) = fail $ "putNative: udt: " ++ show v
putNative _  v@(CqlList  _) = fail $ "putNative: collection type: " ++ show v
putNative _  v@(CqlSet   _) = fail $ "putNative: collection type: " ++ show v
putNative _  v@(CqlMap   _) = fail $ "putNative: collection type: " ++ show v
putNative _  v@(CqlMaybe _) = fail $ "putNative: collection type: " ++ show v
putNative _  v@(CqlTuple _) = fail $ "putNative: tuple type: " ++ show v

-- Note: Empty lists, maps and sets are represented as null in cassandra.
getValue :: Version -> ColumnType -> Get Value
getValue V3 (ListColumn t)    = CqlList <$> (getList $ do
    len <- decodeInt
    replicateM (fromIntegral len) (withBytes 4 (getNative V3 t)))
getValue V2 (ListColumn t)    = CqlList <$> (getList $ do
    len <- decodeShort
    replicateM (fromIntegral len) (withBytes 2 (getNative V2 t)))
getValue V3 (SetColumn t)     = CqlSet <$> (getList $ do
    len <- decodeInt
    replicateM (fromIntegral len) (withBytes 4 (getNative V3 t)))
getValue V2 (SetColumn t)     = CqlSet <$> (getList $ do
    len <- decodeShort
    replicateM (fromIntegral len) (withBytes 2 (getNative V2 t)))
getValue V3 (MapColumn t u)   = CqlMap <$> (getList $ do
    len <- decodeInt
    replicateM (fromIntegral len)
               ((,) <$> withBytes 4 (getNative V3 t) <*> withBytes 4 (getNative V3 u)))
getValue V2 (MapColumn t u)   = CqlMap <$> (getList $ do
    len <- decodeShort
    replicateM (fromIntegral len)
               ((,) <$> withBytes 2 (getNative V2 t) <*> withBytes 2 (getNative V2 u)))
getValue V3 (TupleColumn t)   = do
    b <- withBytes 4 remainingBytes
    either fail return $ flip runGet b $ CqlTuple <$> mapM (getValue V3) t
getValue v (MaybeColumn t)    = do
    n <- lookAhead (get :: Get Int32)
    if n < 0
        then uncheckedSkip 4 >> return (CqlMaybe Nothing)
        else CqlMaybe . Just <$> getValue v t
getValue v colType            = withBytes 4 $ getNative v colType

getNative :: Version -> ColumnType -> Get Value
getNative _ (CustomColumn _) = CqlCustom <$> remainingBytesLazy
getNative _ BooleanColumn    = CqlBoolean . (/= 0) <$> getWord8
getNative _ IntColumn        = CqlInt <$> get
getNative _ BigIntColumn     = CqlBigInt <$> get
getNative _ FloatColumn      = CqlFloat  <$> getFloat32be
getNative _ DoubleColumn     = CqlDouble <$> getFloat64be
getNative _ TextColumn       = CqlText . T.decodeUtf8 <$> remainingBytes
getNative _ VarCharColumn    = CqlText . T.decodeUtf8 <$> remainingBytes
getNative _ AsciiColumn      = CqlAscii . T.decodeUtf8 <$> remainingBytes
getNative _ BlobColumn       = CqlBlob <$> remainingBytesLazy
getNative _ UuidColumn       = CqlUuid <$> decodeUUID
getNative _ TimeUuidColumn   = CqlTimeUuid <$> decodeUUID
getNative _ TimestampColumn  = CqlTimestamp <$> get
getNative _ CounterColumn    = CqlCounter <$> get
getNative _ InetColumn       = CqlInet <$> do
    len <- remaining
    case len of
        4  -> IPv4 . fromHostAddress <$> getWord32le
        16 -> do
            a <- (,,,) <$> getWord32host <*> getWord32host <*> getWord32host <*> getWord32host
            return $ IPv6 (fromHostAddress6 a)
        n  -> fail $ "getNative: invalid Inet length: " ++ show n
getNative _ VarIntColumn  = CqlVarInt <$> bytes2integer
getNative _ DecimalColumn = do
    x <- get :: Get Int32
    y <- bytes2integer
    return (CqlDecimal (Decimal (fromIntegral x) y))
getNative V3 (UdtColumn _ x) = do
    b <- remainingBytes
    either fail return $ flip runGet b $ CqlUdt <$> do
        let (n, t) = unzip x
        zip n <$> mapM (getValue V3) t
getNative V2 c@(UdtColumn _ _) = fail $ "getNative: udt: " ++ show c
getNative _  c@(ListColumn  _) = fail $ "getNative: collection type: " ++ show c
getNative _  c@(SetColumn   _) = fail $ "getNative: collection type: " ++ show c
getNative _  c@(MapColumn _ _) = fail $ "getNative: collection type: " ++ show c
getNative _  c@(MaybeColumn _) = fail $ "getNative: collection type: " ++ show c
getNative _  c@(TupleColumn _) = fail $ "getNative: tuple type: " ++ show c

getList :: Get [a] -> Get [a]
getList m = do
    n <- lookAhead (get :: Get Int32)
    if n < 0 then uncheckedSkip 4 >> return []
             else withBytes 4 m

withBytes :: Int -> Get a -> Get a
withBytes s p = do
    n <- case s of
        2 -> fromIntegral <$> (get :: Get Word16)
        4 -> fromIntegral <$> (get :: Get Int32)
        _ -> fail $ "withBytes: invalid size: " ++ show s
    when (n < 0) $
        fail "withBytes: null"
    b <- getBytes n
    case runGet p b of
        Left  e -> fail $ "withBytes: " ++ e
        Right x -> return x

remainingBytes :: Get ByteString
remainingBytes = remaining >>= getByteString . fromIntegral

remainingBytesLazy :: Get LB.ByteString
remainingBytesLazy = remaining >>= getLazyByteString . fromIntegral

toBytes :: Int -> Put -> Put
toBytes s p = do
    let bytes = runPut p
    case s of
        2 -> put (fromIntegral (B.length bytes) :: Word16)
        _ -> put (fromIntegral (B.length bytes) :: Int32)
    putByteString bytes

#ifdef INCOMPATIBLE_VARINT

-- 'integer2bytes' and 'bytes2integer' implementations are taken
-- from cereal's instance declaration of 'Serialize' for 'Integer'
-- except that no distinction between small and large integers is made.
-- Cf. to LICENSE for copyright details.
integer2bytes :: Putter Integer
integer2bytes n = do
    put sign
    put (unroll (abs n))
  where
    sign = fromIntegral (signum n) :: Word8

    unroll :: Integer -> [Word8]
    unroll = unfoldr step
      where
        step 0 = Nothing
        step i = Just (fromIntegral i, i `shiftR` 8)

bytes2integer :: Get Integer
bytes2integer = do
    sign  <- get
    bytes <- get
    let v = roll bytes
    return $! if sign == (1 :: Word8) then v else - v
  where
    roll :: [Word8] -> Integer
    roll = foldr unstep 0
      where
        unstep b a = a `shiftL` 8 .|. fromIntegral b

#else

integer2bytes :: Putter Integer
integer2bytes n
    | n == 0  = putWord8 0x00
    | n == -1 = putWord8 0xFF
    | n <  0  = do
        let bytes = explode (-1) n
        unless (head bytes >= 0x80) $
            putWord8 0xFF
        mapM_ putWord8 bytes
    | otherwise = do
        let bytes = explode 0 n
        unless (head bytes < 0x80) $
            putWord8 0x00
        mapM_ putWord8 bytes

explode :: Integer -> Integer -> [Word8]
explode x n = loop n []
  where
    loop !i !acc
        | i == x    = acc
        | otherwise = loop (i `shiftR` 8) (fromIntegral i : acc)

bytes2integer :: Get Integer
bytes2integer = do
    msb   <- getWord8
    bytes <- B.unpack <$> remainingBytes
    if msb < 0x80
        then return (implode (msb:bytes))
        else return (- (implode (map complement (msb:bytes)) + 1))

implode :: [Word8] -> Integer
implode = foldl' fun 0
  where
    fun i b = i `shiftL` 8 .|. fromIntegral b

#endif
------------------------------------------------------------------------------
-- Various

decodeKeyspace :: Get Keyspace
decodeKeyspace = Keyspace <$> decodeString

decodeTable :: Get Table
decodeTable = Table <$> decodeString

decodeQueryId :: Get (QueryId k a b)
decodeQueryId = QueryId <$> decodeShortBytes