{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE DefaultSignatures #-}
module LazyPPL.Distributions.Memoization (MonadMemo, memoize, generalmemoize, memrec) where
import Control.Monad
import Control.Monad.Extra
import Control.Monad.State.Lazy (State, get, put, runState, state)
import Data.IORef
import Data.List
import Data.Map (empty, insert, keys, lookup, size)
import Debug.Trace
import LazyPPL
import System.IO.Unsafe
class (Monad m) => MonadMemo m a where
memoize :: (a -> m b) -> m (a -> b)
default memoize :: (Enum a) => (a -> m b) -> m (a -> b)
memoize a -> m b
f =
do
BinTree b
t <- Int -> (Int -> m b) -> m (BinTree b)
forall (m :: * -> *) a.
Monad m =>
Int -> (Int -> m a) -> m (BinTree a)
ini Int
0 (a -> m b
f (a -> m b) -> (Int -> a) -> Int -> m b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> a
forall a. Enum a => Int -> a
toEnum)
(a -> b) -> m (a -> b)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ((a -> b) -> m (a -> b)) -> (a -> b) -> m (a -> b)
forall a b. (a -> b) -> a -> b
$ \a
x -> BinTree b -> Int -> b
forall a. BinTree a -> Int -> a
look BinTree b
t (a -> Int
forall a. Enum a => a -> Int
fromEnum a
x)
data BinTree a = Branch a (BinTree a) (BinTree a)
ini :: (Monad m) => Int -> (Int -> m a) -> m (BinTree a)
ini :: forall (m :: * -> *) a.
Monad m =>
Int -> (Int -> m a) -> m (BinTree a)
ini Int
n Int -> m a
f = do a
x <- Int -> m a
f Int
n; BinTree a
l <- Int -> (Int -> m a) -> m (BinTree a)
forall (m :: * -> *) a.
Monad m =>
Int -> (Int -> m a) -> m (BinTree a)
ini (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int -> m a
f; BinTree a
r <- Int -> (Int -> m a) -> m (BinTree a)
forall (m :: * -> *) a.
Monad m =>
Int -> (Int -> m a) -> m (BinTree a)
ini (Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2) Int -> m a
f; BinTree a -> m (BinTree a)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (BinTree a -> m (BinTree a)) -> BinTree a -> m (BinTree a)
forall a b. (a -> b) -> a -> b
$ a -> BinTree a -> BinTree a -> BinTree a
forall a. a -> BinTree a -> BinTree a -> BinTree a
Branch a
x BinTree a
l BinTree a
r
look :: BinTree a -> Int -> a
look :: forall a. BinTree a -> Int -> a
look (Branch a
x BinTree a
l BinTree a
r) Int
0 = a
x
look (Branch a
_ BinTree a
l BinTree a
r) Int
n = if Int -> Bool
forall a. Integral a => a -> Bool
even Int
n then BinTree a -> Int -> a
forall a. BinTree a -> Int -> a
look BinTree a
r (Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) else BinTree a -> Int -> a
forall a. BinTree a -> Int -> a
look BinTree a
l (Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2)
instance (Monad m) => MonadMemo m Int
instance (Monad m, MonadMemo m a, MonadMemo m b) => MonadMemo m (a, b) where
memoize :: forall b. ((a, b) -> m b) -> m ((a, b) -> b)
memoize (a, b) -> m b
f = ((a -> b -> b) -> (a, b) -> b)
-> m (a -> b -> b) -> m ((a, b) -> b)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a -> b -> b) -> (a, b) -> b
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (m (a -> b -> b) -> m ((a, b) -> b))
-> m (a -> b -> b) -> m ((a, b) -> b)
forall a b. (a -> b) -> a -> b
$ (a -> m (b -> b)) -> m (a -> b -> b)
forall b. (a -> m b) -> m (a -> b)
forall (m :: * -> *) a b. MonadMemo m a => (a -> m b) -> m (a -> b)
memoize ((a -> m (b -> b)) -> m (a -> b -> b))
-> (a -> m (b -> b)) -> m (a -> b -> b)
forall a b. (a -> b) -> a -> b
$ \a
x -> (b -> m b) -> m (b -> b)
forall b. (b -> m b) -> m (b -> b)
forall (m :: * -> *) a b. MonadMemo m a => (a -> m b) -> m (a -> b)
memoize ((b -> m b) -> m (b -> b)) -> (b -> m b) -> m (b -> b)
forall a b. (a -> b) -> a -> b
$ \b
y -> (a, b) -> m b
f (a
x, b
y)
generalmemoize :: Ord a => (a -> Prob b) -> Prob (a -> b)
generalmemoize :: forall a b. Ord a => (a -> Prob b) -> Prob (a -> b)
generalmemoize a -> Prob b
f = (Tree -> a -> b) -> Prob (a -> b)
forall a. (Tree -> a) -> Prob a
Prob ((Tree -> a -> b) -> Prob (a -> b))
-> (Tree -> a -> b) -> Prob (a -> b)
forall a b. (a -> b) -> a -> b
$ \(Tree Double
_ [Tree]
gs) ->
IO (a -> b) -> a -> b
forall a. IO a -> a
unsafePerformIO (IO (a -> b) -> a -> b) -> IO (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
$ do
IORef (Map a b)
ref <- Map a b -> IO (IORef (Map a b))
forall a. a -> IO (IORef a)
newIORef Map a b
forall k a. Map k a
Data.Map.empty
(a -> b) -> IO (a -> b)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ((a -> b) -> IO (a -> b)) -> (a -> b) -> IO (a -> b)
forall a b. (a -> b) -> a -> b
$ \a
x -> IO b -> b
forall a. IO a -> a
unsafePerformIO (IO b -> b) -> IO b -> b
forall a b. (a -> b) -> a -> b
$ do
Maybe b
m <- (Map a b -> Maybe b) -> IO (Map a b) -> IO (Maybe b)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (a -> Map a b -> Maybe b
forall k a. Ord k => k -> Map k a -> Maybe a
Data.Map.lookup a
x) (IORef (Map a b) -> IO (Map a b)
forall a. IORef a -> IO a
readIORef IORef (Map a b)
ref)
case Maybe b
m of
Just b
y -> b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
y
Maybe b
Nothing -> do
Map a b
n <- IORef (Map a b) -> IO (Map a b)
forall a. IORef a -> IO a
readIORef IORef (Map a b)
ref
let y :: b
y = Prob b -> Tree -> b
forall a. Prob a -> Tree -> a
runProb (a -> Prob b
f a
x) ([Tree]
gs [Tree] -> Int -> Tree
forall a. HasCallStack => [a] -> Int -> a
!! (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Map a b -> Int
forall k a. Map k a -> Int
size Map a b
n))
IORef (Map a b) -> (Map a b -> Map a b) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef (Map a b)
ref (a -> b -> Map a b -> Map a b
forall k a. Ord k => k -> a -> Map k a -> Map k a
Data.Map.insert a
x b
y)
b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
y
memrec :: Ord a => Show a => ((a -> b) -> (a -> Prob b)) -> Prob (a -> b)
memrec :: forall a b.
(Ord a, Show a) =>
((a -> b) -> a -> Prob b) -> Prob (a -> b)
memrec (a -> b) -> a -> Prob b
f =
(Tree -> a -> b) -> Prob (a -> b)
forall a. (Tree -> a) -> Prob a
Prob ((Tree -> a -> b) -> Prob (a -> b))
-> (Tree -> a -> b) -> Prob (a -> b)
forall a b. (a -> b) -> a -> b
$ \(Tree Double
_ [Tree]
gs) ->
IO (a -> b) -> a -> b
forall a. IO a -> a
unsafePerformIO (IO (a -> b) -> a -> b) -> IO (a -> b) -> a -> b
forall a b. (a -> b) -> a -> b
$ do
IORef (Map a b)
ref <- Map a b -> IO (IORef (Map a b))
forall a. a -> IO (IORef a)
newIORef Map a b
forall k a. Map k a
Data.Map.empty
let memoized_fixpoint :: a -> b
memoized_fixpoint = \a
x -> IO b -> b
forall a. IO a -> a
unsafePerformIO (IO b -> b) -> IO b -> b
forall a b. (a -> b) -> a -> b
$ do
Maybe b
m <- (Map a b -> Maybe b) -> IO (Map a b) -> IO (Maybe b)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
liftM (a -> Map a b -> Maybe b
forall k a. Ord k => k -> Map k a -> Maybe a
Data.Map.lookup a
x) (IORef (Map a b) -> IO (Map a b)
forall a. IORef a -> IO a
readIORef IORef (Map a b)
ref)
case Maybe b
m of
Just b
y -> b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
y
Maybe b
Nothing -> do
Map a b
n <- IORef (Map a b) -> IO (Map a b)
forall a. IORef a -> IO a
readIORef IORef (Map a b)
ref
let fix :: a -> Prob b
fix = (a -> b) -> a -> Prob b
f a -> b
memoized_fixpoint
let y :: b
y = Prob b -> Tree -> b
forall a. Prob a -> Tree -> a
runProb (a -> Prob b
fix a
x) ([Tree]
gs [Tree] -> Int -> Tree
forall a. HasCallStack => [a] -> Int -> a
!! (Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Map a b -> Int
forall k a. Map k a -> Int
size Map a b
n))
IORef (Map a b) -> (Map a b -> Map a b) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef (Map a b)
ref (a -> b -> Map a b -> Map a b
forall k a. Ord k => k -> a -> Map k a -> Map k a
Data.Map.insert a
x b
y)
b -> IO b
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return b
y
(a -> b) -> IO (a -> b)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return a -> b
memoized_fixpoint