{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE DefaultSignatures #-}

{- | Stochastic memoization for the `LazyPPL` library.
Stochastic memoization is a useful primitive for programming non-parametric models. 
It has type @(a -> Prob b) -> Prob (a -> b)@, and can be thought of as simultaneously sampling results for all possible arguments, in a lazy way. 

When @a@ is enumerable, this amounts to converting a stream of probabilities to a random stream, which we can do by sampling each probability once. 

This module provides:

* A general type-class `MonadMemo` for monads @m@ that support memoization at certain argument types @a@. 

* A default trie-based implementation when @a@ is enumerable, and a curry-based implementation when @a@ is a pair type. 

* A general implementation `generalmemoize` for probability, using memo-tables.

* A memoized recursion combinator, `memrec`. 


For illustrations, see the [graph example](https://lazyppl-team.github.io/GraphDemo.html), [clustering](https://lazyppl-team.github.io/ClusteringDemo.html), [additive clustering](https://lazyppl-team.github.io/AdditiveClusteringDemo.html), or the [infinite relational model](https://lazyppl-team.github.io/IrmDemo.html). 
-}

module LazyPPL.Distributions.Memoization (MonadMemo, memoize, generalmemoize, memrec) where

import Control.Monad
import Control.Monad.Extra
import Control.Monad.State.Lazy (State, get, put, runState, state)
import Data.IORef
import Data.List
import Data.Map (empty, insert, keys, lookup, size)
import Debug.Trace
import LazyPPL 
import System.IO.Unsafe

{-| Type class for memoizable argument types @a@ under a monad @m@ -}
class (Monad m) => MonadMemo m a where
  memoize :: (a -> m b) -> m (a -> b)
  default memoize :: (Enum a) => (a -> m b) -> m (a -> b)
  memoize a -> m b
f = 
    do
      BinTree b
t <- Int -> (Int -> m b) -> m (BinTree b)
forall (m :: * -> *) a.
Monad m =>
Int -> (Int -> m a) -> m (BinTree a)
ini Int
0 (a -> m b
f (a -> m b) -> (Int -> a) -> Int -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> a
forall a. Enum a => Int -> a
toEnum)
      (a -> b) -> m (a -> b)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ((a -> b) -> m (a -> b)) -> (a -> b) -> m (a -> b)
forall a b. (a -> b) -> a -> b
$ \a
x -> BinTree b -> Int -> b
forall a. BinTree a -> Int -> a
look BinTree b
t (a -> Int
forall a. Enum a => a -> Int
fromEnum a
x)

{-- Basic trie based integer-indexed memo table.
    NB Currently ignores negative integers --}
data BinTree a = Branch a (BinTree a) (BinTree a)

ini :: (Monad m) => Int -> (Int -> m a) -> m (BinTree a)
ini :: forall (m :: * -> *) a.
Monad m =>
Int -> (Int -> m a) -> m (BinTree a)
ini Int
n Int -> m a
f = do a
x <- Int -> m a
f Int
n; BinTree a
l <- Int -> (Int -> m a) -> m (BinTree a)
forall (m :: * -> *) a.
Monad m =>
Int -> (Int -> m a) -> m (BinTree a)
ini (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> m a
f; BinTree a
r <- Int -> (Int -> m a) -> m (BinTree a)
forall (m :: * -> *) a.
Monad m =>
Int -> (Int -> m a) -> m (BinTree a)
ini (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2) Int -> m a
f; BinTree a -> m (BinTree a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (BinTree a -> m (BinTree a)) -> BinTree a -> m (BinTree a)
forall a b. (a -> b) -> a -> b
$ a -> BinTree a -> BinTree a -> BinTree a
forall a. a -> BinTree a -> BinTree a -> BinTree a
Branch a
x BinTree a
l BinTree a
r

look :: BinTree a -> Int -> a
look :: forall a. BinTree a -> Int -> a
look (Branch a
x BinTree a
l BinTree a
r) Int
0 = a
x
look (Branch a
_ BinTree a
l BinTree a
r) Int
n = if Int -> Bool
forall a. Integral a => a -> Bool
even Int
n then BinTree a -> Int -> a
forall a. BinTree a -> Int -> a
look BinTree a
r (Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) else BinTree a -> Int -> a
forall a. BinTree a -> Int -> a
look BinTree a
l (Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2)

{-| Implementation for enumerable types using tries -}
instance (Monad m) => MonadMemo m Int

{-| Implementation for pair types using currying -}
instance (Monad m, MonadMemo m a, MonadMemo m b) => MonadMemo m (a, b) where
  memoize :: forall b. ((a, b) -> m b) -> m ((a, b) -> b)
memoize (a, b) -> m b
f = ((a -> b -> b) -> (a, b) -> b)
-> m (a -> b -> b) -> m ((a, b) -> b)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> b -> b) -> (a, b) -> b
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (m (a -> b -> b) -> m ((a, b) -> b))
-> m (a -> b -> b) -> m ((a, b) -> b)
forall a b. (a -> b) -> a -> b
$ (a -> m (b -> b)) -> m (a -> b -> b)
forall b. (a -> m b) -> m (a -> b)
forall (m :: * -> *) a b. MonadMemo m a => (a -> m b) -> m (a -> b)
memoize ((a -> m (b -> b)) -> m (a -> b -> b))
-> (a -> m (b -> b)) -> m (a -> b -> b)
forall a b. (a -> b) -> a -> b
$ \a
x -> (b -> m b) -> m (b -> b)
forall b. (b -> m b) -> m (b -> b)
forall (m :: * -> *) a b. MonadMemo m a => (a -> m b) -> m (a -> b)
memoize ((b -> m b) -> m (b -> b)) -> (b -> m b) -> m (b -> b)
forall a b. (a -> b) -> a -> b
$ \b
y -> (a, b) -> m b
f (a
x, b
y)

{-| A general memoization method when @m@ is a probability monad.

    We use unsafePerformIO to maintain
    a table of calls that have already been made.
    If @a@ is finite, we could just sample all values of @a@ in advance
    and avoid unsafePerformIO. If @a@ is enumerable, we can use the trie method. 
-}
generalmemoize :: Ord a => (a -> Prob b) -> Prob (a -> b)
generalmemoize :: forall a b. Ord a => (a -> Prob b) -> Prob (a -> b)
generalmemoize a -> Prob b
f = (Tree -> a -> b) -> Prob (a -> b)
forall a. (Tree -> a) -> Prob a
Prob ((Tree -> a -> b) -> Prob (a -> b))
-> (Tree -> a -> b) -> Prob (a -> b)
forall a b. (a -> b) -> a -> b
$ \(Tree Double
_ [Tree]
gs) ->
  IO (a -> b) -> a -> b
forall a. IO a -> a
unsafePerformIO (IO (a -> b) -> a -> b) -> IO (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
$ do
    IORef (Map a b)
ref <- Map a b -> IO (IORef (Map a b))
forall a. a -> IO (IORef a)
newIORef Map a b
forall k a. Map k a
Data.Map.empty
    (a -> b) -> IO (a -> b)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((a -> b) -> IO (a -> b)) -> (a -> b) -> IO (a -> b)
forall a b. (a -> b) -> a -> b
$ \a
x -> IO b -> b
forall a. IO a -> a
unsafePerformIO (IO b -> b) -> IO b -> b
forall a b. (a -> b) -> a -> b
$ do
      Maybe b
m <- (Map a b -> Maybe b) -> IO (Map a b) -> IO (Maybe b)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (a -> Map a b -> Maybe b
forall k a. Ord k => k -> Map k a -> Maybe a
Data.Map.lookup a
x) (IORef (Map a b) -> IO (Map a b)
forall a. IORef a -> IO a
readIORef IORef (Map a b)
ref)
      case Maybe b
m of
        Just b
y -> b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
y
        Maybe b
Nothing -> do
          Map a b
n <- IORef (Map a b) -> IO (Map a b)
forall a. IORef a -> IO a
readIORef IORef (Map a b)
ref
          let y :: b
y = Prob b -> Tree -> b
forall a. Prob a -> Tree -> a
runProb (a -> Prob b
f a
x) ([Tree]
gs [Tree] -> Int -> Tree
forall a. HasCallStack => [a] -> Int -> a
!! (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Map a b -> Int
forall k a. Map k a -> Int
size Map a b
n))
          IORef (Map a b) -> (Map a b -> Map a b) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef (Map a b)
ref (a -> b -> Map a b -> Map a b
forall k a. Ord k => k -> a -> Map k a -> Map k a
Data.Map.insert a
x b
y)
          b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
y

{-- Stochastic memoization for recursive functions.
    Applying 'memoize' to a recursively defined function only memoizes at the
    top-level: recursive calls are calls to the non-memoized function.
    'memrec' is an alternative implementation which resolves recursion and
    memoization at the same time, so that recursive calls are also memoized.
--}
memrec :: Ord a => Show a => ((a -> b) -> (a -> Prob b)) -> Prob (a -> b)
memrec :: forall a b.
(Ord a, Show a) =>
((a -> b) -> a -> Prob b) -> Prob (a -> b)
memrec (a -> b) -> a -> Prob b
f =
  (Tree -> a -> b) -> Prob (a -> b)
forall a. (Tree -> a) -> Prob a
Prob ((Tree -> a -> b) -> Prob (a -> b))
-> (Tree -> a -> b) -> Prob (a -> b)
forall a b. (a -> b) -> a -> b
$ \(Tree Double
_ [Tree]
gs) ->
    IO (a -> b) -> a -> b
forall a. IO a -> a
unsafePerformIO (IO (a -> b) -> a -> b) -> IO (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
$ do
      IORef (Map a b)
ref <- Map a b -> IO (IORef (Map a b))
forall a. a -> IO (IORef a)
newIORef Map a b
forall k a. Map k a
Data.Map.empty
      let memoized_fixpoint :: a -> b
memoized_fixpoint = \a
x -> IO b -> b
forall a. IO a -> a
unsafePerformIO (IO b -> b) -> IO b -> b
forall a b. (a -> b) -> a -> b
$ do
            Maybe b
m <- (Map a b -> Maybe b) -> IO (Map a b) -> IO (Maybe b)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (a -> Map a b -> Maybe b
forall k a. Ord k => k -> Map k a -> Maybe a
Data.Map.lookup a
x) (IORef (Map a b) -> IO (Map a b)
forall a. IORef a -> IO a
readIORef IORef (Map a b)
ref)
            case Maybe b
m of
              Just b
y -> b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
y
              Maybe b
Nothing -> do
                Map a b
n <- IORef (Map a b) -> IO (Map a b)
forall a. IORef a -> IO a
readIORef IORef (Map a b)
ref
                let fix :: a -> Prob b
fix = (a -> b) -> a -> Prob b
f a -> b
memoized_fixpoint
                let y :: b
y = Prob b -> Tree -> b
forall a. Prob a -> Tree -> a
runProb (a -> Prob b
fix a
x) ([Tree]
gs [Tree] -> Int -> Tree
forall a. HasCallStack => [a] -> Int -> a
!! (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Map a b -> Int
forall k a. Map k a -> Int
size Map a b
n))
                IORef (Map a b) -> (Map a b -> Map a b) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef (Map a b)
ref (a -> b -> Map a b -> Map a b
forall k a. Ord k => k -> a -> Map k a -> Map k a
Data.Map.insert a
x b
y)
                b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
y
      (a -> b) -> IO (a -> b)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a -> b
memoized_fixpoint