-- | Exponential Family 'Harmonium's and gibbs sampling. module Goal.Probability.Graphical.Harmonium ( -- * Harmoniums Harmonium (Harmonium) -- ** Type Synonyms , NaturalFunction -- ** Structural Manipulation , splitHarmonium , joinHarmonium , harmoniumTranspose -- ** Conditional Distribution Functions , conditionalLatentDistribution , conditionalObservableDistribution , conditionalLatentDistributions , conditionalObservableDistributions -- ** Gibbs Sampling , bulkGibbsSampling , bulkGibbsSampling0 -- * Transducers , buildNormalTransducer , buildReplicatedNormalTransducer , modulateTransducerGain , modulateHarmoniumBelief ) where --- Imports --- -- Goal -- import Goal.Geometry import Goal.Probability.Statistical import Goal.Probability.ExponentialFamily import Goal.Probability.Distributions import Goal.Probability.Graphical import System.Random.MWC.Monad import qualified Data.Vector.Storable as C --- Types --- -- | A quadratic function in the product space of two exponential families. data Harmonium m n = Harmonium m n deriving (Eq, Read, Show) -- Datatype manipulation -- splitHarmonium :: (Manifold m, Manifold n) => Function c d :#: Harmonium m n -> (d :#: m, Function c d :#: Tensor m n, Dual c :#: n) -- | Splits a 'Harmonium' into its components parts of a 'Tensor' and a pair of biases. splitHarmonium qdc = let (Harmonium m n) = manifold qdc tns = Tensor m n (mcs,css') = C.splitAt (dimension m) $ coordinates qdc (mtxcs,ncs) = C.splitAt (dimension tns) css' in (fromCoordinates m mcs, fromCoordinates tns mtxcs, fromCoordinates n ncs) joinHarmonium :: (Manifold m, Manifold n) => d :#: m -> Function c d :#: Tensor m n -> Dual c :#: n -> Function c d :#: Harmonium m n -- | Assembles a 'Harmonium' out of the components of the quadratic function. joinHarmonium dm mtx cn = let (Tensor m n) = manifold mtx in fromCoordinates (Harmonium m n) $ coordinates dm C.++ coordinates mtx C.++ coordinates cn harmoniumTranspose :: (Manifold n, Manifold m, Primal c, Primal d) => Function c d :#: Harmonium m n -> Function (Dual d) (Dual c) :#: Harmonium n m -- | Transposes the 'Tensor' in the 'Harmonium' and swaps the biases. harmoniumTranspose qdc = let (dm,mtx,dn) = splitHarmonium qdc in joinHarmonium dn (matrixTranspose mtx) dm --- Functions --- conditionalLatentDistributions :: (Manifold m, ExponentialFamily n) => NaturalFunction :#: Harmonium m n -> [Sample n] -> [Natural :#: m] -- | Calculates the latent distributions given some observations. conditionalLatentDistributions p os = let (Harmonium _ n) = manifold p in p >$> (sufficientStatistic n <$> os) conditionalObservableDistributions :: (ExponentialFamily m, Manifold n) => NaturalFunction :#: Harmonium m n -> [Sample m] -> [Natural :#: n] -- | Calculates the observable distributions given some latent states. conditionalObservableDistributions p ls = let (Harmonium m _) = manifold p in harmoniumTranspose p >$> (sufficientStatistic m <$> ls) conditionalLatentDistribution :: (Manifold m, ExponentialFamily n) => NaturalFunction :#: Harmonium m n -> Sample n -> Natural :#: m -- | Calculates the latent distributions given an observation. conditionalLatentDistribution p o = let (Harmonium _ n) = manifold p in p >.> sufficientStatistic n o conditionalObservableDistribution :: (ExponentialFamily m, Manifold n) => NaturalFunction :#: Harmonium m n -> Sample m -> Natural :#: n -- | Calculates the observable distributions given a latent state. conditionalObservableDistribution p l = let (Harmonium m _) = manifold p in harmoniumTranspose p >.> sufficientStatistic m l bulkGibbsSampling :: (ExponentialFamily m, Generative Natural m, ExponentialFamily n, Generative Natural n) => Int -> NaturalFunction :#: Harmonium m n -> [Sample n] -> RandST s [[(Sample m, Sample n)]] -- | Returns a Markov chain over the latent and observable states generated by Gibbs sampling. bulkGibbsSampling k0 p o0s = do l0s <- mapM generate $ conditionalLatentDistributions p o0s gbs <- gibbsSampler k0 l0s [] return $ zip l0s o0s : gbs where (Harmonium m n) = manifold p gibbsSampler 0 _ acc = return $ reverse acc gibbsSampler k ls acc = do let mls = sufficientStatistic m <$> ls os' <- mapM generate $ harmoniumTranspose p >$> mls let mos' = sufficientStatistic n <$> os' ls' <- mapM generate $ p >$> mos' gibbsSampler (k-1) ls' (zip ls' os':acc) bulkGibbsSampling0 :: (ExponentialFamily m, Generative Natural m, ExponentialFamily n, Generative Natural n) => Int -> NaturalFunction :#: Harmonium m n -> [Mixture :#: n] -> RandST s [[(Mixture :#: m, Mixture :#: n)]] -- | Returns a Markov chain over the latent and observable expoential families generated by Gibbs sampling. bulkGibbsSampling0 k0 p mo0s = gibbsSampler k0 mo0s [] where (Harmonium m n) = manifold p gibbsSampler 0 mos acc = return . reverse $ zip (potentialMapping <$> (p >$> mos)) mos:acc gibbsSampler k mos acc = do ls <- mapM generate $ p >$> mos let mls = sufficientStatistic m <$> ls os' <- mapM generate $ harmoniumTranspose p >$> mls let mos' = sufficientStatistic n <$> os' gibbsSampler (k-1) mos' (zip mls mos:acc) modulateHarmoniumBelief :: (Manifold m, Manifold n) => Mixture :#: m -> NaturalFunction :#: Harmonium m n -> NaturalFunction :#: Harmonium m n -- | Adds the projection of the given belief to the biases over the state. modulateHarmoniumBelief z trns = let (lb,mtx,ob) = splitHarmonium trns in joinHarmonium lb mtx $ ob <+> matrixTranspose mtx >.> z --- Transducers --- normalBias :: (Standard :#: Normal) -> Double normalBias sp = let [mu,vr] = listCoordinates sp in - mu^2/(2*vr) buildNormalTransducer :: [Standard :#: Normal] -> NaturalFunction :#: Harmonium (Replicated Poisson) Normal -- | Builds a Transducer (i.e. Population Code) which is a 'Harmonium' with -- a 'Replicated' 'Poisson' latent 'Manifold'. Here the observable 'Normal' -- is 'Normal'. buildNormalTransducer sps = let nps = chart Natural . transition <$> sps rp = Replicated Poisson $ length nps lb = fromList rp $ normalBias <$> sps ob = fromList Normal $ replicate 2 0 tns = fromCoordinates (Tensor rp Normal) . C.concat $ coordinates <$> nps in joinHarmonium lb tns ob buildReplicatedNormalTransducer :: [Standard :#: Replicated Normal] -> NaturalFunction :#: Harmonium (Replicated Poisson) (Replicated Normal) -- | Builds a Transducer (i.e. Population Code) which is a 'Harmonium' with -- a 'Replicated' 'Poisson' latent 'Manifold'. Here the observable 'Normal' -- is 'Replicated' 'Normal'. buildReplicatedNormalTransducer sps = let nps = chart Natural . transition <$> sps m = manifold $ head sps rp = Replicated Poisson $ length nps lb = fromList rp $ sum . mapReplicated normalBias <$> sps ob = fromList m $ replicate (dimension m) 0 tns = fromCoordinates (Tensor rp m) . C.concat $ coordinates <$> nps in joinHarmonium lb tns ob modulateTransducerGain :: Manifold n => Double -> NaturalFunction :#: Harmonium (Replicated Poisson) n -> NaturalFunction :#: Harmonium (Replicated Poisson) n -- | Multiplies the current gain of the transducer by the given value. -- Transducers are intially constructed with a gain of 1, and so initially -- this will simply set the gain. modulateTransducerGain gn trns = let (lb,mtx,ob) = splitHarmonium trns lb' = alterCoordinates (+ log gn) lb in joinHarmonium lb' mtx ob --- Instances --- -- Harmoniums -- instance (Manifold m, Manifold n) => Manifold (Harmonium m n) where dimension (Harmonium m n) = dimension m * dimension n + dimension m + dimension n instance (Manifold m, Manifold n) => Map (Harmonium m n) where type Domain (Harmonium m n) = n domain (Harmonium _ n) = n type Codomain (Harmonium m n) = m codomain (Harmonium m _) = m instance (Manifold m, Manifold n) => Apply c d (Harmonium m n) where (>.>) p x = let (lb,mtxp,_) = splitHarmonium p in lb <+> (mtxp >.> x) (>$>) p xs = let (lb,mtxp,_) = splitHarmonium p in (lb <+>) <$> (mtxp >$> xs)