{-# OPTIONS_GHC -fno-warn-type-defaults #-}
{-# LANGUAGE RecordWildCards #-}

-- |
-- Module: Numeric.MCMC.Anneal
-- Copyright: (c) 2015 Jared Tobin
-- License: MIT
--
-- Maintainer: Jared Tobin <jared@jtobin.ca>
-- Stability: unstable
-- Portability: ghc
--
-- Transition operators can easily be tweaked to operate over an /annealed/
-- parameter space, which can be useful when sampling from bumpy landscapes
-- with isolated modes.
--
-- This library exports a single 'anneal' function that allows one to run a
-- /declarative/-compatible transition operator over a space that has been
-- annealed to a specified temperature.
--
-- > import Numeric.MCMC
-- >
-- > annealingTransition = do
-- >   anneal 0.70 (metropolis 1)
-- >   anneal 0.05 (metropolis 1)
-- >   anneal 0.05 (metropolis 1)
-- >   anneal 0.70 (metropolis 1)
-- >   metropolis 1
--
-- These annealed operators can then just be used like any other:
--
-- > himmelblau :: Target [Double]
-- > himmelblau = Target lHimmelblau Nothing where
-- >   lHimmelblau :: [Double] -> Double
-- >   lHimmelblau [x0, x1] =
-- >     (-1) * ((x0 * x0 + x1 - 11) ^ 2 + (x0 + x1 * x1 - 7) ^ 2)
-- >
-- > main :: IO ()
-- > main = withSystemRandom . asGenIO $
-- >   mcmc 10000 [0, 0] annealingTransition himmelblau

module Numeric.MCMC.Anneal (
    anneal
  ) where

import Control.Monad.Trans.State.Strict (get, modify)
import Data.Sampling.Types (Transition, Chain(..), Target(..))

-- | An annealing transformer.
--
--   When executed, the supplied transition operator will execute over the
--   parameter space annealed to the supplied inverse temperature.
--
--   > let annealedTransition = anneal 0.30 (slice 0.5)
anneal
  :: (Monad m, Functor f)
  => Double
  -> Transition m (Chain (f Double) b)
  -> Transition m (Chain (f Double) b)
anneal :: Double
-> Transition m (Chain (f Double) b)
-> Transition m (Chain (f Double) b)
anneal Double
invTemp Transition m (Chain (f Double) b)
baseTransition
  | Double
invTemp Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0 = [Char] -> Transition m (Chain (f Double) b)
forall a. HasCallStack => [Char] -> a
error [Char]
"anneal: invalid temperture"
  | Bool
otherwise = do
      Chain {f Double
Double
Maybe b
Target (f Double)
chainTarget :: forall a b. Chain a b -> Target a
chainScore :: forall a b. Chain a b -> Double
chainPosition :: forall a b. Chain a b -> a
chainTunables :: forall a b. Chain a b -> Maybe b
chainTunables :: Maybe b
chainPosition :: f Double
chainScore :: Double
chainTarget :: Target (f Double)
..} <- StateT (Chain (f Double) b) (Prob m) (Chain (f Double) b)
forall (m :: * -> *) s. Monad m => StateT s m s
get
      let annealedTarget :: Target (f Double)
annealedTarget = Double -> Target (f Double) -> Target (f Double)
forall (f :: * -> *).
Functor f =>
Double -> Target (f Double) -> Target (f Double)
annealer Double
invTemp Target (f Double)
chainTarget
      (Chain (f Double) b -> Chain (f Double) b)
-> Transition m (Chain (f Double) b)
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((Chain (f Double) b -> Chain (f Double) b)
 -> Transition m (Chain (f Double) b))
-> (Chain (f Double) b -> Chain (f Double) b)
-> Transition m (Chain (f Double) b)
forall a b. (a -> b) -> a -> b
$ Target (f Double) -> Chain (f Double) b -> Chain (f Double) b
forall a b. Target a -> Chain a b -> Chain a b
useTarget Target (f Double)
annealedTarget
      Transition m (Chain (f Double) b)
baseTransition
      (Chain (f Double) b -> Chain (f Double) b)
-> Transition m (Chain (f Double) b)
forall (m :: * -> *) s. Monad m => (s -> s) -> StateT s m ()
modify ((Chain (f Double) b -> Chain (f Double) b)
 -> Transition m (Chain (f Double) b))
-> (Chain (f Double) b -> Chain (f Double) b)
-> Transition m (Chain (f Double) b)
forall a b. (a -> b) -> a -> b
$ Target (f Double) -> Chain (f Double) b -> Chain (f Double) b
forall a b. Target a -> Chain a b -> Chain a b
useTarget Target (f Double)
chainTarget

annealer :: Functor f => Double -> Target (f Double) -> Target (f Double)
annealer :: Double -> Target (f Double) -> Target (f Double)
annealer Double
invTemp Target (f Double)
target = (f Double -> Double)
-> Maybe (f Double -> f Double) -> Target (f Double)
forall a. (a -> Double) -> Maybe (a -> a) -> Target a
Target f Double -> Double
annealedL Maybe (f Double -> f Double)
annealedG where
  annealedL :: f Double -> Double
annealedL f Double
xs = Double
invTemp Double -> Double -> Double
forall a. Num a => a -> a -> a
* Target (f Double) -> f Double -> Double
forall a. Target a -> a -> Double
lTarget Target (f Double)
target f Double
xs
  annealedG :: Maybe (f Double -> f Double)
annealedG    =
    case Target (f Double) -> Maybe (f Double -> f Double)
forall a. Target a -> Maybe (a -> a)
glTarget Target (f Double)
target of
      Maybe (f Double -> f Double)
Nothing -> Maybe (f Double -> f Double)
forall a. Maybe a
Nothing
      Just f Double -> f Double
g  -> (f Double -> f Double) -> Maybe (f Double -> f Double)
forall a. a -> Maybe a
Just ((Double -> Double) -> f Double -> f Double
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
invTemp) (f Double -> f Double)
-> (f Double -> f Double) -> f Double -> f Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. f Double -> f Double
g)

useTarget :: Target a -> Chain a b -> Chain a b
useTarget :: Target a -> Chain a b -> Chain a b
useTarget Target a
newTarget Chain {a
Double
Maybe b
Target a
chainTunables :: Maybe b
chainPosition :: a
chainScore :: Double
chainTarget :: Target a
chainTarget :: forall a b. Chain a b -> Target a
chainScore :: forall a b. Chain a b -> Double
chainPosition :: forall a b. Chain a b -> a
chainTunables :: forall a b. Chain a b -> Maybe b
..} =
  Target a -> Double -> a -> Maybe b -> Chain a b
forall a b. Target a -> Double -> a -> Maybe b -> Chain a b
Chain Target a
newTarget (Target a -> a -> Double
forall a. Target a -> a -> Double
lTarget Target a
newTarget a
chainPosition) a
chainPosition Maybe b
chainTunables