module Network.TLS.Record.Layer (
    RecordLayer (..),
    newTransparentRecordLayer,
) where

import Network.TLS.Context
import Network.TLS.Imports
import Network.TLS.Record
import Network.TLS.Struct

import qualified Data.ByteString as B

newTransparentRecordLayer
    :: Eq ann
    => (Context -> IO ann)
    -> ([(ann, ByteString)] -> IO ())
    -> (Context -> IO (Either TLSError ByteString))
    -> RecordLayer [(ann, ByteString)]
newTransparentRecordLayer :: forall ann.
Eq ann =>
(Context -> IO ann)
-> ([(ann, ByteString)] -> IO ())
-> (Context -> IO (Either TLSError ByteString))
-> RecordLayer [(ann, ByteString)]
newTransparentRecordLayer Context -> IO ann
get [(ann, ByteString)] -> IO ()
send Context -> IO (Either TLSError ByteString)
recv =
    RecordLayer
        { recordEncode :: Context
-> Record Plaintext -> IO (Either TLSError [(ann, ByteString)])
recordEncode = (Context -> IO ann)
-> Context
-> Record Plaintext
-> IO (Either TLSError [(ann, ByteString)])
forall ann.
(Context -> IO ann)
-> Context
-> Record Plaintext
-> IO (Either TLSError [(ann, ByteString)])
transparentEncodeRecord Context -> IO ann
get
        , recordEncode13 :: Context
-> Record Plaintext -> IO (Either TLSError [(ann, ByteString)])
recordEncode13 = (Context -> IO ann)
-> Context
-> Record Plaintext
-> IO (Either TLSError [(ann, ByteString)])
forall ann.
(Context -> IO ann)
-> Context
-> Record Plaintext
-> IO (Either TLSError [(ann, ByteString)])
transparentEncodeRecord Context -> IO ann
get
        , recordSendBytes :: Context -> [(ann, ByteString)] -> IO ()
recordSendBytes = ([(ann, ByteString)] -> IO ())
-> Context -> [(ann, ByteString)] -> IO ()
forall ann.
Eq ann =>
([(ann, ByteString)] -> IO ())
-> Context -> [(ann, ByteString)] -> IO ()
transparentSendBytes [(ann, ByteString)] -> IO ()
send
        , recordRecv :: Context -> Int -> IO (Either TLSError (Record Plaintext))
recordRecv = \Context
ctx Int
_ -> (Context -> IO (Either TLSError ByteString))
-> Context -> IO (Either TLSError (Record Plaintext))
transparentRecvRecord Context -> IO (Either TLSError ByteString)
recv Context
ctx
        , recordRecv13 :: Context -> IO (Either TLSError (Record Plaintext))
recordRecv13 = (Context -> IO (Either TLSError ByteString))
-> Context -> IO (Either TLSError (Record Plaintext))
transparentRecvRecord Context -> IO (Either TLSError ByteString)
recv
        }

transparentEncodeRecord
    :: (Context -> IO ann)
    -> Context
    -> Record Plaintext
    -> IO (Either TLSError [(ann, ByteString)])
transparentEncodeRecord :: forall ann.
(Context -> IO ann)
-> Context
-> Record Plaintext
-> IO (Either TLSError [(ann, ByteString)])
transparentEncodeRecord Context -> IO ann
_ Context
_ (Record ProtocolType
ProtocolType_ChangeCipherSpec Version
_ Fragment Plaintext
_) =
    Either TLSError [(ann, ByteString)]
-> IO (Either TLSError [(ann, ByteString)])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError [(ann, ByteString)]
 -> IO (Either TLSError [(ann, ByteString)]))
-> Either TLSError [(ann, ByteString)]
-> IO (Either TLSError [(ann, ByteString)])
forall a b. (a -> b) -> a -> b
$ [(ann, ByteString)] -> Either TLSError [(ann, ByteString)]
forall a b. b -> Either a b
Right []
transparentEncodeRecord Context -> IO ann
_ Context
_ (Record ProtocolType
ProtocolType_Alert Version
_ Fragment Plaintext
_) =
    -- all alerts are silent and must be transported externally based on
    -- TLS exceptions raised by the library
    Either TLSError [(ann, ByteString)]
-> IO (Either TLSError [(ann, ByteString)])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError [(ann, ByteString)]
 -> IO (Either TLSError [(ann, ByteString)]))
-> Either TLSError [(ann, ByteString)]
-> IO (Either TLSError [(ann, ByteString)])
forall a b. (a -> b) -> a -> b
$ [(ann, ByteString)] -> Either TLSError [(ann, ByteString)]
forall a b. b -> Either a b
Right []
transparentEncodeRecord Context -> IO ann
get Context
ctx (Record ProtocolType
_ Version
_ Fragment Plaintext
frag) =
    Context -> IO ann
get Context
ctx IO ann
-> (ann -> IO (Either TLSError [(ann, ByteString)]))
-> IO (Either TLSError [(ann, ByteString)])
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ann
a -> Either TLSError [(ann, ByteString)]
-> IO (Either TLSError [(ann, ByteString)])
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Either TLSError [(ann, ByteString)]
 -> IO (Either TLSError [(ann, ByteString)]))
-> Either TLSError [(ann, ByteString)]
-> IO (Either TLSError [(ann, ByteString)])
forall a b. (a -> b) -> a -> b
$ [(ann, ByteString)] -> Either TLSError [(ann, ByteString)]
forall a b. b -> Either a b
Right [(ann
a, Fragment Plaintext -> ByteString
forall a. Fragment a -> ByteString
fragmentGetBytes Fragment Plaintext
frag)]

transparentSendBytes
    :: Eq ann
    => ([(ann, ByteString)] -> IO ())
    -> Context
    -> [(ann, ByteString)]
    -> IO ()
transparentSendBytes :: forall ann.
Eq ann =>
([(ann, ByteString)] -> IO ())
-> Context -> [(ann, ByteString)] -> IO ()
transparentSendBytes [(ann, ByteString)] -> IO ()
send Context
_ [(ann, ByteString)]
input =
    [(ann, ByteString)] -> IO ()
send
        [ (ann
a, ByteString
bs) | (ann
a, [ByteString]
frgs) <- [(ann, ByteString)] -> [(ann, [ByteString])]
forall ann val. Eq ann => [(ann, val)] -> [(ann, [val])]
compress [(ann, ByteString)]
input, let bs :: ByteString
bs = [ByteString] -> ByteString
B.concat [ByteString]
frgs, Bool -> Bool
not (ByteString -> Bool
B.null ByteString
bs)
        ]

transparentRecvRecord
    :: (Context -> IO (Either TLSError ByteString))
    -> Context
    -> IO (Either TLSError (Record Plaintext))
transparentRecvRecord :: (Context -> IO (Either TLSError ByteString))
-> Context -> IO (Either TLSError (Record Plaintext))
transparentRecvRecord Context -> IO (Either TLSError ByteString)
recv Context
ctx =
    (ByteString -> Record Plaintext)
-> Either TLSError ByteString -> Either TLSError (Record Plaintext)
forall a b. (a -> b) -> Either TLSError a -> Either TLSError b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (ProtocolType -> Version -> Fragment Plaintext -> Record Plaintext
forall a. ProtocolType -> Version -> Fragment a -> Record a
Record ProtocolType
ProtocolType_Handshake Version
TLS12 (Fragment Plaintext -> Record Plaintext)
-> (ByteString -> Fragment Plaintext)
-> ByteString
-> Record Plaintext
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Fragment Plaintext
fragmentPlaintext) (Either TLSError ByteString -> Either TLSError (Record Plaintext))
-> IO (Either TLSError ByteString)
-> IO (Either TLSError (Record Plaintext))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Context -> IO (Either TLSError ByteString)
recv Context
ctx

compress :: Eq ann => [(ann, val)] -> [(ann, [val])]
compress :: forall ann val. Eq ann => [(ann, val)] -> [(ann, [val])]
compress [] = []
compress ((ann
a, val
v) : [(ann, val)]
xs) =
    let ([(ann, val)]
ys, [(ann, val)]
zs) = ((ann, val) -> Bool)
-> [(ann, val)] -> ([(ann, val)], [(ann, val)])
forall a. (a -> Bool) -> [a] -> ([a], [a])
span ((ann -> ann -> Bool
forall a. Eq a => a -> a -> Bool
== ann
a) (ann -> Bool) -> ((ann, val) -> ann) -> (ann, val) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ann, val) -> ann
forall a b. (a, b) -> a
fst) [(ann, val)]
xs
     in (ann
a, val
v val -> [val] -> [val]
forall a. a -> [a] -> [a]
: ((ann, val) -> val) -> [(ann, val)] -> [val]
forall a b. (a -> b) -> [a] -> [b]
map (ann, val) -> val
forall a b. (a, b) -> b
snd [(ann, val)]
ys) (ann, [val]) -> [(ann, [val])] -> [(ann, [val])]
forall a. a -> [a] -> [a]
: [(ann, val)] -> [(ann, [val])]
forall ann val. Eq ann => [(ann, val)] -> [(ann, [val])]
compress [(ann, val)]
zs