{-# LANGUAGE UnboxedTuples, BangPatterns, TypeFamilies, PatternGuards, MagicHash, CPP #-}
{-# OPTIONS -funbox-strict-fields #-}
module Data.TrieMap.IntMap () where

import Data.TrieMap.TrieKey
import Data.TrieMap.Sized

import Control.Applicative
import Control.Monad hiding (join)

import Data.Bits
import Data.Maybe hiding (mapMaybe)
import Data.Word

import Prelude hiding (lookup, null, foldl, foldr)

#include "MachDeps.h"
#if WORD_SIZE_IN_BITS == 32
import GHC.Prim
import GHC.Word

complement32 (W32# w#) = W32# (not# w#)
#elif WORD_SIZE_IN_BITS > 32
complement32 = xor (bit 32 - 1)
#else
import GHC.Prim
import GHC.IntWord32
complement32 (W32# w#) = W32# (not32# w#)
#endif
complement32 :: Word32 -> Word32

{-# RULES
	"complement/Word32" complement = complement32
	#-}

type Nat = Word32

type Prefix = Word32
type Mask   = Word32
type Key    = Word32
type Size   = Int#

data Path a = Root 
	| LeftBin !Prefix !Mask !(Path a) !(TrieMap Word32 a)
	| RightBin !Prefix !Mask !(TrieMap Word32 a) !(Path a)

instance TrieKey Word32 where
	data TrieMap Word32 a = Nil
              | Tip !Size !Key a
              | Bin !Size !Prefix !Mask !(TrieMap Word32 a) !(TrieMap Word32 a)
        data Hole Word32 a = Hole !Key !(Path a)
	emptyM = Nil
	singletonM = singleton
	nullM = null
	sizeM = size
	lookupM = lookup
	traverseWithKeyM = traverseWithKey
	foldrWithKeyM = foldr
	foldlWithKeyM = foldl
	mapWithKeyM = mapWithKey
	mapMaybeM = mapMaybe
	mapEitherM = mapEither
	unionM = unionWithKey
	isectM = intersectionWithKey
	diffM = differenceWithKey
-- 	extractM  f = extract  f
	isSubmapM = isSubmapOfBy
	
	singleHoleM k = Hole k Root
	keyM (Hole k _) = k
	beforeM  a (Hole k path) = before (singletonMaybe  k a) path where
		before t Root = t
		before t (LeftBin _ _ path _) = before t path
		before t (RightBin p m l path) = before (bin p m l t) path
	afterM  a (Hole k path) = after (singletonMaybe  k a) path where
		after t Root = t
		after t (RightBin _ _ _ path) = after t path
		after t (LeftBin p m path r) = after (bin p m t r) path
	searchM !k = onUnboxed (Hole k) (search Root) where
		search path t@(Bin _ p m l r)
			| nomatch k p m	= (# Nothing, branchHole k p path t #)
			| zero k m
				= search (LeftBin p m path r) l
			| otherwise
				= search (RightBin p m l path) r
		search path t@(Tip _ ky y)
			| k == ky	= (# Just y, path #)
			| otherwise	= (# Nothing, branchHole k ky path t #)
		search path _ = (# Nothing, path #)
	indexM i# t = indexT i# t Root where
		indexT _ Nil _ = (# error err, error err, error err #) where
			err = "Error: empty trie"
		indexT i# (Tip _ kx x) path = (# i#, x, Hole kx path #)
		indexT i# (Bin _ p m l r) path
			| i# <# sl#	= indexT i# l (LeftBin p m path r)
			| otherwise	= indexT (i# -# sl#) r (RightBin p m l path)
			where !sl# = size l
	extractHoleM = extractHole Root where
		extractHole _ Nil = mzero
		extractHole path (Tip _ kx x) = return (x, Hole kx path)
		extractHole path (Bin _ p m l r) =
			extractHole (LeftBin p m path r) l `mplus`
				extractHole (RightBin p m l path) r
	assignM v (Hole kx path) = assign (singleton kx v) path where
		assign t Root = t
		assign t (LeftBin p m path r) = assign (bin p m t r) path
		assign t (RightBin p m l path) = assign (bin p m l t) path
	
	clearM (Hole _ path) = clear Nil path where
		clear t Root = t
		clear t (LeftBin p m path r) = clear (bin p m t r) path
		clear t (RightBin p m l path) = clear (bin p m l t) path

branchHole :: Key -> Prefix -> Path a -> TrieMap Word32 a -> Path a
branchHole !k !p path t
  | zero k m	= LeftBin p' m path t
  | otherwise	= RightBin p' m t path
  where	m = branchMask k p
  	p' = mask k m

natFromInt :: Word32 -> Nat
natFromInt = id

intFromNat :: Nat -> Word32
intFromNat = id

shiftRL :: Nat -> Key -> Nat
-- #if __GLASGOW_HASKELL__
{--------------------------------------------------------------------
  GHC: use unboxing to get @shiftRL@ inlined.
--------------------------------------------------------------------}
-- shiftRL (W# x) (I# i)
--   = W# (shiftRL# x i)
-- #else
shiftRL x i   = shiftR x (fromIntegral i)
-- #endif

size :: TrieMap Word32 a -> Int#
size Nil = 0#
size (Tip sz _ _) = sz
size (Bin sz _ _ _ _) = sz

null :: TrieMap Word32 a -> Bool
null Nil = True
null _ = False

lookup :: Nat -> TrieMap Word32 a -> Maybe a
lookup k (Bin _ _ m l r) = lookup k (if zeroN k m then l else r)
lookup k (Tip _ kx x)
	| k == kx	= Just x
lookup _ _ = Nothing

singleton :: Sized a => Key -> a -> TrieMap Word32 a
singleton k a = Tip (getSize# a) k a

singletonMaybe :: Sized a => Key -> Maybe a -> TrieMap Word32 a
singletonMaybe k = maybe Nil (singleton k)

traverseWithKey :: (Applicative f, Sized b) => (Key -> a -> f b) -> TrieMap Word32 a -> f (TrieMap Word32 b)
traverseWithKey f t = case t of
	Nil		-> pure Nil
	Tip _ kx x	-> singleton kx <$> f kx x
	Bin _ p m l r	-> bin p m <$> traverseWithKey f l <*> traverseWithKey f r

foldr :: (Key -> a -> b -> b) -> TrieMap Word32 a -> b -> b
foldr f t
  = case t of
      Bin _ _ _ l r -> foldr f l . foldr f r
      Tip _ k x     -> f k x
      Nil         -> id

foldl :: (Key -> b -> a -> b) -> TrieMap Word32 a -> b -> b
foldl f t
  = case t of
      Bin _ _ _ l r -> foldl f r . foldl f l
      Tip _ k x     -> flip (f k) x
      Nil         -> id

mapWithKey :: Sized b => (Key -> a -> b) -> TrieMap Word32 a -> TrieMap Word32 b
mapWithKey f (Bin _ p m l r)	= bin p m (mapWithKey f l) (mapWithKey f r)
mapWithKey f (Tip _ kx x)	= singleton kx (f kx x)
mapWithKey _ _			= Nil

mapMaybe :: Sized b => (Key -> a -> Maybe b) -> TrieMap Word32 a -> TrieMap Word32 b
mapMaybe f (Bin _ p m l r)	= bin p m (mapMaybe f l) (mapMaybe f r)
mapMaybe f (Tip _ kx x)		= singletonMaybe  kx (f kx x)
mapMaybe _ _			= Nil

mapEither :: (Sized b, Sized c) => EitherMap Key a b c ->
	TrieMap Word32 a -> (# TrieMap Word32 b, TrieMap Word32 c #)
mapEither f (Bin _ p m l r) 
	| (# lL, lR #) <- mapEither f l, 
	  (# rL, rR #) <- mapEither f r
				= (# bin p m lL rL, bin p m lR rR #)
mapEither f (Tip _ kx x)	= both (singletonMaybe kx) (singletonMaybe kx) (f kx) x
mapEither _ _			= (# Nil, Nil #)

unionWithKey :: Sized a => UnionFunc Key a -> TrieMap Word32 a -> TrieMap Word32 a -> TrieMap Word32 a
unionWithKey _ Nil t  = t
unionWithKey _ t Nil  = t
unionWithKey f (Tip _ k x) t = alterM (maybe (Just x) (f k x)) k t
unionWithKey f t (Tip _ k x) = alterM (maybe (Just x) (flip (f k) x)) k t
unionWithKey f t1@(Bin _ p1 m1 l1 r1) t2@(Bin _ p2 m2 l2 r2)
  | shorter m1 m2  = union1
  | shorter m2 m1  = union2
  | p1 == p2       = bin p1 m1 (unionWithKey f l1 l2) (unionWithKey f r1 r2)
  | otherwise      = join p1 t1 p2 t2
  where
    union1  | nomatch p2 p1 m1  = join p1 t1 p2 t2
            | zero p2 m1        = bin p1 m1 (unionWithKey f l1 t2) r1
            | otherwise         = bin p1 m1 l1 (unionWithKey f r1 t2)

    union2  | nomatch p1 p2 m2  = join p1 t1 p2 t2
            | zero p1 m2        = bin p2 m2 (unionWithKey f t1 l2) r2
            | otherwise         = bin p2 m2 l2 (unionWithKey f t1 r2)

intersectionWithKey :: Sized c => IsectFunc Key a b c -> TrieMap Word32 a -> TrieMap Word32 b -> TrieMap Word32 c
intersectionWithKey _ Nil _ = Nil
intersectionWithKey _ _ Nil = Nil
intersectionWithKey f (Tip _ k x) t2
  = singletonMaybe  k (lookup (natFromInt k) t2 >>= f k x)
intersectionWithKey f t1 (Tip _ k y) 
  = singletonMaybe  k (lookup (natFromInt k) t1 >>= flip (f k) y)
intersectionWithKey f t1@(Bin _ p1 m1 l1 r1) t2@(Bin _ p2 m2 l2 r2)
  | shorter m1 m2  = intersection1
  | shorter m2 m1  = intersection2
  | p1 == p2       = bin p1 m1 (intersectionWithKey f l1 l2) (intersectionWithKey f r1 r2)
  | otherwise      = Nil
  where
    intersection1 | nomatch p2 p1 m1  = Nil
                  | zero p2 m1        = intersectionWithKey f l1 t2
                  | otherwise         = intersectionWithKey f r1 t2

    intersection2 | nomatch p1 p2 m2  = Nil
                  | zero p1 m2        = intersectionWithKey f t1 l2
                  | otherwise         = intersectionWithKey f t1 r2

differenceWithKey :: Sized a => (Key -> a -> b -> Maybe a) -> TrieMap Word32 a -> TrieMap Word32 b -> TrieMap Word32 a
differenceWithKey _ Nil _       = Nil
differenceWithKey _ t Nil       = t
differenceWithKey f t1@(Tip _ k x) t2 
  = maybe t1 (singletonMaybe k . f k x) (lookup (natFromInt k) t2)
differenceWithKey f t (Tip _ k y) = alterM  (>>= flip (f k) y) k t
differenceWithKey f t1@(Bin _ p1 m1 l1 r1) t2@(Bin _ p2 m2 l2 r2)
  | shorter m1 m2  = difference1
  | shorter m2 m1  = difference2
  | p1 == p2       = bin p1 m1 (differenceWithKey f l1 l2) (differenceWithKey f r1 r2)
  | otherwise      = t1
  where
    difference1 | nomatch p2 p1 m1  = t1
                | zero p2 m1        = bin p1 m1 (differenceWithKey f l1 t2) r1
                | otherwise         = bin p1 m1 l1 (differenceWithKey f r1 t2)

    difference2 | nomatch p1 p2 m2  = t1
                | zero p1 m2        = differenceWithKey f t1 l2
                | otherwise         = differenceWithKey f t1 r2

isSubmapOfBy :: LEq a b -> LEq (TrieMap Word32 a) (TrieMap Word32 b)
isSubmapOfBy (<=) t1@(Bin _ p1 m1 l1 r1) (Bin _ p2 m2 l2 r2)
  | shorter m1 m2  = False
  | shorter m2 m1  = match p1 p2 m2 && (if zero p1 m2 then isSubmapOfBy (<=) t1 l2
                                                      else isSubmapOfBy (<=) t1 r2)                     
  | otherwise      = (p1==p2) && isSubmapOfBy (<=) l1 l2 && isSubmapOfBy (<=) r1 r2
isSubmapOfBy _		(Bin _ _ _ _ _) _
	= False
isSubmapOfBy (<=)	(Tip _ k x) t
	= maybe False (x <=) (lookup (natFromInt k) t)
isSubmapOfBy _		Nil _
	= True

-- extract :: Alternative f => Sized a -> (Key -> a -> f (x, Maybe a)) -> TrieMap Word32 a -> f (x, TrieMap Word32 a)
-- extract  f (Bin _ p m l r)	= 
-- 	fmap (\ l' -> bin p m l' r) <$> extract  f l <|> fmap (bin p m l) <$> extract  f r
-- extract  f (Tip _ k x)		= fmap (singletonMaybe  k) <$> f k x
-- extract _ _ _			= empty

mask :: Key -> Mask -> Prefix
mask i m
  = maskW (natFromInt i) (natFromInt m)

zero :: Key -> Mask -> Bool
zero i m
  = (natFromInt i) .&. (natFromInt m) == 0

nomatch,match :: Key -> Prefix -> Mask -> Bool
nomatch i p m
  = (mask i m) /= p

match i p m
  = (mask i m) == p

zeroN :: Nat -> Nat -> Bool
zeroN i m = (i .&. m) == 0

maskW :: Nat -> Nat -> Prefix
maskW i m
  = intFromNat (i .&. (complement (m-1) `xor` m))

shorter :: Mask -> Mask -> Bool
shorter m1 m2
  = (natFromInt m1) > (natFromInt m2)

branchMask :: Prefix -> Prefix -> Mask
branchMask p1 p2
  = intFromNat (highestBitMask (natFromInt p1 `xor` natFromInt p2))

highestBitMask :: Nat -> Nat
highestBitMask x0
  = case (x0 .|. shiftRL x0 1) of
     x1 -> case (x1 .|. shiftRL x1 2) of
      x2 -> case (x2 .|. shiftRL x2 4) of
       x3 -> case (x3 .|. shiftRL x3 8) of
        x4 -> case (x4 .|. shiftRL x4 16) of
#if WORD_SIZE_IN_BITS > 32
         x5 -> case (x5 .|. shiftRL x5 32) of   -- for 64 bit platforms
          x6 -> (x6 `xor` (shiftRL x6 1))
#else
	 x5 -> x5 `xor` shiftRL x5 1
#endif

join :: Prefix -> TrieMap Word32 a -> Prefix -> TrieMap Word32 a -> TrieMap Word32 a
join p1 t1 p2 t2
  | zero p1 m = bin p m t1 t2
  | otherwise = bin p m t2 t1
  where
    m = branchMask p1 p2
    p = mask p1 m

bin :: Prefix -> Mask -> TrieMap Word32 a -> TrieMap Word32 a -> TrieMap Word32 a
bin _ _ l Nil = l
bin _ _ Nil r = r
bin p m l r   = Bin (size l +# size r) p m l r