{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
module Mcmc.Proposal
(
Proposal (..),
(@~),
ProposalSimple (..),
Tuner (tParam, tFunc),
createProposal,
tune,
Order (..),
Cycle (ccProposals),
fromList,
setOrder,
getNIterations,
tuneCycle,
autotuneCycle,
summarizeCycle,
Acceptance (fromAcceptance),
emptyA,
pushA,
resetA,
transformKeysA,
acceptanceRatios,
)
where
import Data.Aeson
import qualified Data.ByteString.Builder as BB
import qualified Data.ByteString.Lazy.Char8 as BL
import Data.Default
import qualified Data.Double.Conversion.ByteString as BC
import Data.Function
import Data.List
import Data.Map.Strict (Map)
import qualified Data.Map.Strict as M
import Data.Maybe
import Lens.Micro
import Mcmc.Internal.ByteString
import Mcmc.Tools.Shuffle
import Numeric.Log hiding (sum)
import System.Random.MWC
data Proposal a = Proposal
{
pName :: String,
pWeight :: Int,
pSimple :: ProposalSimple a,
pTuner :: Maybe (Tuner a)
}
instance Show (Proposal a) where
show m = show $ pName m
instance Eq (Proposal a) where
m == n = pName m == pName n
instance Ord (Proposal a) where
compare = compare `on` pName
(@~) :: Lens' b a -> Proposal a -> Proposal b
(@~) l (Proposal n w s t) = Proposal n w (convertS l s) (convertT l <$> t)
newtype ProposalSimple a = ProposalSimple
{ pSample :: a -> GenIO -> IO (a, Log Double)
}
convertS :: Lens' b a -> ProposalSimple a -> ProposalSimple b
convertS l (ProposalSimple s) = ProposalSimple s'
where
s' v g = do
(x', r) <- s (v ^. l) g
return (set l x' v, r)
data Tuner a = Tuner
{ tParam :: Double,
tFunc :: Double -> ProposalSimple a
}
convertT :: Lens' b a -> Tuner a -> Tuner b
convertT l (Tuner p f) = Tuner p f'
where
f' x = convertS l $ f x
createProposal ::
String ->
Int ->
(Double -> ProposalSimple a) ->
Bool ->
Proposal a
createProposal n w f True = Proposal n w (f 1.0) (Just $ Tuner 1.0 f)
createProposal n w f False = Proposal n w (f 1.0) Nothing
tuningParamMin :: Double
tuningParamMin = 1e-12
tune :: Double -> Proposal a -> Maybe (Proposal a)
tune dt m
| dt <= 0 = error $ "tune: Tuning parameter not positive: " <> show dt <> "."
| otherwise = do
(Tuner t f) <- pTuner m
let t' = max tuningParamMin (t * dt)
return $ m {pSimple = f t', pTuner = Just $ Tuner t' f}
ratioOpt :: Double
ratioOpt = 0.44
ratioMin :: Double
ratioMin = 0.1
ratioMax :: Double
ratioMax = 0.9
data Order
=
RandomO
|
SequentialO
|
RandomReversibleO
|
SequentialReversibleO
deriving (Eq, Show)
instance Default Order where def = RandomO
data Cycle a = Cycle
{ ccProposals :: [Proposal a],
ccOrder :: Order
}
fromList :: [Proposal a] -> Cycle a
fromList [] =
error "fromList: Received an empty list but cannot create an empty Cycle."
fromList xs =
if length (nub nms) == length nms
then Cycle xs def
else error "fromList: Proposals don't have unique names."
where
nms = map pName xs
setOrder :: Order -> Cycle a -> Cycle a
setOrder o c = c {ccOrder = o}
getNIterations :: Cycle a -> Int -> GenIO -> IO [[Proposal a]]
getNIterations (Cycle xs o) n g = case o of
RandomO -> shuffleN ps n g
SequentialO -> return $ replicate n ps
RandomReversibleO -> do
psRs <- shuffleN ps n g
return [psR ++ reverse psR | psR <- psRs]
SequentialReversibleO -> return $ replicate n $ ps ++ reverse ps
where
!ps = concat [replicate (pWeight m) m | m <- xs]
tuneCycle :: Map (Proposal a) Double -> Cycle a -> Cycle a
tuneCycle m c =
if sort (M.keys m) == sort ps
then c {ccProposals = map tuneF ps}
else error "tuneCycle: Map contains proposals that are not in the cycle."
where
ps = ccProposals c
tuneF p = case m M.!? p of
Nothing -> p
Just x -> fromMaybe p (tune x p)
autotuneCycle :: Acceptance (Proposal a) -> Cycle a -> Cycle a
autotuneCycle a = tuneCycle (M.map (\x -> exp $ x - ratioOpt) $ acceptanceRatios a)
renderRow ::
BL.ByteString ->
BL.ByteString ->
BL.ByteString ->
BL.ByteString ->
BL.ByteString ->
BL.ByteString ->
BL.ByteString ->
BL.ByteString
renderRow name weight nAccept nReject acceptRatio tuneParam manualAdjustment = " " <> nm <> wt <> na <> nr <> ra <> tp <> mt
where
nm = alignLeft 30 name
wt = alignRight 8 weight
na = alignRight 15 nAccept
nr = alignRight 15 nReject
ra = alignRight 15 acceptRatio
tp = alignRight 20 tuneParam
mt = alignRight 30 manualAdjustment
proposalHeader :: BL.ByteString
proposalHeader =
renderRow "Proposal" "Weight" "Accepted" "Rejected" "Ratio" "Tuning parameter" "Consider manual adjustment"
summarizeProposal :: Proposal a -> Maybe (Int, Int, Double) -> BL.ByteString
summarizeProposal m r =
renderRow
(BL.pack name)
weight
nAccept
nReject
acceptRatio
tuneParamStr
manualAdjustmentStr
where
name = pName m
weight = BB.toLazyByteString $ BB.intDec $ pWeight m
nAccept = BB.toLazyByteString $ maybe "" (BB.intDec . (^. _1)) r
nReject = BB.toLazyByteString $ maybe "" (BB.intDec . (^. _2)) r
acceptRatio = BL.fromStrict $ maybe "" (BC.toFixed 3 . (^. _3)) r
tuneParamStr = BL.fromStrict $ maybe "" (BC.toFixed 3) (tParam <$> pTuner m)
check v
| v < ratioMin = "ratio too low"
| v > ratioMax = "ratio too high"
| otherwise = ""
manualAdjustmentStr = BL.fromStrict $ maybe "" (check . (^. _3)) r
hLine :: BL.ByteString -> BL.ByteString
hLine s = " " <> BL.replicate (BL.length s - 3) '-'
summarizeCycle :: Acceptance (Proposal a) -> Cycle a -> BL.ByteString
summarizeCycle a c =
BL.intercalate "\n" $
[ "Summary of proposal(s) in cycle. " <> mpi <> " proposal(s) per iteration.",
proposalHeader,
hLine proposalHeader
]
++ [summarizeProposal m (ar m) | m <- ps]
++ [hLine proposalHeader]
where
ps = ccProposals c
mpi = BB.toLazyByteString $ BB.intDec $ sum $ map pWeight ps
ar m = acceptanceRatio m a
newtype Acceptance k = Acceptance {fromAcceptance :: Map k (Int, Int)}
instance ToJSONKey k => ToJSON (Acceptance k) where
toJSON (Acceptance m) = toJSON m
toEncoding (Acceptance m) = toEncoding m
instance (Ord k, FromJSONKey k) => FromJSON (Acceptance k) where
parseJSON v = Acceptance <$> parseJSON v
emptyA :: Ord k => [k] -> Acceptance k
emptyA ks = Acceptance $ M.fromList [(k, (0, 0)) | k <- ks]
pushA :: (Ord k, Show k) => k -> Bool -> Acceptance k -> Acceptance k
pushA k True = Acceptance . M.adjust (\(a, r) -> (succ a, r)) k . fromAcceptance
pushA k False = Acceptance . M.adjust (\(a, r) -> (a, succ r)) k . fromAcceptance
{-# INLINEABLE pushA #-}
resetA :: Ord k => Acceptance k -> Acceptance k
resetA = emptyA . M.keys . fromAcceptance
transformKeys :: (Ord k1, Ord k2) => [k1] -> [k2] -> Map k1 v -> Map k2 v
transformKeys ks1 ks2 m = foldl' insrt M.empty $ zip ks1 ks2
where
insrt m' (k1, k2) = M.insert k2 (m M.! k1) m'
transformKeysA :: (Ord k1, Ord k2) => [k1] -> [k2] -> Acceptance k1 -> Acceptance k2
transformKeysA ks1 ks2 = Acceptance . transformKeys ks1 ks2 . fromAcceptance
acceptanceRatio :: (Show k, Ord k) => k -> Acceptance k -> Maybe (Int, Int, Double)
acceptanceRatio k a = case fromAcceptance a M.!? k of
Just (0, 0) -> Nothing
Just (as, rs) -> Just (as, rs, fromIntegral as / fromIntegral (as + rs))
Nothing -> error $ "acceptanceRatio: Key not found in map: " ++ show k ++ "."
acceptanceRatios :: Acceptance k -> Map k Double
acceptanceRatios = M.map (\(as, rs) -> fromIntegral as / fromIntegral (as + rs)) . fromAcceptance