{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Data.EnumMap ( EnumMap,
                      empty,
                      insert, insertWith, insertLookupWithKey,
                      delete, update,
                      lookup, (!), member,
                      null, size,
                      elems )

where

import Prelude hiding ( lookup, null )

import Data.IntMap ( IntMap )
import qualified Data.IntMap as IntMap

newtype EnumMap a b = EM{unEM :: IntMap b}
                      deriving ( Eq,
                                 Ord,
                                 Foldable,
                                 Functor,
                                 Semigroup,
                                 Monoid )

empty :: Enum a => EnumMap a b
empty = EM IntMap.empty

insert :: Enum a => a -> b -> EnumMap a b -> EnumMap a b
insert k v (EM m) = EM $ IntMap.insert (fromEnum k) v m

insertWith :: Enum a => (b -> b -> b) -> a -> b -> EnumMap a b -> EnumMap a b
insertWith f k v (EM m) = EM $ IntMap.insertWith f (fromEnum k) v m

insertLookupWithKey :: Enum a
                    => (a -> b -> b -> b)
                    -> a -> b -> EnumMap a b -> (Maybe b, EnumMap a b)
insertLookupWithKey f k v (EM m) = (m_b, EM m')
    where f'        = f . toEnum
          k'        = fromEnum k
          (m_b, m') = IntMap.insertLookupWithKey f' k' v m

delete :: Enum a => a -> EnumMap a b -> EnumMap a b
delete k (EM m) = EM $ IntMap.delete (fromEnum k) m

update :: Enum a => (b -> Maybe b) -> a -> EnumMap a b -> EnumMap a b
update f k (EM m) = EM $ IntMap.update f (fromEnum k) m

lookup :: Enum a => a -> EnumMap a b -> Maybe b
lookup k (EM m) = IntMap.lookup (fromEnum k) m

(!) :: Enum a => EnumMap a b -> a -> b
(EM m) ! k = m IntMap.! (fromEnum k)

member :: Enum a => a -> EnumMap a b -> Bool
member k (EM m) = IntMap.member (fromEnum k) m

null :: EnumMap a b -> Bool
null = IntMap.null . unEM

size :: EnumMap a b -> Int
size = IntMap.size . unEM

elems :: EnumMap a b -> [b]
elems = IntMap.elems . unEM

-- ---------------------
-- Read + Show instances
-- ---------------------

instance (Show a, Show b) => Show (EnumMap a b) where
    showsPrec p = showsPrec p . unEM

instance (Read a, Read b) => Read (EnumMap a b) where
    readsPrec p = map (\(a,s) -> (EM a,s)) . readsPrec p