{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ImportQualifiedPost #-}
{-# OPTIONS_GHC -Wno-type-defaults #-}

module Control.Monad.Bayes.Inference.TUI where

import Brick
import Brick qualified as B
import Brick.BChan qualified as B
import Brick.Widgets.Border
import Brick.Widgets.Border.Style
import Brick.Widgets.Center
import Brick.Widgets.ProgressBar qualified as B
import Control.Arrow (Arrow (..))
import Control.Concurrent (forkIO)
import Control.Foldl qualified as Fold
import Control.Monad (void)
import Control.Monad.Bayes.Enumerator (toEmpirical)
import Control.Monad.Bayes.Inference.MCMC
import Control.Monad.Bayes.Sampler.Strict (SamplerIO, sampleIO)
import Control.Monad.Bayes.Traced (Traced)
import Control.Monad.Bayes.Traced.Common
import Control.Monad.Bayes.Weighted
import Control.Monad.State.Class (put)
import Data.Scientific (FPFormat (Exponent), formatScientific, fromFloatDigits)
import Data.Text qualified as T
import Data.Text.Lazy qualified as TL
import Data.Text.Lazy.IO qualified as TL
import GHC.Float (double2Float)
import Graphics.Vty
import Graphics.Vty qualified as V
import Numeric.Log (Log (ln))
import Pipes (runEffect, (>->))
import Pipes qualified as P
import Pipes.Prelude qualified as P
import Text.Pretty.Simple (pShow, pShowNoColor)

data MCMCData a = MCMCData
  { forall a. MCMCData a -> Int
numSteps :: Int,
    forall a. MCMCData a -> Int
numSuccesses :: Int,
    forall a. MCMCData a -> [a]
samples :: [a],
    forall a. MCMCData a -> [Double]
lk :: [Double],
    forall a. MCMCData a -> Int
totalSteps :: Int
  }
  deriving stock (Int -> MCMCData a -> ShowS
forall a. Show a => Int -> MCMCData a -> ShowS
forall a. Show a => [MCMCData a] -> ShowS
forall a. Show a => MCMCData a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [MCMCData a] -> ShowS
$cshowList :: forall a. Show a => [MCMCData a] -> ShowS
show :: MCMCData a -> String
$cshow :: forall a. Show a => MCMCData a -> String
showsPrec :: Int -> MCMCData a -> ShowS
$cshowsPrec :: forall a. Show a => Int -> MCMCData a -> ShowS
Show)

-- | Brick is a terminal user interface (TUI)
-- which we use to display inference algorithms in progress

-- | draw the brick app
drawUI :: ([a] -> Widget n) -> MCMCData a -> [Widget n]
drawUI :: forall a n. ([a] -> Widget n) -> MCMCData a -> [Widget n]
drawUI [a] -> Widget n
handleSamples MCMCData a
state = [Widget n
ui]
  where
    completionBar :: Widget n
completionBar =
      forall n. (AttrMap -> AttrMap) -> Widget n -> Widget n
updateAttrMap
        ( [(AttrName, AttrName)] -> AttrMap -> AttrMap
B.mapAttrNames
            [ (AttrName
doneAttr, AttrName
B.progressCompleteAttr),
              (AttrName
toDoAttr, AttrName
B.progressIncompleteAttr)
            ]
        )
        forall a b. (a -> b) -> a -> b
$ forall {n}. Float -> Widget n
toBar forall a b. (a -> b) -> a -> b
$ forall a b. (Integral a, Num b) => a -> b
fromIntegral forall a b. (a -> b) -> a -> b
$ forall a. MCMCData a -> Int
numSteps MCMCData a
state

    likelihoodBar :: Widget n
likelihoodBar =
      forall n. (AttrMap -> AttrMap) -> Widget n -> Widget n
updateAttrMap
        ( [(AttrName, AttrName)] -> AttrMap -> AttrMap
B.mapAttrNames
            [ (AttrName
doneAttr, AttrName
B.progressCompleteAttr),
              (AttrName
toDoAttr, AttrName
B.progressIncompleteAttr)
            ]
        )
        forall a b. (a -> b) -> a -> b
$ forall n. Maybe String -> Float -> Widget n
B.progressBar
          (forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ String
"Mean likelihood for last 1000 samples: " forall a. Semigroup a => a -> a -> a
<> forall a. Int -> [a] -> [a]
take Int
10 (forall a. Show a => a -> String
show (forall a. [a] -> a
head forall a b. (a -> b) -> a -> b
$ forall a. MCMCData a -> [Double]
lk MCMCData a
state forall a. Semigroup a => a -> a -> a
<> [Double
0])))
          (Double -> Float
double2Float (forall (f :: * -> *) a b. Foldable f => Fold a b -> f a -> b
Fold.fold forall a. Fractional a => Fold a a
Fold.mean forall a b. (a -> b) -> a -> b
$ forall a. Int -> [a] -> [a]
take Int
1000 forall a b. (a -> b) -> a -> b
$ forall a. MCMCData a -> [Double]
lk MCMCData a
state) forall a. Fractional a => a -> a -> a
/ Double -> Float
double2Float (forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum forall a b. (a -> b) -> a -> b
$ Double
0 forall a. a -> [a] -> [a]
: forall a. MCMCData a -> [Double]
lk MCMCData a
state))

    displayStep :: a -> Maybe String
displayStep a
c = forall a. a -> Maybe a
Just forall a b. (a -> b) -> a -> b
$ String
"Step " forall a. Semigroup a => a -> a -> a
<> forall a. Show a => a -> String
show a
c
    numFailures :: Int
numFailures = forall a. MCMCData a -> Int
numSteps MCMCData a
state forall a. Num a => a -> a -> a
- forall a. MCMCData a -> Int
numSuccesses MCMCData a
state
    toBar :: Float -> Widget n
toBar Float
v = forall n. Maybe String -> Float -> Widget n
B.progressBar (forall {a}. Show a => a -> Maybe String
displayStep Float
v) (Float
v forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. MCMCData a -> Int
totalSteps MCMCData a
state))
    displaySuccessesAndFailures :: Widget n
displaySuccessesAndFailures =
      forall n. BorderStyle -> Widget n -> Widget n
withBorderStyle BorderStyle
unicode forall a b. (a -> b) -> a -> b
$
        forall n. Widget n -> Widget n -> Widget n
borderWithLabel (forall n. String -> Widget n
str String
"Successes and failures") forall a b. (a -> b) -> a -> b
$
          forall n. Widget n -> Widget n
center (forall n. String -> Widget n
str (forall a. Show a => a -> String
show forall a b. (a -> b) -> a -> b
$ forall a. MCMCData a -> Int
numSuccesses MCMCData a
state))
            forall n. Widget n -> Widget n -> Widget n
<+> forall n. Widget n
vBorder
            forall n. Widget n -> Widget n -> Widget n
<+> forall n. Widget n -> Widget n
center (forall n. String -> Widget n
str (forall a. Show a => a -> String
show Int
numFailures))
    warning :: Widget n
warning =
      if forall a. MCMCData a -> Int
numSteps MCMCData a
state forall a. Ord a => a -> a -> Bool
> Int
1000 Bool -> Bool -> Bool
&& (forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. MCMCData a -> Int
numSuccesses MCMCData a
state) forall a. Fractional a => a -> a -> a
/ forall a b. (Integral a, Num b) => a -> b
fromIntegral (forall a. MCMCData a -> Int
numSteps MCMCData a
state)) forall a. Ord a => a -> a -> Bool
< Double
0.1
        then forall n. AttrName -> Widget n -> Widget n
withAttr (String -> AttrName
attrName String
"highlight") forall a b. (a -> b) -> a -> b
$ forall n. String -> Widget n
str String
"Warning: acceptance rate is rather low.\nThis probably means that your proposal isn't good."
        else forall n. String -> Widget n
str String
""

    ui :: Widget n
ui =
      (forall n. String -> Widget n
str String
"Progress: " forall n. Widget n -> Widget n -> Widget n
<+> forall n. Widget n
completionBar)
        forall n. Widget n -> Widget n -> Widget n
<=> (forall n. String -> Widget n
str String
"Likelihood: " forall n. Widget n -> Widget n -> Widget n
<+> forall n. Widget n
likelihoodBar)
        forall n. Widget n -> Widget n -> Widget n
<=> forall n. String -> Widget n
str String
"\n"
        forall n. Widget n -> Widget n -> Widget n
<=> forall n. Widget n
displaySuccessesAndFailures
        forall n. Widget n -> Widget n -> Widget n
<=> forall n. Widget n
warning
        forall n. Widget n -> Widget n -> Widget n
<=> [a] -> Widget n
handleSamples (forall a. MCMCData a -> [a]
samples MCMCData a
state)

noVisual :: b -> Widget n
noVisual :: forall b n. b -> Widget n
noVisual = forall a b. a -> b -> a
const forall n. Widget n
emptyWidget

showEmpirical :: (Show a, Ord a) => [a] -> Widget n
showEmpirical :: forall a n. (Show a, Ord a) => [a] -> Widget n
showEmpirical =
  forall n. Text -> Widget n
txt
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
T.pack
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> String
TL.unpack
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. Show a => a -> Text
pShow
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second (FPFormat -> Maybe Int -> Scientific -> String
formatScientific FPFormat
Exponent (forall a. a -> Maybe a
Just Int
3) forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a. RealFloat a => a -> Scientific
fromFloatDigits)))
    forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall b a. (Fractional b, Ord a, Ord b) => [a] -> [(a, b)]
toEmpirical

showVal :: Show a => [a] -> Widget n
showVal :: forall a n. Show a => [a] -> Widget n
showVal = forall n. Text -> Widget n
txt forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Text
T.pack forall b c a. (b -> c) -> (a -> b) -> a -> c
. (\case [] -> String
""; [a]
a -> forall a. Show a => a -> String
show forall a b. (a -> b) -> a -> b
$ forall a. [a] -> a
head [a]
a)

-- | handler for events received by the TUI
appEvent :: B.BrickEvent n s -> B.EventM n s ()
appEvent :: forall n s. BrickEvent n s -> EventM n s ()
appEvent (B.VtyEvent (V.EvKey (V.KChar Char
'q') [])) = forall n s. EventM n s ()
B.halt
appEvent (B.VtyEvent Event
_) = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
appEvent (B.AppEvent s
d) = forall s (m :: * -> *). MonadState s m => s -> m ()
put s
d
appEvent BrickEvent n s
_ = forall a. HasCallStack => String -> a
error String
"unknown event"

doneAttr, toDoAttr :: B.AttrName
doneAttr :: AttrName
doneAttr = String -> AttrName
B.attrName String
"theBase" forall a. Semigroup a => a -> a -> a
<> String -> AttrName
B.attrName String
"done"
toDoAttr :: AttrName
toDoAttr = String -> AttrName
B.attrName String
"theBase" forall a. Semigroup a => a -> a -> a
<> String -> AttrName
B.attrName String
"remaining"

theMap :: B.AttrMap
theMap :: AttrMap
theMap =
  Attr -> [(AttrName, Attr)] -> AttrMap
B.attrMap
    Attr
V.defAttr
    [ (String -> AttrName
B.attrName String
"theBase", Color -> Attr
bg Color
V.brightBlack),
      (AttrName
doneAttr, Color
V.black Color -> Color -> Attr
`on` Color
V.white),
      (AttrName
toDoAttr, Color
V.white Color -> Color -> Attr
`on` Color
V.black),
      (String -> AttrName
attrName String
"highlight", Color -> Attr
fg Color
yellow)
    ]

tui :: Show a => Int -> Traced (Weighted SamplerIO) a -> ([a] -> Widget ()) -> IO ()
tui :: forall a.
Show a =>
Int -> Traced (Weighted SamplerIO) a -> ([a] -> Widget ()) -> IO ()
tui Int
burnIn Traced (Weighted SamplerIO) a
distribution [a] -> Widget ()
visualizer = forall (f :: * -> *) a. Functor f => f a -> f ()
void do
  BChan (MCMCData a)
eventChan <- forall a. Int -> IO (BChan a)
B.newBChan Int
10
  Vty
initialVty <- IO Vty
buildVty
  ThreadId
_ <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ forall {a}.
Proxy X () () (MHResult a) (Weighted SamplerIO) ()
-> BChan (MCMCData a) -> Int -> IO ()
run (forall (m :: * -> *) a.
MonadDistribution m =>
MCMCConfig -> Traced m a -> Producer (MHResult a) m ()
mcmcP MCMCConfig {numBurnIn :: Int
numBurnIn = Int
burnIn, proposal :: Proposal
proposal = Proposal
SingleSiteMH, numMCMCSteps :: Int
numMCMCSteps = -Int
1} Traced (Weighted SamplerIO) a
distribution) BChan (MCMCData a)
eventChan Int
n
  MCMCData a
samples <-
    forall n e s.
Ord n =>
Vty -> IO Vty -> Maybe (BChan e) -> App s e n -> s -> IO s
B.customMain
      Vty
initialVty
      IO Vty
buildVty
      (forall a. a -> Maybe a
Just BChan (MCMCData a)
eventChan)
      ( ( B.App
            { appDraw :: MCMCData a -> [Widget ()]
B.appDraw = forall a n. ([a] -> Widget n) -> MCMCData a -> [Widget n]
drawUI [a] -> Widget ()
visualizer,
              appChooseCursor :: MCMCData a -> [CursorLocation ()] -> Maybe (CursorLocation ())
B.appChooseCursor = forall s n. s -> [CursorLocation n] -> Maybe (CursorLocation n)
B.showFirstCursor,
              appHandleEvent :: BrickEvent () (MCMCData a) -> EventM () (MCMCData a) ()
B.appHandleEvent = forall n s. BrickEvent n s -> EventM n s ()
appEvent,
              appStartEvent :: EventM () (MCMCData a) ()
B.appStartEvent = forall (m :: * -> *) a. Monad m => a -> m a
return (),
              appAttrMap :: MCMCData a -> AttrMap
B.appAttrMap = forall a b. a -> b -> a
const AttrMap
theMap
            }
        )
      )
      (forall {a}. Int -> MCMCData a
initialState Int
n)
  String -> Text -> IO ()
TL.writeFile String
"data/tui_output.txt" (forall a. Show a => a -> Text
pShowNoColor MCMCData a
samples)
  forall (m :: * -> *) a. Monad m => a -> m a
return MCMCData a
samples
  where
    buildVty :: IO Vty
buildVty = Config -> IO Vty
V.mkVty Config
V.defaultConfig
    n :: Int
n = Int
100000
    initialState :: Int -> MCMCData a
initialState Int
n = MCMCData {numSteps :: Int
numSteps = Int
0, samples :: [a]
samples = [], lk :: [Double]
lk = [], numSuccesses :: Int
numSuccesses = Int
0, totalSteps :: Int
totalSteps = Int
n}

    run :: Proxy X () () (MHResult a) (Weighted SamplerIO) ()
-> BChan (MCMCData a) -> Int -> IO ()
run Proxy X () () (MHResult a) (Weighted SamplerIO) ()
prod BChan (MCMCData a)
chan Int
i =
      forall (m :: * -> *) r. Monad m => Effect m r -> m r
runEffect forall a b. (a -> b) -> a -> b
$
        forall {k} (t :: (* -> *) -> k -> *) (m :: * -> *) (n :: * -> *)
       (b :: k).
(MFunctor t, Monad m) =>
(forall a. m a -> n a) -> t m b -> t n b
P.hoist (forall a. SamplerIO a -> IO a
sampleIO forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall (m :: * -> *) a. Functor m => Weighted m a -> m a
unweighted) Proxy X () () (MHResult a) (Weighted SamplerIO) ()
prod
          forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) x a b r.
Functor m =>
(x -> a -> x) -> x -> (x -> b) -> Pipe a b m r
P.scan
            ( \mcmcdata :: MCMCData a
mcmcdata@(MCMCData Int
ns Int
nsc [a]
smples [Double]
lk Int
_) MHResult a
a ->
                MCMCData a
mcmcdata
                  { numSteps :: Int
numSteps = Int
ns forall a. Num a => a -> a -> a
+ Int
1,
                    numSuccesses :: Int
numSuccesses = Int
nsc forall a. Num a => a -> a -> a
+ if forall a. MHResult a -> Bool
success MHResult a
a then Int
1 else Int
0,
                    samples :: [a]
samples = forall a. Trace a -> a
output (forall a. MHResult a -> Trace a
trace MHResult a
a) forall a. a -> [a] -> [a]
: [a]
smples,
                    lk :: [Double]
lk = forall a. Floating a => a -> a
exp (forall a. Log a -> a
ln (forall a. Trace a -> Log Double
probDensity (forall a. MHResult a -> Trace a
trace MHResult a
a))) forall a. a -> [a] -> [a]
: [Double]
lk
                  }
            )
            (forall {a}. Int -> MCMCData a
initialState Int
i)
            forall a. a -> a
id
          forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a. Functor m => Int -> Pipe a a m ()
P.take Int
i
          forall (m :: * -> *) a' a b r c' c.
Functor m =>
Proxy a' a () b m r -> Proxy () b c' c m r -> Proxy a' a c' c m r
>-> forall (m :: * -> *) a r. Monad m => (a -> m ()) -> Consumer' a m r
P.mapM_ (forall a. BChan a -> a -> IO ()
B.writeBChan BChan (MCMCData a)
chan)