-- Creation date: Thu Jun  9 15:12:39 2022.
--
-- See "Mcmc.Proposal.Hamiltonian.Hamiltonian".
--
-- References:
--
-- - [1] Chapter 5 of Handbook of Monte Carlo: Neal, R. M., MCMC Using
--   Hamiltonian Dynamics, In S. Brooks, A. Gelman, G. Jones, & X. Meng (Eds.),
--   Handbook of Markov Chain Monte Carlo (2011), CRC press.
--
-- - [2] Gelman, A., Carlin, J. B., Stern, H. S., & Rubin, D. B., Bayesian data
--   analysis (2014), CRC Press.
--
-- - [3] Review by Betancourt and notes: Betancourt, M., A conceptual
--   introduction to Hamiltonian Monte Carlo, arXiv, 1701–02434 (2017).
--
-- - [4] Matthew D. Hoffman, Andrew Gelman (2014) The No-U-Turn Sampler:
--   Adaptively Setting Path Lengths in Hamiltonian Monte Carlo, Journal of
--   Machine Learning Research.

-- |
-- Module      :  Mcmc.Proposal.Hamiltonian.Internal
-- Description :  Internal definitions related to Hamiltonian dynamics
-- Copyright   :  2022 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  experimental
-- Portability :  portable
module Mcmc.Proposal.Hamiltonian.Internal
  ( -- * Parameters
    HParamsI (..),
    hParamsIWith,

    -- * Tuning
    toAuxiliaryTuningParameters,
    fromAuxiliaryTuningParameters,
    findReasonableEpsilon,
    hTuningFunctionWith,

    -- * Structure of state
    checkHStructureWith,

    -- * Hamiltonian dynamics
    generateMomenta,
    exponentialKineticEnergy,

    -- * Leapfrog integrator
    Target,
    leapfrog,
  )
where

import Control.Monad
import Control.Monad.ST
import Data.Foldable
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Unboxed as VU
import Mcmc.Proposal
import Mcmc.Proposal.Hamiltonian.Common
import Mcmc.Proposal.Hamiltonian.Masses
import qualified Numeric.LinearAlgebra as L
import Numeric.Log
import System.Random.Stateful

-- Variable tuning parameters.
--
-- See Algorithm 5 or 6 in [4].
data TParamsVar = TParamsVar
  { -- \bar{eps} of Algorithm 5 or 6.
    TParamsVar -> LeapfrogScalingFactor
tpvLeapfrogScalingFactorMean :: LeapfrogScalingFactor,
    -- H_i of Algorithm 5 or 6.
    TParamsVar -> LeapfrogScalingFactor
tpvHStatistics :: Double,
    -- m of Algorithm 5 or 6.
    TParamsVar -> LeapfrogScalingFactor
tpvCurrentTuningStep :: Double
  }
  deriving (Int -> TParamsVar -> ShowS
[TParamsVar] -> ShowS
TParamsVar -> String
(Int -> TParamsVar -> ShowS)
-> (TParamsVar -> String)
-> ([TParamsVar] -> ShowS)
-> Show TParamsVar
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TParamsVar] -> ShowS
$cshowList :: [TParamsVar] -> ShowS
show :: TParamsVar -> String
$cshow :: TParamsVar -> String
showsPrec :: Int -> TParamsVar -> ShowS
$cshowsPrec :: Int -> TParamsVar -> ShowS
Show)

tParamsVar :: TParamsVar
tParamsVar :: TParamsVar
tParamsVar = LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor -> TParamsVar
TParamsVar LeapfrogScalingFactor
1.0 LeapfrogScalingFactor
0.0 LeapfrogScalingFactor
1.0

-- Fixed tuning parameters.
--
-- See Algorithm 5 and 6 in [4].
data TParamsFixed = TParamsFixed
  { TParamsFixed -> LeapfrogScalingFactor
tpfEps0 :: Double,
    TParamsFixed -> LeapfrogScalingFactor
tpfMu :: Double,
    TParamsFixed -> LeapfrogScalingFactor
tpfGa :: Double,
    TParamsFixed -> LeapfrogScalingFactor
tpfT0 :: Double,
    TParamsFixed -> LeapfrogScalingFactor
tpfKa :: Double
  }
  deriving (Int -> TParamsFixed -> ShowS
[TParamsFixed] -> ShowS
TParamsFixed -> String
(Int -> TParamsFixed -> ShowS)
-> (TParamsFixed -> String)
-> ([TParamsFixed] -> ShowS)
-> Show TParamsFixed
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TParamsFixed] -> ShowS
$cshowList :: [TParamsFixed] -> ShowS
show :: TParamsFixed -> String
$cshow :: TParamsFixed -> String
showsPrec :: Int -> TParamsFixed -> ShowS
$cshowsPrec :: Int -> TParamsFixed -> ShowS
Show)

-- The default tuning parameters in [4] are:
--
--   mu = log $ 10 * eps
--   ga = 0.05
--   t0 = 10
--   ka = 0.75
--
-- However, these default tuning parameters will be off, because the authors
-- suggesting these values tune the proposal after every single iteration.
--
-- The following values are tweaked for our case, where tuning does not happen
-- after each iteration. Of course, we could tune the leapfrog parameters after
-- each generation. Even the mass parameters could be tuned each iteration when
-- the masses are estimated from more past iterations spanning many tuning
-- intervals.
--
-- NOTE: In theory, these we could expose these internal tuning parameters to
-- the user.
tParamsFixedWith :: LeapfrogScalingFactor -> TParamsFixed
tParamsFixedWith :: LeapfrogScalingFactor -> TParamsFixed
tParamsFixedWith LeapfrogScalingFactor
eps = LeapfrogScalingFactor
-> LeapfrogScalingFactor
-> LeapfrogScalingFactor
-> LeapfrogScalingFactor
-> LeapfrogScalingFactor
-> TParamsFixed
TParamsFixed LeapfrogScalingFactor
eps LeapfrogScalingFactor
mu LeapfrogScalingFactor
ga LeapfrogScalingFactor
t0 LeapfrogScalingFactor
ka
  where
    mu :: LeapfrogScalingFactor
mu = LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Floating a => a -> a
log (LeapfrogScalingFactor -> LeapfrogScalingFactor)
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a b. (a -> b) -> a -> b
$ LeapfrogScalingFactor
10 LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
eps
    ga :: LeapfrogScalingFactor
ga = LeapfrogScalingFactor
0.1
    t0 :: LeapfrogScalingFactor
t0 = LeapfrogScalingFactor
3
    ka :: LeapfrogScalingFactor
ka = LeapfrogScalingFactor
0.5

-- All internal parameters.
data HParamsI = HParamsI
  { HParamsI -> LeapfrogScalingFactor
hpsLeapfrogScalingFactor :: LeapfrogScalingFactor,
    HParamsI -> LeapfrogScalingFactor
hpsLeapfrogSimulationLength :: LeapfrogSimulationLength,
    HParamsI -> Masses
hpsMasses :: Masses,
    HParamsI -> TParamsVar
hpsTParamsVar :: TParamsVar,
    HParamsI -> TParamsFixed
hpsTParamsFixed :: TParamsFixed,
    HParamsI -> MassesI
hpsMassesI :: MassesI,
    HParamsI -> Mu
hpsMu :: Mu
  }
  deriving (Int -> HParamsI -> ShowS
[HParamsI] -> ShowS
HParamsI -> String
(Int -> HParamsI -> ShowS)
-> (HParamsI -> String) -> ([HParamsI] -> ShowS) -> Show HParamsI
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HParamsI] -> ShowS
$cshowList :: [HParamsI] -> ShowS
show :: HParamsI -> String
$cshow :: HParamsI -> String
showsPrec :: Int -> HParamsI -> ShowS
$cshowsPrec :: Int -> HParamsI -> ShowS
Show)

-- NOTE: If changed, amend help text of 'defaultHParams', and 'defaultNParams'.
defaultLeapfrogScalingFactor :: LeapfrogScalingFactor
defaultLeapfrogScalingFactor :: LeapfrogScalingFactor
defaultLeapfrogScalingFactor = LeapfrogScalingFactor
0.1

-- NOTE: If changed, amend help text of 'defaultHParams'.
defaultLeapfrogSimulationLength :: LeapfrogSimulationLength
defaultLeapfrogSimulationLength :: LeapfrogScalingFactor
defaultLeapfrogSimulationLength = LeapfrogScalingFactor
0.5

-- NOTE: If changed, amend help text of 'defaultHParams'.
defaultMassesWith :: Int -> Masses
defaultMassesWith :: Int -> Masses
defaultMassesWith Int
d = Matrix LeapfrogScalingFactor -> Masses
forall t. Matrix t -> Herm t
L.trustSym (Matrix LeapfrogScalingFactor -> Masses)
-> Matrix LeapfrogScalingFactor -> Masses
forall a b. (a -> b) -> a -> b
$ Int -> Matrix LeapfrogScalingFactor
forall a. (Num a, Element a) => Int -> Matrix a
L.ident Int
d

-- Instantiate all internal parameters.
hParamsIWith ::
  Target ->
  Positions ->
  Maybe LeapfrogScalingFactor ->
  Maybe LeapfrogSimulationLength ->
  Maybe Masses ->
  Either String HParamsI
hParamsIWith :: Target
-> Mu
-> Maybe LeapfrogScalingFactor
-> Maybe LeapfrogScalingFactor
-> Maybe Masses
-> Either String HParamsI
hParamsIWith Target
htarget Mu
p Maybe LeapfrogScalingFactor
mEps Maybe LeapfrogScalingFactor
mLa Maybe Masses
mMs = do
  Int
d <- case Mu -> Int
forall a. Storable a => Vector a -> Int
VS.length Mu
p of
    Int
0 -> String -> Either String Int
forall b. String -> Either String b
eWith String
"Empty position vector."
    Int
d -> Int -> Either String Int
forall a b. b -> Either a b
Right Int
d
  Masses
ms <- case Maybe Masses
mMs of
    Maybe Masses
Nothing -> Masses -> Either String Masses
forall a b. b -> Either a b
Right (Masses -> Either String Masses) -> Masses -> Either String Masses
forall a b. (a -> b) -> a -> b
$ Int -> Masses
defaultMassesWith Int
d
    Just Masses
ms -> do
      let ms' :: Matrix LeapfrogScalingFactor
ms' = Matrix LeapfrogScalingFactor -> Matrix LeapfrogScalingFactor
cleanMatrix (Matrix LeapfrogScalingFactor -> Matrix LeapfrogScalingFactor)
-> Matrix LeapfrogScalingFactor -> Matrix LeapfrogScalingFactor
forall a b. (a -> b) -> a -> b
$ Masses -> Matrix LeapfrogScalingFactor
forall t. Herm t -> Matrix t
L.unSym Masses
ms
          diagonalMs :: [LeapfrogScalingFactor]
diagonalMs = Mu -> [LeapfrogScalingFactor]
forall a. Storable a => Vector a -> [a]
L.toList (Mu -> [LeapfrogScalingFactor]) -> Mu -> [LeapfrogScalingFactor]
forall a b. (a -> b) -> a -> b
$ Matrix LeapfrogScalingFactor -> Mu
forall t. Element t => Matrix t -> Vector t
L.takeDiag Matrix LeapfrogScalingFactor
ms'
      Bool -> Either String () -> Either String ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when ((LeapfrogScalingFactor -> Bool) -> [LeapfrogScalingFactor] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (LeapfrogScalingFactor -> LeapfrogScalingFactor -> Bool
forall a. Ord a => a -> a -> Bool
<= LeapfrogScalingFactor
0) [LeapfrogScalingFactor]
diagonalMs) (Either String () -> Either String ())
-> Either String () -> Either String ()
forall a b. (a -> b) -> a -> b
$ String -> Either String ()
forall b. String -> Either String b
eWith String
"Some diagonal masses are zero or negative."
      let nrows :: Int
nrows = Matrix LeapfrogScalingFactor -> Int
forall t. Matrix t -> Int
L.rows Matrix LeapfrogScalingFactor
ms'
          ncols :: Int
ncols = Matrix LeapfrogScalingFactor -> Int
forall t. Matrix t -> Int
L.cols Matrix LeapfrogScalingFactor
ms'
      Bool -> Either String () -> Either String ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
nrows Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
ncols) (Either String () -> Either String ())
-> Either String () -> Either String ()
forall a b. (a -> b) -> a -> b
$ String -> Either String ()
forall b. String -> Either String b
eWith String
"Mass matrix is not square."
      Masses -> Either String Masses
forall a b. b -> Either a b
Right Masses
ms
  let msI :: MassesI
msI = Masses -> MassesI
getMassesI Masses
ms
      mus :: Mu
mus = Masses -> Mu
getMus Masses
ms
  LeapfrogScalingFactor
la <- case Maybe LeapfrogScalingFactor
mLa of
    Maybe LeapfrogScalingFactor
Nothing -> LeapfrogScalingFactor -> Either String LeapfrogScalingFactor
forall a b. b -> Either a b
Right LeapfrogScalingFactor
defaultLeapfrogSimulationLength
    Just LeapfrogScalingFactor
l
      | LeapfrogScalingFactor
l LeapfrogScalingFactor -> LeapfrogScalingFactor -> Bool
forall a. Ord a => a -> a -> Bool
<= LeapfrogScalingFactor
0 -> String -> Either String LeapfrogScalingFactor
forall b. String -> Either String b
eWith String
"Leapfrog simulation length is zero or negative."
      | Bool
otherwise -> LeapfrogScalingFactor -> Either String LeapfrogScalingFactor
forall a b. b -> Either a b
Right LeapfrogScalingFactor
l
  LeapfrogScalingFactor
eps <- case Maybe LeapfrogScalingFactor
mEps of
    Maybe LeapfrogScalingFactor
Nothing -> LeapfrogScalingFactor -> Either String LeapfrogScalingFactor
forall a b. b -> Either a b
Right (LeapfrogScalingFactor -> Either String LeapfrogScalingFactor)
-> LeapfrogScalingFactor -> Either String LeapfrogScalingFactor
forall a b. (a -> b) -> a -> b
$ (forall s. ST s LeapfrogScalingFactor) -> LeapfrogScalingFactor
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s LeapfrogScalingFactor) -> LeapfrogScalingFactor)
-> (forall s. ST s LeapfrogScalingFactor) -> LeapfrogScalingFactor
forall a b. (a -> b) -> a -> b
$ do
      -- NOTE: This is not random. However, I do not want to provide a generator
      -- when creating the proposal.
      STGenM StdGen s
g <- StdGen -> ST s (STGenM StdGen s)
forall g s. g -> ST s (STGenM g s)
newSTGenM (StdGen -> ST s (STGenM StdGen s))
-> StdGen -> ST s (STGenM StdGen s)
forall a b. (a -> b) -> a -> b
$ Int -> StdGen
mkStdGen Int
42
      Target
-> Masses -> Mu -> STGenM StdGen s -> ST s LeapfrogScalingFactor
forall g (m :: * -> *).
StatefulGen g m =>
Target -> Masses -> Mu -> g -> m LeapfrogScalingFactor
findReasonableEpsilon Target
htarget Masses
ms Mu
p STGenM StdGen s
g
    Just LeapfrogScalingFactor
e
      | LeapfrogScalingFactor
e LeapfrogScalingFactor -> LeapfrogScalingFactor -> Bool
forall a. Ord a => a -> a -> Bool
<= LeapfrogScalingFactor
0 -> String -> Either String LeapfrogScalingFactor
forall b. String -> Either String b
eWith String
"Leapfrog scaling factor is zero or negative."
      | Bool
otherwise -> LeapfrogScalingFactor -> Either String LeapfrogScalingFactor
forall a b. b -> Either a b
Right LeapfrogScalingFactor
e
  let tParamsFixed :: TParamsFixed
tParamsFixed = LeapfrogScalingFactor -> TParamsFixed
tParamsFixedWith LeapfrogScalingFactor
eps
  HParamsI -> Either String HParamsI
forall (f :: * -> *) a. Applicative f => a -> f a
pure (HParamsI -> Either String HParamsI)
-> HParamsI -> Either String HParamsI
forall a b. (a -> b) -> a -> b
$ LeapfrogScalingFactor
-> LeapfrogScalingFactor
-> Masses
-> TParamsVar
-> TParamsFixed
-> MassesI
-> Mu
-> HParamsI
HParamsI LeapfrogScalingFactor
eps LeapfrogScalingFactor
la Masses
ms TParamsVar
tParamsVar TParamsFixed
tParamsFixed MassesI
msI Mu
mus
  where
    eWith :: String -> Either String b
eWith String
m = String -> Either String b
forall a b. a -> Either a b
Left (String -> Either String b) -> String -> Either String b
forall a b. (a -> b) -> a -> b
$ String
"hParamsIWith: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
m

-- Save internal parameters.
toAuxiliaryTuningParameters :: HParamsI -> AuxiliaryTuningParameters
toAuxiliaryTuningParameters :: HParamsI -> AuxiliaryTuningParameters
toAuxiliaryTuningParameters (HParamsI LeapfrogScalingFactor
eps LeapfrogScalingFactor
la Masses
ms TParamsVar
tpv TParamsFixed
tpf MassesI
_ Mu
_) =
  -- Put masses to the end. Like so, conversion is easier.
  [LeapfrogScalingFactor] -> AuxiliaryTuningParameters
forall a. Unbox a => [a] -> Vector a
VU.fromList ([LeapfrogScalingFactor] -> AuxiliaryTuningParameters)
-> [LeapfrogScalingFactor] -> AuxiliaryTuningParameters
forall a b. (a -> b) -> a -> b
$ LeapfrogScalingFactor
eps LeapfrogScalingFactor
-> [LeapfrogScalingFactor] -> [LeapfrogScalingFactor]
forall a. a -> [a] -> [a]
: LeapfrogScalingFactor
la LeapfrogScalingFactor
-> [LeapfrogScalingFactor] -> [LeapfrogScalingFactor]
forall a. a -> [a] -> [a]
: LeapfrogScalingFactor
epsMean LeapfrogScalingFactor
-> [LeapfrogScalingFactor] -> [LeapfrogScalingFactor]
forall a. a -> [a] -> [a]
: LeapfrogScalingFactor
h LeapfrogScalingFactor
-> [LeapfrogScalingFactor] -> [LeapfrogScalingFactor]
forall a. a -> [a] -> [a]
: LeapfrogScalingFactor
m LeapfrogScalingFactor
-> [LeapfrogScalingFactor] -> [LeapfrogScalingFactor]
forall a. a -> [a] -> [a]
: LeapfrogScalingFactor
eps0 LeapfrogScalingFactor
-> [LeapfrogScalingFactor] -> [LeapfrogScalingFactor]
forall a. a -> [a] -> [a]
: LeapfrogScalingFactor
mu LeapfrogScalingFactor
-> [LeapfrogScalingFactor] -> [LeapfrogScalingFactor]
forall a. a -> [a] -> [a]
: LeapfrogScalingFactor
ga LeapfrogScalingFactor
-> [LeapfrogScalingFactor] -> [LeapfrogScalingFactor]
forall a. a -> [a] -> [a]
: LeapfrogScalingFactor
t0 LeapfrogScalingFactor
-> [LeapfrogScalingFactor] -> [LeapfrogScalingFactor]
forall a. a -> [a] -> [a]
: LeapfrogScalingFactor
ka LeapfrogScalingFactor
-> [LeapfrogScalingFactor] -> [LeapfrogScalingFactor]
forall a. a -> [a] -> [a]
: [LeapfrogScalingFactor]
msL
  where
    (TParamsVar LeapfrogScalingFactor
epsMean LeapfrogScalingFactor
h LeapfrogScalingFactor
m) = TParamsVar
tpv
    (TParamsFixed LeapfrogScalingFactor
eps0 LeapfrogScalingFactor
mu LeapfrogScalingFactor
ga LeapfrogScalingFactor
t0 LeapfrogScalingFactor
ka) = TParamsFixed
tpf
    msL :: [LeapfrogScalingFactor]
msL = AuxiliaryTuningParameters -> [LeapfrogScalingFactor]
forall a. Unbox a => Vector a -> [a]
VU.toList (AuxiliaryTuningParameters -> [LeapfrogScalingFactor])
-> AuxiliaryTuningParameters -> [LeapfrogScalingFactor]
forall a b. (a -> b) -> a -> b
$ Masses -> AuxiliaryTuningParameters
massesToVector Masses
ms

-- Load internal parameters.
fromAuxiliaryTuningParameters :: Dimension -> AuxiliaryTuningParameters -> Either String HParamsI
fromAuxiliaryTuningParameters :: Int -> AuxiliaryTuningParameters -> Either String HParamsI
fromAuxiliaryTuningParameters Int
d AuxiliaryTuningParameters
xs
  | (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
d) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
10 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
len = String -> Either String HParamsI
forall a b. a -> Either a b
Left String
"fromAuxiliaryTuningParameters: Dimension mismatch."
  | Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
d) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
lenMs = String -> Either String HParamsI
forall a b. a -> Either a b
Left String
"fromAuxiliaryTuningParameters: Masses dimension mismatch."
  | Bool
otherwise = case AuxiliaryTuningParameters -> [LeapfrogScalingFactor]
forall a. Unbox a => Vector a -> [a]
VU.toList (AuxiliaryTuningParameters -> [LeapfrogScalingFactor])
-> AuxiliaryTuningParameters -> [LeapfrogScalingFactor]
forall a b. (a -> b) -> a -> b
$ Int -> AuxiliaryTuningParameters -> AuxiliaryTuningParameters
forall a. Unbox a => Int -> Vector a -> Vector a
VU.take Int
10 AuxiliaryTuningParameters
xs of
      [LeapfrogScalingFactor
eps, LeapfrogScalingFactor
la, LeapfrogScalingFactor
epsMean, LeapfrogScalingFactor
h, LeapfrogScalingFactor
m, LeapfrogScalingFactor
eps0, LeapfrogScalingFactor
mu, LeapfrogScalingFactor
ga, LeapfrogScalingFactor
t0, LeapfrogScalingFactor
ka] ->
        let tpv :: TParamsVar
tpv = LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor -> TParamsVar
TParamsVar LeapfrogScalingFactor
epsMean LeapfrogScalingFactor
h LeapfrogScalingFactor
m
            tpf :: TParamsFixed
tpf = LeapfrogScalingFactor
-> LeapfrogScalingFactor
-> LeapfrogScalingFactor
-> LeapfrogScalingFactor
-> LeapfrogScalingFactor
-> TParamsFixed
TParamsFixed LeapfrogScalingFactor
eps0 LeapfrogScalingFactor
mu LeapfrogScalingFactor
ga LeapfrogScalingFactor
t0 LeapfrogScalingFactor
ka
         in HParamsI -> Either String HParamsI
forall a b. b -> Either a b
Right (HParamsI -> Either String HParamsI)
-> HParamsI -> Either String HParamsI
forall a b. (a -> b) -> a -> b
$ LeapfrogScalingFactor
-> LeapfrogScalingFactor
-> Masses
-> TParamsVar
-> TParamsFixed
-> MassesI
-> Mu
-> HParamsI
HParamsI LeapfrogScalingFactor
eps LeapfrogScalingFactor
la Masses
ms TParamsVar
tpv TParamsFixed
tpf MassesI
msI Mu
mus
      -- To please the exhaustive pattern match checker.
      [LeapfrogScalingFactor]
_ -> String -> Either String HParamsI
forall a b. a -> Either a b
Left String
"fromAuxiliaryTuningParameters: Impossible dimension mismatch."
  where
    len :: Int
len = AuxiliaryTuningParameters -> Int
forall a. Unbox a => Vector a -> Int
VU.length AuxiliaryTuningParameters
xs
    msV :: AuxiliaryTuningParameters
msV = Int -> AuxiliaryTuningParameters -> AuxiliaryTuningParameters
forall a. Unbox a => Int -> Vector a -> Vector a
VU.drop Int
10 AuxiliaryTuningParameters
xs
    lenMs :: Int
lenMs = AuxiliaryTuningParameters -> Int
forall a. Unbox a => Vector a -> Int
VU.length AuxiliaryTuningParameters
msV
    ms :: Masses
ms = Int -> AuxiliaryTuningParameters -> Masses
vectorToMasses Int
d AuxiliaryTuningParameters
msV
    msI :: MassesI
msI = Masses -> MassesI
getMassesI Masses
ms
    mus :: Mu
mus = Masses -> Mu
getMus Masses
ms

-- See Algorithm 4 in [4].
findReasonableEpsilon ::
  StatefulGen g m =>
  Target ->
  Masses ->
  Positions ->
  g ->
  m LeapfrogScalingFactor
findReasonableEpsilon :: Target -> Masses -> Mu -> g -> m LeapfrogScalingFactor
findReasonableEpsilon Target
t Masses
ms Mu
q g
g = do
  Mu
p <- Mu -> Masses -> g -> m Mu
forall g (m :: * -> *).
StatefulGen g m =>
Mu -> Masses -> g -> m Mu
generateMomenta Mu
mu Masses
ms g
g
  case Target
-> MassesI
-> Int
-> LeapfrogScalingFactor
-> Mu
-> Mu
-> Maybe
     (Mu, Mu, Log LeapfrogScalingFactor, Log LeapfrogScalingFactor)
leapfrog Target
t MassesI
msI Int
1 LeapfrogScalingFactor
eI Mu
q Mu
p of
    Maybe
  (Mu, Mu, Log LeapfrogScalingFactor, Log LeapfrogScalingFactor)
Nothing -> LeapfrogScalingFactor -> m LeapfrogScalingFactor
forall (f :: * -> *) a. Applicative f => a -> f a
pure LeapfrogScalingFactor
defaultLeapfrogScalingFactor
    Just (Mu
_, Mu
p', Log LeapfrogScalingFactor
prQ, Log LeapfrogScalingFactor
prQ') -> do
      let expEKin :: Log LeapfrogScalingFactor
expEKin = MassesI -> Mu -> Log LeapfrogScalingFactor
exponentialKineticEnergy MassesI
msI Mu
p
          expEKin' :: Log LeapfrogScalingFactor
expEKin' = MassesI -> Mu -> Log LeapfrogScalingFactor
exponentialKineticEnergy MassesI
msI Mu
p'
          rI :: Double
          rI :: LeapfrogScalingFactor
rI = LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Floating a => a -> a
exp (LeapfrogScalingFactor -> LeapfrogScalingFactor)
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a b. (a -> b) -> a -> b
$ Log LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Log a -> a
ln (Log LeapfrogScalingFactor -> LeapfrogScalingFactor)
-> Log LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a b. (a -> b) -> a -> b
$ Log LeapfrogScalingFactor
prQ' Log LeapfrogScalingFactor
-> Log LeapfrogScalingFactor -> Log LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* Log LeapfrogScalingFactor
expEKin' Log LeapfrogScalingFactor
-> Log LeapfrogScalingFactor -> Log LeapfrogScalingFactor
forall a. Fractional a => a -> a -> a
/ (Log LeapfrogScalingFactor
prQ Log LeapfrogScalingFactor
-> Log LeapfrogScalingFactor -> Log LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* Log LeapfrogScalingFactor
expEKin)
          a :: Double
          a :: LeapfrogScalingFactor
a = if LeapfrogScalingFactor
rI LeapfrogScalingFactor -> LeapfrogScalingFactor -> Bool
forall a. Ord a => a -> a -> Bool
> LeapfrogScalingFactor
0.5 then LeapfrogScalingFactor
1 else (-LeapfrogScalingFactor
1)
          go :: LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
go LeapfrogScalingFactor
e LeapfrogScalingFactor
r =
            if LeapfrogScalingFactor
r LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Floating a => a -> a -> a
** LeapfrogScalingFactor
a LeapfrogScalingFactor -> LeapfrogScalingFactor -> Bool
forall a. Ord a => a -> a -> Bool
> LeapfrogScalingFactor
2 LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Floating a => a -> a -> a
** (LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a
negate LeapfrogScalingFactor
a)
              then case Target
-> MassesI
-> Int
-> LeapfrogScalingFactor
-> Mu
-> Mu
-> Maybe
     (Mu, Mu, Log LeapfrogScalingFactor, Log LeapfrogScalingFactor)
leapfrog Target
t MassesI
msI Int
1 LeapfrogScalingFactor
e Mu
q Mu
p of
                Maybe
  (Mu, Mu, Log LeapfrogScalingFactor, Log LeapfrogScalingFactor)
Nothing -> LeapfrogScalingFactor
e
                Just (Mu
_, Mu
p'', Log LeapfrogScalingFactor
_, Log LeapfrogScalingFactor
prQ'') ->
                  let expEKin'' :: Log LeapfrogScalingFactor
expEKin'' = MassesI -> Mu -> Log LeapfrogScalingFactor
exponentialKineticEnergy MassesI
msI Mu
p''
                      r' :: Double
                      r' :: LeapfrogScalingFactor
r' = LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Floating a => a -> a
exp (LeapfrogScalingFactor -> LeapfrogScalingFactor)
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a b. (a -> b) -> a -> b
$ Log LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Log a -> a
ln (Log LeapfrogScalingFactor -> LeapfrogScalingFactor)
-> Log LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a b. (a -> b) -> a -> b
$ Log LeapfrogScalingFactor
prQ'' Log LeapfrogScalingFactor
-> Log LeapfrogScalingFactor -> Log LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* Log LeapfrogScalingFactor
expEKin'' Log LeapfrogScalingFactor
-> Log LeapfrogScalingFactor -> Log LeapfrogScalingFactor
forall a. Fractional a => a -> a -> a
/ (Log LeapfrogScalingFactor
prQ Log LeapfrogScalingFactor
-> Log LeapfrogScalingFactor -> Log LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* Log LeapfrogScalingFactor
expEKin)
                      e' :: LeapfrogScalingFactor
e' = (LeapfrogScalingFactor
2 LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Floating a => a -> a -> a
** LeapfrogScalingFactor
a) LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
e
                   in LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
go LeapfrogScalingFactor
e' LeapfrogScalingFactor
r'
              else LeapfrogScalingFactor
e
      LeapfrogScalingFactor -> m LeapfrogScalingFactor
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LeapfrogScalingFactor -> m LeapfrogScalingFactor)
-> LeapfrogScalingFactor -> m LeapfrogScalingFactor
forall a b. (a -> b) -> a -> b
$ LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
go LeapfrogScalingFactor
eI LeapfrogScalingFactor
rI
  where
    eI :: LeapfrogScalingFactor
eI = LeapfrogScalingFactor
1.0
    msI :: MassesI
msI = Masses -> MassesI
getMassesI Masses
ms
    mu :: Mu
mu = Masses -> Mu
getMus Masses
ms

hTuningFunctionWith ::
  Dimension ->
  -- Conversion from value to vector.
  (a -> Positions) ->
  HTuningConf ->
  Maybe (TuningFunction a)
hTuningFunctionWith :: Int -> (a -> Mu) -> HTuningConf -> Maybe (TuningFunction a)
hTuningFunctionWith Int
n a -> Mu
toVec (HTuningConf HTuneLeapfrog
lc HTuneMasses
mc) = case (HTuneLeapfrog
lc, HTuneMasses
mc) of
  (HTuneLeapfrog
HNoTuneLeapfrog, HTuneMasses
HNoTuneMasses) -> Maybe (TuningFunction a)
forall a. Maybe a
Nothing
  (HTuneLeapfrog
_, HTuneMasses
_) -> TuningFunction a -> Maybe (TuningFunction a)
forall a. a -> Maybe a
Just (TuningFunction a -> Maybe (TuningFunction a))
-> TuningFunction a -> Maybe (TuningFunction a)
forall a b. (a -> b) -> a -> b
$
    \TuningType
tt PDimension
pdim LeapfrogScalingFactor
ar Maybe (Vector a)
mxs (LeapfrogScalingFactor
_, AuxiliaryTuningParameters
ts) ->
      case Maybe (Vector a)
mxs of
        Maybe (Vector a)
Nothing -> String -> (LeapfrogScalingFactor, AuxiliaryTuningParameters)
forall a. HasCallStack => String -> a
error String
"hTuningFunctionWith: empty trace"
        Just Vector a
xs ->
          let (HParamsI LeapfrogScalingFactor
eps LeapfrogScalingFactor
la Masses
ms TParamsVar
tpv TParamsFixed
tpf MassesI
msI Mu
mus) =
                -- NOTE: Use error here, because a dimension mismatch is a serious bug.
                (String -> HParamsI)
-> (HParamsI -> HParamsI) -> Either String HParamsI -> HParamsI
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> HParamsI
forall a. HasCallStack => String -> a
error HParamsI -> HParamsI
forall a. a -> a
id (Either String HParamsI -> HParamsI)
-> Either String HParamsI -> HParamsI
forall a b. (a -> b) -> a -> b
$ Int -> AuxiliaryTuningParameters -> Either String HParamsI
fromAuxiliaryTuningParameters Int
n AuxiliaryTuningParameters
ts
              (TParamsVar LeapfrogScalingFactor
epsMean LeapfrogScalingFactor
h LeapfrogScalingFactor
m) = TParamsVar
tpv
              (TParamsFixed LeapfrogScalingFactor
eps0 LeapfrogScalingFactor
mu LeapfrogScalingFactor
ga LeapfrogScalingFactor
t0 LeapfrogScalingFactor
ka) = TParamsFixed
tpf
              (Masses
ms', MassesI
msI') = case HTuneMasses
mc of
                HTuneMasses
HNoTuneMasses -> (Masses
ms, MassesI
msI)
                HTuneMasses
HTuneDiagonalMassesOnly -> (a -> Mu) -> Vector a -> (Masses, MassesI) -> (Masses, MassesI)
forall a.
(a -> Mu) -> Vector a -> (Masses, MassesI) -> (Masses, MassesI)
tuneDiagonalMassesOnly a -> Mu
toVec Vector a
xs (Masses
ms, MassesI
msI)
                HTuneMasses
HTuneAllMasses -> (a -> Mu) -> Vector a -> (Masses, MassesI) -> (Masses, MassesI)
forall a.
(a -> Mu) -> Vector a -> (Masses, MassesI) -> (Masses, MassesI)
tuneAllMasses a -> Mu
toVec Vector a
xs (Masses
ms, MassesI
msI)
              (LeapfrogScalingFactor
eps'', LeapfrogScalingFactor
epsMean'', LeapfrogScalingFactor
h'') = case HTuneLeapfrog
lc of
                HTuneLeapfrog
HNoTuneLeapfrog -> (LeapfrogScalingFactor
eps, LeapfrogScalingFactor
epsMean, LeapfrogScalingFactor
h)
                HTuneLeapfrog
HTuneLeapfrog ->
                  let delta :: LeapfrogScalingFactor
delta = PDimension -> LeapfrogScalingFactor
getOptimalRate PDimension
pdim
                      c :: LeapfrogScalingFactor
c = LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Fractional a => a -> a
recip (LeapfrogScalingFactor -> LeapfrogScalingFactor)
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a b. (a -> b) -> a -> b
$ LeapfrogScalingFactor
m LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
+ LeapfrogScalingFactor
t0
                      h' :: LeapfrogScalingFactor
h' = (LeapfrogScalingFactor
1.0 LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
- LeapfrogScalingFactor
c) LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
h LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
+ LeapfrogScalingFactor
c LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* (LeapfrogScalingFactor
delta LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
- LeapfrogScalingFactor
ar)
                      logEps' :: LeapfrogScalingFactor
logEps' = LeapfrogScalingFactor
mu LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
- (LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Floating a => a -> a
sqrt LeapfrogScalingFactor
m LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Fractional a => a -> a -> a
/ LeapfrogScalingFactor
ga) LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
h'
                      eps' :: LeapfrogScalingFactor
eps' = LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Floating a => a -> a
exp LeapfrogScalingFactor
logEps'
                      mMKa :: LeapfrogScalingFactor
mMKa = LeapfrogScalingFactor
m LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Floating a => a -> a -> a
** (LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a
negate LeapfrogScalingFactor
ka)
                      epsMean' :: LeapfrogScalingFactor
epsMean' = LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Floating a => a -> a
exp (LeapfrogScalingFactor -> LeapfrogScalingFactor)
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a b. (a -> b) -> a -> b
$ LeapfrogScalingFactor
mMKa LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
logEps' LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
+ (LeapfrogScalingFactor
1 LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
- LeapfrogScalingFactor
mMKa) LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Floating a => a -> a
log LeapfrogScalingFactor
epsMean
                   in (LeapfrogScalingFactor
eps', LeapfrogScalingFactor
epsMean', LeapfrogScalingFactor
h')
              eps''' :: LeapfrogScalingFactor
eps''' = case TuningType
tt of
                TuningType
NormalTuningStep -> LeapfrogScalingFactor
eps''
                TuningType
LastTuningStep -> LeapfrogScalingFactor
epsMean''
              tpv' :: TParamsVar
tpv' = LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor -> TParamsVar
TParamsVar LeapfrogScalingFactor
epsMean'' LeapfrogScalingFactor
h'' (LeapfrogScalingFactor
m LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
+ LeapfrogScalingFactor
1.0)
           in (LeapfrogScalingFactor
eps''' LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Fractional a => a -> a -> a
/ LeapfrogScalingFactor
eps0, HParamsI -> AuxiliaryTuningParameters
toAuxiliaryTuningParameters (HParamsI -> AuxiliaryTuningParameters)
-> HParamsI -> AuxiliaryTuningParameters
forall a b. (a -> b) -> a -> b
$ LeapfrogScalingFactor
-> LeapfrogScalingFactor
-> Masses
-> TParamsVar
-> TParamsFixed
-> MassesI
-> Mu
-> HParamsI
HParamsI LeapfrogScalingFactor
eps''' LeapfrogScalingFactor
la Masses
ms' TParamsVar
tpv' TParamsFixed
tpf MassesI
msI' Mu
mus)

checkHStructureWith :: Foldable s => Masses -> HStructure s -> Maybe String
checkHStructureWith :: Masses -> HStructure s -> Maybe String
checkHStructureWith Masses
ms (HStructure s LeapfrogScalingFactor
x s LeapfrogScalingFactor -> Mu
toVec s LeapfrogScalingFactor -> Mu -> s LeapfrogScalingFactor
fromVec)
  | s LeapfrogScalingFactor -> [LeapfrogScalingFactor]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (s LeapfrogScalingFactor -> Mu -> s LeapfrogScalingFactor
fromVec s LeapfrogScalingFactor
x Mu
xVec) [LeapfrogScalingFactor] -> [LeapfrogScalingFactor] -> Bool
forall a. Eq a => a -> a -> Bool
/= s LeapfrogScalingFactor -> [LeapfrogScalingFactor]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList s LeapfrogScalingFactor
x = String -> Maybe String
eWith String
"'fromVectorWith x (toVector x) /= x' for sample state."
  | Mu -> IndexOf Vector
forall (c :: * -> *) t. Container c t => c t -> IndexOf c
L.size Mu
xVec Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
nrows = String -> Maybe String
eWith String
"Mass matrix and 'toVector x' have different sizes for sample state."
  | Bool
otherwise = Maybe String
forall a. Maybe a
Nothing
  where
    eWith :: String -> Maybe String
eWith String
m = String -> Maybe String
forall a. a -> Maybe a
Just (String -> Maybe String) -> String -> Maybe String
forall a b. (a -> b) -> a -> b
$ String
"checkHStructureWith: " String -> ShowS
forall a. Semigroup a => a -> a -> a
<> String
m
    nrows :: Int
nrows = Matrix LeapfrogScalingFactor -> Int
forall t. Matrix t -> Int
L.rows (Matrix LeapfrogScalingFactor -> Int)
-> Matrix LeapfrogScalingFactor -> Int
forall a b. (a -> b) -> a -> b
$ Masses -> Matrix LeapfrogScalingFactor
forall t. Herm t -> Matrix t
L.unSym Masses
ms
    xVec :: Mu
xVec = s LeapfrogScalingFactor -> Mu
toVec s LeapfrogScalingFactor
x

-- Generate momenta for a new iteration.
generateMomenta ::
  StatefulGen g m =>
  Mu ->
  Masses ->
  g ->
  m Momenta
generateMomenta :: Mu -> Masses -> g -> m Mu
generateMomenta Mu
mu Masses
masses g
gen = do
  Int
seed <- g -> m Int
forall a g (m :: * -> *). (Uniform a, StatefulGen g m) => g -> m a
uniformM g
gen
  let momenta :: Matrix LeapfrogScalingFactor
momenta = Int -> Int -> Mu -> Masses -> Matrix LeapfrogScalingFactor
L.gaussianSample Int
seed Int
1 Mu
mu Masses
masses
  Mu -> m Mu
forall (m :: * -> *) a. Monad m => a -> m a
return (Mu -> m Mu) -> Mu -> m Mu
forall a b. (a -> b) -> a -> b
$ Matrix LeapfrogScalingFactor -> Mu
forall t. Element t => Matrix t -> Vector t
L.flatten Matrix LeapfrogScalingFactor
momenta

-- Compute exponent of kinetic energy.
--
-- Use a general matrix which has special representations for diagonal and
-- sparse matrices, both of which are really useful here.
exponentialKineticEnergy ::
  MassesI ->
  Momenta ->
  Log Double
exponentialKineticEnergy :: MassesI -> Mu -> Log LeapfrogScalingFactor
exponentialKineticEnergy MassesI
msI Mu
xs =
  -- NOTE: Because of numerical errors, the following formulas exhibit different
  -- traces (although the posterior appears to be the same):
  -- - This one we cannot use with general matrices:
  --   Exp $ (-0.5) * ((xs L.#> msI) L.<.> xs)
  LeapfrogScalingFactor -> Log LeapfrogScalingFactor
forall a. a -> Log a
Exp (LeapfrogScalingFactor -> Log LeapfrogScalingFactor)
-> LeapfrogScalingFactor -> Log LeapfrogScalingFactor
forall a b. (a -> b) -> a -> b
$ (-LeapfrogScalingFactor
0.5) LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* (Mu
xs Mu -> Mu -> LeapfrogScalingFactor
forall t. Numeric t => Vector t -> Vector t -> t
L.<.> (MassesI
msI MassesI -> Mu -> Mu
L.!#> Mu
xs))

-- Function calculating target value and gradient.
--
-- The function acts on the subset of the state manipulated by the proposal but
-- the value and gradient have to be calculated for the complete state. The
-- reason is that parameters untouched by the Hamiltonian proposal may affect
-- the result or the gradient.
--
-- Make sure that the value is calculated lazily because many times, only the
-- gradient is required.
type Target = Positions -> (Log Double, Positions)

-- Leapfrog integrator.
leapfrog ::
  Target ->
  MassesI ->
  --
  LeapfrogTrajectoryLength ->
  LeapfrogScalingFactor ->
  --
  Positions ->
  Momenta ->
  -- | (New positions, new momenta, old target, new target).
  --
  -- Fail if state is not valid.
  Maybe (Positions, Momenta, Log Double, Log Double)
leapfrog :: Target
-> MassesI
-> Int
-> LeapfrogScalingFactor
-> Mu
-> Mu
-> Maybe
     (Mu, Mu, Log LeapfrogScalingFactor, Log LeapfrogScalingFactor)
leapfrog Target
tF MassesI
msI Int
l LeapfrogScalingFactor
eps Mu
q Mu
p = do
  -- The first half step of the momenta.
  (Log LeapfrogScalingFactor
x, Mu
pHalf) <-
    let (Log LeapfrogScalingFactor
x, Mu
pHalf) = LeapfrogScalingFactor -> Target -> Mu -> Target
leapfrogStepMomenta (LeapfrogScalingFactor
0.5 LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
eps) Target
tF Mu
q Mu
p
     in if Log LeapfrogScalingFactor
x Log LeapfrogScalingFactor -> Log LeapfrogScalingFactor -> Bool
forall a. Ord a => a -> a -> Bool
> Log LeapfrogScalingFactor
0.0
          then (Log LeapfrogScalingFactor, Mu)
-> Maybe (Log LeapfrogScalingFactor, Mu)
forall a. a -> Maybe a
Just (Log LeapfrogScalingFactor
x, Mu
pHalf)
          else Maybe (Log LeapfrogScalingFactor, Mu)
forall a. Maybe a
Nothing
  -- L-1 full steps for positions and momenta. This gives the positions q_{L-1},
  -- and the momenta p_{L-1/2}.
  (Mu
qLM1, Mu
pLM1Half) <- Int -> Maybe (Mu, Mu) -> Maybe (Mu, Mu)
forall t. (Ord t, Num t) => t -> Maybe (Mu, Mu) -> Maybe (Mu, Mu)
go (Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Maybe (Mu, Mu) -> Maybe (Mu, Mu))
-> Maybe (Mu, Mu) -> Maybe (Mu, Mu)
forall a b. (a -> b) -> a -> b
$ (Mu, Mu) -> Maybe (Mu, Mu)
forall a. a -> Maybe a
Just ((Mu, Mu) -> Maybe (Mu, Mu)) -> (Mu, Mu) -> Maybe (Mu, Mu)
forall a b. (a -> b) -> a -> b
$ (Mu
q, Mu
pHalf)
  -- The last full step of the positions.
  let qL :: Mu
qL = MassesI -> LeapfrogScalingFactor -> Mu -> Mu -> Mu
leapfrogStepPositions MassesI
msI LeapfrogScalingFactor
eps Mu
qLM1 Mu
pLM1Half
  -- The last half step of the momenta.
  (Log LeapfrogScalingFactor
x', Mu
pL) <-
    let (Log LeapfrogScalingFactor
x', Mu
pL) = LeapfrogScalingFactor -> Target -> Mu -> Target
leapfrogStepMomenta (LeapfrogScalingFactor
0.5 LeapfrogScalingFactor
-> LeapfrogScalingFactor -> LeapfrogScalingFactor
forall a. Num a => a -> a -> a
* LeapfrogScalingFactor
eps) Target
tF Mu
qL Mu
pLM1Half
     in if Log LeapfrogScalingFactor
x' Log LeapfrogScalingFactor -> Log LeapfrogScalingFactor -> Bool
forall a. Ord a => a -> a -> Bool
> Log LeapfrogScalingFactor
0.0
          then (Log LeapfrogScalingFactor, Mu)
-> Maybe (Log LeapfrogScalingFactor, Mu)
forall a. a -> Maybe a
Just (Log LeapfrogScalingFactor
x', Mu
pL)
          else Maybe (Log LeapfrogScalingFactor, Mu)
forall a. Maybe a
Nothing
  (Mu, Mu, Log LeapfrogScalingFactor, Log LeapfrogScalingFactor)
-> Maybe
     (Mu, Mu, Log LeapfrogScalingFactor, Log LeapfrogScalingFactor)
forall (m :: * -> *) a. Monad m => a -> m a
return (Mu
qL, Mu
pL, Log LeapfrogScalingFactor
x, Log LeapfrogScalingFactor
x')
  where
    go :: t -> Maybe (Mu, Mu) -> Maybe (Mu, Mu)
go t
_ Maybe (Mu, Mu)
Nothing = Maybe (Mu, Mu)
forall a. Maybe a
Nothing
    go t
n (Just (Mu
qs, Mu
ps))
      | t
n t -> t -> Bool
forall a. Ord a => a -> a -> Bool
<= t
0 = (Mu, Mu) -> Maybe (Mu, Mu)
forall a. a -> Maybe a
Just (Mu
qs, Mu
ps)
      | Bool
otherwise =
          let qs' :: Mu
qs' = MassesI -> LeapfrogScalingFactor -> Mu -> Mu -> Mu
leapfrogStepPositions MassesI
msI LeapfrogScalingFactor
eps Mu
qs Mu
ps
              (Log LeapfrogScalingFactor
x, Mu
ps') = LeapfrogScalingFactor -> Target -> Mu -> Target
leapfrogStepMomenta LeapfrogScalingFactor
eps Target
tF Mu
qs' Mu
p
           in if Log LeapfrogScalingFactor
x Log LeapfrogScalingFactor -> Log LeapfrogScalingFactor -> Bool
forall a. Ord a => a -> a -> Bool
> Log LeapfrogScalingFactor
0.0
                then t -> Maybe (Mu, Mu) -> Maybe (Mu, Mu)
go (t
n t -> t -> t
forall a. Num a => a -> a -> a
- t
1) (Maybe (Mu, Mu) -> Maybe (Mu, Mu))
-> Maybe (Mu, Mu) -> Maybe (Mu, Mu)
forall a b. (a -> b) -> a -> b
$ (Mu, Mu) -> Maybe (Mu, Mu)
forall a. a -> Maybe a
Just ((Mu, Mu) -> Maybe (Mu, Mu)) -> (Mu, Mu) -> Maybe (Mu, Mu)
forall a b. (a -> b) -> a -> b
$ (Mu
qs', Mu
ps')
                else Maybe (Mu, Mu)
forall a. Maybe a
Nothing

leapfrogStepMomenta ::
  LeapfrogScalingFactor ->
  Target ->
  -- Current positions.
  Positions ->
  -- Current momenta.
  Momenta ->
  -- New momenta; also return value target function to be collected at the end
  -- of the leapfrog integration.
  (Log Double, Momenta)
leapfrogStepMomenta :: LeapfrogScalingFactor -> Target -> Mu -> Target
leapfrogStepMomenta LeapfrogScalingFactor
eps Target
tf Mu
q Mu
p = (Log LeapfrogScalingFactor
x, Mu
p Mu -> Mu -> Mu
forall a. Num a => a -> a -> a
+ LeapfrogScalingFactor -> Mu -> Mu
forall t (c :: * -> *). Linear t c => t -> c t -> c t
L.scale LeapfrogScalingFactor
eps Mu
g)
  where
    (Log LeapfrogScalingFactor
x, Mu
g) = Target
tf Mu
q

leapfrogStepPositions ::
  MassesI ->
  LeapfrogScalingFactor ->
  -- Current positions.
  Positions ->
  -- Current momenta.
  Momenta ->
  -- New positions.
  Positions
-- NOTE: Because of numerical errors, the following formulas exhibit different
-- traces (although the posterior appears to be the same):
-- 1. This one we cannot use with general matrices:
--    leapfrogStepPositions msI eps q p = q + (L.scale eps msI L.!#> p)
-- 2. This one seems to be more numerically unstable:
--    leapfrogStepPositions msI eps q p = q + L.scale eps (msI L.!#> p)
leapfrogStepPositions :: MassesI -> LeapfrogScalingFactor -> Mu -> Mu -> Mu
leapfrogStepPositions MassesI
msI LeapfrogScalingFactor
eps Mu
q Mu
p = Mu
q Mu -> Mu -> Mu
forall a. Num a => a -> a -> a
+ (MassesI
msI MassesI -> Mu -> Mu
L.!#> LeapfrogScalingFactor -> Mu -> Mu
forall t (c :: * -> *). Linear t c => t -> c t -> c t
L.scale LeapfrogScalingFactor
eps Mu
p)