{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE OverloadedStrings #-}
module Transit.Internal.FileTransfer
( sendFile
, receiveFile
, MessageType(..)
)
where
import Protolude
import qualified Data.Aeson as Aeson
import qualified Conduit as C
import qualified Data.Set as Set
import qualified Data.ByteString.Lazy as BL
import Network.Socket (socketPort, Socket)
import System.FilePath ((</>))
import System.Directory (removeFile, getTemporaryDirectory)
import System.IO.Temp (createTempDirectory)
import qualified MagicWormhole
import Transit.Internal.Errors (Error(..))
import Transit.Internal.Crypto (CipherText(..))
import Transit.Internal.Network
( tcpListener
, buildHints
, buildRelayHints
, startServer
, startClient
, closeConnection
, RelayEndpoint
, CommunicationError(..)
, TransitEndpoint(..))
import Transit.Internal.Peer
( makeRecordKeys
, senderHandshakeExchange
, senderTransitExchange
, senderOfferExchange
, receiveWormholeMessage
, sendTransitMsg
, sendWormholeMessage
, receiverHandshakeExchange
, makeAckMessage
, generateTransitSide
, sendRecord
, receiveRecord
, unzipInto)
import Transit.Internal.Messages
( TransitMsg( Transit, Answer )
, Ability(..)
, AbilityV1(..)
, Ack( FileAck )
, TransitAck (..))
import Transit.Internal.Pipeline
( sendPipeline
, receivePipeline)
data MessageType
= TMsg Text
| TFile FilePath
deriving (Show, Eq)
transitPurpose :: MagicWormhole.AppID -> ByteString
transitPurpose (MagicWormhole.AppID appID) = toS appID <> "/transit-key"
sendAckMessage :: TransitEndpoint -> ByteString -> IO (Either Error ())
sendAckMessage (TransitEndpoint ep _ key) sha256Sum = do
let ackMessage = makeAckMessage key sha256Sum
case ackMessage of
Right (CipherText encMsg) -> do
res <- sendRecord ep encMsg
return $ bimap NetworkError (const ()) res
Left e -> return $ Left (CipherError e)
receiveAckMessage :: TransitEndpoint -> IO (Either Error Text)
receiveAckMessage (TransitEndpoint ep _ key) = do
ackBytes <- (fmap . fmap) BL.fromStrict (receiveRecord ep key)
case ackBytes of
Left e -> return $ Left (CipherError e)
Right ack' ->
case Aeson.eitherDecode ack' of
Right (TransitAck msg checksum) | msg == "ok" -> return (Right checksum)
| otherwise -> return $ Left (NetworkError (TransitError "transit ack failure"))
Left s -> return $ Left (NetworkError (TransitError (toS ("transit ack failure: " <> s))))
establishSenderTransit :: MagicWormhole.EncryptedConnection -> RelayEndpoint -> MagicWormhole.AppID -> IO (Either Error TransitEndpoint)
establishSenderTransit conn transitserver appid = do
sock' <- tcpListener
portnum <- socketPort sock'
side <- generateTransitSide
ourHints <- buildHints portnum transitserver
let ourRelayHints = buildRelayHints transitserver
transitResp <- senderTransitExchange conn (Set.toList ourHints)
case transitResp of
Left s -> return $ Left (NetworkError s)
Right (Transit _peerAbilities peerHints) -> do
let allHints = Set.toList $ ourRelayHints <> peerHints
transitEndpoint <- race (startServer sock') (startClient allHints)
let ep = either identity identity transitEndpoint
case ep of
Left e -> return (Left (NetworkError e))
Right endpoint -> do
let transitKey = MagicWormhole.deriveKey conn (transitPurpose appid)
recordKeys = makeRecordKeys transitKey
case recordKeys of
Left e -> return (Left (CipherError e))
Right (sRecordKey, rRecordKey) -> do
handshake <- senderHandshakeExchange endpoint transitKey side
case handshake of
Left e -> return (Left (HandshakeError e))
Right _ -> return $ Right (TransitEndpoint endpoint sRecordKey rRecordKey)
Right _ -> return $ Left (NetworkError (UnknownPeerMessage "Could not decode message"))
establishReceiverTransit :: MagicWormhole.EncryptedConnection -> RelayEndpoint -> MagicWormhole.AppID -> TransitMsg -> Socket -> IO (Either Error TransitEndpoint)
establishReceiverTransit conn transitserver appid (Transit _peerAbilities peerHints) socket = do
let ourRelayHints = buildRelayHints transitserver
side <- generateTransitSide
let allHints = Set.toList (peerHints <> ourRelayHints)
let transitKey = MagicWormhole.deriveKey conn (transitPurpose appid)
transitEndpoint <- race (startServer socket) (startClient allHints)
let ep = either identity identity transitEndpoint
case ep of
Left e -> return (Left (NetworkError e))
Right endpoint -> do
let recordKeys = makeRecordKeys transitKey
case recordKeys of
Left e -> return $ Left (CipherError e)
Right (sRecordKey, rRecordKey) -> do
handshake <- receiverHandshakeExchange endpoint transitKey side
case handshake of
Left e -> return (Left (HandshakeError e))
Right _ -> return $ Right (TransitEndpoint endpoint sRecordKey rRecordKey)
establishReceiverTransit _ _ _ _ _ = return $ Left (NetworkError (UnknownPeerMessage "Could not recognize the message"))
sendFile :: MagicWormhole.EncryptedConnection -> RelayEndpoint -> MagicWormhole.AppID -> FilePath -> IO (Either Error ())
sendFile conn transitserver appid filepath = do
endpoint <- establishSenderTransit conn transitserver appid
case endpoint of
Left e -> return $ Left e
Right ep -> do
offerResp <- senderOfferExchange conn filepath
case offerResp of
Left s -> return (Left (NetworkError (OfferError s)))
Right pathToSend -> do
(rxAckMsg, txSha256Hash) <-
finally
(do
(txSha256Hash, _) <- C.runConduitRes (sendPipeline pathToSend ep)
rxAckMsg <- receiveAckMessage ep
return (rxAckMsg, txSha256Hash))
(closeConnection ep)
case rxAckMsg of
Right rxSha256Hash ->
if txSha256Hash /= rxSha256Hash
then return $ Left (NetworkError (Sha256SumError "sha256 mismatch"))
else return (Right ())
Left e -> return $ Left e
receiveFile :: MagicWormhole.EncryptedConnection -> RelayEndpoint -> MagicWormhole.AppID -> TransitMsg -> IO (Either Error ())
receiveFile conn transitserver appid transit = do
let abilities' = [Ability DirectTcpV1, Ability RelayV1]
s <- tcpListener
portnum <- socketPort s
ourHints <- buildHints portnum transitserver
sendTransitMsg conn abilities' (Set.toList ourHints)
offerMsg <- receiveWormholeMessage conn
case Aeson.eitherDecode (toS offerMsg) of
Left err -> return $ Left (NetworkError (OfferError $ "unable to decode offer msg: " <> toS err))
Right (MagicWormhole.File name size) -> rxFile s name size
Right (MagicWormhole.Directory _mode name zipSize _ _uncompressedSize) -> do
systemTmpDir <- getTemporaryDirectory
tmpDir <- createTempDirectory systemTmpDir "wormhole"
let zipFile = tmpDir </> (toS name)
_ <- rxFile s zipFile zipSize
_ <- unzipInto (toS name) zipFile
Right <$> removeFile zipFile
Right _ -> return $ Left (NetworkError (UnknownPeerMessage "cannot decipher the message from peer"))
where
rxFile socket name size = do
let ans = Answer (FileAck "ok")
sendWormholeMessage conn (Aeson.encode ans)
endpoint <- establishReceiverTransit conn transitserver appid transit socket
case endpoint of
Left e -> return $ Left e
Right ep -> do
_ <- finally
(do
(rxSha256Sum, ()) <- C.runConduitRes $ receivePipeline name (fromIntegral size) ep
sendAckMessage ep (toS rxSha256Sum))
(closeConnection ep)
return $ Right ()