{-# LANGUAGE BangPatterns #-}
module LazyPPL.Distributions.Counter (Counter,newCounter,readAndIncrement) where

import LazyPPL
import Data.IORef
import System.IO.Unsafe

{--
Some "unsafe" functions for a hidden counter. 
This is useful for implementing some nonparametric models 
like the indian buffet process. 
The function is deterministic but we perform some redundant sampling
so that the function is treated as impure by the compiler and 
thus re-evalued every time.  

The implementation of the counter (IORef Int) is encapsulated.
--}

data Counter = C (IORef Int)

newCounter :: Prob Counter
newCounter :: Prob Counter
newCounter = do Double
r <- Prob Double
uniform
                Counter -> Prob Counter
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return (Counter -> Prob Counter) -> Counter -> Prob Counter
forall a b. (a -> b) -> a -> b
$ IORef Int -> Counter
C (IORef Int -> Counter) -> IORef Int -> Counter
forall a b. (a -> b) -> a -> b
$ IO (IORef Int) -> IORef Int
forall a. IO a -> a
unsafePerformIO (IO (IORef Int) -> IORef Int) -> IO (IORef Int) -> IORef Int
forall a b. (a -> b) -> a -> b
$ Int -> IO (IORef Int)
forall a. a -> IO (IORef a)
newIORef (Double -> Int
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
round (Double
rDouble -> Double -> Double
forall a. Num a => a -> a -> a
-Double
r))

readAndIncrement :: Counter -> Prob Int 
readAndIncrement :: Counter -> Prob Int
readAndIncrement (C IORef Int
ref) = do
    Double
r <- Prob Double
uniform
    Int -> Prob Int
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return (Int -> Prob Int) -> Int -> Prob Int
forall a b. (a -> b) -> a -> b
$ IO Int -> Int
forall a. IO a -> a
unsafePerformIO (IO Int -> Int) -> IO Int -> Int
forall a b. (a -> b) -> a -> b
$ do 
        !Int
i <- IORef Int -> IO Int
forall a. IORef a -> IO a
readIORef IORef Int
ref 
        () <- IORef Int -> Int -> IO ()
forall a. IORef a -> a -> IO ()
writeIORef IORef Int
ref (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Double -> Int
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
round (Double
r Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
r))
        Int -> IO Int
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
i