{-# LANGUAGE Rank2Types, TypeOperators, KindSignatures, FlexibleInstances, FlexibleContexts, UndecidableInstances, TypeFamilies, GADTs, MultiParamTypeClasses #-}

module Data.TrieMap.MultiRec.TagMap where

import Data.TrieMap.MultiRec.Class
import Data.TrieMap.MultiRec.Eq
import Data.TrieMap.MultiRec.Sized
import Data.TrieMap.Applicative
import Data.TrieMap.TrieKey

import Control.Applicative
import Control.Arrow
import Control.Monad

import Data.Maybe
import Data.Monoid
import Data.Foldable
import Generics.MultiRec

data TagF a ix :: * -> * where
	TagF :: a ix -> TagF a ix ix

unTagF :: TagF a ix xi -> a xi
unTagF (TagF x) = x

newtype TagMap (phi :: * -> *) m ix (r :: * -> *) a xi = TagMap (m r (TagF a ix) xi)
type instance HTrieMapT phi (f :>: ix) = TagMap phi (HTrieMapT phi f) ix
type instance HTrieMap phi ((f :>: ix) r) = HTrieMapT phi (f :>: ix) r

combineTag :: IsectFunc ((f :>: ix) r xi) (a xi) (b xi) (c xi) ->
	IsectFunc (f r xi) (TagF a ix xi) (TagF b ix xi) (TagF c ix xi)
combineTag f k (TagF a) (TagF b) = TagF <$> f (Tag k) a b

mapTag :: Functor t => ((f :>: ix) r xi -> a xi -> t (b xi)) -> f r xi -> TagF a ix xi -> t (TagF b ix xi)
mapTag f k (TagF a) = TagF <$> f (Tag k) a

sizeTag :: HSized phi a -> HSized phi (TagF a ix)
sizeTag s (TagF x) = s x

instance (HTrieKeyT phi f m, m ~ HTrieMapT phi f) => HTrieKeyT phi (f :>: ix) (TagMap phi m ix) where
	emptyT = TagMap . emptyT
	nullT pf (TagMap m) = nullT pf m
	sizeT s (TagMap m) = sizeT (sizeTag s) m
	lookupT pf (Tag k) (TagMap m) = unTagF <$> lookupT pf k m
	lookupIxT pf s (Tag k) (TagMap m) = fmap unTagF <$> lookupIxT pf (sizeTag s) k m
	assocAtT pf s i (TagMap m) = unTagger (assocAtT pf (sizeTag s) i m)
		where	unTagger :: (Int, f r ix, TagF a xi ix) -> (Int, (f :>: xi) r ix, a ix)
			unTagger (i', k, TagF a) = (i', Tag k, a)
	updateAtT pf s f i (TagMap m) = TagMap (updateAtT pf (sizeTag s) (f' f) i m) where
		f' :: (Int -> (f :>: xi) r ix -> a ix -> Maybe (a ix)) -> Int -> f r ix -> TagF a xi ix -> Maybe (TagF a xi ix)
		f' f i k (TagF a) = TagF <$> f i (Tag k) a
	alterT pf s f (Tag k) (TagMap m) = TagMap (alterT pf (sizeTag s) (fmap TagF . f . fmap unTagF) k m)
	traverseWithKeyT pf s f (TagMap m) = TagMap <$> traverseWithKeyT pf (sizeTag s) (mapTag f) m where
		f' :: Applicative t => ((f :>: ix) r xi -> a xi -> t (b xi)) -> f r xi -> TagF a ix xi -> t (TagF b ix xi)
		f' f k (TagF a) = TagF <$> f (Tag k) a
	foldWithKeyT pf f (TagMap m) = foldWithKeyT pf (f' f) m where
		f' :: ((f :>: ix) r xi -> a xi -> b -> b) -> f r xi -> TagF a ix xi -> b -> b
		f' f k (TagF a) = f (Tag k) a
	foldlWithKeyT pf f (TagMap m) = foldlWithKeyT pf (f' f) m where
		f' :: ((f :>: ix) r xi -> b -> a xi  -> b) -> f r xi -> b -> TagF a ix xi -> b
		f' f k z (TagF a) = f (Tag k) z a
	mapEitherT pf s1 s2 f (TagMap m) = (TagMap *** TagMap) (mapEitherT pf (sizeTag s1) (sizeTag s2) (f' f) m) where
		f' :: EitherMap ((f :>: ix) r xi) (a xi) (b xi) (c xi) -> EitherMap (f r xi) (TagF a ix xi) (TagF b ix xi) (TagF c ix xi)
		f' f k (TagF a) = (fmap TagF *** fmap TagF) (f (Tag k) a)
	splitLookupT pf s f (Tag k) (TagMap m) = TagMap `sides` splitLookupT pf (sizeTag s) (f' f) k m where
		f' :: SplitMap (a ix) x -> SplitMap (TagF a xi ix) x
		f' f (TagF a) = fmap TagF `sides` f a
	unionT pf s f (TagMap m1) (TagMap m2) = TagMap (unionT pf (sizeTag s) (combineTag f) m1 m2) 
	isectT pf s f (TagMap m1) (TagMap m2) = TagMap (isectT pf (sizeTag s) (combineTag f) m1 m2)
	diffT pf s f (TagMap m1) (TagMap m2) = TagMap (diffT pf (sizeTag s) (combineTag f) m1 m2)
	extractMinT pf s (TagMap m) = do
		((k, TagF a), m') <- extractMin' pf ((sizeTag :: HSized phi a -> HSized phi (TagF a ix)) s) m
		return ((Tag k, a), TagMap m')
	 where	extractMin' :: (HTrieKeyT phi f m, m ~ HTrieMapT phi f, HTrieKey phi r (HTrieMap phi r)) => 
	 		phi ix -> HSized phi (TagF a xi) -> m r (TagF a xi) ix ->
	 			First ((f r ix, TagF a xi ix), m r (TagF a xi) ix)
	 	extractMin' = extractMinT
	extractMaxT pf s (TagMap m) = do
		((k, TagF a), m') <- extractMax' pf ((sizeTag :: HSized phi a -> HSized phi (TagF a ix)) s) m
		return ((Tag k, a), TagMap m')
	 where	extractMax' :: (HTrieKeyT phi f m, m ~ HTrieMapT phi f, HTrieKey phi r (HTrieMap phi r)) => 
	 		phi ix -> HSized phi (TagF a xi) -> m r (TagF a xi) ix ->
	 			Last ((f r ix, TagF a xi ix), m r (TagF a xi) ix)
	 	extractMax' = extractMaxT
	alterMinT pf s f (TagMap m) = TagMap (alterMinT pf (sizeTag s) (mapTag f) m)
	alterMaxT pf s f (TagMap m) = TagMap (alterMaxT pf (sizeTag s) (mapTag f) m) 
	isSubmapT pf (<=) (TagMap m1) (TagMap m2) = isSubmapT pf (le (<=)) m1 m2 where
		le :: LEq (a ix) (b ix) -> LEq (TagF a xi ix) (TagF b xi ix)
		le (<=) (TagF a) (TagF b) = a <= b
	fromListT pf s f xs = TagMap (fromListT pf (sizeTag s) (f' f) [(k, TagF a) | (Tag k, a) <- xs]) where
		f' :: ((f :>: ix) r xi -> a xi -> a xi -> a xi) -> f r xi -> TagF a ix xi -> TagF a ix xi -> TagF a ix xi
		f' f k (TagF a) (TagF b) = TagF (f (Tag k) a b)
	fromAscListT pf s f xs = TagMap (fromAscListT pf (sizeTag s) (f' f) [(k, TagF a) | (Tag k, a) <- xs]) where
		f' :: ((f :>: ix) r xi -> a xi -> a xi -> a xi) -> f r xi -> TagF a ix xi -> TagF a ix xi -> TagF a ix xi
		f' f k (TagF a) (TagF b) = TagF (f (Tag k) a b)
	fromDistAscListT pf s xs = TagMap (fromDistAscListT pf (sizeTag s) (map f xs)) where
		f :: ((f :>: ix) r xi, a xi) -> (f r xi, TagF a ix xi)
		f (Tag k, a) = (k, TagF a)

instance (HTrieKeyT phi f m, m ~ HTrieMapT phi f, HTrieKey phi r (HTrieMap phi r)) => 
		HTrieKey phi ((f :>: ix) r) (TagMap phi m ix r) where
	emptyH = emptyT
	nullH = nullT
	sizeH = sizeT
	lookupH = lookupT
	lookupIxH = lookupIxT
	assocAtH = assocAtT
	updateAtH = updateAtT
	alterH = alterT
	traverseWithKeyH = traverseWithKeyT
	foldWithKeyH = foldWithKeyT
	foldlWithKeyH = foldlWithKeyT
	mapEitherH = mapEitherT
	splitLookupH = splitLookupT
	unionH = unionT
	isectH = isectT
	diffH = diffT
	alterMinH = alterMinT
	alterMaxH = alterMaxT
	extractMinH = extractMinT
	extractMaxH = extractMaxT
	isSubmapH = isSubmapT
	fromListH = fromListT
	fromAscListH = fromAscListT
	fromDistAscListH = fromDistAscListT