{-# OPTIONS_HADDOCK hide #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PackageImports #-}
{-# LANGUAGE FlexibleContexts #-}
module Network.Xmpp.Tls where
import Control.Applicative ((<$>))
import qualified Control.Exception.Lifted as Ex
import Control.Monad
import Control.Monad.Except
import Control.Monad.State.Strict
import qualified Data.ByteString as BS
import qualified Data.ByteString.Char8 as BSC8
import qualified Data.ByteString.Lazy as BL
import Data.Conduit
import Data.IORef
import Data.Monoid
import Data.XML.Types
import Network.DNS.Resolver (ResolvConf)
import Network.TLS
import Network.Xmpp.Stream
import Network.Xmpp.Types
import System.Log.Logger (debugM, errorM, infoM)
import System.X509
mkBackend :: StreamHandle -> Backend
mkBackend :: StreamHandle -> Backend
mkBackend StreamHandle
con = Backend { backendSend :: ByteString -> IO ()
backendSend = \ByteString
bs -> forall (f :: * -> *) a. Functor f => f a -> f ()
void (StreamHandle -> ByteString -> IO (Either XmppFailure ())
streamSend StreamHandle
con ByteString
bs)
, backendRecv :: Int -> IO ByteString
backendRecv = forall {m :: * -> *} {e}.
(MonadBase IO m, Exception e) =>
(Int -> m (Either e ByteString)) -> Int -> m ByteString
bufferReceive (StreamHandle -> Int -> IO (Either XmppFailure ByteString)
streamReceive StreamHandle
con)
, backendFlush :: IO ()
backendFlush = StreamHandle -> IO ()
streamFlush StreamHandle
con
, backendClose :: IO ()
backendClose = StreamHandle -> IO ()
streamClose StreamHandle
con
}
where
bufferReceive :: (Int -> m (Either e ByteString)) -> Int -> m ByteString
bufferReceive Int -> m (Either e ByteString)
_ Int
0 = forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
BS.empty
bufferReceive Int -> m (Either e ByteString)
recv Int
n = [ByteString] -> ByteString
BS.concat forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` (Int -> m [ByteString]
go Int
n)
where
go :: Int -> m [ByteString]
go Int
m = do
Either e ByteString
mbBs <- Int -> m (Either e ByteString)
recv Int
m
ByteString
bs <- case Either e ByteString
mbBs of
Left e
e -> forall (m :: * -> *) e a. (MonadBase IO m, Exception e) => e -> m a
Ex.throwIO e
e
Right ByteString
r -> forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
r
case ByteString -> Int
BS.length ByteString
bs of
Int
0 -> forall (m :: * -> *) a. Monad m => a -> m a
return []
Int
l -> if Int
l forall a. Ord a => a -> a -> Bool
< Int
m
then (ByteString
bs forall a. a -> [a] -> [a]
:) forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` Int -> m [ByteString]
go (Int
m forall a. Num a => a -> a -> a
- Int
l)
else forall (m :: * -> *) a. Monad m => a -> m a
return [ByteString
bs]
starttlsE :: Element
starttlsE :: Element
starttlsE = Name -> [(Name, [Content])] -> [Node] -> Element
Element Name
"{urn:ietf:params:xml:ns:xmpp-tls}starttls" [] []
tls :: Stream -> IO (Either XmppFailure ())
tls :: Stream -> IO (Either XmppFailure ())
tls Stream
con = forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap forall (m :: * -> *) a. Monad m => m (m a) -> m a
join
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. IO a -> IO (Either XmppFailure a)
wrapExceptions
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b c. (a -> b -> c) -> b -> a -> c
flip forall a. StateT StreamState IO a -> Stream -> IO a
withStream Stream
con
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e (m :: * -> *) a. ExceptT e m a -> m (Either e a)
runExceptT forall a b. (a -> b) -> a -> b
$ do
StreamConfiguration
conf <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets StreamState -> StreamConfiguration
streamConfiguration
ConnectionState
sState <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets StreamState -> ConnectionState
streamConnectionState
case ConnectionState
sState of
ConnectionState
Plain -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
ConnectionState
Closed -> do
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ String -> String -> IO ()
errorM String
"Pontarius.Xmpp.Tls" String
"The stream is closed."
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError XmppFailure
XmppNoStream
ConnectionState
Finished -> do
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ String -> String -> IO ()
errorM String
"Pontarius.Xmpp.Tls" String
"The stream is finished."
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError XmppFailure
XmppNoStream
ConnectionState
Secured -> do
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ String -> String -> IO ()
errorM String
"Pontarius.Xmpp.Tls" String
"The stream is already secured."
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError XmppFailure
TlsStreamSecured
StreamFeatures
features <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets StreamState -> StreamFeatures
streamFeatures
case (StreamConfiguration -> TlsBehaviour
tlsBehaviour StreamConfiguration
conf, StreamFeatures -> Maybe Bool
streamFeaturesTls StreamFeatures
features) of
(TlsBehaviour
RequireTls , Just Bool
_ ) -> ExceptT XmppFailure (StateT StreamState IO) ()
startTls
(TlsBehaviour
RequireTls , Maybe Bool
Nothing ) -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError XmppFailure
TlsNoServerSupport
(TlsBehaviour
PreferTls , Just Bool
_ ) -> ExceptT XmppFailure (StateT StreamState IO) ()
startTls
(TlsBehaviour
PreferTls , Maybe Bool
Nothing ) -> ExceptT XmppFailure (StateT StreamState IO) ()
skipTls
(TlsBehaviour
PreferPlain , Just Bool
True) -> ExceptT XmppFailure (StateT StreamState IO) ()
startTls
(TlsBehaviour
PreferPlain , Maybe Bool
_ ) -> ExceptT XmppFailure (StateT StreamState IO) ()
skipTls
(TlsBehaviour
RefuseTls , Just Bool
True) -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError XmppFailure
XmppOtherFailure
(TlsBehaviour
RefuseTls , Maybe Bool
_ ) -> ExceptT XmppFailure (StateT StreamState IO) ()
skipTls
where
skipTls :: ExceptT XmppFailure (StateT StreamState IO) ()
skipTls = forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ String -> String -> IO ()
infoM String
"Pontarius.Xmpp.Tls" String
"Skipping TLS negotiation"
startTls :: ExceptT XmppFailure (StateT StreamState IO) ()
startTls = do
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ String -> String -> IO ()
infoM String
"Pontarius.Xmpp.Tls" String
"Running StartTLS"
ClientParams
params <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets forall a b. (a -> b) -> a -> b
$ StreamConfiguration -> ClientParams
tlsParams forall b c a. (b -> c) -> (a -> b) -> a -> c
. StreamState -> StreamConfiguration
streamConfiguration
forall e (m :: * -> *) a. m (Either e a) -> ExceptT e m a
ExceptT forall a b. (a -> b) -> a -> b
$ Element -> StateT StreamState IO (Either XmppFailure ())
pushElement Element
starttlsE
Either XmppFailure Element
answer <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ StateT StreamState IO (Either XmppFailure Element)
pullElement
case Either XmppFailure Element
answer of
Left XmppFailure
e -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError XmppFailure
e
Right (Element Name
"{urn:ietf:params:xml:ns:xmpp-tls}proceed" [] []) ->
forall (m :: * -> *) a. Monad m => a -> m a
return ()
Right (Element Name
"{urn:ietf:params:xml:ns:xmpp-tls}failure" [(Name, [Content])]
_ [Node]
_) -> do
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ String -> String -> IO ()
errorM String
"Pontarius.Xmpp" String
"startTls: TLS initiation failed."
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError XmppFailure
XmppOtherFailure
Right Element
r ->
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ String -> String -> IO ()
errorM String
"Pontarius.Xmpp.Tls" forall a b. (a -> b) -> a -> b
$
String
"Unexpected element: " forall a. [a] -> [a] -> [a]
++ forall a. Show a => a -> String
show Element
r
StreamHandle
hand <- forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets StreamState -> StreamHandle
streamHandle
(ConduitT () ByteString IO ()
_raw, ConduitT ByteString Void IO ()
_snk, ByteString -> IO ()
psh, Int -> IO ByteString
recv, Context
ctx) <- forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) (m1 :: * -> *).
(MonadIO m, MonadIO m1) =>
ClientParams
-> Backend
-> m (ConduitT () ByteString m1 (), ConduitT ByteString Void m1 (),
ByteString -> IO (), Int -> m1 ByteString, Context)
tlsinit ClientParams
params (StreamHandle -> Backend
mkBackend StreamHandle
hand)
let newHand :: StreamHandle
newHand = StreamHandle { streamSend :: ByteString -> IO (Either XmppFailure ())
streamSend = IO () -> IO (Either XmppFailure ())
catchPush forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> IO ()
psh
, streamReceive :: Int -> IO (Either XmppFailure ByteString)
streamReceive = forall a. IO a -> IO (Either XmppFailure a)
wrapExceptions forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IO ByteString
recv
, streamFlush :: IO ()
streamFlush = Context -> IO ()
contextFlush Context
ctx
, streamClose :: IO ()
streamClose = forall (m :: * -> *). MonadIO m => Context -> m ()
bye Context
ctx forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> StreamHandle -> IO ()
streamClose StreamHandle
hand
}
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall a b. (a -> b) -> a -> b
$ forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ( \StreamState
x -> StreamState
x {streamHandle :: StreamHandle
streamHandle = StreamHandle
newHand})
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ String -> String -> IO ()
infoM String
"Pontarius.Xmpp.Tls" String
"Stream Secured."
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either (forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) e a. (MonadBase IO m, Exception e) => e -> m a
Ex.throwIO) forall (m :: * -> *) a. Monad m => a -> m a
return forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift StateT StreamState IO (Either XmppFailure ())
restartStream
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (\StreamState
s -> StreamState
s{streamConnectionState :: ConnectionState
streamConnectionState = ConnectionState
Secured})
forall (m :: * -> *) a. Monad m => a -> m a
return ()
client :: MonadIO m => ClientParams -> Backend -> m Context
client :: forall (m :: * -> *).
MonadIO m =>
ClientParams -> Backend -> m Context
client ClientParams
params Backend
backend = forall (m :: * -> *) backend params.
(MonadIO m, HasBackend backend, TLSParams params) =>
backend -> params -> m Context
contextNew Backend
backend ClientParams
params
tlsinit :: (MonadIO m, MonadIO m1) =>
ClientParams
-> Backend
-> m ( ConduitT () BS.ByteString m1 ()
, ConduitT BS.ByteString Void m1 ()
, BS.ByteString -> IO ()
, Int -> m1 BS.ByteString
, Context
)
tlsinit :: forall (m :: * -> *) (m1 :: * -> *).
(MonadIO m, MonadIO m1) =>
ClientParams
-> Backend
-> m (ConduitT () ByteString m1 (), ConduitT ByteString Void m1 (),
ByteString -> IO (), Int -> m1 ByteString, Context)
tlsinit ClientParams
params Backend
backend = do
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ String -> String -> IO ()
debugM String
"Pontarius.Xmpp.Tls" String
"TLS with debug mode enabled."
CertificateStore
sysCStore <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO IO CertificateStore
getSystemCertificateStore
let params' :: ClientParams
params' = ClientParams
params{clientShared :: Shared
clientShared =
(ClientParams -> Shared
clientShared ClientParams
params){ sharedCAStore :: CertificateStore
sharedCAStore =
CertificateStore
sysCStore forall a. Semigroup a => a -> a -> a
<> Shared -> CertificateStore
sharedCAStore (ClientParams -> Shared
clientShared ClientParams
params)}}
Context
con <- forall (m :: * -> *).
MonadIO m =>
ClientParams -> Backend -> m Context
client ClientParams
params' Backend
backend
forall (m :: * -> *). MonadIO m => Context -> m ()
handshake Context
con
let src :: ConduitT i ByteString m1 b
src = forall (f :: * -> *) a b. Applicative f => f a -> f b
forever forall a b. (a -> b) -> a -> b
$ do
ByteString
dt <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *). MonadIO m => Context -> m ByteString
recvData Context
con
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ String -> String -> IO ()
debugM String
"Pontarius.Xmpp.Tls" (String
"In :" forall a. [a] -> [a] -> [a]
++ ByteString -> String
BSC8.unpack ByteString
dt)
forall (m :: * -> *) o i. Monad m => o -> ConduitT i o m ()
yield ByteString
dt
let snk :: ConduitT ByteString o m1 ()
snk = do
Maybe ByteString
d <- forall (m :: * -> *) i o. Monad m => ConduitT i o m (Maybe i)
await
case Maybe ByteString
d of
Maybe ByteString
Nothing -> forall (m :: * -> *) a. Monad m => a -> m a
return ()
Just ByteString
x -> do
forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
sendData Context
con ([ByteString] -> ByteString
BL.fromChunks [ByteString
x])
ConduitT ByteString o m1 ()
snk
Int -> IO ByteString
readWithBuffer <- forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall a b. (a -> b) -> a -> b
$ IO ByteString -> IO (Int -> IO ByteString)
mkReadBuffer (forall (m :: * -> *). MonadIO m => Context -> m ByteString
recvData Context
con)
forall (m :: * -> *) a. Monad m => a -> m a
return ( forall {i} {b}. ConduitT i ByteString m1 b
src
, forall {o}. ConduitT ByteString o m1 ()
snk
, \ByteString
s -> forall (m :: * -> *). MonadIO m => Context -> ByteString -> m ()
sendData Context
con forall a b. (a -> b) -> a -> b
$ [ByteString] -> ByteString
BL.fromChunks [ByteString
s]
, forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IO ByteString
readWithBuffer
, Context
con
)
mkReadBuffer :: IO BS.ByteString -> IO (Int -> IO BS.ByteString)
mkReadBuffer :: IO ByteString -> IO (Int -> IO ByteString)
mkReadBuffer IO ByteString
recv = do
IORef ByteString
buffer <- forall a. a -> IO (IORef a)
newIORef ByteString
BS.empty
let read' :: Int -> IO ByteString
read' Int
n = do
ByteString
nc <- forall a. IORef a -> IO a
readIORef IORef ByteString
buffer
ByteString
bs <- if ByteString -> Bool
BS.null ByteString
nc then IO ByteString
recv
else forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
nc
let (ByteString
result, ByteString
rest) = Int -> ByteString -> (ByteString, ByteString)
BS.splitAt Int
n ByteString
bs
forall a. IORef a -> a -> IO ()
writeIORef IORef ByteString
buffer ByteString
rest
forall (m :: * -> *) a. Monad m => a -> m a
return ByteString
result
forall (m :: * -> *) a. Monad m => a -> m a
return Int -> IO ByteString
read'
connectTls :: ResolvConf
-> ClientParams
-> String
-> ExceptT XmppFailure IO StreamHandle
connectTls :: ResolvConf
-> ClientParams -> String -> ExceptT XmppFailure IO StreamHandle
connectTls ResolvConf
config ClientParams
params String
host = do
Handle
h <- ResolvConf -> String -> ExceptT XmppFailure IO (Maybe Handle)
connectSrv ResolvConf
config String
host forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \Maybe Handle
h' -> case Maybe Handle
h' of
Maybe Handle
Nothing -> forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError XmppFailure
TcpConnectionFailure
Just Handle
h'' -> forall (m :: * -> *) a. Monad m => a -> m a
return Handle
h''
let hand :: StreamHandle
hand = Handle -> StreamHandle
handleToStreamHandle Handle
h
let params' :: ClientParams
params' = ClientParams
params{clientServerIdentification :: (String, ByteString)
clientServerIdentification
= case ClientParams -> (String, ByteString)
clientServerIdentification ClientParams
params of
(String
"", ByteString
_) -> (String
host, ByteString
"")
(String, ByteString)
csi -> (String, ByteString)
csi
}
(ConduitT () ByteString IO ()
_raw, ConduitT ByteString Void IO ()
_snk, ByteString -> IO ()
psh, Int -> IO ByteString
recv, Context
ctx) <- forall (m :: * -> *) (m1 :: * -> *).
(MonadIO m, MonadIO m1) =>
ClientParams
-> Backend
-> m (ConduitT () ByteString m1 (), ConduitT ByteString Void m1 (),
ByteString -> IO (), Int -> m1 ByteString, Context)
tlsinit ClientParams
params' forall a b. (a -> b) -> a -> b
$ StreamHandle -> Backend
mkBackend StreamHandle
hand
forall (m :: * -> *) a. Monad m => a -> m a
return StreamHandle{ streamSend :: ByteString -> IO (Either XmppFailure ())
streamSend = IO () -> IO (Either XmppFailure ())
catchPush forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> IO ()
psh
, streamReceive :: Int -> IO (Either XmppFailure ByteString)
streamReceive = forall a. IO a -> IO (Either XmppFailure a)
wrapExceptions forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> IO ByteString
recv
, streamFlush :: IO ()
streamFlush = Context -> IO ()
contextFlush Context
ctx
, streamClose :: IO ()
streamClose = forall (m :: * -> *). MonadIO m => Context -> m ()
bye Context
ctx forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> StreamHandle -> IO ()
streamClose StreamHandle
hand
}
wrapExceptions :: IO a -> IO (Either XmppFailure a)
wrapExceptions :: forall a. IO a -> IO (Either XmppFailure a)
wrapExceptions IO a
f = forall (m :: * -> *) a.
MonadBaseControl IO m =>
m a -> [Handler m a] -> m a
Ex.catches (forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ IO a
f)
[ forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
Ex.Handler forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left forall b c a. (b -> c) -> (a -> b) -> a -> c
. IOException -> XmppFailure
XmppIOException
#if !MIN_VERSION_tls(1,8,0)
, Ex.Handler $ wrap . XmppTlsError
#endif
, forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
Ex.Handler forall a b. (a -> b) -> a -> b
$ forall {b}. XmppTlsError -> IO (Either XmppFailure b)
wrap forall b c a. (b -> c) -> (a -> b) -> a -> c
. TLSException -> XmppTlsError
XmppTlsException
, forall (m :: * -> *) a e. Exception e => (e -> m a) -> Handler m a
Ex.Handler forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left
]
where
wrap :: XmppTlsError -> IO (Either XmppFailure b)
wrap = forall (m :: * -> *) a. Monad m => a -> m a
return forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. a -> Either a b
Left forall b c a. (b -> c) -> (a -> b) -> a -> c
. XmppTlsError -> XmppFailure
TlsError