module Mcmc.Proposal.Hamiltonian.Nuts
( NParams (..),
defaultNParams,
nuts,
)
where
import Data.Bifunctor
import Mcmc.Acceptance
import Mcmc.Proposal
import Mcmc.Proposal.Hamiltonian.Common
import Mcmc.Proposal.Hamiltonian.Internal
import Mcmc.Proposal.Hamiltonian.Masses
import Numeric.AD.Double
import qualified Numeric.LinearAlgebra as L
import Numeric.Log
import System.Random.Stateful
type SliceVariable = Log Double
type Direction = Bool
type DoublingStep = Int
type NStepsOk = Int
type Alpha = Log Double
type NAlpha = Int
type BuildTreeReturnType = (Positions, Momenta, Positions, Momenta, Positions, NStepsOk, Alpha, NAlpha)
deltaMax :: Log Double
deltaMax :: Log Double
deltaMax = Double -> Log Double
forall a. a -> Log a
Exp Double
1000
buildTreeWith ::
Log Double ->
MassesI ->
Target ->
IOGenM StdGen ->
Positions ->
Momenta ->
SliceVariable ->
Direction ->
DoublingStep ->
LeapfrogScalingFactor ->
IO (Maybe BuildTreeReturnType)
buildTreeWith :: Log Double
-> MassesI
-> Target
-> IOGenM StdGen
-> Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTreeWith Log Double
expETot0 MassesI
msI Target
tfun IOGenM StdGen
g Positions
q Positions
p Log Double
u Direction
v DoublingStep
j Double
e
| DoublingStep
j DoublingStep -> DoublingStep -> Direction
forall a. Ord a => a -> a -> Direction
<= DoublingStep
0 =
let e' :: Double
e' = if Direction
v then Double
e else Double -> Double
forall a. Num a => a -> a
negate Double
e
in case Target
-> MassesI
-> DoublingStep
-> Double
-> Positions
-> Positions
-> Maybe (Positions, Positions, Log Double, Log Double)
leapfrog Target
tfun MassesI
msI DoublingStep
1 Double
e' Positions
q Positions
p of
Maybe (Positions, Positions, Log Double, Log Double)
Nothing -> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe BuildTreeReturnType
forall a. Maybe a
Nothing
Just (Positions
q', Positions
p', Log Double
_, Log Double
expEPot') ->
if Direction
errorIsSmall
then Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType))
-> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall a b. (a -> b) -> a -> b
$ BuildTreeReturnType -> Maybe BuildTreeReturnType
forall a. a -> Maybe a
Just (Positions
q', Positions
p', Positions
q', Positions
p', Positions
q', DoublingStep
n', Log Double
alpha, DoublingStep
1)
else Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe BuildTreeReturnType
forall a. Maybe a
Nothing
where
expEKin' :: Log Double
expEKin' = MassesI -> Positions -> Log Double
exponentialKineticEnergy MassesI
msI Positions
p'
expETot' :: Log Double
expETot' = Log Double
expEPot' Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
expEKin'
n' :: DoublingStep
n' = if Log Double
u Log Double -> Log Double -> Direction
forall a. Ord a => a -> a -> Direction
<= Log Double
expEPot' Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
expEKin' then DoublingStep
1 else DoublingStep
0
errorIsSmall :: Direction
errorIsSmall = Log Double
u Log Double -> Log Double -> Direction
forall a. Ord a => a -> a -> Direction
< Log Double
deltaMax Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
expETot'
alpha' :: Log Double
alpha' = Log Double
expETot' Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ Log Double
expETot0
alpha :: Log Double
alpha = Log Double -> Log Double -> Log Double
forall a. Ord a => a -> a -> a
min Log Double
1.0 Log Double
alpha'
| Direction
otherwise = do
Maybe BuildTreeReturnType
mr <- Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTree Positions
q Positions
p Log Double
u Direction
v (DoublingStep
j DoublingStep -> DoublingStep -> DoublingStep
forall a. Num a => a -> a -> a
- DoublingStep
1) Double
e
case Maybe BuildTreeReturnType
mr of
Maybe BuildTreeReturnType
Nothing -> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe BuildTreeReturnType
forall a. Maybe a
Nothing
Just (Positions
qm, Positions
pm, Positions
qp, Positions
pp, Positions
q', DoublingStep
n', Log Double
a', DoublingStep
na') -> do
Maybe BuildTreeReturnType
mr' <-
if Direction
v
then
do
Maybe BuildTreeReturnType
mr'' <- Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTree Positions
qp Positions
pp Log Double
u Direction
v (DoublingStep
j DoublingStep -> DoublingStep -> DoublingStep
forall a. Num a => a -> a -> a
- DoublingStep
1) Double
e
case Maybe BuildTreeReturnType
mr'' of
Maybe BuildTreeReturnType
Nothing -> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe BuildTreeReturnType
forall a. Maybe a
Nothing
Just (Positions
_, Positions
_, Positions
qp', Positions
pp', Positions
q'', DoublingStep
n'', Log Double
a'', DoublingStep
na'') ->
Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType))
-> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall a b. (a -> b) -> a -> b
$ BuildTreeReturnType -> Maybe BuildTreeReturnType
forall a. a -> Maybe a
Just (Positions
qm, Positions
pm, Positions
qp', Positions
pp', Positions
q'', DoublingStep
n'', Log Double
a'', DoublingStep
na'')
else
do
Maybe BuildTreeReturnType
mr'' <- Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTree Positions
qm Positions
pm Log Double
u Direction
v (DoublingStep
j DoublingStep -> DoublingStep -> DoublingStep
forall a. Num a => a -> a -> a
- DoublingStep
1) Double
e
case Maybe BuildTreeReturnType
mr'' of
Maybe BuildTreeReturnType
Nothing -> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe BuildTreeReturnType
forall a. Maybe a
Nothing
Just (Positions
qm', Positions
pm', Positions
_, Positions
_, Positions
q'', DoublingStep
n'', Log Double
a'', DoublingStep
na'') ->
Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType))
-> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall a b. (a -> b) -> a -> b
$ BuildTreeReturnType -> Maybe BuildTreeReturnType
forall a. a -> Maybe a
Just (Positions
qm', Positions
pm', Positions
qp, Positions
pp, Positions
q'', DoublingStep
n'', Log Double
a'', DoublingStep
na'')
case Maybe BuildTreeReturnType
mr' of
Maybe BuildTreeReturnType
Nothing -> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe BuildTreeReturnType
forall a. Maybe a
Nothing
Just (Positions
qm'', Positions
pm'', Positions
qp'', Positions
pp'', Positions
q''', DoublingStep
n''', Log Double
a''', DoublingStep
na''') -> do
Double
b <- (Double, Double) -> IOGenM StdGen -> IO Double
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Double
0, Double
1) IOGenM StdGen
g :: IO Double
let q'''' :: Positions
q'''' = if Double
b Double -> Double -> Direction
forall a. Ord a => a -> a -> Direction
< DoublingStep -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral DoublingStep
n''' Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (DoublingStep -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (DoublingStep -> Double) -> DoublingStep -> Double
forall a b. (a -> b) -> a -> b
$ DoublingStep
n' DoublingStep -> DoublingStep -> DoublingStep
forall a. Num a => a -> a -> a
+ DoublingStep
n''') then Positions
q''' else Positions
q'
a'''' :: Log Double
a'''' = Log Double
a' Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
+ Log Double
a'''
na'''' :: DoublingStep
na'''' = DoublingStep
na' DoublingStep -> DoublingStep -> DoublingStep
forall a. Num a => a -> a -> a
+ DoublingStep
na'''
n'''' :: DoublingStep
n'''' = DoublingStep
n' DoublingStep -> DoublingStep -> DoublingStep
forall a. Num a => a -> a -> a
+ DoublingStep
n'''
isUTurn :: Direction
isUTurn = let dq :: Positions
dq = (Positions
qp'' Positions -> Positions -> Positions
forall a. Num a => a -> a -> a
- Positions
qm'') in (Positions
dq Positions -> Positions -> Positions
forall a. Num a => a -> a -> a
* Positions
pm'' Positions -> Positions -> Direction
forall a. Ord a => a -> a -> Direction
< Positions
0) Direction -> Direction -> Direction
|| (Positions
dq Positions -> Positions -> Positions
forall a. Num a => a -> a -> a
* Positions
pp'' Positions -> Positions -> Direction
forall a. Ord a => a -> a -> Direction
< Positions
0)
if Direction
isUTurn
then Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe BuildTreeReturnType
forall a. Maybe a
Nothing
else Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType))
-> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall a b. (a -> b) -> a -> b
$ BuildTreeReturnType -> Maybe BuildTreeReturnType
forall a. a -> Maybe a
Just (Positions
qm'', Positions
pm'', Positions
qp'', Positions
pp'', Positions
q'''', DoublingStep
n'''', Log Double
a'''', DoublingStep
na'''')
where
buildTree :: Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTree = Log Double
-> MassesI
-> Target
-> IOGenM StdGen
-> Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTreeWith Log Double
expETot0 MassesI
msI Target
tfun IOGenM StdGen
g
data NParams = NParams
{ NParams -> Maybe Double
nLeapfrogScalingFactor :: Maybe LeapfrogScalingFactor,
NParams -> Maybe Masses
nMasses :: Maybe Masses
}
deriving (DoublingStep -> NParams -> ShowS
[NParams] -> ShowS
NParams -> String
(DoublingStep -> NParams -> ShowS)
-> (NParams -> String) -> ([NParams] -> ShowS) -> Show NParams
forall a.
(DoublingStep -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [NParams] -> ShowS
$cshowList :: [NParams] -> ShowS
show :: NParams -> String
$cshow :: NParams -> String
showsPrec :: DoublingStep -> NParams -> ShowS
$cshowsPrec :: DoublingStep -> NParams -> ShowS
Show)
defaultNParams :: NParams
defaultNParams :: NParams
defaultNParams = Maybe Double -> Maybe Masses -> NParams
NParams Maybe Double
forall a. Maybe a
Nothing Maybe Masses
forall a. Maybe a
Nothing
nutsPFunctionWithTuningParameters ::
Traversable s =>
Dimension ->
HStructure s ->
(s Double -> Target) ->
TuningParameter ->
AuxiliaryTuningParameters ->
Either String (PFunction (s Double))
nutsPFunctionWithTuningParameters :: DoublingStep
-> HStructure s
-> (s Double -> Target)
-> Double
-> AuxiliaryTuningParameters
-> Either String (PFunction (s Double))
nutsPFunctionWithTuningParameters DoublingStep
d HStructure s
hstruct s Double -> Target
targetWith Double
_ AuxiliaryTuningParameters
ts = do
HParamsI
hParamsI <- DoublingStep -> AuxiliaryTuningParameters -> Either String HParamsI
fromAuxiliaryTuningParameters DoublingStep
d AuxiliaryTuningParameters
ts
PFunction (s Double) -> Either String (PFunction (s Double))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PFunction (s Double) -> Either String (PFunction (s Double)))
-> PFunction (s Double) -> Either String (PFunction (s Double))
forall a b. (a -> b) -> a -> b
$ HParamsI
-> HStructure s -> (s Double -> Target) -> PFunction (s Double)
forall (s :: * -> *).
HParamsI
-> HStructure s -> (s Double -> Target) -> PFunction (s Double)
nutsPFunction HParamsI
hParamsI HStructure s
hstruct s Double -> Target
targetWith
data IsNew
= Old
| OldWith {IsNew -> AcceptanceCounts
_acceptanceCountsOld :: AcceptanceCounts}
| NewWith {IsNew -> AcceptanceCounts
_acceptanceCountsNew :: AcceptanceCounts}
nutsPFunction ::
HParamsI ->
HStructure s ->
(s Double -> Target) ->
PFunction (s Double)
nutsPFunction :: HParamsI
-> HStructure s -> (s Double -> Target) -> PFunction (s Double)
nutsPFunction HParamsI
hparamsi HStructure s
hstruct s Double -> Target
targetWith s Double
x IOGenM StdGen
g = do
Positions
p <- Positions -> Masses -> IOGenM StdGen -> IO Positions
forall g (m :: * -> *).
StatefulGen g m =>
Positions -> Masses -> g -> m Positions
generateMomenta Positions
mus Masses
ms IOGenM StdGen
g
Double
uZeroOne <- (Double, Double) -> IOGenM StdGen -> IO Double
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Double
0, Double
1) IOGenM StdGen
g :: IO Double
let q :: Positions
q = s Double -> Positions
toVec s Double
x
expEPot :: Log Double
expEPot = (Log Double, Positions) -> Log Double
forall a b. (a, b) -> a
fst ((Log Double, Positions) -> Log Double)
-> (Log Double, Positions) -> Log Double
forall a b. (a -> b) -> a -> b
$ Target
target Positions
q
expEKin :: Log Double
expEKin = MassesI -> Positions -> Log Double
exponentialKineticEnergy MassesI
msI Positions
p
expETot :: Log Double
expETot = Log Double
expEPot Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
expEKin
uZeroOneL :: Log Double
uZeroOneL = Double -> Log Double
forall a. a -> Log a
Exp (Double -> Log Double) -> Double -> Log Double
forall a b. (a -> b) -> a -> b
$ Double -> Double
forall a. Floating a => a -> a
log Double
uZeroOne
u :: Log Double
u = Log Double
expETot Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
uZeroOneL
let
go :: Positions
-> Positions
-> Positions
-> Positions
-> DoublingStep
-> Positions
-> DoublingStep
-> IsNew
-> IO (Positions, IsNew)
go Positions
qm Positions
pm Positions
qp Positions
pp DoublingStep
j Positions
y DoublingStep
n IsNew
isNew = do
Direction
v <- IOGenM StdGen -> IO Direction
forall a g (m :: * -> *). (Uniform a, StatefulGen g m) => g -> m a
uniformM IOGenM StdGen
g :: IO Direction
Maybe BuildTreeReturnType
mr' <-
if Direction
v
then
do
Maybe BuildTreeReturnType
mr <- Log Double
-> MassesI
-> Target
-> IOGenM StdGen
-> Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTreeWith Log Double
expETot MassesI
msI Target
target IOGenM StdGen
g Positions
qp Positions
pp Log Double
u Direction
v DoublingStep
j Double
e
case Maybe BuildTreeReturnType
mr of
Maybe BuildTreeReturnType
Nothing -> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe BuildTreeReturnType
forall a. Maybe a
Nothing
Just (Positions
_, Positions
_, Positions
qp', Positions
pp', Positions
y', DoublingStep
n', Log Double
a, DoublingStep
na) -> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType))
-> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall a b. (a -> b) -> a -> b
$ BuildTreeReturnType -> Maybe BuildTreeReturnType
forall a. a -> Maybe a
Just (Positions
qm, Positions
pm, Positions
qp', Positions
pp', Positions
y', DoublingStep
n', Log Double
a, DoublingStep
na)
else
do
Maybe BuildTreeReturnType
mr <- Log Double
-> MassesI
-> Target
-> IOGenM StdGen
-> Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTreeWith Log Double
expETot MassesI
msI Target
target IOGenM StdGen
g Positions
qm Positions
pm Log Double
u Direction
v DoublingStep
j Double
e
case Maybe BuildTreeReturnType
mr of
Maybe BuildTreeReturnType
Nothing -> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe BuildTreeReturnType
forall a. Maybe a
Nothing
Just (Positions
qm', Positions
pm', Positions
_, Positions
_, Positions
y', DoublingStep
n', Log Double
a, DoublingStep
na) -> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType))
-> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall a b. (a -> b) -> a -> b
$ BuildTreeReturnType -> Maybe BuildTreeReturnType
forall a. a -> Maybe a
Just (Positions
qm', Positions
pm', Positions
qp, Positions
pp, Positions
y', DoublingStep
n', Log Double
a, DoublingStep
na)
case Maybe BuildTreeReturnType
mr' of
Maybe BuildTreeReturnType
Nothing -> (Positions, IsNew) -> IO (Positions, IsNew)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Positions
y, IsNew
isNew)
Just (Positions
qm'', Positions
pm'', Positions
qp'', Positions
pp'', Positions
y'', DoublingStep
n'', Log Double
a, DoublingStep
na) -> do
let r :: Double
r = DoublingStep -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral DoublingStep
n'' Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ DoublingStep -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral DoublingStep
n :: Double
ar :: Double
ar = (Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Log Double -> Double
forall a. Log a -> a
ln Log Double
a) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ DoublingStep -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral DoublingStep
na :: Double
getCounts :: a -> a
getCounts a
s = a -> a -> a
forall a. Ord a => a -> a -> a
max a
0 (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a -> a -> a
forall a. Ord a => a -> a -> a
min a
100 (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a -> a
forall a b. (RealFrac a, Integral b) => a -> b
round (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
s a -> a -> a
forall a. Num a => a -> a -> a
* a
100
ac :: AcceptanceCounts
ac =
if Double
ar Double -> Double -> Direction
forall a. Ord a => a -> a -> Direction
>= Double
0
then let cs :: DoublingStep
cs = Double -> DoublingStep
forall a a. (RealFrac a, Integral a) => a -> a
getCounts Double
ar in DoublingStep -> DoublingStep -> AcceptanceCounts
AcceptanceCounts DoublingStep
cs (DoublingStep
100 DoublingStep -> DoublingStep -> DoublingStep
forall a. Num a => a -> a -> a
- DoublingStep
cs)
else String -> AcceptanceCounts
forall a. HasCallStack => String -> a
error (String -> AcceptanceCounts) -> String -> AcceptanceCounts
forall a b. (a -> b) -> a -> b
$ String
"nutsPFunction: Acceptance rate negative."
Direction
isAccept <-
if Double
r Double -> Double -> Direction
forall a. Ord a => a -> a -> Direction
> Double
1.0
then Direction -> IO Direction
forall (f :: * -> *) a. Applicative f => a -> f a
pure Direction
True
else do
Double
b <- (Double, Double) -> IOGenM StdGen -> IO Double
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Double
0, Double
1) IOGenM StdGen
g
Direction -> IO Direction
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Direction -> IO Direction) -> Direction -> IO Direction
forall a b. (a -> b) -> a -> b
$ Double
b Double -> Double -> Direction
forall a. Ord a => a -> a -> Direction
< Double
r
let (Positions
y''', IsNew
isNew') = if Direction
isAccept then (Positions
y'', AcceptanceCounts -> IsNew
NewWith AcceptanceCounts
ac) else (Positions
y, AcceptanceCounts -> IsNew
OldWith AcceptanceCounts
ac)
isUTurn :: Direction
isUTurn = let dq :: Positions
dq = (Positions
qp'' Positions -> Positions -> Positions
forall a. Num a => a -> a -> a
- Positions
qm'') in (Positions
dq Positions -> Positions -> Positions
forall a. Num a => a -> a -> a
* Positions
pm'' Positions -> Positions -> Direction
forall a. Ord a => a -> a -> Direction
< Positions
0) Direction -> Direction -> Direction
|| (Positions
dq Positions -> Positions -> Positions
forall a. Num a => a -> a -> a
* Positions
pp'' Positions -> Positions -> Direction
forall a. Ord a => a -> a -> Direction
< Positions
0)
if Direction
isUTurn
then (Positions, IsNew) -> IO (Positions, IsNew)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Positions
y''', IsNew
isNew')
else Positions
-> Positions
-> Positions
-> Positions
-> DoublingStep
-> Positions
-> DoublingStep
-> IsNew
-> IO (Positions, IsNew)
go Positions
qm'' Positions
pm'' Positions
qp'' Positions
pp'' (DoublingStep
j DoublingStep -> DoublingStep -> DoublingStep
forall a. Num a => a -> a -> a
+ DoublingStep
1) Positions
y''' (DoublingStep
n DoublingStep -> DoublingStep -> DoublingStep
forall a. Num a => a -> a -> a
+ DoublingStep
n'') IsNew
isNew'
(Positions
x', IsNew
isNew) <- Positions
-> Positions
-> Positions
-> Positions
-> DoublingStep
-> Positions
-> DoublingStep
-> IsNew
-> IO (Positions, IsNew)
go Positions
q Positions
p Positions
q Positions
p DoublingStep
0 Positions
q DoublingStep
1 IsNew
Old
(PResult (s Double), Maybe AcceptanceCounts)
-> IO (PResult (s Double), Maybe AcceptanceCounts)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PResult (s Double), Maybe AcceptanceCounts)
-> IO (PResult (s Double), Maybe AcceptanceCounts))
-> (PResult (s Double), Maybe AcceptanceCounts)
-> IO (PResult (s Double), Maybe AcceptanceCounts)
forall a b. (a -> b) -> a -> b
$ case IsNew
isNew of
IsNew
Old -> (PResult (s Double)
forall a. PResult a
ForceReject, AcceptanceCounts -> Maybe AcceptanceCounts
forall a. a -> Maybe a
Just (AcceptanceCounts -> Maybe AcceptanceCounts)
-> AcceptanceCounts -> Maybe AcceptanceCounts
forall a b. (a -> b) -> a -> b
$ DoublingStep -> DoublingStep -> AcceptanceCounts
AcceptanceCounts DoublingStep
0 DoublingStep
100)
OldWith AcceptanceCounts
ac -> (PResult (s Double)
forall a. PResult a
ForceReject, AcceptanceCounts -> Maybe AcceptanceCounts
forall a. a -> Maybe a
Just (AcceptanceCounts -> Maybe AcceptanceCounts)
-> AcceptanceCounts -> Maybe AcceptanceCounts
forall a b. (a -> b) -> a -> b
$ AcceptanceCounts
ac)
NewWith AcceptanceCounts
ac -> (s Double -> PResult (s Double)
forall a. a -> PResult a
ForceAccept (s Double -> PResult (s Double)) -> s Double -> PResult (s Double)
forall a b. (a -> b) -> a -> b
$ Positions -> s Double
fromVec Positions
x', AcceptanceCounts -> Maybe AcceptanceCounts
forall a. a -> Maybe a
Just (AcceptanceCounts -> Maybe AcceptanceCounts)
-> AcceptanceCounts -> Maybe AcceptanceCounts
forall a b. (a -> b) -> a -> b
$ AcceptanceCounts
ac)
where
(HParamsI Double
e Double
_ Masses
ms TParamsVar
_ TParamsFixed
_ MassesI
msI Positions
mus) = HParamsI
hparamsi
(HStructure s Double
_ s Double -> Positions
toVec s Double -> Positions -> s Double
fromVecWith) = HStructure s
hstruct
fromVec :: Positions -> s Double
fromVec = s Double -> Positions -> s Double
fromVecWith s Double
x
target :: Target
target = s Double -> Target
targetWith s Double
x
nuts ::
Traversable s =>
NParams ->
HTuningConf ->
HStructure s ->
HTarget s ->
PName ->
PWeight ->
Proposal (s Double)
nuts :: NParams
-> HTuningConf
-> HStructure s
-> HTarget s
-> PName
-> PWeight
-> Proposal (s Double)
nuts NParams
nparams HTuningConf
htconf HStructure s
hstruct HTarget s
htarget PName
n PWeight
w =
let
desc :: PDescription
desc = String -> PDescription
PDescription String
"No U-turn sampler (NUTS)"
(HStructure s Double
sample s Double -> Positions
toVec s Double -> Positions -> s Double
fromVec) = HStructure s
hstruct
dim :: IndexOf Vector
dim = Positions -> IndexOf Vector
forall (c :: * -> *) t. Container c t => c t -> IndexOf c
L.size (Positions -> IndexOf Vector) -> Positions -> IndexOf Vector
forall a b. (a -> b) -> a -> b
$ s Double -> Positions
toVec s Double
sample
pDim :: PDimension
pDim = DoublingStep -> Double -> PDimension
PSpecial DoublingStep
dim Double
0.6
(HTarget forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mPrF forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mJcF) = HTarget s
htarget
tF :: s a -> LikelihoodG a
tF s a
y = case (Maybe (s a -> LikelihoodG a)
forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mPrF, Maybe (s a -> LikelihoodG a)
forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mJcF) of
(Maybe (s a -> LikelihoodG a)
Nothing, Maybe (s a -> LikelihoodG a)
Nothing) -> s a -> LikelihoodG a
forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
y
(Just s a -> LikelihoodG a
prF, Maybe (s a -> LikelihoodG a)
Nothing) -> s a -> LikelihoodG a
prF s a
y LikelihoodG a -> LikelihoodG a -> LikelihoodG a
forall a. Num a => a -> a -> a
* s a -> LikelihoodG a
forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
y
(Maybe (s a -> LikelihoodG a)
Nothing, Just s a -> LikelihoodG a
jcF) -> s a -> LikelihoodG a
forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
y LikelihoodG a -> LikelihoodG a -> LikelihoodG a
forall a. Num a => a -> a -> a
* s a -> LikelihoodG a
jcF s a
y
(Just s a -> LikelihoodG a
prF, Just s a -> LikelihoodG a
jcF) -> s a -> LikelihoodG a
prF s a
y LikelihoodG a -> LikelihoodG a -> LikelihoodG a
forall a. Num a => a -> a -> a
* s a -> LikelihoodG a
forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
y LikelihoodG a -> LikelihoodG a -> LikelihoodG a
forall a. Num a => a -> a -> a
* s a -> LikelihoodG a
jcF s a
y
tFnG :: s Double -> (Double, s Double)
tFnG = (forall s.
(Reifies s Tape, Typeable s) =>
s (ReverseDouble s) -> ReverseDouble s)
-> s Double -> (Double, s Double)
forall (f :: * -> *).
Traversable f =>
(forall s.
(Reifies s Tape, Typeable s) =>
f (ReverseDouble s) -> ReverseDouble s)
-> f Double -> (Double, f Double)
grad' (Log (ReverseDouble s) -> ReverseDouble s
forall a. Log a -> a
ln (Log (ReverseDouble s) -> ReverseDouble s)
-> (s (ReverseDouble s) -> Log (ReverseDouble s))
-> s (ReverseDouble s)
-> ReverseDouble s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s (ReverseDouble s) -> Log (ReverseDouble s)
forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
tF)
targetWith :: s Double -> Target
targetWith s Double
x = (Double -> Log Double)
-> (s Double -> Positions)
-> (Double, s Double)
-> (Log Double, Positions)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap Double -> Log Double
forall a. a -> Log a
Exp s Double -> Positions
toVec ((Double, s Double) -> (Log Double, Positions))
-> (Positions -> (Double, s Double)) -> Target
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s Double -> (Double, s Double)
tFnG (s Double -> (Double, s Double))
-> (Positions -> s Double) -> Positions -> (Double, s Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s Double -> Positions -> s Double
fromVec s Double
x
(NParams Maybe Double
mEps Maybe Masses
mMs) = NParams
nparams
hParamsI :: HParamsI
hParamsI =
(String -> HParamsI)
-> (HParamsI -> HParamsI) -> Either String HParamsI -> HParamsI
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> HParamsI
forall a. HasCallStack => String -> a
error HParamsI -> HParamsI
forall a. a -> a
id (Either String HParamsI -> HParamsI)
-> Either String HParamsI -> HParamsI
forall a b. (a -> b) -> a -> b
$
Target
-> Positions
-> Maybe Double
-> Maybe Double
-> Maybe Masses
-> Either String HParamsI
hParamsIWith (s Double -> Target
targetWith s Double
sample) (s Double -> Positions
toVec s Double
sample) Maybe Double
mEps Maybe Double
forall a. Maybe a
Nothing Maybe Masses
mMs
ps :: PFunction (s Double)
ps = HParamsI
-> HStructure s -> (s Double -> Target) -> PFunction (s Double)
forall (s :: * -> *).
HParamsI
-> HStructure s -> (s Double -> Target) -> PFunction (s Double)
nutsPFunction HParamsI
hParamsI HStructure s
hstruct s Double -> Target
targetWith
nutsWith :: Maybe (Tuner (s Double)) -> Proposal (s Double)
nutsWith = PName
-> PDescription
-> PSpeed
-> PDimension
-> PWeight
-> PFunction (s Double)
-> Maybe (Tuner (s Double))
-> Proposal (s Double)
forall a.
PName
-> PDescription
-> PSpeed
-> PDimension
-> PWeight
-> PFunction a
-> Maybe (Tuner a)
-> Proposal a
Proposal PName
n PDescription
desc PSpeed
PSlow PDimension
pDim PWeight
w PFunction (s Double)
ps
ts :: AuxiliaryTuningParameters
ts = HParamsI -> AuxiliaryTuningParameters
toAuxiliaryTuningParameters HParamsI
hParamsI
tuner :: Maybe (Tuner (s Double))
tuner = do
TuningFunction (s Double)
tfun <- DoublingStep
-> (s Double -> Positions)
-> HTuningConf
-> Maybe (TuningFunction (s Double))
forall a.
DoublingStep
-> (a -> Positions) -> HTuningConf -> Maybe (TuningFunction a)
hTuningFunctionWith DoublingStep
dim s Double -> Positions
toVec HTuningConf
htconf
let pfun :: Double
-> AuxiliaryTuningParameters
-> Either String (PFunction (s Double))
pfun = DoublingStep
-> HStructure s
-> (s Double -> Target)
-> Double
-> AuxiliaryTuningParameters
-> Either String (PFunction (s Double))
forall (s :: * -> *).
Traversable s =>
DoublingStep
-> HStructure s
-> (s Double -> Target)
-> Double
-> AuxiliaryTuningParameters
-> Either String (PFunction (s Double))
nutsPFunctionWithTuningParameters DoublingStep
dim HStructure s
hstruct s Double -> Target
targetWith
Tuner (s Double) -> Maybe (Tuner (s Double))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tuner (s Double) -> Maybe (Tuner (s Double)))
-> Tuner (s Double) -> Maybe (Tuner (s Double))
forall a b. (a -> b) -> a -> b
$ Double
-> AuxiliaryTuningParameters
-> Direction
-> TuningFunction (s Double)
-> (Double
-> AuxiliaryTuningParameters
-> Either String (PFunction (s Double)))
-> Tuner (s Double)
forall a.
Double
-> AuxiliaryTuningParameters
-> Direction
-> TuningFunction a
-> (Double
-> AuxiliaryTuningParameters -> Either String (PFunction a))
-> Tuner a
Tuner Double
1.0 AuxiliaryTuningParameters
ts Direction
True TuningFunction (s Double)
tfun Double
-> AuxiliaryTuningParameters
-> Either String (PFunction (s Double))
pfun
in case Masses -> HStructure s -> Maybe String
forall (s :: * -> *).
Foldable s =>
Masses -> HStructure s -> Maybe String
checkHStructureWith (HParamsI -> Masses
hpsMasses HParamsI
hParamsI) HStructure s
hstruct of
Just String
err -> String -> Proposal (s Double)
forall a. HasCallStack => String -> a
error String
err
Maybe String
Nothing -> Maybe (Tuner (s Double)) -> Proposal (s Double)
nutsWith Maybe (Tuner (s Double))
tuner