{-# LANGUAGE GeneralizedNewtypeDeriving #-}

module Data.SNMap
    ( SNMapReaderT
    , runSNMapReaderT
    , memoizeM
    , scopedM
    ) where

import           Control.Applicative              (Applicative)
import           Control.Monad.Exception          (MonadAsyncException,
                                                   MonadException)
import           Control.Monad.IO.Class           (MonadIO, liftIO)
import           Control.Monad.Trans.Class        (MonadTrans (..))
import           Control.Monad.Trans.State.Strict (StateT (StateT), evalStateT,
                                                   get, put)
import qualified Data.HashMap.Strict              as HT
import           System.Mem.StableName            (StableName, makeStableName)

{- A map (SN stands for stable name) to cache the results of computations
(m ends up being constrained to MonadIO m).
-}
newtype SNMap m a = SNMap { SNMap m a -> HashMap (StableName (m a)) a
unSNMap :: HT.HashMap (StableName (m a)) a }

newSNMap :: SNMap m a
newSNMap :: SNMap m a
newSNMap = HashMap (StableName (m a)) a -> SNMap m a
forall (m :: * -> *) a. HashMap (StableName (m a)) a -> SNMap m a
SNMap HashMap (StableName (m a)) a
forall k v. HashMap k v
HT.empty

memoize :: MonadIO m
    => m (SNMap m a)        -- The "IO call" to retrieve our cache.

    -> (SNMap m a -> m ())  -- The "IO call" to store an updated cache.

    -> m a                  -- The "IO call" to execute and cache the result.

    -> m a                  -- The result being naturally also returned.

memoize :: m (SNMap m a) -> (SNMap m a -> m ()) -> m a -> m a
memoize m (SNMap m a)
getter SNMap m a -> m ()
putter m a
m = do
    StableName (m a)
s <- IO (StableName (m a)) -> m (StableName (m a))
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (StableName (m a)) -> m (StableName (m a)))
-> IO (StableName (m a)) -> m (StableName (m a))
forall a b. (a -> b) -> a -> b
$ m a -> IO (StableName (m a))
forall a. a -> IO (StableName a)
makeStableName m a
m
    Maybe a
x <- StableName (m a) -> HashMap (StableName (m a)) a -> Maybe a
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
HT.lookup StableName (m a)
s (HashMap (StableName (m a)) a -> Maybe a)
-> (SNMap m a -> HashMap (StableName (m a)) a)
-> SNMap m a
-> Maybe a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SNMap m a -> HashMap (StableName (m a)) a
forall (m :: * -> *) a. SNMap m a -> HashMap (StableName (m a)) a
unSNMap (SNMap m a -> Maybe a) -> m (SNMap m a) -> m (Maybe a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (SNMap m a)
getter
    case Maybe a
x of
        Just a
a -> a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a
        Maybe a
Nothing -> do
            a
a <- m a
m
            -- Need to redo the getter action because of scopeM.

            m (SNMap m a)
getter m (SNMap m a) -> (SNMap m a -> m ()) -> m ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= SNMap m a -> m ()
putter (SNMap m a -> m ())
-> (SNMap m a -> SNMap m a) -> SNMap m a -> m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HashMap (StableName (m a)) a -> SNMap m a
forall (m :: * -> *) a. HashMap (StableName (m a)) a -> SNMap m a
SNMap (HashMap (StableName (m a)) a -> SNMap m a)
-> (SNMap m a -> HashMap (StableName (m a)) a)
-> SNMap m a
-> SNMap m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. StableName (m a)
-> a
-> HashMap (StableName (m a)) a
-> HashMap (StableName (m a)) a
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
HT.insert StableName (m a)
s a
a (HashMap (StableName (m a)) a -> HashMap (StableName (m a)) a)
-> (SNMap m a -> HashMap (StableName (m a)) a)
-> SNMap m a
-> HashMap (StableName (m a)) a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SNMap m a -> HashMap (StableName (m a)) a
forall (m :: * -> *) a. SNMap m a -> HashMap (StableName (m a)) a
unSNMap
            a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
a

-- An (IO) action producing a 'b' value while caching 'a' values along the way.

newtype SNMapReaderT a m b = SNMapReaderT (StateT (SNMap (SNMapReaderT a m) a) m b)
    deriving (a -> SNMapReaderT a m b -> SNMapReaderT a m a
(a -> b) -> SNMapReaderT a m a -> SNMapReaderT a m b
(forall a b. (a -> b) -> SNMapReaderT a m a -> SNMapReaderT a m b)
-> (forall a b. a -> SNMapReaderT a m b -> SNMapReaderT a m a)
-> Functor (SNMapReaderT a m)
forall a b. a -> SNMapReaderT a m b -> SNMapReaderT a m a
forall a b. (a -> b) -> SNMapReaderT a m a -> SNMapReaderT a m b
forall a (m :: * -> *) a b.
Functor m =>
a -> SNMapReaderT a m b -> SNMapReaderT a m a
forall a (m :: * -> *) a b.
Functor m =>
(a -> b) -> SNMapReaderT a m a -> SNMapReaderT a m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> SNMapReaderT a m b -> SNMapReaderT a m a
$c<$ :: forall a (m :: * -> *) a b.
Functor m =>
a -> SNMapReaderT a m b -> SNMapReaderT a m a
fmap :: (a -> b) -> SNMapReaderT a m a -> SNMapReaderT a m b
$cfmap :: forall a (m :: * -> *) a b.
Functor m =>
(a -> b) -> SNMapReaderT a m a -> SNMapReaderT a m b
Functor, Functor (SNMapReaderT a m)
a -> SNMapReaderT a m a
Functor (SNMapReaderT a m)
-> (forall a. a -> SNMapReaderT a m a)
-> (forall a b.
    SNMapReaderT a m (a -> b)
    -> SNMapReaderT a m a -> SNMapReaderT a m b)
-> (forall a b c.
    (a -> b -> c)
    -> SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m c)
-> (forall a b.
    SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m b)
-> (forall a b.
    SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m a)
-> Applicative (SNMapReaderT a m)
SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m b
SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m a
SNMapReaderT a m (a -> b)
-> SNMapReaderT a m a -> SNMapReaderT a m b
(a -> b -> c)
-> SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m c
forall a. a -> SNMapReaderT a m a
forall a b.
SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m a
forall a b.
SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m b
forall a b.
SNMapReaderT a m (a -> b)
-> SNMapReaderT a m a -> SNMapReaderT a m b
forall a b c.
(a -> b -> c)
-> SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m c
forall a (m :: * -> *). Monad m => Functor (SNMapReaderT a m)
forall a (m :: * -> *) a. Monad m => a -> SNMapReaderT a m a
forall a (m :: * -> *) a b.
Monad m =>
SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m a
forall a (m :: * -> *) a b.
Monad m =>
SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m b
forall a (m :: * -> *) a b.
Monad m =>
SNMapReaderT a m (a -> b)
-> SNMapReaderT a m a -> SNMapReaderT a m b
forall a (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m a
$c<* :: forall a (m :: * -> *) a b.
Monad m =>
SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m a
*> :: SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m b
$c*> :: forall a (m :: * -> *) a b.
Monad m =>
SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m b
liftA2 :: (a -> b -> c)
-> SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m c
$cliftA2 :: forall a (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m c
<*> :: SNMapReaderT a m (a -> b)
-> SNMapReaderT a m a -> SNMapReaderT a m b
$c<*> :: forall a (m :: * -> *) a b.
Monad m =>
SNMapReaderT a m (a -> b)
-> SNMapReaderT a m a -> SNMapReaderT a m b
pure :: a -> SNMapReaderT a m a
$cpure :: forall a (m :: * -> *) a. Monad m => a -> SNMapReaderT a m a
$cp1Applicative :: forall a (m :: * -> *). Monad m => Functor (SNMapReaderT a m)
Applicative, Applicative (SNMapReaderT a m)
a -> SNMapReaderT a m a
Applicative (SNMapReaderT a m)
-> (forall a b.
    SNMapReaderT a m a
    -> (a -> SNMapReaderT a m b) -> SNMapReaderT a m b)
-> (forall a b.
    SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m b)
-> (forall a. a -> SNMapReaderT a m a)
-> Monad (SNMapReaderT a m)
SNMapReaderT a m a
-> (a -> SNMapReaderT a m b) -> SNMapReaderT a m b
SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m b
forall a. a -> SNMapReaderT a m a
forall a b.
SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m b
forall a b.
SNMapReaderT a m a
-> (a -> SNMapReaderT a m b) -> SNMapReaderT a m b
forall a (m :: * -> *). Monad m => Applicative (SNMapReaderT a m)
forall a (m :: * -> *) a. Monad m => a -> SNMapReaderT a m a
forall a (m :: * -> *) a b.
Monad m =>
SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m b
forall a (m :: * -> *) a b.
Monad m =>
SNMapReaderT a m a
-> (a -> SNMapReaderT a m b) -> SNMapReaderT a m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> SNMapReaderT a m a
$creturn :: forall a (m :: * -> *) a. Monad m => a -> SNMapReaderT a m a
>> :: SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m b
$c>> :: forall a (m :: * -> *) a b.
Monad m =>
SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m b
>>= :: SNMapReaderT a m a
-> (a -> SNMapReaderT a m b) -> SNMapReaderT a m b
$c>>= :: forall a (m :: * -> *) a b.
Monad m =>
SNMapReaderT a m a
-> (a -> SNMapReaderT a m b) -> SNMapReaderT a m b
$cp1Monad :: forall a (m :: * -> *). Monad m => Applicative (SNMapReaderT a m)
Monad, Monad (SNMapReaderT a m)
Monad (SNMapReaderT a m)
-> (forall a. IO a -> SNMapReaderT a m a)
-> MonadIO (SNMapReaderT a m)
IO a -> SNMapReaderT a m a
forall a. IO a -> SNMapReaderT a m a
forall a (m :: * -> *). MonadIO m => Monad (SNMapReaderT a m)
forall a (m :: * -> *) a. MonadIO m => IO a -> SNMapReaderT a m a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
liftIO :: IO a -> SNMapReaderT a m a
$cliftIO :: forall a (m :: * -> *) a. MonadIO m => IO a -> SNMapReaderT a m a
$cp1MonadIO :: forall a (m :: * -> *). MonadIO m => Monad (SNMapReaderT a m)
MonadIO, Monad (SNMapReaderT a m)
e -> SNMapReaderT a m a
Monad (SNMapReaderT a m)
-> (forall e a. Exception e => e -> SNMapReaderT a m a)
-> (forall e a.
    Exception e =>
    SNMapReaderT a m a
    -> (e -> SNMapReaderT a m a) -> SNMapReaderT a m a)
-> (forall a b.
    SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m a)
-> MonadException (SNMapReaderT a m)
SNMapReaderT a m a
-> (e -> SNMapReaderT a m a) -> SNMapReaderT a m a
SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m a
forall e a. Exception e => e -> SNMapReaderT a m a
forall e a.
Exception e =>
SNMapReaderT a m a
-> (e -> SNMapReaderT a m a) -> SNMapReaderT a m a
forall a b.
SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m a
forall a (m :: * -> *).
MonadException m =>
Monad (SNMapReaderT a m)
forall a (m :: * -> *) e a.
(MonadException m, Exception e) =>
e -> SNMapReaderT a m a
forall a (m :: * -> *) e a.
(MonadException m, Exception e) =>
SNMapReaderT a m a
-> (e -> SNMapReaderT a m a) -> SNMapReaderT a m a
forall a (m :: * -> *) a b.
MonadException m =>
SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m a
forall (m :: * -> *).
Monad m
-> (forall e a. Exception e => e -> m a)
-> (forall e a. Exception e => m a -> (e -> m a) -> m a)
-> (forall a b. m a -> m b -> m a)
-> MonadException m
finally :: SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m a
$cfinally :: forall a (m :: * -> *) a b.
MonadException m =>
SNMapReaderT a m a -> SNMapReaderT a m b -> SNMapReaderT a m a
catch :: SNMapReaderT a m a
-> (e -> SNMapReaderT a m a) -> SNMapReaderT a m a
$ccatch :: forall a (m :: * -> *) e a.
(MonadException m, Exception e) =>
SNMapReaderT a m a
-> (e -> SNMapReaderT a m a) -> SNMapReaderT a m a
throw :: e -> SNMapReaderT a m a
$cthrow :: forall a (m :: * -> *) e a.
(MonadException m, Exception e) =>
e -> SNMapReaderT a m a
$cp1MonadException :: forall a (m :: * -> *).
MonadException m =>
Monad (SNMapReaderT a m)
MonadException, MonadIO (SNMapReaderT a m)
MonadException (SNMapReaderT a m)
MonadIO (SNMapReaderT a m)
-> MonadException (SNMapReaderT a m)
-> (forall b.
    ((forall a. SNMapReaderT a m a -> SNMapReaderT a m a)
     -> SNMapReaderT a m b)
    -> SNMapReaderT a m b)
-> MonadAsyncException (SNMapReaderT a m)
((forall a. SNMapReaderT a m a -> SNMapReaderT a m a)
 -> SNMapReaderT a m b)
-> SNMapReaderT a m b
forall b.
((forall a. SNMapReaderT a m a -> SNMapReaderT a m a)
 -> SNMapReaderT a m b)
-> SNMapReaderT a m b
forall a (m :: * -> *).
MonadAsyncException m =>
MonadIO (SNMapReaderT a m)
forall a (m :: * -> *).
MonadAsyncException m =>
MonadException (SNMapReaderT a m)
forall a (m :: * -> *) b.
MonadAsyncException m =>
((forall a. SNMapReaderT a m a -> SNMapReaderT a m a)
 -> SNMapReaderT a m b)
-> SNMapReaderT a m b
forall (m :: * -> *).
MonadIO m
-> MonadException m
-> (forall b. ((forall a. m a -> m a) -> m b) -> m b)
-> MonadAsyncException m
mask :: ((forall a. SNMapReaderT a m a -> SNMapReaderT a m a)
 -> SNMapReaderT a m b)
-> SNMapReaderT a m b
$cmask :: forall a (m :: * -> *) b.
MonadAsyncException m =>
((forall a. SNMapReaderT a m a -> SNMapReaderT a m a)
 -> SNMapReaderT a m b)
-> SNMapReaderT a m b
$cp2MonadAsyncException :: forall a (m :: * -> *).
MonadAsyncException m =>
MonadException (SNMapReaderT a m)
$cp1MonadAsyncException :: forall a (m :: * -> *).
MonadAsyncException m =>
MonadIO (SNMapReaderT a m)
MonadAsyncException)

runSNMapReaderT :: MonadIO m => SNMapReaderT a m b -> m b
runSNMapReaderT :: SNMapReaderT a m b -> m b
runSNMapReaderT (SNMapReaderT StateT (SNMap (SNMapReaderT a m) a) m b
m) = do
    let h :: SNMap m a
h = SNMap m a
forall (m :: * -> *) a. SNMap m a
newSNMap
    StateT (SNMap (SNMapReaderT a m) a) m b
-> SNMap (SNMapReaderT a m) a -> m b
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT StateT (SNMap (SNMapReaderT a m) a) m b
m SNMap (SNMapReaderT a m) a
forall (m :: * -> *) a. SNMap m a
h

instance MonadTrans (SNMapReaderT a) where
    lift :: m a -> SNMapReaderT a m a
lift = StateT (SNMap (SNMapReaderT a m) a) m a -> SNMapReaderT a m a
forall a (m :: * -> *) b.
StateT (SNMap (SNMapReaderT a m) a) m b -> SNMapReaderT a m b
SNMapReaderT (StateT (SNMap (SNMapReaderT a m) a) m a -> SNMapReaderT a m a)
-> (m a -> StateT (SNMap (SNMapReaderT a m) a) m a)
-> m a
-> SNMapReaderT a m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. m a -> StateT (SNMap (SNMapReaderT a m) a) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift

-- Simplified memoize version when using a SNMapReaderT.

memoizeM :: MonadIO m => SNMapReaderT a m a -> SNMapReaderT a m a
memoizeM :: SNMapReaderT a m a -> SNMapReaderT a m a
memoizeM = SNMapReaderT a m (SNMap (SNMapReaderT a m) a)
-> (SNMap (SNMapReaderT a m) a -> SNMapReaderT a m ())
-> SNMapReaderT a m a
-> SNMapReaderT a m a
forall (m :: * -> *) a.
MonadIO m =>
m (SNMap m a) -> (SNMap m a -> m ()) -> m a -> m a
memoize (StateT (SNMap (SNMapReaderT a m) a) m (SNMap (SNMapReaderT a m) a)
-> SNMapReaderT a m (SNMap (SNMapReaderT a m) a)
forall a (m :: * -> *) b.
StateT (SNMap (SNMapReaderT a m) a) m b -> SNMapReaderT a m b
SNMapReaderT StateT (SNMap (SNMapReaderT a m) a) m (SNMap (SNMapReaderT a m) a)
forall (m :: * -> *) s. Monad m => StateT s m s
get) (StateT (SNMap (SNMapReaderT a m) a) m () -> SNMapReaderT a m ()
forall a (m :: * -> *) b.
StateT (SNMap (SNMapReaderT a m) a) m b -> SNMapReaderT a m b
SNMapReaderT (StateT (SNMap (SNMapReaderT a m) a) m () -> SNMapReaderT a m ())
-> (SNMap (SNMapReaderT a m) a
    -> StateT (SNMap (SNMapReaderT a m) a) m ())
-> SNMap (SNMapReaderT a m) a
-> SNMapReaderT a m ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SNMap (SNMapReaderT a m) a
-> StateT (SNMap (SNMapReaderT a m) a) m ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put)

-- | Run a subcomputation in a scope, where nothing memoized inside will be remembered after

scopedM :: MonadIO m => SNMapReaderT a m x -> SNMapReaderT a m x
scopedM :: SNMapReaderT a m x -> SNMapReaderT a m x
scopedM SNMapReaderT a m x
m = do
    SNMap (SNMapReaderT a m) a
save <- StateT (SNMap (SNMapReaderT a m) a) m (SNMap (SNMapReaderT a m) a)
-> SNMapReaderT a m (SNMap (SNMapReaderT a m) a)
forall a (m :: * -> *) b.
StateT (SNMap (SNMapReaderT a m) a) m b -> SNMapReaderT a m b
SNMapReaderT StateT (SNMap (SNMapReaderT a m) a) m (SNMap (SNMapReaderT a m) a)
forall (m :: * -> *) s. Monad m => StateT s m s
get
    x
x <- SNMapReaderT a m x
m
    StateT (SNMap (SNMapReaderT a m) a) m () -> SNMapReaderT a m ()
forall a (m :: * -> *) b.
StateT (SNMap (SNMapReaderT a m) a) m b -> SNMapReaderT a m b
SNMapReaderT (StateT (SNMap (SNMapReaderT a m) a) m () -> SNMapReaderT a m ())
-> StateT (SNMap (SNMapReaderT a m) a) m () -> SNMapReaderT a m ()
forall a b. (a -> b) -> a -> b
$ SNMap (SNMapReaderT a m) a
-> StateT (SNMap (SNMapReaderT a m) a) m ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put SNMap (SNMapReaderT a m) a
save
    x -> SNMapReaderT a m x
forall (m :: * -> *) a. Monad m => a -> m a
return x
x