-- |
-- Module      : MonusWeightedSearch.Examples.Sampling
-- Copyright   : (c) Donnacha Oisín Kidney 2021
-- Maintainer  : mail@doisinkidney.com
-- Stability   : experimental
-- Portability : non-portable
--
-- Random sampling from the 'Heap' monad.
--
-- The 'Heap' monad can function as a probability monad, and it implements an
-- efficient sampling algorithm, based on reservoir sampling.

module MonusWeightedSearch.Examples.Sampling where

import Control.Monad.Heap
import Data.Monus.Prob
import Data.Ratio
import System.Random

-- | Sample a single value from the heap.
sample :: Heap Prob a -> IO a
sample :: forall a. Heap Prob a -> IO a
sample = Ratio Natural -> [(a, Prob)] -> IO a
forall {m :: * -> *} {b}.
MonadIO m =>
Ratio Natural -> [(b, Prob)] -> m b
go Ratio Natural
1 ([(a, Prob)] -> IO a)
-> (Heap Prob a -> [(a, Prob)]) -> Heap Prob a -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Heap Prob a -> [(a, Prob)]
forall w a. Monus w => Heap w a -> [(a, w)]
search where
  go :: Ratio Natural -> [(b, Prob)] -> m b
go Ratio Natural
r ((b
x,Prob Ratio Natural
px):[(b, Prob)]
xs) = do
    let f :: Ratio Natural
f = Ratio Natural
r Ratio Natural -> Ratio Natural -> Ratio Natural
forall a. Num a => a -> a -> a
* Ratio Natural
px
    Integer
c <- (Integer, Integer) -> m Integer
forall a (m :: * -> *). (Random a, MonadIO m) => (a, a) -> m a
randomRIO (Integer
1, Natural -> Integer
forall a. Integral a => a -> Integer
toInteger (Ratio Natural -> Natural
forall a. Ratio a -> a
denominator Ratio Natural
f))
    if Integer -> Natural
forall a. Num a => Integer -> a
fromInteger Integer
c Natural -> Natural -> Bool
forall a. Ord a => a -> a -> Bool
<= Ratio Natural -> Natural
forall a. Ratio a -> a
numerator Ratio Natural
f  then b -> m b
forall (f :: * -> *) a. Applicative f => a -> f a
pure b
x else Ratio Natural -> [(b, Prob)] -> m b
go (Ratio Natural
r Ratio Natural -> Ratio Natural -> Ratio Natural
forall a. Fractional a => a -> a -> a
/ (Ratio Natural
1 Ratio Natural -> Ratio Natural -> Ratio Natural
forall a. Num a => a -> a -> a
- Ratio Natural
f)) [(b, Prob)]
xs
  go Ratio Natural
r [] = [Char] -> m b
forall a. HasCallStack => [Char] -> a
error [Char]
"impossible"