{-# language BangPatterns #-}
{-# language DerivingStrategies #-}
{-# language DeriveAnyClass #-}
{-# language LambdaCase #-}
{-# language ScopedTypeVariables #-}
{-# language TypeFamilies #-}

import Control.Concurrent.Async (concurrently)
import Control.Monad (replicateM_)
import Control.Exception (Exception)
import Control.Exception (throwIO)
import Control.Monad.ST (runST)
import Data.Bool (bool)
import Data.Primitive (ByteArray)
import Data.Word (Word16,Word8)
import GHC.Exts (RealWorld)
import System.Exit (exitFailure)
import System.IO (stderr,hPutStrLn)
import Test.Tasty
import Test.Tasty.HUnit

import qualified Data.Primitive as PM
import qualified Data.Primitive.MVar as PM
import qualified GHC.Exts as E
import qualified Net.IPv4 as IPv4
import qualified Socket.Datagram.IPv4.Spoof as DIS
import qualified Socket.Datagram.IPv4.Undestined as DIU
import qualified Socket.Stream.IPv4 as SI

main :: IO ()
main = do
  canSpoof <- DIS.withSocket (const (pure ())) >>= \case
    Right () -> pure True
    Left e -> case e of
      DIS.SocketPermissionDenied -> pure False
      DIS.SocketFileDescriptorLimit -> do
        hPutStrLn stderr "All ephemeral ports are in use. Terminating."
        exitFailure
  defaultMain (tests canSpoof)

tests :: Bool -> TestTree
tests canSpoof = testGroup "socket"
  [ testGroup "datagram"
    [ testGroup "ipv4"
      [ testGroup "undestined"
        [ testCase "A" testDatagramUndestinedA
        , testCase "B" testDatagramUndestinedB
        , testCase "C" testDatagramUndestinedC
        ]
      , testGroup "spoof" $ if canSpoof
          then
            [ testCase "A" testDatagramSpoofA
            , testCase "B" testDatagramSpoofB
            ]
          else []
      ]
    ]
  , testGroup "stream"
    [ testGroup "ipv4"
      [ testCase "A" testStreamA
      , testGroup "B"
        [ testCase "1MB" (testStreamB 1)
        , testCase "4MB" (testStreamB 4)
        , testCase "32MB" (testStreamB 32)
        ]
      ]
    ]
  ]

unhandled :: Exception e => IO (Either e a) -> IO a
unhandled action = action >>= either throwIO pure

unhandledClose :: Either SI.CloseException () -> a -> IO a
unhandledClose m a = case m of
  Right () -> pure a
  Left e -> throwIO e

data MagicByteMismatch = MagicByteMismatch
  deriving stock (Show,Eq)
  deriving anyclass (Exception)

data NegativeByteCount = NegativeByteCount
  deriving stock (Show,Eq)
  deriving anyclass (Exception)

testDatagramUndestinedA :: Assertion
testDatagramUndestinedA = do
  (m :: PM.MVar RealWorld Word16) <- PM.newEmptyMVar
  (port,received) <- concurrently (sender m) (receiver m)
  received @=? DIU.Message (DIU.Endpoint IPv4.loopback port) message
  where
  message = E.fromList [0,1,2,3] :: ByteArray
  sz = PM.sizeofByteArray message
  sender :: PM.MVar RealWorld Word16 -> IO Word16
  sender m = unhandled $ DIU.withSocket (DIU.Endpoint IPv4.loopback 0) $ \sock srcPort -> do
    dstPort <- PM.takeMVar m
    unhandled $ DIU.send sock (DIU.Endpoint IPv4.loopback dstPort) message 0 sz
    pure srcPort
  receiver :: PM.MVar RealWorld Word16 -> IO DIU.Message
  receiver m = unhandled $ DIU.withSocket (DIU.Endpoint IPv4.loopback 0) $ \sock port -> do
    PM.putMVar m port
    unhandled $ DIU.receiveByteArray sock sz

testDatagramUndestinedB :: Assertion
testDatagramUndestinedB = do
  (m :: PM.MVar RealWorld Word16) <- PM.newEmptyMVar
  (n :: PM.MVar RealWorld ()) <- PM.newEmptyMVar
  (port,received) <- concurrently (sender m n) (receiver m n)
  received @=?
    ( DIU.Message (DIU.Endpoint IPv4.loopback port) message1
    , DIU.Message (DIU.Endpoint IPv4.loopback port) message2
    )
  where
  message1 = E.fromList [0,1,2,3] :: ByteArray
  message2 = E.fromList [4,5,6,8,9,10] :: ByteArray
  sz1 = PM.sizeofByteArray message1
  sz2 = PM.sizeofByteArray message2
  sender :: PM.MVar RealWorld Word16 -> PM.MVar RealWorld () -> IO Word16
  sender m n = unhandled $ DIU.withSocket (DIU.Endpoint IPv4.loopback 0) $ \sock srcPort -> do
    dstPort <- PM.takeMVar m
    unhandled $ DIU.send sock (DIU.Endpoint IPv4.loopback dstPort) message1 0 sz1
    unhandled $ DIU.send sock (DIU.Endpoint IPv4.loopback dstPort) message2 0 sz2
    PM.putMVar n ()
    pure srcPort
  receiver :: PM.MVar RealWorld Word16 -> PM.MVar RealWorld () -> IO (DIU.Message,DIU.Message)
  receiver m n = unhandled $ DIU.withSocket (DIU.Endpoint IPv4.loopback 0) $ \sock port -> do
    PM.putMVar m port
    PM.takeMVar n
    msgs <- unhandled $ DIU.receiveMany sock 3 (max sz1 sz2)
    if PM.sizeofArray msgs == 2
      then pure (PM.indexArray msgs 0, PM.indexArray msgs 1)
      else fail "received a number of messages other than 2"
      
testDatagramUndestinedC :: Assertion
testDatagramUndestinedC = do
  (m :: PM.MVar RealWorld Word16) <- PM.newEmptyMVar
  (n :: PM.MVar RealWorld ()) <- PM.newEmptyMVar
  (port,received) <- concurrently (sender m n) (receiver m n)
  received @=?
    ( DIU.Message (DIU.Endpoint IPv4.loopback port) message1
    , DIU.Message (DIU.Endpoint IPv4.loopback port) message2
    , DIU.Message (DIU.Endpoint IPv4.loopback port) message3
    )
  where
  message1 = E.fromList (enumFromTo 0 9):: ByteArray
  message2 = E.fromList (enumFromTo 10 10) :: ByteArray
  message3 = E.fromList (enumFromTo 11 25) :: ByteArray
  sz1 = PM.sizeofByteArray message1
  sz2 = PM.sizeofByteArray message2
  sz3 = PM.sizeofByteArray message3
  sender :: PM.MVar RealWorld Word16 -> PM.MVar RealWorld () -> IO Word16
  sender m n = unhandled $ DIU.withSocket (DIU.Endpoint IPv4.loopback 0) $ \sock srcPort -> do
    dstPort <- PM.takeMVar m
    unhandled $ DIU.send sock (DIU.Endpoint IPv4.loopback dstPort) message1 0 sz1
    unhandled $ DIU.send sock (DIU.Endpoint IPv4.loopback dstPort) message2 0 sz2
    unhandled $ DIU.send sock (DIU.Endpoint IPv4.loopback dstPort) message3 0 sz3
    PM.putMVar n ()
    pure srcPort
  receiver :: PM.MVar RealWorld Word16 -> PM.MVar RealWorld () -> IO (DIU.Message,DIU.Message,DIU.Message)
  receiver m n = unhandled $ DIU.withSocket (DIU.Endpoint IPv4.loopback 0) $ \sock port -> do
    PM.putMVar m port
    PM.takeMVar n
    msgsX <- unhandled $ DIU.receiveMany sock 2 (max sz1 sz2)
    (msg1,msg2) <- if PM.sizeofArray msgsX == 2
      then pure (PM.indexArray msgsX 0, PM.indexArray msgsX 1)
      else fail "received a number of messages other than 2"
    msgsY <- unhandled $ DIU.receiveMany sock 2 sz3
    msg3 <- if PM.sizeofArray msgsY == 1
      then pure (PM.indexArray msgsY 0)
      else fail "received a number of messages other than 2"
    pure (msg1,msg2,msg3)
      

-- This test involves a made up protocol that goes like this:
-- The sender always starts by sending the length of the rest
-- of the payload as a native-endian encoded machine-sized int.
-- (This could only ever work for a machine that is communicating
-- with itself). Then, it sends a bytearray of that specified
-- length. Then, both ends are expected to shutdown their sides
-- of the connection.
testStreamA :: Assertion
testStreamA = do
  (m :: PM.MVar RealWorld Word16) <- PM.newEmptyMVar
  ((),received) <- concurrently (sender m) (receiver m)
  received @=? message
  where
  message = E.fromList (enumFromTo 0 (100 :: Word8)) :: ByteArray
  sz = PM.sizeofByteArray message
  szb = runST $ do
    marr <- PM.newByteArray (PM.sizeOf (undefined :: Int))
    PM.writeByteArray marr 0 sz
    PM.unsafeFreezeByteArray marr
  sender :: PM.MVar RealWorld Word16 -> IO ()
  sender m = do
    dstPort <- PM.takeMVar m
    unhandled $ SI.withConnection (DIU.Endpoint IPv4.loopback dstPort) unhandledClose $ \conn -> do
      unhandled $ SI.sendByteArray conn szb
      unhandled $ SI.sendByteArray conn message
  receiver :: PM.MVar RealWorld Word16 -> IO ByteArray
  receiver m = unhandled $ SI.withListener (SI.Endpoint IPv4.loopback 0) $ \listener port -> do
    PM.putMVar m port
    unhandled $ SI.withAccepted listener unhandledClose $ \conn _ -> do
      serializedSize <- unhandled $ SI.receiveByteArray conn (PM.sizeOf (undefined :: Int))
      let theSize = PM.indexByteArray serializedSize 0 :: Int
      result <- unhandled $ SI.receiveByteArray conn theSize
      pure result

-- The sender sends a large amount of traffic that may exceed
-- the size of the operating system's TCP send buffer. The 
-- amount is configurable because the test suite wants to
-- check this for several values.
testStreamB :: Int -> Assertion
testStreamB megabytes = do
  (m :: PM.MVar RealWorld Word16) <- PM.newEmptyMVar
  ((),()) <- concurrently (sender m) (receiver m)
  pure ()
  where
  message = E.fromList (replicate (32 * 1024) magicByte) :: ByteArray
  chunkSize = PM.sizeofByteArray message
  sender :: PM.MVar RealWorld Word16 -> IO ()
  sender m = do
    dstPort <- PM.takeMVar m
    unhandled $ SI.withConnection (DIU.Endpoint IPv4.loopback dstPort) unhandledClose $ \conn -> do
      replicateM_ (32 * megabytes) $ unhandled $ SI.sendByteArray conn message
  receiver :: PM.MVar RealWorld Word16 -> IO ()
  receiver m = unhandled $ SI.withListener (SI.Endpoint IPv4.loopback 0) $ \listener port -> do
    PM.putMVar m port
    unhandled $ SI.withAccepted listener unhandledClose $ \conn _ -> do
      buffer <- PM.newByteArray chunkSize
      let receiveLoop !remaining
            | remaining > 0 = do
                let recvSize = min remaining chunkSize
                PM.setByteArray buffer 0 chunkSize (0 :: Word8)
                bytesReceived <- unhandled (SI.receiveBoundedMutableByteArraySlice conn recvSize buffer 0)
                verifyClientSendBytes buffer bytesReceived >>= \case
                  True -> receiveLoop (remaining - bytesReceived)
                  False -> throwIO MagicByteMismatch
            | remaining == 0 = pure ()
            | otherwise = throwIO NegativeByteCount
      receiveLoop (32 * megabytes * chunkSize)
      pure ()

magicByte :: Word8
magicByte = 0xFA

verifyClientSendBytes :: PM.MutableByteArray RealWorld -> Int -> IO Bool
verifyClientSendBytes arr len = go (len - 1)
  where
  go !ix = if ix >= 0
    then do
      w <- PM.readByteArray arr ix
      if w == magicByte then go (ix - 1) else pure False
    else pure True

-- Here, the sender spoofs its ip address and port.
testDatagramSpoofA :: Assertion
testDatagramSpoofA = do
  (m :: PM.MVar RealWorld Word16) <- PM.newEmptyMVar
  ((),received) <- concurrently (sender m) (receiver m)
  received @=? DIU.Message
    (DIU.Endpoint (IPv4.fromOctets 8 7 6 5) 60000)
    payload
  where
  sz = 16
  payload = E.fromList (enumFromTo (0 :: Word8) (fromIntegral sz - 1))
  sender :: PM.MVar RealWorld Word16 -> IO ()
  sender m = unhandled $ DIS.withSocket $ \sock -> do
    dstPort <- PM.takeMVar m
    marr <- PM.newByteArray sz
    PM.copyByteArray marr 0 payload 0 sz
    unhandled $ DIS.sendMutableByteArray sock
      (DIU.Endpoint (IPv4.fromOctets 8 7 6 5) 60000)
      (DIU.Endpoint IPv4.loopback dstPort)
      marr 0 sz
  receiver :: PM.MVar RealWorld Word16 -> IO DIU.Message
  receiver m = unhandled $ DIU.withSocket (DIU.Endpoint IPv4.loopback 0) $ \sock port -> do
    PM.putMVar m port
    unhandled $ DIU.receiveByteArray sock 500
      
-- Here, the sender spoofs its ip address and port twice, picking a
-- different port each time.
testDatagramSpoofB :: Assertion
testDatagramSpoofB = do
  (m :: PM.MVar RealWorld Word16) <- PM.newEmptyMVar
  ((),received) <- concurrently (sender m) (receiver m)
  received @=?
    ( DIU.Message
      (DIU.Endpoint (IPv4.fromOctets 8 7 6 5) 60000)
      payloadA
    , DIU.Message
      (DIU.Endpoint (IPv4.fromOctets 9 8 7 6) 59999)
      payloadB
    )
  where
  sz = 16
  payloadA = E.fromList (enumFromTo (1 :: Word8) (fromIntegral sz))
  payloadB = E.fromList (enumFromTo (2 :: Word8) (fromIntegral sz + 1))
  sender :: PM.MVar RealWorld Word16 -> IO ()
  sender m = unhandled $ DIS.withSocket $ \sock -> do
    dstPort <- PM.takeMVar m
    marrA <- PM.newByteArray sz
    marrB <- PM.newByteArray sz
    PM.copyByteArray marrA 0 payloadA 0 sz
    PM.copyByteArray marrB 0 payloadB 0 sz
    unhandled $ DIS.sendMutableByteArray sock
      (DIU.Endpoint (IPv4.fromOctets 8 7 6 5) 60000)
      (DIU.Endpoint IPv4.loopback dstPort)
      marrA 0 sz
    unhandled $ DIS.sendMutableByteArray sock
      (DIU.Endpoint (IPv4.fromOctets 9 8 7 6) 59999)
      (DIU.Endpoint IPv4.loopback dstPort)
      marrB 0 sz
  receiver :: PM.MVar RealWorld Word16 -> IO (DIU.Message,DIU.Message)
  receiver m = unhandled $ DIU.withSocket (DIU.Endpoint IPv4.loopback 0) $ \sock port -> do
    PM.putMVar m port
    msg1 <- unhandled $ DIU.receiveByteArray sock 500
    msg2 <- unhandled $ DIU.receiveByteArray sock 500
    return (msg1,msg2)