{-# LANGUAGE ExistentialQuantification, NamedFieldPuns, ScopedTypeVariables, RecordWildCards, ApplicativeDo #-}
module Data.MultiKeyedMap
( MKMap
, at, (!)
, mkMKMap, fromList, toList
, insert
, flattenKeys, keys, values
) where
import qualified Data.Map.Strict as M
import Data.Monoid (All(..))
import Data.Semigroup ((<>))
import Data.Foldable (foldl')
import qualified Data.List.NonEmpty as NE
import Data.Proxy (Proxy(..))
import qualified Data.Tuple as Tuple
import GHC.Stack (HasCallStack)
import qualified Text.Show as Show
data MKMap k v = forall ik. (Ord ik, Enum ik)
=> MKMap
{ keyMap :: M.Map k ik
, highestIk :: ik
, valMap :: M.Map ik v }
instance (Eq k, Ord k, Eq v) => Eq (MKMap k v) where
(==) m1@(MKMap { keyMap = km1
, valMap = vm1 })
m2@(MKMap { keyMap = km2
, valMap = vm2 })
= getAll $ foldMap All
$ let ks1 = M.keys km1 in
[ length vm1 == length vm2
, ks1 == M.keys km2 ]
++ map (\k -> m1 ! k == m2 ! k) ks1
instance Functor (MKMap k) where
fmap f (MKMap{..}) = MKMap { valMap = fmap f valMap, .. }
{-# INLINE fmap #-}
instance Foldable (MKMap k) where
foldMap f (MKMap{..}) = foldMap f valMap
{-# INLINE foldMap #-}
instance Traversable (MKMap k) where
traverse f (MKMap{..}) = do
val <- traverse f valMap
pure $ MKMap{ valMap=val, .. }
{-# INLINE traverse #-}
at :: (HasCallStack, Ord k) => MKMap k v -> k -> v
at MKMap{keyMap, valMap} k = valMap M.! (keyMap M.! k)
(!) :: (HasCallStack, Ord k) => MKMap k v -> k -> v
(!) = at
{-# INLINABLE (!) #-}
{-# INLINABLE at #-}
infixl 9 !
mkMKMap :: forall k ik v. (Ord k, Ord ik, Enum ik, Bounded ik)
=> (Proxy ik)
-> MKMap k v
mkMKMap _ = MKMap mempty (minBound :: ik) mempty
{-# INLINE mkMKMap #-}
instance (Show k, Show v) => Show (MKMap k v) where
showsPrec d m = Show.showString "fromList " . (showsPrec d $ toList m)
fromList :: forall ik k v. (Ord k, Ord ik, Enum ik, Bounded ik)
=> (Proxy ik)
-> [(NE.NonEmpty k, v)]
-> MKMap k v
fromList p = foldl' (\m (ks, v) -> newVal ks v m) (mkMKMap p)
toList :: MKMap k v -> [(NE.NonEmpty k, v)]
toList MKMap{keyMap, valMap} =
map (fmap (valMap M.!) . Tuple.swap) . M.assocs . aggregateIk $ keyMap
where
aggregateIk :: forall k ik. (Ord ik, Enum ik)
=> M.Map k ik
-> M.Map ik (NE.NonEmpty k)
aggregateIk = M.foldlWithKey
(\m k ik -> M.insertWith (<>) ik (pure k) m) mempty
flattenKeys :: (Ord k) => MKMap k v -> M.Map k v
flattenKeys MKMap{keyMap, valMap} =
M.foldlWithKey' (\m k ik -> M.insert k (valMap M.! ik) m) mempty keyMap
keys :: (Ord k) => MKMap k v -> [k]
keys = M.keys . flattenKeys
values :: MKMap k v -> [v]
values (MKMap _ _ valMap) = M.elems valMap
insert :: (Ord k) => k -> v -> MKMap k v -> MKMap k v
insert k v m@MKMap{keyMap, highestIk, valMap} =
maybe ins upd $ M.lookup k keyMap
where
ins = newVal (pure k) v m
upd ik = MKMap { keyMap, highestIk, valMap = M.insert ik v valMap }
newVal :: (Ord k) => NE.NonEmpty k -> v -> MKMap k v -> MKMap k v
newVal ks v MKMap{keyMap, highestIk, valMap} =
MKMap { keyMap = foldl' (\m k -> M.insert k next m) keyMap ks
, highestIk = next
, valMap = M.insert next v valMap }
where next = succ highestIk