{-# LANGUAGE GeneralizedNewtypeDeriving #-}

{-| Abstract types for the [Dirichlet Process](https://en.wikipedia.org/wiki/Dirichlet_process) viewed through the interface of the [Chinese Restaurant Process](https://en.wikipedia.org/wiki/Chinese_restaurant_process).

Ideas following [S. Staton, H. Yang, N. L. Ackerman, C. Freer, D. Roy. Exchangeable random process and data abstraction. Workshop on probabilistic programming semantics (PPS 2017).](https://www.cs.ox.ac.uk/people/hongseok.yang/paper/pps17a.pdf)

Our implementation here uses stick breaking, with a lazily broken stick. Other urn-based implementations are possible with hidden state, and they should be observationally equivalent.

For illustrations, see [non-parametric clustering](https://lazyppl-team.github.io/ClusteringDemo.html) and [relational inference](https://lazyppl-team.github.io/IrmDemo.html).

-}

module LazyPPL.Distributions.DirichletP (
{- * Chinese Restaurant Process interface -}
{- | For clustering, we regard each data point as a "customer" in a "restaurant", and they are in the same cluster if they sit at the same `Table`. 
-}
Restaurant, Table, newRestaurant, newCustomer, 
-- * Random distribution interface
dp) where

import Data.List
import Data.Maybe
import LazyPPL
import LazyPPL.Distributions
import LazyPPL.Distributions.Memoization (MonadMemo)


-- | Abstract type of restaurants
newtype Restaurant = R [Double]

-- | Abstract type of tables. This supports `Eq` so that we can ask whether customers are at the same table (i.e. whether points are in the same cluster). 
newtype Table = T Int deriving (Table -> Table -> Bool
(Table -> Table -> Bool) -> (Table -> Table -> Bool) -> Eq Table
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Table -> Table -> Bool
== :: Table -> Table -> Bool
$c/= :: Table -> Table -> Bool
/= :: Table -> Table -> Bool
Eq, Int -> Table -> ShowS
[Table] -> ShowS
Table -> String
(Int -> Table -> ShowS)
-> (Table -> String) -> ([Table] -> ShowS) -> Show Table
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Table -> ShowS
showsPrec :: Int -> Table -> ShowS
$cshow :: Table -> String
show :: Table -> String
$cshowList :: [Table] -> ShowS
showList :: [Table] -> ShowS
Show, MonadMemo Prob)

{-| A customer enters the restaurant and is assigned a table. -}
newCustomer :: Restaurant -> Prob Table
newCustomer :: Restaurant -> Prob Table
newCustomer (R [Double]
restaurant) =
  do
    Double
r <- Prob Double
uniform
    Table -> Prob Table
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return (Table -> Prob Table) -> Table -> Prob Table
forall a b. (a -> b) -> a -> b
$ Int -> Table
T (Int -> Table) -> Int -> Table
forall a b. (a -> b) -> a -> b
$ Maybe Int -> Int
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Int -> Int) -> Maybe Int -> Int
forall a b. (a -> b) -> a -> b
$ (Double -> Bool) -> [Double] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
findIndex (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
r) ((Double -> Double -> Double) -> [Double] -> [Double]
forall a. (a -> a -> a) -> [a] -> [a]
scanl1 Double -> Double -> Double
forall a. Num a => a -> a -> a
(+) [Double]
restaurant)

{-| Create a new restaurant with concentration parameter alpha. -}
newRestaurant :: Double -- ^ Concentration parameter, alpha
              -> Prob Restaurant
newRestaurant :: Double -> Prob Restaurant
newRestaurant Double
alpha = do
  [Double]
sticks <- Double -> Double -> Prob [Double]
stickBreaking Double
alpha Double
0
  Restaurant -> Prob Restaurant
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return (Restaurant -> Prob Restaurant) -> Restaurant -> Prob Restaurant
forall a b. (a -> b) -> a -> b
$ [Double] -> Restaurant
R [Double]
sticks

{- | Stick breaking breaks the unit interval into an
    infinite number of parts (lazily) --}
stickBreaking :: Double -> Double -> Prob [Double]
stickBreaking :: Double -> Double -> Prob [Double]
stickBreaking Double
alpha Double
lower =
  do
    Double
r <- Double -> Double -> Prob Double
beta Double
1 Double
alpha
    let v :: Double
v = Double
r Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
1 Double -> Double -> Double
forall a. Num a => a -> a -> a
- Double
lower)
    [Double]
vs <- Double -> Double -> Prob [Double]
stickBreaking Double
alpha (Double
lower Double -> Double -> Double
forall a. Num a => a -> a -> a
+ Double
v)
    [Double] -> Prob [Double]
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return (Double
v Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: [Double]
vs)

{-| [Dirichlet Process](https://en.wikipedia.org/wiki/Dirichlet_process) as a random distribution. -}
dp :: Double -- ^ Concentration parameter, alpha
   -> Prob a -- ^ Base distribution
   -> Prob (Prob a)
dp :: forall a. Double -> Prob a -> Prob (Prob a)
dp Double
alpha Prob a
p = do
  [a]
xs <- Prob a -> Prob [a]
forall a. Prob a -> Prob [a]
iid Prob a
p
  [Double]
vs <- Double -> Double -> Prob [Double]
stickBreaking Double
alpha Double
0
  Prob a -> Prob (Prob a)
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return (Prob a -> Prob (Prob a)) -> Prob a -> Prob (Prob a)
forall a b. (a -> b) -> a -> b
$ do
    Double
r <- Prob Double
uniform
    a -> Prob a
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return (a -> Prob a) -> a -> Prob a
forall a b. (a -> b) -> a -> b
$ [a]
xs [a] -> Int -> a
forall a. HasCallStack => [a] -> Int -> a
!! Maybe Int -> Int
forall a. HasCallStack => Maybe a -> a
fromJust ((Double -> Bool) -> [Double] -> Maybe Int
forall a. (a -> Bool) -> [a] -> Maybe Int
findIndex (Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
r) ((Double -> Double -> Double) -> [Double] -> [Double]
forall a. (a -> a -> a) -> [a] -> [a]
scanl1 Double -> Double -> Double
forall a. Num a => a -> a -> a
(+) [Double]
vs))