-- Creation date: Thu Jun  9 15:12:39 2022.
--
-- See "Mcmc.Proposal.Hamiltonian.Hamiltonian".
--
-- References:
--
-- - [1] Chapter 5 of Handbook of Monte Carlo: Neal, R. M., MCMC Using
--   Hamiltonian Dynamics, In S. Brooks, A. Gelman, G. Jones, & X. Meng (Eds.),
--   Handbook of Markov Chain Monte Carlo (2011), CRC press.
--
-- - [2] Gelman, A., Carlin, J. B., Stern, H. S., & Rubin, D. B., Bayesian data
--   analysis (2014), CRC Press.
--
-- - [3] Review by Betancourt and notes: Betancourt, M., A conceptual
--   introduction to Hamiltonian Monte Carlo, arXiv, 1701–02434 (2017).
--
-- - [4] Matthew D. Hoffman, Andrew Gelman (2014) The No-U-Turn Sampler:
--   Adaptively Setting Path Lengths in Hamiltonian Monte Carlo, Journal of
--   Machine Learning Research.
{-# LANGUAGE BangPatterns #-}

-- |
-- Module      :  Mcmc.Proposal.Hamiltonian.Internal
-- Description :  Internal definitions related to Hamiltonian dynamics
-- Copyright   :  2022 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  experimental
-- Portability :  portable
module Mcmc.Proposal.Hamiltonian.Internal
  ( -- * Parameters
    HParamsI (..),
    hParamsIWith,

    -- * Tuning
    toAuxiliaryTuningParameters,
    fromAuxiliaryTuningParameters,
    findReasonableEpsilon,
    hTuningFunctionWith,

    -- * Structure of state
    checkHStructureWith,

    -- * Hamiltonian dynamics
    generateMomenta,
    exponentialKineticEnergy,

    -- * Leapfrog integrator
    Target,
    leapfrog,
  )
where

import Control.Monad
import Control.Monad.ST
import Data.Foldable
import Data.Maybe
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Unboxed as VU
import Mcmc.Proposal
import Mcmc.Proposal.Hamiltonian.Common
import Mcmc.Proposal.Hamiltonian.Masses
import qualified Numeric.LinearAlgebra as L
import Numeric.Log
import System.Random.Stateful

-- Variable tuning parameters.
--
-- See Algorithm 5 or 6 in [4].
data TParamsVar = TParamsVar
  { -- \bar{eps} of Algorithm 5 or 6.
    TParamsVar -> Double
tpvLeapfrogScalingFactorMean :: LeapfrogScalingFactor,
    -- H_i of Algorithm 5 or 6.
    TParamsVar -> Double
tpvHStatistics :: Double,
    -- m of Algorithm 5 or 6.
    TParamsVar -> Double
tpvCurrentTuningStep :: Double
  }
  deriving (LeapfrogTrajectoryLength -> TParamsVar -> ShowS
[TParamsVar] -> ShowS
TParamsVar -> String
forall a.
(LeapfrogTrajectoryLength -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TParamsVar] -> ShowS
$cshowList :: [TParamsVar] -> ShowS
show :: TParamsVar -> String
$cshow :: TParamsVar -> String
showsPrec :: LeapfrogTrajectoryLength -> TParamsVar -> ShowS
$cshowsPrec :: LeapfrogTrajectoryLength -> TParamsVar -> ShowS
Show)

tParamsVar :: TParamsVar
tParamsVar :: TParamsVar
tParamsVar = Double -> Double -> Double -> TParamsVar
TParamsVar Double
1.0 Double
0.0 Double
1.0

-- Fixed tuning parameters.
--
-- See Algorithm 5 and 6 in [4].
data TParamsFixed = TParamsFixed
  { TParamsFixed -> Double
tpfEps0 :: Double,
    TParamsFixed -> Double
tpfMu :: Double,
    TParamsFixed -> Double
tpfGa :: Double,
    TParamsFixed -> Double
tpfT0 :: Double,
    TParamsFixed -> Double
tpfKa :: Double
  }
  deriving (LeapfrogTrajectoryLength -> TParamsFixed -> ShowS
[TParamsFixed] -> ShowS
TParamsFixed -> String
forall a.
(LeapfrogTrajectoryLength -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TParamsFixed] -> ShowS
$cshowList :: [TParamsFixed] -> ShowS
show :: TParamsFixed -> String
$cshow :: TParamsFixed -> String
showsPrec :: LeapfrogTrajectoryLength -> TParamsFixed -> ShowS
$cshowsPrec :: LeapfrogTrajectoryLength -> TParamsFixed -> ShowS
Show)

-- The default tuning parameters in [4] which have been tweaked for tuning the
-- proposals after every iteration are:
--
--   mu = log $ 10 * eps
--   ga = 0.05
--   t0 = 10
--   ka = 0.75
--
-- For reference, I used the following default parameters with longer auto
-- tuning intervals.
--
--   mu = log $ 10 * eps
--   ga = 0.1
--   t0 = 3
--   ka = 0.5
--
-- Another good resource:
-- https://mc-stan.org/docs/2_29/reference-manual/hmc-algorithm-parameters.html.
--
-- NOTE: In theory, we could expose these internal tuning parameters to the
-- user.
tParamsFixedWith :: LeapfrogScalingFactor -> TParamsFixed
tParamsFixedWith :: Double -> TParamsFixed
tParamsFixedWith Double
eps = Double -> Double -> Double -> Double -> Double -> TParamsFixed
TParamsFixed Double
eps Double
mu Double
ga Double
t0 Double
ka
  where
    -- "Mu is a freely chosen point that the iterators are shrunk towards". I am
    -- not exactly sure what this means. The parameter does not seem to have
    -- much of an effect.
    mu :: Double
mu = forall a. Floating a => a -> a
log forall a b. (a -> b) -> a -> b
$ Double
10 forall a. Num a => a -> a -> a
* Double
eps
    -- Gamma "controls the amount of shrinkage towards mu". The larger gamma is,
    -- the less variant epsilon is.
    --
    -- I changed this parameter from 0.05 to get better results in test runs.
    ga :: Double
ga = Double
0.15
    -- "Free parameter that stabilizes the initial iterations". The larger t0
    -- is, the stabler epsilon is in the first iterations.
    t0 :: Double
t0 = Double
10
    -- "Setting the parameter ka < 1 allows us to give higher weight to more
    -- recent iterates and to more quickly forget the iterates produced during
    -- the early warmup stages."
    --
    -- I changed this parameter from 0.75 to get better results in test runs.
    ka :: Double
ka = Double
0.75

-- All internal parameters.
data HParamsI = HParamsI
  { HParamsI -> Double
hpsLeapfrogScalingFactor :: LeapfrogScalingFactor,
    HParamsI -> Double
hpsLeapfrogSimulationLength :: LeapfrogSimulationLength,
    HParamsI -> Masses
hpsMasses :: Masses,
    HParamsI -> TParamsVar
hpsTParamsVar :: TParamsVar,
    HParamsI -> TParamsFixed
hpsTParamsFixed :: TParamsFixed,
    HParamsI -> MassesI
hpsMassesI :: MassesI,
    HParamsI -> Positions
hpsMu :: Mu
  }
  deriving (LeapfrogTrajectoryLength -> HParamsI -> ShowS
[HParamsI] -> ShowS
HParamsI -> String
forall a.
(LeapfrogTrajectoryLength -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HParamsI] -> ShowS
$cshowList :: [HParamsI] -> ShowS
show :: HParamsI -> String
$cshow :: HParamsI -> String
showsPrec :: LeapfrogTrajectoryLength -> HParamsI -> ShowS
$cshowsPrec :: LeapfrogTrajectoryLength -> HParamsI -> ShowS
Show)

-- NOTE: If changed, amend help text of 'defaultHParams', and 'defaultNParams'.
defaultLeapfrogScalingFactor :: LeapfrogScalingFactor
defaultLeapfrogScalingFactor :: Double
defaultLeapfrogScalingFactor = Double
0.1

-- NOTE: If changed, amend help text of 'defaultHParams'.
defaultLeapfrogSimulationLength :: LeapfrogSimulationLength
defaultLeapfrogSimulationLength :: Double
defaultLeapfrogSimulationLength = Double
0.5

-- NOTE: If changed, amend help text of 'defaultHParams'.
defaultMassesWith :: Int -> Masses
defaultMassesWith :: LeapfrogTrajectoryLength -> Masses
defaultMassesWith LeapfrogTrajectoryLength
d = forall t. Matrix t -> Herm t
L.trustSym forall a b. (a -> b) -> a -> b
$ forall a.
(Num a, Element a) =>
LeapfrogTrajectoryLength -> Matrix a
L.ident LeapfrogTrajectoryLength
d

-- Instantiate all internal parameters.
hParamsIWith ::
  Target ->
  Positions ->
  Maybe LeapfrogScalingFactor ->
  Maybe LeapfrogSimulationLength ->
  Maybe Masses ->
  Either String HParamsI
hParamsIWith :: Target
-> Positions
-> Maybe Double
-> Maybe Double
-> Maybe Masses
-> Either String HParamsI
hParamsIWith Target
htarget Positions
p Maybe Double
mEps Maybe Double
mLa Maybe Masses
mMs = do
  LeapfrogTrajectoryLength
d <- case forall a. Storable a => Vector a -> LeapfrogTrajectoryLength
VS.length Positions
p of
    LeapfrogTrajectoryLength
0 -> forall {b}. String -> Either String b
eWith String
"Empty position vector."
    LeapfrogTrajectoryLength
d -> forall a b. b -> Either a b
Right LeapfrogTrajectoryLength
d
  Masses
ms <- case Maybe Masses
mMs of
    Maybe Masses
Nothing -> forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ LeapfrogTrajectoryLength -> Masses
defaultMassesWith LeapfrogTrajectoryLength
d
    Just Masses
ms -> do
      let ms' :: Matrix Double
ms' = Matrix Double -> Matrix Double
cleanMatrix forall a b. (a -> b) -> a -> b
$ forall t. Herm t -> Matrix t
L.unSym Masses
ms
          diagonalMs :: [Double]
diagonalMs = forall a. Storable a => Vector a -> [a]
L.toList forall a b. (a -> b) -> a -> b
$ forall t. Element t => Matrix t -> Vector t
L.takeDiag Matrix Double
ms'
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall a. Ord a => a -> a -> Bool
<= Double
0) [Double]
diagonalMs) forall a b. (a -> b) -> a -> b
$ forall {b}. String -> Either String b
eWith String
"Some diagonal masses are zero or negative."
      let nrows :: LeapfrogTrajectoryLength
nrows = forall t. Matrix t -> LeapfrogTrajectoryLength
L.rows Matrix Double
ms'
          ncols :: LeapfrogTrajectoryLength
ncols = forall t. Matrix t -> LeapfrogTrajectoryLength
L.cols Matrix Double
ms'
      forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (LeapfrogTrajectoryLength
nrows forall a. Eq a => a -> a -> Bool
/= LeapfrogTrajectoryLength
ncols) forall a b. (a -> b) -> a -> b
$ forall {b}. String -> Either String b
eWith String
"Mass matrix is not square."
      forall a b. b -> Either a b
Right Masses
ms
  let msI :: MassesI
msI = Masses -> MassesI
getMassesI Masses
ms
      mus :: Positions
mus = Masses -> Positions
getMus Masses
ms
  Double
la <- case Maybe Double
mLa of
    Maybe Double
Nothing -> forall a b. b -> Either a b
Right Double
defaultLeapfrogSimulationLength
    Just Double
l
      | Double
l forall a. Ord a => a -> a -> Bool
<= Double
0 -> forall {b}. String -> Either String b
eWith String
"Leapfrog simulation length is zero or negative."
      | Bool
otherwise -> forall a b. b -> Either a b
Right Double
l
  Double
eps <- case Maybe Double
mEps of
    Maybe Double
Nothing -> forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
      -- NOTE: This is not random. However, I do not want to provide a generator
      -- when creating the proposal.
      STGenM StdGen s
g <- forall g s. g -> ST s (STGenM g s)
newSTGenM forall a b. (a -> b) -> a -> b
$ LeapfrogTrajectoryLength -> StdGen
mkStdGen LeapfrogTrajectoryLength
42
      forall g (m :: * -> *).
StatefulGen g m =>
Target -> Masses -> Positions -> g -> m Double
findReasonableEpsilon Target
htarget Masses
ms Positions
p STGenM StdGen s
g
    Just Double
e
      | Double
e forall a. Ord a => a -> a -> Bool
<= Double
0 -> forall {b}. String -> Either String b
eWith String
"Leapfrog scaling factor is zero or negative."
      | Bool
otherwise -> forall a b. b -> Either a b
Right Double
e
  let tParamsFixed :: TParamsFixed
tParamsFixed = Double -> TParamsFixed
tParamsFixedWith Double
eps
  forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Double
-> Double
-> Masses
-> TParamsVar
-> TParamsFixed
-> MassesI
-> Positions
-> HParamsI
HParamsI Double
eps Double
la Masses
ms TParamsVar
tParamsVar TParamsFixed
tParamsFixed MassesI
msI Positions
mus
  where
    eWith :: String -> Either String b
eWith String
m = forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ String
"hParamsIWith: " forall a. Semigroup a => a -> a -> a
<> String
m

-- Save internal parameters.
toAuxiliaryTuningParameters :: HParamsI -> AuxiliaryTuningParameters
toAuxiliaryTuningParameters :: HParamsI -> AuxiliaryTuningParameters
toAuxiliaryTuningParameters (HParamsI Double
eps Double
la Masses
ms TParamsVar
tpv TParamsFixed
tpf MassesI
_ Positions
_) =
  -- Put masses to the end. Like so, conversion is easier.
  forall a. Unbox a => [a] -> Vector a
VU.fromList forall a b. (a -> b) -> a -> b
$ Double
eps forall a. a -> [a] -> [a]
: Double
la forall a. a -> [a] -> [a]
: Double
epsMean forall a. a -> [a] -> [a]
: Double
h forall a. a -> [a] -> [a]
: Double
m forall a. a -> [a] -> [a]
: Double
eps0 forall a. a -> [a] -> [a]
: Double
mu forall a. a -> [a] -> [a]
: Double
ga forall a. a -> [a] -> [a]
: Double
t0 forall a. a -> [a] -> [a]
: Double
ka forall a. a -> [a] -> [a]
: [Double]
msL
  where
    (TParamsVar Double
epsMean Double
h Double
m) = TParamsVar
tpv
    (TParamsFixed Double
eps0 Double
mu Double
ga Double
t0 Double
ka) = TParamsFixed
tpf
    msL :: [Double]
msL = forall a. Unbox a => Vector a -> [a]
VU.toList forall a b. (a -> b) -> a -> b
$ Masses -> AuxiliaryTuningParameters
massesToVector Masses
ms

-- Load internal parameters.
fromAuxiliaryTuningParameters :: Dimension -> AuxiliaryTuningParameters -> Either String HParamsI
fromAuxiliaryTuningParameters :: LeapfrogTrajectoryLength
-> AuxiliaryTuningParameters -> Either String HParamsI
fromAuxiliaryTuningParameters LeapfrogTrajectoryLength
d AuxiliaryTuningParameters
xs
  | (LeapfrogTrajectoryLength
d forall a. Num a => a -> a -> a
* LeapfrogTrajectoryLength
d) forall a. Num a => a -> a -> a
+ LeapfrogTrajectoryLength
10 forall a. Eq a => a -> a -> Bool
/= LeapfrogTrajectoryLength
len = forall a b. a -> Either a b
Left String
"fromAuxiliaryTuningParameters: Dimension mismatch."
  | forall a b. (Integral a, Num b) => a -> b
fromIntegral (LeapfrogTrajectoryLength
d forall a. Num a => a -> a -> a
* LeapfrogTrajectoryLength
d) forall a. Eq a => a -> a -> Bool
/= LeapfrogTrajectoryLength
lenMs = forall a b. a -> Either a b
Left String
"fromAuxiliaryTuningParameters: Masses dimension mismatch."
  | Bool
otherwise = case forall a. Unbox a => Vector a -> [a]
VU.toList forall a b. (a -> b) -> a -> b
$ forall a.
Unbox a =>
LeapfrogTrajectoryLength -> Vector a -> Vector a
VU.take LeapfrogTrajectoryLength
10 AuxiliaryTuningParameters
xs of
      [Double
eps, Double
la, Double
epsMean, Double
h, Double
m, Double
eps0, Double
mu, Double
ga, Double
t0, Double
ka] ->
        let tpv :: TParamsVar
tpv = Double -> Double -> Double -> TParamsVar
TParamsVar Double
epsMean Double
h Double
m
            tpf :: TParamsFixed
tpf = Double -> Double -> Double -> Double -> Double -> TParamsFixed
TParamsFixed Double
eps0 Double
mu Double
ga Double
t0 Double
ka
         in forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ Double
-> Double
-> Masses
-> TParamsVar
-> TParamsFixed
-> MassesI
-> Positions
-> HParamsI
HParamsI Double
eps Double
la Masses
ms TParamsVar
tpv TParamsFixed
tpf MassesI
msI Positions
mus
      -- To please the exhaustive pattern match checker.
      [Double]
_ -> forall a b. a -> Either a b
Left String
"fromAuxiliaryTuningParameters: Impossible dimension mismatch."
  where
    len :: LeapfrogTrajectoryLength
len = forall a. Unbox a => Vector a -> LeapfrogTrajectoryLength
VU.length AuxiliaryTuningParameters
xs
    msV :: AuxiliaryTuningParameters
msV = forall a.
Unbox a =>
LeapfrogTrajectoryLength -> Vector a -> Vector a
VU.drop LeapfrogTrajectoryLength
10 AuxiliaryTuningParameters
xs
    lenMs :: LeapfrogTrajectoryLength
lenMs = forall a. Unbox a => Vector a -> LeapfrogTrajectoryLength
VU.length AuxiliaryTuningParameters
msV
    ms :: Masses
ms = LeapfrogTrajectoryLength -> AuxiliaryTuningParameters -> Masses
vectorToMasses LeapfrogTrajectoryLength
d AuxiliaryTuningParameters
msV
    msI :: MassesI
msI = Masses -> MassesI
getMassesI Masses
ms
    mus :: Positions
mus = Masses -> Positions
getMus Masses
ms

-- See Algorithm 4 in [4].
findReasonableEpsilon ::
  StatefulGen g m =>
  Target ->
  Masses ->
  Positions ->
  g ->
  m LeapfrogScalingFactor
findReasonableEpsilon :: forall g (m :: * -> *).
StatefulGen g m =>
Target -> Masses -> Positions -> g -> m Double
findReasonableEpsilon Target
t Masses
ms Positions
q g
g = do
  Positions
p <- forall g (m :: * -> *).
StatefulGen g m =>
Positions -> Masses -> g -> m Positions
generateMomenta Positions
mu Masses
ms g
g
  case Target
-> MassesI
-> LeapfrogTrajectoryLength
-> Double
-> Positions
-> Positions
-> Maybe (Positions, Positions, Log Double, Log Double)
leapfrog Target
t MassesI
msI LeapfrogTrajectoryLength
1 Double
eI Positions
q Positions
p of
    Maybe (Positions, Positions, Log Double, Log Double)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
defaultLeapfrogScalingFactor
    Just (Positions
_, Positions
p', Log Double
prQ, Log Double
prQ') -> do
      let expEKin :: Log Double
expEKin = MassesI -> Positions -> Log Double
exponentialKineticEnergy MassesI
msI Positions
p
          expEKin' :: Log Double
expEKin' = MassesI -> Positions -> Log Double
exponentialKineticEnergy MassesI
msI Positions
p'
          rI :: Double
          rI :: Double
rI = forall a. Floating a => a -> a
exp forall a b. (a -> b) -> a -> b
$ forall a. Log a -> a
ln forall a b. (a -> b) -> a -> b
$ Log Double
prQ' forall a. Num a => a -> a -> a
* Log Double
expEKin' forall a. Fractional a => a -> a -> a
/ (Log Double
prQ forall a. Num a => a -> a -> a
* Log Double
expEKin)
          a :: Double
          a :: Double
a = if Double
rI forall a. Ord a => a -> a -> Bool
> Double
0.5 then Double
1 else (-Double
1)
          go :: Double -> Double -> Double
go Double
e Double
r =
            if Double
r forall a. Floating a => a -> a -> a
** Double
a forall a. Ord a => a -> a -> Bool
> Double
2 forall a. Floating a => a -> a -> a
** forall a. Num a => a -> a
negate Double
a
              then case Target
-> MassesI
-> LeapfrogTrajectoryLength
-> Double
-> Positions
-> Positions
-> Maybe (Positions, Positions, Log Double, Log Double)
leapfrog Target
t MassesI
msI LeapfrogTrajectoryLength
1 Double
e Positions
q Positions
p of
                Maybe (Positions, Positions, Log Double, Log Double)
Nothing -> Double
e
                Just (Positions
_, Positions
p'', Log Double
_, Log Double
prQ'') ->
                  let expEKin'' :: Log Double
expEKin'' = MassesI -> Positions -> Log Double
exponentialKineticEnergy MassesI
msI Positions
p''
                      r' :: Double
                      r' :: Double
r' = forall a. Floating a => a -> a
exp forall a b. (a -> b) -> a -> b
$ forall a. Log a -> a
ln forall a b. (a -> b) -> a -> b
$ Log Double
prQ'' forall a. Num a => a -> a -> a
* Log Double
expEKin'' forall a. Fractional a => a -> a -> a
/ (Log Double
prQ forall a. Num a => a -> a -> a
* Log Double
expEKin)
                      e' :: Double
e' = (Double
2 forall a. Floating a => a -> a -> a
** Double
a) forall a. Num a => a -> a -> a
* Double
e
                   in Double -> Double -> Double
go Double
e' Double
r'
              else Double
e
      forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Double -> Double -> Double
go Double
eI Double
rI
  where
    eI :: Double
eI = Double
1.0
    msI :: MassesI
msI = Masses -> MassesI
getMassesI Masses
ms
    mu :: Positions
mu = Masses -> Positions
getMus Masses
ms

hTuningFunctionWith ::
  Dimension ->
  -- Conversion from value to vector.
  (a -> Positions) ->
  HTuningConf ->
  Maybe (TuningFunction a)
hTuningFunctionWith :: forall a.
LeapfrogTrajectoryLength
-> (a -> Positions) -> HTuningConf -> Maybe (TuningFunction a)
hTuningFunctionWith LeapfrogTrajectoryLength
_ a -> Positions
_ (HTuningConf HTuneLeapfrog
HNoTuneLeapfrog HTuneMasses
HNoTuneMasses) = forall a. Maybe a
Nothing
hTuningFunctionWith LeapfrogTrajectoryLength
n a -> Positions
toVec (HTuningConf HTuneLeapfrog
lc HTuneMasses
mc) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ \TuningType
tt PDimension
pdim Maybe Double
mar Maybe (Vector a)
mxs (Double
_, !AuxiliaryTuningParameters
ts) ->
  case TuningType
tt of
    TuningType
IntermediateTuningFastProposalsOnly -> forall {a}. String -> a
err String
"fast intermediate tuning step but slow proposal"
    TuningType
NormalTuningFastProposalsOnly -> forall {a}. String -> a
err String
"fast normal tuning step but slow proposal"
    TuningType
_ ->
      let (HParamsI Double
eps Double
la Masses
ms TParamsVar
tpv TParamsFixed
tpf MassesI
msI Positions
mus) =
            -- NOTE: Use error here, because a dimension mismatch is a serious bug.
            forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. HasCallStack => String -> a
error forall a. a -> a
id forall a b. (a -> b) -> a -> b
$ LeapfrogTrajectoryLength
-> AuxiliaryTuningParameters -> Either String HParamsI
fromAuxiliaryTuningParameters LeapfrogTrajectoryLength
n AuxiliaryTuningParameters
ts
          (TParamsVar Double
epsMean Double
h Double
m) = TParamsVar
tpv
          (TParamsFixed Double
eps0 Double
mu Double
ga Double
t0 Double
ka) = TParamsFixed
tpf
          m' :: SmoothingParameter
m' = Natural -> SmoothingParameter
SmoothingParameter forall a b. (a -> b) -> a -> b
$ forall a b. (RealFrac a, Integral b) => a -> b
round Double
m
          (Masses
ms', MassesI
msI') = case TuningType
tt of
            TuningType
IntermediateTuningAllProposals -> (Masses
ms, MassesI
msI)
            TuningType
_ ->
              let xs :: Vector a
xs = forall a. a -> Maybe a -> a
fromMaybe (forall {a}. String -> a
err String
"empty trace") Maybe (Vector a)
mxs
               in case HTuneMasses
mc of
                    HTuneMasses
HNoTuneMasses -> (Masses
ms, MassesI
msI)
                    HTuneMasses
HTuneDiagonalMassesOnly -> forall a.
SmoothingParameter
-> (a -> Positions)
-> Vector a
-> (Masses, MassesI)
-> (Masses, MassesI)
tuneDiagonalMassesOnly SmoothingParameter
m' a -> Positions
toVec Vector a
xs (Masses
ms, MassesI
msI)
                    HTuneMasses
HTuneAllMasses -> forall a.
SmoothingParameter
-> (a -> Positions)
-> Vector a
-> (Masses, MassesI)
-> (Masses, MassesI)
tuneAllMasses SmoothingParameter
m' a -> Positions
toVec Vector a
xs (Masses
ms, MassesI
msI)
          (Double
eps'', Double
epsMean'', Double
h'') = case TuningType
tt of
            TuningType
LastTuningFastProposalsOnly -> (Double
eps, Double
epsMean, Double
h)
            TuningType
_ -> case HTuneLeapfrog
lc of
              HTuneLeapfrog
HNoTuneLeapfrog -> (Double
eps, Double
epsMean, Double
h)
              HTuneLeapfrog
HTuneLeapfrog ->
                let ar :: Double
ar = forall a. a -> Maybe a -> a
fromMaybe (forall {a}. String -> a
err String
"no acceptance rate") Maybe Double
mar
                    delta :: Double
delta = PDimension -> Double
getOptimalRate PDimension
pdim
                    -- Algorithm 6; explained in Section 3.2.
                    --
                    -- Another good resource is the Tensorflow API
                    -- documentation:
                    -- https://www.tensorflow.org/probability/api_docs/python/tfp/mcmc/DualAveragingStepSizeAdaptation.
                    --
                    -- See also Nesterov (2007) Primal-dual subgradient methods
                    -- for convex problems, Mathematical Programming.
                    c :: Double
c = forall a. Fractional a => a -> a
recip forall a b. (a -> b) -> a -> b
$ Double
m forall a. Num a => a -> a -> a
+ Double
t0
                    h' :: Double
h' = (Double
1.0 forall a. Num a => a -> a -> a
- Double
c) forall a. Num a => a -> a -> a
* Double
h forall a. Num a => a -> a -> a
+ Double
c forall a. Num a => a -> a -> a
* (Double
delta forall a. Num a => a -> a -> a
- Double
ar)
                    eps' :: Double
eps' = forall a. Floating a => a -> a
exp forall a b. (a -> b) -> a -> b
$ Double
mu forall a. Num a => a -> a -> a
- (forall a. Floating a => a -> a
sqrt Double
m forall a. Fractional a => a -> a -> a
/ Double
ga) forall a. Num a => a -> a -> a
* Double
h'
                    mMKa :: Double
mMKa = Double
m forall a. Floating a => a -> a -> a
** forall a. Num a => a -> a
negate Double
ka
                    -- Original formula is:
                    -- epsMean' = exp $ mMKa * logEps' + (1 - mMKa) * log epsMean
                    -- Which is the same as:
                    epsMean' :: Double
epsMean' = (Double
eps' forall a. Floating a => a -> a -> a
** Double
mMKa) forall a. Num a => a -> a -> a
* (Double
epsMean forall a. Floating a => a -> a -> a
** (Double
1 forall a. Num a => a -> a -> a
- Double
mMKa))
                    epsF :: Double
epsF = if TuningType
tt forall a. Eq a => a -> a -> Bool
== TuningType
LastTuningAllProposals then Double
epsMean' else Double
eps'
                 in (Double
epsF, Double
epsMean', Double
h')
          tpv' :: TParamsVar
tpv' = Double -> Double -> Double -> TParamsVar
TParamsVar Double
epsMean'' Double
h'' (Double
m forall a. Num a => a -> a -> a
+ Double
1.0)
       in (Double
eps'' forall a. Fractional a => a -> a -> a
/ Double
eps0, HParamsI -> AuxiliaryTuningParameters
toAuxiliaryTuningParameters forall a b. (a -> b) -> a -> b
$ Double
-> Double
-> Masses
-> TParamsVar
-> TParamsFixed
-> MassesI
-> Positions
-> HParamsI
HParamsI Double
eps'' Double
la Masses
ms' TParamsVar
tpv' TParamsFixed
tpf MassesI
msI' Positions
mus)
  where
    err :: String -> a
err String
msg = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"hTuningFunctionWith: " forall a. Semigroup a => a -> a -> a
<> String
msg

checkHStructureWith :: Foldable s => Masses -> HStructure s -> Maybe String
checkHStructureWith :: forall (s :: * -> *).
Foldable s =>
Masses -> HStructure s -> Maybe String
checkHStructureWith Masses
ms (HStructure s Double
x s Double -> Positions
toVec s Double -> Positions -> s Double
fromVec)
  | forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (s Double -> Positions -> s Double
fromVec s Double
x Positions
xVec) forall a. Eq a => a -> a -> Bool
/= forall (t :: * -> *) a. Foldable t => t a -> [a]
toList s Double
x = String -> Maybe String
eWith String
"'fromVectorWith x (toVector x) /= x' for sample state."
  | forall (c :: * -> *) t. Container c t => c t -> IndexOf c
L.size Positions
xVec forall a. Eq a => a -> a -> Bool
/= LeapfrogTrajectoryLength
nrows = String -> Maybe String
eWith String
"Mass matrix and 'toVector x' have different sizes for sample state."
  | Bool
otherwise = forall a. Maybe a
Nothing
  where
    eWith :: String -> Maybe String
eWith String
m = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ String
"checkHStructureWith: " forall a. Semigroup a => a -> a -> a
<> String
m
    nrows :: LeapfrogTrajectoryLength
nrows = forall t. Matrix t -> LeapfrogTrajectoryLength
L.rows forall a b. (a -> b) -> a -> b
$ forall t. Herm t -> Matrix t
L.unSym Masses
ms
    xVec :: Positions
xVec = s Double -> Positions
toVec s Double
x

-- Generate momenta for a new iteration.
generateMomenta ::
  StatefulGen g m =>
  Mu ->
  Masses ->
  g ->
  m Momenta
generateMomenta :: forall g (m :: * -> *).
StatefulGen g m =>
Positions -> Masses -> g -> m Positions
generateMomenta Positions
mu Masses
masses g
gen = do
  LeapfrogTrajectoryLength
seed <- forall a g (m :: * -> *). (Uniform a, StatefulGen g m) => g -> m a
uniformM g
gen
  let momenta :: Matrix Double
momenta = LeapfrogTrajectoryLength
-> LeapfrogTrajectoryLength -> Positions -> Masses -> Matrix Double
L.gaussianSample LeapfrogTrajectoryLength
seed LeapfrogTrajectoryLength
1 Positions
mu Masses
masses
  forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall t. Element t => Matrix t -> Vector t
L.flatten Matrix Double
momenta

-- Compute exponent of kinetic energy.
--
-- Use a general matrix which has special representations for diagonal and
-- sparse matrices, both of which are really useful here.
exponentialKineticEnergy ::
  MassesI ->
  Momenta ->
  Log Double
exponentialKineticEnergy :: MassesI -> Positions -> Log Double
exponentialKineticEnergy MassesI
msI Positions
xs =
  -- NOTE: Because of numerical errors, the following formulas exhibit different
  -- traces (although the posterior appears to be the same):
  -- - This one we cannot use with general matrices:
  --   Exp $ (-0.5) * ((xs L.#> msI) L.<.> xs)
  forall a. a -> Log a
Exp forall a b. (a -> b) -> a -> b
$ (-Double
0.5) forall a. Num a => a -> a -> a
* (Positions
xs forall t. Numeric t => Vector t -> Vector t -> t
L.<.> (MassesI
msI MassesI -> Positions -> Positions
L.!#> Positions
xs))

-- Function calculating target value and gradient.
--
-- The function acts on the subset of the state manipulated by the proposal but
-- the value and gradient have to be calculated for the complete state. The
-- reason is that parameters untouched by the Hamiltonian proposal may affect
-- the result or the gradient.
--
-- Make sure that the value is calculated lazily because many times, only the
-- gradient is required.
type Target = Positions -> (Log Double, Positions)

-- Leapfrog integrator.
leapfrog ::
  Target ->
  MassesI ->
  --
  LeapfrogTrajectoryLength ->
  LeapfrogScalingFactor ->
  --
  Positions ->
  Momenta ->
  -- | (New positions, new momenta, old target, new target).
  --
  -- Fail if state is not valid.
  Maybe (Positions, Momenta, Log Double, Log Double)
leapfrog :: Target
-> MassesI
-> LeapfrogTrajectoryLength
-> Double
-> Positions
-> Positions
-> Maybe (Positions, Positions, Log Double, Log Double)
leapfrog Target
tF MassesI
msI LeapfrogTrajectoryLength
l Double
eps Positions
q Positions
p = do
  -- The first half step of the momenta.
  (Log Double
x, Positions
pHalf) <-
    let (Log Double
x, Positions
pHalf) = Double -> Target -> Positions -> Target
leapfrogStepMomenta (Double
0.5 forall a. Num a => a -> a -> a
* Double
eps) Target
tF Positions
q Positions
p
     in if Log Double
x forall a. Ord a => a -> a -> Bool
> Log Double
0.0
          then forall a. a -> Maybe a
Just (Log Double
x, Positions
pHalf)
          else forall a. Maybe a
Nothing
  -- L-1 full steps for positions and momenta. This gives the positions q_{L-1},
  -- and the momenta p_{L-1/2}.
  (Positions
qLM1, Positions
pLM1Half) <- forall {t}.
(Ord t, Num t) =>
t -> Maybe (Positions, Positions) -> Maybe (Positions, Positions)
go (LeapfrogTrajectoryLength
l forall a. Num a => a -> a -> a
- LeapfrogTrajectoryLength
1) forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Positions
q, Positions
pHalf)
  -- The last full step of the positions.
  let qL :: Positions
qL = MassesI -> Double -> Positions -> Positions -> Positions
leapfrogStepPositions MassesI
msI Double
eps Positions
qLM1 Positions
pLM1Half
  -- The last half step of the momenta.
  (Log Double
x', Positions
pL) <-
    let (Log Double
x', Positions
pL) = Double -> Target -> Positions -> Target
leapfrogStepMomenta (Double
0.5 forall a. Num a => a -> a -> a
* Double
eps) Target
tF Positions
qL Positions
pLM1Half
     in if Log Double
x' forall a. Ord a => a -> a -> Bool
> Log Double
0.0
          then forall a. a -> Maybe a
Just (Log Double
x', Positions
pL)
          else forall a. Maybe a
Nothing
  forall (m :: * -> *) a. Monad m => a -> m a
return (Positions
qL, Positions
pL, Log Double
x, Log Double
x')
  where
    go :: t -> Maybe (Positions, Positions) -> Maybe (Positions, Positions)
go t
_ Maybe (Positions, Positions)
Nothing = forall a. Maybe a
Nothing
    go t
n (Just (Positions
qs, Positions
ps))
      | t
n forall a. Ord a => a -> a -> Bool
<= t
0 = forall a. a -> Maybe a
Just (Positions
qs, Positions
ps)
      | Bool
otherwise =
          let qs' :: Positions
qs' = MassesI -> Double -> Positions -> Positions -> Positions
leapfrogStepPositions MassesI
msI Double
eps Positions
qs Positions
ps
              (Log Double
x, Positions
ps') = Double -> Target -> Positions -> Target
leapfrogStepMomenta Double
eps Target
tF Positions
qs' Positions
p
           in if Log Double
x forall a. Ord a => a -> a -> Bool
> Log Double
0.0
                then t -> Maybe (Positions, Positions) -> Maybe (Positions, Positions)
go (t
n forall a. Num a => a -> a -> a
- t
1) forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Positions
qs', Positions
ps')
                else forall a. Maybe a
Nothing

leapfrogStepMomenta ::
  LeapfrogScalingFactor ->
  Target ->
  -- Current positions.
  Positions ->
  -- Current momenta.
  Momenta ->
  -- New momenta; also return value target function to be collected at the end
  -- of the leapfrog integration.
  (Log Double, Momenta)
leapfrogStepMomenta :: Double -> Target -> Positions -> Target
leapfrogStepMomenta Double
eps Target
tf Positions
q Positions
p = (Log Double
x, Positions
p forall a. Num a => a -> a -> a
+ forall t (c :: * -> *). Linear t c => t -> c t -> c t
L.scale Double
eps Positions
g)
  where
    (Log Double
x, Positions
g) = Target
tf Positions
q

leapfrogStepPositions ::
  MassesI ->
  LeapfrogScalingFactor ->
  -- Current positions.
  Positions ->
  -- Current momenta.
  Momenta ->
  -- New positions.
  Positions
-- NOTE: Because of numerical errors, the following formulas exhibit different
-- traces (although the posterior appears to be the same):
-- 1. This one we cannot use with general matrices:
--    leapfrogStepPositions msI eps q p = q + (L.scale eps msI L.!#> p)
-- 2. This one seems to be more numerically unstable:
--    leapfrogStepPositions msI eps q p = q + L.scale eps (msI L.!#> p)
leapfrogStepPositions :: MassesI -> Double -> Positions -> Positions -> Positions
leapfrogStepPositions MassesI
msI Double
eps Positions
q Positions
p = Positions
q forall a. Num a => a -> a -> a
+ (MassesI
msI MassesI -> Positions -> Positions
L.!#> forall t (c :: * -> *). Linear t c => t -> c t -> c t
L.scale Double
eps Positions
p)