-- |
-- Module      :  Mcmc.Proposal.Hamiltonian.Masses
-- Description :  Mass matrices
-- Copyright   :  2022 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- Creation date: Tue Jun 14 10:09:24 2022.
module Mcmc.Proposal.Hamiltonian.Masses
  ( Mu,
    MassesI,
    toGMatrix,
    cleanMatrix,
    getMassesI,
    getMus,
    Dimension,
    vectorToMasses,
    massesToVector,

    -- * Tuning
    SmoothingParameter (..),
    tuneDiagonalMassesOnly,
    tuneAllMasses,
  )
where

import Data.Maybe
import qualified Data.Vector as VB
import qualified Data.Vector.Storable as VS
import qualified Data.Vector.Unboxed as VU
import Mcmc.Proposal.Hamiltonian.Common
import qualified Numeric.LinearAlgebra as L
import Numeric.Natural
import qualified Statistics.Covariance as S
import qualified Statistics.Function as S
import qualified Statistics.Sample as S

-- Mean vector containing zeroes. We save this vector because it is required
-- when sampling from the multivariate normal distribution.
type Mu = L.Vector Double

-- General, symmetric, inverted mass matrix.
type MassesI = L.GMatrix

-- Purge masses and inverted masses (excluding the diagonal) strictly smaller
-- than the precision.
--
-- If changed, also change help text of 'HTuneMasses', which is indirectly
-- affected via 'massMin', and 'massMax'.
precision :: Double
precision :: Double
precision = Double
1e-8

isDiag :: L.Matrix Double -> Bool
isDiag :: Matrix Double -> Bool
isDiag Matrix Double
xs = forall a. Num a => a -> a
abs (Double
sumDiag forall a. Num a => a -> a -> a
- Double
sumFull) forall a. Ord a => a -> a -> Bool
< Double
precision
  where
    xsAbs :: Matrix Double
xsAbs = forall b (c :: * -> *) e.
(Element b, Container c e) =>
(e -> b) -> c e -> c b
L.cmap forall a. Num a => a -> a
abs Matrix Double
xs
    sumDiag :: Double
sumDiag = forall (c :: * -> *) e. Container c e => c e -> e
L.sumElements (forall t. Element t => Matrix t -> Vector t
L.takeDiag Matrix Double
xsAbs)
    sumFull :: Double
sumFull = forall (c :: * -> *) e. Container c e => c e -> e
L.sumElements Matrix Double
xsAbs

-- Consider a matrix sparse if less than (5 * number of rows) elements are
-- non-zero.
isSparse :: L.Matrix Double -> Bool
isSparse :: Matrix Double -> Bool
isSparse Matrix Double
xs = Double
nNonZero forall a. Ord a => a -> a -> Bool
< forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nMax
  where
    n :: Int
n = forall t. Matrix t -> Int
L.rows Matrix Double
xs
    m :: Int
m = forall a. Ord a => a -> a -> a
min Int
5 Int
n
    nMax :: Int
nMax = Int
n forall a. Num a => a -> a -> a
* Int
m
    f :: Double -> Double
f Double
x = if forall a. Num a => a -> a
abs Double
x forall a. Ord a => a -> a -> Bool
>= Double
precision then Double
1 else Double
0 :: Double
    xsInd :: Matrix Double
xsInd = forall b (c :: * -> *) e.
(Element b, Container c e) =>
(e -> b) -> c e -> c b
L.cmap Double -> Double
f Matrix Double
xs
    nNonZero :: Double
nNonZero = forall (c :: * -> *) e. Container c e => c e -> e
L.sumElements Matrix Double
xsInd

toAssocMatrix :: L.Matrix Double -> L.AssocMatrix
toAssocMatrix :: Matrix Double -> AssocMatrix
toAssocMatrix Matrix Double
xs
  | Int
n forall a. Eq a => a -> a -> Bool
/= Int
m = forall a. HasCallStack => [Char] -> a
error [Char]
"toAssocMatrix: Matrix not square."
  | Bool
otherwise =
      [ ((Int
i, Int
j), Double
e)
        | Int
i <- [Int
0 .. (Int
n forall a. Num a => a -> a -> a
- Int
1)],
          Int
j <- [Int
0 .. (Int
n forall a. Num a => a -> a -> a
- Int
1)],
          let e :: Double
e = Matrix Double
xs forall (c :: * -> *) e. Container c e => c e -> IndexOf c -> e
`L.atIndex` (Int
i, Int
j),
          forall a. Num a => a -> a
abs Double
e forall a. Ord a => a -> a -> Bool
>= Double
precision
      ]
  where
    n :: Int
n = forall t. Matrix t -> Int
L.rows Matrix Double
xs
    m :: Int
m = forall t. Matrix t -> Int
L.cols Matrix Double
xs

toGMatrix :: L.Matrix Double -> L.GMatrix
toGMatrix :: Matrix Double -> GMatrix
toGMatrix Matrix Double
xs
  | Int
n forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| Int
m forall a. Eq a => a -> a -> Bool
== Int
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"toGMatrix: Matrix empty."
  | Int
n forall a. Eq a => a -> a -> Bool
/= Int
m = forall a. HasCallStack => [Char] -> a
error [Char]
"toGMatrix: Matrix not square."
  | Matrix Double -> Bool
isDiag Matrix Double
xs = Int -> Int -> Vector Double -> GMatrix
L.mkDiagR Int
n Int
m forall a b. (a -> b) -> a -> b
$ forall t. Element t => Matrix t -> Vector t
L.takeDiag Matrix Double
xs
  | Matrix Double -> Bool
isSparse Matrix Double
xs = AssocMatrix -> GMatrix
L.mkSparse forall a b. (a -> b) -> a -> b
$ Matrix Double -> AssocMatrix
toAssocMatrix Matrix Double
xs
  | Bool
otherwise = Matrix Double -> GMatrix
L.mkDense Matrix Double
xs
  where
    n :: Int
n = forall t. Matrix t -> Int
L.rows Matrix Double
xs
    m :: Int
m = forall t. Matrix t -> Int
L.cols Matrix Double
xs

-- Diagonal:
-- - NaN values are set to 1.
-- - Negative values are set to 1.
-- - Small elements are set to 'precision'.
--
-- Off-diagonal:
-- - NaN values are set to 0.
-- - Elements with absolute values strictly smaller than 'precision' are purged.
--
-- We are permissive with negative and NaN values because adequate masses are
-- crucial. The Hamiltonian algorithms also work when the masses are off.
cleanMatrix :: L.Matrix Double -> L.Matrix Double
cleanMatrix :: Matrix Double -> Matrix Double
cleanMatrix Matrix Double
xs =
  forall a. (Num a, Element a) => Vector a -> Matrix a
L.diag (forall b (c :: * -> *) e.
(Element b, Container c e) =>
(e -> b) -> c e -> c b
L.cmap Double -> Double
cleanDiag Vector Double
xsDiag) forall a. Num a => a -> a -> a
+ forall b (c :: * -> *) e.
(Element b, Container c e) =>
(e -> b) -> c e -> c b
L.cmap Double -> Double
cleanOffDiag Matrix Double
xsOffDiag
  where
    xsDiag :: Vector Double
xsDiag = forall t. Element t => Matrix t -> Vector t
L.takeDiag Matrix Double
xs
    cleanDiag :: Double -> Double
cleanDiag Double
x
      | forall a. RealFloat a => a -> Bool
isNaN Double
x = Double
1
      | Double
x forall a. Ord a => a -> a -> Bool
< Double
0 = Double
1
      -- The strict comparison is important.
      | Double
x forall a. Ord a => a -> a -> Bool
< Double
precision = Double
precision
      | Bool
otherwise = Double
x
    xsOffDiag :: Matrix Double
xsOffDiag = Matrix Double
xs forall a. Num a => a -> a -> a
- forall a. (Num a, Element a) => Vector a -> Matrix a
L.diag Vector Double
xsDiag
    cleanOffDiag :: Double -> Double
cleanOffDiag Double
x
      | forall a. RealFloat a => a -> Bool
isNaN Double
x = Double
0
      -- The strict comparison is important.
      | forall a. Num a => a -> a
abs Double
x forall a. Ord a => a -> a -> Bool
< Double
precision = Double
0
      | Bool
otherwise = Double
x

getMassesI :: L.Herm Double -> L.GMatrix
getMassesI :: Herm Double -> GMatrix
getMassesI Herm Double
xs
  | Int
n forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| Int
m forall a. Eq a => a -> a -> Bool
== Int
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"getMassesI: Matrix empty."
  | Int
n forall a. Eq a => a -> a -> Bool
/= Int
m = forall a. HasCallStack => [Char] -> a
error [Char]
"getMassesI: Matrix not square."
  | Double
sign forall a. Eq a => a -> a -> Bool
/= Double
1.0 = forall a. HasCallStack => [Char] -> a
error [Char]
"getMassesI: Determinant of matrix is negative."
  | Bool
otherwise = Matrix Double -> GMatrix
toGMatrix forall a b. (a -> b) -> a -> b
$ Matrix Double -> Matrix Double
cleanMatrix Matrix Double
xsI
  where
    xs' :: Matrix Double
xs' = forall t. Herm t -> Matrix t
L.unSym Herm Double
xs
    n :: Int
n = forall t. Matrix t -> Int
L.rows Matrix Double
xs'
    m :: Int
m = forall t. Matrix t -> Int
L.cols Matrix Double
xs'
    (Matrix Double
xsI, (Double
_, Double
sign)) = forall t. Field t => Matrix t -> (Matrix t, (t, t))
L.invlndet Matrix Double
xs'

getMus :: Masses -> L.Vector Double
getMus :: Herm Double -> Vector Double
getMus Herm Double
xs
  | Int
n forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| Int
m forall a. Eq a => a -> a -> Bool
== Int
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"getMu: Matrix empty."
  | Int
n forall a. Eq a => a -> a -> Bool
/= Int
m = forall a. HasCallStack => [Char] -> a
error [Char]
"getMu: Matrix not square."
  | Bool
otherwise = forall a. Storable a => [a] -> Vector a
L.fromList forall a b. (a -> b) -> a -> b
$ forall a. Int -> a -> [a]
replicate Int
n Double
0.0
  where
    xs' :: Matrix Double
xs' = forall t. Herm t -> Matrix t
L.unSym Herm Double
xs
    n :: Int
n = forall t. Matrix t -> Int
L.rows Matrix Double
xs'
    m :: Int
m = forall t. Matrix t -> Int
L.cols Matrix Double
xs'

-- Dimension of the proposal.
type Dimension = Int

massesToVector :: Masses -> VU.Vector Double
massesToVector :: Herm Double -> Vector Double
massesToVector = forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VU.convert forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Element t => Matrix t -> Vector t
L.flatten forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Herm t -> Matrix t
L.unSym

vectorToMasses :: Dimension -> VU.Vector Double -> Masses
vectorToMasses :: Int -> Vector Double -> Herm Double
vectorToMasses Int
d = forall t. Matrix t -> Herm t
L.trustSym forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Storable t => Int -> Vector t -> Matrix t
L.reshape Int
d forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
VU.convert

-- If changed, also change help text of 'HTuneMasses'.
massMin :: Double
massMin :: Double
massMin = Double
precision

-- If changed, also change help text of 'HTuneMasses'.
massMax :: Double
massMax :: Double
massMax = forall a. Fractional a => a -> a
recip Double
precision

-- Minimal number of unique samples required for tuning the diagonal entries of
-- the mass matrix.
--
-- NOTE: If changed, also change help text of 'HTuneMasses'.
samplesDiagonalMin :: Int
samplesDiagonalMin :: Int
samplesDiagonalMin = Int
61

-- Minimal number of samples required for tuning all entries of the mass matrix.
--
-- NOTE: If changed, also change help text of 'HTuneMasses'.
samplesAllMinWith :: Dimension -> Int
samplesAllMinWith :: Int -> Int
samplesAllMinWith Int
d = Int
samplesDiagonalMin forall a. Num a => a -> a -> a
+ forall a. Ord a => a -> a -> a
max Int
samplesDiagonalMin Int
d

getSampleSize :: VS.Vector Double -> Int
getSampleSize :: Vector Double -> Int
getSampleSize = forall a. Storable a => Vector a -> Int
VS.length forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. (Storable a, Eq a) => Vector a -> Vector a
VS.uniq forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall e (v :: * -> *). (Ord e, Vector v e) => v e -> v e
S.gsort

rescueAfter :: Double -> Double
rescueAfter :: Double -> Double
rescueAfter Double
y
  | Double
massMin forall a. Ord a => a -> a -> Bool
> forall a. Num a => a -> a
abs Double
y = forall a. Num a => a -> a
signum Double
y forall a. Num a => a -> a -> a
* Double
massMin
  | forall a. Num a => a -> a
abs Double
y forall a. Ord a => a -> a -> Bool
> Double
massMax = forall a. Num a => a -> a
signum Double
y forall a. Num a => a -> a -> a
* Double
massMax
  | Bool
otherwise = Double
y

rescueBefore :: (Double -> Double -> Double) -> Double -> Double -> Double
rescueBefore :: (Double -> Double -> Double) -> Double -> Double -> Double
rescueBefore Double -> Double -> Double
f Double
old Double
new
  -- Be permissive with NaN and infinite values.
  | forall a. RealFloat a => a -> Bool
isNaN Double
new = Double
old
  | forall a. RealFloat a => a -> Bool
isInfinite Double
new = Double
old
  | Bool
otherwise = Double -> Double -> Double
f Double
old Double
new

rescue :: (Double -> Double -> Double) -> Double -> Double -> Double
rescue :: (Double -> Double -> Double) -> Double -> Double -> Double
rescue Double -> Double -> Double
f Double
old Double
new = Double -> Double
rescueAfter forall a b. (a -> b) -> a -> b
$ (Double -> Double -> Double) -> Double -> Double -> Double
rescueBefore Double -> Double -> Double
f Double
old Double
new

-- -- I do not know where I got this function from, but it works pretty well!
-- interpolate :: Double -> Double -> Double
-- interpolate old new = interSqrt ** 2
--   where
--     interSqrt = recip 3 * (sqrt old + 2 * sqrt new)

-- | This parameter plays the same role as @m@ in the dual averaging algorithm.
-- In the beginning, when @m@ is zero, the newly estimated masses will take full
-- precedence over the current masses. After some auto tuning steps, when @m@ is
-- larger, the newly estimated masses influence the current masses only
-- slightly.
newtype SmoothingParameter = SmoothingParameter Natural

-- Another interpolation function I came up with. It is pretty cool, because
-- (similar to above) it interpolates the square roots, which is what we want.
interpolate' :: SmoothingParameter -> Double -> Double -> Double
interpolate' :: SmoothingParameter -> Double -> Double -> Double
interpolate' (SmoothingParameter Natural
m) Double
oldSquared Double
newSquared = Double
finSign forall a. Num a => a -> a -> a
* (Double
fin forall a. Floating a => a -> a -> a
** Double
2)
  where
    oldSign :: Double
oldSign = forall a. Num a => a -> a
signum Double
oldSquared
    newSign :: Double
newSign = forall a. Num a => a -> a
signum Double
newSquared
    sqrt' :: Double -> Double
sqrt' = forall a. Floating a => a -> a
sqrt forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Num a => a -> a
abs
    old :: Double
old = Double
oldSign forall a. Num a => a -> a -> a
* Double -> Double
sqrt' Double
oldSquared
    new :: Double
new = Double
newSign forall a. Num a => a -> a -> a
* Double -> Double
sqrt' Double
newSquared
    -- The new mass will be the second last boundary. That is, the larger the
    -- number of bins is, the more informative the new mass is compared to the
    -- old mass.
    nbins :: Double
nbins = forall a. Ord a => a -> a -> a
max (Double
100 forall a. Num a => a -> a -> a
- forall a b. (Integral a, Num b) => a -> b
fromIntegral Natural
m) Double
2
    fin :: Double
fin = forall a. Fractional a => a -> a
recip Double
nbins forall a. Num a => a -> a -> a
* (Double
old forall a. Num a => a -> a -> a
+ (Double
nbins forall a. Num a => a -> a -> a
- Double
1) forall a. Num a => a -> a -> a
* Double
new)
    finSign :: Double
finSign = forall a. Num a => a -> a
signum Double
fin

getNewMassDiagonalWithSampleSize :: SmoothingParameter -> Int -> Double -> Double -> Double
getNewMassDiagonalWithSampleSize :: SmoothingParameter -> Int -> Double -> Double -> Double
getNewMassDiagonalWithSampleSize SmoothingParameter
m Int
sampleSize Double
massOld Double
massEstimate
  | Int
sampleSize forall a. Ord a => a -> a -> Bool
< Int
samplesDiagonalMin = Double
massOld
  -- Diagonal masses are variances which are strictly positive.
  | Double
massEstimate forall a. Ord a => a -> a -> Bool
<= Double
0 = Double
massOld
  | Bool
otherwise = (Double -> Double -> Double) -> Double -> Double -> Double
rescue (SmoothingParameter -> Double -> Double -> Double
interpolate' SmoothingParameter
m) Double
massOld Double
massEstimate

getNewMassDiagonal :: SmoothingParameter -> Double -> Double -> Double
getNewMassDiagonal :: SmoothingParameter -> Double -> Double -> Double
getNewMassDiagonal SmoothingParameter
m Double
massOld Double
massEstimate
  -- Diagonal masses are variances which are strictly positive.
  | Double
massEstimate forall a. Ord a => a -> a -> Bool
<= Double
0 = Double
massOld
  | Bool
otherwise = (Double -> Double -> Double) -> Double -> Double -> Double
rescue (SmoothingParameter -> Double -> Double -> Double
interpolate' SmoothingParameter
m) Double
massOld Double
massEstimate

getNewMassOffDiagonal :: SmoothingParameter -> Double -> Double -> Double
getNewMassOffDiagonal :: SmoothingParameter -> Double -> Double -> Double
getNewMassOffDiagonal SmoothingParameter
m = (Double -> Double -> Double) -> Double -> Double -> Double
rescue (SmoothingParameter -> Double -> Double -> Double
interpolate' SmoothingParameter
m)

-- The Cholesky decomposition, which is performed when sampling new momenta with
-- 'generateMomenta', requires a positive definite covariance matrix. The
-- Graphical Lasso algorithm finds positive definite covariance matrices, but
-- sometimes positive definiteness is violated because of numerical errors.
-- Further, when non-diagonal masses are already non-zero, the tuning of
-- diagonal masses only may violate positive definiteness.
--
-- Find the closest positive definite matrix of a given matrix.
--
-- See https://gist.github.com/fasiha/fdb5cec2054e6f1c6ae35476045a0bbd.
findClosestPositiveDefiniteMatrix :: L.Matrix Double -> L.Matrix Double
findClosestPositiveDefiniteMatrix :: Matrix Double -> Matrix Double
findClosestPositiveDefiniteMatrix Matrix Double
a
  | Int
n forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| Int
m forall a. Eq a => a -> a -> Bool
== Int
0 = forall a. HasCallStack => [Char] -> a
error [Char]
"findClosestPositiveDefiniteMatrix: Matrix empty."
  | Int
n forall a. Eq a => a -> a -> Bool
/= Int
m = forall a. HasCallStack => [Char] -> a
error [Char]
"findClosestPositiveDefiniteMatrix: Matrix not square."
  | Matrix Double -> Bool
isPositiveDefinite Matrix Double
a = Matrix Double
a
  | Bool
otherwise = Matrix Double -> Double -> Matrix Double
go Matrix Double
a3 Double
1
  where
    n :: Int
n = forall t. Matrix t -> Int
L.rows Matrix Double
a
    m :: Int
m = forall t. Matrix t -> Int
L.cols Matrix Double
a
    b :: Matrix Double
b = forall t. Herm t -> Matrix t
L.unSym forall a b. (a -> b) -> a -> b
$ forall t. Field t => Matrix t -> Herm t
L.sym Matrix Double
a
    (Matrix Double
_, Vector Double
s, Matrix Double
v) = forall t.
Field t =>
Matrix t -> (Matrix t, Vector Double, Matrix t)
L.svd Matrix Double
b
    h :: Matrix Double
h = forall m mt. Transposable m mt => m -> mt
L.tr Matrix Double
v forall t. Numeric t => Matrix t -> Matrix t -> Matrix t
L.<> (forall a. (Num a, Element a) => Vector a -> Matrix a
L.diag Vector Double
s forall t. Numeric t => Matrix t -> Matrix t -> Matrix t
L.<> Matrix Double
v)
    a2 :: Matrix Double
a2 = forall t (c :: * -> *). Linear t c => t -> c t -> c t
L.scale Double
0.5 (Matrix Double
b forall a. Num a => a -> a -> a
+ Matrix Double
h)
    a3 :: Matrix Double
a3 = forall t. Herm t -> Matrix t
L.unSym forall a b. (a -> b) -> a -> b
$ forall t. Field t => Matrix t -> Herm t
L.sym Matrix Double
a2
    isPositiveDefinite :: Matrix Double -> Bool
isPositiveDefinite = forall a. Maybe a -> Bool
isJust forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Field t => Herm t -> Maybe (Matrix t)
L.mbChol forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall t. Matrix t -> Herm t
L.trustSym
    --
    i :: Matrix Double
i = forall a. (Num a, Element a) => Int -> Matrix a
L.ident Int
n
    -- See https://hackage.haskell.org/package/ieee754-0.8.0/docs/src/Numeric-IEEE.html#line-177.
    eps :: Double
eps = Double
2.2204460492503131e-16
    go :: Matrix Double -> Double -> Matrix Double
go Matrix Double
x Double
k
      | Matrix Double -> Bool
isPositiveDefinite Matrix Double
x = Matrix Double
x
      | Bool
otherwise =
          let minEig :: Double
minEig = forall (c :: * -> *) e. Container c e => c e -> e
L.minElement forall a b. (a -> b) -> a -> b
$ forall b (c :: * -> *) e.
(Element b, Container c e) =>
(e -> b) -> c e -> c b
L.cmap forall a. Complex a -> a
L.realPart forall a b. (a -> b) -> a -> b
$ forall t. Field t => Matrix t -> Vector (Complex Double)
L.eigenvalues Matrix Double
x
              nu :: Double
nu = forall a. Num a => a -> a
negate Double
minEig forall a. Num a => a -> a -> a
* (Double
k forall a. Floating a => a -> a -> a
** Double
2) forall a. Num a => a -> a -> a
+ Double
eps
              x' :: Matrix Double
x' = Matrix Double
x forall a. Num a => a -> a -> a
+ forall t (c :: * -> *). Linear t c => t -> c t -> c t
L.scale Double
nu Matrix Double
i
           in Matrix Double -> Double -> Matrix Double
go Matrix Double
x' (Double
k forall a. Num a => a -> a -> a
+ Double
1)

tuneDiagonalMassesOnly ::
  SmoothingParameter ->
  -- Conversion from value to vector.
  (a -> Positions) ->
  -- Value vector.
  VB.Vector a ->
  -- Old mass matrix, and inverted mass matrix.
  (Masses, MassesI) ->
  -- new mass matrix, and inverted mass matrix.
  (Masses, MassesI)
-- NOTE: Here, we lose time because we convert the states to vectors again,
-- something that has already been done. But then, auto tuning is not a runtime
-- determining factor.
tuneDiagonalMassesOnly :: forall a.
SmoothingParameter
-> (a -> Vector Double)
-> Vector a
-> (Herm Double, GMatrix)
-> (Herm Double, GMatrix)
tuneDiagonalMassesOnly SmoothingParameter
m a -> Vector Double
toVec Vector a
xs (Herm Double
ms, GMatrix
msI)
  -- If not enough data is available, do not tune.
  | forall a. Vector a -> Int
VB.length Vector a
xs forall a. Ord a => a -> a -> Bool
< Int
samplesDiagonalMin = (Herm Double
ms, GMatrix
msI)
  | Int
dimState forall a. Eq a => a -> a -> Bool
/= Int
dimMs = forall a. HasCallStack => [Char] -> a
error [Char]
"tuneDiagonalMassesOnly: Dimension mismatch."
  -- Replace the diagonal.
  | Bool
otherwise =
      let msDirty :: Matrix Double
msDirty = Matrix Double
msOld forall a. Num a => a -> a -> a
- forall a. (Num a, Element a) => Vector a -> Matrix a
L.diag Vector Double
msDiagonalOld forall a. Num a => a -> a -> a
+ forall a. (Num a, Element a) => Vector a -> Matrix a
L.diag Vector Double
msDiagonalNew
          -- Positive definite matrices are symmetric.
          ms' :: Herm Double
ms' = forall t. Matrix t -> Herm t
L.trustSym forall a b. (a -> b) -> a -> b
$ Matrix Double -> Matrix Double
findClosestPositiveDefiniteMatrix forall a b. (a -> b) -> a -> b
$ Matrix Double -> Matrix Double
cleanMatrix Matrix Double
msDirty
          msI' :: GMatrix
msI' = Herm Double -> GMatrix
getMassesI Herm Double
ms'
       in (Herm Double
ms', GMatrix
msI')
  where
    -- xs: Each element contains all parameters of one iteration.
    -- xs': Each element is a vector containing all parameters changed by the
    -- proposal of one iteration.
    xs' :: Vector (Vector Double)
xs' = forall a b. (a -> b) -> Vector a -> Vector b
VB.map a -> Vector Double
toVec Vector a
xs
    -- xs'': Matrix with each row containing all parameter values changed by the
    -- proposal of one iteration.
    xs'' :: Matrix Double
xs'' = forall t. Element t => [Vector t] -> Matrix t
L.fromRows forall a b. (a -> b) -> a -> b
$ forall a. Vector a -> [a]
VB.toList Vector (Vector Double)
xs'
    -- We can safely use 'VB.head' here since the length of 'xs' must be larger
    -- than 'samplesDiagonalMin'.
    dimState :: Int
dimState = forall a. Storable a => Vector a -> Int
VS.length forall a b. (a -> b) -> a -> b
$ forall a. Vector a -> a
VB.head Vector (Vector Double)
xs'
    sampleSizes :: Vector Int
sampleSizes = forall a. Storable a => [a] -> Vector a
VS.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map Vector Double -> Int
getSampleSize forall a b. (a -> b) -> a -> b
$ forall t. Element t => Matrix t -> [Vector t]
L.toColumns Matrix Double
xs''
    msOld :: Matrix Double
msOld = forall t. Herm t -> Matrix t
L.unSym Herm Double
ms
    dimMs :: Int
dimMs = forall t. Matrix t -> Int
L.rows Matrix Double
msOld
    msDiagonalOld :: Vector Double
msDiagonalOld = forall t. Element t => Matrix t -> Vector t
L.takeDiag Matrix Double
msOld
    msDiagonalEstimate :: Vector Double
msDiagonalEstimate = forall a. Storable a => [a] -> Vector a
VS.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (forall a. Fractional a => a -> a
recip forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (v :: * -> *). Vector v Double => v Double -> Double
S.variance) forall a b. (a -> b) -> a -> b
$ forall t. Element t => Matrix t -> [Vector t]
L.toColumns Matrix Double
xs''
    msDiagonalNew :: Vector Double
msDiagonalNew =
      forall a b c d.
(Storable a, Storable b, Storable c, Storable d) =>
(a -> b -> c -> d) -> Vector a -> Vector b -> Vector c -> Vector d
VS.zipWith3
        (SmoothingParameter -> Int -> Double -> Double -> Double
getNewMassDiagonalWithSampleSize SmoothingParameter
m)
        Vector Int
sampleSizes
        Vector Double
msDiagonalOld
        Vector Double
msDiagonalEstimate

-- This value was carefully tuned using the example "hamiltonian".
defaultGraphicalLassoPenalty :: Double
defaultGraphicalLassoPenalty :: Double
defaultGraphicalLassoPenalty = Double
0.3

interpolateElementWise :: SmoothingParameter -> L.Matrix Double -> L.Matrix Double -> L.Matrix Double
interpolateElementWise :: SmoothingParameter
-> Matrix Double -> Matrix Double -> Matrix Double
interpolateElementWise SmoothingParameter
m Matrix Double
old Matrix Double
new
  | Int
mO forall a. Eq a => a -> a -> Bool
/= Int
mN = forall {a}. [Char] -> a
err [Char]
"different number of rows"
  | Int
nO forall a. Eq a => a -> a -> Bool
/= Int
nN = forall {a}. [Char] -> a
err [Char]
"different number of columns"
  | Int
mO forall a. Eq a => a -> a -> Bool
/= Int
nO = forall {a}. [Char] -> a
err [Char]
"not square"
  | Int
mO forall a. Ord a => a -> a -> Bool
< Int
1 = forall {a}. [Char] -> a
err [Char]
"empty matrix"
  | Bool
otherwise = forall d f (c :: * -> *) e. Build d f c e => d -> f -> c e
L.build (Int
mO, Int
nO) forall {p} {p}. (RealFrac p, RealFrac p) => p -> p -> Double
f
  where
    mO :: Int
mO = forall t. Matrix t -> Int
L.rows Matrix Double
old
    nO :: Int
nO = forall t. Matrix t -> Int
L.cols Matrix Double
old
    mN :: Int
mN = forall t. Matrix t -> Int
L.rows Matrix Double
new
    nN :: Int
nN = forall t. Matrix t -> Int
L.cols Matrix Double
new
    err :: [Char] -> a
err [Char]
msg = forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"interpolateElementWise: " forall a. Semigroup a => a -> a -> a
<> [Char]
msg
    f :: p -> p -> Double
f p
iD p
jD =
      let -- This sucks a bit, because we need a function (e -> e -> e), and
          -- since the return type is Double, the indices are also Doubble.
          i :: Int
i = forall a b. (RealFrac a, Integral b) => a -> b
round p
iD
          j :: Int
j = forall a b. (RealFrac a, Integral b) => a -> b
round p
jD
          g :: Double -> Double -> Double
g = if Int
i forall a. Eq a => a -> a -> Bool
== Int
j then SmoothingParameter -> Double -> Double -> Double
getNewMassDiagonal SmoothingParameter
m else SmoothingParameter -> Double -> Double -> Double
getNewMassOffDiagonal SmoothingParameter
m
          eO :: Double
eO = Matrix Double
old forall (c :: * -> *) e. Container c e => c e -> IndexOf c -> e
`L.atIndex` (Int
i, Int
j)
          eN :: Double
eN = Matrix Double
new forall (c :: * -> *) e. Container c e => c e -> IndexOf c -> e
`L.atIndex` (Int
i, Int
j)
       in Double -> Double -> Double
g Double
eO Double
eN

tuneAllMasses ::
  SmoothingParameter ->
  -- Conversion from value to vector.
  (a -> Positions) ->
  -- Value vector.
  VB.Vector a ->
  -- Old mass matrix, and inverted mass matrix.
  (Masses, MassesI) ->
  -- New mass matrix, and inverted mass matrix.
  (Masses, MassesI)
-- NOTE: Here, we lose time because we convert the states to vectors again,
-- something that has already been done. But then, auto tuning is not a runtime
-- determining factor.
tuneAllMasses :: forall a.
SmoothingParameter
-> (a -> Vector Double)
-> Vector a
-> (Herm Double, GMatrix)
-> (Herm Double, GMatrix)
tuneAllMasses SmoothingParameter
m a -> Vector Double
toVec Vector a
xs (Herm Double
ms, GMatrix
msI)
  -- If not enough data is available, do not tune.
  | forall a. Vector a -> Int
VB.length Vector a
xs forall a. Ord a => a -> a -> Bool
< Int
samplesDiagonalMin = (Herm Double
ms, GMatrix
msI)
  -- If not enough data is available, only the diagonal masses are tuned.
  | forall a. Vector a -> Int
VB.length Vector a
xs forall a. Ord a => a -> a -> Bool
< Int -> Int
samplesAllMinWith Int
dimMs = (Herm Double, GMatrix)
fallbackDiagonal
  | forall t. Field t => Matrix t -> Int
L.rank Matrix Double
xs'' forall a. Eq a => a -> a -> Bool
/= Int
dimState = (Herm Double, GMatrix)
fallbackDiagonal
  | Int
dimState forall a. Eq a => a -> a -> Bool
/= Int
dimMs = forall a. HasCallStack => [Char] -> a
error [Char]
"tuneAllMasses: Dimension mismatch."
  | Bool
otherwise = (Herm Double
msNew, GMatrix
msINew)
  where
    fallbackDiagonal :: (Herm Double, GMatrix)
fallbackDiagonal = forall a.
SmoothingParameter
-> (a -> Vector Double)
-> Vector a
-> (Herm Double, GMatrix)
-> (Herm Double, GMatrix)
tuneDiagonalMassesOnly SmoothingParameter
m a -> Vector Double
toVec Vector a
xs (Herm Double
ms, GMatrix
msI)
    -- xs: Each element contains all parameters of one iteration.
    -- xs': Each element is a vector containing all parameters changed by the
    -- proposal of one iteration.
    xs' :: Vector (Vector Double)
xs' = forall a b. (a -> b) -> Vector a -> Vector b
VB.map a -> Vector Double
toVec Vector a
xs
    -- xs'': Matrix with each row containing all parameter values changed by the
    -- proposal of one iteration.
    xs'' :: Matrix Double
xs'' = forall t. Element t => [Vector t] -> Matrix t
L.fromRows forall a b. (a -> b) -> a -> b
$ forall a. Vector a -> [a]
VB.toList Vector (Vector Double)
xs'
    -- We can safely use 'VB.head' here since the length of 'xs' must be larger
    -- than 'samplesDiagonalMin'.
    dimState :: Int
dimState = forall a. Storable a => Vector a -> Int
VS.length forall a b. (a -> b) -> a -> b
$ forall a. Vector a -> a
VB.head Vector (Vector Double)
xs'
    dimMs :: Int
dimMs = forall t. Matrix t -> Int
L.rows forall a b. (a -> b) -> a -> b
$ forall t. Herm t -> Matrix t
L.unSym Herm Double
ms
    (Vector Double
_, Vector Double
ss, Matrix Double
xsNormalized) = Matrix Double -> (Vector Double, Vector Double, Matrix Double)
S.scale Matrix Double
xs''
    -- The first value is the covariance matrix sigma, which the inverted mass
    -- matrix (precision matrix). However, we interpolate the new mass matrix
    -- using the old one and the new estimate, so we have to recalculate the
    -- covariance matrix anyways.
    (Herm Double
_, Herm Double
precNormalized) =
      forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. HasCallStack => [Char] -> a
error forall a. a -> a
id forall a b. (a -> b) -> a -> b
$
        Double -> Matrix Double -> Either [Char] (Herm Double, Herm Double)
S.graphicalLasso Double
defaultGraphicalLassoPenalty Matrix Double
xsNormalized
    ms' :: Matrix Double
ms' = Vector Double -> Matrix Double -> Matrix Double
S.rescalePWith Vector Double
ss (forall t. Herm t -> Matrix t
L.unSym Herm Double
precNormalized)
    -- Clean NaNs, infinities; ensure positive definiteness. The masses should
    -- be positive definite, but sometimes they happen to be not because of
    -- numerical errors.
    msNewDirty :: Matrix Double
msNewDirty = SmoothingParameter
-> Matrix Double -> Matrix Double -> Matrix Double
interpolateElementWise SmoothingParameter
m (forall t. Herm t -> Matrix t
L.unSym Herm Double
ms) Matrix Double
ms'
    -- Positive definite matrices are symmetric.
    msNew :: Herm Double
msNew = forall t. Matrix t -> Herm t
L.trustSym forall a b. (a -> b) -> a -> b
$ Matrix Double -> Matrix Double
findClosestPositiveDefiniteMatrix forall a b. (a -> b) -> a -> b
$ Matrix Double -> Matrix Double
cleanMatrix Matrix Double
msNewDirty
    msINew :: GMatrix
msINew = Herm Double -> GMatrix
getMassesI Herm Double
msNew