-- |
-- Module      : Network.TLS.Wire
-- License     : BSD-style
-- Maintainer  : Vincent Hanquez <vincent@snarc.org>
-- Stability   : experimental
-- Portability : unknown
--
-- the Wire module is a specialized marshalling/unmarshalling package related to the TLS protocol.
-- all multibytes values are written as big endian.
--
module Network.TLS.Wire
    ( Get
    , GetResult(..)
    , GetContinuation
    , runGet
    , runGetErr
    , runGetMaybe
    , tryGet
    , remaining
    , getWord8
    , getWords8
    , getWord16
    , getWords16
    , getWord24
    , getWord32
    , getWord64
    , getBytes
    , getOpaque8
    , getOpaque16
    , getOpaque24
    , getInteger16
    , getBigNum16
    , getList
    , processBytes
    , isEmpty
    , Put
    , runPut
    , putWord8
    , putWords8
    , putWord16
    , putWords16
    , putWord24
    , putWord32
    , putWord64
    , putBytes
    , putOpaque8
    , putOpaque16
    , putOpaque24
    , putInteger16
    , putBigNum16
    , encodeWord16
    , encodeWord32
    , encodeWord64
    ) where

import Data.Serialize.Get hiding (runGet)
import qualified Data.Serialize.Get as G
import Data.Serialize.Put
import qualified Data.ByteString as B
import Network.TLS.Struct
import Network.TLS.Imports
import Network.TLS.Util.Serialization

type GetContinuation a = ByteString -> GetResult a
data GetResult a =
      GotError TLSError
    | GotPartial (GetContinuation a)
    | GotSuccess a
    | GotSuccessRemaining a ByteString

runGet :: String -> Get a -> ByteString -> GetResult a
runGet :: forall a. String -> Get a -> ByteString -> GetResult a
runGet String
lbl Get a
f = forall {a}. Result a -> GetResult a
toGetResult forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Get a -> ByteString -> Result a
G.runGetPartial (forall a. String -> Get a -> Get a
label String
lbl Get a
f)
  where toGetResult :: Result a -> GetResult a
toGetResult (G.Fail String
err ByteString
_)    = forall a. TLSError -> GetResult a
GotError (String -> TLSError
Error_Packet_Parsing String
err)
        toGetResult (G.Partial ByteString -> Result a
cont)  = forall a. GetContinuation a -> GetResult a
GotPartial (Result a -> GetResult a
toGetResult forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ByteString -> Result a
cont)
        toGetResult (G.Done a
r ByteString
bsLeft)
            | ByteString -> Bool
B.null ByteString
bsLeft = forall a. a -> GetResult a
GotSuccess a
r
            | Bool
otherwise     = forall a. a -> GetContinuation a
GotSuccessRemaining a
r ByteString
bsLeft

runGetErr :: String -> Get a -> ByteString -> Either TLSError a
runGetErr :: forall a. String -> Get a -> ByteString -> Either TLSError a
runGetErr String
lbl Get a
getter ByteString
b = forall {b}. GetResult b -> Either TLSError b
toSimple forall a b. (a -> b) -> a -> b
$ forall a. String -> Get a -> ByteString -> GetResult a
runGet String
lbl Get a
getter ByteString
b
  where toSimple :: GetResult b -> Either TLSError b
toSimple (GotError TLSError
err) = forall a b. a -> Either a b
Left TLSError
err
        toSimple (GotPartial GetContinuation b
_) = forall a b. a -> Either a b
Left (String -> TLSError
Error_Packet_Parsing (String
lbl forall a. [a] -> [a] -> [a]
++ String
": parsing error: partial packet"))
        toSimple (GotSuccessRemaining b
_ ByteString
_) = forall a b. a -> Either a b
Left (String -> TLSError
Error_Packet_Parsing (String
lbl forall a. [a] -> [a] -> [a]
++ String
": parsing error: remaining bytes"))
        toSimple (GotSuccess b
r) = forall a b. b -> Either a b
Right b
r

runGetMaybe :: Get a -> ByteString -> Maybe a
runGetMaybe :: forall a. Get a -> ByteString -> Maybe a
runGetMaybe Get a
f = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a b. a -> b -> a
const forall a. Maybe a
Nothing) forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Get a -> ByteString -> Either String a
G.runGet Get a
f

tryGet :: Get a -> ByteString -> Maybe a
tryGet :: forall a. Get a -> ByteString -> Maybe a
tryGet Get a
f = forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall a b. a -> b -> a
const forall a. Maybe a
Nothing) forall a. a -> Maybe a
Just forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Get a -> ByteString -> Either String a
G.runGet Get a
f

getWords8 :: Get [Word8]
getWords8 :: Get [Word8]
getWords8 = Get Word8
getWord8 forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Word8
lenb -> forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
lenb) Get Word8
getWord8

getWord16 :: Get Word16
getWord16 :: Get Word16
getWord16 = Get Word16
getWord16be

getWords16 :: Get [Word16]
getWords16 :: Get [Word16]
getWords16 = Get Word16
getWord16 forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Word16
lenb -> forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM (forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
lenb forall a. Integral a => a -> a -> a
`div` Int
2) Get Word16
getWord16

getWord24 :: Get Int
getWord24 :: Get Int
getWord24 = do
    Int
a <- forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
getWord8
    Int
b <- forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
getWord8
    Int
c <- forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Word8
getWord8
    forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ (Int
a forall a. Bits a => a -> Int -> a
`shiftL` Int
16) forall a. Bits a => a -> a -> a
.|. (Int
b forall a. Bits a => a -> Int -> a
`shiftL` Int
8) forall a. Bits a => a -> a -> a
.|. Int
c

getWord32 :: Get Word32
getWord32 :: Get Word32
getWord32 = Get Word32
getWord32be

getWord64 :: Get Word64
getWord64 :: Get Word64
getWord64 = Get Word64
getWord64be

getOpaque8 :: Get ByteString
getOpaque8 :: Get ByteString
getOpaque8 = Get Word8
getWord8 forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> Get ByteString
getBytes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral

getOpaque16 :: Get ByteString
getOpaque16 :: Get ByteString
getOpaque16 = Get Word16
getWord16 forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> Get ByteString
getBytes forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (Integral a, Num b) => a -> b
fromIntegral

getOpaque24 :: Get ByteString
getOpaque24 :: Get ByteString
getOpaque24 = Get Int
getWord24 forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Int -> Get ByteString
getBytes

getInteger16 :: Get Integer
getInteger16 :: Get Integer
getInteger16 = forall ba. ByteArrayAccess ba => ba -> Integer
os2ip forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get ByteString
getOpaque16

getBigNum16 :: Get BigNum
getBigNum16 :: Get BigNum
getBigNum16 = ByteString -> BigNum
BigNum forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get ByteString
getOpaque16

getList :: Int -> Get (Int, a) -> Get [a]
getList :: forall a. Int -> Get (Int, a) -> Get [a]
getList Int
totalLen Get (Int, a)
getElement = forall a. Int -> Get a -> Get a
isolate Int
totalLen (Int -> Get [a]
getElements Int
totalLen)
  where getElements :: Int -> Get [a]
getElements Int
len
            | Int
len forall a. Ord a => a -> a -> Bool
< Int
0     = forall a. HasCallStack => String -> a
error String
"list consumed too much data. should never happen with isolate."
            | Int
len forall a. Eq a => a -> a -> Bool
== Int
0    = forall (m :: * -> *) a. Monad m => a -> m a
return []
            | Bool
otherwise   = Get (Int, a)
getElement forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(Int
elementLen, a
a) -> (:) a
a forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Get [a]
getElements (Int
len forall a. Num a => a -> a -> a
- Int
elementLen)

processBytes :: Int -> Get a -> Get a
processBytes :: forall a. Int -> Get a -> Get a
processBytes Int
i Get a
f = forall a. Int -> Get a -> Get a
isolate Int
i Get a
f

putWords8 :: [Word8] -> Put
putWords8 :: [Word8] -> Put
putWords8 [Word8]
l = do
    Putter Word8
putWord8 forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Word8]
l)
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Putter Word8
putWord8 [Word8]
l

putWord16 :: Word16 -> Put
putWord16 :: Word16 -> Put
putWord16 = Word16 -> Put
putWord16be

putWord32 :: Word32 -> Put
putWord32 :: Word32 -> Put
putWord32 = Word32 -> Put
putWord32be

putWord64 :: Word64 -> Put
putWord64 :: Word64 -> Put
putWord64 = Word64 -> Put
putWord64be

putWords16 :: [Word16] -> Put
putWords16 :: [Word16] -> Put
putWords16 [Word16]
l = do
    Word16 -> Put
putWord16 forall a b. (a -> b) -> a -> b
$ Word16
2 forall a. Num a => a -> a -> a
* forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Word16]
l)
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Word16 -> Put
putWord16 [Word16]
l

putWord24 :: Int -> Put
putWord24 :: Int -> Put
putWord24 Int
i = do
    let a :: Word8
a = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int
i forall a. Bits a => a -> Int -> a
`shiftR` Int
16) forall a. Bits a => a -> a -> a
.&. Int
0xff)
    let b :: Word8
b = forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int
i forall a. Bits a => a -> Int -> a
`shiftR` Int
8) forall a. Bits a => a -> a -> a
.&. Int
0xff)
    let c :: Word8
c = forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
i forall a. Bits a => a -> a -> a
.&. Int
0xff)
    forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Putter Word8
putWord8 [Word8
a,Word8
b,Word8
c]

putBytes :: ByteString -> Put
putBytes :: ByteString -> Put
putBytes = ByteString -> Put
putByteString

putOpaque8 :: ByteString -> Put
putOpaque8 :: ByteString -> Put
putOpaque8 ByteString
b = Putter Word8
putWord8 (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
b) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteString -> Put
putBytes ByteString
b

putOpaque16 :: ByteString -> Put
putOpaque16 :: ByteString -> Put
putOpaque16 ByteString
b = Word16 -> Put
putWord16 (forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ ByteString -> Int
B.length ByteString
b) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteString -> Put
putBytes ByteString
b

putOpaque24 :: ByteString -> Put
putOpaque24 :: ByteString -> Put
putOpaque24 ByteString
b = Int -> Put
putWord24 (ByteString -> Int
B.length ByteString
b) forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> ByteString -> Put
putBytes ByteString
b

putInteger16 :: Integer -> Put
putInteger16 :: Integer -> Put
putInteger16 = ByteString -> Put
putOpaque16 forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall ba. ByteArray ba => Integer -> ba
i2osp

putBigNum16 :: BigNum -> Put
putBigNum16 :: BigNum -> Put
putBigNum16 (BigNum ByteString
b) = ByteString -> Put
putOpaque16 ByteString
b

encodeWord16 :: Word16 -> ByteString
encodeWord16 :: Word16 -> ByteString
encodeWord16 = Put -> ByteString
runPut forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word16 -> Put
putWord16

encodeWord32 :: Word32 -> ByteString
encodeWord32 :: Word32 -> ByteString
encodeWord32 = Put -> ByteString
runPut forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Put
putWord32

encodeWord64 :: Word64 -> ByteString
encodeWord64 :: Word64 -> ByteString
encodeWord64 = Put -> ByteString
runPut forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> Put
putWord64be