{-# LANGUAGE GADTs, DeriveAnyClass, DeriveGeneric, OverloadedStrings, DeriveLift
    , QuasiQuotes, TemplateHaskell, DeriveDataTypeable #-}
  Module      : Text.ANTLR.Set
  Description : Entrypoint for swapping out different underlying set representations
  Copyright   : (c) Karl Cronburg, 2018
  License     : BSD3
  Maintainer  : karl@cs.tufts.edu
  Stability   : experimental
  Portability : POSIX

module Text.ANTLR.Set
  ( Set, null, size, member, notMember
  , empty, singleton, insert, delete, union, unions
  , difference, intersection, filter, map, foldr, foldl', fold
  , toList, fromList, (\\), findMin, maybeMin
  , Hashable(..), Generic(..)
  ) where
import Text.ANTLR.Pretty

import GHC.Generics (Generic, Rep)
import Data.Hashable (Hashable(..))
import Language.Haskell.TH.Syntax (Lift(..))

import qualified Data.Functor         as F
import qualified Control.Applicative  as A
import qualified Data.Foldable        as Foldable

import Data.Map ( Map(..) )
import qualified Data.Map as M

import qualified Data.HashSet as S
import Data.HashSet as S
  ( HashSet(..), member, toList, union
  , null, empty, map, size, singleton, insert
  , delete, unions, difference, intersection, foldl'
  , fromList

import Prelude hiding (null, filter, map, foldr, foldl)

-- | Use a hash-based set (hashable keys) for our internal set representation
--   during parsing.
type Set = S.HashSet

-- | Is @e@ not a member of the set @s@.
notMember e s = not $ member e s

-- | Set fold
fold = S.foldr

-- | Set fold
foldr = S.foldr

-- | Find the minimum value of an orderable set.
findMin :: (Ord a, Hashable a) => Set a -> a
findMin = minimum . toList

--maybeMin :: (Ord a, Hashable a) => Set a -> Maybe a
-- | Get minimum of a set without erroring out on empty set.
maybeMin as
  | S.size as == 0  = Nothing
  | otherwise       = Just $ findMin as

infixl 9 \\

-- | Set difference
(\\) :: (Hashable a, Eq a) => Set a -> Set a -> Set a
m1 \\ m2 = difference m1 m2

instance (Hashable a, Eq a, Lift a) => Lift (S.HashSet a) where
  lift set = [| fromList $(lift $ toList set) |]

instance (Hashable k, Hashable v) => Hashable (Map k v) where
  hashWithSalt salt mp = salt `hashWithSalt` M.toList mp

instance (Prettify a, Hashable a, Eq a) => Prettify (S.HashSet a) where
  prettify s = do
    pStr "Set: "; incrIndent 5
    pListLines $ toList s
    incrIndent (-5)
    pLine ""

--filter :: (Hashable a, Eq a) => (a -> Bool) -> Set a -> Set a
-- | Set filter
filter f s = S.filter f s

--instance (Hashable a, Eq a) => Hashable (S.HashSet a) where
--  hashWithSalt salt set = salt `hashWithSalt` S.toList (run set)


--import Data.Set.Monad (Set(..), member, toList, union, notMember)
--import qualified Data.Set.Monad as Set

import Data.Monoid
import Data.Foldable (Foldable)
import Control.Arrow
import Control.Monad
import Control.DeepSeq

import Data.Data (Data(..))

instance (Hashable a, Eq a) => Hashable (Set a) where
  hashWithSalt salt set = salt `hashWithSalt` S.toList (run set)

instance (Hashable a, Ord a) => Ord (Set a) where
  s1 <= s2 = S.toList (run s1) <= S.toList (run s2)

data Set a where
  Prim   :: (Hashable a, Eq a) => S.HashSet a -> Set a
  Return :: a -> Set a
  Bind   :: Set a -> (a -> Set b) -> Set b
  Zero   :: Set a
  Plus   :: Set a -> Set a -> Set a

instance (Data a) => Data (Set a)

instance (Hashable a, Eq a, Lift a) => Lift (Set a) where
  lift set = [| fromList $(lift $ toList set) |]

run :: (Hashable a, Eq a) => Set a -> S.HashSet a
run (Prim s)                        = s
run (Return a)                      = S.singleton a
run (Zero)                          = S.empty
run (Plus ma mb)                    = run ma `S.union` run mb
run (Bind (Prim s) f)               = S.foldl' S.union S.empty (S.map (run . f) s)
run (Bind (Return a) f)             = run (f a)
run (Bind Zero _)                   = S.empty
run (Bind (Plus (Prim s) ma) f)     = run (Bind (Prim (s `S.union` run ma)) f)
run (Bind (Plus ma (Prim s)) f)     = run (Bind (Prim (run ma `S.union` s)) f)
run (Bind (Plus (Return a) ma) f)   = run (Plus (f a) (Bind ma f))
run (Bind (Plus ma (Return a)) f)   = run (Plus (Bind ma f) (f a))
run (Bind (Plus Zero ma) f)         = run (Bind ma f)
run (Bind (Plus ma Zero) f)         = run (Bind ma f)
run (Bind (Plus (Plus ma mb) mc) f) = run (Bind (Plus ma (Plus mb mc)) f)
run (Bind (Plus ma mb) f)           = run (Plus (Bind ma f) (Bind mb f))
run (Bind (Bind ma f) g)            = run (Bind ma (\a -> Bind (f a) g))

instance F.Functor Set where
  fmap = liftM

instance A.Applicative Set where
  pure  = return
  (<*>) = ap

instance A.Alternative Set where
  empty = Zero
  (<|>) = Plus

instance Monad Set where
  return = Return
  (>>=)  = Bind

instance MonadPlus Set where
  mzero = Zero
  mplus = Plus

instance (Hashable a, Eq a) => Monoid (Set a) where
  mempty  = empty
  mappend = union
  mconcat = unions

instance Foldable Set where
    foldr f def m = 
        case m of
            Prim s -> S.foldr f def s
            Return a -> f a def
            Zero -> def
            Plus ma mb -> Foldable.foldr f (Foldable.foldr f def ma) mb
            Bind s g -> Foldable.foldr f' def s
                where f' x b = Foldable.foldr f b (g x)

instance (Hashable a, Eq a) => Eq (Set a) where
  s1 == s2 = run s1 == run s2

--instance (Hashable a, Eq a, Ord a) => Ord (Set a) where
--  compare s1 s2 = compare (run s1) (run s2)

instance (Show a, Hashable a, Eq a) => Show (Set a) where
  show = show . run

instance (Prettify a, Hashable a, Eq a) => Prettify (Set a) where
  prettify s = do
    pStr "Set: "; incrIndent 5
    pListLines $ toList s
    incrIndent (-5)
    pLine ""

instance (Read a, Hashable a, Eq a) => Read (Set a) where
  readsPrec i s = L.map (first Prim) (readsPrec i s)

instance (NFData a, Hashable a, Eq a) => NFData (Set a) where
  rnf = rnf . run

infixl 9 \\

(\\) :: (Hashable a, Eq a) => Set a -> Set a -> Set a
m1 \\ m2 = difference m1 m2

null :: (Hashable a, Eq a) => Set a -> Bool
null = S.null . run

size :: (Hashable a, Eq a) => Set a -> Int
size = S.size . run

member :: (Hashable a, Eq a) => a -> Set a -> Bool
member a s = S.member a (run s)

notMember :: (Hashable a, Eq a) => a -> Set a -> Bool
notMember a t = not (member a t)

empty :: (Hashable a, Eq a) => Set a
empty = Prim S.empty

singleton :: (Hashable a, Eq a) => a -> Set a
singleton a = Prim (S.singleton a)

insert :: (Hashable a, Eq a) => a -> Set a -> Set a
insert a s = Prim (S.insert a (run s))

delete :: (Hashable a, Eq a) => a -> Set a -> Set a
delete a s = Prim (S.delete a (run s))

union :: (Hashable a, Eq a) => Set a -> Set a -> Set a
union s1 s2 = Prim (run s1 `S.union` run s2)

unions :: (Hashable a, Eq a) => [Set a] -> Set a
unions ss = Prim (S.unions (L.map run ss))

difference :: (Hashable a, Eq a) => Set a -> Set a -> Set a
difference s1 s2 = Prim (S.difference (run s1) (run s2))

intersection :: (Hashable a, Eq a) => Set a -> Set a -> Set a
intersection s1 s2 = Prim (S.intersection (run s1) (run s2))

filter :: (Hashable a, Eq a) => (a -> Bool) -> Set a -> Set a
filter f s = Prim (S.filter f (run s))

map :: (Hashable a, Eq a, Hashable b, Eq b) => (a -> b) -> Set a -> Set b
map f s = Prim (S.map f (run s))

foldr :: (Hashable a, Eq a) => (a -> b -> b) -> b -> Set a -> b
foldr f z s = S.foldr f z (run s)

fold :: (Hashable a, Eq a) => (a -> b -> b) -> b -> Set a -> b
fold f z s = S.foldr f z (run s)

foldl' :: (Hashable a, Eq a) => (b -> a -> b) -> b -> Set a -> b
foldl' f z s = S.foldl' f z (run s)

toList :: (Hashable a, Eq a) => Set a -> [a]
toList = S.toList . run

fromList :: (Hashable a, Eq a) => [a] -> Set a
fromList as = Prim (S.fromList as)

findMin :: (Ord a, Hashable a) => Set a -> a
findMin = minimum . toList

maybeMin :: (Ord a, Hashable a) => Set a -> Maybe a
maybeMin as
  | size as == 0 = Nothing
  | otherwise    = Just $ findMin as
