module Agda.Utils.Memo where

import Control.Monad.State
import System.IO.Unsafe
import Data.IORef
import qualified Data.Map as Map
import qualified Data.HashMap.Strict as HMap
import Data.Hashable

import Agda.Utils.Lens

-- Simple memoisation in a state monad

-- | Simple, non-reentrant memoisation.
memo :: MonadState s m => Lens' (Maybe a) s -> m a -> m a
memo :: forall s (m :: * -> *) a.
MonadState s m =>
Lens' (Maybe a) s -> m a -> m a
memo Lens' (Maybe a) s
tbl m a
compute = do
  Maybe a
mv <- Lens' (Maybe a) s -> m (Maybe a)
forall o (m :: * -> *) i. MonadState o m => Lens' i o -> m i
use (Maybe a -> f (Maybe a)) -> s -> f s
Lens' (Maybe a) s
tbl
  case Maybe a
mv of
    Just a
x  -> a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
    Maybe a
Nothing -> do
      a
x <- m a
compute
      a
x a -> m () -> m a
forall a b. a -> m b -> m a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ ((Maybe a -> f (Maybe a)) -> s -> f s
Lens' (Maybe a) s
tbl Lens' (Maybe a) s -> Maybe a -> m ()
forall o (m :: * -> *) i. MonadState o m => Lens' i o -> i -> m ()
.= a -> Maybe a
forall a. a -> Maybe a
Just a
x)

-- | Recursive memoisation, second argument is the value you get
--   on recursive calls.
memoRec :: MonadState s m => Lens' (Maybe a) s -> a -> m a -> m a
memoRec :: forall s (m :: * -> *) a.
MonadState s m =>
Lens' (Maybe a) s -> a -> m a -> m a
memoRec Lens' (Maybe a) s
tbl a
ih m a
compute = do
  Maybe a
mv <- Lens' (Maybe a) s -> m (Maybe a)
forall o (m :: * -> *) i. MonadState o m => Lens' i o -> m i
use (Maybe a -> f (Maybe a)) -> s -> f s
Lens' (Maybe a) s
tbl
  case Maybe a
mv of
    Just a
x  -> a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
    Maybe a
Nothing -> do
      (Maybe a -> f (Maybe a)) -> s -> f s
Lens' (Maybe a) s
tbl Lens' (Maybe a) s -> Maybe a -> m ()
forall o (m :: * -> *) i. MonadState o m => Lens' i o -> i -> m ()
.= a -> Maybe a
forall a. a -> Maybe a
Just a
ih
      a
x <- m a
compute
      a
x a -> m () -> m a
forall a b. a -> m b -> m a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ ((Maybe a -> f (Maybe a)) -> s -> f s
Lens' (Maybe a) s
tbl Lens' (Maybe a) s -> Maybe a -> m ()
forall o (m :: * -> *) i. MonadState o m => Lens' i o -> i -> m ()
.= a -> Maybe a
forall a. a -> Maybe a
Just a
x)

{-# NOINLINE memoUnsafe #-}
memoUnsafe :: Ord a => (a -> b) -> (a -> b)
memoUnsafe :: forall a b. Ord a => (a -> b) -> a -> b
memoUnsafe a -> b
f = IO (a -> b) -> a -> b
forall a. IO a -> a
unsafePerformIO (IO (a -> b) -> a -> b) -> IO (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
$ do
  IORef (Map a b)
tbl <- Map a b -> IO (IORef (Map a b))
forall a. a -> IO (IORef a)
newIORef Map a b
forall k a. Map k a
Map.empty
  (a -> b) -> IO (a -> b)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IO b -> b
forall a. IO a -> a
unsafePerformIO (IO b -> b) -> (a -> IO b) -> a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IORef (Map a b) -> a -> IO b
f' IORef (Map a b)
tbl)
  where
    f' :: IORef (Map a b) -> a -> IO b
f' IORef (Map a b)
tbl a
x = do
      Map a b
m <- IORef (Map a b) -> IO (Map a b)
forall a. IORef a -> IO a
readIORef IORef (Map a b)
tbl
      case a -> Map a b -> Maybe b
forall k a. Ord k => k -> Map k a -> Maybe a
Map.lookup a
x Map a b
m of
        Just b
y  -> b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
y
        Maybe b
Nothing -> do
          let y :: b
y = a -> b
f a
x
          IORef (Map a b) -> Map a b -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (Map a b)
tbl (a -> b -> Map a b -> Map a b
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert a
x b
y Map a b
m)
          b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
y

{-# NOINLINE memoUnsafeH #-}
memoUnsafeH :: (Eq a, Hashable a) => (a -> b) -> (a -> b)
memoUnsafeH :: forall a b. (Eq a, Hashable a) => (a -> b) -> a -> b
memoUnsafeH a -> b
f = IO (a -> b) -> a -> b
forall a. IO a -> a
unsafePerformIO (IO (a -> b) -> a -> b) -> IO (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
$ do
  IORef (HashMap a b)
tbl <- HashMap a b -> IO (IORef (HashMap a b))
forall a. a -> IO (IORef a)
newIORef HashMap a b
forall k v. HashMap k v
HMap.empty
  (a -> b) -> IO (a -> b)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (IO b -> b
forall a. IO a -> a
unsafePerformIO (IO b -> b) -> (a -> IO b) -> a -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IORef (HashMap a b) -> a -> IO b
f' IORef (HashMap a b)
tbl)
  where
    f' :: IORef (HashMap a b) -> a -> IO b
f' IORef (HashMap a b)
tbl a
x = do
      HashMap a b
m <- IORef (HashMap a b) -> IO (HashMap a b)
forall a. IORef a -> IO a
readIORef IORef (HashMap a b)
tbl
      case a -> HashMap a b -> Maybe b
forall k v. (Eq k, Hashable k) => k -> HashMap k v -> Maybe v
HMap.lookup a
x HashMap a b
m of
        Just b
y  -> b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
y
        Maybe b
Nothing -> do
          let y :: b
y = a -> b
f a
x
          IORef (HashMap a b) -> HashMap a b -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef (HashMap a b)
tbl (a -> b -> HashMap a b -> HashMap a b
forall k v.
(Eq k, Hashable k) =>
k -> v -> HashMap k v -> HashMap k v
HMap.insert a
x b
y HashMap a b
m)
          b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
y