{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RankNTypes #-}
module Mcmc.Proposal
(
Proposal (..),
(@~),
ProposalSimple (..),
Tuner (tParam, tFunc),
createProposal,
tune,
Order (..),
Cycle (ccProposals),
fromList,
setOrder,
getNCycles,
tuneCycle,
autotuneCycle,
summarizeCycle,
Acceptance (fromAcceptance),
emptyA,
pushA,
resetA,
transformKeysA,
acceptanceRatios,
)
where
import Data.Aeson
import Data.Default
import Data.Function
import Data.List
import qualified Data.Map.Strict as M
import Data.Map.Strict (Map)
import Data.Maybe
import qualified Data.Text.Lazy as T
import Data.Text.Lazy (Text)
import qualified Data.Text.Lazy.Builder as B
import qualified Data.Text.Lazy.Builder.Int as B
import qualified Data.Text.Lazy.Builder.RealFloat as B
import Lens.Micro
import Mcmc.Tools.Shuffle
import Numeric.Log hiding (sum)
import System.Random.MWC
data Proposal a = Proposal
{
pName :: String,
pWeight :: Int,
pSimple :: ProposalSimple a,
pTuner :: Maybe (Tuner a)
}
instance Show (Proposal a) where
show m = show $ pName m
instance Eq (Proposal a) where
m == n = pName m == pName n
instance Ord (Proposal a) where
compare = compare `on` pName
convertP :: Lens' b a -> Proposal a -> Proposal b
convertP l (Proposal n w s t) = Proposal n w (convertS l s) (convertT l <$> t)
(@~) :: Lens' b a -> Proposal a -> Proposal b
(@~) = convertP
newtype ProposalSimple a = ProposalSimple
{ pSample :: a -> GenIO -> IO (a, Log Double)
}
convertS :: Lens' b a -> ProposalSimple a -> ProposalSimple b
convertS l (ProposalSimple s) = ProposalSimple s'
where
s' v g = do
(x', r) <- s (v ^. l) g
return (set l x' v, r)
data Tuner a = Tuner
{ tParam :: Double,
tFunc :: Double -> ProposalSimple a
}
convertT :: Lens' b a -> Tuner a -> Tuner b
convertT l (Tuner p f) = Tuner p f'
where
f' x = convertS l $ f x
createProposal ::
String ->
Int ->
(Double -> ProposalSimple a) ->
Bool ->
Proposal a
createProposal n w f True = Proposal n w (f 1.0) (Just $ Tuner 1.0 f)
createProposal n w f False = Proposal n w (f 1.0) Nothing
tuningParamMin :: Double
tuningParamMin = 1e-12
tune :: Double -> Proposal a -> Maybe (Proposal a)
tune dt m
| dt <= 0 = error $ "tune: Tuning parameter not positive: " <> show dt <> "."
| otherwise = do
(Tuner t f) <- pTuner m
let t' = max tuningParamMin (t * dt)
return $ m {pSimple = f t', pTuner = Just $ Tuner t' f}
ratioOpt :: Double
ratioOpt = 0.44
data Order
=
RandomO
|
SequentialO
|
RandomReversibleO
|
SequentialReversibleO
deriving (Eq, Show)
instance Default Order where def = RandomO
data Cycle a = Cycle
{ ccProposals :: [Proposal a],
ccOrder :: Order
}
fromList :: [Proposal a] -> Cycle a
fromList [] =
error "fromList: Received an empty list but cannot create an empty Cycle."
fromList xs =
if length (nub nms) == length nms
then Cycle xs def
else error "fromList: Proposals don't have unique names."
where
nms = map pName xs
setOrder :: Order -> Cycle a -> Cycle a
setOrder o c = c {ccOrder = o}
getNCycles :: Cycle a -> Int -> GenIO -> IO [[Proposal a]]
getNCycles (Cycle xs o) n g = case o of
RandomO -> shuffleN ps n g
SequentialO -> return $ replicate n ps
RandomReversibleO -> do
psRs <- shuffleN ps n g
return [psR ++ reverse psR | psR <- psRs]
SequentialReversibleO -> return $ replicate n $ ps ++ reverse ps
where
!ps = concat [replicate (pWeight m) m | m <- xs]
tuneCycle :: Map (Proposal a) Double -> Cycle a -> Cycle a
tuneCycle m c =
if sort (M.keys m) == sort ps
then c {ccProposals = map tuneF ps}
else error "tuneCycle: Map contains proposals that are not in the cycle."
where
ps = ccProposals c
tuneF p = case m M.!? p of
Nothing -> p
Just x -> fromMaybe p (tune x p)
autotuneCycle :: Acceptance (Proposal a) -> Cycle a -> Cycle a
autotuneCycle a = tuneCycle (M.map (\x -> exp $ x - ratioOpt) $ acceptanceRatios a)
renderRow :: Text -> Text -> Text -> Text -> Text -> Text -> Text
renderRow name weight nAccept nReject acceptRatio tuneParam = " " <> nm <> wt <> na <> nr <> ra <> tp
where
nm = T.justifyLeft 30 ' ' name
wt = T.justifyRight 8 ' ' weight
na = T.justifyRight 15 ' ' nAccept
nr = T.justifyRight 15 ' ' nReject
ra = T.justifyRight 15 ' ' acceptRatio
tp = T.justifyRight 20 ' ' tuneParam
proposalHeader :: Text
proposalHeader =
renderRow "Proposal" "Weight" "Accepted" "Rejected" "Ratio" "Tuning parameter"
summarizeProposal :: Proposal a -> Maybe (Int, Int, Double) -> Text
summarizeProposal m r = renderRow (T.pack name) weight nAccept nReject acceptRatio tuneParamStr
where
name = pName m
weight = B.toLazyText $ B.decimal $ pWeight m
nAccept = B.toLazyText $ maybe "" (B.decimal . (^. _1)) r
nReject = B.toLazyText $ maybe "" (B.decimal . (^. _2)) r
acceptRatio = B.toLazyText $ maybe "" (B.formatRealFloat B.Fixed (Just 3) . (^. _3)) r
tuneParamStr = B.toLazyText $ maybe "" (B.formatRealFloat B.Fixed (Just 3)) (tParam <$> pTuner m)
summarizeCycle :: Acceptance (Proposal a) -> Cycle a -> Text
summarizeCycle a c =
T.intercalate "\n" $
[ "Summary of proposal(s) in cycle. " <> mpi <> " proposal(s) per iteration.",
proposalHeader,
" " <> T.replicate (T.length proposalHeader - 3) "─"
]
++ [summarizeProposal m (ar m) | m <- ps]
++ [" " <> T.replicate (T.length proposalHeader - 3) "─"]
where
ps = ccProposals c
mpi = B.toLazyText $ B.decimal $ sum $ map pWeight ps
ar m = acceptanceRatio m a
newtype Acceptance k = Acceptance {fromAcceptance :: Map k (Int, Int)}
instance ToJSONKey k => ToJSON (Acceptance k) where
toJSON (Acceptance m) = toJSON m
toEncoding (Acceptance m) = toEncoding m
instance (Ord k, FromJSONKey k) => FromJSON (Acceptance k) where
parseJSON v = Acceptance <$> parseJSON v
emptyA :: Ord k => [k] -> Acceptance k
emptyA ks = Acceptance $ M.fromList [(k, (0, 0)) | k <- ks]
pushA :: (Ord k, Show k) => k -> Bool -> Acceptance k -> Acceptance k
pushA k True = Acceptance . M.adjust (\(a, r) -> (succ a, r)) k . fromAcceptance
pushA k False = Acceptance . M.adjust (\(a, r) -> (a, succ r)) k . fromAcceptance
{-# INLINEABLE pushA #-}
resetA :: Ord k => Acceptance k -> Acceptance k
resetA = emptyA . M.keys . fromAcceptance
transformKeys :: (Ord k1, Ord k2) => [k1] -> [k2] -> Map k1 v -> Map k2 v
transformKeys ks1 ks2 m = foldl' insrt M.empty $ zip ks1 ks2
where
insrt m' (k1, k2) = M.insert k2 (m M.! k1) m'
transformKeysA :: (Ord k1, Ord k2) => [k1] -> [k2] -> Acceptance k1 -> Acceptance k2
transformKeysA ks1 ks2 = Acceptance . transformKeys ks1 ks2 . fromAcceptance
acceptanceRatio :: (Show k, Ord k) => k -> Acceptance k -> Maybe (Int, Int, Double)
acceptanceRatio k a = case fromAcceptance a M.!? k of
Just (0, 0) -> Nothing
Just (as, rs) -> Just (as, rs, fromIntegral as / fromIntegral (as + rs))
Nothing -> error $ "acceptanceRatio: Key not found in map: " ++ show k ++ "."
acceptanceRatios :: Acceptance k -> Map k Double
acceptanceRatios = M.map (\(as, rs) -> fromIntegral as / fromIntegral (as + rs)) . fromAcceptance