{-# LANGUAGE RankNTypes #-}
module Pinch.Transport
( Transport(..)
, framedTransport
, unframedTransport
, Connection(..)
, ReadResult(..)
) where
import Data.IORef (newIORef, readIORef, writeIORef)
import Network.Socket (Socket)
import Network.Socket.ByteString (sendAll, recv)
import System.IO (Handle)
import qualified Data.ByteString as BS
import qualified Data.Serialize.Get as G
import qualified Pinch.Internal.Builder as B
class Connection c where
cGetSome :: c -> Int -> IO BS.ByteString
cPut :: c -> BS.ByteString -> IO ()
instance Connection Handle where
cPut :: Handle -> ByteString -> IO ()
cPut = Handle -> ByteString -> IO ()
BS.hPut
cGetSome :: Handle -> Int -> IO ByteString
cGetSome = Handle -> Int -> IO ByteString
BS.hGetSome
instance Connection Socket where
cPut :: Socket -> ByteString -> IO ()
cPut = Socket -> ByteString -> IO ()
sendAll
cGetSome :: Socket -> Int -> IO ByteString
cGetSome Socket
s Int
n = Socket -> Int -> IO ByteString
recv Socket
s (Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
n Int
4096)
data ReadResult a
= RRSuccess a
| RRFailure String
| RREOF
deriving (ReadResult a -> ReadResult a -> Bool
(ReadResult a -> ReadResult a -> Bool)
-> (ReadResult a -> ReadResult a -> Bool) -> Eq (ReadResult a)
forall a. Eq a => ReadResult a -> ReadResult a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ReadResult a -> ReadResult a -> Bool
$c/= :: forall a. Eq a => ReadResult a -> ReadResult a -> Bool
== :: ReadResult a -> ReadResult a -> Bool
$c== :: forall a. Eq a => ReadResult a -> ReadResult a -> Bool
Eq, Int -> ReadResult a -> ShowS
[ReadResult a] -> ShowS
ReadResult a -> String
(Int -> ReadResult a -> ShowS)
-> (ReadResult a -> String)
-> ([ReadResult a] -> ShowS)
-> Show (ReadResult a)
forall a. Show a => Int -> ReadResult a -> ShowS
forall a. Show a => [ReadResult a] -> ShowS
forall a. Show a => ReadResult a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ReadResult a] -> ShowS
$cshowList :: forall a. Show a => [ReadResult a] -> ShowS
show :: ReadResult a -> String
$cshow :: forall a. Show a => ReadResult a -> String
showsPrec :: Int -> ReadResult a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> ReadResult a -> ShowS
Show)
data Transport
= Transport
{ Transport -> Builder -> IO ()
writeMessage :: B.Builder -> IO ()
, Transport -> forall a. Get a -> IO (ReadResult a)
readMessage :: forall a . G.Get a -> IO (ReadResult a)
}
framedTransport :: Connection c => c -> IO Transport
framedTransport :: c -> IO Transport
framedTransport c
c = Transport -> IO Transport
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Transport -> IO Transport) -> Transport -> IO Transport
forall a b. (a -> b) -> a -> b
$ (Builder -> IO ())
-> (forall a. Get a -> IO (ReadResult a)) -> Transport
Transport Builder -> IO ()
writeMsg forall a. Get a -> IO (ReadResult a)
readMsg where
writeMsg :: Builder -> IO ()
writeMsg Builder
msg = do
c -> ByteString -> IO ()
forall c. Connection c => c -> ByteString -> IO ()
cPut c
c (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Builder -> ByteString
B.runBuilder (Builder -> ByteString) -> Builder -> ByteString
forall a b. (a -> b) -> a -> b
$ Int32 -> Builder
B.int32BE (Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Int32) -> Int -> Int32
forall a b. (a -> b) -> a -> b
$ Builder -> Int
B.getSize Builder
msg)
c -> ByteString -> IO ()
forall c. Connection c => c -> ByteString -> IO ()
cPut c
c (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Builder -> ByteString
B.runBuilder Builder
msg
readMsg :: Get a -> IO (ReadResult a)
readMsg Get a
p = do
ByteString
szBs <- c -> Int -> IO ByteString
forall c. Connection c => c -> Int -> IO ByteString
getExactly c
c Int
4
if ByteString -> Int
BS.length ByteString
szBs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
4
then
ReadResult a -> IO (ReadResult a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ReadResult a -> IO (ReadResult a))
-> ReadResult a -> IO (ReadResult a)
forall a b. (a -> b) -> a -> b
$ ReadResult a
forall a. ReadResult a
RREOF
else do
let sz :: Either String Int
sz = Int32 -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int32 -> Int) -> Either String Int32 -> Either String Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Get Int32 -> ByteString -> Either String Int32
forall a. Get a -> ByteString -> Either String a
G.runGet Get Int32
G.getInt32be ByteString
szBs
case Either String Int
sz of
Right Int
x -> do
ByteString
msgBs <- c -> Int -> IO ByteString
forall c. Connection c => c -> Int -> IO ByteString
getExactly c
c Int
x
ReadResult a -> IO (ReadResult a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ReadResult a -> IO (ReadResult a))
-> ReadResult a -> IO (ReadResult a)
forall a b. (a -> b) -> a -> b
$ if ByteString -> Int
BS.length ByteString
msgBs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
x
then
ReadResult a
forall a. ReadResult a
RREOF
else
(String -> ReadResult a)
-> (a -> ReadResult a) -> Either String a -> ReadResult a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> ReadResult a
forall a. String -> ReadResult a
RRFailure a -> ReadResult a
forall a. a -> ReadResult a
RRSuccess (Either String a -> ReadResult a)
-> Either String a -> ReadResult a
forall a b. (a -> b) -> a -> b
$ Get a -> ByteString -> Either String a
forall a. Get a -> ByteString -> Either String a
G.runGet Get a
p ByteString
msgBs
Left String
s -> ReadResult a -> IO (ReadResult a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ReadResult a -> IO (ReadResult a))
-> ReadResult a -> IO (ReadResult a)
forall a b. (a -> b) -> a -> b
$ String -> ReadResult a
forall a. String -> ReadResult a
RRFailure (String -> ReadResult a) -> String -> ReadResult a
forall a b. (a -> b) -> a -> b
$ String
"Invalid frame size: " String -> ShowS
forall a. [a] -> [a] -> [a]
++ ShowS
forall a. Show a => a -> String
show String
s
unframedTransport :: Connection c => c -> IO Transport
unframedTransport :: c -> IO Transport
unframedTransport c
c = do
IORef ByteString
readBuffer <- ByteString -> IO (IORef ByteString)
forall a. a -> IO (IORef a)
newIORef ByteString
forall a. Monoid a => a
mempty
Transport -> IO Transport
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Transport -> IO Transport) -> Transport -> IO Transport
forall a b. (a -> b) -> a -> b
$ (Builder -> IO ())
-> (forall a. Get a -> IO (ReadResult a)) -> Transport
Transport Builder -> IO ()
writeMsg (IORef ByteString -> Get a -> IO (ReadResult a)
forall a. IORef ByteString -> Get a -> IO (ReadResult a)
readMsg IORef ByteString
readBuffer)
where
writeMsg :: Builder -> IO ()
writeMsg Builder
msg = c -> ByteString -> IO ()
forall c. Connection c => c -> ByteString -> IO ()
cPut c
c (ByteString -> IO ()) -> ByteString -> IO ()
forall a b. (a -> b) -> a -> b
$ Builder -> ByteString
B.runBuilder Builder
msg
readMsg :: IORef ByteString -> Get a -> IO (ReadResult a)
readMsg IORef ByteString
buf Get a
p = do
ByteString
bs <- IORef ByteString -> IO ByteString
forall a. IORef a -> IO a
readIORef IORef ByteString
buf
ByteString
bs' <- if ByteString -> Bool
BS.null ByteString
bs then IO ByteString
getSome else ByteString -> IO ByteString
forall (f :: * -> *) a. Applicative f => a -> f a
pure ByteString
bs
(ByteString
leftOvers, ReadResult a
r) <- IO ByteString
-> Get a -> ByteString -> IO (ByteString, ReadResult a)
forall a.
IO ByteString
-> Get a -> ByteString -> IO (ByteString, ReadResult a)
runGetWith IO ByteString
getSome Get a
p ByteString
bs'
IORef ByteString -> ByteString -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
buf ByteString
leftOvers
ReadResult a -> IO (ReadResult a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ReadResult a -> IO (ReadResult a))
-> ReadResult a -> IO (ReadResult a)
forall a b. (a -> b) -> a -> b
$ ReadResult a
r
getSome :: IO ByteString
getSome = c -> Int -> IO ByteString
forall c. Connection c => c -> Int -> IO ByteString
cGetSome c
c Int
1024
runGetWith :: IO BS.ByteString -> G.Get a -> BS.ByteString -> IO (BS.ByteString, ReadResult a)
runGetWith :: IO ByteString
-> Get a -> ByteString -> IO (ByteString, ReadResult a)
runGetWith IO ByteString
getBs Get a
p ByteString
initial = Result a -> IO (ByteString, ReadResult a)
forall a. Result a -> IO (ByteString, ReadResult a)
go (Get a -> ByteString -> Result a
forall a. Get a -> ByteString -> Result a
G.runGetPartial Get a
p ByteString
initial)
where
go :: Result a -> IO (ByteString, ReadResult a)
go Result a
r = case Result a
r of
G.Fail String
err ByteString
bs -> do
(ByteString, ReadResult a) -> IO (ByteString, ReadResult a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString
bs, String -> ReadResult a
forall a. String -> ReadResult a
RRFailure String
err)
G.Done a
a ByteString
bs -> do
(ByteString, ReadResult a) -> IO (ByteString, ReadResult a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString
bs, a -> ReadResult a
forall a. a -> ReadResult a
RRSuccess a
a)
G.Partial ByteString -> Result a
cont -> do
ByteString
bs <- IO ByteString
getBs
if ByteString -> Bool
BS.null ByteString
bs
then
(ByteString, ReadResult a) -> IO (ByteString, ReadResult a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (ByteString
bs, ReadResult a
forall a. ReadResult a
RREOF)
else
Result a -> IO (ByteString, ReadResult a)
go (Result a -> IO (ByteString, ReadResult a))
-> Result a -> IO (ByteString, ReadResult a)
forall a b. (a -> b) -> a -> b
$ ByteString -> Result a
cont ByteString
bs
getExactly :: Connection c => c -> Int -> IO BS.ByteString
getExactly :: c -> Int -> IO ByteString
getExactly c
c Int
sz = Builder -> ByteString
B.runBuilder (Builder -> ByteString) -> IO Builder -> IO ByteString
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Builder -> IO Builder
go Int
sz Builder
forall a. Monoid a => a
mempty
where
go :: Int -> B.Builder -> IO B.Builder
go :: Int -> Builder -> IO Builder
go Int
n Builder
b = do
ByteString
bs <- c -> Int -> IO ByteString
forall c. Connection c => c -> Int -> IO ByteString
cGetSome c
c Int
n
let b' :: Builder
b' = Builder
b Builder -> Builder -> Builder
forall a. Semigroup a => a -> a -> a
<> ByteString -> Builder
B.byteString ByteString
bs
case ByteString -> Int
BS.length ByteString
bs of
Int
0 -> Builder -> IO Builder
forall (f :: * -> *) a. Applicative f => a -> f a
pure Builder
forall a. Monoid a => a
mempty
Int
n' | Int
n' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n -> Int -> Builder -> IO Builder
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n') Builder
b'
Int
_ | Bool
otherwise -> Builder -> IO Builder
forall (f :: * -> *) a. Applicative f => a -> f a
pure Builder
b'