{-# LANGUAGE CPP                   #-}
{-# LANGUAGE DeriveDataTypeable    #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE Trustworthy           #-}
{-# LANGUAGE TypeFamilies          #-}
-- | 'InsOrdHashSet' is like 'HashMap', but it folds in insertion order.
--
-- This module interface mimics "Data.HashSet", with some additions.
module Data.HashSet.InsOrd (
    InsOrdHashSet,
    -- * Construction
    empty,
    singleton,
    -- * Basic interface
    null,
    size,
    member,
    insert,
    delete,
    -- * Combine
    union,
    -- * Transformations
    map,
    -- ** Unordered
    -- * Difference and intersection
    difference,
    intersection,
    -- * Folds
    -- ** Unordered
    -- * Filter
    filter,
    -- * Conversions
    toList,
    fromList,
    toHashSet,
    fromHashSet,
    -- * Lenses
    hashSet,
    -- * Debugging
    valid,
    )where

import Prelude ()
import Prelude.Compat hiding (filter, foldr, lookup, map, null)

import           Control.Arrow                   (first)
import           Data.Aeson
import           Data.Data                       (Data, Typeable)
import           Data.Hashable                   (Hashable (..))
import           Data.List                       (nub, sortBy)
import           Data.Ord                        (comparing)
import           Data.Semigroup                  (Semigroup (..))
import           Text.ParserCombinators.ReadPrec (prec)
import           Text.Read
                 (Lexeme (..), Read (..), lexP, parens, readListPrecDefault)

import Control.Lens
       (At (..), Contains (..), Index, Iso', IxValue, Ixed (..), iso, (<&>))
import Control.Monad.Trans.State.Strict (State, runState, state)

import           Data.HashMap.Strict (HashMap)
import qualified Data.HashMap.Strict as HashMap
import           Data.HashSet        (HashSet)
import qualified Data.HashSet        as HashSet

#if MIN_VERSION_base(4,7,0)
import qualified GHC.Exts as Exts
#endif

#if MIN_VERSION_base(4,8,0)
import qualified Data.Foldable
#endif

import Data.HashMap.InsOrd.Internal

-------------------------------------------------------------------------------
-- InsOrdHashSet
-------------------------------------------------------------------------------

-- | 'HashSet' which tries it's best to remember insertion order of elements.

data InsOrdHashSet k = InsOrdHashSet
    { _getIndex        :: !Int
    , getInsOrdHashSet :: !(HashMap k Int)
    }
    deriving (Typeable, Data)

instance Eq k => Eq (InsOrdHashSet k) where
    InsOrdHashSet _ a == InsOrdHashSet _ b = a == b

instance Show k => Show (InsOrdHashSet k) where
    showsPrec d m = showParen (d > 10) $
        showString "fromList " . showsPrec 11 (toList m)

instance (Eq k, Hashable k, Read k) => Read (InsOrdHashSet k) where
    readPrec = parens $ prec 10 $ do
      Ident "fromList" <- lexP
      xs <- readPrec
      return (fromList xs)

    readListPrec = readListPrecDefault

instance (Eq k, Hashable k) => Semigroup (InsOrdHashSet k) where
    (<>) = union

instance (Eq k, Hashable k) => Monoid (InsOrdHashSet k) where
    mempty = empty
    mappend = union

instance Foldable InsOrdHashSet where
    -- in newer base only
    -- length = length . getInsOrdHashSet
    foldMap f = foldMap f . toList

#if MIN_VERSION_base(4,8,0)
    null = null
    toList = toList
    length = size
#endif

-- | @'hashWithSalt' salt . 'toHashSet' = 'hashWithSalt' salt@.
instance Hashable k => Hashable (InsOrdHashSet k) where
    hashWithSalt salt (InsOrdHashSet _ m) =
        hashWithSalt salt m

#if MIN_VERSION_base(4,7,0)
instance (Eq k, Hashable k) => Exts.IsList (InsOrdHashSet k) where
    type Item (InsOrdHashSet k) = k
    fromList = fromList
    toList   = toList
#endif

-------------------------------------------------------------------------------
-- Aeson
-------------------------------------------------------------------------------

instance ToJSON a => ToJSON (InsOrdHashSet a) where
    toJSON     = toJSON . toList
    toEncoding = toEncoding . toList

instance (Eq a, Hashable a, FromJSON a) => FromJSON (InsOrdHashSet a) where
    parseJSON v = fromList <$> parseJSON v

-------------------------------------------------------------------------------
-- Lens
-------------------------------------------------------------------------------

type instance Index (InsOrdHashSet a) = a
type instance IxValue (InsOrdHashSet a) = ()

instance (Eq k, Hashable k) => Ixed (InsOrdHashSet k) where
    ix k f (InsOrdHashSet i m) = InsOrdHashSet i <$> ix k (\j -> j <$ f ()) m
    {-# INLINE ix #-}

instance (Eq k, Hashable k) => At (InsOrdHashSet k) where
  at k f m = f mv <&> \r -> case r of
    Nothing -> maybe m (const (delete k m)) mv
    Just () -> insert k m
    where mv = if member k m then Just () else Nothing
  {-# INLINE at #-}

instance (Eq a, Hashable a) => Contains (InsOrdHashSet a) where
  contains k f s = f (member k s) <&> \b ->
    if b then insert k s else delete k s
  {-# INLINE contains #-}

-- | This is a slight lie, as roundtrip doesn't preserve ordering.
hashSet :: Iso' (InsOrdHashSet a) (HashSet a)
hashSet = iso toHashSet fromHashSet

-------------------------------------------------------------------------------
-- Construction
-------------------------------------------------------------------------------

empty :: InsOrdHashSet k
empty = InsOrdHashSet 0 HashMap.empty
{-# INLINABLE empty #-}

singleton :: Hashable k => k -> InsOrdHashSet k
singleton k = InsOrdHashSet 1 (HashMap.singleton k 0)
{-# INLINABLE singleton #-}

-------------------------------------------------------------------------------
-- Basic interface
-------------------------------------------------------------------------------

null :: InsOrdHashSet k -> Bool
null = HashMap.null . getInsOrdHashSet
{-# INLINABLE null #-}

size :: InsOrdHashSet k -> Int
size = HashMap.size . getInsOrdHashSet
{-# INLINABLE size #-}

member :: (Eq k, Hashable k) => k -> InsOrdHashSet k -> Bool
member k = HashMap.member k . getInsOrdHashSet
{-# INLINABLE member #-}

insert :: (Eq k, Hashable k) => k -> InsOrdHashSet k -> InsOrdHashSet k
insert k (InsOrdHashSet i m) = InsOrdHashSet (i + 1) (HashMap.insert k i m)

delete :: (Eq k, Hashable k) => k -> InsOrdHashSet k -> InsOrdHashSet k
delete k (InsOrdHashSet i m) = InsOrdHashSet i (HashMap.delete k m)

-------------------------------------------------------------------------------
-- Combine
-------------------------------------------------------------------------------

union
    :: (Eq k, Hashable k)
    => InsOrdHashSet k -> InsOrdHashSet k -> InsOrdHashSet k
union (InsOrdHashSet i a) (InsOrdHashSet j b) =
    mk $ HashMap.union a b'
  where
    mk | i >= 0xfffff || j >= 0xfffff = fromHashMapInt
       | otherwise                    = InsOrdHashSet (i + j)

    b' = fmap (\k -> k + i + 1) b

-------------------------------------------------------------------------------
-- Transformations
-------------------------------------------------------------------------------

map :: (Hashable b, Eq b) => (a -> b) -> InsOrdHashSet a -> InsOrdHashSet b
map f (InsOrdHashSet i m) = InsOrdHashSet i
    $ HashMap.fromList . fmap (first f) . HashMap.toList
    $ m

-------------------------------------------------------------------------------
-- Difference and intersection
-------------------------------------------------------------------------------

difference :: (Eq a, Hashable a) => InsOrdHashSet a -> InsOrdHashSet a -> InsOrdHashSet a
difference (InsOrdHashSet i a) (InsOrdHashSet _ b) =
    InsOrdHashSet i $ HashMap.difference a b

intersection :: (Eq a, Hashable a) => InsOrdHashSet a -> InsOrdHashSet a -> InsOrdHashSet a
intersection (InsOrdHashSet i a) (InsOrdHashSet _ b) =
    InsOrdHashSet i $ HashMap.intersection a b

-------------------------------------------------------------------------------
-- Filter
-------------------------------------------------------------------------------

filter :: (a -> Bool) -> InsOrdHashSet a -> InsOrdHashSet a
filter p (InsOrdHashSet i m) = InsOrdHashSet i $
    HashMap.filterWithKey (\k _ -> p k) m

-------------------------------------------------------------------------------
-- Conversions
-------------------------------------------------------------------------------

fromList :: (Eq k, Hashable k) => [k] -> InsOrdHashSet k
fromList = mk . flip runState 0 . traverse newInt where
    mk (m, i) = InsOrdHashSet i (HashMap.fromList m)

toList :: InsOrdHashSet k -> [k]
toList
    = fmap fst
    . sortBy (comparing snd)
    . HashMap.toList
    . getInsOrdHashSet

fromHashSet :: HashSet k -> InsOrdHashSet k
fromHashSet = mk . flip runState 0 . traverse (const newInt') . HashSet.toMap where
    mk (m, i) = InsOrdHashSet i m

toHashSet :: InsOrdHashSet k -> HashSet k
toHashSet (InsOrdHashSet _ m) =
#if MIN_VERSION_unordered_containers(0,2,10)
    HashMap.keysSet m
#else
    HashSet.fromMap (fmap (const ()) m)
#endif

-------------------------------------------------------------------------------
-- Internal
-------------------------------------------------------------------------------

fromHashMapInt :: HashMap k Int -> InsOrdHashSet k
fromHashMapInt = mk . flip runState 0 . retractSortedAp . traverse f
  where
    mk (m, i) = InsOrdHashSet i m
    f i = liftSortedAp i newInt'

newInt :: a -> State Int (a, Int)
newInt a = state $ \s -> ((a, s), s + 1)

newInt' :: State Int Int
newInt' = state $ \s -> (s, s + 1)

-------------------------------------------------------------------------------
-- Valid
-------------------------------------------------------------------------------

-- | Test if the internal map structure is valid.
valid :: InsOrdHashSet a -> Bool
valid (InsOrdHashSet i m) = indexesDistinct && indexesSmaller
  where
    indexes :: [Int]
    indexes = HashMap.elems m

    indexesDistinct = indexes == nub indexes
    indexesSmaller  = all (< i) indexes