module Language.Synthesis.MCMC (mhList) where
import Control.Monad.Random (Rand, RandomGen, getRandom,
getSplit, runRand)
import Data.Functor ((<$>))
import Language.Synthesis.Distribution (Distr)
import qualified Language.Synthesis.Distribution as Distr
mhList :: RandomGen g =>
a
-> (a -> (b, Double))
-> (a -> Distr a)
-> Rand g [(a, b, Double)]
mhList startValue density jump = go (startValue, startAux, startDensity) <$> getSplit
where (startAux, startDensity) = density startValue
go orig g = let (next, g') = runRand (mhNext orig) g in orig : go next g'
mhNext (orig, origAux, origDensity) = do
next <- Distr.sample $ jump orig
let origToNext = Distr.logProbability (jump orig) next
nextToOrig = Distr.logProbability (jump next) orig
(nextAux, nextDensity) = density next
score = nextDensity origDensity + nextToOrig origToNext
acceptance <- getRandom
return $ if score >= log acceptance
then (next, nextAux, nextDensity)
else (orig, origAux, origDensity)