{-# Options_GHC -Wno-unused-do-bind #-}
{-# Language OverloadedStrings #-}
module Client.Network.Async
  ( NetworkConnection
  , NetworkEvent(..)
  , createConnection
  , Client.Network.Async.send
  , Client.Network.Async.recv
  
  , abortConnection
  , TerminationReason(..)
  ) where
import           Client.Configuration.ServerSettings
import           Client.Network.Connect
import           Control.Concurrent
import           Control.Concurrent.Async
import           Control.Concurrent.STM
import           Control.Exception
import           Control.Lens
import           Control.Monad
import           Data.ByteString (ByteString)
import qualified Data.ByteString as B
import           Data.Foldable
import           Data.Time
import           Data.List
import           Irc.RateLimit
import           Hookup
import           Data.Text (Text)
import qualified Data.Text as Text
import           Data.Word (Word8)
import           Numeric (showHex)
import           OpenSSL.X509 (printX509)
data NetworkConnection = NetworkConnection
  { connOutQueue :: TQueue ByteString
  , connInQueue  :: TQueue NetworkEvent
  , connAsync    :: Async ()
  }
data NetworkEvent
  
  = NetworkOpen  !ZonedTime [Text]
  
  | NetworkLine  !ZonedTime !ByteString
  
  | NetworkError !ZonedTime !SomeException
  
  | NetworkClose !ZonedTime
instance Show NetworkConnection where
  showsPrec p _ = showParen (p > 10)
                $ showString "NetworkConnection _"
data TerminationReason
  = PingTimeout      
  | ForcedDisconnect 
  | StsUpgrade       
  | BadCertFingerprint ByteString (Maybe ByteString)
  | BadPubkeyFingerprint ByteString (Maybe ByteString)
  deriving Show
instance Exception TerminationReason where
  displayException PingTimeout      = "connection killed due to ping timeout"
  displayException ForcedDisconnect = "connection killed by client command"
  displayException StsUpgrade       = "connection killed by sts policy"
  displayException (BadCertFingerprint expect got) =
       "Expected certificate fingerprint: " ++ formatDigest expect ++
       "; got: "    ++ maybe "none" formatDigest got
  displayException (BadPubkeyFingerprint expect got) =
       "Expected public key fingerprint: " ++ formatDigest expect ++
       "; got: "    ++ maybe "none" formatDigest got
send :: NetworkConnection -> ByteString -> IO ()
send c msg = atomically (writeTQueue (connOutQueue c) msg)
recv :: NetworkConnection -> STM [NetworkEvent]
recv = flushTQueue . connInQueue
abortConnection :: TerminationReason -> NetworkConnection -> IO ()
abortConnection reason c = cancelWith (connAsync c) reason
createConnection ::
  Int  ->
  ServerSettings ->
  IO NetworkConnection
createConnection delay settings =
   do outQueue <- newTQueueIO
      inQueue  <- newTQueueIO
      supervisor <- async $
                      threadDelay (delay * 1000000) >>
                      startConnection settings inQueue outQueue
      let recordFailure :: SomeException -> IO ()
          recordFailure ex =
              do now <- getZonedTime
                 atomically (writeTQueue inQueue (NetworkError now ex))
          recordNormalExit :: IO ()
          recordNormalExit =
            do now <- getZonedTime
               atomically (writeTQueue inQueue (NetworkClose now))
      
      
      
      forkIO $ do outcome <- waitCatch supervisor
                  case outcome of
                    Right{} -> recordNormalExit
                    Left e  -> recordFailure e
      return NetworkConnection
        { connOutQueue = outQueue
        , connInQueue  = inQueue
        , connAsync    = supervisor
        }
startConnection ::
  ServerSettings ->
  TQueue NetworkEvent ->
  TQueue ByteString ->
  IO ()
startConnection settings inQueue outQueue =
  do rate <- newRateLimit
               (view ssFloodPenalty settings)
               (view ssFloodThreshold settings)
     withConnection settings $ \h ->
       do for_ (view ssTlsCertFingerprint settings)
            (checkCertFingerprint h)
          for_ (view ssTlsPubkeyFingerprint settings)
            (checkPubkeyFingerprint h)
          reportNetworkOpen h inQueue
          withAsync (sendLoop h outQueue rate) $ \sender ->
            withAsync (receiveLoop h inQueue) $ \receiver ->
              do res <- waitEitherCatch sender receiver
                 case res of
                   Left  Right{}  -> fail "PANIC: sendLoop returned"
                   Right Right{}  -> return ()
                   Left  (Left e) -> throwIO e
                   Right (Left e) -> throwIO e
checkCertFingerprint :: Connection -> Fingerprint -> IO ()
checkCertFingerprint h fp =
  do (expect, got) <-
       case fp of
         FingerprintSha1   expect -> (,) expect <$> getPeerCertFingerprintSha1   h
         FingerprintSha256 expect -> (,) expect <$> getPeerCertFingerprintSha256 h
         FingerprintSha512 expect -> (,) expect <$> getPeerCertFingerprintSha512 h
     unless (Just expect == got)
       (throwIO (BadCertFingerprint expect got))
checkPubkeyFingerprint :: Connection -> Fingerprint -> IO ()
checkPubkeyFingerprint h fp =
  do (expect, got) <-
       case fp of
         FingerprintSha1   expect -> (,) expect <$> getPeerPubkeyFingerprintSha1   h
         FingerprintSha256 expect -> (,) expect <$> getPeerPubkeyFingerprintSha256 h
         FingerprintSha512 expect -> (,) expect <$> getPeerPubkeyFingerprintSha512 h
     unless (Just expect == got)
       (throwIO (BadPubkeyFingerprint expect got))
reportNetworkOpen :: Connection -> TQueue NetworkEvent -> IO ()
reportNetworkOpen h inQueue =
  do now <- getZonedTime
     mbX509 <- getPeerCertificate h
     txts <- case mbX509 of
               Nothing -> return []
               Just x509 -> do str <- printX509 x509
                               return $! reverse (Text.lines (Text.pack str))
     atomically (writeTQueue inQueue (NetworkOpen now txts))
formatDigest :: ByteString -> String
formatDigest
  = intercalate ":"
  . map showByte
  . B.unpack
showByte :: Word8 -> String
showByte x
  | x < 0x10  = '0' : showHex x ""
  | otherwise = showHex x ""
sendLoop :: Connection -> TQueue ByteString -> RateLimit -> IO ()
sendLoop h outQueue rate =
  forever $
    do msg <- atomically (readTQueue outQueue)
       tickRateLimit rate
       Hookup.send h msg
ircMaxMessageLength :: Int
ircMaxMessageLength = 512
receiveLoop :: Connection -> TQueue NetworkEvent -> IO ()
receiveLoop h inQueue =
  do mb <- recvLine h (4*ircMaxMessageLength)
     for_ mb $ \msg ->
       do unless (B.null msg) $ 
            do now <- getZonedTime
               atomically $ writeTQueue inQueue
                          $ NetworkLine now msg
          receiveLoop h inQueue