{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE NoFieldSelectors #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeApplications #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module Test.Database.CQL.IO.Replication (tests) where

import Control.Monad (replicateM)
import Data.IP (IP (..), toIPv4, toIPv6)
import Data.Int (Int64)
import Data.List (sort)
import Data.Map.Strict (Map)
import Data.Map.Strict qualified as Map (fromList, unionWith, singleton, unionsWith, lookup, lookupGE, findMin, foldMapWithKey, size)
import Data.Maybe (mapMaybe)
import Data.Set (Set)
import Data.Set qualified as Set (fromList, lookupGE, findMin, size, unions)
import Data.Text (Text)
import Data.Traversable (forM)
import Data.UUID (UUID)
import Data.Word (Word64)
import Database.CQL.IO.Client (buildTokenMap)
import Database.CQL.IO.Cluster.Host (Host (..), ip2inet)
import Database.CQL.IO.Replication (buildMasterReplicaMaps)
import Formatting (sformat, (%), int)
import Test.QuickCheck (Arbitrary (..), chooseInt, chooseAny, elements, Property, (===), Large (..), Every (..), (.&&.), suchThat)
import Test.Tasty (TestTree, testGroup)
import Test.Tasty.QuickCheck (testProperty)

data HostReplicationProblem = HostReplicationProblem
  { simpleTokenMap :: Map Int64 UUID
  , dcTokenMap :: Map Text (Map Int64 UUID)
  , dcHostCount :: Map Text (Int, Map Text Int)
  , fullHostMap :: Map UUID Host
  , ownedTokenMap :: Map UUID (Set Int64)
  } deriving (Eq, Ord, Show)

instance Arbitrary IP where
  arbitrary = do
    version <- chooseInt (0, 1)
    case version of
      0 -> do
        suffix <- replicateM 3 (chooseInt (0, 255))
        pure . IPv4 . toIPv4 $ 10 : suffix
      _ -> do
        suffix <- replicateM 6 (chooseInt (0, 0xffff))
        pure . IPv6 . toIPv6 $ 0x2001 : 0x0db8 : suffix

instance Arbitrary HostReplicationProblem where
  arbitrary = do
    dcCount <- chooseInt (1, 5)
    rackCount <- chooseInt (1, 5)
    perRackCount <- chooseInt (1, 10)
    tokenCount <- elements [8, 16, 32, 64]
    fullHosts <- forM [1..dcCount] (\dcNum -> do
      let dcName = sformat ("DC" % int) dcNum
      dcHosts <- forM [1..rackCount] $ \rackNum -> do
        let rackName = sformat ("RAC" % int) rackNum
        replicateM perRackCount (do
          hostUUID <- chooseAny
          hostIP <- arbitrary
          tokens <- replicateM @_ @Int64 tokenCount chooseAny
          let hostAddr = ip2inet 9042 hostIP
              broadcastAddr = ip2inet 7000 hostIP
              host = Host
                { _hostAddr = hostAddr
                , _broadcastAddr = broadcastAddr
                , _hostId = hostUUID
                , _dataCentre = dcName
                , _rack = rackName
                }
          pure (host, Set.fromList tokens))
      pure $ concat dcHosts)
        `suchThat` \possibleHostList ->
          let expectedSize = dcCount * rackCount * perRackCount * tokenCount
              allTokens = Set.unions $ map (Set.unions . map snd) possibleHostList
           in Set.size allTokens == expectedSize
    let fullHostsFlat = concat fullHosts
        ownedTokenMap = Map.fromList $ [ (hostId, tokens)
          | (Host {_hostId = hostId}, tokens) <- fullHostsFlat
          ]
        lookupMap = Map.fromList $ [ (hostId, host)
          | (host@(Host {_hostId = hostId}), _tokens) <- fullHostsFlat
          ]
        (simpleMap, dcMap) = buildTokenMap ownedTokenMap lookupMap
        mergeCounts :: (Int, Map Text Int) -> (Int, Map Text Int) -> (Int, Map Text Int)
        mergeCounts (!count1, !racks1) (!count2, !racks2) =
          let !count3 = count1 + count2
              !racks3 = Map.unionWith (+) racks1 racks2
           in (count3, racks3)
        countMap = Map.unionsWith mergeCounts $ [
          Map.singleton dc (1, Map.singleton rack 1)
          | (Host {_dataCentre = dc, _rack = rack}, _tokens) <- fullHostsFlat
          ]
    pure $ HostReplicationProblem
      { simpleTokenMap = simpleMap
      , dcTokenMap = dcMap
      , dcHostCount = countMap
      , fullHostMap = lookupMap
      , ownedTokenMap = ownedTokenMap
      }

tokenOffsetDistance :: Int64 -> Set Int64 -> Word64
tokenOffsetDistance queryVal tokenSet =
  let qvW64 = fromIntegral queryVal
      tokenW64 = case Set.lookupGE queryVal tokenSet of
        Nothing -> fromIntegral $ Set.findMin tokenSet
        Just owningToken -> fromIntegral owningToken
   in tokenW64 - qvW64

distanceForHost :: Int64 -> Map UUID (Set Int64) -> UUID -> Word64
distanceForHost queryVal ownedTokenMap hostId =
  case Map.lookup hostId ownedTokenMap of
    Nothing -> error ("distance requested for unknown host: " ++ show hostId)
    Just tokenSet -> tokenOffsetDistance queryVal tokenSet

correctlySorted :: Large Int64 -> HostReplicationProblem -> Property
correctlySorted (Large queryVal) hrp =
  let (simpleReplicaMap, nonSimpleReplicaMap) = buildMasterReplicaMaps
        hrp.simpleTokenMap
        hrp.dcTokenMap
        hrp.dcHostCount
        hrp.fullHostMap
      replicas = case Map.lookupGE queryVal simpleReplicaMap of
        Just sortedReplicas -> snd sortedReplicas
        Nothing -> snd $ Map.findMin simpleReplicaMap
      replicaDistances = map (distanceForHost queryVal hrp.ownedTokenMap) replicas
   in isSorted replicaDistances
        .&&. allUnique replicas
        .&&. Map.foldMapWithKey (\k v -> Every (dcSortedCorrectly queryVal hrp k v)) nonSimpleReplicaMap

isSorted :: (Ord a, Eq a, Show a) => [a] -> Property
isSorted vals = sort vals === vals

allUnique :: (Ord a, Eq a, Show a) => [a] -> Property
allUnique vals = length vals === Set.size (Set.fromList vals)

hostRacks :: Map UUID Host -> [UUID] -> [Text]
hostRacks hostMap =
  map (\x -> x._rack) . mapMaybe lookupHost
 where lookupHost hostId = Map.lookup hostId hostMap

dcSortedCorrectly :: Int64 -> HostReplicationProblem -> Text -> Map Int64 [UUID] -> Property
dcSortedCorrectly queryVal hrp dc localReplicaMap =
  let replicas = case Map.lookupGE queryVal localReplicaMap of
        Just sortedReplicas -> snd sortedReplicas
        Nothing -> snd $ Map.findMin localReplicaMap
      rackCount = case Map.lookup dc hrp.dcHostCount of
        Nothing -> error ("could not find dc in dcHostCount map: " ++ show dc)
        Just (_total, countByRack) -> Map.size countByRack
      (primaryReplicas, secondaryReplicas) = splitAt rackCount replicas
      rackSingleton :: UUID -> Map Text [UUID]
      rackSingleton hostId = case Map.lookup hostId hrp.fullHostMap of
        Nothing -> error ("could not find host in fullHostMap by UUID: " ++ show hostId)
        Just fullHost -> Map.singleton fullHost._rack [hostId]
      replicasByRack :: Map Text [UUID]
      replicasByRack = Map.unionsWith (flip (++)) $ map rackSingleton (reverse replicas)
      sortedDistances :: [UUID] -> Property
      sortedDistances = isSorted . map (distanceForHost queryVal hrp.ownedTokenMap)
   in sortedDistances primaryReplicas
        .&&. sortedDistances secondaryReplicas
        .&&. allUnique replicas
        .&&. allUnique (hostRacks hrp.fullHostMap primaryReplicas)
        .&&. foldMap (Every . sortedDistances) replicasByRack

tests :: TestTree
tests = testGroup "Replication"
  [ testGroup "Ordering"
    [ testProperty "Simple and DC based replicas are sorted correctly." correctlySorted
    ]
  ]
