{-# LANGUAGE CPP #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE ParallelListComp #-}
{-# LANGUAGE DefaultSignatures #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# OPTIONS_GHC -fno-cse -fno-full-laziness #-}

module Data.Discrimination.Grouping
  ( Group(..)
  , Grouping(..)
  , Grouping1(..)
  -- * Combinators
  , nub, nubWith
  , group, groupWith
  , groupingEq
  -- * Internals
  , groupingBag
  , groupingSet
  , groupingShort
  , groupingNat
  ) where

import Control.Arrow
import Control.Monad
import Data.Bits
import Data.Complex
import Data.Discrimination.Internal
import Data.Foldable hiding (concat)
import Data.Functor
import Data.Functor.Compose
import Data.Functor.Contravariant
import Data.Functor.Contravariant.Divisible
import Data.Functor.Contravariant.Generic
import Data.IORef (IORef, newIORef, atomicModifyIORef)
import Data.Int
import Data.Monoid hiding (Any)
import Data.Proxy
import Data.Ratio
import Data.Typeable
import qualified Data.Vector.Mutable as UM
import Data.Void
import Data.Word
import GHC.Prim (Any, RealWorld)
import Prelude hiding (read, concat)
import System.IO.Unsafe
import Unsafe.Coerce
{-
import Data.Coerce
import Data.Primitive.Types (Addr(..))
import GHC.IO (IO(IO))
import qualified Data.Vector.Primitive as P
import qualified Data.Vector.Primitive.Mutable as PM
import Data.Primitive.ByteArray (MutableByteArray(MutableByteArray))
import GHC.Prim (Any, State#, RealWorld, MutableByteArray#, Int#)
import GHC.IORef (IORef(IORef))
import GHC.STRef (STRef(STRef))
-}

-- | Discriminator

-- TODO: use [(a,b)] -> [NonEmpty b] to better indicate safety?
newtype Group a = Group { runGroup :: forall b. [(a,b)] -> [[b]] }
  deriving Typeable

#ifndef HLINT
type role Group representational
#endif

instance Contravariant Group where
  contramap f (Group g) = Group $ g . map (first f)

instance Divisible Group where
  conquer = Group $ return . fmap snd
  divide k (Group l) (Group r) = Group $ \xs ->
    l [ (b, (c, d)) | (a,d) <- xs, let (b, c) = k a] >>= r

instance Decidable Group where
  lose k = Group $ fmap (absurd.k.fst)
  choose f (Group l) (Group r) = Group $ \xs -> let
      ys = zipWith (\n (a,d) -> (f a, (n, d))) [0..] xs
    in l [ (k,p) | (Left k, p) <- ys ] `mix`
       r [ (k,p) | (Right k, p) <- ys ]

mix :: [[(Int,b)]] -> [[(Int,b)]] -> [[b]]
mix [] bs = fmap snd <$> bs
mix as [] = fmap snd <$> as
mix asss@(((n,a):as):ass) bsss@(((m,b):bs):bss)
  | n < m     = (a:fmap snd as) : mix ass bsss
  | otherwise = (b:fmap snd bs) : mix asss bss
mix _ _ = error "bad discriminator"

instance Monoid (Group a) where
  mempty = conquer
  mappend (Group l) (Group r) = Group $ \xs -> l [ (fst x, x) | x <- xs ] >>= r

--------------------------------------------------------------------------------
-- Primitives
--------------------------------------------------------------------------------

-- | Perform stable unordered discrimination by bucket.
--
-- This reuses arrays unlike the more obvious ST implementation, so it wins by
-- a huge margin in a race, especially when we have a large
-- keyspace, sparsely used, with low contention.
-- This will leak a number of arrays equal to the maximum concurrent
-- contention for this resource. If this becomes a bottleneck we can
-- make multiple stacks of working pads and index the stack with the
-- hash of the current thread id to reduce contention at the expense
-- of taking more memory.
--
-- You should create a thunk that holds the discriminator from @groupingNat n@
-- for a known @n@ and then reuse it.
groupingNat :: Int -> Group Int
groupingNat n = unsafePerformIO $ do
    ts <- newIORef ([] :: [UM.MVector RealWorld [Any]])
    return $ Group $ go ts
  where
    step1 t keys (k, v) = UM.read t k >>= \vs -> case vs of
      [] -> (k:keys) <$ UM.write t k [v]
      _  -> keys     <$ UM.write t k (v:vs)
    step2 t vss k = do
      es <- UM.read t k
      (reverse es : vss) <$ UM.write t k []
    go :: IORef [UM.MVector RealWorld [Any]] -> [(Int, b)] -> [[b]]
    go ts xs = unsafePerformIO $ do
      mt <- atomicModifyIORef ts $ \case
        (y:ys) -> (ys, Just y)
        []     -> ([], Nothing)
      t <- maybe (UM.replicate n []) (return . unsafeCoerce) mt
      ys <- foldM (step1 t) [] xs
      zs <- foldM (step2 t) [] ys
      atomicModifyIORef ts $ \ws -> (unsafeCoerce t:ws, ())
      return zs
    {-# NOINLINE go #-}
{-# NOINLINE groupingNat #-}

-- | Shared bucket set for small integers
groupingShort :: Group Int
groupingShort = groupingNat 65536
{-# NOINLINE groupingShort #-}

{-
foreign import prim "walk" walk :: Any -> MutableByteArray# s -> State# s -> (# State# s, Int# #)

groupingSTRef :: Group Addr -> Group (STRef s a)
groupingSTRef (Group f) = Group $ \xs ->
  let force !n !(!(STRef !_,_):ys) = force (n + 1) ys
      force !n [] = n
  in case force 0 xs of
   !n -> unsafePerformIO $ do
     mv@(PM.MVector _ _ (MutableByteArray mba)) <- PM.new n :: IO (PM.MVector RealWorld Addr)
     IO $ \s -> case walk (unsafeCoerce xs) mba s of (# s', _ #) -> (# s', () #)
     ys <- P.freeze mv
     return $ f [ (a,snd kv) | kv <- xs | a <- P.toList ys ]
{-# NOINLINE groupingSTRef #-}

groupingIORef :: forall a. Group Addr -> Group (IORef a)
groupingIORef = coerce (groupingSTRef :: Group Addr -> Group (STRef RealWorld a))
-}

--------------------------------------------------------------------------------
-- * Unordered Discrimination (for partitioning)
--------------------------------------------------------------------------------

-- | 'Eq' equipped with a compatible stable unordered discriminator.
class Grouping a where
  -- | For every surjection @f@,
  --
  -- @
  -- 'contramap' f 'grouping' ≡ 'grouping'
  -- @

  grouping :: Group a
#ifndef HLINT
  default grouping :: Deciding Grouping a => Group a
  grouping = deciding (Proxy :: Proxy Grouping) grouping
#endif

instance Grouping Void where
  grouping = lose id

instance Grouping Word8 where
  grouping = contramap fromIntegral groupingShort

instance Grouping Word16 where
  grouping = contramap fromIntegral groupingShort

instance Grouping Word32 where
  grouping = Group (runs <=< runGroup groupingShort . join . runGroup groupingShort . map radices) where
    radices (x,b) = (fromIntegral x .&. 0xffff, (fromIntegral (unsafeShiftR x 16), (x,b)))

instance Grouping Word64 where
  grouping = Group (runs <=< runGroup groupingShort . join . runGroup groupingShort . join
                          . runGroup groupingShort . join . runGroup groupingShort . map radices)
    where
      radices (x,b) = (fromIntegral x .&. 0xffff, (fromIntegral (unsafeShiftR x 16) .&. 0xffff
                    , (fromIntegral (unsafeShiftR x 32) .&. 0xffff, (fromIntegral (unsafeShiftR x 48)
                    , (x,b)))))


instance Grouping Word where
  grouping
    | (maxBound :: Word) == 4294967295 = contramap (fromIntegral :: Word -> Word32) grouping
    | otherwise                        = contramap (fromIntegral :: Word -> Word64) grouping

instance Grouping Int8 where
  grouping = contramap (\x -> fromIntegral x + 128) groupingShort

instance Grouping Int16 where
  grouping = contramap (\x -> fromIntegral x + 32768) groupingShort

instance Grouping Int32 where
  grouping = contramap (\x -> fromIntegral (x - minBound) :: Word32) grouping

instance Grouping Int64 where
  grouping = contramap (\x -> fromIntegral (x - minBound) :: Word64) grouping

instance Grouping Int where
  grouping = contramap (\x -> fromIntegral (x - minBound) :: Word) grouping

instance Grouping Bool
instance (Grouping a, Grouping b) => Grouping (a, b)
instance (Grouping a, Grouping b, Grouping c) => Grouping (a, b, c)
instance (Grouping a, Grouping b, Grouping c, Grouping d) => Grouping (a, b, c, d)
instance Grouping a => Grouping [a]
instance Grouping a => Grouping (Maybe a)
instance (Grouping a, Grouping b) => Grouping (Either a b)
instance Grouping a => Grouping (Complex a) where
  grouping = divide (\(a :+ b) -> (a, b)) grouping grouping
instance (Grouping a, Integral a) => Grouping (Ratio a) where
  grouping = divide (\r -> (numerator r, denominator r)) grouping grouping
instance (Grouping1 f, Grouping1 g, Grouping a) => Grouping (Compose f g a) where
  grouping = getCompose `contramap` grouping1 (grouping1 grouping)

class Grouping1 f where
  grouping1 :: Group a -> Group (f a)
#ifndef HLINT
  default grouping1 :: Deciding1 Grouping f => Group a -> Group (f a)
  grouping1 = deciding1 (Proxy :: Proxy Grouping) grouping
#endif

instance Grouping1 []
instance Grouping1 Maybe
instance Grouping a => Grouping1 (Either a)
instance Grouping a => Grouping1 ((,) a)
instance (Grouping a, Grouping b) => Grouping1 ((,,) a b)
instance (Grouping a, Grouping b, Grouping c) => Grouping1 ((,,,) a b c)
instance (Grouping1 f, Grouping1 g) => Grouping1 (Compose f g) where
  grouping1 f = getCompose `contramap` grouping1 (grouping1 f)
instance Grouping1 Complex where
  grouping1 f = divide (\(a :+ b) -> (a, b)) f f

-- | Valid definition for @('==')@ in terms of 'Grouping'.
groupingEq :: Grouping a => a -> a -> Bool
groupingEq a b = case runGroup grouping [(a,()),(b,())] of
  _:_:_ -> False
  _ -> True
{-# INLINE groupingEq #-}

--------------------------------------------------------------------------------
-- * Combinators
--------------------------------------------------------------------------------

-- | /O(n)/. Similar to 'Data.List.group', except we do not require groups to be clustered.
--
-- This combinator still operates in linear time, at the expense of productivity.
--
-- The result equivalence classes are _not_ sorted, but the grouping is stable.
--
-- @
-- 'group' = 'groupWith' 'id'
-- @
group :: Grouping a => [a] -> [[a]]
group as = runGroup grouping [(a, a) | a <- as]

-- | /O(n)/. This is a replacement for 'GHC.Exts.groupWith' using discrimination.
--
-- The result equivalence classes are _not_ sorted, but the grouping is stable.
groupWith :: Grouping b => (a -> b) -> [a] -> [[a]]
groupWith f as = runGroup grouping [(f a, a) | a <- as]

-- | /O(n)/. This upgrades 'Data.List.nub' from @Data.List@ from /O(n^2)/ to /O(n)/ by using
-- unordered discrimination.
--
-- @
-- 'nub' = 'nubWith' 'id'
-- 'nub' as = 'head' 'Control.Applicative.<$>' 'group' as
-- @
nub :: Grouping a => [a] -> [a]
nub as = head <$> group as

-- | /O(n)/. 'nub' with a Schwartzian transform.
--
-- @
-- 'nubWith' f as = 'head' 'Control.Applicative.<$>' 'groupWith' f as
-- @
nubWith :: Grouping b => (a -> b) -> [a] -> [a]
nubWith f as = head <$> groupWith f as

--------------------------------------------------------------------------------
-- * Collections
--------------------------------------------------------------------------------

-- | Construct an stable unordered discriminator that partitions into equivalence classes based on the equivalence of keys as a multiset.
groupingBag :: Foldable f => Group k -> Group (f k)
groupingBag = groupingColl updateBag

-- | Construct an stable unordered discriminator that partitions into equivalence classes based on the equivalence of keys as a set.
groupingSet :: Foldable f => Group k -> Group (f k)
groupingSet = groupingColl updateSet

groupingColl :: Foldable f => ([Int] -> Int -> [Int]) -> Group k -> Group (f k)
groupingColl update r = Group $ \xss -> let
    (kss, vs)           = unzip xss
    elemKeyNumAssocs    = groupNum (toList <$> kss)
    keyNumBlocks        = runGroup r elemKeyNumAssocs
    keyNumElemNumAssocs = groupNum keyNumBlocks
    sigs                = bdiscNat (length kss) update keyNumElemNumAssocs
    yss                 = zip sigs vs
  in filter (not . null) $ grouping1 (groupingNat (length keyNumBlocks)) `runGroup` yss