-- | The shard logic
{-# LANGUAGE RecursiveDo #-}
{-# LANGUAGE TemplateHaskell #-}

module Calamity.Gateway.Shard
    ( Shard(..)
    , newShard ) where

import           Calamity.Gateway.DispatchEvents
import           Calamity.Gateway.Types
import           Calamity.Internal.Utils
import           Calamity.LogEff
import           Calamity.Types.Token

import           Control.Concurrent
import           Control.Concurrent.Async
import           Control.Concurrent.STM
import           Control.Concurrent.STM.TBMQueue
import           Control.Exception
import           Control.Lens
import           Control.Monad
import           Control.Monad.State.Lazy

import qualified Data.Aeson                      as A
import           Data.Functor
import           Data.Maybe
import           Data.Text.Lazy                  ( Text, stripPrefix )
import           Data.Text.Lazy.Lens
import           Data.Void

import           DiPolysemy                      hiding ( debug, error, info )

import           Fmt

import           Network.WebSockets              ( Connection, ConnectionException(..), receiveData, sendCloseCode
                                                 , sendTextData )

import           Polysemy                        ( Sem )
import qualified Polysemy                        as P
import qualified Polysemy.Async                  as P
import qualified Polysemy.AtomicState            as P
import qualified Polysemy.Error                  as P
import qualified Polysemy.Resource               as P

import           Prelude                         hiding ( error )

import           Wuss

data Websocket m a where
  RunWebsocket :: Text -> Text -> (Connection -> m a) -> Websocket m a

P.makeSem ''Websocket

websocketToIO :: forall r a. P.Member (P.Embed IO) r => Sem (Websocket ': r) a -> Sem r a
websocketToIO = P.interpretH
  (\case
     RunWebsocket host path a -> do
       istate <- P.getInitialStateT
       ma <- P.bindT a

       P.withLowerToIO $ \lower finish -> do
         let done :: Sem (Websocket ': r) x -> IO x
             done = lower . P.raise . websocketToIO

         runSecureClient (host ^. unpacked) 443 (path ^. unpacked)
           (\x -> do
              res <- done (ma $ istate $> x)
              finish
              pure res))

newShardState :: Shard -> ShardState
newShardState shard = ShardState shard Nothing Nothing False Nothing Nothing Nothing

-- | Creates and launches a shard
newShard :: P.Members '[LogEff, P.Embed IO, P.Final IO, P.Async] r
         => Text
         -> Int
         -> Int
         -> Token
         -> TQueue DispatchMessage
         -> Sem r (Shard, Async (Maybe ()))
newShard gateway id count token evtQueue = do
  (shard, stateVar) <- P.embed $ mdo
    cmdQueue' <- newTQueueIO
    stateVar <- newTVarIO (newShardState shard)
    let shard = Shard id count gateway evtQueue cmdQueue' stateVar (rawToken token)
    pure (shard, stateVar)

  let runShard = P.runAtomicStateTVar stateVar shardLoop
  let action = attr "shard-id" id . push "calamity-shard" $ runShard

  thread' <- P.async action

  pure (shard, thread')

sendToWs :: ShardC r => SentDiscordMessage -> Sem r ()
sendToWs data' = do
  wsConn' <- P.atomicGets wsConn
  case wsConn' of
    Just wsConn -> do
      let encodedData = A.encode data'
      debug $ "sending " +|| data' ||+ " encoded to " +|| encodedData ||+ " to gateway"
      P.embed . sendTextData wsConn $ encodedData
    Nothing -> debug "tried to send to closed WS"

fromEitherVoid :: Either a Void -> a
fromEitherVoid (Left a) = a
fromEitherVoid (Right a) = absurd a -- yeet

-- | Catches ws close events and decides if we can restart or not
checkWSClose :: IO a -> IO (Either ControlMessage a)
checkWSClose m = (Right <$> m) `catch` \case
  e@(CloseRequest code _) -> do
    print e
    if code `elem` [1000, 4004, 4010, 4011]
      then pure . Left $ ShutDownShard
      else pure . Left $ RestartShard

  e                       -> throwIO e

tryWriteTBMQueue' :: TBMQueue a -> a -> STM Bool
tryWriteTBMQueue' q v = do
  v' <- tryWriteTBMQueue q v
  case v' of
    Just False -> retry
    Just True  -> pure True
    Nothing    -> pure False

-- | The loop a shard will run on
shardLoop :: ShardC r => Sem r ()
shardLoop = do
  void outerloop
  debug "Shard shut down"
 where
  controlStream :: Shard -> TBMQueue ShardMsg -> IO ()
  controlStream shard outqueue = inner
    where
      q = shard ^. #cmdQueue
      inner = do
        v <- atomically $ readTQueue q
        r <- atomically $ tryWriteTBMQueue' outqueue (Control v)
        when r inner

  discordStream :: P.Members '[LogEff, P.Embed IO] r => Connection -> TBMQueue ShardMsg -> Sem r ()
  discordStream ws outqueue = inner
    where inner = do
            msg <- P.embed . checkWSClose $ receiveData ws

            -- trace $ "Received from stream: "+||msg||+""

            case msg of
              Left c ->
                P.embed . atomically $ writeTBMQueue outqueue (Control c)

              Right msg' -> do
                let decoded = A.eitherDecode msg'
                r <- case decoded of
                  Right a ->
                    P.embed . atomically $ tryWriteTBMQueue' outqueue (Discord a)
                  Left e -> do
                    error $ "Failed to decode: "+|e|+""
                    pure True
                when r inner

  -- mergedStream ::  Log -> Shard -> Connection -> ExceptT ShardException ShardM ShardMsg
  -- mergedStream logEnv shard ws =
  --   liftIO (fromEither <$> race (controlStream shard) (discordStream logEnv ws))

  -- | The outer loop, sets up the ws conn, etc handles reconnecting and such
  -- Currently if this goes to the error path we just exit the forever loop
  -- and the shard stops, maybe we might want to do some extra logic to reboot
  -- the shard, or maybe force a resharding
  outerloop :: ShardC r => Sem r (Either ShardException ())
  outerloop = P.runError . forever $ do
    shard :: Shard <- P.atomicGets (^. #shardS)
    let host = shard ^. #gateway
    let host' =  fromMaybe host $ stripPrefix "wss://" host
    info $ "starting up shard "+| (shard ^. #shardID) |+" of "+| (shard ^. #shardCount) |+""


    innerLoopVal <- websocketToIO $ runWebsocket host' "/?v=7&encoding=json" innerloop

    case innerLoopVal of
      ShardExcShutDown -> do
        info "Shutting down shard"
        P.throw ShardExcShutDown

      ShardExcRestart ->
        info "Restaring shard"
        -- we restart normally when we loop

  -- | The inner loop, handles receiving a message from discord or a command message
  -- and then decides what to do with it
  innerloop :: ShardC r => Connection -> Sem r ShardException
  innerloop ws = do
    debug "Entering inner loop of shard"

    shard <- P.atomicGets (^. #shardS)
    P.atomicModify (#wsConn ?~ ws)

    seqNum'    <- P.atomicGets (^. #seqNum)
    sessionID' <- P.atomicGets (^. #sessionID)

    case (seqNum', sessionID') of
      (Just n, Just s) -> do
        debug $ "Resuming shard (sessionID: "+|s|+", seq: "+|n|+")"
        sendToWs (Resume ResumeData
                  { token = shard ^. #token
                  , sessionID = s
                  , seq = n
                  })
      _ -> do
        debug "Identifying shard"
        sendToWs (Identify IdentifyData
                  { token = shard ^. #token
                  , properties = IdentifyProps
                                 { browser = "Calamity: https://github.com/nitros12/calamity"
                                 , device = "Calamity: https://github.com/nitros12/calamity"
                                 }
                  , compress = False
                  , largeThreshold = 250
                  , shard = (shard ^. #shardID,
                             shard ^. #shardCount)
                  , presence = Nothing
                  })

    result <- P.runResource $ P.bracket (P.embed $ newTBMQueueIO 1)
      (\q -> P.embed . atomically $ closeTBMQueue q)
      (\q -> do
        debug "handling events now"
        _controlThread <- P.async . P.embed $ controlStream shard q
        _discordThread <- P.async $ discordStream ws q
        (fromEitherVoid <$>) . P.raise . P.runError . forever $ do
          -- only we close the queue
          msg <- P.embed . atomically $ readTBMQueue q
          handleMsg $ fromJust msg)

    debug "Exiting inner loop of shard"

    P.atomicModify (#wsConn .~ Nothing)
    haltHeartBeat
    pure result

  -- | Handlers for each message, not sure what they'll need to do exactly yet
  handleMsg :: (ShardC r, P.Member (P.Error ShardException) r) => ShardMsg -> Sem r ()
  handleMsg (Discord msg) = case msg of
    Dispatch sn data' -> do
      -- trace $ "Handling event: ("+||data'||+")"
      P.atomicModify (#seqNum ?~ sn)

      case data' of
        Ready rdata' ->
          P.atomicModify (#sessionID ?~ (rdata' ^. #sessionID))

        _ -> pure ()

      shard <- P.atomicGets (^. #shardS)
      P.embed . atomically $ writeTQueue (shard ^. #evtQueue) (DispatchData' data')
      -- sn' <- P.atomicGets (^. #seqNum)
      -- trace $ "Done handling event, seq is now: "+||sn'||+""

    HeartBeatReq -> do
      debug "Received heartbeat request"
      sendHeartBeat

    Reconnect -> do
      debug "Being asked to restart by Discord"
      P.throw ShardExcRestart

    InvalidSession resumable -> do
      if resumable
      then do
        info "Received non-resumable invalid session, sleeping for 15 seconds then retrying"
        P.atomicModify (#sessionID .~ Nothing)
        P.atomicModify (#seqNum .~ Nothing)
        P.embed $ threadDelay (15 * 1000 * 1000)
      else
        info "Received resumable invalid session"
      P.throw ShardExcRestart

    Hello interval -> do
      info $ "Received hello, beginning to heartbeat at an interval of "+|interval|+"ms"
      startHeartBeatLoop interval

    HeartBeatAck -> do
      debug "Received heartbeat ack"
      P.atomicModify (#hbResponse .~ True)

  handleMsg (Control msg) = case msg of
    SendPresence data' -> do
      debug $ "Sending presence: ("+||data'||+")"
      sendToWs $ StatusUpdate data'

    RestartShard       -> P.throw ShardExcRestart
    ShutDownShard      -> P.throw ShardExcShutDown

startHeartBeatLoop :: ShardC r => Int -> Sem r ()
startHeartBeatLoop interval = do
  haltHeartBeat -- cancel any currently running hb thread
  thread <- P.async $ heartBeatLoop interval
  P.atomicModify (#hbThread ?~ thread)

haltHeartBeat :: ShardC r => Sem r ()
haltHeartBeat = do
  thread <- P.atomicState @ShardState . (swap .) . runState $ do
    thread <- use #hbThread
    #hbThread .= Nothing
    pure thread
  case thread of
    Just t  -> do
      debug "Stopping heartbeat thread"
      P.embed (void $ cancel t)
    Nothing -> pure ()

sendHeartBeat :: ShardC r => Sem r ()
sendHeartBeat = do
  sn <- P.atomicGets (^. #seqNum)
  debug $ "Sending heartbeat (seq: " +|| sn ||+ ")"
  sendToWs $ HeartBeat sn
  P.atomicModify (#hbResponse .~ False)

heartBeatLoop :: ShardC r => Int -> Sem r ()
heartBeatLoop interval = void . P.runError . forever $ do
  sendHeartBeat
  P.embed . threadDelay $ interval * 1000
  unlessM (P.atomicGets (^. #hbResponse)) $ do
    debug "No heartbeat response, restarting shard"
    wsConn <- fromJust <$> P.atomicGets (^. #wsConn)
    P.embed $ sendCloseCode wsConn 4000 ("No heartbeat in time" :: Text)
    P.throw ()