{-# LANGUAGE TypeFamilies, FlexibleContexts, UnboxedTuples #-}

module Data.TrieMap (
	-- * Map type
	TKey,
	TMap,
	-- * Operators
	(!),
	(\\),
	-- * Query
	null,
	size,
	member,
	notMember,
	lookup,
	findWithDefault,
	-- * Construction
	empty,
-- 	showMap,
	singleton,
	-- ** Insertion
	insert,
	insertWith,
	insertWithKey,
	-- ** Delete/Update
	delete,
	adjust,
	adjustWithKey,
	update,
	updateWithKey,
	alter,
	-- * Combine
	-- ** Union
	union,
	unionWith,
	unionWithKey,
	unionMaybeWith,
	unionMaybeWithKey,
	symmetricDifference,
	-- ** Difference
	difference,
	differenceWith,
	differenceWithKey,
	-- ** Intersection
	intersection,
	intersectionWith,
	intersectionWithKey,
	intersectionMaybeWith,
	intersectionMaybeWithKey,
	-- * Traversal
	-- ** Map
	map,
	mapWithKey,
	mapKeys,
	mapKeysWith,
	mapKeysMonotonic,
	-- ** Traverse
	traverseWithKey,
	-- ** Fold
	fold,
	foldWithKey,
	foldrWithKey,
	foldlWithKey,
	-- * Conversion
	elems,
	keys,
	keysSet,
	assocs,
	-- ** Lists
	fromList,
	fromListWith,
	fromListWithKey,
	-- ** Ordered lists
	fromAscList,
	fromAscListWith,
	fromAscListWithKey,
	fromDistinctAscList,
	-- * Filter
	filter,
	filterWithKey,
	partition,
	partitionWithKey,
	mapMaybe,
	mapMaybeWithKey,
	mapEither,
	mapEitherWithKey,
	split,
	splitLookup,
	-- * Submap
	isSubmapOf,
	isSubmapOfBy,
	-- * Min/Max
	findMin,
	findMax,
	deleteMin,
	deleteMax,
	deleteFindMin,
	deleteFindMax,
	updateMin,
	updateMax,
	updateMinWithKey,
	updateMaxWithKey,
	minView,
	maxView,
	minViewWithKey,
	maxViewWithKey
	) where

import Data.TrieMap.Class
import Data.TrieMap.Class.Instances()
import Data.TrieMap.TrieKey
import Data.TrieMap.Applicative
import Data.TrieMap.Rep
import Data.TrieMap.Rep.Instances ()
import Data.TrieMap.Sized

import Control.Applicative hiding (empty)
import Control.Arrow
import Control.Monad
import Data.Maybe hiding (mapMaybe)
import Data.Monoid(Monoid(..), First(..), Last(..))

import GHC.Exts (build)

import Prelude hiding (lookup, foldr, null, map, filter, reverse)

instance (Show k, Show a, TKey k) => Show (TMap k a) where
	show m = "fromList " ++ show (assocs m)

instance (Eq k, TKey k, Eq a) => Eq (TMap k a) where
	m1 == m2 = assocs m1 == assocs m2

instance (Ord k, TKey k, Ord a) => Ord (TMap k a) where
	m1 `compare` m2 = assocs m1 `compare` assocs m2

instance TKey k => Monoid (TMap k a) where
	mempty = empty
	mappend = union

-- | The empty map.
empty :: TKey k => TMap k a
empty = TMap emptyM

-- | A map with a single element.
singleton :: TKey k => k -> a -> TMap k a
singleton k a = insert k a empty

-- | Is the map empty?
null :: TKey k => TMap k a -> Bool
null (TMap m) = nullM m

-- | Lookup the value at a key in the map.
-- 
-- The function will return the corresponding value as @('Just' value)@, or 'Nothing' if the key isn't in the map.
lookup :: TKey k => k -> TMap k a -> Maybe a
lookup k (TMap m) = getElem <$> lookupM (toRep k) m

-- | The expression @('findWithDefault' def k map)@ returns the value at key @k@ or returns default value @def@
-- when the key is not in the map.
findWithDefault :: TKey k => a -> k -> TMap k a -> a
findWithDefault a = fromMaybe a .: lookup

-- | Find the value at a key. Calls 'error' when the element can not be found.
(!) :: TKey k => TMap k a -> k -> a
m ! k = fromMaybe (error "Element not found") (lookup k m)

-- | The expression @('alter' f k map)@ alters the value @x@ at @k@, or absence thereof. 
-- 'alter' can be used to insert, delete, or update a value in a 'TMap'. In short:
-- @'lookup' k ('alter' f k m) = f ('lookup' k m)@.
alter :: TKey k => (Maybe a -> Maybe a) -> k -> TMap k a -> TMap k a
alter f k (TMap m) = TMap (alterM elemSize (fmap Elem . f . fmap getElem) (toRep k) m)

extract :: (TKey k, MonadPlus m) => (k -> a -> m (x, Maybe a)) -> TMap k a -> m (x, TMap k a)
extract f m = unwrapMonad (extractA (WrapMonad .: f) m)

-- | Projects information out of, and modifies or deletes, an individual association pair, 
-- alternating over all associations in the map.
-- 
-- If @assocs m == [(k1, a1), ..., (kn, an)]@, then
-- 
-- > extract f m = let upd k (x, maybeA) = (x, alter (const maybeA) k m) in
-- >   (upd k1 <$> f kn an) <|> ... <|> (upd kn <$> f kn an)
-- 
-- This generalizes a large number of operations, including
-- 
-- > minViewWithKey == getFirst (extract (\ k a -> return ((k, a), Nothing)))
-- > updateMaxWithKey f m == maybe m snd (getLast (extract (\ k a -> return ((), f k a)) m))
-- 
-- In addition,
-- 
-- > getFirst (extract (\ k a -> if p k a then return ((k, a), Nothing) else mzero) m)
-- 
-- finds and removes the first association pair satisfying the predicate |p|.
extractA :: (TKey k, Alternative f) => (k -> a -> f (x, Maybe a)) -> TMap k a -> f (x, TMap k a)
extractA f (TMap m) = fmap TMap <$> extractM elemSize (\ k (Elem a) -> fmap (fmap (fmap Elem)) (f (fromRep k) a)) m

insert :: TKey k => k -> a -> TMap k a -> TMap k a
insert = insertWith const

insertWith :: TKey k => (a -> a -> a) -> k -> a -> TMap k a -> TMap k a
insertWith = insertWithKey . const

insertWithKey :: TKey k => (k -> a -> a -> a) -> k -> a -> TMap k a -> TMap k a
insertWithKey f k a = alter f' k where
	f' = Just . maybe a (f k a)

delete :: TKey k => k -> TMap k a -> TMap k a
delete = alter (const Nothing)

adjust :: TKey k => (a -> a) -> k -> TMap k a -> TMap k a
adjust = adjustWithKey . const

adjustWithKey :: TKey k => (k -> a -> a) -> k -> TMap k a -> TMap k a
adjustWithKey f = updateWithKey (Just .: f)

update :: TKey k => (a -> Maybe a) -> k -> TMap k a -> TMap k a
update f = alter (>>= f)

updateWithKey :: TKey k => (k -> a -> Maybe a) -> k -> TMap k a -> TMap k a
updateWithKey f k = update (f k) k

fold :: TKey k => (a -> b -> b) -> b -> TMap k a -> b
fold = foldWithKey . const

foldWithKey, foldrWithKey :: TKey k => (k -> a -> b -> b) -> b -> TMap k a -> b
foldWithKey f z (TMap m) = foldWithKeyM (\ k (Elem a) -> f (fromRep k) a) m z
foldrWithKey = foldWithKey

foldlWithKey :: TKey k => (b -> k -> a -> b) -> b -> TMap k a -> b
foldlWithKey f z (TMap m) = foldlWithKeyM (\ k z (Elem a) -> f z (fromRep k) a) m z

traverseWithKey :: (TKey k, Applicative f) => (k -> a -> f b) -> TMap k a -> f (TMap k b)
traverseWithKey f (TMap m) = TMap <$> traverseWithKeyM elemSize (\ k (Elem a) -> Elem <$> f (fromRep k) a) m

map :: TKey k => (a -> b) -> TMap k a -> TMap k b
map = fmap

mapWithKey :: TKey k => (k -> a -> b) -> TMap k a -> TMap k b
mapWithKey f (TMap m) = TMap (mapWithKeyM elemSize (\ k (Elem a) -> Elem (f (fromRep k) a)) m)

mapKeys :: (TKey k, TKey k') => (k -> k') -> TMap k a -> TMap k' a
mapKeys f m = fromList [(f k, a) | (k, a) <- assocs m]

mapKeysWith :: (TKey k, TKey k') => (a -> a -> a) -> (k -> k') -> TMap k a -> TMap k' a
mapKeysWith g f m = fromListWith g [(f k, a) | (k, a) <- assocs m]

mapKeysMonotonic :: (TKey k, TKey k') => (k -> k') -> TMap k a -> TMap k' a
mapKeysMonotonic f m = fromDistinctAscList [(f k, a) | (k, a) <- assocs m]

union :: TKey k => TMap k a -> TMap k a -> TMap k a
union = unionWith const

unionWith :: TKey k => (a -> a -> a) -> TMap k a -> TMap k a -> TMap k a
unionWith = unionWithKey . const

unionWithKey :: TKey k => (k -> a -> a -> a) -> TMap k a -> TMap k a -> TMap k a
unionWithKey f = unionMaybeWithKey (\ k a b -> Just (f k a b))

unionMaybeWith :: TKey k => (a -> a -> Maybe a) -> TMap k a -> TMap k a -> TMap k a
unionMaybeWith = unionMaybeWithKey . const

unionMaybeWithKey :: TKey k => (k -> a -> a -> Maybe a) -> TMap k a -> TMap k a -> TMap k a
unionMaybeWithKey f (TMap m1) (TMap m2) = TMap (unionM elemSize f' m1 m2) where
	f' k (Elem a) (Elem b) = Elem <$> f (fromRep k) a b

symmetricDifference :: TKey k => TMap k a -> TMap k a -> TMap k a
symmetricDifference = unionMaybeWith (\ _ _ -> Nothing)

intersection :: TKey k => TMap k a -> TMap k b -> TMap k a
intersection = intersectionWith const

intersectionWith :: TKey k => (a -> b -> c) -> TMap k a -> TMap k b -> TMap k c
intersectionWith = intersectionWithKey . const

intersectionWithKey :: TKey k => (k -> a -> b -> c) -> TMap k a -> TMap k b -> TMap k c
intersectionWithKey f = intersectionMaybeWithKey (\ k a b -> Just (f k a b))

intersectionMaybeWith :: TKey k => (a -> b -> Maybe c) -> TMap k a -> TMap k b -> TMap k c
intersectionMaybeWith = intersectionMaybeWithKey . const

intersectionMaybeWithKey :: TKey k => (k -> a -> b -> Maybe c) -> TMap k a -> TMap k b -> TMap k c
intersectionMaybeWithKey f (TMap m1) (TMap m2) = TMap (isectM elemSize f' m1 m2) where
	f' k (Elem a) (Elem b) = Elem <$> f (fromRep k) a b

difference, (\\) :: TKey k => TMap k a -> TMap k b -> TMap k a
difference = differenceWith (\ _ _ -> Nothing)

(\\) = difference

differenceWith :: TKey k => (a -> b -> Maybe a) -> TMap k a -> TMap k b -> TMap k a
differenceWith = differenceWithKey . const

differenceWithKey :: TKey k => (k -> a -> b -> Maybe a) -> TMap k a -> TMap k b -> TMap k a
differenceWithKey f (TMap m1) (TMap m2) = TMap (diffM elemSize f' m1 m2) where
	f' k (Elem a) (Elem b) = Elem <$> f (fromRep k) a b

minView, maxView :: TKey k => TMap k a -> Maybe (a, TMap k a)
minView m = first snd <$> minViewWithKey m
maxView m = first snd <$> maxViewWithKey m

findMin, findMax :: TKey k => TMap k a -> (k, a)
findMin = maybe (error "empty map has no minimal element") fst . minViewWithKey
findMax = maybe (error "empty map has no maximal element") fst . maxViewWithKey

deleteMin, deleteMax :: TKey k => TMap k a -> TMap k a
deleteMin m = maybe m snd (minViewWithKey m)
deleteMax m = maybe m snd (maxViewWithKey m)

updateMin, updateMax :: TKey k => (a -> Maybe a) -> TMap k a -> TMap k a
updateMin = updateMinWithKey . const
updateMax = updateMaxWithKey . const

updateMinWithKey, updateMaxWithKey :: TKey k => (k -> a -> Maybe a) -> TMap k a -> TMap k a
updateMinWithKey f m = maybe m snd (getFirst (extract (\ k a -> return ((), f k a)) m))
updateMaxWithKey f m = maybe m snd (getLast (extract (\ k a -> return ((), f k a)) m))

deleteFindMin, deleteFindMax :: TKey k => TMap k a -> ((k, a), TMap k a)
deleteFindMin m = fromMaybe (error "Cannot return the minimal element of an empty map") (minViewWithKey m)
deleteFindMax m = fromMaybe (error "Cannot return the maximal element of an empty map") (maxViewWithKey m)

minViewWithKey, maxViewWithKey :: TKey k => TMap k a -> Maybe ((k, a), TMap k a)
minViewWithKey = getFirst . extract (\ k a -> return ((k, a), Nothing))
maxViewWithKey = getLast . extract (\ k a -> return ((k, a), Nothing))

elems :: TKey k => TMap k a -> [a]
elems = fmap snd . assocs

keys :: TKey k => TMap k a -> [k]
keys = fmap fst . assocs

assocs :: TKey k => TMap k a -> [(k, a)]
assocs m = build (\ c n -> foldWithKey (curry c) n m)

mapEither :: TKey k => (a -> Either b c) -> TMap k a -> (TMap k b, TMap k c)
mapEither = mapEitherWithKey . const

mapEitherWithKey :: TKey k => (k -> a -> Either b c) -> TMap k a -> (TMap k b, TMap k c)
mapEitherWithKey f (TMap m) = case mapEitherM elemSize elemSize f' m of
	(# mL, mR #) -> (TMap mL, TMap mR) 
	where	f' k (Elem a) = case f (fromRep k) a of
			Left b	-> (# Just (Elem b), Nothing #)
			Right c	-> (# Nothing, Just (Elem c) #)

mapMaybe :: TKey k => (a -> Maybe b) -> TMap k a -> TMap k b
mapMaybe = mapMaybeWithKey . const

mapMaybeWithKey :: TKey k => (k -> a -> Maybe b) -> TMap k a -> TMap k b
mapMaybeWithKey f (TMap m) = TMap (mapMaybeM elemSize (\ k (Elem a) -> Elem <$> f (fromRep k) a) m)

partition :: TKey k => (a -> Bool) -> TMap k a -> (TMap k a, TMap k a)
partition = partitionWithKey . const

partitionWithKey :: TKey k => (k -> a -> Bool) -> TMap k a -> (TMap k a, TMap k a)
partitionWithKey p = mapEitherWithKey (\ k a -> (if p k a then Left else Right) a)

filter :: TKey k => (a -> Bool) -> TMap k a -> TMap k a
filter = filterWithKey . const

filterWithKey :: TKey k => (k -> a -> Bool) -> TMap k a -> TMap k a
filterWithKey p = mapMaybeWithKey (\ k a -> if p k a then Just a else Nothing)

split :: TKey k => k -> TMap k a -> (TMap k a, TMap k a)
split k m = case splitLookup k m of
	(mL, _, mR) -> (mL, mR)

splitLookup :: TKey k => k -> TMap k a -> (TMap k a, Maybe a, TMap k a)
splitLookup k (TMap m) = case splitLookupM elemSize f (toRep k) m of
	(# mL, x, mR #) -> (TMap mL, x, TMap mR) 
	where	f (Elem x) = (# Nothing, Just x, Nothing #)

isSubmapOf :: (TKey k, Eq a) => TMap k a -> TMap k a -> Bool
isSubmapOf = isSubmapOfBy (==)

isSubmapOfBy :: TKey k => (a -> b -> Bool) -> TMap k a -> TMap k b -> Bool
isSubmapOfBy (<=) (TMap m1) (TMap m2) = isSubmapM (<<=) m1 m2 where
	Elem a <<= Elem b = a <= b

fromList, fromAscList :: TKey k => [(k, a)] -> TMap k a
fromList = fromListWith const
fromAscList = fromAscListWith const

fromListWith, fromAscListWith :: TKey k => (a -> a -> a) -> [(k, a)] -> TMap k a
fromListWith = fromListWithKey . const
fromAscListWith = fromAscListWithKey . const

fromListWithKey, fromAscListWithKey :: TKey k => (k -> a -> a -> a) -> [(k, a)] -> TMap k a
fromListWithKey f xs = TMap (fromListM elemSize (\ k (Elem a) (Elem b) -> Elem (f (fromRep k) a b)) [(toRep k, Elem a) | (k, a) <- xs])
fromAscListWithKey f xs = TMap (fromAscListM elemSize (\ k (Elem a) (Elem b) -> Elem (f (fromRep k) a b)) [(toRep k, Elem a) | (k, a) <- xs])

fromDistinctAscList :: TKey k => [(k, a)] -> TMap k a
fromDistinctAscList xs = TMap (fromDistAscListM elemSize [(toRep k, Elem a) | (k, a) <- xs])

size :: TKey k => TMap k a -> Int
size (TMap m) = sizeM elemSize m

member :: TKey k => k -> TMap k a -> Bool
member = isJust .: lookup

notMember :: TKey k => k -> TMap k a -> Bool
notMember = not .: member

keysSet :: TKey k => TMap k a -> TSet k
keysSet m = TSet (() <$ m)