{-# LANGUAGE UndecidableInstances,TypeApplications #-}
module Goal.Probability.ExponentialFamily
(
ExponentialFamily (sufficientStatistic, averageSufficientStatistic, logBaseMeasure)
, LegendreExponentialFamily
, DuallyFlatExponentialFamily
, exponentialFamilyLogDensities
, unnormalizedLogDensities
, Natural
, Mean
, Source
, toNatural
, toMean
, toSource
, relativeEntropy
, crossEntropy
, relativeEntropyDifferential
, stochasticRelativeEntropyDifferential
, stochasticInformationProjectionDifferential
, exponentialFamilyLogLikelihood
, exponentialFamilyLogLikelihoodDifferential
) where
import Goal.Probability.Statistical
import Goal.Core
import Goal.Geometry
import qualified Goal.Core.Vector.Storable as S
import Foreign.Storable
data Source
data Natural
data Mean
instance Primal Natural where
type Dual Natural = Mean
instance Primal Mean where
type Dual Mean = Natural
toNatural :: (Transition c Natural x) => c # x -> Natural # x
toNatural :: (c # x) -> Natural # x
toNatural = (c # x) -> Natural # x
forall c d x. Transition c d x => (c # x) -> d # x
transition
toMean :: (Transition c Mean x) => c # x -> Mean # x
toMean :: (c # x) -> Mean # x
toMean = (c # x) -> Mean # x
forall c d x. Transition c d x => (c # x) -> d # x
transition
toSource :: (Transition c Source x) => c # x -> Source # x
toSource :: (c # x) -> Source # x
toSource = (c # x) -> Source # x
forall c d x. Transition c d x => (c # x) -> d # x
transition
class Statistical x => ExponentialFamily x where
sufficientStatistic :: SamplePoint x -> Mean # x
averageSufficientStatistic :: Sample x -> Mean # x
averageSufficientStatistic = [Mean # x] -> Mean # x
forall (f :: Type -> Type) x.
(Foldable f, Fractional x) =>
f x -> x
average ([Mean # x] -> Mean # x)
-> (Sample x -> [Mean # x]) -> Sample x -> Mean # x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SamplePoint x -> Mean # x) -> Sample x -> [Mean # x]
forall a b. (a -> b) -> [a] -> [b]
map SamplePoint x -> Mean # x
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic
logBaseMeasure :: Proxy x -> SamplePoint x -> Double
type LegendreExponentialFamily x =
( PotentialCoordinates x ~ Natural, Legendre x, ExponentialFamily x
, Transition (PotentialCoordinates x) (Dual (PotentialCoordinates x)) x )
type DuallyFlatExponentialFamily x =
( LegendreExponentialFamily x, DuallyFlat x
, Transition (Dual (PotentialCoordinates x)) (PotentialCoordinates x) x )
relativeEntropy :: DuallyFlatExponentialFamily x => Mean # x -> Natural # x -> Double
relativeEntropy :: (Mean # x) -> (Natural # x) -> Double
relativeEntropy = ((Natural # x) -> (Mean # x) -> Double)
-> (Mean # x) -> (Natural # x) -> Double
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Natural # x) -> (Mean # x) -> Double
forall x.
DuallyFlat x =>
(PotentialCoordinates x # x)
-> (PotentialCoordinates x #* x) -> Double
canonicalDivergence
crossEntropy :: DuallyFlatExponentialFamily x => Mean # x -> Natural # x ->
Double
crossEntropy :: (Mean # x) -> (Natural # x) -> Double
crossEntropy Mean # x
mp Natural # x
nq = (PotentialCoordinates x # x) -> Double
forall x. Legendre x => (PotentialCoordinates x # x) -> Double
potential PotentialCoordinates x # x
Natural # x
nq Double -> Double -> Double
forall a. Num a => a -> a -> a
- (Mean # x
mp (Mean # x) -> (Mean #* x) -> Double
forall c x. (c # x) -> (c #* x) -> Double
<.> Mean #* x
Natural # x
nq)
relativeEntropyDifferential :: LegendreExponentialFamily x => Mean # x -> Natural # x -> Mean # x
relativeEntropyDifferential :: (Mean # x) -> (Natural # x) -> Mean # x
relativeEntropyDifferential Mean # x
mp Natural # x
nq = (Natural # x) -> Mean # x
forall c d x. Transition c d x => (c # x) -> d # x
transition Natural # x
nq (Mean # x) -> (Mean # x) -> Mean # x
forall a. Num a => a -> a -> a
- Mean # x
mp
stochasticRelativeEntropyDifferential
:: ExponentialFamily x
=> Sample x
-> Sample x
-> Mean # x
stochasticRelativeEntropyDifferential :: Sample x -> Sample x -> Mean # x
stochasticRelativeEntropyDifferential Sample x
pxs Sample x
qxs =
Sample x -> Mean # x
forall x. ExponentialFamily x => Sample x -> Mean # x
averageSufficientStatistic Sample x
qxs (Mean # x) -> (Mean # x) -> Mean # x
forall a. Num a => a -> a -> a
- Sample x -> Mean # x
forall x. ExponentialFamily x => Sample x -> Mean # x
averageSufficientStatistic Sample x
pxs
stochasticInformationProjectionDifferential
:: ExponentialFamily x
=> Natural # x
-> Sample x
-> (SamplePoint x -> Double)
-> Mean # x
stochasticInformationProjectionDifferential :: (Natural # x) -> Sample x -> (SamplePoint x -> Double) -> Mean # x
stochasticInformationProjectionDifferential Natural # x
px Sample x
xs SamplePoint x -> Double
f =
let mxs :: [Mean # x]
mxs = SamplePoint x -> Mean # x
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic (SamplePoint x -> Mean # x) -> Sample x -> [Mean # x]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Sample x
xs
mys :: [Double]
mys = (\SamplePoint x
x -> SamplePoint x -> Mean # x
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic SamplePoint x
x (Mean # x) -> (Mean #* x) -> Double
forall c x. (c # x) -> (c #* x) -> Double
<.> Mean #* x
Natural # x
px Double -> Double -> Double
forall a. Num a => a -> a -> a
- SamplePoint x -> Double
f SamplePoint x
x) (SamplePoint x -> Double) -> Sample x -> [Double]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Sample x
xs
ln :: Double
ln = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Double) -> Int -> Double
forall a b. (a -> b) -> a -> b
$ Sample x -> Int
forall (t :: Type -> Type) a. Foldable t => t a -> Int
length Sample x
xs
mxht :: Mean # x
mxht = Double
ln Double -> (Mean # x) -> Mean # x
forall c x. Double -> (c # x) -> c # x
/> [Mean # x] -> Mean # x
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum [Mean # x]
mxs
myht :: Double
myht = [Double] -> Double
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum [Double]
mys Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
ln
in (Double
ln Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
1) Double -> (Mean # x) -> Mean # x
forall c x. Double -> (c # x) -> c # x
/> [Mean # x] -> Mean # x
forall (t :: Type -> Type) a. (Foldable t, Num a) => t a -> a
sum [ (Double
my Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
myht) Double -> (Mean # x) -> Mean # x
forall c x. Double -> (c # x) -> c # x
.> (Mean # x
mx (Mean # x) -> (Mean # x) -> Mean # x
forall a. Num a => a -> a -> a
- Mean # x
mxht) | (Mean # x
mx,Double
my) <- [Mean # x] -> [Double] -> [(Mean # x, Double)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Mean # x]
mxs [Double]
mys ]
exponentialFamilyLogDensities
:: (ExponentialFamily x, Legendre x, PotentialCoordinates x ~ Natural) => Natural # x -> Sample x -> [Double]
exponentialFamilyLogDensities :: (Natural # x) -> Sample x -> [Double]
exponentialFamilyLogDensities Natural # x
p Sample x
xs = Double -> Double -> Double
forall a. Num a => a -> a -> a
subtract ((PotentialCoordinates x # x) -> Double
forall x. Legendre x => (PotentialCoordinates x # x) -> Double
potential PotentialCoordinates x # x
Natural # x
p) (Double -> Double) -> [Double] -> [Double]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> (Natural # x) -> Sample x -> [Double]
forall x.
ExponentialFamily x =>
(Natural # x) -> Sample x -> [Double]
unnormalizedLogDensities Natural # x
p Sample x
xs
unnormalizedLogDensities :: forall x . ExponentialFamily x => Natural # x -> Sample x -> [Double]
unnormalizedLogDensities :: (Natural # x) -> Sample x -> [Double]
unnormalizedLogDensities Natural # x
p Sample x
xs =
(Double -> Double -> Double) -> [Double] -> [Double] -> [Double]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Double -> Double -> Double
forall a. Num a => a -> a -> a
(+) ((Natural # x) -> [Natural #* x] -> [Double]
forall x c. Manifold x => (c # x) -> [c #* x] -> [Double]
dotMap Natural # x
p ([Natural #* x] -> [Double]) -> [Natural #* x] -> [Double]
forall a b. (a -> b) -> a -> b
$ SamplePoint x -> Mean # x
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic (SamplePoint x -> Mean # x) -> Sample x -> [Mean # x]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Sample x
xs) (Proxy x -> SamplePoint x -> Double
forall x. ExponentialFamily x => Proxy x -> SamplePoint x -> Double
logBaseMeasure (Proxy x
forall k (t :: k). Proxy t
Proxy @ x) (SamplePoint x -> Double) -> Sample x -> [Double]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Sample x
xs)
exponentialFamilyLogLikelihood
:: forall x . LegendreExponentialFamily x
=> Sample x -> Natural # x -> Double
exponentialFamilyLogLikelihood :: Sample x -> (Natural # x) -> Double
exponentialFamilyLogLikelihood Sample x
xs Natural # x
nq =
let mp :: Mean # x
mp = Sample x -> Mean # x
forall x. ExponentialFamily x => Sample x -> Mean # x
averageSufficientStatistic Sample x
xs
bm :: Double
bm = [Double] -> Double
forall (f :: Type -> Type) x.
(Foldable f, Fractional x) =>
f x -> x
average ([Double] -> Double) -> [Double] -> Double
forall a b. (a -> b) -> a -> b
$ Proxy x -> SamplePoint x -> Double
forall x. ExponentialFamily x => Proxy x -> SamplePoint x -> Double
logBaseMeasure (Proxy x
forall k (t :: k). Proxy t
Proxy :: Proxy x) (SamplePoint x -> Double) -> Sample x -> [Double]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Sample x
xs
in -(PotentialCoordinates x # x) -> Double
forall x. Legendre x => (PotentialCoordinates x # x) -> Double
potential PotentialCoordinates x # x
Natural # x
nq Double -> Double -> Double
forall a. Num a => a -> a -> a
+ (Mean # x
mp (Mean # x) -> (Mean #* x) -> Double
forall c x. (c # x) -> (c #* x) -> Double
<.> Mean #* x
Natural # x
nq) Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
bm
exponentialFamilyLogLikelihoodDifferential
:: LegendreExponentialFamily x
=> Sample x -> Natural # x -> Mean # x
exponentialFamilyLogLikelihoodDifferential :: Sample x -> (Natural # x) -> Mean # x
exponentialFamilyLogLikelihoodDifferential Sample x
xs Natural # x
nq =
let mp :: Mean # x
mp = Sample x -> Mean # x
forall x. ExponentialFamily x => Sample x -> Mean # x
averageSufficientStatistic Sample x
xs
in Mean # x
mp (Mean # x) -> (Mean # x) -> Mean # x
forall a. Num a => a -> a -> a
- (Natural # x) -> Mean # x
forall c d x. Transition c d x => (c # x) -> d # x
transition Natural # x
nq
replicatedlogBaseMeasure0 :: (ExponentialFamily x, Storable (SamplePoint x), KnownNat k)
=> Proxy x -> Proxy (Replicated k x) -> S.Vector k (SamplePoint x) -> Double
replicatedlogBaseMeasure0 :: Proxy x
-> Proxy (Replicated k x) -> Vector k (SamplePoint x) -> Double
replicatedlogBaseMeasure0 Proxy x
prxym Proxy (Replicated k x)
_ Vector k (SamplePoint x)
xs = Vector k Double -> Double
forall a (n :: Nat). (Storable a, Num a) => Vector n a -> a
S.sum (Vector k Double -> Double) -> Vector k Double -> Double
forall a b. (a -> b) -> a -> b
$ (SamplePoint x -> Double)
-> Vector k (SamplePoint x) -> Vector k Double
forall a b (n :: Nat).
(Storable a, Storable b) =>
(a -> b) -> Vector n a -> Vector n b
S.map (Proxy x -> SamplePoint x -> Double
forall x. ExponentialFamily x => Proxy x -> SamplePoint x -> Double
logBaseMeasure Proxy x
prxym) Vector k (SamplePoint x)
xs
pairlogBaseMeasure
:: (ExponentialFamily x, ExponentialFamily y)
=> Proxy x
-> Proxy y
-> Proxy (x,y)
-> SamplePoint (x,y)
-> Double
pairlogBaseMeasure :: Proxy x -> Proxy y -> Proxy (x, y) -> SamplePoint (x, y) -> Double
pairlogBaseMeasure Proxy x
prxym Proxy y
prxyn Proxy (x, y)
_ (xm,xn) =
Proxy x -> SamplePoint x -> Double
forall x. ExponentialFamily x => Proxy x -> SamplePoint x -> Double
logBaseMeasure Proxy x
prxym SamplePoint x
xm Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Proxy y -> SamplePoint y -> Double
forall x. ExponentialFamily x => Proxy x -> SamplePoint x -> Double
logBaseMeasure Proxy y
prxyn SamplePoint y
xn
instance Transition Natural Natural x where
transition :: (Natural # x) -> Natural # x
transition = (Natural # x) -> Natural # x
forall a. a -> a
id
instance Transition Mean Mean x where
transition :: (Mean # x) -> Mean # x
transition = (Mean # x) -> Mean # x
forall a. a -> a
id
instance Transition Source Source x where
transition :: (Source # x) -> Source # x
transition = (Source # x) -> Source # x
forall a. a -> a
id
instance (ExponentialFamily x, Storable (SamplePoint x), KnownNat k)
=> ExponentialFamily (Replicated k x) where
sufficientStatistic :: SamplePoint (Replicated k x) -> Mean # Replicated k x
sufficientStatistic SamplePoint (Replicated k x)
xs = Vector k (Mean # x) -> Mean # Replicated k x
forall (k :: Nat) x c.
(KnownNat k, Manifold x) =>
Vector k (c # x) -> c # Replicated k x
joinReplicated (Vector k (Mean # x) -> Mean # Replicated k x)
-> Vector k (Mean # x) -> Mean # Replicated k x
forall a b. (a -> b) -> a -> b
$ (SamplePoint x -> Mean # x)
-> Vector k (SamplePoint x) -> Vector k (Mean # x)
forall a b (n :: Nat).
(Storable a, Storable b) =>
(a -> b) -> Vector n a -> Vector n b
S.map SamplePoint x -> Mean # x
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic Vector k (SamplePoint x)
SamplePoint (Replicated k x)
xs
logBaseMeasure :: Proxy (Replicated k x) -> SamplePoint (Replicated k x) -> Double
logBaseMeasure = Proxy x
-> Proxy (Replicated k x) -> Vector k (SamplePoint x) -> Double
forall x (k :: Nat).
(ExponentialFamily x, Storable (SamplePoint x), KnownNat k) =>
Proxy x
-> Proxy (Replicated k x) -> Vector k (SamplePoint x) -> Double
replicatedlogBaseMeasure0 Proxy x
forall k (t :: k). Proxy t
Proxy
instance (ExponentialFamily x, ExponentialFamily y) => ExponentialFamily (x,y) where
sufficientStatistic :: SamplePoint (x, y) -> Mean # (x, y)
sufficientStatistic (xm,xn) =
(Mean # First (x, y)) -> (Mean # Second (x, y)) -> Mean # (x, y)
forall z c. Product z => (c # First z) -> (c # Second z) -> c # z
join (SamplePoint x -> Mean # x
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic SamplePoint x
xm) (SamplePoint y -> Mean # y
forall x. ExponentialFamily x => SamplePoint x -> Mean # x
sufficientStatistic SamplePoint y
xn)
logBaseMeasure :: Proxy (x, y) -> SamplePoint (x, y) -> Double
logBaseMeasure = Proxy x -> Proxy y -> Proxy (x, y) -> SamplePoint (x, y) -> Double
forall x y.
(ExponentialFamily x, ExponentialFamily y) =>
Proxy x -> Proxy y -> Proxy (x, y) -> SamplePoint (x, y) -> Double
pairlogBaseMeasure Proxy x
forall k (t :: k). Proxy t
Proxy Proxy y
forall k (t :: k). Proxy t
Proxy
instance Primal Source where
type Dual Source = Source