{-# LANGUAGE TupleSections #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- Module      :  Mcmc.Proposal.Hamiltonian
-- Description :  Hamiltonian Monte Carlo proposal
-- Copyright   :  (c) 2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- Creation date: Mon Jul  5 12:59:42 2021.
--
-- The Hamiltonian Monte Carlo (HMC) proposal.
--
-- For references, see:
--
-- - [1] Chapter 5 of Handbook of Monte Carlo: Neal, R. M., MCMC Using
--   Hamiltonian Dynamics, In S. Brooks, A. Gelman, G. Jones, & X. Meng (Eds.),
--   Handbook of Markov Chain Monte Carlo (2011), CRC press.
--
-- - [2] Gelman, A., Carlin, J. B., Stern, H. S., & Rubin, D. B., Bayesian data
--   analysis (2014), CRC Press.
--
-- - [3] Review by Betancourt and notes: Betancourt, M., A conceptual
--   introduction to Hamiltonian Monte Carlo, arXiv, 1701–02434 (2017).
--
-- NOTE on implementation:
--
-- - The implementation assumes the existence of the gradient. Like so, the user
--   can use automatic or manual differentiation, depending on the problem at
--   hand.
--
-- - The state needs to be list like or 'Traversable' so that the structure of
--   the state space is available. A 'Traversable' constraint on the data type
--   is nice because it is more general than, for example, a list, and
--   user-defined data structures can be used.
--
-- - The state needs to have a zip-like 'Applicative' instance so that
-- - matrix/vector operations can be performed.

module Mcmc.Proposal.Hamiltonian
  ( Gradient,
    Masses,
    LeapfrogTrajectoryLength,
    LeapfrogScalingFactor,
    HTune (..),
    HSettings (..),
    hamiltonian,
  )
where

import Data.Foldable
import qualified Data.Matrix as M
import Data.Maybe
import Data.Traversable
import qualified Data.Vector as VB
import Mcmc.Prior
import Mcmc.Proposal
import Numeric.Log
import Statistics.Distribution
import Statistics.Distribution.Normal
import qualified Statistics.Function as S
import qualified Statistics.Sample as S
import System.Random.MWC

-- TODO: At the moment, the HMC proposal is agnostic of the prior and
-- likelihood, that is, the posterior function. This means, that it cannot know
-- when it reaches a point with zero posterior probability. This also affects
-- restricted or constrained parameters. See Gelman p. 303.

-- TODO: No-U-turn sampler.

-- TODO: Riemannian adaptation.

-- | Gradient of the log posterior function.
type Gradient f = f Double -> f Double

-- | Function validating the state.
--
-- Useful if parameters are constrained.
type Validate f = f Double -> Bool

-- | Masses of parameters.
--
-- NOTE: Full specification of a mass matrix including off-diagonal elements is
-- not supported.
--
-- NOTE: Parameters without masses ('Nothing') are not changed by the
-- Hamiltonian proposal.
--
-- The masses roughly describe how reluctant the particle moves through the
-- state space. If a parameter has higher mass, the momentum in this direction
-- will be changed less by the provided gradient, than when the same parameter
-- has lower mass.
--
-- The proposal is more efficient if masses are assigned according to the
-- inverse (co)-variance structure of the posterior function. That is,
-- parameters changing on larger scales should have lower masses than parameters
-- changing on lower scales. In particular, and for a diagonal mass matrix, the
-- optimal masses are the inverted variances of the parameters distributed
-- according to the posterior function.
--
-- Of course, the scales of the parameters of the posterior function are usually
-- unknown. Often, it is sufficient to
--
-- - set the masses to identical values roughly scaled with the inverted
--   estimated average variance of the posterior function; or even to
--
-- - set all masses to 1.0, and trust the tuning algorithm (see
--   'HTuneMassesAndLeapfrog') to find the correct values.
type Masses f = f (Maybe Double)

-- | Mean leapfrog trajectory length \(L\).
--
-- Number of leapfrog steps per proposal.
--
-- To avoid problems with ergodicity, the actual number of leapfrog steps is
-- sampled proposal from a discrete uniform distribution over the interval
-- \([\text{floor}(0.8L),\text{ceiling}(1.2L)]\).
--
-- For a discussion of ergodicity and reasons why randomization is important,
-- see [1] p. 15; also mentioned in [2] p. 304.
--
-- NOTE: To avoid errors, the left bound has an additional hard minimum of 1,
-- and the right bound is required to be larger equal than the left bound.
--
-- Usually set to 10, but larger values may be desirable.
type LeapfrogTrajectoryLength = Int

-- | Mean of leapfrog scaling factor \(\epsilon\).
--
-- Determines the size of each leapfrog step.
--
-- To avoid problems with ergodicity, the actual leapfrog scaling factor is
-- sampled per proposal from a continuous uniform distribution over the interval
-- \((0.8\epsilon,1.2\epsilon]\).
--
-- For a discussion of ergodicity and reasons why randomization is important,
-- see [1] p. 15; also mentioned in [2] p. 304.
--
-- Usually set such that \( L \epsilon = 1.0 \), but smaller values may be
-- required if acceptance rates are low.
type LeapfrogScalingFactor = Double

-- Target state containing parameters.
type Positions f = f Double

-- Momenta of the parameters.
type Momenta f = f (Maybe Double)

-- | Tuning settings.
--
--
-- Tuning of leapfrog parameters:
--
-- We expect that the larger the leapfrog step size the larger the proposal step
-- size and the lower the acceptance ratio. Consequently, if the acceptance rate
-- is too low, the leapfrog step size is decreased and vice versa. Further, the
-- leapfrog trajectory length is scaled such that the product of the leapfrog
-- step size and trajectory length stays constant.
--
-- Tuning of masses:
--
-- The variances of all parameters of the posterior distribution obtained over
-- the last auto tuning interval is calculated and the masses are amended using
-- the old masses and the inverted variances. If, for a specific coordinate, the
-- sample size is too low, or if the calculated variance is out of predefined
-- bounds, the mass of the affected position is not changed.
data HTune
  = -- | Tune masses and leapfrog parameters.
    HTuneMassesAndLeapfrog
  | -- | Tune leapfrog parameters only.
    HTuneLeapfrogOnly
  | -- | Do not tune at all.
    HNoTune
  deriving (HTune -> HTune -> Bool
(HTune -> HTune -> Bool) -> (HTune -> HTune -> Bool) -> Eq HTune
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: HTune -> HTune -> Bool
$c/= :: HTune -> HTune -> Bool
== :: HTune -> HTune -> Bool
$c== :: HTune -> HTune -> Bool
Eq, Int -> HTune -> ShowS
[HTune] -> ShowS
HTune -> String
(Int -> HTune -> ShowS)
-> (HTune -> String) -> ([HTune] -> ShowS) -> Show HTune
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HTune] -> ShowS
$cshowList :: [HTune] -> ShowS
show :: HTune -> String
$cshow :: HTune -> String
showsPrec :: Int -> HTune -> ShowS
$cshowsPrec :: Int -> HTune -> ShowS
Show)

-- | Specifications for Hamilton Monte Carlo proposal.
data HSettings f = HSettings
  { HSettings f -> Gradient f
hGradient :: Gradient f,
    HSettings f -> Maybe (Validate f)
hMaybeValidate :: Maybe (Validate f),
    HSettings f -> Masses f
hMasses :: Masses f,
    HSettings f -> Int
hLeapfrogTrajectoryLength :: LeapfrogTrajectoryLength,
    HSettings f -> LeapfrogScalingFactor
hLeapfrogScalingFactor :: LeapfrogScalingFactor,
    HSettings f -> HTune
hTune :: HTune
  }

checkHSettings :: Foldable f => HSettings f -> Maybe String
checkHSettings :: HSettings f -> Maybe String
checkHSettings (HSettings Gradient f
_ Maybe (Validate f)
_ Masses f
masses Int
l LeapfrogScalingFactor
eps HTune
_)
  | (Maybe LeapfrogScalingFactor -> Bool) -> Masses f -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Maybe LeapfrogScalingFactor -> Bool
forall a. (Ord a, Num a) => Maybe a -> Bool
f Masses f
masses = String -> Maybe String
forall a. a -> Maybe a
Just String
"checkHSettings: One or more masses are zero or negative."
  | Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
1 = String -> Maybe String
forall a. a -> Maybe a
Just String
"checkHSettings: Leapfrog trajectory length is zero or negative."
  | LeapfrogScalingFactor
eps LeapfrogScalingFactor -> LeapfrogScalingFactor -> Bool
forall a. Ord a => a -> a -> Bool
<= LeapfrogScalingFactor
0 = String -> Maybe String
forall a. a -> Maybe a
Just String
"checkHSettings: Leapfrog scaling factor is zero or negative."
  | Bool
otherwise = Maybe String
forall a. Maybe a
Nothing
  where
    f :: Maybe a -> Bool
f (Just a
m) = a
m a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
0
    f Maybe a
Nothing = Bool
False

generateMomenta ::
  Traversable f =>
  Masses f ->
  GenIO ->
  IO (Momenta f)
generateMomenta :: Masses f -> GenIO -> IO (Masses f)
generateMomenta Masses f
masses GenIO
gen = (Maybe LeapfrogScalingFactor -> IO (Maybe LeapfrogScalingFactor))
-> Masses f -> IO (Masses f)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse (GenIO
-> Maybe LeapfrogScalingFactor -> IO (Maybe LeapfrogScalingFactor)
forall (f :: * -> *).
PrimMonad f =>
Gen (PrimState f)
-> Maybe LeapfrogScalingFactor -> f (Maybe LeapfrogScalingFactor)
generateWith GenIO
gen) Masses f
masses
  where
    generateWith :: Gen (PrimState f)
-> Maybe LeapfrogScalingFactor -> f (Maybe LeapfrogScalingFactor)
generateWith Gen (PrimState f)
g (Just LeapfrogScalingFactor
m) = let d :: NormalDistribution
d = LeapfrogScalingFactor
-> LeapfrogScalingFactor -> NormalDistribution
normalDistr LeapfrogScalingFactor
0 (LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Floating a => a -> a
sqrt LeapfrogScalingFactor
m) in LeapfrogScalingFactor -> Maybe LeapfrogScalingFactor
forall a. a -> Maybe a
Just (LeapfrogScalingFactor -> Maybe LeapfrogScalingFactor)
-> f LeapfrogScalingFactor -> f (Maybe LeapfrogScalingFactor)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NormalDistribution -> Gen (PrimState f) -> f LeapfrogScalingFactor
forall d (m :: * -> *).
(ContGen d, PrimMonad m) =>
d -> Gen (PrimState m) -> m LeapfrogScalingFactor
genContVar NormalDistribution
d Gen (PrimState f)
g
    generateWith Gen (PrimState f)
_ Maybe LeapfrogScalingFactor
Nothing = Maybe LeapfrogScalingFactor -> f (Maybe LeapfrogScalingFactor)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe LeapfrogScalingFactor
forall a. Maybe a
Nothing

priorMomenta ::
  (Applicative f, Foldable f) =>
  Masses f ->
  Momenta f ->
  Prior
priorMomenta :: Masses f -> Masses f -> Prior
priorMomenta Masses f
masses Masses f
phi = (Prior -> Prior -> Prior) -> Prior -> f Prior -> Prior
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Prior -> Prior -> Prior
forall a. Num a => a -> a -> a
(*) Prior
1.0 (f Prior -> Prior) -> f Prior -> Prior
forall a b. (a -> b) -> a -> b
$ Maybe LeapfrogScalingFactor -> Maybe LeapfrogScalingFactor -> Prior
f (Maybe LeapfrogScalingFactor
 -> Maybe LeapfrogScalingFactor -> Prior)
-> Masses f -> f (Maybe LeapfrogScalingFactor -> Prior)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Masses f
masses f (Maybe LeapfrogScalingFactor -> Prior) -> Masses f -> f Prior
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Masses f
phi
  where
    f :: Maybe LeapfrogScalingFactor -> Maybe LeapfrogScalingFactor -> Prior
f (Just LeapfrogScalingFactor
m) (Just LeapfrogScalingFactor
p) = let d :: NormalDistribution
d = LeapfrogScalingFactor
-> LeapfrogScalingFactor -> NormalDistribution
normalDistr LeapfrogScalingFactor
0 (LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Floating a => a -> a
sqrt LeapfrogScalingFactor
m) in LeapfrogScalingFactor -> Prior
forall a. a -> Log a
Exp (LeapfrogScalingFactor -> Prior) -> LeapfrogScalingFactor -> Prior
forall a b. (a -> b) -> a -> b
$ NormalDistribution
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall d.
ContDistr d =>
d -> LeapfrogScalingFactor -> LeapfrogScalingFactor
logDensity NormalDistribution
d LeapfrogScalingFactor
p
    f Maybe LeapfrogScalingFactor
Nothing Maybe LeapfrogScalingFactor
Nothing = Prior
1.0
    f Maybe LeapfrogScalingFactor
_ Maybe LeapfrogScalingFactor
_ = String -> Prior
forall a. HasCallStack => String -> a
error String
"priorMomenta: Got just a mass and no momentum or the other way around."

leapfrog ::
  Applicative f =>
  Gradient f ->
  Maybe (Validate f) ->
  Masses f ->
  LeapfrogTrajectoryLength ->
  LeapfrogScalingFactor ->
  Positions f ->
  Momenta f ->
  -- Maybe (Positions', Momenta').
  Maybe (Positions f, Momenta f)
leapfrog :: Gradient f
-> Maybe (Validate f)
-> Masses f
-> Int
-> LeapfrogScalingFactor
-> Positions f
-> Masses f
-> Maybe (Positions f, Masses f)
leapfrog Gradient f
grad Maybe (Validate f)
mVal Masses f
masses Int
l LeapfrogScalingFactor
eps Positions f
theta Masses f
phi = do
  let -- The first half step of the momenta.
      phiHalf :: Masses f
phiHalf = LeapfrogScalingFactor
-> LeapfrogScalingFactor
-> Gradient f
-> Positions f
-> Masses f
-> Masses f
forall (f :: * -> *).
Applicative f =>
LeapfrogScalingFactor
-> LeapfrogScalingFactor
-> Gradient f
-> Positions f
-> Momenta f
-> Momenta f
leapfrogStepMomenta LeapfrogScalingFactor
0.5 LeapfrogScalingFactor
eps Gradient f
grad Positions f
theta Masses f
phi
  -- L-1 full steps. This gives the positions theta_{L-1}, and the momenta
  -- phi_{L-1/2}.
  (Positions f
thetaLM1, Masses f
phiLM1Half) <- Int
-> Maybe (Positions f, Masses f) -> Maybe (Positions f, Masses f)
forall t.
(Eq t, Num t) =>
t -> Maybe (Positions f, Masses f) -> Maybe (Positions f, Masses f)
go (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) ((Positions f, Masses f) -> Maybe (Positions f, Masses f)
forall a. a -> Maybe a
Just (Positions f
theta, Masses f
phiHalf))
  -- The last full step of the positions.
  Positions f
thetaL <- Positions f -> Maybe (Positions f)
valF (Positions f -> Maybe (Positions f))
-> Positions f -> Maybe (Positions f)
forall a b. (a -> b) -> a -> b
$ LeapfrogScalingFactor
-> Masses f -> Positions f -> Masses f -> Positions f
forall (f :: * -> *).
Applicative f =>
LeapfrogScalingFactor
-> Masses f -> Positions f -> Masses f -> Positions f
leapfrogStepPositions LeapfrogScalingFactor
eps Masses f
masses Positions f
thetaLM1 Masses f
phiLM1Half
  let -- The last half step of the momenta.
      phiL :: Masses f
phiL = LeapfrogScalingFactor
-> LeapfrogScalingFactor
-> Gradient f
-> Positions f
-> Masses f
-> Masses f
forall (f :: * -> *).
Applicative f =>
LeapfrogScalingFactor
-> LeapfrogScalingFactor
-> Gradient f
-> Positions f
-> Momenta f
-> Momenta f
leapfrogStepMomenta LeapfrogScalingFactor
0.5 LeapfrogScalingFactor
eps Gradient f
grad Positions f
thetaL Masses f
phiLM1Half
  (Positions f, Masses f) -> Maybe (Positions f, Masses f)
forall (m :: * -> *) a. Monad m => a -> m a
return (Positions f
thetaL, Masses f
phiL)
  where
    valF :: Positions f -> Maybe (Positions f)
valF Positions f
x = case Maybe (Validate f)
mVal of
      Maybe (Validate f)
Nothing -> Positions f -> Maybe (Positions f)
forall a. a -> Maybe a
Just Positions f
x
      Just Validate f
f -> if Validate f
f Positions f
x then Positions f -> Maybe (Positions f)
forall a. a -> Maybe a
Just Positions f
x else Maybe (Positions f)
forall a. Maybe a
Nothing
    go :: t -> Maybe (Positions f, Masses f) -> Maybe (Positions f, Masses f)
go t
_ Maybe (Positions f, Masses f)
Nothing = Maybe (Positions f, Masses f)
forall a. Maybe a
Nothing
    go t
0 (Just (Positions f
t, Masses f
p)) = (Positions f, Masses f) -> Maybe (Positions f, Masses f)
forall a. a -> Maybe a
Just (Positions f
t, Masses f
p)
    go t
n (Just (Positions f
t, Masses f
p)) =
      let t' :: Positions f
t' = LeapfrogScalingFactor
-> Masses f -> Positions f -> Masses f -> Positions f
forall (f :: * -> *).
Applicative f =>
LeapfrogScalingFactor
-> Masses f -> Positions f -> Masses f -> Positions f
leapfrogStepPositions LeapfrogScalingFactor
eps Masses f
masses Positions f
t Masses f
p
          p' :: Masses f
p' = LeapfrogScalingFactor
-> LeapfrogScalingFactor
-> Gradient f
-> Positions f
-> Masses f
-> Masses f
forall (f :: * -> *).
Applicative f =>
LeapfrogScalingFactor
-> LeapfrogScalingFactor
-> Gradient f
-> Positions f
-> Momenta f
-> Momenta f
leapfrogStepMomenta LeapfrogScalingFactor
1.0 LeapfrogScalingFactor
eps Gradient f
grad Positions f
t' Masses f
p
          r :: Maybe (Positions f, Masses f)
r = (,Masses f
p') (Positions f -> (Positions f, Masses f))
-> Maybe (Positions f) -> Maybe (Positions f, Masses f)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Positions f -> Maybe (Positions f)
valF Positions f
t'
       in t -> Maybe (Positions f, Masses f) -> Maybe (Positions f, Masses f)
go (t
n t -> t -> t
forall a. Num a => a -> a -> a
- t
1) Maybe (Positions f, Masses f)
r

leapfrogStepMomenta ::
  Applicative f =>
  -- Size of step (half or full step).
  Double ->
  LeapfrogScalingFactor ->
  Gradient f ->
  -- Current positions.
  Positions f ->
  -- Current momenta.
  Momenta f ->
  -- New momenta.
  Momenta f
leapfrogStepMomenta :: LeapfrogScalingFactor
-> LeapfrogScalingFactor
-> Gradient f
-> Positions f
-> Momenta f
-> Momenta f
leapfrogStepMomenta LeapfrogScalingFactor
xi LeapfrogScalingFactor
eps Gradient f
grad Positions f
theta Momenta f
phi = Momenta f
phi Momenta f -> Positions f -> Momenta f
forall (f :: * -> *).
Applicative f =>
f (Maybe LeapfrogScalingFactor)
-> f LeapfrogScalingFactor -> f (Maybe LeapfrogScalingFactor)
<+. ((LeapfrogScalingFactor
xi LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
eps) LeapfrogScalingFactor -> Gradient f
forall (f :: * -> *).
Applicative f =>
LeapfrogScalingFactor
-> f LeapfrogScalingFactor -> f LeapfrogScalingFactor
.* Gradient f
grad Positions f
theta)
  where
    (<+.) :: Applicative f => f (Maybe Double) -> f Double -> f (Maybe Double)
    <+. :: f (Maybe LeapfrogScalingFactor)
-> f LeapfrogScalingFactor -> f (Maybe LeapfrogScalingFactor)
(<+.) f (Maybe LeapfrogScalingFactor)
xs f LeapfrogScalingFactor
ys = Maybe LeapfrogScalingFactor
-> LeapfrogScalingFactor -> Maybe LeapfrogScalingFactor
forall a. Num a => Maybe a -> a -> Maybe a
f (Maybe LeapfrogScalingFactor
 -> LeapfrogScalingFactor -> Maybe LeapfrogScalingFactor)
-> f (Maybe LeapfrogScalingFactor)
-> f (LeapfrogScalingFactor -> Maybe LeapfrogScalingFactor)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (Maybe LeapfrogScalingFactor)
xs f (LeapfrogScalingFactor -> Maybe LeapfrogScalingFactor)
-> f LeapfrogScalingFactor -> f (Maybe LeapfrogScalingFactor)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> f LeapfrogScalingFactor
ys
    f :: Maybe a -> a -> Maybe a
f Maybe a
Nothing a
_ = Maybe a
forall a. Maybe a
Nothing
    f (Just a
x) a
y = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
y

leapfrogStepPositions ::
  Applicative f =>
  LeapfrogScalingFactor ->
  Masses f ->
  -- Current positions.
  Positions f ->
  -- Current momenta.
  Momenta f ->
  Positions f
-- The arguments are flipped to encounter the maybe momentum.
leapfrogStepPositions :: LeapfrogScalingFactor
-> Masses f -> Positions f -> Masses f -> Positions f
leapfrogStepPositions LeapfrogScalingFactor
eps Masses f
masses Positions f
theta Masses f
phi = Positions f
theta Positions f -> Masses f -> Positions f
forall (f :: * -> *).
Applicative f =>
f LeapfrogScalingFactor
-> f (Maybe LeapfrogScalingFactor) -> f LeapfrogScalingFactor
<+. (Masses f
mScaledReversed Masses f -> Masses f -> Masses f
forall (f :: * -> *).
Applicative f =>
f (Maybe LeapfrogScalingFactor)
-> f (Maybe LeapfrogScalingFactor)
-> f (Maybe LeapfrogScalingFactor)
.*> Masses f
phi)
  where
    (<+.) :: Applicative f => f Double -> f (Maybe Double) -> f Double
    <+. :: f LeapfrogScalingFactor
-> f (Maybe LeapfrogScalingFactor) -> f LeapfrogScalingFactor
(<+.) f LeapfrogScalingFactor
xs f (Maybe LeapfrogScalingFactor)
ys = LeapfrogScalingFactor
-> Maybe LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> Maybe a -> a
f (LeapfrogScalingFactor
 -> Maybe LeapfrogScalingFactor -> LeapfrogScalingFactor)
-> f LeapfrogScalingFactor
-> f (Maybe LeapfrogScalingFactor -> LeapfrogScalingFactor)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f LeapfrogScalingFactor
xs f (Maybe LeapfrogScalingFactor -> LeapfrogScalingFactor)
-> f (Maybe LeapfrogScalingFactor) -> f LeapfrogScalingFactor
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> f (Maybe LeapfrogScalingFactor)
ys
    f :: a -> Maybe a -> a
f a
x Maybe a
Nothing = a
x
    f a
x (Just a
y) = a
x a -> a -> a
forall a. Num a => a -> a -> a
+ a
y
    mScaledReversed :: Masses f
mScaledReversed = ((Maybe LeapfrogScalingFactor -> Maybe LeapfrogScalingFactor)
-> Masses f -> Masses f
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Maybe LeapfrogScalingFactor -> Maybe LeapfrogScalingFactor)
 -> Masses f -> Masses f)
-> ((LeapfrogScalingFactor -> LeapfrogScalingFactor)
    -> Maybe LeapfrogScalingFactor -> Maybe LeapfrogScalingFactor)
-> (LeapfrogScalingFactor -> LeapfrogScalingFactor)
-> Masses f
-> Masses f
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (LeapfrogScalingFactor -> LeapfrogScalingFactor)
-> Maybe LeapfrogScalingFactor -> Maybe LeapfrogScalingFactor
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap) ((LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
eps) (LeapfrogScalingFactor -> LeapfrogScalingFactor)
-> (LeapfrogScalingFactor -> LeapfrogScalingFactor)
-> LeapfrogScalingFactor
-> LeapfrogScalingFactor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Floating a => a -> a -> a
** (-LeapfrogScalingFactor
1))) Masses f
masses
    (.*>) :: Applicative f => f (Maybe Double) -> f (Maybe Double) -> f (Maybe Double)
    .*> :: f (Maybe LeapfrogScalingFactor)
-> f (Maybe LeapfrogScalingFactor)
-> f (Maybe LeapfrogScalingFactor)
(.*>) f (Maybe LeapfrogScalingFactor)
xs f (Maybe LeapfrogScalingFactor)
ys = Maybe LeapfrogScalingFactor
-> Maybe LeapfrogScalingFactor -> Maybe LeapfrogScalingFactor
forall a. Num a => Maybe a -> Maybe a -> Maybe a
g (Maybe LeapfrogScalingFactor
 -> Maybe LeapfrogScalingFactor -> Maybe LeapfrogScalingFactor)
-> f (Maybe LeapfrogScalingFactor)
-> f (Maybe LeapfrogScalingFactor -> Maybe LeapfrogScalingFactor)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f (Maybe LeapfrogScalingFactor)
xs f (Maybe LeapfrogScalingFactor -> Maybe LeapfrogScalingFactor)
-> f (Maybe LeapfrogScalingFactor)
-> f (Maybe LeapfrogScalingFactor)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> f (Maybe LeapfrogScalingFactor)
ys
    g :: Maybe a -> Maybe a -> Maybe a
g (Just a
x) (Just a
y) = a -> Maybe a
forall a. a -> Maybe a
Just (a -> Maybe a) -> a -> Maybe a
forall a b. (a -> b) -> a -> b
$ a
x a -> a -> a
forall a. Num a => a -> a -> a
* a
y
    g Maybe a
Nothing Maybe a
Nothing = Maybe a
forall a. Maybe a
Nothing
    g Maybe a
_ Maybe a
_ = String -> Maybe a
forall a. HasCallStack => String -> a
error String
"leapfrogStepPositions: Got just a mass and no momentum or the other way around."

-- Scalar-vector multiplication.
(.*) :: Applicative f => Double -> f Double -> f Double
.* :: LeapfrogScalingFactor
-> f LeapfrogScalingFactor -> f LeapfrogScalingFactor
(.*) LeapfrogScalingFactor
x f LeapfrogScalingFactor
ys = (LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
x) (LeapfrogScalingFactor -> LeapfrogScalingFactor)
-> f LeapfrogScalingFactor -> f LeapfrogScalingFactor
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> f LeapfrogScalingFactor
ys

-- NOTE: Fixed parameters without mass have a tuning parameter of NaN.
massesToTuningParameters :: Foldable f => Masses f -> AuxiliaryTuningParameters
massesToTuningParameters :: Masses f -> AuxiliaryTuningParameters
massesToTuningParameters = [LeapfrogScalingFactor] -> AuxiliaryTuningParameters
forall a. [a] -> Vector a
VB.fromList ([LeapfrogScalingFactor] -> AuxiliaryTuningParameters)
-> (Masses f -> [LeapfrogScalingFactor])
-> Masses f
-> AuxiliaryTuningParameters
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Maybe LeapfrogScalingFactor -> LeapfrogScalingFactor)
-> [Maybe LeapfrogScalingFactor] -> [LeapfrogScalingFactor]
forall a b. (a -> b) -> [a] -> [b]
map (LeapfrogScalingFactor
-> Maybe LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. a -> Maybe a -> a
fromMaybe LeapfrogScalingFactor
nan) ([Maybe LeapfrogScalingFactor] -> [LeapfrogScalingFactor])
-> (Masses f -> [Maybe LeapfrogScalingFactor])
-> Masses f
-> [LeapfrogScalingFactor]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Masses f -> [Maybe LeapfrogScalingFactor]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList
  where
    nan :: LeapfrogScalingFactor
nan = LeapfrogScalingFactor
0 LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Fractional a => a -> a -> a
/ LeapfrogScalingFactor
0

-- We need the structure in order to fill it with the given parameters.
tuningParametersToMasses ::
  Traversable f =>
  AuxiliaryTuningParameters ->
  Masses f ->
  Either String (Masses f)
tuningParametersToMasses :: AuxiliaryTuningParameters -> Masses f -> Either String (Masses f)
tuningParametersToMasses AuxiliaryTuningParameters
xs Masses f
ms =
  if [LeapfrogScalingFactor] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [LeapfrogScalingFactor]
xs'
    then f (Either String (Maybe LeapfrogScalingFactor))
-> Either String (Masses f)
forall (t :: * -> *) (f :: * -> *) a.
(Traversable t, Applicative f) =>
t (f a) -> f (t a)
sequenceA f (Either String (Maybe LeapfrogScalingFactor))
msE
    else String -> Either String (Masses f)
forall a b. a -> Either a b
Left String
"tuningParametersToMasses: Too many values."
  where
    ([LeapfrogScalingFactor]
xs', f (Either String (Maybe LeapfrogScalingFactor))
msE) = ([LeapfrogScalingFactor]
 -> Maybe LeapfrogScalingFactor
 -> ([LeapfrogScalingFactor],
     Either String (Maybe LeapfrogScalingFactor)))
-> [LeapfrogScalingFactor]
-> Masses f
-> ([LeapfrogScalingFactor],
    f (Either String (Maybe LeapfrogScalingFactor)))
forall (t :: * -> *) a b c.
Traversable t =>
(a -> b -> (a, c)) -> a -> t b -> (a, t c)
mapAccumL [LeapfrogScalingFactor]
-> Maybe LeapfrogScalingFactor
-> ([LeapfrogScalingFactor],
    Either String (Maybe LeapfrogScalingFactor))
forall a p.
RealFloat a =>
[a] -> p -> ([a], Either String (Maybe a))
setValue (AuxiliaryTuningParameters -> [LeapfrogScalingFactor]
forall a. Vector a -> [a]
VB.toList AuxiliaryTuningParameters
xs) Masses f
ms
    setValue :: [a] -> p -> ([a], Either String (Maybe a))
setValue [] p
_ = ([], String -> Either String (Maybe a)
forall a b. a -> Either a b
Left String
"tuningParametersToMasses: Too few values.")
    -- NOTE: Recover fixed parameters and unset their mass.
    setValue (a
y : [a]
ys) p
_ = let y' :: Maybe a
y' = if a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
y then Maybe a
forall a. Maybe a
Nothing else a -> Maybe a
forall a. a -> Maybe a
Just a
y in ([a]
ys, Maybe a -> Either String (Maybe a)
forall a b. b -> Either a b
Right Maybe a
y')

hTuningParametersToSettings ::
  Traversable f =>
  TuningParameter ->
  AuxiliaryTuningParameters ->
  HSettings f ->
  Either String (HSettings f)
hTuningParametersToSettings :: LeapfrogScalingFactor
-> AuxiliaryTuningParameters
-> HSettings f
-> Either String (HSettings f)
hTuningParametersToSettings LeapfrogScalingFactor
t AuxiliaryTuningParameters
ts (HSettings Gradient f
g Maybe (Validate f)
v Masses f
m Int
l LeapfrogScalingFactor
e HTune
tn) =
  if HTune
tn HTune -> HTune -> Bool
forall a. Eq a => a -> a -> Bool
== HTune
HTuneMassesAndLeapfrog
    then case AuxiliaryTuningParameters -> Masses f -> Either String (Masses f)
forall (f :: * -> *).
Traversable f =>
AuxiliaryTuningParameters -> Masses f -> Either String (Masses f)
tuningParametersToMasses AuxiliaryTuningParameters
ts Masses f
m of
      Left String
err -> String -> Either String (HSettings f)
forall a b. a -> Either a b
Left String
err
      Right Masses f
m' -> HSettings f -> Either String (HSettings f)
forall a b. b -> Either a b
Right (HSettings f -> Either String (HSettings f))
-> HSettings f -> Either String (HSettings f)
forall a b. (a -> b) -> a -> b
$ Gradient f
-> Maybe (Validate f)
-> Masses f
-> Int
-> LeapfrogScalingFactor
-> HTune
-> HSettings f
forall (f :: * -> *).
Gradient f
-> Maybe (Validate f)
-> Masses f
-> Int
-> LeapfrogScalingFactor
-> HTune
-> HSettings f
HSettings Gradient f
g Maybe (Validate f)
v Masses f
m' Int
lTuned LeapfrogScalingFactor
eTuned HTune
tn
    else HSettings f -> Either String (HSettings f)
forall a b. b -> Either a b
Right (HSettings f -> Either String (HSettings f))
-> HSettings f -> Either String (HSettings f)
forall a b. (a -> b) -> a -> b
$ Gradient f
-> Maybe (Validate f)
-> Masses f
-> Int
-> LeapfrogScalingFactor
-> HTune
-> HSettings f
forall (f :: * -> *).
Gradient f
-> Maybe (Validate f)
-> Masses f
-> Int
-> LeapfrogScalingFactor
-> HTune
-> HSettings f
HSettings Gradient f
g Maybe (Validate f)
v Masses f
m Int
lTuned LeapfrogScalingFactor
eTuned HTune
tn
  where
    -- The larger epsilon, the larger the proposal step size and the lower the
    -- expected acceptance ratio.
    --
    -- Further, we roughly keep \( L * \epsilon = 1.0 \). The equation is not
    -- correct, because we pull L closer to the original value to keep the
    -- runtime somewhat acceptable.
    lTuned :: Int
lTuned = LeapfrogScalingFactor -> Int
forall a b. (RealFrac a, Integral b) => a -> b
ceiling (LeapfrogScalingFactor -> Int) -> LeapfrogScalingFactor -> Int
forall a b. (a -> b) -> a -> b
$ Int -> LeapfrogScalingFactor
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
l LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Fractional a => a -> a -> a
/ (LeapfrogScalingFactor
t LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Floating a => a -> a -> a
** LeapfrogScalingFactor
0.9) :: Int
    eTuned :: LeapfrogScalingFactor
eTuned = LeapfrogScalingFactor
t LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
e

hamiltonianSimpleWithTuningParameters ::
  (Applicative f, Traversable f) =>
  HSettings f ->
  TuningParameter ->
  AuxiliaryTuningParameters ->
  Either String (ProposalSimple (Positions f))
hamiltonianSimpleWithTuningParameters :: HSettings f
-> LeapfrogScalingFactor
-> AuxiliaryTuningParameters
-> Either String (ProposalSimple (Positions f))
hamiltonianSimpleWithTuningParameters HSettings f
s LeapfrogScalingFactor
t AuxiliaryTuningParameters
ts = case LeapfrogScalingFactor
-> AuxiliaryTuningParameters
-> HSettings f
-> Either String (HSettings f)
forall (f :: * -> *).
Traversable f =>
LeapfrogScalingFactor
-> AuxiliaryTuningParameters
-> HSettings f
-> Either String (HSettings f)
hTuningParametersToSettings LeapfrogScalingFactor
t AuxiliaryTuningParameters
ts HSettings f
s of
  Left String
err -> String
-> Either
     String
     (Positions f -> Gen RealWorld -> IO (Positions f, Prior, Prior))
forall a b. a -> Either a b
Left String
err
  Right HSettings f
s' -> (Positions f -> Gen RealWorld -> IO (Positions f, Prior, Prior))
-> Either
     String
     (Positions f -> Gen RealWorld -> IO (Positions f, Prior, Prior))
forall a b. b -> Either a b
Right ((Positions f -> Gen RealWorld -> IO (Positions f, Prior, Prior))
 -> Either
      String
      (Positions f -> Gen RealWorld -> IO (Positions f, Prior, Prior)))
-> (Positions f -> Gen RealWorld -> IO (Positions f, Prior, Prior))
-> Either
     String
     (Positions f -> Gen RealWorld -> IO (Positions f, Prior, Prior))
forall a b. (a -> b) -> a -> b
$ HSettings f -> ProposalSimple (Positions f)
forall (f :: * -> *).
(Applicative f, Traversable f) =>
HSettings f -> ProposalSimple (Positions f)
hamiltonianSimple HSettings f
s'

hamiltonianSimple ::
  (Applicative f, Traversable f) =>
  HSettings f ->
  ProposalSimple (Positions f)
hamiltonianSimple :: HSettings f -> ProposalSimple (Positions f)
hamiltonianSimple (HSettings Gradient f
gradient Maybe (Validate f)
mVal Masses f
masses Int
l LeapfrogScalingFactor
e HTune
_) Positions f
theta GenIO
g = do
  Masses f
phi <- Masses f -> GenIO -> IO (Masses f)
forall (f :: * -> *).
Traversable f =>
Masses f -> GenIO -> IO (Masses f)
generateMomenta Masses f
masses GenIO
g
  Int
lRan <- (Int, Int) -> GenIO -> IO Int
forall a (m :: * -> *).
(Variate a, PrimMonad m) =>
(a, a) -> Gen (PrimState m) -> m a
uniformR (Int
lL, Int
lR) GenIO
g
  LeapfrogScalingFactor
eRan <- (LeapfrogScalingFactor, LeapfrogScalingFactor)
-> GenIO -> IO LeapfrogScalingFactor
forall a (m :: * -> *).
(Variate a, PrimMonad m) =>
(a, a) -> Gen (PrimState m) -> m a
uniformR (LeapfrogScalingFactor
eL, LeapfrogScalingFactor
eR) GenIO
g
  case Gradient f
-> Maybe (Validate f)
-> Masses f
-> Int
-> LeapfrogScalingFactor
-> Positions f
-> Masses f
-> Maybe (Positions f, Masses f)
forall (f :: * -> *).
Applicative f =>
Gradient f
-> Maybe (Validate f)
-> Masses f
-> Int
-> LeapfrogScalingFactor
-> Positions f
-> Masses f
-> Maybe (Positions f, Masses f)
leapfrog Gradient f
gradient Maybe (Validate f)
mVal Masses f
masses Int
lRan LeapfrogScalingFactor
eRan Positions f
theta Masses f
phi of
    Maybe (Positions f, Masses f)
Nothing -> (Positions f, Prior, Prior) -> IO (Positions f, Prior, Prior)
forall (m :: * -> *) a. Monad m => a -> m a
return (Positions f
theta, Prior
0.0, Prior
1.0)
    Just (Positions f
theta', Masses f
phi') ->
      let prPhi :: Prior
prPhi = Masses f -> Masses f -> Prior
forall (f :: * -> *).
(Applicative f, Foldable f) =>
Masses f -> Masses f -> Prior
priorMomenta Masses f
masses Masses f
phi
          -- NOTE: Neal page 12: In order for the proposal to be in detailed
          -- balance, the momenta have to be negated before proposing the new value.
          -- This is not required here since the prior involves normal distributions
          -- centered around 0. However, if the multivariate normal distribution is
          -- used, it makes a difference.
          prPhi' :: Prior
prPhi' = Masses f -> Masses f -> Prior
forall (f :: * -> *).
(Applicative f, Foldable f) =>
Masses f -> Masses f -> Prior
priorMomenta Masses f
masses Masses f
phi'
          kernelR :: Prior
kernelR = Prior
prPhi' Prior -> Prior -> Prior
forall a. Fractional a => a -> a -> a
/ Prior
prPhi
       in (Positions f, Prior, Prior) -> IO (Positions f, Prior, Prior)
forall (m :: * -> *) a. Monad m => a -> m a
return (Positions f
theta', Prior
kernelR, Prior
1.0)
  where
    lL :: Int
lL = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum [Int
1 :: Int, LeapfrogScalingFactor -> Int
forall a b. (RealFrac a, Integral b) => a -> b
floor (LeapfrogScalingFactor -> Int) -> LeapfrogScalingFactor -> Int
forall a b. (a -> b) -> a -> b
$ (LeapfrogScalingFactor
0.8 :: Double) LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* Int -> LeapfrogScalingFactor
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
l]
    lR :: Int
lR = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum [Int
lL, LeapfrogScalingFactor -> Int
forall a b. (RealFrac a, Integral b) => a -> b
ceiling (LeapfrogScalingFactor -> Int) -> LeapfrogScalingFactor -> Int
forall a b. (a -> b) -> a -> b
$ (LeapfrogScalingFactor
1.2 :: Double) LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* Int -> LeapfrogScalingFactor
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
l]
    eL :: LeapfrogScalingFactor
eL = LeapfrogScalingFactor
0.8 LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
e
    eR :: LeapfrogScalingFactor
eR = LeapfrogScalingFactor
1.2 LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
e

minVariance :: Double
minVariance :: LeapfrogScalingFactor
minVariance = LeapfrogScalingFactor
1e-6

maxVariance :: Double
maxVariance :: LeapfrogScalingFactor
maxVariance = LeapfrogScalingFactor
1e6

minSamples :: Int
minSamples :: Int
minSamples = Int
60

computeAuxiliaryTuningParameters ::
  Foldable f =>
  VB.Vector (Positions f) ->
  AuxiliaryTuningParameters ->
  AuxiliaryTuningParameters
computeAuxiliaryTuningParameters :: Vector (Positions f)
-> AuxiliaryTuningParameters -> AuxiliaryTuningParameters
computeAuxiliaryTuningParameters Vector (Positions f)
xss AuxiliaryTuningParameters
ts =
  (LeapfrogScalingFactor
 -> AuxiliaryTuningParameters -> LeapfrogScalingFactor)
-> AuxiliaryTuningParameters
-> Vector AuxiliaryTuningParameters
-> AuxiliaryTuningParameters
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
VB.zipWith (\LeapfrogScalingFactor
t -> LeapfrogScalingFactor
-> (Int, LeapfrogScalingFactor) -> LeapfrogScalingFactor
rescueWith LeapfrogScalingFactor
t ((Int, LeapfrogScalingFactor) -> LeapfrogScalingFactor)
-> (AuxiliaryTuningParameters -> (Int, LeapfrogScalingFactor))
-> AuxiliaryTuningParameters
-> LeapfrogScalingFactor
forall b c a. (b -> c) -> (a -> b) -> a -> c
. AuxiliaryTuningParameters -> (Int, LeapfrogScalingFactor)
calcSamplesAndVariance) AuxiliaryTuningParameters
ts Vector AuxiliaryTuningParameters
xssT
  where
    -- TODO: Improve matrix transposition.
    xssT :: Vector AuxiliaryTuningParameters
xssT = [AuxiliaryTuningParameters] -> Vector AuxiliaryTuningParameters
forall a. [a] -> Vector a
VB.fromList ([AuxiliaryTuningParameters] -> Vector AuxiliaryTuningParameters)
-> [AuxiliaryTuningParameters] -> Vector AuxiliaryTuningParameters
forall a b. (a -> b) -> a -> b
$ Matrix LeapfrogScalingFactor -> [AuxiliaryTuningParameters]
forall a. Context a => Matrix a -> [Vector a]
M.toColumns (Matrix LeapfrogScalingFactor -> [AuxiliaryTuningParameters])
-> Matrix LeapfrogScalingFactor -> [AuxiliaryTuningParameters]
forall a b. (a -> b) -> a -> b
$ [[LeapfrogScalingFactor]] -> Matrix LeapfrogScalingFactor
forall a. Context a => [[a]] -> Matrix a
M.fromLists ([[LeapfrogScalingFactor]] -> Matrix LeapfrogScalingFactor)
-> [[LeapfrogScalingFactor]] -> Matrix LeapfrogScalingFactor
forall a b. (a -> b) -> a -> b
$ Vector [LeapfrogScalingFactor] -> [[LeapfrogScalingFactor]]
forall a. Vector a -> [a]
VB.toList (Vector [LeapfrogScalingFactor] -> [[LeapfrogScalingFactor]])
-> Vector [LeapfrogScalingFactor] -> [[LeapfrogScalingFactor]]
forall a b. (a -> b) -> a -> b
$ (Positions f -> [LeapfrogScalingFactor])
-> Vector (Positions f) -> Vector [LeapfrogScalingFactor]
forall a b. (a -> b) -> Vector a -> Vector b
VB.map Positions f -> [LeapfrogScalingFactor]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList Vector (Positions f)
xss
    calcSamplesAndVariance :: AuxiliaryTuningParameters -> (Int, LeapfrogScalingFactor)
calcSamplesAndVariance AuxiliaryTuningParameters
xs = (AuxiliaryTuningParameters -> Int
forall a. Vector a -> Int
VB.length (AuxiliaryTuningParameters -> Int)
-> AuxiliaryTuningParameters -> Int
forall a b. (a -> b) -> a -> b
$ AuxiliaryTuningParameters -> AuxiliaryTuningParameters
forall a. Eq a => Vector a -> Vector a
VB.uniq (AuxiliaryTuningParameters -> AuxiliaryTuningParameters)
-> AuxiliaryTuningParameters -> AuxiliaryTuningParameters
forall a b. (a -> b) -> a -> b
$ AuxiliaryTuningParameters -> AuxiliaryTuningParameters
forall e (v :: * -> *). (Ord e, Vector v e) => v e -> v e
S.gsort AuxiliaryTuningParameters
xs, AuxiliaryTuningParameters -> LeapfrogScalingFactor
forall (v :: * -> *).
Vector v LeapfrogScalingFactor =>
v LeapfrogScalingFactor -> LeapfrogScalingFactor
S.variance AuxiliaryTuningParameters
xs)
    rescueWith :: LeapfrogScalingFactor
-> (Int, LeapfrogScalingFactor) -> LeapfrogScalingFactor
rescueWith LeapfrogScalingFactor
t (Int
sampleSize, LeapfrogScalingFactor
var) =
      if LeapfrogScalingFactor
var LeapfrogScalingFactor -> LeapfrogScalingFactor -> Bool
forall a. Ord a => a -> a -> Bool
< LeapfrogScalingFactor
minVariance Bool -> Bool -> Bool
|| LeapfrogScalingFactor
maxVariance LeapfrogScalingFactor -> LeapfrogScalingFactor -> Bool
forall a. Ord a => a -> a -> Bool
< LeapfrogScalingFactor
var Bool -> Bool -> Bool
|| Int
sampleSize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
minSamples
        then -- then traceShow ("Rescue with " <> show t) t
          LeapfrogScalingFactor
t
        else
          let t' :: LeapfrogScalingFactor
t' = LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Floating a => a -> a
sqrt (LeapfrogScalingFactor
t LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Fractional a => a -> a
recip LeapfrogScalingFactor
var)
           in -- in traceShow ("Old mass " <> show t <> " new mass " <> show t') t'
              LeapfrogScalingFactor
t'

-- | Hamiltonian Monte Carlo proposal.
--
-- The 'Applicative' and 'Traversable' instances are used for element-wise
-- operations.
--
-- Assume a zip-like 'Applicative' instance so that cardinality remains
-- constant.
--
-- NOTE: The desired acceptance rate is 0.65, although the dimension of the
-- proposal is high.
--
-- NOTE: The speed of this proposal can change drastically when tuned because
-- the leapfrog trajectory length is changed.
hamiltonian ::
  (Applicative f, Traversable f) =>
  -- | The sample state is used to calculate the dimension of the proposal.
  f Double ->
  HSettings f ->
  PName ->
  PWeight ->
  Proposal (f Double)
hamiltonian :: f LeapfrogScalingFactor
-> HSettings f
-> PName
-> PWeight
-> Proposal (f LeapfrogScalingFactor)
hamiltonian f LeapfrogScalingFactor
x HSettings f
s PName
n PWeight
w = case HSettings f -> Maybe String
forall (f :: * -> *). Foldable f => HSettings f -> Maybe String
checkHSettings HSettings f
s of
  Just String
err -> String -> Proposal (f LeapfrogScalingFactor)
forall a. HasCallStack => String -> a
error String
err
  Maybe String
Nothing ->
    let desc :: PDescription
desc = String -> PDescription
PDescription String
"Hamiltonian Monte Carlo (HMC)"
        dim :: PDimension
dim = Int -> LeapfrogScalingFactor -> PDimension
PSpecial (f LeapfrogScalingFactor -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length f LeapfrogScalingFactor
x) LeapfrogScalingFactor
0.65
        ts :: AuxiliaryTuningParameters
ts = Masses f -> AuxiliaryTuningParameters
forall (f :: * -> *).
Foldable f =>
Masses f -> AuxiliaryTuningParameters
massesToTuningParameters (HSettings f -> Masses f
forall (f :: * -> *). HSettings f -> Masses f
hMasses HSettings f
s)
        ps :: ProposalSimple (f LeapfrogScalingFactor)
ps = HSettings f -> ProposalSimple (f LeapfrogScalingFactor)
forall (f :: * -> *).
(Applicative f, Traversable f) =>
HSettings f -> ProposalSimple (Positions f)
hamiltonianSimple HSettings f
s
        p' :: Maybe (Tuner (f LeapfrogScalingFactor))
-> Proposal (f LeapfrogScalingFactor)
p' = PName
-> PDescription
-> PDimension
-> PWeight
-> ProposalSimple (f LeapfrogScalingFactor)
-> Maybe (Tuner (f LeapfrogScalingFactor))
-> Proposal (f LeapfrogScalingFactor)
forall a.
PName
-> PDescription
-> PDimension
-> PWeight
-> ProposalSimple a
-> Maybe (Tuner a)
-> Proposal a
Proposal PName
n PDescription
desc PDimension
dim PWeight
w f LeapfrogScalingFactor
-> Gen RealWorld -> IO (f LeapfrogScalingFactor, Prior, Prior)
ProposalSimple (f LeapfrogScalingFactor)
ps
        fT :: LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
fT = PDimension
-> LeapfrogScalingFactor
-> LeapfrogScalingFactor
-> LeapfrogScalingFactor
defaultTuningFunction PDimension
dim
        tS :: HTune
tS = HSettings f -> HTune
forall (f :: * -> *). HSettings f -> HTune
hTune HSettings f
s
        fTs :: Vector (f LeapfrogScalingFactor)
-> AuxiliaryTuningParameters -> AuxiliaryTuningParameters
fTs =
          if HTune
tS HTune -> HTune -> Bool
forall a. Eq a => a -> a -> Bool
== HTune
HTuneMassesAndLeapfrog
            then Vector (f LeapfrogScalingFactor)
-> AuxiliaryTuningParameters -> AuxiliaryTuningParameters
forall (f :: * -> *).
Foldable f =>
Vector (Positions f)
-> AuxiliaryTuningParameters -> AuxiliaryTuningParameters
computeAuxiliaryTuningParameters
            else \Vector (f LeapfrogScalingFactor)
_ AuxiliaryTuningParameters
xs -> AuxiliaryTuningParameters
xs
     in case HTune
tS of
          HTune
HNoTune -> Maybe (Tuner (f LeapfrogScalingFactor))
-> Proposal (f LeapfrogScalingFactor)
p' Maybe (Tuner (f LeapfrogScalingFactor))
forall a. Maybe a
Nothing
          HTune
_ -> Maybe (Tuner (f LeapfrogScalingFactor))
-> Proposal (f LeapfrogScalingFactor)
p' (Maybe (Tuner (f LeapfrogScalingFactor))
 -> Proposal (f LeapfrogScalingFactor))
-> Maybe (Tuner (f LeapfrogScalingFactor))
-> Proposal (f LeapfrogScalingFactor)
forall a b. (a -> b) -> a -> b
$ Tuner (f LeapfrogScalingFactor)
-> Maybe (Tuner (f LeapfrogScalingFactor))
forall a. a -> Maybe a
Just (Tuner (f LeapfrogScalingFactor)
 -> Maybe (Tuner (f LeapfrogScalingFactor)))
-> Tuner (f LeapfrogScalingFactor)
-> Maybe (Tuner (f LeapfrogScalingFactor))
forall a b. (a -> b) -> a -> b
$ LeapfrogScalingFactor
-> (LeapfrogScalingFactor
    -> LeapfrogScalingFactor -> LeapfrogScalingFactor)
-> AuxiliaryTuningParameters
-> (Vector (f LeapfrogScalingFactor)
    -> AuxiliaryTuningParameters -> AuxiliaryTuningParameters)
-> (LeapfrogScalingFactor
    -> AuxiliaryTuningParameters
    -> Either String (ProposalSimple (f LeapfrogScalingFactor)))
-> Tuner (f LeapfrogScalingFactor)
forall a.
LeapfrogScalingFactor
-> (LeapfrogScalingFactor
    -> LeapfrogScalingFactor -> LeapfrogScalingFactor)
-> AuxiliaryTuningParameters
-> (Vector a
    -> AuxiliaryTuningParameters -> AuxiliaryTuningParameters)
-> (LeapfrogScalingFactor
    -> AuxiliaryTuningParameters -> Either String (ProposalSimple a))
-> Tuner a
Tuner LeapfrogScalingFactor
1.0 LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
fT AuxiliaryTuningParameters
ts Vector (f LeapfrogScalingFactor)
-> AuxiliaryTuningParameters -> AuxiliaryTuningParameters
fTs (HSettings f
-> LeapfrogScalingFactor
-> AuxiliaryTuningParameters
-> Either String (ProposalSimple (f LeapfrogScalingFactor))
forall (f :: * -> *).
(Applicative f, Traversable f) =>
HSettings f
-> LeapfrogScalingFactor
-> AuxiliaryTuningParameters
-> Either String (ProposalSimple (Positions f))
hamiltonianSimpleWithTuningParameters HSettings f
s)