{-# LANGUAGE BangPatterns #-}
module Mcmc.Proposal.Hamiltonian.Internal
(
HParamsI (..),
hParamsIWith,
toAuxiliaryTuningParameters,
fromAuxiliaryTuningParameters,
findReasonableEpsilon,
hTuningFunctionWith,
checkHStructureWith,
generateMomenta,
exponentialKineticEnergy,
Target,
leapfrog,
)
where
import Control.Monad
import Control.Monad.ST
import Data.Foldable
import Data.Maybe
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Unboxed as VU
import Mcmc.Proposal
import Mcmc.Proposal.Hamiltonian.Common
import Mcmc.Proposal.Hamiltonian.Masses
import qualified Numeric.LinearAlgebra as L
import Numeric.Log
import System.Random.Stateful
data TParamsVar = TParamsVar
{
TParamsVar -> Double
tpvLeapfrogScalingFactorMean :: LeapfrogScalingFactor,
TParamsVar -> Double
tpvHStatistics :: Double,
TParamsVar -> Double
tpvCurrentTuningStep :: Double
}
deriving (LeapfrogTrajectoryLength -> TParamsVar -> ShowS
[TParamsVar] -> ShowS
TParamsVar -> String
forall a.
(LeapfrogTrajectoryLength -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TParamsVar] -> ShowS
$cshowList :: [TParamsVar] -> ShowS
show :: TParamsVar -> String
$cshow :: TParamsVar -> String
showsPrec :: LeapfrogTrajectoryLength -> TParamsVar -> ShowS
$cshowsPrec :: LeapfrogTrajectoryLength -> TParamsVar -> ShowS
Show)
tParamsVar :: TParamsVar
tParamsVar :: TParamsVar
tParamsVar = Double -> Double -> Double -> TParamsVar
TParamsVar Double
1.0 Double
0.0 Double
1.0
data TParamsFixed = TParamsFixed
{ TParamsFixed -> Double
tpfEps0 :: Double,
TParamsFixed -> Double
tpfMu :: Double,
TParamsFixed -> Double
tpfGa :: Double,
TParamsFixed -> Double
tpfT0 :: Double,
TParamsFixed -> Double
tpfKa :: Double
}
deriving (LeapfrogTrajectoryLength -> TParamsFixed -> ShowS
[TParamsFixed] -> ShowS
TParamsFixed -> String
forall a.
(LeapfrogTrajectoryLength -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TParamsFixed] -> ShowS
$cshowList :: [TParamsFixed] -> ShowS
show :: TParamsFixed -> String
$cshow :: TParamsFixed -> String
showsPrec :: LeapfrogTrajectoryLength -> TParamsFixed -> ShowS
$cshowsPrec :: LeapfrogTrajectoryLength -> TParamsFixed -> ShowS
Show)
tParamsFixedWith :: LeapfrogScalingFactor -> TParamsFixed
tParamsFixedWith :: Double -> TParamsFixed
tParamsFixedWith Double
eps = Double -> Double -> Double -> Double -> Double -> TParamsFixed
TParamsFixed Double
eps Double
mu Double
ga Double
t0 Double
ka
where
mu :: Double
mu = forall a. Floating a => a -> a
log forall a b. (a -> b) -> a -> b
$ Double
10 forall a. Num a => a -> a -> a
* Double
eps
ga :: Double
ga = Double
0.15
t0 :: Double
t0 = Double
10
ka :: Double
ka = Double
0.75
data HParamsI = HParamsI
{ HParamsI -> Double
hpsLeapfrogScalingFactor :: LeapfrogScalingFactor,
HParamsI -> Double
hpsLeapfrogSimulationLength :: LeapfrogSimulationLength,
HParamsI -> Masses
hpsMasses :: Masses,
HParamsI -> TParamsVar
hpsTParamsVar :: TParamsVar,
HParamsI -> TParamsFixed
hpsTParamsFixed :: TParamsFixed,
HParamsI -> MassesI
hpsMassesI :: MassesI,
HParamsI -> Positions
hpsMu :: Mu
}
deriving (LeapfrogTrajectoryLength -> HParamsI -> ShowS
[HParamsI] -> ShowS
HParamsI -> String
forall a.
(LeapfrogTrajectoryLength -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [HParamsI] -> ShowS
$cshowList :: [HParamsI] -> ShowS
show :: HParamsI -> String
$cshow :: HParamsI -> String
showsPrec :: LeapfrogTrajectoryLength -> HParamsI -> ShowS
$cshowsPrec :: LeapfrogTrajectoryLength -> HParamsI -> ShowS
Show)
defaultLeapfrogScalingFactor :: LeapfrogScalingFactor
defaultLeapfrogScalingFactor :: Double
defaultLeapfrogScalingFactor = Double
0.1
defaultLeapfrogSimulationLength :: LeapfrogSimulationLength
defaultLeapfrogSimulationLength :: Double
defaultLeapfrogSimulationLength = Double
0.5
defaultMassesWith :: Int -> Masses
defaultMassesWith :: LeapfrogTrajectoryLength -> Masses
defaultMassesWith LeapfrogTrajectoryLength
d = forall t. Matrix t -> Herm t
L.trustSym forall a b. (a -> b) -> a -> b
$ forall a.
(Num a, Element a) =>
LeapfrogTrajectoryLength -> Matrix a
L.ident LeapfrogTrajectoryLength
d
hParamsIWith ::
Target ->
Positions ->
Maybe LeapfrogScalingFactor ->
Maybe LeapfrogSimulationLength ->
Maybe Masses ->
Either String HParamsI
hParamsIWith :: Target
-> Positions
-> Maybe Double
-> Maybe Double
-> Maybe Masses
-> Either String HParamsI
hParamsIWith Target
htarget Positions
p Maybe Double
mEps Maybe Double
mLa Maybe Masses
mMs = do
LeapfrogTrajectoryLength
d <- case forall a. Storable a => Vector a -> LeapfrogTrajectoryLength
VS.length Positions
p of
LeapfrogTrajectoryLength
0 -> forall {b}. String -> Either String b
eWith String
"Empty position vector."
LeapfrogTrajectoryLength
d -> forall a b. b -> Either a b
Right LeapfrogTrajectoryLength
d
Masses
ms <- case Maybe Masses
mMs of
Maybe Masses
Nothing -> forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ LeapfrogTrajectoryLength -> Masses
defaultMassesWith LeapfrogTrajectoryLength
d
Just Masses
ms -> do
let ms' :: Matrix Double
ms' = Matrix Double -> Matrix Double
cleanMatrix forall a b. (a -> b) -> a -> b
$ forall t. Herm t -> Matrix t
L.unSym Masses
ms
diagonalMs :: [Double]
diagonalMs = forall a. Storable a => Vector a -> [a]
L.toList forall a b. (a -> b) -> a -> b
$ forall t. Element t => Matrix t -> Vector t
L.takeDiag Matrix Double
ms'
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any (forall a. Ord a => a -> a -> Bool
<= Double
0) [Double]
diagonalMs) forall a b. (a -> b) -> a -> b
$ forall {b}. String -> Either String b
eWith String
"Some diagonal masses are zero or negative."
let nrows :: LeapfrogTrajectoryLength
nrows = forall t. Matrix t -> LeapfrogTrajectoryLength
L.rows Matrix Double
ms'
ncols :: LeapfrogTrajectoryLength
ncols = forall t. Matrix t -> LeapfrogTrajectoryLength
L.cols Matrix Double
ms'
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (LeapfrogTrajectoryLength
nrows forall a. Eq a => a -> a -> Bool
/= LeapfrogTrajectoryLength
ncols) forall a b. (a -> b) -> a -> b
$ forall {b}. String -> Either String b
eWith String
"Mass matrix is not square."
forall a b. b -> Either a b
Right Masses
ms
let msI :: MassesI
msI = Masses -> MassesI
getMassesI Masses
ms
mus :: Positions
mus = Masses -> Positions
getMus Masses
ms
Double
la <- case Maybe Double
mLa of
Maybe Double
Nothing -> forall a b. b -> Either a b
Right Double
defaultLeapfrogSimulationLength
Just Double
l
| Double
l forall a. Ord a => a -> a -> Bool
<= Double
0 -> forall {b}. String -> Either String b
eWith String
"Leapfrog simulation length is zero or negative."
| Bool
otherwise -> forall a b. b -> Either a b
Right Double
l
Double
eps <- case Maybe Double
mEps of
Maybe Double
Nothing -> forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ forall a. (forall s. ST s a) -> a
runST forall a b. (a -> b) -> a -> b
$ do
STGenM StdGen s
g <- forall g s. g -> ST s (STGenM g s)
newSTGenM forall a b. (a -> b) -> a -> b
$ LeapfrogTrajectoryLength -> StdGen
mkStdGen LeapfrogTrajectoryLength
42
forall g (m :: * -> *).
StatefulGen g m =>
Target -> Masses -> Positions -> g -> m Double
findReasonableEpsilon Target
htarget Masses
ms Positions
p STGenM StdGen s
g
Just Double
e
| Double
e forall a. Ord a => a -> a -> Bool
<= Double
0 -> forall {b}. String -> Either String b
eWith String
"Leapfrog scaling factor is zero or negative."
| Bool
otherwise -> forall a b. b -> Either a b
Right Double
e
let tParamsFixed :: TParamsFixed
tParamsFixed = Double -> TParamsFixed
tParamsFixedWith Double
eps
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Double
-> Double
-> Masses
-> TParamsVar
-> TParamsFixed
-> MassesI
-> Positions
-> HParamsI
HParamsI Double
eps Double
la Masses
ms TParamsVar
tParamsVar TParamsFixed
tParamsFixed MassesI
msI Positions
mus
where
eWith :: String -> Either String b
eWith String
m = forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ String
"hParamsIWith: " forall a. Semigroup a => a -> a -> a
<> String
m
toAuxiliaryTuningParameters :: HParamsI -> AuxiliaryTuningParameters
toAuxiliaryTuningParameters :: HParamsI -> AuxiliaryTuningParameters
toAuxiliaryTuningParameters (HParamsI Double
eps Double
la Masses
ms TParamsVar
tpv TParamsFixed
tpf MassesI
_ Positions
_) =
forall a. Unbox a => [a] -> Vector a
VU.fromList forall a b. (a -> b) -> a -> b
$ Double
eps forall a. a -> [a] -> [a]
: Double
la forall a. a -> [a] -> [a]
: Double
epsMean forall a. a -> [a] -> [a]
: Double
h forall a. a -> [a] -> [a]
: Double
m forall a. a -> [a] -> [a]
: Double
eps0 forall a. a -> [a] -> [a]
: Double
mu forall a. a -> [a] -> [a]
: Double
ga forall a. a -> [a] -> [a]
: Double
t0 forall a. a -> [a] -> [a]
: Double
ka forall a. a -> [a] -> [a]
: [Double]
msL
where
(TParamsVar Double
epsMean Double
h Double
m) = TParamsVar
tpv
(TParamsFixed Double
eps0 Double
mu Double
ga Double
t0 Double
ka) = TParamsFixed
tpf
msL :: [Double]
msL = forall a. Unbox a => Vector a -> [a]
VU.toList forall a b. (a -> b) -> a -> b
$ Masses -> AuxiliaryTuningParameters
massesToVector Masses
ms
fromAuxiliaryTuningParameters :: Dimension -> AuxiliaryTuningParameters -> Either String HParamsI
fromAuxiliaryTuningParameters :: LeapfrogTrajectoryLength
-> AuxiliaryTuningParameters -> Either String HParamsI
fromAuxiliaryTuningParameters LeapfrogTrajectoryLength
d AuxiliaryTuningParameters
xs
| (LeapfrogTrajectoryLength
d forall a. Num a => a -> a -> a
* LeapfrogTrajectoryLength
d) forall a. Num a => a -> a -> a
+ LeapfrogTrajectoryLength
10 forall a. Eq a => a -> a -> Bool
/= LeapfrogTrajectoryLength
len = forall a b. a -> Either a b
Left String
"fromAuxiliaryTuningParameters: Dimension mismatch."
| forall a b. (Integral a, Num b) => a -> b
fromIntegral (LeapfrogTrajectoryLength
d forall a. Num a => a -> a -> a
* LeapfrogTrajectoryLength
d) forall a. Eq a => a -> a -> Bool
/= LeapfrogTrajectoryLength
lenMs = forall a b. a -> Either a b
Left String
"fromAuxiliaryTuningParameters: Masses dimension mismatch."
| Bool
otherwise = case forall a. Unbox a => Vector a -> [a]
VU.toList forall a b. (a -> b) -> a -> b
$ forall a.
Unbox a =>
LeapfrogTrajectoryLength -> Vector a -> Vector a
VU.take LeapfrogTrajectoryLength
10 AuxiliaryTuningParameters
xs of
[Double
eps, Double
la, Double
epsMean, Double
h, Double
m, Double
eps0, Double
mu, Double
ga, Double
t0, Double
ka] ->
let tpv :: TParamsVar
tpv = Double -> Double -> Double -> TParamsVar
TParamsVar Double
epsMean Double
h Double
m
tpf :: TParamsFixed
tpf = Double -> Double -> Double -> Double -> Double -> TParamsFixed
TParamsFixed Double
eps0 Double
mu Double
ga Double
t0 Double
ka
in forall a b. b -> Either a b
Right forall a b. (a -> b) -> a -> b
$ Double
-> Double
-> Masses
-> TParamsVar
-> TParamsFixed
-> MassesI
-> Positions
-> HParamsI
HParamsI Double
eps Double
la Masses
ms TParamsVar
tpv TParamsFixed
tpf MassesI
msI Positions
mus
[Double]
_ -> forall a b. a -> Either a b
Left String
"fromAuxiliaryTuningParameters: Impossible dimension mismatch."
where
len :: LeapfrogTrajectoryLength
len = forall a. Unbox a => Vector a -> LeapfrogTrajectoryLength
VU.length AuxiliaryTuningParameters
xs
msV :: AuxiliaryTuningParameters
msV = forall a.
Unbox a =>
LeapfrogTrajectoryLength -> Vector a -> Vector a
VU.drop LeapfrogTrajectoryLength
10 AuxiliaryTuningParameters
xs
lenMs :: LeapfrogTrajectoryLength
lenMs = forall a. Unbox a => Vector a -> LeapfrogTrajectoryLength
VU.length AuxiliaryTuningParameters
msV
ms :: Masses
ms = LeapfrogTrajectoryLength -> AuxiliaryTuningParameters -> Masses
vectorToMasses LeapfrogTrajectoryLength
d AuxiliaryTuningParameters
msV
msI :: MassesI
msI = Masses -> MassesI
getMassesI Masses
ms
mus :: Positions
mus = Masses -> Positions
getMus Masses
ms
findReasonableEpsilon ::
StatefulGen g m =>
Target ->
Masses ->
Positions ->
g ->
m LeapfrogScalingFactor
findReasonableEpsilon :: forall g (m :: * -> *).
StatefulGen g m =>
Target -> Masses -> Positions -> g -> m Double
findReasonableEpsilon Target
t Masses
ms Positions
q g
g = do
Positions
p <- forall g (m :: * -> *).
StatefulGen g m =>
Positions -> Masses -> g -> m Positions
generateMomenta Positions
mu Masses
ms g
g
case Target
-> MassesI
-> LeapfrogTrajectoryLength
-> Double
-> Positions
-> Positions
-> Maybe (Positions, Positions, Log Double, Log Double)
leapfrog Target
t MassesI
msI LeapfrogTrajectoryLength
1 Double
eI Positions
q Positions
p of
Maybe (Positions, Positions, Log Double, Log Double)
Nothing -> forall (f :: * -> *) a. Applicative f => a -> f a
pure Double
defaultLeapfrogScalingFactor
Just (Positions
_, Positions
p', Log Double
prQ, Log Double
prQ') -> do
let expEKin :: Log Double
expEKin = MassesI -> Positions -> Log Double
exponentialKineticEnergy MassesI
msI Positions
p
expEKin' :: Log Double
expEKin' = MassesI -> Positions -> Log Double
exponentialKineticEnergy MassesI
msI Positions
p'
rI :: Double
rI :: Double
rI = forall a. Floating a => a -> a
exp forall a b. (a -> b) -> a -> b
$ forall a. Log a -> a
ln forall a b. (a -> b) -> a -> b
$ Log Double
prQ' forall a. Num a => a -> a -> a
* Log Double
expEKin' forall a. Fractional a => a -> a -> a
/ (Log Double
prQ forall a. Num a => a -> a -> a
* Log Double
expEKin)
a :: Double
a :: Double
a = if Double
rI forall a. Ord a => a -> a -> Bool
> Double
0.5 then Double
1 else (-Double
1)
go :: Double -> Double -> Double
go Double
e Double
r =
if Double
r forall a. Floating a => a -> a -> a
** Double
a forall a. Ord a => a -> a -> Bool
> Double
2 forall a. Floating a => a -> a -> a
** forall a. Num a => a -> a
negate Double
a
then case Target
-> MassesI
-> LeapfrogTrajectoryLength
-> Double
-> Positions
-> Positions
-> Maybe (Positions, Positions, Log Double, Log Double)
leapfrog Target
t MassesI
msI LeapfrogTrajectoryLength
1 Double
e Positions
q Positions
p of
Maybe (Positions, Positions, Log Double, Log Double)
Nothing -> Double
e
Just (Positions
_, Positions
p'', Log Double
_, Log Double
prQ'') ->
let expEKin'' :: Log Double
expEKin'' = MassesI -> Positions -> Log Double
exponentialKineticEnergy MassesI
msI Positions
p''
r' :: Double
r' :: Double
r' = forall a. Floating a => a -> a
exp forall a b. (a -> b) -> a -> b
$ forall a. Log a -> a
ln forall a b. (a -> b) -> a -> b
$ Log Double
prQ'' forall a. Num a => a -> a -> a
* Log Double
expEKin'' forall a. Fractional a => a -> a -> a
/ (Log Double
prQ forall a. Num a => a -> a -> a
* Log Double
expEKin)
e' :: Double
e' = (Double
2 forall a. Floating a => a -> a -> a
** Double
a) forall a. Num a => a -> a -> a
* Double
e
in Double -> Double -> Double
go Double
e' Double
r'
else Double
e
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ Double -> Double -> Double
go Double
eI Double
rI
where
eI :: Double
eI = Double
1.0
msI :: MassesI
msI = Masses -> MassesI
getMassesI Masses
ms
mu :: Positions
mu = Masses -> Positions
getMus Masses
ms
hTuningFunctionWith ::
Dimension ->
(a -> Positions) ->
HTuningConf ->
Maybe (TuningFunction a)
hTuningFunctionWith :: forall a.
LeapfrogTrajectoryLength
-> (a -> Positions) -> HTuningConf -> Maybe (TuningFunction a)
hTuningFunctionWith LeapfrogTrajectoryLength
_ a -> Positions
_ (HTuningConf HTuneLeapfrog
HNoTuneLeapfrog HTuneMasses
HNoTuneMasses) = forall a. Maybe a
Nothing
hTuningFunctionWith LeapfrogTrajectoryLength
n a -> Positions
toVec (HTuningConf HTuneLeapfrog
lc HTuneMasses
mc) = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ \TuningType
tt PDimension
pdim Maybe Double
mar Maybe (Vector a)
mxs (Double
_, !AuxiliaryTuningParameters
ts) ->
case TuningType
tt of
TuningType
IntermediateTuningFastProposalsOnly -> forall {a}. String -> a
err String
"fast intermediate tuning step but slow proposal"
TuningType
NormalTuningFastProposalsOnly -> forall {a}. String -> a
err String
"fast normal tuning step but slow proposal"
TuningType
_ ->
let (HParamsI Double
eps Double
la Masses
ms TParamsVar
tpv TParamsFixed
tpf MassesI
msI Positions
mus) =
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. HasCallStack => String -> a
error forall a. a -> a
id forall a b. (a -> b) -> a -> b
$ LeapfrogTrajectoryLength
-> AuxiliaryTuningParameters -> Either String HParamsI
fromAuxiliaryTuningParameters LeapfrogTrajectoryLength
n AuxiliaryTuningParameters
ts
(TParamsVar Double
epsMean Double
h Double
m) = TParamsVar
tpv
(TParamsFixed Double
eps0 Double
mu Double
ga Double
t0 Double
ka) = TParamsFixed
tpf
m' :: SmoothingParameter
m' = Natural -> SmoothingParameter
SmoothingParameter forall a b. (a -> b) -> a -> b
$ forall a b. (RealFrac a, Integral b) => a -> b
round Double
m
(Masses
ms', MassesI
msI') = case TuningType
tt of
TuningType
IntermediateTuningAllProposals -> (Masses
ms, MassesI
msI)
TuningType
_ ->
let xs :: Vector a
xs = forall a. a -> Maybe a -> a
fromMaybe (forall {a}. String -> a
err String
"empty trace") Maybe (Vector a)
mxs
in case HTuneMasses
mc of
HTuneMasses
HNoTuneMasses -> (Masses
ms, MassesI
msI)
HTuneMasses
HTuneDiagonalMassesOnly -> forall a.
SmoothingParameter
-> (a -> Positions)
-> Vector a
-> (Masses, MassesI)
-> (Masses, MassesI)
tuneDiagonalMassesOnly SmoothingParameter
m' a -> Positions
toVec Vector a
xs (Masses
ms, MassesI
msI)
HTuneMasses
HTuneAllMasses -> forall a.
SmoothingParameter
-> (a -> Positions)
-> Vector a
-> (Masses, MassesI)
-> (Masses, MassesI)
tuneAllMasses SmoothingParameter
m' a -> Positions
toVec Vector a
xs (Masses
ms, MassesI
msI)
(Double
eps'', Double
epsMean'', Double
h'') = case TuningType
tt of
TuningType
LastTuningFastProposalsOnly -> (Double
eps, Double
epsMean, Double
h)
TuningType
_ -> case HTuneLeapfrog
lc of
HTuneLeapfrog
HNoTuneLeapfrog -> (Double
eps, Double
epsMean, Double
h)
HTuneLeapfrog
HTuneLeapfrog ->
let ar :: Double
ar = forall a. a -> Maybe a -> a
fromMaybe (forall {a}. String -> a
err String
"no acceptance rate") Maybe Double
mar
delta :: Double
delta = PDimension -> Double
getOptimalRate PDimension
pdim
c :: Double
c = forall a. Fractional a => a -> a
recip forall a b. (a -> b) -> a -> b
$ Double
m forall a. Num a => a -> a -> a
+ Double
t0
h' :: Double
h' = (Double
1.0 forall a. Num a => a -> a -> a
- Double
c) forall a. Num a => a -> a -> a
* Double
h forall a. Num a => a -> a -> a
+ Double
c forall a. Num a => a -> a -> a
* (Double
delta forall a. Num a => a -> a -> a
- Double
ar)
eps' :: Double
eps' = forall a. Floating a => a -> a
exp forall a b. (a -> b) -> a -> b
$ Double
mu forall a. Num a => a -> a -> a
- (forall a. Floating a => a -> a
sqrt Double
m forall a. Fractional a => a -> a -> a
/ Double
ga) forall a. Num a => a -> a -> a
* Double
h'
mMKa :: Double
mMKa = Double
m forall a. Floating a => a -> a -> a
** forall a. Num a => a -> a
negate Double
ka
epsMean' :: Double
epsMean' = (Double
eps' forall a. Floating a => a -> a -> a
** Double
mMKa) forall a. Num a => a -> a -> a
* (Double
epsMean forall a. Floating a => a -> a -> a
** (Double
1 forall a. Num a => a -> a -> a
- Double
mMKa))
epsF :: Double
epsF = if TuningType
tt forall a. Eq a => a -> a -> Bool
== TuningType
LastTuningAllProposals then Double
epsMean' else Double
eps'
in (Double
epsF, Double
epsMean', Double
h')
tpv' :: TParamsVar
tpv' = Double -> Double -> Double -> TParamsVar
TParamsVar Double
epsMean'' Double
h'' (Double
m forall a. Num a => a -> a -> a
+ Double
1.0)
in (Double
eps'' forall a. Fractional a => a -> a -> a
/ Double
eps0, HParamsI -> AuxiliaryTuningParameters
toAuxiliaryTuningParameters forall a b. (a -> b) -> a -> b
$ Double
-> Double
-> Masses
-> TParamsVar
-> TParamsFixed
-> MassesI
-> Positions
-> HParamsI
HParamsI Double
eps'' Double
la Masses
ms' TParamsVar
tpv' TParamsFixed
tpf MassesI
msI' Positions
mus)
where
err :: String -> a
err String
msg = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"hTuningFunctionWith: " forall a. Semigroup a => a -> a -> a
<> String
msg
checkHStructureWith :: Foldable s => Masses -> HStructure s -> Maybe String
checkHStructureWith :: forall (s :: * -> *).
Foldable s =>
Masses -> HStructure s -> Maybe String
checkHStructureWith Masses
ms (HStructure s Double
x s Double -> Positions
toVec s Double -> Positions -> s Double
fromVec)
| forall (t :: * -> *) a. Foldable t => t a -> [a]
toList (s Double -> Positions -> s Double
fromVec s Double
x Positions
xVec) forall a. Eq a => a -> a -> Bool
/= forall (t :: * -> *) a. Foldable t => t a -> [a]
toList s Double
x = String -> Maybe String
eWith String
"'fromVectorWith x (toVector x) /= x' for sample state."
| forall (c :: * -> *) t. Container c t => c t -> IndexOf c
L.size Positions
xVec forall a. Eq a => a -> a -> Bool
/= LeapfrogTrajectoryLength
nrows = String -> Maybe String
eWith String
"Mass matrix and 'toVector x' have different sizes for sample state."
| Bool
otherwise = forall a. Maybe a
Nothing
where
eWith :: String -> Maybe String
eWith String
m = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ String
"checkHStructureWith: " forall a. Semigroup a => a -> a -> a
<> String
m
nrows :: LeapfrogTrajectoryLength
nrows = forall t. Matrix t -> LeapfrogTrajectoryLength
L.rows forall a b. (a -> b) -> a -> b
$ forall t. Herm t -> Matrix t
L.unSym Masses
ms
xVec :: Positions
xVec = s Double -> Positions
toVec s Double
x
generateMomenta ::
StatefulGen g m =>
Mu ->
Masses ->
g ->
m Momenta
generateMomenta :: forall g (m :: * -> *).
StatefulGen g m =>
Positions -> Masses -> g -> m Positions
generateMomenta Positions
mu Masses
masses g
gen = do
LeapfrogTrajectoryLength
seed <- forall a g (m :: * -> *). (Uniform a, StatefulGen g m) => g -> m a
uniformM g
gen
let momenta :: Matrix Double
momenta = LeapfrogTrajectoryLength
-> LeapfrogTrajectoryLength -> Positions -> Masses -> Matrix Double
L.gaussianSample LeapfrogTrajectoryLength
seed LeapfrogTrajectoryLength
1 Positions
mu Masses
masses
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ forall t. Element t => Matrix t -> Vector t
L.flatten Matrix Double
momenta
exponentialKineticEnergy ::
MassesI ->
Momenta ->
Log Double
exponentialKineticEnergy :: MassesI -> Positions -> Log Double
exponentialKineticEnergy MassesI
msI Positions
xs =
forall a. a -> Log a
Exp forall a b. (a -> b) -> a -> b
$ (-Double
0.5) forall a. Num a => a -> a -> a
* (Positions
xs forall t. Numeric t => Vector t -> Vector t -> t
L.<.> (MassesI
msI MassesI -> Positions -> Positions
L.!#> Positions
xs))
type Target = Positions -> (Log Double, Positions)
leapfrog ::
Target ->
MassesI ->
LeapfrogTrajectoryLength ->
LeapfrogScalingFactor ->
Positions ->
Momenta ->
Maybe (Positions, Momenta, Log Double, Log Double)
leapfrog :: Target
-> MassesI
-> LeapfrogTrajectoryLength
-> Double
-> Positions
-> Positions
-> Maybe (Positions, Positions, Log Double, Log Double)
leapfrog Target
tF MassesI
msI LeapfrogTrajectoryLength
l Double
eps Positions
q Positions
p = do
(Log Double
x, Positions
pHalf) <-
let (Log Double
x, Positions
pHalf) = Double -> Target -> Positions -> Target
leapfrogStepMomenta (Double
0.5 forall a. Num a => a -> a -> a
* Double
eps) Target
tF Positions
q Positions
p
in if Log Double
x forall a. Ord a => a -> a -> Bool
> Log Double
0.0
then forall a. a -> Maybe a
Just (Log Double
x, Positions
pHalf)
else forall a. Maybe a
Nothing
(Positions
qLM1, Positions
pLM1Half) <- forall {t}.
(Ord t, Num t) =>
t -> Maybe (Positions, Positions) -> Maybe (Positions, Positions)
go (LeapfrogTrajectoryLength
l forall a. Num a => a -> a -> a
- LeapfrogTrajectoryLength
1) forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Positions
q, Positions
pHalf)
let qL :: Positions
qL = MassesI -> Double -> Positions -> Positions -> Positions
leapfrogStepPositions MassesI
msI Double
eps Positions
qLM1 Positions
pLM1Half
(Log Double
x', Positions
pL) <-
let (Log Double
x', Positions
pL) = Double -> Target -> Positions -> Target
leapfrogStepMomenta (Double
0.5 forall a. Num a => a -> a -> a
* Double
eps) Target
tF Positions
qL Positions
pLM1Half
in if Log Double
x' forall a. Ord a => a -> a -> Bool
> Log Double
0.0
then forall a. a -> Maybe a
Just (Log Double
x', Positions
pL)
else forall a. Maybe a
Nothing
forall (m :: * -> *) a. Monad m => a -> m a
return (Positions
qL, Positions
pL, Log Double
x, Log Double
x')
where
go :: t -> Maybe (Positions, Positions) -> Maybe (Positions, Positions)
go t
_ Maybe (Positions, Positions)
Nothing = forall a. Maybe a
Nothing
go t
n (Just (Positions
qs, Positions
ps))
| t
n forall a. Ord a => a -> a -> Bool
<= t
0 = forall a. a -> Maybe a
Just (Positions
qs, Positions
ps)
| Bool
otherwise =
let qs' :: Positions
qs' = MassesI -> Double -> Positions -> Positions -> Positions
leapfrogStepPositions MassesI
msI Double
eps Positions
qs Positions
ps
(Log Double
x, Positions
ps') = Double -> Target -> Positions -> Target
leapfrogStepMomenta Double
eps Target
tF Positions
qs' Positions
p
in if Log Double
x forall a. Ord a => a -> a -> Bool
> Log Double
0.0
then t -> Maybe (Positions, Positions) -> Maybe (Positions, Positions)
go (t
n forall a. Num a => a -> a -> a
- t
1) forall a b. (a -> b) -> a -> b
$ forall a. a -> Maybe a
Just (Positions
qs', Positions
ps')
else forall a. Maybe a
Nothing
leapfrogStepMomenta ::
LeapfrogScalingFactor ->
Target ->
Positions ->
Momenta ->
(Log Double, Momenta)
leapfrogStepMomenta :: Double -> Target -> Positions -> Target
leapfrogStepMomenta Double
eps Target
tf Positions
q Positions
p = (Log Double
x, Positions
p forall a. Num a => a -> a -> a
+ forall t (c :: * -> *). Linear t c => t -> c t -> c t
L.scale Double
eps Positions
g)
where
(Log Double
x, Positions
g) = Target
tf Positions
q
leapfrogStepPositions ::
MassesI ->
LeapfrogScalingFactor ->
Positions ->
Momenta ->
Positions
leapfrogStepPositions :: MassesI -> Double -> Positions -> Positions -> Positions
leapfrogStepPositions MassesI
msI Double
eps Positions
q Positions
p = Positions
q forall a. Num a => a -> a -> a
+ (MassesI
msI MassesI -> Positions -> Positions
L.!#> forall t (c :: * -> *). Linear t c => t -> c t -> c t
L.scale Double
eps Positions
p)