{-# LANGUAGE GADTs               #-}
{-# LANGUAGE NamedFieldPuns      #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE ScopedTypeVariables #-}


module Driver (Driver.init, Driver.allConnections, Driver.RetryInterval(..)) where


import           Codec
import           Common
import           Control.Applicative
import           Control.Concurrent
import           Control.Concurrent.MVar
import           Control.Concurrent.STM
import           Control.Concurrent.STM.TBQueue
import           Control.Exception.Safe         (bracket, catchAny, mask)
import           Control.Monad                  (forM_, forever, replicateM)
import           Control.Monad.Except
import           Data.Binary
import           Data.Binary.Get
import           Data.Binary.IEEE754
import           Data.Binary.Put
import           Data.Bits
import           Data.ByteString
import qualified Data.ByteString.Char8          as C8
import qualified Data.ByteString.Lazy           as DBL
import           Data.Either
import           Data.Int
import qualified Data.IntMap                    as IM
import           Data.List
import qualified Data.Map.Strict                as DMS
import           Data.Maybe
import           Data.Monoid                    as DM
import           Data.Set
import           Data.Traversable
import           Data.UUID
import           Debug.Trace
import           Encoding
import           GHC.Generics                   (Generic)
import           GHC.IO.Handle                  (Handle, hClose, hFlush)
import           Network
import           System.IO.Unsafe


startUp :: Handle -> IO ()
startUp h = do
  let s = encode $ StringMap (DMS.fromList [(ShortStr "CQL_VERSION", ShortStr "3.0.0")])
  let startup = DBL.toStrict $ DBL.pack [4, 0, 0, 1, 1] <> encode (fromIntegral (DBL.length s)::Int32) <> s
  hPut h startup
  header <- hGet h 9
  he <- pure $ runGet (get :: Get Header) $ DBL.fromStrict header
  hGet h (fromIntegral $ len he)
  return ()


registerForEvents :: Handle -> IO ()
registerForEvents h = do
  let s = encode $ CQLList [CQLString "TOPOLOGY_CHANGE", CQLString "STATUS_CHANGE"]
  let register = DBL.toStrict $ DBL.pack [4, 0, 0, 1, 11] <> encode (fromIntegral (DBL.length s)::Int32) <> s
  hPut h register
  header <- hGet h 9
  he <- pure $ runGet (get :: Get Header) $ DBL.fromStrict header
  hGet h (fromIntegral $ len he)
  return ()


{-# NOINLINE prepareQueries #-}
prepareQueries :: MVar (Set LongStr)
prepareQueries = unsafePerformIO $ do
  mvar <- newEmptyMVar
  putMVar mvar Data.Set.empty
  return mvar


receiveThread :: HostName -> PortID -> RetryInterval -> Handle -> MVar [Int16] -> MVar(IM.IntMap (MVar (Either ShortStr Result))) -> IO ()
receiveThread host port ri h streams streamMap = do
  forkIO $ catchAny (forever $ do
                                header <- hGet h 9
                                he <- pure $ runGet (get :: Get Header) $ DBL.fromStrict header
                                p <- hGet h (fromIntegral $ len he)
                                case opcode he of
                                  0 -> do
                                    let (erCode, erMsg) = runGet getErr (DBL.fromStrict p)
                                    mvar <- getResHolder he streams streamMap
                                    putMVar mvar (Left erMsg)
                                  8 -> do
                                    let resultType = runGet (get :: Get Int32) (DBL.fromStrict p)

                                    case resultType of
                                      1 -> do
                                        mvar <- getResHolder he streams streamMap
                                        putMVar mvar (Right $ RRows [])

                                      2 -> do
                                        let rows = content $ runGet getRows (DBL.fromStrict $ C8.drop 4 p)
                                        mvar <- getResHolder he streams streamMap
                                        putMVar mvar (Right $ RRows rows)

                                      4 -> do
                                        let prep = runGet (get :: Get ShortBytes) (DBL.fromStrict $ C8.drop 4 p)
                                        mvar <- getResHolder he streams streamMap
                                        putMVar mvar (Right $ RPrepared prep)

                                      5 -> do
                                        mvar <- getResHolder he streams streamMap
                                        putMVar mvar (Right $ RRows [])) (\e -> do
                                                                                  print "error in receiving"
                                                                                  sm <- takeMVar streamMap
                                                                                  nm <- sequence $ fmap (\mvar -> tryPutMVar mvar (Left $ ShortStr $ DBL.fromStrict $ C8.pack $ "connection to node " <> show host <> " was lost")) sm
                                                                                  putMVar streamMap sm)

          -- 12 -> do
          --   print "something happened"
          --   print $ runGet (get :: Get ShortStr) (DBL.fromStrict p)
  return ()


getResHolder :: Header -> MVar [Int16] -> MVar(IM.IntMap (MVar (Either ShortStr Result))) -> IO (MVar (Either ShortStr Result))
getResHolder he1 streams streamMap = do
          m <- takeMVar streamMap
          let strm = stream he1
          let fm = IM.lookup (fromIntegral strm :: Int) m
          let mvar = fromJust fm
          putMVar streamMap (IM.delete (fromIntegral strm :: Int) m)
          strs <- takeMVar streams
          putMVar streams (strs ++ [strm])
          return mvar


writeQuery :: Handle -> MVar [Int16] -> MVar (IM.IntMap (MVar (Either ShortStr Result))) -> (LongStr, Word8, MVar (Either ShortStr Result)) -> IO ()
writeQuery h streams streamMap (bs, opc, mvar) = do
      strs <- takeMVar streams
      case Data.List.uncons strs of
        Just (i, strsTail) -> do
          putMVar streams strsTail
          m <- takeMVar streamMap
          let m' = IM.insert (fromIntegral i :: Int) mvar m
          putMVar streamMap m'
          seq m' (return ())
          let bs' = Data.ByteString.pack [4, 0] <> DBL.toStrict (encode i <> DBL.pack [opc] <> encode bs)
          hPut h bs'

        Nothing -> do
          putMVar streams strs
          writeQuery h streams streamMap (bs, opc, mvar)


sendThread :: HostName -> PortID -> RetryInterval -> Handle -> MVar [Int16] -> MVar(IM.IntMap (MVar (Either ShortStr Result))) -> TBQueue (LongStr, Word8, MVar (Either ShortStr Result)) -> IO ()
sendThread host port ri h streams streamMap driverQ = do
  forkIO $ catchAny (forever $ do
          (bs, opc, mvar) <- atomically $ readTBQueue driverQ
          when (opc == 9) $ do
             pq <- takeMVar prepareQueries
             putMVar prepareQueries (Data.Set.insert bs pq)
             cons <- readMVar allConnections
             mvars <- forM cons $ \(_, _, h', str, strMap) -> do
               mv <- newEmptyMVar
               writeQuery h' str strMap (bs, opc, mv)
               return mv
             forkIO $ do
               ls <- sequence $ fmap (\mva -> do {a <- takeMVar mva; return a}) mvars
               putMVar mvar $ Data.List.foldl' (\b a -> if isLeft a then a else b) (Data.List.head ls) (Data.List.tail ls)
             return ()
          when (opc /= 9) $ writeQuery h streams streamMap (bs, opc, mvar)) (\e -> do
                                                                                  sm <- takeMVar streamMap
                                                                                  nm <- sequence $ fmap (\mvar -> tryPutMVar mvar (Left $ ShortStr $ DBL.fromStrict $ C8.pack $ "connection to node " <> show host <> " was lost")) sm
                                                                                  putMVar streamMap sm
                                                                                  (host, port, h, streams, streamMap) <- setupNode host port ri
                                                                                  sendThread host port ri h streams streamMap driverQ
                                                                                  receiveThread host port ri h streams streamMap)
  return ()


getPeers :: Handle -> IO (Either ShortStr [Row])
getPeers h = do
  let q' = "select peer from system.peers"
  let q = q' <> encode LOCAL_ONE <> encode (0x00 :: Int8)
  let l = fromIntegral (DBL.length q')::Int32
  let hd = LongStr $ encode l <> q
  let bs' = Data.ByteString.pack [4, 0] <> DBL.toStrict (encode (0::Int16) <> DBL.pack [7] <> encode hd)
  hPut h bs'
  header <- hGet h 9
  he <- pure $ runGet (get :: Get Header) $ DBL.fromStrict header
  p <- hGet h (fromIntegral $ len he)
  if opcode he == 0
    then do
      let (erCode, erMsg) = runGet getErr (DBL.fromStrict p)
      return $ Left erMsg
    else do
      let resultType = runGet (get :: Get Int32) (DBL.fromStrict p)
      case resultType of
        2 -> do
          let rows = content $ runGet getRows (DBL.fromStrict $ C8.drop 4 p)
          return $ Right rows
        _ ->
          return $ Left $ ShortStr "Could not get list of peers. Please check if this cassandra version is supported."


{-# NOINLINE allConnections #-}
allConnections :: MVar [(HostName, PortID, Handle, MVar [Int16], MVar(IM.IntMap (MVar (Either ShortStr Result))))]
allConnections = unsafePerformIO $ newMVar []


setupNode peer port rInt@(RetryInterval ri) = catchAny (do
  streams <- newMVar ([0..32767] :: [Int16])
  streamMap <- newMVar (IM.empty :: (IM.IntMap (MVar (Either ShortStr Result))))
  h <- connectTo peer port
  startUp h
  return (peer, port, h, streams, streamMap)) (\e -> do
                                                  threadDelay ri
                                                  setupNode peer port rInt)



newtype RetryInterval = RetryInterval Int


-- | The first function you need to call. It initializes the driver and connects to the cluster.
-- You only need to specify one node from your cluster here.
-- Retryinterval is the interval with which connection to a node will be retried in case of disconnection.
init :: HostName -> PortID -> RetryInterval -> ExceptT ShortStr IO Candle
init host port ri = do
    streamNum <- liftIO $ newMVar 0
    streams <- liftIO $ newMVar ([0..32767] :: [Int16])
    streamMap <- liftIO $ newMVar (IM.empty :: (IM.IntMap (MVar (Either ShortStr Result))))
    driverQ <- liftIO $ atomically $ newTBQueue 32768

    h <- liftIO $ connectTo host port
    liftIO $ startUp h
    -- liftIO $ registerForEvents h

    res <- liftIO $ getPeers h
    case res of
      Left err -> throwError err
      Right rows -> do
        let peers = (\(CQLString p) -> Data.List.concat <$> Data.List.intersperse "." $ fmap show (unpack p)) <$> catMaybes (fmap (\r -> fromCQL r (CQLString "peer")::Maybe CQLString) rows)
        oCons <- liftIO $ sequence $ fmap (\peer -> setupNode peer port ri) peers
        let connections = (host, port, h, streams, streamMap) : oCons
        allCon <- liftIO $ takeMVar allConnections
        liftIO $ putMVar allConnections connections
        forM_ connections (\(host, port, h, streams, streamMap) -> do
          liftIO $ receiveThread host port ri h streams streamMap
          liftIO $ sendThread host port ri h streams streamMap driverQ)
        return $ Candle driverQ