{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}

-- |
-- Module      :  Mcmc.Cycle
-- Description :  A cycle is a list of proposals
-- Copyright   :  2021 Dominik Schrempf
-- License     :  GPL-3.0-or-later
--
-- Maintainer  :  dominik.schrempf@gmail.com
-- Stability   :  experimental
-- Portability :  portable
--
-- Creation date: Thu Jul  8 17:56:03 2021.
module Mcmc.Cycle
  ( -- * Cycles
    Order (..),
    Cycle (ccProposals, ccRequireTrace),
    cycleFromList,
    setOrder,
    IterationMode (..),
    prepareProposals,
    autoTuneCycle,

    -- * Output
    proposalHLine,
    summarizeCycle,
  )
where

import qualified Data.ByteString.Builder as BB
import qualified Data.ByteString.Lazy.Char8 as BL
import Data.Default
import Data.List
import qualified Data.Map.Strict as M
import qualified Data.Vector as VB
import Mcmc.Acceptance
import Mcmc.Internal.Shuffle
import Mcmc.Proposal
import System.Random.Stateful

-- | Define the order in which 'Proposal's are executed in a 'Cycle'. The total
-- number of 'Proposal's per 'Cycle' may differ between 'Order's (e.g., compare
-- 'RandomO' and 'RandomReversibleO').
data Order
  = -- | Shuffle the 'Proposal's in the 'Cycle'. The 'Proposal's are replicated
    -- according to their weights and executed in random order. If a 'Proposal' has
    -- weight @w@, it is executed exactly @w@ times per iteration.
    RandomO
  | -- | The 'Proposal's are executed sequentially, in the order they appear in the
    -- 'Cycle'. 'Proposal's with weight @w>1@ are repeated immediately @w@ times
    -- (and not appended to the end of the list).
    SequentialO
  | -- | Similar to 'RandomO'. However, a reversed copy of the list of
    --  shuffled 'Proposal's is appended such that the resulting Markov chain is
    --  reversible.
    --  Note: the total number of 'Proposal's executed per cycle is twice the number
    --  of 'RandomO'.
    RandomReversibleO
  | -- | Similar to 'SequentialO'. However, a reversed copy of the list of
    -- sequentially ordered 'Proposal's is appended such that the resulting Markov
    -- chain is reversible.
    SequentialReversibleO
  deriving (Order -> Order -> Bool
(Order -> Order -> Bool) -> (Order -> Order -> Bool) -> Eq Order
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Order -> Order -> Bool
$c/= :: Order -> Order -> Bool
== :: Order -> Order -> Bool
$c== :: Order -> Order -> Bool
Eq, Int -> Order -> ShowS
[Order] -> ShowS
Order -> String
(Int -> Order -> ShowS)
-> (Order -> String) -> ([Order] -> ShowS) -> Show Order
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Order] -> ShowS
$cshowList :: [Order] -> ShowS
show :: Order -> String
$cshow :: Order -> String
showsPrec :: Int -> Order -> ShowS
$cshowsPrec :: Int -> Order -> ShowS
Show)

instance Default Order where def :: Order
def = Order
RandomO

-- Describe the order.
describeOrder :: Order -> BL.ByteString
describeOrder :: Order -> ByteString
describeOrder Order
RandomO = ByteString
"The proposals are executed in random order."
describeOrder Order
SequentialO = ByteString
"The proposals are executed sequentially."
describeOrder Order
RandomReversibleO =
  ByteString -> [ByteString] -> ByteString
BL.intercalate
    ByteString
"\n"
    [ Order -> ByteString
describeOrder Order
RandomO,
      ByteString
"A reversed copy of the shuffled proposals is appended to ensure reversibility."
    ]
describeOrder Order
SequentialReversibleO =
  ByteString -> [ByteString] -> ByteString
BL.intercalate
    ByteString
"\n"
    [ Order -> ByteString
describeOrder Order
SequentialO,
      ByteString
"A reversed copy of the sequential proposals is appended to ensure reversibility."
    ]

-- | In brief, a 'Cycle' is a list of proposals.
--
-- The state of the Markov chain will be logged only after all 'Proposal's in
-- the 'Cycle' have been completed, and the iteration counter will be increased
-- by one. The order in which the 'Proposal's are executed is specified by
-- 'Order'. The default is 'RandomO'.
--
-- No proposals with the same name and description are allowed in a 'Cycle', so
-- that they can be uniquely identified.
data Cycle a = Cycle
  { Cycle a -> [Proposal a]
ccProposals :: [Proposal a],
    Cycle a -> Order
ccOrder :: Order,
    -- | Does the cycle require the trace when auto tuning? See 'tRequireTrace'.
    Cycle a -> Bool
ccRequireTrace :: Bool
  }

-- | Create a 'Cycle' from a list of 'Proposal's.
cycleFromList :: [Proposal a] -> Cycle a
cycleFromList :: [Proposal a] -> Cycle a
cycleFromList [] =
  String -> Cycle a
forall a. HasCallStack => String -> a
error String
"cycleFromList: Received an empty list but cannot create an empty Cycle."
cycleFromList [Proposal a]
xs =
  if [Proposal a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Proposal a]
uniqueXs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Proposal a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Proposal a]
xs
    then [Proposal a] -> Order -> Bool -> Cycle a
forall a. [Proposal a] -> Order -> Bool -> Cycle a
Cycle [Proposal a]
xs Order
forall a. Default a => a
def ((Proposal a -> Bool) -> [Proposal a] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Proposal a -> Bool
forall a. Proposal a -> Bool
needsTrace [Proposal a]
xs)
    else String -> Cycle a
forall a. HasCallStack => String -> a
error (String -> Cycle a) -> String -> Cycle a
forall a b. (a -> b) -> a -> b
$ String
"\n" String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
msg String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
"cycleFromList: Proposals are not unique."
  where
    uniqueXs :: [Proposal a]
uniqueXs = [Proposal a] -> [Proposal a]
forall a. Eq a => [a] -> [a]
nub [Proposal a]
xs
    removedXs :: [Proposal a]
removedXs = [Proposal a]
xs [Proposal a] -> [Proposal a] -> [Proposal a]
forall a. Eq a => [a] -> [a] -> [a]
\\ [Proposal a]
uniqueXs
    removedNames :: [String]
removedNames = (Proposal a -> String) -> [Proposal a] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (PName -> String
forall a. Show a => a -> String
show (PName -> String) -> (Proposal a -> PName) -> Proposal a -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proposal a -> PName
forall a. Proposal a -> PName
prName) [Proposal a]
removedXs
    removedDescriptions :: [String]
removedDescriptions = (Proposal a -> String) -> [Proposal a] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (PDescription -> String
forall a. Show a => a -> String
show (PDescription -> String)
-> (Proposal a -> PDescription) -> Proposal a -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proposal a -> PDescription
forall a. Proposal a -> PDescription
prDescription) [Proposal a]
removedXs
    removedMsgs :: [String]
removedMsgs = (String -> ShowS) -> [String] -> [String] -> [String]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\String
n String
d -> String
n String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
d) [String]
removedNames [String]
removedDescriptions
    msg :: String
msg = [String] -> String
unlines [String]
removedMsgs
    needsTrace :: Proposal a -> Bool
needsTrace Proposal a
p = case Proposal a -> Maybe (Tuner a)
forall a. Proposal a -> Maybe (Tuner a)
prTuner Proposal a
p of
      Maybe (Tuner a)
Nothing -> Bool
False
      Just Tuner a
t -> Tuner a -> Bool
forall a. Tuner a -> Bool
tRequireTrace Tuner a
t

-- | Set the order of 'Proposal's in a 'Cycle'.
setOrder :: Order -> Cycle a -> Cycle a
setOrder :: Order -> Cycle a -> Cycle a
setOrder Order
o Cycle a
c = Cycle a
c {ccOrder :: Order
ccOrder = Order
o}

-- | Use all proposals, or use fast proposals only?
data IterationMode = AllProposals | FastProposals
  deriving (IterationMode -> IterationMode -> Bool
(IterationMode -> IterationMode -> Bool)
-> (IterationMode -> IterationMode -> Bool) -> Eq IterationMode
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IterationMode -> IterationMode -> Bool
$c/= :: IterationMode -> IterationMode -> Bool
== :: IterationMode -> IterationMode -> Bool
$c== :: IterationMode -> IterationMode -> Bool
Eq)

-- | Replicate 'Proposal's according to their weights and possibly shuffle them.
prepareProposals :: StatefulGen g m => IterationMode -> Cycle a -> g -> m [Proposal a]
prepareProposals :: IterationMode -> Cycle a -> g -> m [Proposal a]
prepareProposals IterationMode
m (Cycle [Proposal a]
xs Order
o Bool
_) g
g =
  if [Proposal a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Proposal a]
ps
    then String -> m [Proposal a]
forall a. HasCallStack => String -> a
error String
"prepareProposals: No proposals found."
    else case Order
o of
      Order
RandomO -> [Proposal a] -> g -> m [Proposal a]
forall g (m :: * -> *) a. StatefulGen g m => [a] -> g -> m [a]
shuffle [Proposal a]
ps g
g
      Order
SequentialO -> [Proposal a] -> m [Proposal a]
forall (m :: * -> *) a. Monad m => a -> m a
return [Proposal a]
ps
      Order
RandomReversibleO -> do
        [Proposal a]
psR <- [Proposal a] -> g -> m [Proposal a]
forall g (m :: * -> *) a. StatefulGen g m => [a] -> g -> m [a]
shuffle [Proposal a]
ps g
g
        [Proposal a] -> m [Proposal a]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Proposal a] -> m [Proposal a]) -> [Proposal a] -> m [Proposal a]
forall a b. (a -> b) -> a -> b
$ [Proposal a]
psR [Proposal a] -> [Proposal a] -> [Proposal a]
forall a. [a] -> [a] -> [a]
++ [Proposal a] -> [Proposal a]
forall a. [a] -> [a]
reverse [Proposal a]
psR
      Order
SequentialReversibleO -> [Proposal a] -> m [Proposal a]
forall (m :: * -> *) a. Monad m => a -> m a
return ([Proposal a] -> m [Proposal a]) -> [Proposal a] -> m [Proposal a]
forall a b. (a -> b) -> a -> b
$ [Proposal a]
ps [Proposal a] -> [Proposal a] -> [Proposal a]
forall a. [a] -> [a] -> [a]
++ [Proposal a] -> [Proposal a]
forall a. [a] -> [a]
reverse [Proposal a]
ps
  where
    !ps :: [Proposal a]
ps =
      [[Proposal a]] -> [Proposal a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
        [ Int -> Proposal a -> [Proposal a]
forall a. Int -> a -> [a]
replicate (PWeight -> Int
fromPWeight (PWeight -> Int) -> PWeight -> Int
forall a b. (a -> b) -> a -> b
$ Proposal a -> PWeight
forall a. Proposal a -> PWeight
prWeight Proposal a
p) Proposal a
p
          | Proposal a
p <- [Proposal a]
xs,
            case IterationMode
m of
              IterationMode
AllProposals -> Bool
True
              -- Only use proposal if it is fast.
              IterationMode
FastProposals -> Proposal a -> PSpeed
forall a. Proposal a -> PSpeed
prSpeed Proposal a
p PSpeed -> PSpeed -> Bool
forall a. Eq a => a -> a -> Bool
== PSpeed
PFast
        ]

-- The number of proposals depends on the order.
getNProposalsPerCycle :: IterationMode -> Cycle a -> Int
getNProposalsPerCycle :: IterationMode -> Cycle a -> Int
getNProposalsPerCycle IterationMode
m (Cycle [Proposal a]
xs Order
o Bool
_) = case Order
o of
  Order
RandomO -> Int
once
  Order
SequentialO -> Int
once
  Order
RandomReversibleO -> Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
once
  Order
SequentialReversibleO -> Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
once
  where
    xs' :: [Proposal a]
xs' = case IterationMode
m of
      IterationMode
AllProposals -> [Proposal a]
xs
      IterationMode
FastProposals -> (Proposal a -> Bool) -> [Proposal a] -> [Proposal a]
forall a. (a -> Bool) -> [a] -> [a]
filter (\Proposal a
x -> Proposal a -> PSpeed
forall a. Proposal a -> PSpeed
prSpeed Proposal a
x PSpeed -> PSpeed -> Bool
forall a. Eq a => a -> a -> Bool
== PSpeed
PFast) [Proposal a]
xs
    once :: Int
once = [Int] -> Int
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([Int] -> Int) -> [Int] -> Int
forall a b. (a -> b) -> a -> b
$ (Proposal a -> Int) -> [Proposal a] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (PWeight -> Int
fromPWeight (PWeight -> Int) -> (Proposal a -> PWeight) -> Proposal a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Proposal a -> PWeight
forall a. Proposal a -> PWeight
prWeight) [Proposal a]
xs'

-- See 'tuneWithTuningParameters' and 'Tuner'.
tuneWithChainParameters :: TuningType -> AcceptanceRate -> Maybe (VB.Vector a) -> Proposal a -> Either String (Proposal a)
tuneWithChainParameters :: TuningType
-> AcceptanceRate
-> Maybe (Vector a)
-> Proposal a
-> Either String (Proposal a)
tuneWithChainParameters TuningType
b AcceptanceRate
ar Maybe (Vector a)
mxs Proposal a
p = case Proposal a -> Maybe (Tuner a)
forall a. Proposal a -> Maybe (Tuner a)
prTuner Proposal a
p of
  Maybe (Tuner a)
Nothing -> Proposal a -> Either String (Proposal a)
forall a b. b -> Either a b
Right Proposal a
p
  Just (Tuner AcceptanceRate
t AuxiliaryTuningParameters
ts Bool
rt TuningFunction a
fT AcceptanceRate
-> AuxiliaryTuningParameters -> Either String (PFunction a)
_) -> case (Bool
rt, Maybe (Vector a)
mxs) of
    (Bool
True, Maybe (Vector a)
Nothing) -> String -> Either String (Proposal a)
forall a. HasCallStack => String -> a
error String
"tuneWithChainParameters: trace required"
    (Bool, Maybe (Vector a))
_ ->
      let (AcceptanceRate
t', AuxiliaryTuningParameters
ts') = TuningFunction a
fT TuningType
b PDimension
d AcceptanceRate
ar Maybe (Vector a)
mxs (AcceptanceRate
t, AuxiliaryTuningParameters
ts)
       in AcceptanceRate
-> AuxiliaryTuningParameters
-> Proposal a
-> Either String (Proposal a)
forall a.
AcceptanceRate
-> AuxiliaryTuningParameters
-> Proposal a
-> Either String (Proposal a)
tuneWithTuningParameters AcceptanceRate
t' AuxiliaryTuningParameters
ts' Proposal a
p
      where
        d :: PDimension
d = Proposal a -> PDimension
forall a. Proposal a -> PDimension
prDimension Proposal a
p

-- | Calculate acceptance rates and auto tunes the 'Proposal's in the 'Cycle'.
--
-- Do not change 'Proposal's that are not tuneable.
autoTuneCycle :: TuningType -> Acceptance (Proposal a) -> Maybe (VB.Vector a) -> Cycle a -> Cycle a
autoTuneCycle :: TuningType
-> Acceptance (Proposal a)
-> Maybe (Vector a)
-> Cycle a
-> Cycle a
autoTuneCycle TuningType
b Acceptance (Proposal a)
a Maybe (Vector a)
mxs Cycle a
c = case (Cycle a -> Bool
forall a. Cycle a -> Bool
ccRequireTrace Cycle a
c, Maybe (Vector a)
mxs) of
  (Bool
False, Just Vector a
_) -> String -> Cycle a
forall a. HasCallStack => String -> a
error String
"autoTuneCycle: trace not required"
  (Bool
True, Maybe (Vector a)
Nothing) -> String -> Cycle a
forall a. HasCallStack => String -> a
error String
"autoTuneCycle: trace required"
  (Bool, Maybe (Vector a))
_ ->
    if [Proposal a] -> [Proposal a]
forall a. Ord a => [a] -> [a]
sort (Map (Proposal a) (Maybe AcceptanceRate) -> [Proposal a]
forall k a. Map k a -> [k]
M.keys Map (Proposal a) (Maybe AcceptanceRate)
ar) [Proposal a] -> [Proposal a] -> Bool
forall a. Eq a => a -> a -> Bool
== [Proposal a] -> [Proposal a]
forall a. Ord a => [a] -> [a]
sort [Proposal a]
ps
      then Cycle a
c {ccProposals :: [Proposal a]
ccProposals = (Proposal a -> Proposal a) -> [Proposal a] -> [Proposal a]
forall a b. (a -> b) -> [a] -> [b]
map Proposal a -> Proposal a
tuneF [Proposal a]
ps}
      else String -> Cycle a
forall a. HasCallStack => String -> a
error String
"autoTuneCycle: Proposals in map and cycle do not match."
    where
      ar :: Map (Proposal a) (Maybe AcceptanceRate)
ar = Acceptance (Proposal a) -> Map (Proposal a) (Maybe AcceptanceRate)
forall k. Acceptance k -> Map k (Maybe AcceptanceRate)
acceptanceRates Acceptance (Proposal a)
a
      ps :: [Proposal a]
ps = Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
c
      tuneF :: Proposal a -> Proposal a
tuneF Proposal a
p = case Map (Proposal a) (Maybe AcceptanceRate)
ar Map (Proposal a) (Maybe AcceptanceRate)
-> Proposal a -> Maybe (Maybe AcceptanceRate)
forall k a. Ord k => Map k a -> k -> Maybe a
M.!? Proposal a
p of
        Just (Just AcceptanceRate
x) -> (String -> Proposal a)
-> (Proposal a -> Proposal a)
-> Either String (Proposal a)
-> Proposal a
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either String -> Proposal a
forall a. HasCallStack => String -> a
error Proposal a -> Proposal a
forall a. a -> a
id (Either String (Proposal a) -> Proposal a)
-> Either String (Proposal a) -> Proposal a
forall a b. (a -> b) -> a -> b
$ TuningType
-> AcceptanceRate
-> Maybe (Vector a)
-> Proposal a
-> Either String (Proposal a)
forall a.
TuningType
-> AcceptanceRate
-> Maybe (Vector a)
-> Proposal a
-> Either String (Proposal a)
tuneWithChainParameters TuningType
b AcceptanceRate
x Maybe (Vector a)
mxs Proposal a
p
        Maybe (Maybe AcceptanceRate)
_ -> Proposal a
p

-- | Horizontal line of proposal summaries.
proposalHLine :: BL.ByteString
proposalHLine :: ByteString
proposalHLine = Int64 -> Char -> ByteString
BL.replicate (ByteString -> Int64
BL.length ByteString
proposalHeader) Char
'-'

-- | Summarize the 'Proposal's in the 'Cycle'. Also report acceptance rates.
summarizeCycle :: IterationMode -> Acceptance (Proposal a) -> Cycle a -> BL.ByteString
summarizeCycle :: IterationMode -> Acceptance (Proposal a) -> Cycle a -> ByteString
summarizeCycle IterationMode
m Acceptance (Proposal a)
a Cycle a
c =
  ByteString -> [ByteString] -> ByteString
BL.intercalate ByteString
"\n" ([ByteString] -> ByteString) -> [ByteString] -> ByteString
forall a b. (a -> b) -> a -> b
$
    [ ByteString
"Summary of proposal(s) in cycle.",
      ByteString
nProposalsFullStr,
      Order -> ByteString
describeOrder (Cycle a -> Order
forall a. Cycle a -> Order
ccOrder Cycle a
c),
      ByteString
proposalHeader,
      ByteString
proposalHLine
    ]
      [ByteString] -> [ByteString] -> [ByteString]
forall a. [a] -> [a] -> [a]
++ [ PName
-> PDescription
-> PWeight
-> Maybe AcceptanceRate
-> PDimension
-> Maybe (Int, Int, AcceptanceRate)
-> ByteString
summarizeProposal
             (Proposal a -> PName
forall a. Proposal a -> PName
prName Proposal a
p)
             (Proposal a -> PDescription
forall a. Proposal a -> PDescription
prDescription Proposal a
p)
             (Proposal a -> PWeight
forall a. Proposal a -> PWeight
prWeight Proposal a
p)
             (Tuner a -> AcceptanceRate
forall a. Tuner a -> AcceptanceRate
tTuningParameter (Tuner a -> AcceptanceRate)
-> Maybe (Tuner a) -> Maybe AcceptanceRate
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Proposal a -> Maybe (Tuner a)
forall a. Proposal a -> Maybe (Tuner a)
prTuner Proposal a
p)
             (Proposal a -> PDimension
forall a. Proposal a -> PDimension
prDimension Proposal a
p)
             (Proposal a -> Maybe (Int, Int, AcceptanceRate)
ar Proposal a
p)
           | Proposal a
p <- [Proposal a]
ps
         ]
      [ByteString] -> [ByteString] -> [ByteString]
forall a. [a] -> [a] -> [a]
++ [ByteString
proposalHLine]
  where
    ps :: [Proposal a]
ps = Cycle a -> [Proposal a]
forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
c
    nProposals :: Int
nProposals = IterationMode -> Cycle a -> Int
forall a. IterationMode -> Cycle a -> Int
getNProposalsPerCycle IterationMode
m Cycle a
c
    nProposalsStr :: ByteString
nProposalsStr = Builder -> ByteString
BB.toLazyByteString (Builder -> ByteString) -> Builder -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> Builder
BB.intDec Int
nProposals
    nProposalsFullStr :: ByteString
nProposalsFullStr = case Int
nProposals of
      Int
1 -> ByteString
nProposalsStr ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
" proposal is performed per iteration."
      Int
_ -> ByteString
nProposalsStr ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
" proposals are performed per iterations."
    ar :: Proposal a -> Maybe (Int, Int, AcceptanceRate)
ar Proposal a
pr = Proposal a
-> Acceptance (Proposal a) -> Maybe (Int, Int, AcceptanceRate)
forall k.
Ord k =>
k -> Acceptance k -> Maybe (Int, Int, AcceptanceRate)
acceptanceRate Proposal a
pr Acceptance (Proposal a)
a