{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module Data.EnumSet ( EnumSet,
                      empty, singleton,
                      insert, delete,
                      intersection, union,
                      size, null,
                      toList )

where

import Prelude hiding ( null )

import Data.IntSet ( IntSet )
import qualified Data.IntSet as IntSet

newtype EnumSet a = ES{unES :: IntSet}
                    deriving ( Eq,
                               Ord,
                               Semigroup,
                               Monoid )


empty :: Enum a => EnumSet a
empty = ES IntSet.empty

singleton :: Enum a => a -> EnumSet a
singleton = lift IntSet.singleton

insert :: Enum a => a -> EnumSet a -> EnumSet a
insert = liftC IntSet.insert

delete :: Enum a => a -> EnumSet a -> EnumSet a
delete = liftC IntSet.delete

intersection :: Enum a => EnumSet a -> EnumSet a -> EnumSet a
intersection = liftS IntSet.intersection

union :: Enum a => EnumSet a -> EnumSet a -> EnumSet a
union = liftS IntSet.union

size :: EnumSet a -> Int
size = IntSet.size . unES

null :: EnumSet a -> Bool
null = IntSet.null . unES

toList :: Enum a => EnumSet a -> [a]
toList (ES m) = map toEnum (IntSet.toList m)

--

lift :: Enum a => (Int -> IntSet) -> a -> EnumSet a
lift f = ES . f . fromEnum

liftS :: Enum a => (IntSet -> IntSet -> IntSet)
                -> EnumSet a -> EnumSet a -> EnumSet a
liftS f = \(ES s) (ES r) -> ES (f s r)

liftC :: Enum a => (Int -> IntSet -> IntSet) -> a -> EnumSet a -> EnumSet a
liftC f = \a (ES s) -> ES (f (fromEnum a) s)


-- ---------------------
-- Read + Show instances
-- ---------------------

instance Show a => Show (EnumSet a) where
    showsPrec p = showsPrec p . unES

instance Read a => Read (EnumSet a) where
    readsPrec p = map (\(a,s) -> (ES a,s)) . readsPrec p