Safe Haskell | None |
---|---|
Language | Haskell2010 |
Optimized sampler for Gaussian Mixture Model
Here is the code in the Hakaru language for generating the data used in this example:
p <- unconditioned (beta 2 2) [m1,m2] <- replicateM 2 $ unconditioned (normal 100 30) [s1,s2] <- replicateM 2 $ unconditioned (uniform 0 2) let makePoint = do b <- unconditioned (bern p) unconditioned (ifThenElse b (normal m1 s1) (normal m2 s2)) replicateM nPoints makePoint
Documentation
data GaussianMixtureState
Focused targets
targetLabel :: Int -> Target GaussianMixtureState targetLabel i =makeTarget
(densityLabel i) densityLabel :: Int -> GaussianMixtureState -> Double densityLabel i (GMM l _ p) = if (l !! (i-1)) then p else 1-p targetGaussParams :: Target GaussianMixtureState targetGaussParams =makeTarget
densityGaussParams densityGaussParams :: GaussianMixtureState -> Double densityGaussParams state = mdens m1 * mdens m2 * cdens c1 * cdens c2 where ((m1, c1), (m2, c2)) = gaussParams state mdens m =density
(normal
100 900) m cdens c =density
(uniform
0 2) c targetBernParam :: Target GaussianMixtureState targetBernParam =makeTarget
densityBernParam densityBernParam :: GaussianMixtureState -> Double densityBernParam state =density
(beta
2 2) (bernParam state) targetObs :: Int -> [Double] -> Target GaussianMixtureState targetObs i obs =makeTarget
(densityObs i obs) densityObs :: Int -> [Double] -> GaussianMixtureState -> Double densityObs i obs state = if labels state !! (i-1) thendensity
(normal
m1 c1) oi elsedensity
(normal
m2 c2) oi where oi = obs !! (i-1) ((m1, c1), (m2, c2)) = gaussParams state
Focused proposals
labelsProposal :: Int -> GaussianMixtureState -> Proposal GaussianMixtureState labelsProposal i x =makeProposal
dens sf where dens y =density
(updateLabel i $ labels x) (labels y) sf g = do newLabels <-sampleFrom
(updateLabel i $ labels x) g return x { labels = newLabels } updateLabel :: Int -> [Bool] -> Proposal [Bool] updateLabel i ls =updateNth
i flipBool ls where flipBool bn = if bn thenbern
0 elsebern
1 gaussParamsProposal :: GaussianMixtureState -> Proposal GaussianMixtureState gaussParamsProposal x =makeProposal
dens sf where dens y =density
(updateGaussParams $ gaussParams x) (gaussParams y) sf g = do newParams <-sampleFrom
(updateGaussParams $ gaussParams x) g return x { gaussParams = newParams } updateGaussParams :: ((Double, Double), (Double, Double)) -> Proposal ((Double, Double), (Double, Double)) updateGaussParams params =mixProposals
$ zip [m1p, c1p, m2p, c2p] (repeat 1) where condProp c =normal
c 1 m1p =updateFirst
(updateFirst
condProp) params c1p =updateFirst
(updateSecond
condProp) params m2p =updateSecond
(updateFirst
condProp) params c2p =updateSecond
(updateSecond
condProp) params bernParamProposal :: GaussianMixtureState -> Proposal GaussianMixtureState bernParamProposal x =makeProposal
dens sf where dens y =density
(updateBernParam $ bernParam x) (bernParam y) sf g = do newParam <-sampleFrom
(updateBernParam $ bernParam x) g return x { bernParam = newParam } updateBernParam :: Double -> Proposal Double updateBernParam p =uniform
(p2) (1-p2)
Focused steps
Each step computes only those parts of the density ratio that its proposal affects - the other parts would cancel out
stepLabels :: [Double] -> Step GaussianMixtureState stepLabels obs =chooseStep
nPoints (i ->makeTarget
$ dens i) labelsProposalmetropolisHastings
where dens i state =density
(targetLabel i) state *density
(targetObs i obs) state -- This could be optimized further if we know the label corresponding -- to the gaussian to which the updated param belongs. stepGaussParams :: [Double] -> Step GaussianMixtureState stepGaussParams obs =metropolisHastings
(makeTarget
dens) gaussParamsProposal where dens state =density
targetGaussParams state * product [density
(targetObs i obs) state | i <- [1..nPoints]] stepBernParam :: Step GaussianMixtureState stepBernParam =metropolisHastings
(makeTarget
dens) bernParamProposal where dens state =density
targetBernParam state * product [density
(targetLabel i) state | i <- [1..nPoints]]
Optimized sampler
A mixture of focused, i.e. optimized steps
gmmSampler :: [Double] -> Step GaussianMixtureState
gmmSampler obs = mixSteps
$
zip [(stepLabels obs), (stepGaussParams obs), stepBernParam] [1,1,1]
Main
nPoints :: Int nPoints = 6 sampleData :: [Double] sampleData = [ 63.13941114139962, 132.02763712240528 , 62.59642260289356, 132.2616834236893 , 64.10610391933461, 62.143820541377934 ] startState :: GaussianMixtureState startState = GMM { -- labels = [True, True, True, False, False, False], labels = [False, False, False, True, True, True], gaussParams = ((63, 100), (132, 100)), bernParam = 0.5 } test :: IO () test = do g <- MWC.createSystemRandom let c =every
50collect
p =every
1 (display
labels) -- ls <-walk
(gmmSampler sampleData) startState (10^6) g c -- print $ take 20 (map labels ls)walk
(gmmSampler sampleData) startState (10^2) g p