{-# LANGUAGE UndecidableInstances,TypeApplications #-}
-- | Definitions for working with exponential families.
module Goal.Probability.ExponentialFamily
    ( -- * Exponential Families
    ExponentialFamily (sufficientStatistic, averageSufficientStatistic, logBaseMeasure)
    , LegendreExponentialFamily
    , DuallyFlatExponentialFamily
    , exponentialFamilyLogDensities
    , unnormalizedLogDensities
    -- ** Coordinate Systems
    , Natural
    , Mean
    , Source
    -- ** Coordinate Transforms
    , toNatural
    , toMean
    , toSource
    -- ** Entropies
    , relativeEntropy
    , crossEntropy
    -- ** Differentials
    , relativeEntropyDifferential
    , stochasticRelativeEntropyDifferential
    , stochasticInformationProjectionDifferential
    -- *** Maximum Likelihood Instances
    , exponentialFamilyLogLikelihood
    , exponentialFamilyLogLikelihoodDifferential
    ) where

--- Imports ---


-- Package --

import Goal.Probability.Statistical

import Goal.Core
import Goal.Geometry

import qualified Goal.Core.Vector.Storable as S

import Foreign.Storable

--- Exponential Families ---


-- | A parameterization which represents the standard or typical parameterization of
-- the given manifold, e.g. the Poisson rate or Normal mean and standard deviation.
data Source

-- | A parameterization in terms of the natural parameters of an exponential family.
data Natural

-- | A parameterization in terms of the mean 'sufficientStatistic' of an exponential family.
data Mean

instance Primal Natural where
    type Dual Natural = Mean

instance Primal Mean where
    type Dual Mean = Natural

-- | Expresses an exponential family distribution in 'Natural' coordinates.
toNatural :: (Transition c Natural x) => c # x -> Natural # x
toNatural :: (c # x) -> Natural # x
toNatural = (c # x) -> Natural # x
forall c d x. Transition c d x => (c # x) -> d # x
transition

-- | Expresses an exponential family distribution in 'Mean' coordinates.
toMean :: (Transition c Mean x) => c # x -> Mean # x
toMean :: (c # x) -> Mean # x
toMean = (c # x) -> Mean # x
forall c d x. Transition c d x => (c # x) -> d # x
transition

-- | Expresses an exponential family distribution in 'Source' coordinates.
toSource :: (Transition c Source x) => c # x -> Source # x
toSource :: (c # x) -> Source # x
toSource = (c # x) -> Source # x
forall c d x. Transition c d x => (c # x) -> d # x
transition

-- | An 'ExponentialFamily' is a 'Statistical' 'Manifold' \( \mathcal M \)
-- determined by a fixed-length 'sufficientStatistic' \(s_i\) and a
-- 'logBaseMeasure' \(\mu\). Each distribution \(P \in \mathcal M\) may then be
-- identified with 'Natural' parameters \(\theta_i\) such that
-- \(p(x) \propto e^{\sum_{i=1}^n \theta_i s_i(x)}\mu(x)\).  'ExponentialFamily'
-- distributions theoretically have a 'Riemannian' geometry, with 'metric'
-- 'Tensor' given by the Fisher information metric. However, not all
-- distributions (e.g. the von Mises distribution) afford closed-form
-- expressions for all the relevant structures.
class Statistical x => ExponentialFamily x where
    sufficientStatistic :: SamplePoint x -> Mean # x
    averageSufficientStatistic :: Sample x -> Mean # x
    averageSufficientStatistic = [Mean # x] -> Mean # x
forall (f :: Type -> Type) x.
(Foldable f, Fractional x) =>
f x -> x
average ([Mean # x] -> Mean # x)
-> (Sample x -> [Mean # x]) -> Sample x -> Mean # x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SamplePoint x -> Mean # x) -> Sample x -> [Mean # x]
forall a b. (a -> b) -> [a] -> [b]
map SamplePoint x -> Mean # x
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic
    logBaseMeasure :: Proxy x -> SamplePoint x -> Double

-- | When the log-partition function and its derivative of the given
-- 'ExponentialFamily' may be computed in closed-form, then we refer to it as a
-- 'LegendreExponentialFamily'.
--
-- Note that the log-partition function is the 'potential' of the 'Legendre'
-- class, and its derivative maps 'Natural' coordinates to 'Mean' coordinates.
type LegendreExponentialFamily x =
    ( PotentialCoordinates x ~ Natural, Legendre x, ExponentialFamily x
    , Transition (PotentialCoordinates x) (Dual (PotentialCoordinates x)) x )

-- | When additionally, the (negative) entropy and its derivative of the given
-- 'ExponentialFamily' may be computed in closed-form, then we refer to it as a
-- 'DuallyFlatExponentialFamily'.
--
-- Note that the negative entropy is the 'dualPotential' of the 'DuallyFlat' class,
-- and its derivative maps 'Mean' coordinates to 'Natural' coordinates.
type DuallyFlatExponentialFamily x =
    ( LegendreExponentialFamily x, DuallyFlat x
    , Transition (Dual (PotentialCoordinates x)) (PotentialCoordinates x) x )

-- | The relative entropy \(D(P \parallel Q)\), also known as the KL-divergence.
-- This is simply the 'canonicalDivergence' with its arguments flipped.
relativeEntropy :: DuallyFlatExponentialFamily x => Mean # x -> Natural # x -> Double
relativeEntropy :: (Mean # x) -> (Natural # x) -> Double
relativeEntropy = ((Natural # x) -> (Mean # x) -> Double)
-> (Mean # x) -> (Natural # x) -> Double
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Natural # x) -> (Mean # x) -> Double
forall x.
DuallyFlat x =>
(PotentialCoordinates x # x)
-> (PotentialCoordinates x #* x) -> Double
canonicalDivergence

-- | A function for computing the cross-entropy, which is the relative entropy
-- plus the entropy of the first distribution.
crossEntropy :: DuallyFlatExponentialFamily x => Mean # x -> Natural # x ->
    Double
crossEntropy :: (Mean # x) -> (Natural # x) -> Double
crossEntropy Mean # x
mp Natural # x
nq = (PotentialCoordinates x # x) -> Double
forall x. Legendre x => (PotentialCoordinates x # x) -> Double
potential PotentialCoordinates x # x
Natural # x
nq Double -> Double -> Double
forall a. Num a => a -> a -> a
- (Mean # x
mp (Mean # x) -> (Mean #* x) -> Double
forall c x. (c # x) -> (c #* x) -> Double
<.> Mean #* x
Natural # x
nq)

-- | The differential of the relative entropy with respect to the 'Natural' parameters of
-- the second argument.
relativeEntropyDifferential :: LegendreExponentialFamily x => Mean # x -> Natural # x -> Mean # x
relativeEntropyDifferential :: (Mean # x) -> (Natural # x) -> Mean # x
relativeEntropyDifferential Mean # x
mp Natural # x
nq = (Natural # x) -> Mean # x
forall c d x. Transition c d x => (c # x) -> d # x
transition Natural # x
nq (Mean # x) -> (Mean # x) -> Mean # x
forall a. Num a => a -> a -> a
- Mean # x
mp

-- | Monte Carlo estimate of the differential of the relative entropy with
-- respect to the 'Natural' parameters of the second argument, based on samples from
-- the two distributions.
stochasticRelativeEntropyDifferential
    :: ExponentialFamily x
    => Sample x -- ^ True Samples
    -> Sample x -- ^ Model Samples
    -> Mean # x -- ^ Differential Estimate
stochasticRelativeEntropyDifferential :: Sample x -> Sample x -> Mean # x
stochasticRelativeEntropyDifferential Sample x
pxs Sample x
qxs =
    Sample x -> Mean # x
forall x. ExponentialFamily x => Sample x -> Mean # x
averageSufficientStatistic Sample x
qxs (Mean # x) -> (Mean # x) -> Mean # x
forall a. Num a => a -> a -> a
- Sample x -> Mean # x
forall x. ExponentialFamily x => Sample x -> Mean # x
averageSufficientStatistic Sample x
pxs

-- | Estimate of the differential of relative entropy with respect to the
-- 'Natural' parameters of the first argument, based a 'Sample' from the first
-- argument and the unnormalized log-density of the second.
stochasticInformationProjectionDifferential
    :: ExponentialFamily x
    => Natural # x -- ^ Model Distribution
    -> Sample x -- ^ Model Samples
    -> (SamplePoint x -> Double) -- ^ Unnormalized log-density of target distribution
    -> Mean # x -- ^ Differential Estimate
stochasticInformationProjectionDifferential :: (Natural # x) -> Sample x -> (SamplePoint x -> Double) -> Mean # x
stochasticInformationProjectionDifferential Natural # x
px Sample x
xs SamplePoint x -> Double
f =
    let mxs :: [Mean # x]
mxs = SamplePoint x -> Mean # x
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic (SamplePoint x -> Mean # x) -> Sample x -> [Mean # x]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Sample x
xs
        mys :: [Double]
mys = (\SamplePoint x
x -> SamplePoint x -> Mean # x
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic SamplePoint x
x (Mean # x) -> (Mean #* x) -> Double
forall c x. (c # x) -> (c #* x) -> Double
<.> Mean #* x
Natural # x
px Double -> Double -> Double
forall a. Num a => a -> a -> a
- SamplePoint x -> Double
f SamplePoint x
x) (SamplePoint x -> Double) -> Sample x -> [Double]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Sample x
xs
        ln :: Double
ln = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Double) -> Int -> Double
forall a b. (a -> b) -> a -> b
$ Sample x -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length Sample x
xs
        mxht :: Mean # x
mxht = Double
ln Double -> (Mean # x) -> Mean # x
forall c x. Double -> (c # x) -> c # x
/> [Mean # x] -> Mean # x
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum [Mean # x]
mxs
        myht :: Double
myht = [Double] -> Double
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum [Double]
mys Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
ln
     in (Double
ln Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1) Double -> (Mean # x) -> Mean # x
forall c x. Double -> (c # x) -> c # x
/> [Mean # x] -> Mean # x
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum [ (Double
my Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
myht) Double -> (Mean # x) -> Mean # x
forall c x. Double -> (c # x) -> c # x
.> (Mean # x
mx (Mean # x) -> (Mean # x) -> Mean # x
forall a. Num a => a -> a -> a
- Mean # x
mxht) | (Mean # x
mx,Double
my) <- [Mean # x] -> [Double] -> [(Mean # x, Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Mean # x]
mxs [Double]
mys ]

-- | The density of an exponential family distribution that has an exact
-- expression for the log-partition function.
exponentialFamilyLogDensities
    :: (ExponentialFamily x, Legendre x, PotentialCoordinates x ~ Natural) => Natural # x -> Sample x -> [Double]
exponentialFamilyLogDensities :: (Natural # x) -> Sample x -> [Double]
exponentialFamilyLogDensities Natural # x
p Sample x
xs = Double -> Double -> Double
forall a. Num a => a -> a -> a
subtract ((PotentialCoordinates x # x) -> Double
forall x. Legendre x => (PotentialCoordinates x # x) -> Double
potential PotentialCoordinates x # x
Natural # x
p) (Double -> Double) -> [Double] -> [Double]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Natural # x) -> Sample x -> [Double]
forall x.
ExponentialFamily x =>
(Natural # x) -> Sample x -> [Double]
unnormalizedLogDensities Natural # x
p Sample x
xs

-- | The unnormalized log-density of an arbitrary exponential family distribution.
unnormalizedLogDensities :: forall x . ExponentialFamily x => Natural # x -> Sample x -> [Double]
unnormalizedLogDensities :: (Natural # x) -> Sample x -> [Double]
unnormalizedLogDensities Natural # x
p Sample x
xs =
    (Double -> Double -> Double) -> [Double] -> [Double] -> [Double]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Double -> Double -> Double
forall a. Num a => a -> a -> a
(+) ((Natural # x) -> [Natural #* x] -> [Double]
forall x c. Manifold x => (c # x) -> [c #* x] -> [Double]
dotMap Natural # x
p ([Natural #* x] -> [Double]) -> [Natural #* x] -> [Double]
forall a b. (a -> b) -> a -> b
$ SamplePoint x -> Mean # x
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic (SamplePoint x -> Mean # x) -> Sample x -> [Mean # x]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Sample x
xs) (Proxy x -> SamplePoint x -> Double
forall x. ExponentialFamily x => Proxy x -> SamplePoint x -> Double
logBaseMeasure (Proxy x
forall k (t :: k). Proxy t
Proxy @ x) (SamplePoint x -> Double) -> Sample x -> [Double]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Sample x
xs)

-- | 'logLikelihood' for a 'LegendreExponentialFamily'.
exponentialFamilyLogLikelihood
    :: forall x . LegendreExponentialFamily x
    => Sample x -> Natural # x -> Double
exponentialFamilyLogLikelihood :: Sample x -> (Natural # x) -> Double
exponentialFamilyLogLikelihood Sample x
xs Natural # x
nq =
    let mp :: Mean # x
mp = Sample x -> Mean # x
forall x. ExponentialFamily x => Sample x -> Mean # x
averageSufficientStatistic Sample x
xs
        bm :: Double
bm = [Double] -> Double
forall (f :: Type -> Type) x.
(Foldable f, Fractional x) =>
f x -> x
average ([Double] -> Double) -> [Double] -> Double
forall a b. (a -> b) -> a -> b
$ Proxy x -> SamplePoint x -> Double
forall x. ExponentialFamily x => Proxy x -> SamplePoint x -> Double
logBaseMeasure (Proxy x
forall k (t :: k). Proxy t
Proxy :: Proxy x) (SamplePoint x -> Double) -> Sample x -> [Double]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Sample x
xs
     in -(PotentialCoordinates x # x) -> Double
forall x. Legendre x => (PotentialCoordinates x # x) -> Double
potential PotentialCoordinates x # x
Natural # x
nq Double -> Double -> Double
forall a. Num a => a -> a -> a
+ (Mean # x
mp (Mean # x) -> (Mean #* x) -> Double
forall c x. (c # x) -> (c #* x) -> Double
<.> Mean #* x
Natural # x
nq) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
bm

-- | 'logLikelihoodDifferential' for a 'LegendreExponentialFamily'.
exponentialFamilyLogLikelihoodDifferential
    :: LegendreExponentialFamily x
    => Sample x -> Natural # x -> Mean # x
exponentialFamilyLogLikelihoodDifferential :: Sample x -> (Natural # x) -> Mean # x
exponentialFamilyLogLikelihoodDifferential Sample x
xs Natural # x
nq =
    let mp :: Mean # x
mp = Sample x -> Mean # x
forall x. ExponentialFamily x => Sample x -> Mean # x
averageSufficientStatistic Sample x
xs
     in Mean # x
mp (Mean # x) -> (Mean # x) -> Mean # x
forall a. Num a => a -> a -> a
- (Natural # x) -> Mean # x
forall c d x. Transition c d x => (c # x) -> d # x
transition Natural # x
nq


--- Internal ---


replicatedlogBaseMeasure0 :: (ExponentialFamily x, Storable (SamplePoint x), KnownNat k)
                       => Proxy x -> Proxy (Replicated k x) -> S.Vector k (SamplePoint x) -> Double
replicatedlogBaseMeasure0 :: Proxy x
-> Proxy (Replicated k x) -> Vector k (SamplePoint x) -> Double
replicatedlogBaseMeasure0 Proxy x
prxym Proxy (Replicated k x)
_ Vector k (SamplePoint x)
xs = Vector k Double -> Double
forall a (n :: Nat). (Storable a, Num a) => Vector n a -> a
S.sum (Vector k Double -> Double) -> Vector k Double -> Double
forall a b. (a -> b) -> a -> b
$ (SamplePoint x -> Double)
-> Vector k (SamplePoint x) -> Vector k Double
forall a b (n :: Nat).
(Storable a, Storable b) =>
(a -> b) -> Vector n a -> Vector n b
S.map (Proxy x -> SamplePoint x -> Double
forall x. ExponentialFamily x => Proxy x -> SamplePoint x -> Double
logBaseMeasure Proxy x
prxym) Vector k (SamplePoint x)
xs

pairlogBaseMeasure
    :: (ExponentialFamily x, ExponentialFamily y)
    => Proxy x
    -> Proxy y
    -> Proxy (x,y)
    -> SamplePoint (x,y)
    -> Double
pairlogBaseMeasure :: Proxy x -> Proxy y -> Proxy (x, y) -> SamplePoint (x, y) -> Double
pairlogBaseMeasure Proxy x
prxym Proxy y
prxyn Proxy (x, y)
_ (xm,xn) =
     Proxy x -> SamplePoint x -> Double
forall x. ExponentialFamily x => Proxy x -> SamplePoint x -> Double
logBaseMeasure Proxy x
prxym SamplePoint x
xm Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Proxy y -> SamplePoint y -> Double
forall x. ExponentialFamily x => Proxy x -> SamplePoint x -> Double
logBaseMeasure Proxy y
prxyn SamplePoint y
xn


--- Instances ---


-- Replicated --

instance Transition Natural Natural x where
    transition :: (Natural # x) -> Natural # x
transition = (Natural # x) -> Natural # x
forall a. a -> a
id

instance Transition Mean Mean x where
    transition :: (Mean # x) -> Mean # x
transition = (Mean # x) -> Mean # x
forall a. a -> a
id

instance Transition Source Source x where
    transition :: (Source # x) -> Source # x
transition = (Source # x) -> Source # x
forall a. a -> a
id

instance (ExponentialFamily x, Storable (SamplePoint x), KnownNat k)
  => ExponentialFamily (Replicated k x) where
    sufficientStatistic :: SamplePoint (Replicated k x) -> Mean # Replicated k x
sufficientStatistic SamplePoint (Replicated k x)
xs = Vector k (Mean # x) -> Mean # Replicated k x
forall (k :: Nat) x c.
(KnownNat k, Manifold x) =>
Vector k (c # x) -> c # Replicated k x
joinReplicated (Vector k (Mean # x) -> Mean # Replicated k x)
-> Vector k (Mean # x) -> Mean # Replicated k x
forall a b. (a -> b) -> a -> b
$ (SamplePoint x -> Mean # x)
-> Vector k (SamplePoint x) -> Vector k (Mean # x)
forall a b (n :: Nat).
(Storable a, Storable b) =>
(a -> b) -> Vector n a -> Vector n b
S.map SamplePoint x -> Mean # x
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic Vector k (SamplePoint x)
SamplePoint (Replicated k x)
xs
    logBaseMeasure :: Proxy (Replicated k x) -> SamplePoint (Replicated k x) -> Double
logBaseMeasure = Proxy x
-> Proxy (Replicated k x) -> Vector k (SamplePoint x) -> Double
forall x (k :: Nat).
(ExponentialFamily x, Storable (SamplePoint x), KnownNat k) =>
Proxy x
-> Proxy (Replicated k x) -> Vector k (SamplePoint x) -> Double
replicatedlogBaseMeasure0 Proxy x
forall k (t :: k). Proxy t
Proxy

-- Sum --

instance (ExponentialFamily x, ExponentialFamily y) => ExponentialFamily (x,y) where
    sufficientStatistic :: SamplePoint (x, y) -> Mean # (x, y)
sufficientStatistic (xm,xn) =
         (Mean # First (x, y)) -> (Mean # Second (x, y)) -> Mean # (x, y)
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join (SamplePoint x -> Mean # x
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic SamplePoint x
xm) (SamplePoint y -> Mean # y
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic SamplePoint y
xn)
    logBaseMeasure :: Proxy (x, y) -> SamplePoint (x, y) -> Double
logBaseMeasure = Proxy x -> Proxy y -> Proxy (x, y) -> SamplePoint (x, y) -> Double
forall x y.
(ExponentialFamily x, ExponentialFamily y) =>
Proxy x -> Proxy y -> Proxy (x, y) -> SamplePoint (x, y) -> Double
pairlogBaseMeasure Proxy x
forall k (t :: k). Proxy t
Proxy Proxy y
forall k (t :: k). Proxy t
Proxy


-- Source Coordinates --

instance Primal Source where
    type Dual Source = Source