module Numeric.MCMC.Hamiltonian (
mcmc
, chain
, hamiltonian
, Target(..)
, MWC.create
, MWC.createSystemRandom
, MWC.withSystemRandom
, MWC.asGenIO
) where
import Control.Lens hiding (index)
import Control.Monad (replicateM)
import Control.Monad.Codensity (lowerCodensity)
import Control.Monad.IO.Class (MonadIO, liftIO)
import Control.Monad.Primitive (PrimState, PrimMonad)
import Control.Monad.Trans.State.Strict hiding (state)
import qualified Data.Foldable as Foldable (sum)
import Data.Maybe (fromMaybe)
import Data.Sampling.Types
import Data.Traversable (for)
import Pipes hiding (for, next)
import qualified Pipes.Prelude as Pipes
import System.Random.MWC.Probability (Prob, Gen)
import qualified System.Random.MWC.Probability as MWC
mcmc
:: ( MonadIO m, PrimMonad m
, Num (IxValue (t Double)), Show (t Double), Traversable t
, FunctorWithIndex (Index (t Double)) t, Ixed (t Double)
, IxValue (t Double) ~ Double)
=> Int
-> Double
-> Int
-> t Double
-> Target (t Double)
-> Gen (PrimState m)
-> m ()
mcmc n step leaps chainPosition chainTarget gen = runEffect $
drive step leaps Chain {..} gen
>-> Pipes.take n
>-> Pipes.mapM_ (liftIO . print)
where
chainScore = lTarget chainTarget chainPosition
chainTunables = Nothing
chain
:: (PrimMonad m, Traversable f
, FunctorWithIndex (Index (f Double)) f, Ixed (f Double)
, IxValue (f Double) ~ Double)
=> Int
-> Double
-> Int
-> f Double
-> Target (f Double)
-> Gen (PrimState m)
-> m [Chain (f Double) b]
chain n step leaps position target gen = runEffect $
drive step leaps origin gen
>-> collect n
where
origin = Chain {
chainScore = lTarget target position
, chainTunables = Nothing
, chainTarget = target
, chainPosition = position
}
collect :: Monad m => Int -> Consumer a m [a]
collect size = lowerCodensity $
replicateM size (lift Pipes.await)
drive
:: (Num (IxValue (t Double)), Traversable t
, FunctorWithIndex (Index (t Double)) t, Ixed (t Double)
, PrimMonad m, IxValue (t Double) ~ Double)
=> Double
-> Int
-> Chain (t Double) b
-> Gen (PrimState m)
-> Producer (Chain (t Double) b) m c
drive step leaps = loop where
loop state prng = do
next <- lift (MWC.sample (execStateT (hamiltonian step leaps) state) prng)
yield next
loop next prng
hamiltonian
:: (Num (IxValue (t Double)), Traversable t
, FunctorWithIndex (Index (t Double)) t, Ixed (t Double), PrimMonad m
, IxValue (t Double) ~ Double)
=> Double -> Int -> Transition m (Chain (t Double) b)
hamiltonian e l = do
Chain {..} <- get
r0 <- lift (for chainPosition (const MWC.standardNormal))
zc <- lift (MWC.uniform :: PrimMonad m => Prob m Double)
let (q, r) = leapfrogIntegrator chainTarget e l (chainPosition, r0)
perturbed = nextState chainTarget (chainPosition, q) (r0, r) zc
perturbedScore = lTarget chainTarget perturbed
put (Chain chainTarget perturbedScore perturbed chainTunables)
nextState
:: (Foldable s, Foldable t, FunctorWithIndex (Index (t Double)) t
, FunctorWithIndex (Index (s Double)) s, Ixed (s Double)
, Ixed (t Double), IxValue (t Double) ~ Double
, IxValue (s Double) ~ Double)
=> Target b
-> (b, b)
-> (s Double, t Double)
-> Double
-> b
nextState target position momentum z
| z < pAccept = snd position
| otherwise = fst position
where
pAccept = acceptProb target position momentum
acceptProb
:: (Foldable t, Foldable s, FunctorWithIndex (Index (t Double)) t
, FunctorWithIndex (Index (s Double)) s, Ixed (t Double)
, Ixed (s Double), IxValue (t Double) ~ Double
, IxValue (s Double) ~ Double)
=> Target a
-> (a, a)
-> (s Double, t Double)
-> Double
acceptProb target (q0, q1) (r0, r1) = exp . min 0 $
auxilliaryTarget target (q1, r1) auxilliaryTarget target (q0, r0)
auxilliaryTarget
:: (Foldable t, FunctorWithIndex (Index (t Double)) t
, Ixed (t Double), IxValue (t Double) ~ Double)
=> Target a
-> (a, t Double)
-> Double
auxilliaryTarget target (t, r) = f t 0.5 * innerProduct r r where
f = lTarget target
innerProduct
:: (Num (IxValue s), Foldable t, FunctorWithIndex (Index s) t, Ixed s)
=> t (IxValue s) -> s -> IxValue s
innerProduct xs ys = Foldable.sum $ gzipWith (*) xs ys
gzipWith
:: (FunctorWithIndex (Index s) f, Ixed s)
=> (a -> IxValue s -> b) -> f a -> s -> f b
gzipWith f xs ys = imap (\j x -> f x (fromMaybe err (ys ^? ix j))) xs where
err = error "gzipWith: invalid index"
leapfrogIntegrator
:: (Num (IxValue (f Double))
, FunctorWithIndex (Index (f Double)) t
, FunctorWithIndex (Index (t Double)) f
, Ixed (f Double), Ixed (t Double)
, IxValue (f Double) ~ Double
, IxValue (t Double) ~ Double)
=> Target (f Double)
-> Double
-> Int
-> (f Double, t (IxValue (f Double)))
-> (f Double, t (IxValue (f Double)))
leapfrogIntegrator target e l (q0, r0) = go q0 r0 l where
go q r 0 = (q, r)
go q r n = go q1 r1 (pred n) where
(q1, r1) = leapfrog target e (q, r)
leapfrog
:: (Num (IxValue (f Double))
, FunctorWithIndex (Index (f Double)) t
, FunctorWithIndex (Index (t Double)) f
, Ixed (t Double), Ixed (f Double)
, IxValue (f Double) ~ Double, IxValue (t Double) ~ Double)
=> Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> (f Double, t (IxValue (f Double)))
leapfrog target e (q, r) = (qf, rf) where
rm = adjustMomentum target e (q, r)
qf = adjustPosition e (rm, q)
rf = adjustMomentum target e (qf, rm)
adjustMomentum
:: (Functor f, Num (IxValue (f Double))
, FunctorWithIndex (Index (f Double)) t, Ixed (f Double))
=> Target (f Double)
-> Double
-> (f Double, t (IxValue (f Double)))
-> t (IxValue (f Double))
adjustMomentum target e (q, r) = r .+ ((0.5 * e) .* g q) where
g = fromMaybe err (glTarget target)
err = error "adjustMomentum: no gradient provided"
adjustPosition
:: (Functor f, Num (IxValue (f Double))
, FunctorWithIndex (Index (f Double)) t, Ixed (f Double))
=> Double
-> (f Double, t (IxValue (f Double)))
-> t (IxValue (f Double))
adjustPosition e (r, q) = q .+ (e .* r)
(.*) :: (Num a, Functor f) => a -> f a -> f a
z .* xs = fmap (* z) xs
(.+)
:: (Num (IxValue t), FunctorWithIndex (Index t) f, Ixed t)
=> f (IxValue t)
-> t
-> f (IxValue t)
(.+) = gzipWith (+)