mcmc-samplers-0.1.1.1: Combinators for MCMC sampling

Safe HaskellNone
LanguageHaskell2010

MCMC.Examples.GMM

Contents

Description

Sampler for Gaussian Mixture Model

Here is the code in the Hakaru language for generating the data used in this example:

p <- unconditioned (beta 2 2)
[m1,m2] <- replicateM 2 $ unconditioned (normal 100 30)
[s1,s2] <- replicateM 2 $ unconditioned (uniform 0 2)
let makePoint = do        
      b <- unconditioned (bern p)
      unconditioned (ifThenElse b (normal m1 s1)
                                  (normal m2 s2))
replicateM nPoints makePoint

Synopsis

Documentation

data GaussianMixtureState

Constructors

GMM 

Fields

labels :: [Bool]

The list of observation labels

gaussParams :: ((Double, Double), (Double, Double))

The parameters of the two Gaussians (mean, covariance)

bernParam :: Double

The mixture proportion

obs :: [Double]

The observed data

Target

Focus combinators

focusLabels :: Target (Double, [Bool]) -> Target GaussianMixtureState
focusLabels t = makeTarget dens
    where dens (GMM l _ p _) = density t (p,l)

focusGaussParams :: Target ((Double, Double), (Double, Double)) -> Target GaussianMixtureState
focusGaussParams t = makeTarget (density t . gaussParams)

focusBernParam :: Target Double -> Target GaussianMixtureState
focusBernParam t = makeTarget (density t . bernParam)

focusObs :: Target ([Bool], ((Double, Double), (Double, Double)), [Double])
         -> Target GaussianMixtureState
focusObs t = makeTarget dens
    where dens (GMM l gps _ o) = density t (l, gps, o)

Record field targets

labelsTarget :: Target (Double, [Bool])
labelsTarget = makeTarget $ (p,ls) -> product $ map (density $ bern p) ls

gaussParamsTarget :: Target ((Double, Double), (Double, Double))
gaussParamsTarget = makeTarget dens
    where dens ((m1, c1), (m2, c2)) = mdens m1 * mdens m2 * cdens c1 * cdens c2
          mdens m = density (normal 100 900) m
          cdens c = density (uniform 0 200) c

bernParamTarget :: Target Double
bernParamTarget = fromProposal (beta 2 2)

obsTarget :: Target ([Bool], ((Double, Double), (Double, Double)), [Double])
obsTarget  = makeTarget dens
    where dens (ls, ((m1, c1), (m2, c2)), os) 
              = let ols = zip os ls
                    gauss l = if l then normal m1 (c1*c1) else normal m2 (c2*c2)
                in product $ map ((o,l) -> density (gauss l) o) ols

Target density factors

labelsFactor :: Target GaussianMixtureState
labelsFactor = focusLabels labelsTarget

gaussParamsFactor :: Target GaussianMixtureState
gaussParamsFactor = focusGaussParams gaussParamsTarget

bernParamFactor :: Target GaussianMixtureState
bernParamFactor = focusBernParam bernParamTarget

obsFactor :: Target GaussianMixtureState
obsFactor = focusObs obsTarget

Target density

gmmTarget :: Target GaussianMixtureState
gmmTarget = makeTarget $ productDensity 
            [labelsFactor, gaussParamsFactor, bernParamFactor, obsFactor]

Proposal

Proposal update boilerplate

updateLabels :: ([Bool] -> Proposal [Bool]) -> GaussianMixtureState -> Proposal GaussianMixtureState
updateLabels f x = makeProposal dens sf
    where dens y = density (f $ labels x) (labels y)
          sf g = do newLabels <- sampleFrom (f $ labels x) g
                    return x { labels = newLabels }

updateGaussParams :: (((Double, Double), (Double, Double)) -> Proposal ((Double, Double), (Double, Double)))
                     -> GaussianMixtureState -> Proposal GaussianMixtureState
updateGaussParams f x = makeProposal dens sf
    where dens y = density (f $ gaussParams x) (gaussParams y)
          sf g = do newParams <- sampleFrom (f $ gaussParams x) g
                    return x { gaussParams = newParams }

updateBernParam :: (Double -> Proposal Double) -> GaussianMixtureState -> Proposal GaussianMixtureState
updateBernParam f x = makeProposal dens sf
    where dens y = density (f $ bernParam x) (bernParam y)
          sf g = do newParam <- sampleFrom (f $ bernParam x) g
                    return x { bernParam = newParam }

Field proposals

labelsProposal :: [Bool] -> Proposal [Bool]
labelsProposal ls = chooseProposal nPoints (n -> updateNth n flipBool ls)
    where flipBool bn = if bn then bern 0 else bern 1

gaussParamsProposal :: ((Double, Double), (Double, Double)) -> Proposal ((Double, Double), (Double, Double))
gaussParamsProposal params = mixProposals $ zip [m1p, c1p, m2p, c2p] (repeat 1)
    where condProp c = normal c 1
          m1p = updateFirst (updateFirst condProp) params
          c1p = updateFirst (updateSecond condProp) params
          m2p = updateSecond (updateFirst condProp) params
          c2p = updateSecond (updateSecond condProp) params

bernParamProposal :: Double -> Proposal Double
bernParamProposal p = uniform (p2) (1-p2)

Field updaters

labelsUpdater :: GaussianMixtureState -> Proposal GaussianMixtureState
labelsUpdater = updateLabels labelsProposal

gaussParamsUpdater :: GaussianMixtureState -> Proposal GaussianMixtureState
gaussParamsUpdater = updateGaussParams gaussParamsProposal

bernParamUpdater :: GaussianMixtureState -> Proposal GaussianMixtureState
bernParamUpdater = updateBernParam bernParamProposal

The combined proposal

gmmProposal :: GaussianMixtureState -> Proposal GaussianMixtureState
gmmProposal = mixCondProposals $ zip [labelsUpdater, gaussParamsUpdater, bernParamUpdater] [10,1,2]

Running the sampler

Transition kernel

gmmMH :: Step GaussianMixtureState
gmmMH = metropolisHastings gmmTarget gmmProposal

Visualization methods

histogram :: Ord a => [a] -> Map.Map a Int
histogram ls = foldl addElem Map.empty ls
               where addElem m e = Map.insertWith (+) e 1 m

printFields :: PrintF GaussianMixtureState ([Bool], ((Double, Double), (Double, Double)), Double)
printFields = let f s = (labels s, gaussParams s, bernParam s) in map f 

printLabelN :: Int -> PrintF GaussianMixtureState Bool
printLabelN n = let f s = labels s !! (n-1) in map f

compareLabels :: Int -> Int -> PrintF GaussianMixtureState (Bool,Bool)
compareLabels n m = let f s = (labels s !! (n-1) , labels s !! (m-1)) in map f

printHist :: (Ord s, Show s) => PrintF x s -> Batch x -> IO ()
printHist f (ls,_) = unless (null ls) $ print . histogram $ f ls

batchHist :: (Ord s, Show s) => PrintF x s -> Int -> BatchAction x IO ()
batchHist f n = pack (printHist f) $ inBatches (printHist f) n

Main

nPoints :: Int
nPoints = 6

sampleData :: [Double]
sampleData = [ 63.13941114139962, 132.02763712240528
             , 62.59642260289356, 132.2616834236893
             , 64.10610391933461, 62.143820541377934 ]

gmmStart :: GaussianMixtureState
gmmStart = GMM { labels = [True, True, True, False, False, False],
                 gaussParams = ((63, 100), (132, 100)),
                 bernParam = 0.5,
                 obs = sampleData }

gmmTest :: IO ()
gmmTest = do
  g <- MWC.createSystemRandom
  let a = batchHist (compareLabels 5 6) 50
      e = every 50 a
      c = every 50 collect
  ls <- walk gmmMH gmmStart (10^6) g c
  putStrLn "Done"
  print $ take 20 (map labels ls)