{-# LANGUAGE UnboxedTuples, TupleSections, PatternGuards, TypeFamilies #-}

module Data.TrieMap.ProdMap () where

import Data.TrieMap.TrieKey
import Data.TrieMap.Applicative

import Control.Applicative

import Data.Maybe
import Data.Foldable

import Data.Sequence ((|>))
import qualified Data.Sequence as Seq

instance (TrieKey k1, TrieKey k2) => TrieKey (k1, k2) where
	newtype TrieMap (k1, k2) a = PMap (TrieMap k1 (TrieMap k2 a))
	emptyM = PMap emptyM
	singletonM s (k1, k2) a = PMap (singletonM (sizeM s) k1 (singletonM s k2 a))
	nullM (PMap m) = nullM m
	sizeM s (PMap m) = sizeM (sizeM s) m
	lookupM (k1, k2) (PMap m) = lookupM k1 m >>= lookupM k2
	alterM s f (a, b) (PMap m) = PMap (alterM (sizeM s) g a m) where
		g = guardNullM . alterM s f b . fromMaybe emptyM
	alterLookupM s f (a, b) (PMap m) = onUnboxed PMap (alterLookupM (sizeM s) g a) m where
		g (Just m) = onUnboxed guardNullM (alterLookupM s f b) m
		g _ = onUnboxed guardNullM (alterLookupM s f b) emptyM
	traverseWithKeyM s f (PMap m) = PMap <$> traverseWithKeyM (sizeM s) (\ a -> traverseWithKeyM s (f . (a,))) m
	foldWithKeyM f (PMap m) = foldWithKeyM (\ a -> foldWithKeyM (f . (a,))) m
	foldlWithKeyM f (PMap m) = foldlWithKeyM (\ a -> flip (foldlWithKeyM (f . (a,)))) m
	mapMaybeM s f (PMap m) = PMap (mapMaybeM (sizeM s) g m) where
		g a = guardNullM . mapMaybeM s (f . (a,))
	mapEitherM s1 s2 f (PMap m) = both PMap PMap (mapEitherM (sizeM s1) (sizeM s2) g) m where
		g a m = both guardNullM guardNullM (mapEitherM s1 s2 (f . (a,))) m
	splitLookupM s f (a, b) (PMap m) = sides PMap (splitLookupM (sizeM s) g a) m where
		g = sides guardNullM (splitLookupM s f b)
	isSubmapM (<=) (PMap m1) (PMap m2) = isSubmapM (isSubmapM (<=)) m1 m2
	unionM s f (PMap m1) (PMap m2) = PMap (unionM (sizeM s) (\ a -> guardNullM .: unionM s (f . (a,))) m1 m2)
	isectM s f (PMap m1) (PMap m2) = PMap (isectM (sizeM s) (\ a -> guardNullM .: isectM s (f . (a,))) m1 m2)
	diffM s f (PMap m1) (PMap m2) = PMap (diffM (sizeM s) (\ a -> guardNullM .: diffM s (f . (a,))) m1 m2)
	extractM s f (PMap m) = fmap PMap <$> extractM (sizeM s) g m where
		g a = fmap guardNullM <.> extractM s (f . (a,))
	fromListM s f xs = PMap (mapWithKeyM (sizeM s) (\ a -> fromListM s (f . (a,)))
		(fromListM (const 1) (const (++)) (breakFst xs)))
	fromAscListM s f xs = PMap (fromDistAscListM (sizeM s)
		[(a, fromAscListM s (f . (a,)) ys) | (a, ys) <- breakFst xs])

breakFst :: Eq k1 => [((k1, k2), a)] -> [(k1, [(k2, a)])]
breakFst [] = []
breakFst (((a, b),v):xs) = breakFst' a (Seq.singleton (b, v)) xs where
	breakFst' a vs (((a', b'), v'):xs)
		| a == a'	= breakFst' a' (vs |> (b', v')) xs
		| otherwise	= (a, toList vs):breakFst' a' (Seq.singleton (b', v')) xs
	breakFst' a vs [] = [(a, toList vs)]