{-# LANGUAGE ExistentialQuantification  #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE TypeFamilies               #-}
module Pinch.Internal.RPC
  ( Channel(..)
  , createChannel
  , createChannel1
  , readMessage
  , writeMessage

  , ReadResult(..)

  , ServiceName(..)
  , ThriftResult(..)

  , Unit(..)
  ) where

import           Data.Hashable            (Hashable (..))
import           Data.String              (IsString (..))
import           Data.Typeable            (Typeable)

import qualified Data.HashMap.Strict      as HM
import qualified Data.Text                as T

import           Pinch.Internal.Message
import           Pinch.Internal.Pinchable (Pinchable (..), Tag)
import           Pinch.Internal.TType     (TStruct)
import           Pinch.Internal.Value     (Value (..))
import           Pinch.Protocol           (Protocol, deserializeMessage',
                                           serializeMessage)
import           Pinch.Transport          (Connection, ReadResult (..),
                                           Transport)

import qualified Pinch.Transport          as Transport

-- | A bi-directional channel to read/write Thrift messages.
data Channel = Channel
  { Channel -> Transport
cTransportIn  :: !Transport
  , Channel -> Transport
cTransportOut :: !Transport
  , Channel -> Protocol
cProtocolIn   :: !Protocol
  , Channel -> Protocol
cProtocolOut  :: !Protocol
  }

-- | Creates a channel using the same transport/protocol for both directions.
createChannel :: Connection c => c -> (c -> IO Transport) -> Protocol -> IO Channel
createChannel :: c -> (c -> IO Transport) -> Protocol -> IO Channel
createChannel c
c c -> IO Transport
t Protocol
p = do
  Transport
t' <- c -> IO Transport
t c
c
  Channel -> IO Channel
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Channel -> IO Channel) -> Channel -> IO Channel
forall a b. (a -> b) -> a -> b
$ Transport -> Transport -> Protocol -> Protocol -> Channel
Channel Transport
t' Transport
t' Protocol
p Protocol
p

-- | Creates a channel.
createChannel1 :: (Transport, Protocol) -> (Transport, Protocol) -> Channel
createChannel1 :: (Transport, Protocol) -> (Transport, Protocol) -> Channel
createChannel1 (Transport
tIn, Protocol
pIn) (Transport
tOut, Protocol
pOut) = Transport -> Transport -> Protocol -> Protocol -> Channel
Channel Transport
tIn Transport
tOut Protocol
pIn Protocol
pOut

readMessage :: Channel -> IO (ReadResult Message)
readMessage :: Channel -> IO (ReadResult Message)
readMessage Channel
chan = Transport -> forall a. Get a -> IO (ReadResult a)
Transport.readMessage (Channel -> Transport
cTransportIn Channel
chan) (Get Message -> IO (ReadResult Message))
-> Get Message -> IO (ReadResult Message)
forall a b. (a -> b) -> a -> b
$ Protocol -> Get Message
deserializeMessage' (Channel -> Protocol
cProtocolIn Channel
chan)

writeMessage :: Channel -> Message -> IO ()
writeMessage :: Channel -> Message -> IO ()
writeMessage Channel
chan Message
msg = Transport -> Builder -> IO ()
Transport.writeMessage (Channel -> Transport
cTransportOut Channel
chan) (Builder -> IO ()) -> Builder -> IO ()
forall a b. (a -> b) -> a -> b
$ Protocol -> Message -> Builder
serializeMessage (Channel -> Protocol
cProtocolOut Channel
chan) Message
msg


newtype ServiceName = ServiceName T.Text
  deriving (Typeable, ServiceName -> ServiceName -> Bool
(ServiceName -> ServiceName -> Bool)
-> (ServiceName -> ServiceName -> Bool) -> Eq ServiceName
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ServiceName -> ServiceName -> Bool
$c/= :: ServiceName -> ServiceName -> Bool
== :: ServiceName -> ServiceName -> Bool
$c== :: ServiceName -> ServiceName -> Bool
Eq, Int -> ServiceName -> Int
ServiceName -> Int
(Int -> ServiceName -> Int)
-> (ServiceName -> Int) -> Hashable ServiceName
forall a. (Int -> a -> Int) -> (a -> Int) -> Hashable a
hash :: ServiceName -> Int
$chash :: ServiceName -> Int
hashWithSalt :: Int -> ServiceName -> Int
$chashWithSalt :: Int -> ServiceName -> Int
Hashable)

instance IsString ServiceName where
  fromString :: String -> ServiceName
fromString = Text -> ServiceName
ServiceName (Text -> ServiceName) -> (String -> Text) -> String -> ServiceName
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
T.pack


-- | The Result datatype for a Thrift Service Method.
class (Pinchable a, Tag a ~ TStruct) => ThriftResult a where
  -- | The Haskell type returned when the Thrift call succeeds.
  type ResultType a
  -- | Tries to extract the result from a Thrift call. If the call threw any
  -- of the Thrift exceptions declared for this Thrift service method,
  -- the corresponding Haskell excpetions is thrown using `throwIO`.
  unwrap :: a -> IO (ResultType a)

  -- | Runs the given computation. If it throws any of the exceptions
  -- declared in the Thrift service definition, it is caught and converted
  -- to the corresponding Haskell result datatype constructor.
  wrap :: IO (ResultType a) -> IO a

-- | Result datatype for void methods not throwing any exceptions.
data Unit = Unit

instance Pinchable Unit where
  type Tag Unit = TStruct
  pinch :: Unit -> Value (Tag Unit)
pinch Unit
Unit = HashMap Int16 SomeValue -> Value TStruct
VStruct HashMap Int16 SomeValue
forall a. Monoid a => a
mempty
  unpinch :: Value (Tag Unit) -> Parser Unit
unpinch (VStruct HashMap Int16 SomeValue
xs) | HashMap Int16 SomeValue -> Bool
forall k v. HashMap k v -> Bool
HM.null HashMap Int16 SomeValue
xs = Unit -> Parser Unit
forall (f :: * -> *) a. Applicative f => a -> f a
pure Unit
Unit
  unpinch Value (Tag Unit)
x            = String -> Parser Unit
forall (m :: * -> *) a. MonadFail m => String -> m a
fail (String -> Parser Unit) -> String -> Parser Unit
forall a b. (a -> b) -> a -> b
$ String
"Failed to read void success. Got " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Value TStruct -> String
forall a. Show a => a -> String
show Value TStruct
Value (Tag Unit)
x

instance ThriftResult Unit where
  type ResultType Unit = ()
  wrap :: IO (ResultType Unit) -> IO Unit
wrap IO (ResultType Unit)
m = Unit
Unit Unit -> IO () -> IO Unit
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ IO ()
IO (ResultType Unit)
m
  unwrap :: Unit -> IO (ResultType Unit)
unwrap Unit
Unit = () -> IO ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()