{-# LANGUAGE TupleSections, TypeFamilies, UnboxedTuples #-}

module Data.TrieMap.TrieKey where

import Data.TrieMap.Applicative
import Data.TrieMap.Sized

import Control.Applicative
import Control.Arrow

import Data.Monoid

type EitherMap k a b c = k -> a -> (# Maybe b, Maybe c #)
type SplitMap a x = a -> (# Maybe a, Maybe x, Maybe a #)
type UnionFunc k a = k -> a -> a -> Maybe a
type IsectFunc k a b c = k -> a -> b -> Maybe c
type DiffFunc k a b = k -> a -> b -> Maybe a
type ExtractFunc f m k a x = (k -> a -> f (x, Maybe a)) -> m -> f (x, m)
type LEq a b = a -> b -> Bool

data Assoc k a = Asc {-# UNPACK #-} !Int k a
type IndexPos k a = (# Last (Assoc k a), Maybe (Assoc k a), First (Assoc k a) #)

onIndexA :: (Int -> Int) -> Assoc k a -> Assoc k a
onIndexA f (Asc i k a) = Asc (f i) k a

onKeyA :: (k -> k') -> Assoc k a -> Assoc k' a
onKeyA = onValueA . first

onValA :: (a -> a') -> Assoc k a -> Assoc k a'
onValA = onValueA . second

{-# INLINE onValueA #-}
onValueA :: ((k, a) -> (k', a')) -> Assoc k a -> Assoc k' a'
onValueA f (Asc i k a) = uncurry (Asc i) (f (k, a))

onUnboxed :: (c -> d) -> (a -> (# b, c #)) -> a -> (# b, d #)
onUnboxed g f a = case f a of
		       (# b, c #) -> (# b, g c #)

class Ord k => TrieKey k where
	data TrieMap k :: * -> *
	emptyM :: TrieMap k a
	singletonM :: Sized a -> k -> a -> TrieMap k a
	nullM :: TrieMap k a -> Bool
	sizeM :: Sized a -> TrieMap k a -> Int
	lookupM :: k -> TrieMap k a -> Maybe a
	alterM :: Sized a -> (Maybe (a) -> Maybe (a)) -> k -> TrieMap k a -> TrieMap k a
	alterLookupM :: Sized a -> (Maybe a -> (# x, Maybe a #)) -> k -> TrieMap k a -> (# x, TrieMap k a #)
	{-# SPECIALIZE traverseWithKeyM :: (k -> a -> Id (b)) -> TrieMap k a -> Id (TrieMap k b) #-}
	traverseWithKeyM :: (TrieMap k ~ m, Applicative f) => Sized b ->
		(k -> a -> f (b)) -> TrieMap k a -> f (TrieMap k b)
	foldWithKeyM :: (k -> a -> b -> b) -> TrieMap k a -> b -> b
	foldlWithKeyM :: (k -> b -> a -> b) -> TrieMap k a -> b -> b
	mapMaybeM :: Sized b -> (k -> a -> Maybe b) -> TrieMap k a -> TrieMap k b
	mapEitherM :: Sized b -> Sized c -> EitherMap k (a) (b) (c) -> TrieMap k a -> (# TrieMap k b, TrieMap k c #)
	splitLookupM :: Sized a -> SplitMap a x -> k -> TrieMap k a -> (# TrieMap k a, Maybe x, TrieMap k a #)
	unionM :: Sized a -> UnionFunc k (a) -> TrieMap k a -> TrieMap k a -> TrieMap k a
	isectM :: Sized c -> IsectFunc k (a) (b) (c) -> TrieMap k a -> TrieMap k b -> TrieMap k c
	diffM :: Sized a -> DiffFunc k (a) (b) -> TrieMap k a -> TrieMap k b -> TrieMap k a
	extractM :: (Alternative f) => Sized a -> ExtractFunc f (TrieMap k a) k a x
	isSubmapM :: LEq (a) (b) -> LEq (TrieMap k a) (TrieMap k b)
	fromListM, fromAscListM :: Sized a -> (k -> a -> a -> a) -> [(k, a)] -> TrieMap k a
	fromDistAscListM :: Sized a -> [(k, a)] -> TrieMap k a
	
	sizeM s m = foldWithKeyM (\ _ a n -> s a + n) m 0
	fromListM s f = foldr (uncurry (insertWithKeyM s f)) emptyM
	fromAscListM = fromListM
	fromDistAscListM s = fromAscListM s (const const)

guardNullM :: TrieKey k => TrieMap k a -> Maybe (TrieMap k a)
guardNullM m
	| nullM m	= Nothing
	| otherwise	= Just m

sides :: (b -> d) -> (a -> (# b, c, b #)) -> a -> (# d, c, d #)
sides g f a = case f a of
		   (# x, y, z #) -> (# g x, y, g z #)

both :: (b -> b') -> (c -> c') -> (a -> (# b, c #)) -> a -> (# b', c' #)
both g1 g2 f a = case f a of
		  (# x, y #) -> (# g1 x, g2 y #)

{-# INLINE [1] mapWithKeyM #-}
mapWithKeyM :: TrieKey k => Sized b -> (k -> a -> b) -> TrieMap k a -> TrieMap k b
mapWithKeyM s f  = unId . traverseWithKeyM s (Id .: f)

mapM :: TrieKey k => Sized b -> (a -> b) -> TrieMap k a -> TrieMap k b
mapM s = mapWithKeyM s . const

assocsM :: TrieKey k => TrieMap k a -> [(k, a)]
assocsM m = foldWithKeyM (\ k a xs -> (k, a):xs) m []

insertM :: TrieKey k => Sized a -> k -> a -> TrieMap k a -> TrieMap k a
insertM s = insertWithKeyM s (const const)

insertWithKeyM :: TrieKey k => Sized a -> (k -> a -> a -> a) -> k -> a -> TrieMap k a -> TrieMap k a
insertWithKeyM s f k a = alterM s f' k where
	f' = Just . maybe a (f k a)

fromListM' :: TrieKey k => Sized a -> [(k, a)] -> TrieMap k a
fromListM' s = fromListM s (const const) --xs = foldr (uncurry insertM) emptyM xs

unionMaybe :: (a -> a -> Maybe a) -> Maybe a -> Maybe a -> Maybe a
unionMaybe _ Nothing y = y
unionMaybe _ x Nothing = x
unionMaybe f (Just x) (Just y) = f x y

isectMaybe :: (a -> b -> Maybe c) -> Maybe a -> Maybe b -> Maybe c
isectMaybe f (Just x) (Just y) = f x y
isectMaybe _ _ _ = Nothing

diffMaybe :: (a -> b -> Maybe a) -> Maybe a -> Maybe b -> Maybe a
diffMaybe _ Nothing _ = Nothing
diffMaybe _ (Just x) Nothing = Just x
diffMaybe f (Just x) (Just y) = f x y

subMaybe :: (a -> b -> Bool) -> Maybe a -> Maybe b -> Bool
subMaybe _ Nothing _ = True
subMaybe (<=) (Just a) (Just b) = a <= b
subMaybe _ _ _ = False

aboutM :: (TrieKey k, Alternative t) => (k -> a -> t z) -> TrieMap k a -> t z
aboutM f = fst <.> extractM (const 0) (\ k a -> fmap (, Nothing) (f k a))

{-# RULES
-- 	"lookupM/emptyM" forall k . lookupM k emptyM = Nothing;
-- 	"sizeM/emptyM" forall s . sizeM s emptyM = 0;
-- 	"traverseWithKeyM/emptyM" forall s f . traverseWithKeyM s f emptyM = pure emptyM;
-- 	"extractM/emptyM" forall s f . extractM s f emptyM = empty;
-- 	"foldWithKeyM/emptyM" forall f . foldWithKeyM f emptyM z = z;
-- 	"foldlWithKeyM/emptyM" forall f . foldlWithKeyM f emptyM z = z;
-- 	"lookupIxM/emptyM" forall s k . lookupIxM s k emptyM = (empty, empty, empty);
-- 	"mapEitherM/emptyM" forall s1 s2 f . mapEitherM s1 s2 f emptyM = (emptyM, emptyM);
	#-}