```{-|
Module      : What4.Utils.AnnotatedMap
Description : A finite map data structure with monoidal annotations
Copyright   : (c) Galois Inc, 2019-2020
Maintainer  : huffman@galois.com

A finite map data structure with monoidal annotations.
-}

{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}

module What4.Utils.AnnotatedMap
( AnnotatedMap
, null
, empty
, singleton
, size
, lookup
, delete
, annotation
, toList
, fromAscList
, insert
, alter
, alterF
, union
, unionWith
, unionWithKeyMaybe
, filter
, mapMaybe
, traverseMaybeWithKey
, difference
, mergeWithKey
, mergeWithKeyM
, mergeA
, eqBy
) where

import           Data.Functor.Identity
import qualified Data.Foldable as Foldable
import           Data.Foldable (foldl')
import           Prelude hiding (null, filter, lookup)

import qualified Data.FingerTree as FT
import           Data.FingerTree ((><), (<|))

----------------------------------------------------------------------
-- Operations on FingerTrees

filterFingerTree ::
FT.Measured v a =>
(a -> Bool) -> FT.FingerTree v a -> FT.FingerTree v a
filterFingerTree p =
foldl' (\xs x -> if p x then xs FT.|> x else xs) FT.empty

mapMaybeFingerTree ::
(FT.Measured v1 a1, FT.Measured v2 a2) =>
(a1 -> Maybe a2) -> FT.FingerTree v1 a1 -> FT.FingerTree v2 a2
mapMaybeFingerTree f =
foldl' (\xs x -> maybe xs (xs FT.|>) (f x)) FT.empty

traverseMaybeFingerTree ::
(Applicative f, FT.Measured v1 a1, FT.Measured v2 a2) =>
(a1 -> f (Maybe a2)) -> FT.FingerTree v1 a1 -> f (FT.FingerTree v2 a2)
traverseMaybeFingerTree f =
foldl' (\m x -> rebuild <\$> m <*> f x) (pure FT.empty)
where
rebuild ys Nothing  = ys
rebuild ys (Just y) = ys FT.|> y

----------------------------------------------------------------------
-- Tags

data Tag k v = NoTag | Tag !Int k v
-- The Int is there to support the size function.

instance (Ord k, Semigroup v) => Semigroup (Tag k v) where
(<>) = unionTag

instance (Ord k, Semigroup v) => Monoid (Tag k v) where
mempty  = NoTag
mappend = unionTag

unionTag :: (Ord k, Semigroup v) => Tag k v -> Tag k v -> Tag k v
unionTag x NoTag = x
unionTag NoTag y = y
unionTag (Tag ix _ vx) (Tag iy ky vy) =
Tag (ix + iy) ky (vx <> vy)

----------------------------------------------------------------------

newtype AnnotatedMap k v a =
AnnotatedMap { annotatedMap :: FT.FingerTree (Tag k v) (Entry k v a) }
-- Invariant: The entries in the fingertree must be sorted by key,
-- strictly increasing from left to right.

data Entry k v a = Entry k v a
deriving (Functor, Foldable, Traversable)

keyOf :: Entry k v a -> k
keyOf (Entry k _ _) = k

valOf :: Entry k v a -> (v, a)
valOf (Entry _ v a) = (v, a)

instance (Ord k, Semigroup v) => FT.Measured (Tag k v) (Entry k v a) where
measure (Entry k v _) = Tag 1 k v

instance (Ord k, Semigroup v) => Functor (AnnotatedMap k v) where
fmap f (AnnotatedMap ft) =
AnnotatedMap (FT.unsafeFmap (fmap f) ft)

instance (Ord k, Semigroup v) => Foldable.Foldable (AnnotatedMap k v) where
foldr f z (AnnotatedMap ft) =
foldr f z [ a | Entry _ _ a <- Foldable.toList ft ]

instance (Ord k, Semigroup v) => Traversable (AnnotatedMap k v) where
traverse f (AnnotatedMap ft) =
AnnotatedMap <\$> FT.unsafeTraverse (traverse f) ft

annotation :: (Ord k, Semigroup v) => AnnotatedMap k v a -> Maybe v
annotation (AnnotatedMap ft) =
case FT.measure ft of
Tag _ _ v -> Just v
NoTag     -> Nothing

toList :: AnnotatedMap k v a -> [(k, a)]
toList (AnnotatedMap ft) =
[ (k, a) | Entry k _ a <- Foldable.toList ft ]

fromAscList :: (Ord k, Semigroup v) => [(k,v,a)] -> AnnotatedMap k v a
fromAscList = AnnotatedMap . FT.fromList . fmap f
where
f (k, v, a) = Entry k v a

listEqBy :: (a -> a -> Bool) -> [a] -> [a] -> Bool
listEqBy _ [] [] = True
listEqBy f (x : xs) (y : ys)
| f x y = listEqBy f xs ys
listEqBy _ _ _ = False

eqBy :: Eq k => (a -> a -> Bool) -> AnnotatedMap k v a -> AnnotatedMap k v a -> Bool
eqBy f x y = listEqBy (\(kx,ax) (ky,ay) -> kx == ky && f ax ay) (toList x) (toList y)

null :: AnnotatedMap k v a -> Bool
null (AnnotatedMap ft) = FT.null ft

empty :: (Ord k, Semigroup v) => AnnotatedMap k v a
empty = AnnotatedMap FT.empty

singleton :: (Ord k, Semigroup v) => k -> v -> a -> AnnotatedMap k v a
singleton k v a =
AnnotatedMap (FT.singleton (Entry k v a))

size :: (Ord k, Semigroup v) => AnnotatedMap k v a -> Int
size (AnnotatedMap ft) =
case FT.measure ft of
Tag i _ _ -> i
NoTag     -> 0

splitAtKey ::
(Ord k, Semigroup v) =>
k -> FT.FingerTree (Tag k v) (Entry k v a) ->
( FT.FingerTree (Tag k v) (Entry k v a)
, Maybe (Entry k v a)
, FT.FingerTree (Tag k v) (Entry k v a)
)
splitAtKey k ft =
case FT.viewl r of
e FT.:< r' | k == keyOf e -> (l, Just e, r')
_ -> (l, Nothing, r)
where
(l, r) = FT.split found ft
found NoTag = False
found (Tag _ k' _) = k <= k'

insert ::
(Ord k, Semigroup v) =>
k -> v -> a -> AnnotatedMap k v a -> AnnotatedMap k v a
insert k v a (AnnotatedMap ft) =
AnnotatedMap (l >< (Entry k v a <| r))
where
(l, _, r) = splitAtKey k ft

lookup :: (Ord k, Semigroup v) => k -> AnnotatedMap k v a -> Maybe (v, a)
lookup k (AnnotatedMap ft) = valOf <\$> m
where
(_, m, _) = splitAtKey k ft

delete :: (Ord k, Semigroup v) => k -> AnnotatedMap k v a -> AnnotatedMap k v a
delete k m@(AnnotatedMap ft) =
case splitAtKey k ft of
(_, Nothing, _) -> m
(l, Just _, r)  -> AnnotatedMap (l >< r)

alter ::
(Ord k, Semigroup v) =>
(Maybe (v, a) -> Maybe (v, a)) -> k -> AnnotatedMap k v a -> AnnotatedMap k v a
alter f k (AnnotatedMap ft) =
case f (fmap valOf m) of
Nothing -> AnnotatedMap (l >< r)
Just (v, a) -> AnnotatedMap (l >< (Entry k v a <| r))
where
(l, m, r) = splitAtKey k ft

alterF ::
(Functor f, Ord k, Semigroup v) =>
(Maybe (v, a) -> f (Maybe (v, a))) -> k -> AnnotatedMap k v a -> f (AnnotatedMap k v a)
alterF f k (AnnotatedMap ft) = rebuild <\$> f (fmap valOf m)
where
(l, m, r) = splitAtKey k ft

rebuild Nothing       = AnnotatedMap (l >< r)
rebuild (Just (v, a)) = AnnotatedMap (l >< (Entry k v a) <| r)

union ::
(Ord k, Semigroup v) =>
AnnotatedMap k v a -> AnnotatedMap k v a -> AnnotatedMap k v a
union = unionGeneric (const . Just)

unionWith ::
(Ord k, Semigroup v) =>
((v, a) -> (v, a) -> (v, a)) ->
AnnotatedMap k v a -> AnnotatedMap k v a -> AnnotatedMap k v a
unionWith f = unionGeneric g
where
g (Entry k v1 x1) (Entry _ v2 x2) = Just (Entry k v3 x3)
where (v3, x3) = f (v1, x1) (v2, x2)

unionWithKeyMaybe ::
(Ord k, Semigroup v) =>
(k -> a -> a -> Maybe (v, a)) ->
AnnotatedMap k v a -> AnnotatedMap k v a -> AnnotatedMap k v a
unionWithKeyMaybe f = unionGeneric g
where g (Entry k _ x) (Entry _ _ y) = fmap (\(v, z) -> Entry k v z) (f k x y)

unionGeneric ::
(Ord k, Semigroup v) =>
(Entry k v a -> Entry k v a -> Maybe (Entry k v a)) ->
AnnotatedMap k v a -> AnnotatedMap k v a -> AnnotatedMap k v a
unionGeneric f (AnnotatedMap ft1) (AnnotatedMap ft2) = AnnotatedMap (merge1 ft1 ft2)
where
merge1 xs ys =
case FT.viewl xs of
FT.EmptyL -> ys
x FT.:< xs' ->
case ym of
Nothing -> ys1 >< (x <| merge2 xs' ys2)
Just y ->
case f x y of
Nothing -> ys1 >< merge2 xs' ys2
Just z -> ys1 >< (z <| merge2 xs' ys2)
where
(ys1, ym, ys2) = splitAtKey (keyOf x) ys

merge2 xs ys =
case FT.viewl ys of
FT.EmptyL -> xs
y FT.:< ys' ->
case xm of
Nothing -> xs1 >< (y <| merge1 xs2 ys')
Just x ->
case f x y of
Nothing -> xs1 >< merge1 xs2 ys'
Just z -> xs1 >< (z <| merge1 xs2 ys')
where
(xs1, xm, xs2) = splitAtKey (keyOf y) xs

filter ::
(Ord k, Semigroup v) =>
(a -> Bool) -> AnnotatedMap k v a -> AnnotatedMap k v a
filter f (AnnotatedMap ft) = AnnotatedMap (filterFingerTree g ft)
where g (Entry _ _ a) = f a

mapMaybe ::
(Ord k, Semigroup v) =>
(a -> Maybe b) ->
AnnotatedMap k v a -> AnnotatedMap k v b
mapMaybe f (AnnotatedMap ft) =
AnnotatedMap (mapMaybeFingerTree g ft)
where g (Entry k v a) = Entry k v <\$> f a

traverseMaybeWithKey ::
(Applicative f, Ord k, Semigroup v1, Semigroup v2) =>
(k -> v1 -> a1 -> f (Maybe (v2, a2))) ->
AnnotatedMap k v1 a1 -> f (AnnotatedMap k v2 a2)
traverseMaybeWithKey f (AnnotatedMap ft) =
AnnotatedMap <\$> traverseMaybeFingerTree g ft
where
g (Entry k v1 x1) = fmap (\(v2, x2) -> Entry k v2 x2) <\$> f k v1 x1

difference ::
(Ord k, Semigroup v, Semigroup w) =>
AnnotatedMap k v a -> AnnotatedMap k w b -> AnnotatedMap k v a
difference a b = runIdentity \$ mergeGeneric (\_ _ -> Identity Nothing) pure (const (pure empty)) a b

mergeWithKey ::
(Ord k, Semigroup u, Semigroup v, Semigroup w) =>
(k -> (u, a) -> (v, b) -> Maybe (w, c)) {- ^ for keys present in both maps -} ->
(AnnotatedMap k u a -> AnnotatedMap k w c) {- ^ for subtrees only in first map -} ->
(AnnotatedMap k v b -> AnnotatedMap k w c) {- ^ for subtrees only in second map -} ->
AnnotatedMap k u a -> AnnotatedMap k v b -> AnnotatedMap k w c
mergeWithKey f g1 g2 m1 m2 = runIdentity \$ mergeGeneric f' (pure . g1) (pure . g2) m1 m2
where
f' (Entry k u a) (Entry _ v b) =
Identity \$
case f k (u, a) (v, b) of
Nothing -> Nothing
Just (w, c) -> Just (Entry k w c)

mergeA ::
(Ord k, Semigroup v, Applicative f) =>
(k -> (v, a) -> (v, a) -> f (v,a)) ->
AnnotatedMap k v a -> AnnotatedMap k v a -> f (AnnotatedMap k v a)
mergeA f m1 m2 = mergeGeneric f' pure pure m1 m2
where
f' (Entry k v1 x1) (Entry _ v2 x2) = g k <\$> f k (v1, x1) (v2, x2)
g k (v, x) = Just (Entry k v x)

mergeWithKeyM :: (Ord k, Semigroup u, Semigroup v, Semigroup w, Applicative m) =>
(k -> (u, a) -> (v, b) -> m (w, c)) ->
(k -> (u, a) -> m (w, c)) ->
(k -> (v, b) -> m (w, c)) ->
AnnotatedMap k u a -> AnnotatedMap k v b -> m (AnnotatedMap k w c)
mergeWithKeyM both left right = mergeGeneric both' left' right'
where
both' (Entry k u a) (Entry _ v b) = q k <\$> both k (u, a) (v, b)
left'  m = AnnotatedMap <\$> traverseMaybeFingerTree fl (annotatedMap m)
right' m = AnnotatedMap <\$> traverseMaybeFingerTree fr (annotatedMap m)

fl (Entry k v x) = q k <\$> left k (v, x)
fr (Entry k v x) = q k <\$> right k (v, x)

q k (a, b) = Just (Entry k a b)

mergeGeneric ::
(Ord k, Semigroup u, Semigroup v, Semigroup w, Applicative m) =>
(Entry k u a -> Entry k v b -> m (Maybe (Entry k w c))) {- ^ for keys present in both maps -} ->
(AnnotatedMap k u a -> m (AnnotatedMap k w c)) {- ^ for subtrees only in first map -} ->
(AnnotatedMap k v b -> m (AnnotatedMap k w c)) {- ^ for subtrees only in second map -} ->
AnnotatedMap k u a -> AnnotatedMap k v b -> m (AnnotatedMap k w c)
mergeGeneric f g1 g2 (AnnotatedMap ft1) (AnnotatedMap ft2) = AnnotatedMap <\$> (merge1 ft1 ft2)
where
g1' ft = annotatedMap <\$> g1 (AnnotatedMap ft)
g2' ft = annotatedMap <\$> g2 (AnnotatedMap ft)

rebuild l Nothing r  = l >< r
rebuild l (Just x) r = l >< (x <| r)

merge1 xs ys =
case FT.viewl xs of
FT.EmptyL -> g2' ys
x FT.:< xs' ->
let (ys1, ym, ys2) = splitAtKey (keyOf x) ys in
case ym of
Nothing -> (><) <\$> g2' ys1 <*> merge2 xs ys2
Just y  -> rebuild <\$> g2' ys1 <*> f x y <*> merge2 xs' ys2

merge2 xs ys =
case FT.viewl ys of
FT.EmptyL -> g1' xs
y FT.:< ys' ->
let (xs1, xm, xs2) = splitAtKey (keyOf y) xs in
case xm of
Nothing -> (><) <\$> g1' xs1 <*> merge1 xs2 ys
Just x  -> rebuild <\$> g1' xs1 <*> f x y <*> merge1 xs2 ys'
```