-- | Generalized tries. \"Normal\" tries encode finite maps from lists to arbitrary values, where the
-- common prefixes are shared. Here we do the same for trees, generically.
--
-- See also
--
-- * Connelly, Morris: A generalization of the trie data structure
--
-- * Ralf Hinze: Generalizing Generalized Tries
--
-- This module should be imported qualified.
--

{-# LANGUAGE CPP #-}
module Data.Generics.Fixplate.Trie
  ( Trie
    -- * Construction \/ deconstruction
  , empty , singleton
  , fromList , toList
    -- * Multisets
  , bag , universeBag
  , christmasTree
    -- * Lookup
  , lookup
    -- * Insertion \/ deletion
  , insert , insertWith
  , delete , update
    -- * Set operations
  , intersection , intersectionWith
  , union        , unionWith
  , difference   , differenceWith
  )
  where

---------------------------------------------------------------------------------

import Prelude hiding ( lookup )

import Data.Generics.Fixplate.Base
import Data.Generics.Fixplate.Open hiding ( toList )
import Data.Generics.Fixplate.Traversals ( universe )

import qualified Data.Foldable as Foldable

import Data.Foldable    ()
import Data.Traversable ()

import qualified Data.Map as Map ; import Data.Map (Map)

---------------------------------------------------------------------------------

-- | Creates a trie-multiset from a list of trees.
bag :: (Functor f, Foldable f, OrdF f) => [Mu f] -> Trie f Int
bag ts = Prelude.foldl worker emptyTrie ts where
  worker trie tree = trieInsertWith id (+) tree 1 trie

-- | This is equivalent to
--
-- > universeBag == bag . universe
--
-- TODO: more efficient implementation? and better name
universeBag :: (Functor f, Foldable f, OrdF f) => Mu f -> Trie f Int
universeBag = bag . universe

-- | We attribute each node with the multiset of all its subtrees.
-- TODO: better name
christmasTree :: (Functor f, Foldable f, OrdF f) => Mu f -> Attr f (Trie f Int)
christmasTree = go where
  go this@(Fix t) = Fix (Ann (ins us) sub) where
    sub = fmap go t
    us  = Foldable.foldl (trieUnionWith (+)) emptyTrie (fmap attribute sub)
    ins = trieInsertWith id (+) this 1

---------------------------------------------------------------------------------

empty :: (Functor f, Foldable f, OrdF f) => Trie f a
empty = emptyTrie

singleton :: (Functor f, Foldable f, OrdF f) => Mu f -> a -> Trie f a
singleton = trieSingleton

lookup :: (Functor f, Foldable f, OrdF f) => Mu f -> Trie f a -> Maybe a
lookup = trieLookup

insert :: (Functor f, Foldable f, OrdF f) => Mu f -> a -> Trie f a -> Trie f a
insert = trieInsertWith id const

insertWith :: (Functor f, Foldable f, OrdF f) => (a -> b) -> (a -> b -> b) -> Mu f -> a -> Trie f b -> Trie f b
insertWith = trieInsertWith

update :: (Functor f, Foldable f, OrdF f) => (a -> Maybe a) -> Mu f -> Trie f a -> Trie f a
update = trieUpdate

delete :: (Functor f, Foldable f, OrdF f) => Mu f -> Trie f a -> Trie f a
delete = trieUpdate (const Nothing)

-- | TODO: more efficient implementation?
fromList :: (Traversable f, OrdF f) => [(Mu f, a)] -> Trie f a
fromList ts = Prelude.foldl worker emptyTrie ts where
  worker trie (tree,value) = trieInsertWith id const tree value trie

toList :: (Traversable f, OrdF f) => Trie f a -> [(Mu f, a)]
toList = trieToList

intersection :: (Functor f, Foldable f, OrdF f) => Trie f a -> Trie f b -> Trie f a
intersection = trieIntersectionWith const

intersectionWith :: (Functor f, Foldable f, OrdF f) => (a -> b -> c) -> Trie f a -> Trie f b -> Trie f c
intersectionWith = trieIntersectionWith

-- | Union is left-biased:
--
-- > union == unionWith const
--
union :: (Functor f, Foldable f, OrdF f) => Trie f a -> Trie f a -> Trie f a
union = trieUnionWith const

unionWith :: (Functor f, Foldable f, OrdF f) => (a -> a -> a) -> Trie f a -> Trie f a -> Trie f a
unionWith = trieUnionWith

difference :: (Functor f, Foldable f, OrdF f) => Trie f a -> Trie f b -> Trie f a
difference = trieDifferenceWith (\_ _ -> Nothing)

differenceWith :: (Functor f, Foldable f, OrdF f) => (a -> b -> Maybe a) -> Trie f a -> Trie f b -> Trie f a
differenceWith = trieDifferenceWith

---------------------------------------------------------------------------------

-- | 'Trie' is an efficient(?) implementation of finite maps from @(Mu f)@ to an arbitrary type @v@.
newtype Trie f v = Trie { unTrie :: Map (HoleF f) (Chain f v) }

data Chain f v
  = Value v
  | Chain (Trie f (Chain f v))

-- this is only to be able to define an Ord instance
newtype HoleF f = HoleF { unHoleF :: f Hole }

instance EqF  f => Eq  (HoleF f) where (==)    (HoleF x) (HoleF y) = equalF   x y
instance OrdF f => Ord (HoleF f) where compare (HoleF x) (HoleF y) = compareF x y

emptyTrie :: (Functor f, Foldable f, OrdF f) => Trie f v
emptyTrie = Trie (Map.empty)

---------------------------------------------------------------------------------

trieLookup :: (Functor f, Foldable f, OrdF f) => Mu f -> Trie f v -> Maybe v
trieLookup (Fix t) (Trie trie) =
  case Map.lookup (HoleF s) trie of
    Nothing    -> Nothing
    Just chain -> chainLookup (Foldable.toList t) chain
  where
    s = fmap (const Hole) t

chainLookup :: (Functor f, Foldable f, OrdF f) => [Mu f] -> Chain f v -> Maybe v
chainLookup [] chain = case chain of { Value x -> Just x ; _ -> error "chainLookup: shouldn't happen #1" }
chainLookup (k:ks) chain = case chain of
  Chain sub -> case trieLookup k sub of
    Just chain -> chainLookup ks chain
    Nothing    -> Nothing
  Value  _  -> error "chainLookup: shouldn't happen #2"

---------------------------------------------------------------------------------

chainSingleton :: (Functor f, Foldable f, OrdF f) => [Mu f] -> a -> Chain f a
chainSingleton trees x = go trees where
  go [] = Value x
  go (t:ts) = Chain (trieSingleton t (go ts))

trieSingleton :: (Functor f, Foldable f, OrdF f) => Mu f -> a -> Trie f a
trieSingleton (Fix t) x = Trie $ Map.singleton (HoleF s) (chainSingleton (Foldable.toList t) x) where
  s = fmap (const Hole) t

---------------------------------------------------------------------------------

mapInsertWith :: Ord k => (a -> v) -> (a -> v -> v) -> k -> a -> Map k v ->  Map k v
mapInsertWith f g k x = x `seq` Map.alter worker k where
  worker Nothing   =          Just $! (f x)
  worker (Just y)  = y `seq` (Just $! (g x y))

trieInsertWith :: (Functor f, Foldable f, OrdF f) => (a -> b) -> (a -> b -> b) -> Mu f -> a -> Trie f b -> Trie f b
trieInsertWith uf ug (Fix t) value (Trie trie) = Trie $ mapInsertWith wf wg (HoleF s) value trie where
  wf z       = chainSingleton (Foldable.toList t) (uf z)
  wg z chain = chainInsertWith uf ug (Foldable.toList t) z chain
  s = fmap (const Hole) t

chainInsertWith :: (Functor f, Foldable f, OrdF f) => (a -> b) -> (a -> b -> b) -> [Mu f] -> a -> Chain f b -> Chain f b
chainInsertWith uf ug trees x chain = go trees chain where
  go ts chn = case ts of
    [] -> case chn of
      Value y -> Value (ug x y)
      Chain _ -> error "chainInsertWith: shouldn't happen #1"
    (t:ts) -> case chn of
      Chain trie -> Chain $ trieInsertWith wf wg t x trie where
        wf z   = chainSingleton ts (uf z)
        wg z c = chainInsertWith uf ug ts z c
      Value _    -> error "chainInsertWith: shouldn't happen #2"

---------------------------------------------------------------------------------

trieUpdate :: (Functor f, Foldable f, OrdF f) => (a -> Maybe a) -> Mu f -> Trie f a -> Trie f a
trieUpdate user (Fix t) (Trie trie) = Trie $ Map.update worker (HoleF s) trie where
  worker chain = chainUpdate user (Foldable.toList t) chain
  s = fmap (const Hole) t

chainUpdate :: (Functor f, Foldable f, OrdF f) => (a -> Maybe a) -> [Mu f] -> Chain f a -> Maybe (Chain f a)
chainUpdate user = go where
  go trees chain = case trees of
    [] -> case chain of
      Value x -> case user x of
        Just y  -> Just (Value y)
        Nothing -> Nothing
      Chain _ -> error "chainUpdate: shouldn't happen #1"
    (t:ts) -> case chain of
      Chain trie -> Just $ Chain $ trieUpdate (go ts) t trie
      Value _    -> error "chainInsertWith: shouldn't happen #2"

---------------------------------------------------------------------------------

trieToList :: (Traversable f, OrdF f) => Trie f a -> [(Mu f, a)]
trieToList (Trie trie) =
  [ (Fix (builder key ts), val)
  | (HoleF key, chain) <- Map.toList trie
  , (ts, val) <- chainToList chain
  ]

chainToList :: (Traversable f, OrdF f) => Chain f a -> [([Mu f], a)]
chainToList = go where
  go chain = case chain of
    Value x    -> [([],x)]
    Chain trie ->
      [ (t:ts, val)
      | (t ,ch ) <- trieToList trie
      , (ts,val) <- go ch
      ]

---------------------------------------------------------------------------------

chainIntersectionWith :: (Functor f, Foldable f, OrdF f) => (a -> b -> c) -> Chain f a -> Chain f b -> Chain f c
chainIntersectionWith f (Value x ) (Value y ) = Value (f x y)
chainIntersectionWith f (Chain t1) (Chain t2) = Chain (trieIntersectionWith (chainIntersectionWith f) t1 t2)
chainIntersectionWith _ _ _ = error "chainIntersectionWith: shouldn't happen"

trieIntersectionWith :: (Functor f, Foldable f, OrdF f) => (a -> b -> c) -> Trie f a -> Trie f b -> Trie f c
trieIntersectionWith f (Trie trie1) (Trie trie2) = Trie (Map.intersectionWith worker trie1 trie2) where
  worker chain1 chain2 = chainIntersectionWith f chain1 chain2

---------------------------------------------------------------------------------

chainUnionWith :: (Functor f, Foldable f, OrdF f) => (a -> a -> a) -> Chain f a -> Chain f a -> Chain f a
chainUnionWith f (Value x ) (Value y ) = Value (f x y)
chainUnionWith f (Chain t1) (Chain t2) = Chain (trieUnionWith (chainUnionWith f) t1 t2)
chainUnionWith _ _ _ = error "chainUnionWith: shouldn't happen"

trieUnionWith :: (Functor f, Foldable f, OrdF f) => (a -> a -> a) -> Trie f a -> Trie f a -> Trie f a
trieUnionWith f (Trie trie1) (Trie trie2) = Trie (Map.unionWith worker trie1 trie2) where
  worker chain1 chain2 = chainUnionWith f chain1 chain2

---------------------------------------------------------------------------------

chainDifferenceWith :: (Functor f, Foldable f, OrdF f) => (a -> b -> Maybe a) -> Chain f a -> Chain f b -> Maybe (Chain f a)
chainDifferenceWith f (Value x ) (Value y ) = case f x y of
  Just z  -> Just (Value z)
  Nothing -> Nothing
chainDifferenceWith f (Chain t1) (Chain t2) = Just $ Chain (trieDifferenceWith (chainDifferenceWith f) t1 t2)
chainDifferenceWith _ _ _ = error "chainDifferenceWith: shouldn't happen"

trieDifferenceWith :: (Functor f, Foldable f, OrdF f) => (a -> b -> Maybe a) -> Trie f a -> Trie f b -> Trie f a
trieDifferenceWith f (Trie trie1) (Trie trie2) = Trie (Map.differenceWith worker trie1 trie2) where
  worker chain1 chain2 = chainDifferenceWith f chain1 chain2

---------------------------------------------------------------------------------