{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# OPTIONS_GHC -Wno-type-defaults #-}
module Control.Monad.Bayes.Inference.TUI where
import Brick
import Brick qualified as B
import Brick.BChan qualified as B
import Brick.Widgets.Border
import Brick.Widgets.Border.Style
import Brick.Widgets.Center
import Brick.Widgets.ProgressBar qualified as B
import Control.Arrow (Arrow (..))
import Control.Concurrent (forkIO)
import Control.Foldl qualified as Fold
import Control.Monad (void)
import Control.Monad.Bayes.Enumerator (toEmpirical)
import Control.Monad.Bayes.Inference.MCMC
import Control.Monad.Bayes.Sampler.Strict (SamplerIO, sampleIO)
import Control.Monad.Bayes.Traced (Traced)
import Control.Monad.Bayes.Traced.Common hiding (burnIn)
import Control.Monad.Bayes.Weighted
import Data.Scientific (FPFormat (Exponent), formatScientific, fromFloatDigits)
import Data.Text qualified as T
import Data.Text.Lazy qualified as TL
import Data.Text.Lazy.IO qualified as TL
import GHC.Float (double2Float)
import Graphics.Vty
import Graphics.Vty qualified as V
import Numeric.Log (Log (ln))
import Pipes (runEffect, (>->))
import Pipes qualified as P
import Pipes.Prelude qualified as P
import Text.Pretty.Simple (pShow, pShowNoColor)
data MCMCData a = MCMCData
{ forall a. MCMCData a -> Int
numSteps :: Int,
forall a. MCMCData a -> Int
numSuccesses :: Int,
forall a. MCMCData a -> [a]
samples :: [a],
forall a. MCMCData a -> [Double]
lk :: [Double],
forall a. MCMCData a -> Int
totalSteps :: Int
}
deriving stock (Int -> MCMCData a -> ShowS
forall a. Show a => Int -> MCMCData a -> ShowS
forall a. Show a => [MCMCData a] -> ShowS
forall a. Show a => MCMCData a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MCMCData a] -> ShowS
$cshowList :: forall a. Show a => [MCMCData a] -> ShowS
show :: MCMCData a -> String
$cshow :: forall a. Show a => MCMCData a -> String
showsPrec :: Int -> MCMCData a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> MCMCData a -> ShowS
Show)
drawUI :: ([a] -> Widget n) -> MCMCData a -> [Widget n]
drawUI :: forall a n. ([a] -> Widget n) -> MCMCData a -> [Widget n]
drawUI [a] -> Widget n
handleSamples MCMCData a
state = [Widget n
ui]
where
completionBar :: Widget n
completionBar =
forall n. (AttrMap -> AttrMap) -> Widget n -> Widget n
updateAttrMap
( [(AttrName, AttrName)] -> AttrMap -> AttrMap
B.mapAttrNames
[ (AttrName
doneAttr, AttrName
B.progressCompleteAttr),
(AttrName
toDoAttr, AttrName
B.progressIncompleteAttr)
]
)
forall a b. (a -> b) -> a -> b
$ forall {n}. Float -> Widget n
toBar
forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral
forall a b. (a -> b) -> a -> b
$ forall a. MCMCData a -> Int
numSteps MCMCData a
state
likelihoodBar :: Widget n
likelihoodBar =
forall n. (AttrMap -> AttrMap) -> Widget n -> Widget n
updateAttrMap
( [(AttrName, AttrName)] -> AttrMap -> AttrMap
B.mapAttrNames
[ (AttrName
doneAttr, AttrName
B.progressCompleteAttr),
(AttrName
toDoAttr, AttrName
B.progressIncompleteAttr)
]
)
forall a b. (a -> b) -> a -> b
$ forall n. Maybe String -> Float -> Widget n
B.progressBar
(forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ String
"Mean likelihood for last 1000 samples: " forall a. Semigroup a => a -> a -> a
<> forall a. Int -> [a] -> [a]
take Int
10 (forall a. Show a => a -> String
show (forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall a. MCMCData a -> [Double]
lk MCMCData a
state forall a. Semigroup a => a -> a -> a
<> [Double
0])))
(Double -> Float
double2Float (forall (f :: * -> *) a b. Foldable f => Fold a b -> f a -> b
Fold.fold forall a. Fractional a => Fold a a
Fold.mean forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
1000 forall a b. (a -> b) -> a -> b
$ forall a. MCMCData a -> [Double]
lk MCMCData a
state) forall a. Fractional a => a -> a -> a
/ Double -> Float
double2Float (forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum forall a b. (a -> b) -> a -> b
$ Double
0 forall a. a -> [a] -> [a]
: forall a. MCMCData a -> [Double]
lk MCMCData a
state))
displayStep :: a -> Maybe String
displayStep a
c = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ String
"Step " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show a
c
numFailures :: Int
numFailures = forall a. MCMCData a -> Int
numSteps MCMCData a
state forall a. Num a => a -> a -> a
- forall a. MCMCData a -> Int
numSuccesses MCMCData a
state
toBar :: Float -> Widget n
toBar Float
v = forall n. Maybe String -> Float -> Widget n
B.progressBar (forall {a}. Show a => a -> Maybe String
displayStep Float
v) (Float
v forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. MCMCData a -> Int
totalSteps MCMCData a
state))
displaySuccessesAndFailures :: Widget n
displaySuccessesAndFailures =
forall n. BorderStyle -> Widget n -> Widget n
withBorderStyle BorderStyle
unicode forall a b. (a -> b) -> a -> b
$
forall n. Widget n -> Widget n -> Widget n
borderWithLabel (forall n. String -> Widget n
str String
"Successes and failures") forall a b. (a -> b) -> a -> b
$
forall n. Widget n -> Widget n
center (forall n. String -> Widget n
str (forall a. Show a => a -> String
show forall a b. (a -> b) -> a -> b
$ forall a. MCMCData a -> Int
numSuccesses MCMCData a
state))
forall n. Widget n -> Widget n -> Widget n
<+> forall n. Widget n
vBorder
forall n. Widget n -> Widget n -> Widget n
<+> forall n. Widget n -> Widget n
center (forall n. String -> Widget n
str (forall a. Show a => a -> String
show Int
numFailures))
warning :: Widget n
warning =
if forall a. MCMCData a -> Int
numSteps MCMCData a
state forall a. Ord a => a -> a -> Bool
> Int
1000 Bool -> Bool -> Bool
&& (forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. MCMCData a -> Int
numSuccesses MCMCData a
state) forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. MCMCData a -> Int
numSteps MCMCData a
state)) forall a. Ord a => a -> a -> Bool
< Double
0.1
then forall n. AttrName -> Widget n -> Widget n
withAttr (String -> AttrName
attrName String
"highlight") forall a b. (a -> b) -> a -> b
$ forall n. String -> Widget n
str String
"Warning: acceptance rate is rather low.\nThis probably means that your proposal isn't good."
else forall n. String -> Widget n
str String
""
ui :: Widget n
ui =
(forall n. String -> Widget n
str String
"Progress: " forall n. Widget n -> Widget n -> Widget n
<+> forall n. Widget n
completionBar)
forall n. Widget n -> Widget n -> Widget n
<=> (forall n. String -> Widget n
str String
"Likelihood: " forall n. Widget n -> Widget n -> Widget n
<+> forall n. Widget n
likelihoodBar)
forall n. Widget n -> Widget n -> Widget n
<=> forall n. String -> Widget n
str String
"\n"
forall n. Widget n -> Widget n -> Widget n
<=> forall n. Widget n
displaySuccessesAndFailures
forall n. Widget n -> Widget n -> Widget n
<=> forall n. Widget n
warning
forall n. Widget n -> Widget n -> Widget n
<=> [a] -> Widget n
handleSamples (forall a. MCMCData a -> [a]
samples MCMCData a
state)
noVisual :: b -> Widget n
noVisual :: forall b n. b -> Widget n
noVisual = forall a b. a -> b -> a
const forall n. Widget n
emptyWidget
showEmpirical :: (Show a, Ord a) => [a] -> Widget n
showEmpirical :: forall a n. (Show a, Ord a) => [a] -> Widget n
showEmpirical =
forall n. Text -> Widget n
txt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
T.pack
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
TL.unpack
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> Text
pShow
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (FPFormat -> Maybe Int -> Scientific -> String
formatScientific FPFormat
Exponent (forall a. a -> Maybe a
Just Int
3) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. RealFloat a => a -> Scientific
fromFloatDigits)))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b a. (Fractional b, Ord a, Ord b) => [a] -> [(a, b)]
toEmpirical
showVal :: Show a => [a] -> Widget n
showVal :: forall a n. Show a => [a] -> Widget n
showVal = forall n. Text -> Widget n
txt forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
T.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\case [] -> String
""; [a]
a -> forall a. Show a => a -> String
show forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [a]
a)
appEvent :: B.BrickEvent n s -> B.EventM n s ()
appEvent :: forall n s. BrickEvent n s -> EventM n s ()
appEvent (B.VtyEvent (V.EvKey (V.KChar Char
'q') [])) = forall n s. EventM n s ()
B.halt
appEvent (B.VtyEvent Event
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
appEvent (B.AppEvent s
d) = forall s (m :: * -> *). MonadState s m => s -> m ()
put s
d
appEvent BrickEvent n s
_ = forall a. HasCallStack => String -> a
error String
"unknown event"
doneAttr, toDoAttr :: B.AttrName
doneAttr :: AttrName
doneAttr = String -> AttrName
B.attrName String
"theBase" forall a. Semigroup a => a -> a -> a
<> String -> AttrName
B.attrName String
"done"
toDoAttr :: AttrName
toDoAttr = String -> AttrName
B.attrName String
"theBase" forall a. Semigroup a => a -> a -> a
<> String -> AttrName
B.attrName String
"remaining"
theMap :: B.AttrMap
theMap :: AttrMap
theMap =
Attr -> [(AttrName, Attr)] -> AttrMap
B.attrMap
Attr
V.defAttr
[ (String -> AttrName
B.attrName String
"theBase", Color -> Attr
bg Color
V.brightBlack),
(AttrName
doneAttr, Color
V.black Color -> Color -> Attr
`on` Color
V.white),
(AttrName
toDoAttr, Color
V.white Color -> Color -> Attr
`on` Color
V.black),
(String -> AttrName
attrName String
"highlight", Color -> Attr
fg Color
yellow)
]
tui :: Show a => Int -> Traced (Weighted SamplerIO) a -> ([a] -> Widget ()) -> IO ()
tui :: forall a.
Show a =>
Int -> Traced (Weighted SamplerIO) a -> ([a] -> Widget ()) -> IO ()
tui Int
burnIn Traced (Weighted SamplerIO) a
distribution [a] -> Widget ()
visualizer = forall (f :: * -> *) a. Functor f => f a -> f ()
void do
BChan (MCMCData a)
eventChan <- forall a. Int -> IO (BChan a)
B.newBChan Int
10
Vty
initialVty <- IO Vty
buildVty
ThreadId
_ <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ forall {a}.
Proxy X () () (MHResult a) (Weighted SamplerIO) ()
-> BChan (MCMCData a) -> Int -> IO ()
run (forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> Traced m a -> Producer (MHResult a) m ()
mcmcP MCMCConfig {numBurnIn :: Int
numBurnIn = Int
burnIn, proposal :: Proposal
proposal = Proposal
SingleSiteMH, numMCMCSteps :: Int
numMCMCSteps = -Int
1} Traced (Weighted SamplerIO) a
distribution) BChan (MCMCData a)
eventChan Int
n
MCMCData a
samples <-
forall n e s.
Ord n =>
Vty -> IO Vty -> Maybe (BChan e) -> App s e n -> s -> IO s
B.customMain
Vty
initialVty
IO Vty
buildVty
(forall a. a -> Maybe a
Just BChan (MCMCData a)
eventChan)
( ( B.App
{ appDraw :: MCMCData a -> [Widget ()]
B.appDraw = forall a n. ([a] -> Widget n) -> MCMCData a -> [Widget n]
drawUI [a] -> Widget ()
visualizer,
appChooseCursor :: MCMCData a -> [CursorLocation ()] -> Maybe (CursorLocation ())
B.appChooseCursor = forall s n. s -> [CursorLocation n] -> Maybe (CursorLocation n)
B.showFirstCursor,
appHandleEvent :: BrickEvent () (MCMCData a) -> EventM () (MCMCData a) ()
B.appHandleEvent = forall n s. BrickEvent n s -> EventM n s ()
appEvent,
appStartEvent :: EventM () (MCMCData a) ()
B.appStartEvent = forall (m :: * -> *) a. Monad m => a -> m a
return (),
appAttrMap :: MCMCData a -> AttrMap
B.appAttrMap = forall a b. a -> b -> a
const AttrMap
theMap
}
)
)
(forall {a}. Int -> MCMCData a
initialState Int
n)
String -> Text -> IO ()
TL.writeFile String
"data/tui_output.txt" (forall a. Show a => a -> Text
pShowNoColor MCMCData a
samples)
forall (m :: * -> *) a. Monad m => a -> m a
return MCMCData a
samples
where
buildVty :: IO Vty
buildVty = Config -> IO Vty
V.mkVty Config
V.defaultConfig
n :: Int
n = Int
100000
initialState :: Int -> MCMCData a
initialState Int
n = MCMCData {numSteps :: Int
numSteps = Int
0, samples :: [a]
samples = [], lk :: [Double]
lk = [], numSuccesses :: Int
numSuccesses = Int
0, totalSteps :: Int
totalSteps = Int
n}
run :: Proxy X () () (MHResult a) (Weighted SamplerIO) ()
-> BChan (MCMCData a) -> Int -> IO ()
run Proxy X () () (MHResult a) (Weighted SamplerIO) ()
prod BChan (MCMCData a)
chan Int
i =
forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect forall a b. (a -> b) -> a -> b
$
forall {k} (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
(b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
P.hoist (forall a. SamplerIO a -> IO a
sampleIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Functor m => Weighted m a -> m a
unweighted) Proxy X () () (MHResult a) (Weighted SamplerIO) ()
prod
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) x a b r.
Functor m =>
(x -> a -> x) -> x -> (x -> b) -> Pipe a b m r
P.scan
( \mcmcdata :: MCMCData a
mcmcdata@(MCMCData Int
ns Int
nsc [a]
smples [Double]
lk Int
_) MHResult a
a ->
MCMCData a
mcmcdata
{ numSteps :: Int
numSteps = Int
ns forall a. Num a => a -> a -> a
+ Int
1,
numSuccesses :: Int
numSuccesses = Int
nsc forall a. Num a => a -> a -> a
+ if forall a. MHResult a -> Bool
success MHResult a
a then Int
1 else Int
0,
samples :: [a]
samples = forall a. Trace a -> a
output (forall a. MHResult a -> Trace a
trace MHResult a
a) forall a. a -> [a] -> [a]
: [a]
smples,
lk :: [Double]
lk = forall a. Floating a => a -> a
exp (forall a. Log a -> a
ln (forall a. Trace a -> Log Double
probDensity (forall a. MHResult a -> Trace a
trace MHResult a
a))) forall a. a -> [a] -> [a]
: [Double]
lk
}
)
(forall {a}. Int -> MCMCData a
initialState Int
i)
forall a. a -> a
id
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a. Functor m => Int -> Pipe a a m ()
P.take Int
i
forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a r. Monad m => (a -> m ()) -> Consumer' a m r
P.mapM_ (forall a. BChan a -> a -> IO ()
B.writeBChan BChan (MCMCData a)
chan)