{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE CPP #-}
module Haxl.Core.StateStore
( StateKey(..)
, StateStore
, stateGet
, stateSet
, stateEmpty
) where
import Data.Map (Map)
import qualified Data.Map.Strict as Map
#if __GLASGOW_HASKELL__ < 804
import Data.Monoid
#endif
import Data.Typeable
import Unsafe.Coerce
#if __GLASGOW_HASKELL__ >= 708
class Typeable f => StateKey (f :: * -> *) where
#else
class Typeable1 f => StateKey (f :: * -> *) where
#endif
data State f
getStateType :: Proxy f -> TypeRep
getStateType = typeRep
newtype StateStore = StateStore (Map TypeRep StateStoreData)
#if __GLASGOW_HASKELL__ >= 804
instance Semigroup StateStore where
(<>) = mappend
#endif
instance Monoid StateStore where
mempty = stateEmpty
mappend (StateStore m1) (StateStore m2) = StateStore $ m1 <> m2
data StateStoreData = forall f. StateKey f => StateStoreData (State f)
stateEmpty :: StateStore
stateEmpty = StateStore Map.empty
stateSet :: forall f . StateKey f => State f -> StateStore -> StateStore
stateSet st (StateStore m) =
StateStore (Map.insert (getStateType (Proxy :: Proxy f)) (StateStoreData st) m)
stateGet :: forall r . StateKey r => StateStore -> Maybe (State r)
stateGet (StateStore m) =
case Map.lookup ty m of
Nothing -> Nothing
Just (StateStoreData (st :: State f))
| getStateType (Proxy :: Proxy f) == ty -> Just (unsafeCoerce st)
| otherwise -> Nothing
where
ty = getStateType (Proxy :: Proxy r)