{-# OPTIONS_GHC -Wall #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} -- | -- Module: Numeric.MCMC.Hamiltonian -- Copyright: (c) 2015 Jared Tobin -- License: MIT -- -- Maintainer: Jared Tobin -- Stability: unstable -- Portability: ghc -- -- This implementation performs Hamiltonian Monte Carlo using an identity mass -- matrix. -- -- The 'mcmc' function streams a trace to stdout to be processed elsewhere, -- while the `slice` transition can be used for more flexible purposes, such as -- working with samples in memory. -- -- See for the definitive -- reference of the algorithm. module Numeric.MCMC.Hamiltonian ( mcmc , hamiltonian -- * Re-exported , Target(..) , MWC.create , MWC.createSystemRandom , MWC.withSystemRandom , MWC.asGenIO ) where import Control.Lens hiding (index) import Control.Monad.Trans.State.Strict hiding (state) import Control.Monad.Primitive (PrimState, PrimMonad, RealWorld) import qualified Data.Foldable as Foldable (sum) import Data.Maybe (fromMaybe) import Data.Sampling.Types import Data.Traversable (for) import Pipes hiding (for, next) import qualified Pipes.Prelude as Pipes import System.Random.MWC.Probability (Prob, Gen) import qualified System.Random.MWC.Probability as MWC -- | Trace 'n' iterations of a Markov chain and stream them to stdout. -- -- >>> withSystemRandom . asGenIO $ mcmc 3 1 [0, 0] target mcmc :: (Num (IxValue (t Double)), Show (t Double), Traversable t , FunctorWithIndex (Index (t Double)) t, Ixed (t Double) , IxValue (t Double) ~ Double) => Int -> Double -> Int -> t Double -> Target (t Double) -> Gen RealWorld -> IO () mcmc n step leaps chainPosition chainTarget gen = runEffect $ chain step leaps Chain {..} gen >-> Pipes.take n >-> Pipes.mapM_ print where chainScore = lTarget chainTarget chainPosition chainTunables = Nothing -- A Markov chain driven by the Metropolis transition operator. chain :: (Num (IxValue (t Double)), Traversable t , FunctorWithIndex (Index (t Double)) t, Ixed (t Double) , PrimMonad m, IxValue (t Double) ~ Double) => Double -> Int -> Chain (t Double) b -> Gen (PrimState m) -> Producer (Chain (t Double) b) m () chain step leaps = loop where loop state prng = do next <- lift (MWC.sample (execStateT (hamiltonian step leaps) state) prng) yield next loop next prng -- | A Hamiltonian transition operator. hamiltonian :: (Num (IxValue (t Double)), Traversable t , FunctorWithIndex (Index (t Double)) t, Ixed (t Double), PrimMonad m , IxValue (t Double) ~ Double) => Double -> Int -> Transition m (Chain (t Double) b) hamiltonian e l = do Chain {..} <- get r0 <- lift (for chainPosition (const MWC.standard)) zc <- lift (MWC.uniform :: PrimMonad m => Prob m Double) let (q, r) = leapfrogIntegrator chainTarget e l (chainPosition, r0) perturbed = nextState chainTarget (chainPosition, q) (r0, r) zc perturbedScore = lTarget chainTarget perturbed put (Chain chainTarget perturbedScore perturbed chainTunables) -- Calculate the next state of the chain. nextState :: (Foldable s, Foldable t, FunctorWithIndex (Index (t Double)) t , FunctorWithIndex (Index (s Double)) s, Ixed (s Double) , Ixed (t Double), IxValue (t Double) ~ Double , IxValue (s Double) ~ Double) => Target b -> (b, b) -> (s Double, t Double) -> Double -> b nextState target position momentum z | z < pAccept = snd position | otherwise = fst position where pAccept = acceptProb target position momentum -- Calculate the acceptance probability of a proposed moved. acceptProb :: (Foldable t, Foldable s, FunctorWithIndex (Index (t Double)) t , FunctorWithIndex (Index (s Double)) s, Ixed (t Double) , Ixed (s Double), IxValue (t Double) ~ Double , IxValue (s Double) ~ Double) => Target a -> (a, a) -> (s Double, t Double) -> Double acceptProb target (q0, q1) (r0, r1) = exp . min 0 $ auxilliaryTarget target (q1, r1) - auxilliaryTarget target (q0, r0) -- A momentum-augmented target. auxilliaryTarget :: (Foldable t, FunctorWithIndex (Index (t Double)) t , Ixed (t Double), IxValue (t Double) ~ Double) => Target a -> (a, t Double) -> Double auxilliaryTarget target (t, r) = f t - 0.5 * innerProduct r r where f = lTarget target innerProduct :: (Num (IxValue s), Foldable t, FunctorWithIndex (Index s) t, Ixed s) => t (IxValue s) -> s -> IxValue s innerProduct xs ys = Foldable.sum $ gzipWith (*) xs ys -- A container-generic zipwith. gzipWith :: (FunctorWithIndex (Index s) f, Ixed s) => (a -> IxValue s -> b) -> f a -> s -> f b gzipWith f xs ys = imap (\j x -> f x (fromMaybe err (ys ^? ix j))) xs where err = error "gzipWith: invalid index" -- The leapfrog or Stormer-Verlet integrator. leapfrogIntegrator :: (Num (IxValue (f Double)), Num (IxValue (t Double)) , FunctorWithIndex (Index (f Double)) t , FunctorWithIndex (Index (t Double)) f , Ixed (f Double), Ixed (t Double) , IxValue (f Double) ~ Double , IxValue (t Double) ~ Double) => Target (f Double) -> Double -> Int -> (f Double, t (IxValue (f Double))) -> (f Double, t (IxValue (f Double))) leapfrogIntegrator target e l (q0, r0) = go q0 r0 l where go q r 0 = (q, r) go q r n = go q1 r1 (pred n) where (q1, r1) = leapfrog target e (q, r) -- A single leapfrog step. leapfrog :: (Num (IxValue (f Double)), Num (IxValue (t Double)) , FunctorWithIndex (Index (f Double)) t , FunctorWithIndex (Index (t Double)) f , Ixed (t Double), Ixed (f Double) , IxValue (f Double) ~ Double, IxValue (t Double) ~ Double) => Target (f Double) -> Double -> (f Double, t (IxValue (f Double))) -> (f Double, t (IxValue (f Double))) leapfrog target e (q, r) = (qf, rf) where rm = adjustMomentum target e (q, r) qf = adjustPosition e (rm, q) rf = adjustMomentum target e (qf, rm) adjustMomentum :: (Functor f, Num (IxValue (f Double)) , FunctorWithIndex (Index (f Double)) t, Ixed (f Double)) => Target (f Double) -> Double -> (f Double, t (IxValue (f Double))) -> t (IxValue (f Double)) adjustMomentum target e (q, r) = r .+ ((0.5 * e) .* g q) where g = fromMaybe err (glTarget target) err = error "adjustMomentum: no gradient provided" adjustPosition :: (Functor f, Num (IxValue (f Double)) , FunctorWithIndex (Index (f Double)) t, Ixed (f Double)) => Double -> (f Double, t (IxValue (f Double))) -> t (IxValue (f Double)) adjustPosition e (r, q) = q .+ (e .* r) -- Scalar-vector product. (.*) :: (Num a, Functor f) => a -> f a -> f a z .* xs = fmap (* z) xs -- Vector addition. (.+) :: (Num (IxValue t), FunctorWithIndex (Index t) f, Ixed t) => f (IxValue t) -> t -> f (IxValue t) (.+) = gzipWith (+)