{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}
module Mcmc.Cycle
(
Order (..),
Cycle (ccProposals, ccRequireTrace, ccHasIntermediateTuners),
cycleFromList,
setOrder,
IterationMode (..),
prepareProposals,
autoTuneCycle,
summarizeCycle,
)
where
import Control.Applicative
import qualified Data.ByteString.Builder as BB
import qualified Data.ByteString.Lazy.Char8 as BL
import Data.List
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Vector as VB
import Mcmc.Acceptance
import Mcmc.Internal.Shuffle
import Mcmc.Proposal
import System.Random.Stateful
data Order
=
RandomO
|
SequentialO
|
RandomReversibleO
|
SequentialReversibleO
deriving (Order -> Order -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Order -> Order -> Bool
$c/= :: Order -> Order -> Bool
== :: Order -> Order -> Bool
$c== :: Order -> Order -> Bool
Eq, Int -> Order -> ShowS
[Order] -> ShowS
Order -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Order] -> ShowS
$cshowList :: [Order] -> ShowS
show :: Order -> String
$cshow :: Order -> String
showsPrec :: Int -> Order -> ShowS
$cshowsPrec :: Int -> Order -> ShowS
Show)
describeOrder :: Order -> BL.ByteString
describeOrder :: Order -> ByteString
describeOrder Order
RandomO = ByteString
"The proposals are executed in random order."
describeOrder Order
SequentialO = ByteString
"The proposals are executed sequentially."
describeOrder Order
RandomReversibleO =
ByteString -> [ByteString] -> ByteString
BL.intercalate
ByteString
"\n"
[ Order -> ByteString
describeOrder Order
RandomO,
ByteString
"A reversed copy of the shuffled proposals is appended to ensure reversibility."
]
describeOrder Order
SequentialReversibleO =
ByteString -> [ByteString] -> ByteString
BL.intercalate
ByteString
"\n"
[ Order -> ByteString
describeOrder Order
SequentialO,
ByteString
"A reversed copy of the sequential proposals is appended to ensure reversibility."
]
data Cycle a = Cycle
{ forall a. Cycle a -> [Proposal a]
ccProposals :: [Proposal a],
forall a. Cycle a -> Order
ccOrder :: Order,
forall a. Cycle a -> Bool
ccRequireTrace :: Bool,
forall a. Cycle a -> Bool
ccHasIntermediateTuners :: Bool
}
cycleFromList :: [Proposal a] -> Cycle a
cycleFromList :: forall a. [Proposal a] -> Cycle a
cycleFromList [] =
forall a. HasCallStack => String -> a
error String
"cycleFromList: Received an empty list but cannot create an empty Cycle."
cycleFromList [Proposal a]
xs =
if forall (t :: * -> *) a. Foldable t => t a -> Int
length [Proposal a]
uniqueXs forall a. Eq a => a -> a -> Bool
== forall (t :: * -> *) a. Foldable t => t a -> Int
length [Proposal a]
xs
then forall a. [Proposal a] -> Order -> Bool -> Bool -> Cycle a
Cycle [Proposal a]
xs Order
RandomO (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any forall {a}. Proposal a -> Bool
needsTrace [Proposal a]
xs) (forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any forall {a}. Proposal a -> Bool
isIntermediate [Proposal a]
xs)
else forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"\n" forall a. [a] -> [a] -> [a]
++ String
msg forall a. [a] -> [a] -> [a]
++ String
"cycleFromList: Proposals are not unique."
where
uniqueXs :: [Proposal a]
uniqueXs = forall a. Eq a => [a] -> [a]
nub [Proposal a]
xs
removedXs :: [Proposal a]
removedXs = [Proposal a]
xs forall a. Eq a => [a] -> [a] -> [a]
\\ [Proposal a]
uniqueXs
removedNames :: [String]
removedNames = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Show a => a -> String
show forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Proposal a -> PName
prName) [Proposal a]
removedXs
removedDescriptions :: [String]
removedDescriptions = forall a b. (a -> b) -> [a] -> [b]
map (forall a. Show a => a -> String
show forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Proposal a -> PDescription
prDescription) [Proposal a]
removedXs
removedMsgs :: [String]
removedMsgs = forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\String
n String
d -> String
n forall a. [a] -> [a] -> [a]
++ String
" " forall a. [a] -> [a] -> [a]
++ String
d) [String]
removedNames [String]
removedDescriptions
msg :: String
msg = [String] -> String
unlines [String]
removedMsgs
needsTrace :: Proposal a -> Bool
needsTrace Proposal a
p = forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False forall a. Tuner a -> Bool
tRequireTrace (forall a. Proposal a -> Maybe (Tuner a)
prTuner Proposal a
p)
isIntermediate :: Proposal a -> Bool
isIntermediate Proposal a
p = forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
False forall a. Tuner a -> Bool
tSuitableForIntermediateTuning (forall a. Proposal a -> Maybe (Tuner a)
prTuner Proposal a
p)
setOrder :: Order -> Cycle a -> Cycle a
setOrder :: forall a. Order -> Cycle a -> Cycle a
setOrder Order
o Cycle a
c = Cycle a
c {ccOrder :: Order
ccOrder = Order
o}
data IterationMode = AllProposals | FastProposals
deriving (IterationMode -> IterationMode -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: IterationMode -> IterationMode -> Bool
$c/= :: IterationMode -> IterationMode -> Bool
== :: IterationMode -> IterationMode -> Bool
$c== :: IterationMode -> IterationMode -> Bool
Eq)
prepareProposals :: StatefulGen g m => IterationMode -> Cycle a -> g -> m [Proposal a]
prepareProposals :: forall g (m :: * -> *) a.
StatefulGen g m =>
IterationMode -> Cycle a -> g -> m [Proposal a]
prepareProposals IterationMode
m (Cycle [Proposal a]
xs Order
o Bool
_ Bool
_) g
g =
if forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Proposal a]
ps
then
let msg :: String
msg = case IterationMode
m of
IterationMode
FastProposals -> String
"no fast proposals found"
IterationMode
AllProposals -> String
"no proposals found"
in forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"prepareProposals: " forall a. Semigroup a => a -> a -> a
<> String
msg
else case Order
o of
Order
RandomO -> forall g (m :: * -> *) a. StatefulGen g m => [a] -> g -> m [a]
shuffle [Proposal a]
ps g
g
Order
SequentialO -> forall (m :: * -> *) a. Monad m => a -> m a
return [Proposal a]
ps
Order
RandomReversibleO -> do
[Proposal a]
psR <- forall g (m :: * -> *) a. StatefulGen g m => [a] -> g -> m [a]
shuffle [Proposal a]
ps g
g
forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Proposal a]
psR forall a. [a] -> [a] -> [a]
++ forall a. [a] -> [a]
reverse [Proposal a]
psR
Order
SequentialReversibleO -> forall (m :: * -> *) a. Monad m => a -> m a
return forall a b. (a -> b) -> a -> b
$ [Proposal a]
ps forall a. [a] -> [a] -> [a]
++ forall a. [a] -> [a]
reverse [Proposal a]
ps
where
!ps :: [Proposal a]
ps =
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat
[ forall a. Int -> a -> [a]
replicate (PWeight -> Int
fromPWeight forall a b. (a -> b) -> a -> b
$ forall a. Proposal a -> PWeight
prWeight Proposal a
p) Proposal a
p
| Proposal a
p <- [Proposal a]
xs,
case IterationMode
m of
IterationMode
AllProposals -> Bool
True
IterationMode
FastProposals -> forall a. Proposal a -> PSpeed
prSpeed Proposal a
p forall a. Eq a => a -> a -> Bool
== PSpeed
PFast
]
getNProposalsPerCycle :: IterationMode -> Cycle a -> Int
getNProposalsPerCycle :: forall a. IterationMode -> Cycle a -> Int
getNProposalsPerCycle IterationMode
m (Cycle [Proposal a]
xs Order
o Bool
_ Bool
_) = case Order
o of
Order
RandomO -> Int
once
Order
SequentialO -> Int
once
Order
RandomReversibleO -> Int
2 forall a. Num a => a -> a -> a
* Int
once
Order
SequentialReversibleO -> Int
2 forall a. Num a => a -> a -> a
* Int
once
where
xs' :: [Proposal a]
xs' = case IterationMode
m of
IterationMode
AllProposals -> [Proposal a]
xs
IterationMode
FastProposals -> forall a. (a -> Bool) -> [a] -> [a]
filter (\Proposal a
x -> forall a. Proposal a -> PSpeed
prSpeed Proposal a
x forall a. Eq a => a -> a -> Bool
== PSpeed
PFast) [Proposal a]
xs
once :: Int
once = forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map (PWeight -> Int
fromPWeight forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Proposal a -> PWeight
prWeight) [Proposal a]
xs'
tuneWithChainParameters ::
TuningType ->
Maybe AcceptanceRate ->
Maybe (VB.Vector a) ->
Proposal a ->
Either String (Proposal a)
tuneWithChainParameters :: forall a.
TuningType
-> Maybe AcceptanceRate
-> Maybe (Vector a)
-> Proposal a
-> Either String (Proposal a)
tuneWithChainParameters TuningType
tt Maybe AcceptanceRate
mar Maybe (Vector a)
mxs Proposal a
p = case forall a. Proposal a -> Maybe (Tuner a)
prTuner Proposal a
p of
Maybe (Tuner a)
Nothing -> forall a b. b -> Either a b
Right Proposal a
p
Just (Tuner AcceptanceRate
t AuxiliaryTuningParameters
ts Bool
rt Bool
it TuningFunction a
fT AcceptanceRate
-> AuxiliaryTuningParameters -> Either String (PFunction a)
_) -> case (TuningType
tt, Bool
it, forall a. Proposal a -> PSpeed
prSpeed Proposal a
p) of
(TuningType
IntermediateTuningFastProposalsOnly, Bool
True, PSpeed
PFast) -> Either String (Proposal a)
tuneIntermediate
(TuningType
IntermediateTuningAllProposals, Bool
True, PSpeed
_) -> Either String (Proposal a)
tuneIntermediate
(TuningType
NormalTuningFastProposalsOnly, Bool
_, PSpeed
PFast) -> Either String (Proposal a)
tuneNormally
(TuningType
NormalTuningAllProposals, Bool
_, PSpeed
_) -> Either String (Proposal a)
tuneNormally
(TuningType
LastTuningFastProposalsOnly, Bool
_, PSpeed
_) -> Either String (Proposal a)
tuneNormally
(TuningType
LastTuningAllProposals, Bool
_, PSpeed
_) -> Either String (Proposal a)
tuneNormally
(TuningType, Bool, PSpeed)
_ -> forall a b. b -> Either a b
Right Proposal a
p
where
hasTrace :: Bool
hasTrace = forall a. Maybe a -> Bool
isJust Maybe (Vector a)
mxs
err :: a -> Either a b
err a
m = forall a b. a -> Either a b
Left forall a b. (a -> b) -> a -> b
$ a
"tuneWithChainParameters: " forall a. Semigroup a => a -> a -> a
<> a
m
tuneIntermediate :: Either String (Proposal a)
tuneIntermediate =
if Bool
hasTrace
then forall {a} {b}. (Semigroup a, IsString a) => a -> Either a b
err String
"intermediate tuning but trace provided"
else Either String (Proposal a)
tune
tuneNormally :: Either String (Proposal a)
tuneNormally =
if Bool
rt Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
hasTrace
then forall {a} {b}. (Semigroup a, IsString a) => a -> Either a b
err String
"trace required"
else Either String (Proposal a)
tune
tune :: Either String (Proposal a)
tune =
let (AcceptanceRate
t', AuxiliaryTuningParameters
ts') = TuningFunction a
fT TuningType
tt (forall a. Proposal a -> PDimension
prDimension Proposal a
p) Maybe AcceptanceRate
mar Maybe (Vector a)
mxs (AcceptanceRate
t, AuxiliaryTuningParameters
ts)
in forall a.
AcceptanceRate
-> AuxiliaryTuningParameters
-> Proposal a
-> Either String (Proposal a)
tuneWithTuningParameters AcceptanceRate
t' AuxiliaryTuningParameters
ts' Proposal a
p
autoTuneCycle :: TuningType -> Acceptances (Proposal a) -> Maybe (VB.Vector a) -> Cycle a -> Cycle a
autoTuneCycle :: forall a.
TuningType
-> Acceptances (Proposal a)
-> Maybe (Vector a)
-> Cycle a
-> Cycle a
autoTuneCycle TuningType
tt Acceptances (Proposal a)
a Maybe (Vector a)
mxs Cycle a
c
| forall a. Maybe a -> Bool
isJust Maybe (Vector a)
mxs Bool -> Bool -> Bool
&& Bool -> Bool
not (forall a. Cycle a -> Bool
ccRequireTrace Cycle a
c) = forall {a}. String -> a
err String
"trace provided but not required"
| Bool
otherwise =
if forall a. Ord a => [a] -> [a]
sort (forall k a. Map k a -> [k]
M.keys forall a b. (a -> b) -> a -> b
$ forall k. Acceptances k -> Map k Acceptance
fromAcceptances Acceptances (Proposal a)
a) forall a. Eq a => a -> a -> Bool
== forall a. Ord a => [a] -> [a]
sort [Proposal a]
ps
then Cycle a
c {ccProposals :: [Proposal a]
ccProposals = forall a b. (a -> b) -> [a] -> [b]
map Proposal a -> Proposal a
tuneF [Proposal a]
ps}
else forall {a}. String -> a
err String
"proposals in map and cycle do not match"
where
err :: String -> a
err String
msg = forall a. HasCallStack => String -> a
error forall a b. (a -> b) -> a -> b
$ String
"autoTuneCycle: " forall a. Semigroup a => a -> a -> a
<> String
msg
ps :: [Proposal a]
ps = forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
c
tuneF :: Proposal a -> Proposal a
tuneF Proposal a
p =
let (Int
_, Int
_, Maybe AcceptanceRate
mar, Maybe AcceptanceRate
mtr) = forall k.
Ord k =>
k
-> Acceptances k
-> (Int, Int, Maybe AcceptanceRate, Maybe AcceptanceRate)
acceptanceRate Proposal a
p Acceptances (Proposal a)
a
mr :: Maybe AcceptanceRate
mr = Maybe AcceptanceRate
mtr forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Maybe AcceptanceRate
mar
in forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either forall a. HasCallStack => String -> a
error forall a. a -> a
id forall a b. (a -> b) -> a -> b
$ forall a.
TuningType
-> Maybe AcceptanceRate
-> Maybe (Vector a)
-> Proposal a
-> Either String (Proposal a)
tuneWithChainParameters TuningType
tt Maybe AcceptanceRate
mr Maybe (Vector a)
mxs Proposal a
p
summarizeCycle :: IterationMode -> Acceptances (Proposal a) -> Cycle a -> BL.ByteString
summarizeCycle :: forall a.
IterationMode -> Acceptances (Proposal a) -> Cycle a -> ByteString
summarizeCycle IterationMode
m Acceptances (Proposal a)
a Cycle a
c =
ByteString -> [ByteString] -> ByteString
BL.intercalate ByteString
"\n" forall a b. (a -> b) -> a -> b
$
[ ByteString
"Summary of proposal(s) in cycle.",
ByteString
nProposalsFullStr,
Order -> ByteString
describeOrder (forall a. Cycle a -> Order
ccOrder Cycle a
c),
ByteString
proposalHeader,
ByteString
proposalHLine
]
forall a. [a] -> [a] -> [a]
++ [ PName
-> PDescription
-> PWeight
-> Maybe AcceptanceRate
-> PDimension
-> (Int, Int, Maybe AcceptanceRate, Maybe AcceptanceRate)
-> ByteString
summarizeProposal
(forall a. Proposal a -> PName
prName Proposal a
p)
(forall a. Proposal a -> PDescription
prDescription Proposal a
p)
(forall a. Proposal a -> PWeight
prWeight Proposal a
p)
(forall a. Tuner a -> AcceptanceRate
tTuningParameter forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Proposal a -> Maybe (Tuner a)
prTuner Proposal a
p)
(forall a. Proposal a -> PDimension
prDimension Proposal a
p)
(Proposal a
-> (Int, Int, Maybe AcceptanceRate, Maybe AcceptanceRate)
ar Proposal a
p)
| Proposal a
p <- [Proposal a]
ps
]
forall a. [a] -> [a] -> [a]
++ [ByteString
proposalHLine]
where
ps :: [Proposal a]
ps = forall a. Cycle a -> [Proposal a]
ccProposals Cycle a
c
nProposals :: Int
nProposals = forall a. IterationMode -> Cycle a -> Int
getNProposalsPerCycle IterationMode
m Cycle a
c
nProposalsStr :: ByteString
nProposalsStr = Builder -> ByteString
BB.toLazyByteString forall a b. (a -> b) -> a -> b
$ Int -> Builder
BB.intDec Int
nProposals
nProposalsFullStr :: ByteString
nProposalsFullStr = case Int
nProposals of
Int
1 -> ByteString
nProposalsStr forall a. Semigroup a => a -> a -> a
<> ByteString
" proposal is performed per iteration."
Int
_ -> ByteString
nProposalsStr forall a. Semigroup a => a -> a -> a
<> ByteString
" proposals are performed per iterations."
ar :: Proposal a
-> (Int, Int, Maybe AcceptanceRate, Maybe AcceptanceRate)
ar Proposal a
pr = forall k.
Ord k =>
k
-> Acceptances k
-> (Int, Int, Maybe AcceptanceRate, Maybe AcceptanceRate)
acceptanceRate Proposal a
pr Acceptances (Proposal a)
a
proposalHLine :: ByteString
proposalHLine = Int64 -> Char -> ByteString
BL.replicate (ByteString -> Int64
BL.length ByteString
proposalHeader) Char
'-'