-- |
-- Module      :  Mcmc.Proposal.Hamiltonian.Nuts
-- Description :  No-U-Turn sampler (NUTS)
-- Copyright   :  2022 Dominik Schrempf
-- License     :  GPL-3.0-or-later
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  experimental
-- Portability :  portable
-- Creation date: Fri May 27 09:58:23 2022.
-- For a general introduction to Hamiltonian proposals, see
-- "Mcmc.Proposal.Hamiltonian.Hamiltonian".
-- This module implements the No-U-Turn Sampler (NUTS), as described in [4].
-- Work in progress.
-- References:
-- - [1] Chapter 5 of Handbook of Monte Carlo: Neal, R. M., MCMC Using
--   Hamiltonian Dynamics, In S. Brooks, A. Gelman, G. Jones, & X. Meng (Eds.),
--   Handbook of Markov Chain Monte Carlo (2011), CRC press.
-- - [2] Gelman, A., Carlin, J. B., Stern, H. S., & Rubin, D. B., Bayesian data
--   analysis (2014), CRC Press.
-- - [3] Review by Betancourt and notes: Betancourt, M., A conceptual
--   introduction to Hamiltonian Monte Carlo, arXiv, 1701–02434 (2017).
-- - [4] Matthew D. Hoffman, Andrew Gelman (2014) The No-U-Turn Sampler:
--   Adaptively Setting Path Lengths in Hamiltonian Monte Carlo, Journal of
--   Machine Learning Research.
module Mcmc.Proposal.Hamiltonian.Nuts
  ( NParams (..),

import Data.Bifunctor
import Mcmc.Acceptance
import Mcmc.Proposal
import Mcmc.Proposal.Hamiltonian.Common
import Mcmc.Proposal.Hamiltonian.Internal
import Mcmc.Proposal.Hamiltonian.Masses
import Numeric.AD.Double
import qualified Numeric.LinearAlgebra as L
import Numeric.Log
import System.Random.Stateful

-- Internal; Slice variable 'u'.
type SliceVariable = Log Double

-- Internal; Forward is True.
type Direction = Bool

-- Internal; Doubling step number 'j'.
type DoublingStep = Int

-- Internal; Number of leapfrog steps within the slice 'n'.
type NStepsOk = Int

-- Internal; Estimated acceptance rate \(\alpha\)'.
type Alpha = Log Double

-- Internal; Number of accepted steps.
type NAlpha = Int

-- Internal; Well, that's fun, isn't it? Have a look at Algorithm 3 in [4].
type BuildTreeReturnType = (Positions, Momenta, Positions, Momenta, Positions, NStepsOk, Alpha, NAlpha)

-- Constant determining largest allowed leapfrog integration error. See
-- discussion around Equation (3) in [4].
deltaMax :: Log Double
deltaMax :: Log Double
deltaMax = Double -> Log Double
forall a. a -> Log a
Exp Double

-- Second function in Algorithm 3 and Algorithm 6, respectively in [4].
buildTreeWith ::
  -- The exponent of the total energy of the starting state is used to
  -- calcaulate the expected acceptance rate 'Alpha'.
  Log Double ->
  MassesI ->
  Target ->
  IOGenM StdGen ->
  Positions ->
  Momenta ->
  SliceVariable ->
  Direction ->
  DoublingStep ->
  LeapfrogScalingFactor ->
  IO (Maybe BuildTreeReturnType)
buildTreeWith :: Log Double
-> MassesI
-> Target
-> IOGenM StdGen
-> Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTreeWith Log Double
expETot0 MassesI
msI Target
tfun IOGenM StdGen
g Positions
q Positions
p Log Double
u Direction
v DoublingStep
j Double
  | DoublingStep
j DoublingStep -> DoublingStep -> Direction
forall a. Ord a => a -> a -> Direction
<= DoublingStep
0 =
      -- Move backwards or forwards?
      let e' :: Double
e' = if Direction
v then Double
e else Double -> Double
forall a. Num a => a -> a
negate Double
       in case Target
-> MassesI
-> DoublingStep
-> Double
-> Positions
-> Positions
-> Maybe (Positions, Positions, Log Double, Log Double)
leapfrog Target
tfun MassesI
msI DoublingStep
1 Double
e' Positions
q Positions
p of
            Maybe (Positions, Positions, Log Double, Log Double)
Nothing -> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe BuildTreeReturnType
forall a. Maybe a
            Just (Positions
q', Positions
p', Log Double
_, Log Double
expEPot') ->
              if Direction
                then Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType))
-> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall a b. (a -> b) -> a -> b
$ BuildTreeReturnType -> Maybe BuildTreeReturnType
forall a. a -> Maybe a
Just (Positions
q', Positions
p', Positions
q', Positions
p', Positions
q', DoublingStep
n', Log Double
alpha, DoublingStep
                else Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe BuildTreeReturnType
forall a. Maybe a
                expEKin' :: Log Double
expEKin' = MassesI -> Positions -> Log Double
exponentialKineticEnergy MassesI
msI Positions
                expETot' :: Log Double
expETot' = Log Double
expEPot' Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
                n' :: DoublingStep
n' = if Log Double
u Log Double -> Log Double -> Direction
forall a. Ord a => a -> a -> Direction
<= Log Double
expEPot' Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
expEKin' then DoublingStep
1 else DoublingStep
                errorIsSmall :: Direction
errorIsSmall = Log Double
u Log Double -> Log Double -> Direction
forall a. Ord a => a -> a -> Direction
< Log Double
deltaMax Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
                alpha' :: Log Double
alpha' = Log Double
expETot' Log Double -> Log Double -> Log Double
forall a. Fractional a => a -> a -> a
/ Log Double
                alpha :: Log Double
alpha = Log Double -> Log Double -> Log Double
forall a. Ord a => a -> a -> a
min Log Double
1.0 Log Double

  -- Recursive case. This is complicated because the algorithm is written for an
  -- imperative language, and because we have two stacked monads.
  | Direction
otherwise = do
      Maybe BuildTreeReturnType
mr <- Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTree Positions
q Positions
p Log Double
u Direction
v (DoublingStep
j DoublingStep -> DoublingStep -> DoublingStep
forall a. Num a => a -> a -> a
- DoublingStep
1) Double
      case Maybe BuildTreeReturnType
mr of
        Maybe BuildTreeReturnType
Nothing -> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe BuildTreeReturnType
forall a. Maybe a
        -- Here, the suffixes 'm' and 'p' stand for minus and plus, respectively.
        Just (Positions
qm, Positions
pm, Positions
qp, Positions
pp, Positions
q', DoublingStep
n', Log Double
a', DoublingStep
na') -> do
          Maybe BuildTreeReturnType
mr' <-
            if Direction
              then -- Forwards.
                Maybe BuildTreeReturnType
mr'' <- Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTree Positions
qp Positions
pp Log Double
u Direction
v (DoublingStep
j DoublingStep -> DoublingStep -> DoublingStep
forall a. Num a => a -> a -> a
- DoublingStep
1) Double
                case Maybe BuildTreeReturnType
mr'' of
                  Maybe BuildTreeReturnType
Nothing -> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe BuildTreeReturnType
forall a. Maybe a
                  Just (Positions
_, Positions
_, Positions
qp', Positions
pp', Positions
q'', DoublingStep
n'', Log Double
a'', DoublingStep
na'') ->
                    Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType))
-> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall a b. (a -> b) -> a -> b
$ BuildTreeReturnType -> Maybe BuildTreeReturnType
forall a. a -> Maybe a
Just (Positions
qm, Positions
pm, Positions
qp', Positions
pp', Positions
q'', DoublingStep
n'', Log Double
a'', DoublingStep
              else -- Backwards.
                Maybe BuildTreeReturnType
mr'' <- Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTree Positions
qm Positions
pm Log Double
u Direction
v (DoublingStep
j DoublingStep -> DoublingStep -> DoublingStep
forall a. Num a => a -> a -> a
- DoublingStep
1) Double
                case Maybe BuildTreeReturnType
mr'' of
                  Maybe BuildTreeReturnType
Nothing -> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe BuildTreeReturnType
forall a. Maybe a
                  Just (Positions
qm', Positions
pm', Positions
_, Positions
_, Positions
q'', DoublingStep
n'', Log Double
a'', DoublingStep
na'') ->
                    Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType))
-> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall a b. (a -> b) -> a -> b
$ BuildTreeReturnType -> Maybe BuildTreeReturnType
forall a. a -> Maybe a
Just (Positions
qm', Positions
pm', Positions
qp, Positions
pp, Positions
q'', DoublingStep
n'', Log Double
a'', DoublingStep
          case Maybe BuildTreeReturnType
mr' of
            Maybe BuildTreeReturnType
Nothing -> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe BuildTreeReturnType
forall a. Maybe a
            Just (Positions
qm'', Positions
pm'', Positions
qp'', Positions
pp'', Positions
q''', DoublingStep
n''', Log Double
a''', DoublingStep
na''') -> do
b <- (Double, Double) -> IOGenM StdGen -> IO Double
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Double
0, Double
1) IOGenM StdGen
g :: IO Double
              let q'''' :: Positions
q'''' = if Double
b Double -> Double -> Direction
forall a. Ord a => a -> a -> Direction
< DoublingStep -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral DoublingStep
n''' Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ (DoublingStep -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (DoublingStep -> Double) -> DoublingStep -> Double
forall a b. (a -> b) -> a -> b
$ DoublingStep
n' DoublingStep -> DoublingStep -> DoublingStep
forall a. Num a => a -> a -> a
+ DoublingStep
n''') then Positions
q''' else Positions
                  a'''' :: Log Double
a'''' = Log Double
a' Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
+ Log Double
                  na'''' :: DoublingStep
na'''' = DoublingStep
na' DoublingStep -> DoublingStep -> DoublingStep
forall a. Num a => a -> a -> a
+ DoublingStep
                  n'''' :: DoublingStep
n'''' = DoublingStep
n' DoublingStep -> DoublingStep -> DoublingStep
forall a. Num a => a -> a -> a
+ DoublingStep
                  -- Important: Check for U-turn. This formula differs from the
                  -- formula using indicator functions in Algorithm 3. However,
                  -- check Equation (4).
                  isUTurn :: Direction
isUTurn = let dq :: Positions
dq = (Positions
qp'' Positions -> Positions -> Positions
forall a. Num a => a -> a -> a
- Positions
qm'') in (Positions
dq Positions -> Positions -> Positions
forall a. Num a => a -> a -> a
* Positions
pm'' Positions -> Positions -> Direction
forall a. Ord a => a -> a -> Direction
< Positions
0) Direction -> Direction -> Direction
|| (Positions
dq Positions -> Positions -> Positions
forall a. Num a => a -> a -> a
* Positions
pp'' Positions -> Positions -> Direction
forall a. Ord a => a -> a -> Direction
< Positions
              if Direction
                then Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe BuildTreeReturnType
forall a. Maybe a
                else Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType))
-> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall a b. (a -> b) -> a -> b
$ BuildTreeReturnType -> Maybe BuildTreeReturnType
forall a. a -> Maybe a
Just (Positions
qm'', Positions
pm'', Positions
qp'', Positions
pp'', Positions
q'''', DoublingStep
n'''', Log Double
a'''', DoublingStep
    buildTree :: Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTree = Log Double
-> MassesI
-> Target
-> IOGenM StdGen
-> Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTreeWith Log Double
expETot0 MassesI
msI Target
tfun IOGenM StdGen

-- | Paramters of the NUTS proposal.
-- Includes tuning parameters and tuning configuration.
data NParams = NParams
  { NParams -> Maybe Double
nLeapfrogScalingFactor :: Maybe LeapfrogScalingFactor,
    NParams -> Maybe Masses
nMasses :: Maybe Masses
  deriving (DoublingStep -> NParams -> ShowS
[NParams] -> ShowS
NParams -> String
(DoublingStep -> NParams -> ShowS)
-> (NParams -> String) -> ([NParams] -> ShowS) -> Show NParams
forall a.
(DoublingStep -> a -> ShowS)
-> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [NParams] -> ShowS
$cshowList :: [NParams] -> ShowS
show :: NParams -> String
$cshow :: NParams -> String
showsPrec :: DoublingStep -> NParams -> ShowS
$cshowsPrec :: DoublingStep -> NParams -> ShowS

-- | Default parameters.
-- - Estimate a reasonable leapfrog scaling factor using Algorithm 4 [4]. If all
--   fails, use 0.1.
-- - The mass matrix is set to the identity matrix.
defaultNParams :: NParams
defaultNParams :: NParams
defaultNParams = Maybe Double -> Maybe Masses -> NParams
NParams Maybe Double
forall a. Maybe a
Nothing Maybe Masses
forall a. Maybe a

nutsPFunctionWithTuningParameters ::
  Traversable s =>
  Dimension ->
  HStructure s ->
  (s Double -> Target) ->
  TuningParameter ->
  AuxiliaryTuningParameters ->
  Either String (PFunction (s Double))
nutsPFunctionWithTuningParameters :: DoublingStep
-> HStructure s
-> (s Double -> Target)
-> Double
-> AuxiliaryTuningParameters
-> Either String (PFunction (s Double))
nutsPFunctionWithTuningParameters DoublingStep
d HStructure s
hstruct s Double -> Target
targetWith Double
_ AuxiliaryTuningParameters
ts = do
hParamsI <- DoublingStep -> AuxiliaryTuningParameters -> Either String HParamsI
fromAuxiliaryTuningParameters DoublingStep
d AuxiliaryTuningParameters
  PFunction (s Double) -> Either String (PFunction (s Double))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PFunction (s Double) -> Either String (PFunction (s Double)))
-> PFunction (s Double) -> Either String (PFunction (s Double))
forall a b. (a -> b) -> a -> b
$ HParamsI
-> HStructure s -> (s Double -> Target) -> PFunction (s Double)
forall (s :: * -> *).
-> HStructure s -> (s Double -> Target) -> PFunction (s Double)
nutsPFunction HParamsI
hParamsI HStructure s
hstruct s Double -> Target

data IsNew
  = Old
  | OldWith {IsNew -> AcceptanceCounts
_acceptanceCountsOld :: AcceptanceCounts}
  | NewWith {IsNew -> AcceptanceCounts
_acceptanceCountsNew :: AcceptanceCounts}

-- First function in Algorithm 3.
nutsPFunction ::
  HParamsI ->
  HStructure s ->
  (s Double -> Target) ->
  PFunction (s Double)
nutsPFunction :: HParamsI
-> HStructure s -> (s Double -> Target) -> PFunction (s Double)
nutsPFunction HParamsI
hparamsi HStructure s
hstruct s Double -> Target
targetWith s Double
x IOGenM StdGen
g = do
p <- Positions -> Masses -> IOGenM StdGen -> IO Positions
forall g (m :: * -> *).
StatefulGen g m =>
Positions -> Masses -> g -> m Positions
generateMomenta Positions
mus Masses
ms IOGenM StdGen
uZeroOne <- (Double, Double) -> IOGenM StdGen -> IO Double
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Double
0, Double
1) IOGenM StdGen
g :: IO Double
  -- NOTE (runtime): Here we need the target function value from the previous
  -- step. For now, I just recalculate the value, but this is, of course, slow!
  -- However, if other proposals have changed the state inbetween, we do need to
  -- recalculate this value.
  let q :: Positions
q = s Double -> Positions
toVec s Double
      expEPot :: Log Double
expEPot = (Log Double, Positions) -> Log Double
forall a b. (a, b) -> a
fst ((Log Double, Positions) -> Log Double)
-> (Log Double, Positions) -> Log Double
forall a b. (a -> b) -> a -> b
$ Target
target Positions
      expEKin :: Log Double
expEKin = MassesI -> Positions -> Log Double
exponentialKineticEnergy MassesI
msI Positions
      expETot :: Log Double
expETot = Log Double
expEPot Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
      uZeroOneL :: Log Double
uZeroOneL = Double -> Log Double
forall a. a -> Log a
Exp (Double -> Log Double) -> Double -> Log Double
forall a b. (a -> b) -> a -> b
$ Double -> Double
forall a. Floating a => a -> a
log Double
      u :: Log Double
u = Log Double
expETot Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
  let -- Recursive case. This is complicated because the algorithm is written for an
      -- imperative language, and because we have two stacked monads.
      -- Here, the suffixes 'm' and 'p' stand for minus and plus, respectively.
      go :: Positions
-> Positions
-> Positions
-> Positions
-> DoublingStep
-> Positions
-> DoublingStep
-> IsNew
-> IO (Positions, IsNew)
go Positions
qm Positions
pm Positions
qp Positions
pp DoublingStep
j Positions
y DoublingStep
n IsNew
isNew = do
v <- IOGenM StdGen -> IO Direction
forall a g (m :: * -> *). (Uniform a, StatefulGen g m) => g -> m a
uniformM IOGenM StdGen
g :: IO Direction
        Maybe BuildTreeReturnType
mr' <-
          if Direction
            then -- Forwards.
              Maybe BuildTreeReturnType
mr <- Log Double
-> MassesI
-> Target
-> IOGenM StdGen
-> Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTreeWith Log Double
expETot MassesI
msI Target
target IOGenM StdGen
g Positions
qp Positions
pp Log Double
u Direction
v DoublingStep
j Double
              case Maybe BuildTreeReturnType
mr of
                Maybe BuildTreeReturnType
Nothing -> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe BuildTreeReturnType
forall a. Maybe a
                Just (Positions
_, Positions
_, Positions
qp', Positions
pp', Positions
y', DoublingStep
n', Log Double
a, DoublingStep
na) -> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType))
-> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall a b. (a -> b) -> a -> b
$ BuildTreeReturnType -> Maybe BuildTreeReturnType
forall a. a -> Maybe a
Just (Positions
qm, Positions
pm, Positions
qp', Positions
pp', Positions
y', DoublingStep
n', Log Double
a, DoublingStep
            else -- Backwards.
              Maybe BuildTreeReturnType
mr <- Log Double
-> MassesI
-> Target
-> IOGenM StdGen
-> Positions
-> Positions
-> Log Double
-> Direction
-> DoublingStep
-> Double
-> IO (Maybe BuildTreeReturnType)
buildTreeWith Log Double
expETot MassesI
msI Target
target IOGenM StdGen
g Positions
qm Positions
pm Log Double
u Direction
v DoublingStep
j Double
              case Maybe BuildTreeReturnType
mr of
                Maybe BuildTreeReturnType
Nothing -> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe BuildTreeReturnType
forall a. Maybe a
                Just (Positions
qm', Positions
pm', Positions
_, Positions
_, Positions
y', DoublingStep
n', Log Double
a, DoublingStep
na) -> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType))
-> Maybe BuildTreeReturnType -> IO (Maybe BuildTreeReturnType)
forall a b. (a -> b) -> a -> b
$ BuildTreeReturnType -> Maybe BuildTreeReturnType
forall a. a -> Maybe a
Just (Positions
qm', Positions
pm', Positions
qp, Positions
pp, Positions
y', DoublingStep
n', Log Double
a, DoublingStep
        case Maybe BuildTreeReturnType
mr' of
          Maybe BuildTreeReturnType
Nothing -> (Positions, IsNew) -> IO (Positions, IsNew)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Positions
y, IsNew
          Just (Positions
qm'', Positions
pm'', Positions
qp'', Positions
pp'', Positions
y'', DoublingStep
n'', Log Double
a, DoublingStep
na) -> do
            let r :: Double
r = DoublingStep -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral DoublingStep
n'' Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ DoublingStep -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral DoublingStep
n :: Double
                ar :: Double
ar = (Double -> Double
forall a. Floating a => a -> a
exp (Double -> Double) -> Double -> Double
forall a b. (a -> b) -> a -> b
$ Log Double -> Double
forall a. Log a -> a
ln Log Double
a) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ DoublingStep -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral DoublingStep
na :: Double
                getCounts :: a -> a
getCounts a
s = a -> a -> a
forall a. Ord a => a -> a -> a
max a
0 (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a -> a -> a
forall a. Ord a => a -> a -> a
min a
100 (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a -> a
forall a b. (RealFrac a, Integral b) => a -> b
round (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ a
s a -> a -> a
forall a. Num a => a -> a -> a
* a
                ac :: AcceptanceCounts
ac =
                  if Double
ar Double -> Double -> Direction
forall a. Ord a => a -> a -> Direction
>= Double
                    then let cs :: DoublingStep
cs = Double -> DoublingStep
forall a a. (RealFrac a, Integral a) => a -> a
getCounts Double
ar in DoublingStep -> DoublingStep -> AcceptanceCounts
AcceptanceCounts DoublingStep
cs (DoublingStep
100 DoublingStep -> DoublingStep -> DoublingStep
forall a. Num a => a -> a -> a
- DoublingStep
                    else String -> AcceptanceCounts
forall a. HasCallStack => String -> a
error (String -> AcceptanceCounts) -> String -> AcceptanceCounts
forall a b. (a -> b) -> a -> b
$ String
"nutsPFunction: Acceptance rate negative."
isAccept <-
              if Double
r Double -> Double -> Direction
forall a. Ord a => a -> a -> Direction
> Double
                then Direction -> IO Direction
forall (f :: * -> *) a. Applicative f => a -> f a
pure Direction
                else do
b <- (Double, Double) -> IOGenM StdGen -> IO Double
forall a g (m :: * -> *).
(UniformRange a, StatefulGen g m) =>
(a, a) -> g -> m a
uniformRM (Double
0, Double
1) IOGenM StdGen
                  Direction -> IO Direction
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Direction -> IO Direction) -> Direction -> IO Direction
forall a b. (a -> b) -> a -> b
$ Double
b Double -> Double -> Direction
forall a. Ord a => a -> a -> Direction
< Double
            let (Positions
y''', IsNew
isNew') = if Direction
isAccept then (Positions
y'', AcceptanceCounts -> IsNew
NewWith AcceptanceCounts
ac) else (Positions
y, AcceptanceCounts -> IsNew
OldWith AcceptanceCounts
                isUTurn :: Direction
isUTurn = let dq :: Positions
dq = (Positions
qp'' Positions -> Positions -> Positions
forall a. Num a => a -> a -> a
- Positions
qm'') in (Positions
dq Positions -> Positions -> Positions
forall a. Num a => a -> a -> a
* Positions
pm'' Positions -> Positions -> Direction
forall a. Ord a => a -> a -> Direction
< Positions
0) Direction -> Direction -> Direction
|| (Positions
dq Positions -> Positions -> Positions
forall a. Num a => a -> a -> a
* Positions
pp'' Positions -> Positions -> Direction
forall a. Ord a => a -> a -> Direction
< Positions
            if Direction
              then (Positions, IsNew) -> IO (Positions, IsNew)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Positions
y''', IsNew
              else Positions
-> Positions
-> Positions
-> Positions
-> DoublingStep
-> Positions
-> DoublingStep
-> IsNew
-> IO (Positions, IsNew)
go Positions
qm'' Positions
pm'' Positions
qp'' Positions
pp'' (DoublingStep
j DoublingStep -> DoublingStep -> DoublingStep
forall a. Num a => a -> a -> a
+ DoublingStep
1) Positions
y''' (DoublingStep
n DoublingStep -> DoublingStep -> DoublingStep
forall a. Num a => a -> a -> a
+ DoublingStep
n'') IsNew
x', IsNew
isNew) <- Positions
-> Positions
-> Positions
-> Positions
-> DoublingStep
-> Positions
-> DoublingStep
-> IsNew
-> IO (Positions, IsNew)
go Positions
q Positions
p Positions
q Positions
p DoublingStep
0 Positions
q DoublingStep
1 IsNew
  (PResult (s Double), Maybe AcceptanceCounts)
-> IO (PResult (s Double), Maybe AcceptanceCounts)
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((PResult (s Double), Maybe AcceptanceCounts)
 -> IO (PResult (s Double), Maybe AcceptanceCounts))
-> (PResult (s Double), Maybe AcceptanceCounts)
-> IO (PResult (s Double), Maybe AcceptanceCounts)
forall a b. (a -> b) -> a -> b
$ case IsNew
isNew of
Old -> (PResult (s Double)
forall a. PResult a
ForceReject, AcceptanceCounts -> Maybe AcceptanceCounts
forall a. a -> Maybe a
Just (AcceptanceCounts -> Maybe AcceptanceCounts)
-> AcceptanceCounts -> Maybe AcceptanceCounts
forall a b. (a -> b) -> a -> b
$ DoublingStep -> DoublingStep -> AcceptanceCounts
AcceptanceCounts DoublingStep
0 DoublingStep
    OldWith AcceptanceCounts
ac -> (PResult (s Double)
forall a. PResult a
ForceReject, AcceptanceCounts -> Maybe AcceptanceCounts
forall a. a -> Maybe a
Just (AcceptanceCounts -> Maybe AcceptanceCounts)
-> AcceptanceCounts -> Maybe AcceptanceCounts
forall a b. (a -> b) -> a -> b
$ AcceptanceCounts
    NewWith AcceptanceCounts
ac -> (s Double -> PResult (s Double)
forall a. a -> PResult a
ForceAccept (s Double -> PResult (s Double)) -> s Double -> PResult (s Double)
forall a b. (a -> b) -> a -> b
$ Positions -> s Double
fromVec Positions
x', AcceptanceCounts -> Maybe AcceptanceCounts
forall a. a -> Maybe a
Just (AcceptanceCounts -> Maybe AcceptanceCounts)
-> AcceptanceCounts -> Maybe AcceptanceCounts
forall a b. (a -> b) -> a -> b
$ AcceptanceCounts
    (HParamsI Double
e Double
_ Masses
ms TParamsVar
_ TParamsFixed
_ MassesI
msI Positions
mus) = HParamsI
    (HStructure s Double
_ s Double -> Positions
toVec s Double -> Positions -> s Double
fromVecWith) = HStructure s
    fromVec :: Positions -> s Double
fromVec = s Double -> Positions -> s Double
fromVecWith s Double
    target :: Target
target = s Double -> Target
targetWith s Double

-- | No U-turn Hamiltonian Monte Carlo sampler (NUTS).
-- The structure of the state is denoted as @s@.
-- May call 'error' during initialization.
nuts ::
  Traversable s =>
  NParams ->
  HTuningConf ->
  HStructure s ->
  HTarget s ->
  PName ->
  PWeight ->
  Proposal (s Double)
nuts :: NParams
-> HTuningConf
-> HStructure s
-> HTarget s
-> PName
-> PWeight
-> Proposal (s Double)
nuts NParams
nparams HTuningConf
htconf HStructure s
hstruct HTarget s
htarget PName
n PWeight
w =
  let -- Misc.
      desc :: PDescription
desc = String -> PDescription
PDescription String
"No U-turn sampler (NUTS)"
      (HStructure s Double
sample s Double -> Positions
toVec s Double -> Positions -> s Double
fromVec) = HStructure s
      dim :: IndexOf Vector
dim = Positions -> IndexOf Vector
forall (c :: * -> *) t. Container c t => c t -> IndexOf c
L.size (Positions -> IndexOf Vector) -> Positions -> IndexOf Vector
forall a b. (a -> b) -> a -> b
$ s Double -> Positions
toVec s Double
      -- See bottom of page 1616 in [4].
      pDim :: PDimension
pDim = DoublingStep -> Double -> PDimension
PSpecial DoublingStep
dim Double
      -- Vectorize and derive the target function.
      (HTarget forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mPrF forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mJcF) = HTarget s
      tF :: s a -> LikelihoodG a
tF s a
y = case (Maybe (s a -> LikelihoodG a)
forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mPrF, Maybe (s a -> LikelihoodG a)
forall a.
(RealFloat a, Typeable a) =>
Maybe (PriorFunctionG (s a) a)
mJcF) of
        (Maybe (s a -> LikelihoodG a)
Nothing, Maybe (s a -> LikelihoodG a)
Nothing) -> s a -> LikelihoodG a
forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
        (Just s a -> LikelihoodG a
prF, Maybe (s a -> LikelihoodG a)
Nothing) -> s a -> LikelihoodG a
prF s a
y LikelihoodG a -> LikelihoodG a -> LikelihoodG a
forall a. Num a => a -> a -> a
* s a -> LikelihoodG a
forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
        (Maybe (s a -> LikelihoodG a)
Nothing, Just s a -> LikelihoodG a
jcF) -> s a -> LikelihoodG a
forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
y LikelihoodG a -> LikelihoodG a -> LikelihoodG a
forall a. Num a => a -> a -> a
* s a -> LikelihoodG a
jcF s a
        (Just s a -> LikelihoodG a
prF, Just s a -> LikelihoodG a
jcF) -> s a -> LikelihoodG a
prF s a
y LikelihoodG a -> LikelihoodG a -> LikelihoodG a
forall a. Num a => a -> a -> a
* s a -> LikelihoodG a
forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
lhF s a
y LikelihoodG a -> LikelihoodG a -> LikelihoodG a
forall a. Num a => a -> a -> a
* s a -> LikelihoodG a
jcF s a
      tFnG :: s Double -> (Double, s Double)
tFnG = (forall s.
 (Reifies s Tape, Typeable s) =>
 s (ReverseDouble s) -> ReverseDouble s)
-> s Double -> (Double, s Double)
forall (f :: * -> *).
Traversable f =>
(forall s.
 (Reifies s Tape, Typeable s) =>
 f (ReverseDouble s) -> ReverseDouble s)
-> f Double -> (Double, f Double)
grad' (Log (ReverseDouble s) -> ReverseDouble s
forall a. Log a -> a
ln (Log (ReverseDouble s) -> ReverseDouble s)
-> (s (ReverseDouble s) -> Log (ReverseDouble s))
-> s (ReverseDouble s)
-> ReverseDouble s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s (ReverseDouble s) -> Log (ReverseDouble s)
forall a. (RealFloat a, Typeable a) => LikelihoodFunctionG (s a) a
      targetWith :: s Double -> Target
targetWith s Double
x = (Double -> Log Double)
-> (s Double -> Positions)
-> (Double, s Double)
-> (Log Double, Positions)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap Double -> Log Double
forall a. a -> Log a
Exp s Double -> Positions
toVec ((Double, s Double) -> (Log Double, Positions))
-> (Positions -> (Double, s Double)) -> Target
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s Double -> (Double, s Double)
tFnG (s Double -> (Double, s Double))
-> (Positions -> s Double) -> Positions -> (Double, s Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. s Double -> Positions -> s Double
fromVec s Double
      (NParams Maybe Double
mEps Maybe Masses
mMs) = NParams
      hParamsI :: HParamsI
hParamsI =
        (String -> HParamsI)
-> (HParamsI -> HParamsI) -> Either String HParamsI -> HParamsI
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> HParamsI
forall a. HasCallStack => String -> a
error HParamsI -> HParamsI
forall a. a -> a
id (Either String HParamsI -> HParamsI)
-> Either String HParamsI -> HParamsI
forall a b. (a -> b) -> a -> b
-> Positions
-> Maybe Double
-> Maybe Double
-> Maybe Masses
-> Either String HParamsI
hParamsIWith (s Double -> Target
targetWith s Double
sample) (s Double -> Positions
toVec s Double
sample) Maybe Double
mEps Maybe Double
forall a. Maybe a
Nothing Maybe Masses
      ps :: PFunction (s Double)
ps = HParamsI
-> HStructure s -> (s Double -> Target) -> PFunction (s Double)
forall (s :: * -> *).
-> HStructure s -> (s Double -> Target) -> PFunction (s Double)
nutsPFunction HParamsI
hParamsI HStructure s
hstruct s Double -> Target
      nutsWith :: Maybe (Tuner (s Double)) -> Proposal (s Double)
nutsWith = PName
-> PDescription
-> PSpeed
-> PDimension
-> PWeight
-> PFunction (s Double)
-> Maybe (Tuner (s Double))
-> Proposal (s Double)
forall a.
-> PDescription
-> PSpeed
-> PDimension
-> PWeight
-> PFunction a
-> Maybe (Tuner a)
-> Proposal a
Proposal PName
n PDescription
desc PSpeed
PSlow PDimension
pDim PWeight
w PFunction (s Double)
      -- Tuning.
      ts :: AuxiliaryTuningParameters
ts = HParamsI -> AuxiliaryTuningParameters
toAuxiliaryTuningParameters HParamsI
      tuner :: Maybe (Tuner (s Double))
tuner = do
        TuningFunction (s Double)
tfun <- DoublingStep
-> (s Double -> Positions)
-> HTuningConf
-> Maybe (TuningFunction (s Double))
forall a.
-> (a -> Positions) -> HTuningConf -> Maybe (TuningFunction a)
hTuningFunctionWith DoublingStep
dim s Double -> Positions
toVec HTuningConf
        let pfun :: Double
-> AuxiliaryTuningParameters
-> Either String (PFunction (s Double))
pfun = DoublingStep
-> HStructure s
-> (s Double -> Target)
-> Double
-> AuxiliaryTuningParameters
-> Either String (PFunction (s Double))
forall (s :: * -> *).
Traversable s =>
-> HStructure s
-> (s Double -> Target)
-> Double
-> AuxiliaryTuningParameters
-> Either String (PFunction (s Double))
nutsPFunctionWithTuningParameters DoublingStep
dim HStructure s
hstruct s Double -> Target
        Tuner (s Double) -> Maybe (Tuner (s Double))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Tuner (s Double) -> Maybe (Tuner (s Double)))
-> Tuner (s Double) -> Maybe (Tuner (s Double))
forall a b. (a -> b) -> a -> b
$ Double
-> AuxiliaryTuningParameters
-> Direction
-> TuningFunction (s Double)
-> (Double
    -> AuxiliaryTuningParameters
    -> Either String (PFunction (s Double)))
-> Tuner (s Double)
forall a.
-> AuxiliaryTuningParameters
-> Direction
-> TuningFunction a
-> (Double
    -> AuxiliaryTuningParameters -> Either String (PFunction a))
-> Tuner a
Tuner Double
1.0 AuxiliaryTuningParameters
ts Direction
True TuningFunction (s Double)
tfun Double
-> AuxiliaryTuningParameters
-> Either String (PFunction (s Double))
   in case Masses -> HStructure s -> Maybe String
forall (s :: * -> *).
Foldable s =>
Masses -> HStructure s -> Maybe String
checkHStructureWith (HParamsI -> Masses
hpsMasses HParamsI
hParamsI) HStructure s
hstruct of
        Just String
err -> String -> Proposal (s Double)
forall a. HasCallStack => String -> a
error String
        Maybe String
Nothing -> Maybe (Tuner (s Double)) -> Proposal (s Double)
nutsWith Maybe (Tuner (s Double))