{-# LANGUAGE GeneralizedNewtypeDeriving #-}

{-|
An implementation of the Indian buffet process by [Griffiths and Ghahramani](https://papers.nips.cc/paper_files/paper/2005/file/2ef35a8b78b572a47f56846acbeef5d3-Paper.pdf).

We are using abstract types to hide the implementation details, inspired by [Exchangeable Random Processes and Data Abstraction](https://www.cs.ox.ac.uk/people/hongseok.yang/paper/pps17a.pdf). 

Illustration: [Feature extraction example](https://lazyppl-team.github.io/AdditiveClusteringDemo.html). 
-} 

module LazyPPL.Distributions.IBP where

import LazyPPL
import LazyPPL.Distributions
import LazyPPL.Distributions.Counter
import LazyPPL.Distributions.Memoization

import Data.List


 
-- Some abstract types 
newtype Restaurant = R ([[Bool]], Counter)  
newtype Dish = D Int  deriving (Dish -> Dish -> Bool
(Dish -> Dish -> Bool) -> (Dish -> Dish -> Bool) -> Eq Dish
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Dish -> Dish -> Bool
== :: Dish -> Dish -> Bool
$c/= :: Dish -> Dish -> Bool
/= :: Dish -> Dish -> Bool
Eq,Eq Dish
Eq Dish =>
(Dish -> Dish -> Ordering)
-> (Dish -> Dish -> Bool)
-> (Dish -> Dish -> Bool)
-> (Dish -> Dish -> Bool)
-> (Dish -> Dish -> Bool)
-> (Dish -> Dish -> Dish)
-> (Dish -> Dish -> Dish)
-> Ord Dish
Dish -> Dish -> Bool
Dish -> Dish -> Ordering
Dish -> Dish -> Dish
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Dish -> Dish -> Ordering
compare :: Dish -> Dish -> Ordering
$c< :: Dish -> Dish -> Bool
< :: Dish -> Dish -> Bool
$c<= :: Dish -> Dish -> Bool
<= :: Dish -> Dish -> Bool
$c> :: Dish -> Dish -> Bool
> :: Dish -> Dish -> Bool
$c>= :: Dish -> Dish -> Bool
>= :: Dish -> Dish -> Bool
$cmax :: Dish -> Dish -> Dish
max :: Dish -> Dish -> Dish
$cmin :: Dish -> Dish -> Dish
min :: Dish -> Dish -> Dish
Ord,Int -> Dish -> ShowS
[Dish] -> ShowS
Dish -> String
(Int -> Dish -> ShowS)
-> (Dish -> String) -> ([Dish] -> ShowS) -> Show Dish
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Dish -> ShowS
showsPrec :: Int -> Dish -> ShowS
$cshow :: Dish -> String
show :: Dish -> String
$cshowList :: [Dish] -> ShowS
showList :: [Dish] -> ShowS
Show,MonadMemo Prob)

newCustomer :: Restaurant -> Prob [Dish]
newCustomer :: Restaurant -> Prob [Dish]
newCustomer (R ([[Bool]]
matrix, Counter
ref)) = do 
    Int
i <- Counter -> Prob Int
readAndIncrement Counter
ref 
    [Dish] -> Prob [Dish]
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return [ Int -> Dish
D Int
k | Int
k <- [Int
0..([Bool] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([[Bool]]
matrix[[Bool]] -> Int -> [Bool]
forall a. HasCallStack => [a] -> Int -> a
!!Int
i) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)], [[Bool]]
matrix[[Bool]] -> Int -> [Bool]
forall a. HasCallStack => [a] -> Int -> a
!!Int
i[Bool] -> Int -> Bool
forall a. HasCallStack => [a] -> Int -> a
!!Int
k ]

            
newRestaurant :: Double -> Prob Restaurant 
newRestaurant :: Double -> Prob Restaurant
newRestaurant Double
alpha = do
        Double
r <- Prob Double
uniform 
        Counter
ref <- Prob Counter
newCounter
        [[Bool]]
matrix <- Double -> Prob [[Bool]]
ibp Double
alpha  
        Restaurant -> Prob Restaurant
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return (Restaurant -> Prob Restaurant) -> Restaurant -> Prob Restaurant
forall a b. (a -> b) -> a -> b
$ ([[Bool]], Counter) -> Restaurant
R ([[Bool]]
matrix, Counter
ref) 


matrix :: Double -> Int -> [Int] -> Prob [[Bool]] 
matrix :: Double -> Int -> [Int] -> Prob [[Bool]]
matrix Double
alpha Int
index [Int]
features = 
     do 
        let i :: Double
i = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
index
        [Bool]
existingDishes <- (Int -> Prob Bool) -> [Int] -> Prob [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
m -> Double -> Prob Bool
bernoulli ((Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
i)) [Int]
features 
        let newFeatures :: [Int]
newFeatures = (Int -> Bool -> Int) -> [Int] -> [Bool] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
a -> \Bool
b -> if Bool
b then Int
a Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 else Int
a) [Int]
features [Bool]
existingDishes 
        Int
nNewDishes     <- (Integer -> Int) -> Prob Integer -> Prob Int
forall a b. (a -> b) -> Prob a -> Prob b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Prob Integer -> Prob Int) -> Prob Integer -> Prob Int
forall a b. (a -> b) -> a -> b
$ Double -> Prob Integer
poisson (Double
alpha Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
i) 
        let fixZero :: Int
fixZero = if [Int]
features [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [] Bool -> Bool -> Bool
&& Int
nNewDishes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then Int
1 else Int
nNewDishes  
        let newRow :: [Bool]
newRow = [Bool]
existingDishes [Bool] -> [Bool] -> [Bool]
forall a. [a] -> [a] -> [a]
++ (Int -> [Bool] -> [Bool]
forall a. Int -> [a] -> [a]
take Int
fixZero ([Bool] -> [Bool]) -> [Bool] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Bool -> [Bool]
forall a. a -> [a]
repeat Bool
True) 
        [[Bool]]
rest           <- Double -> Int -> [Int] -> Prob [[Bool]]
matrix Double
alpha (Int
index Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ([Int]
newFeatures [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
fixZero ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ Int -> [Int]
forall a. a -> [a]
repeat Int
1)) 
        [[Bool]] -> Prob [[Bool]]
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return ([[Bool]] -> Prob [[Bool]]) -> [[Bool]] -> Prob [[Bool]]
forall a b. (a -> b) -> a -> b
$ [Bool]
newRow [Bool] -> [[Bool]] -> [[Bool]]
forall a. a -> [a] -> [a]
: [[Bool]]
rest  

-- the distribution on matrices 
ibp :: Double -> Prob [[Bool]]  
ibp :: Double -> Prob [[Bool]]
ibp Double
alpha = Double -> Int -> [Int] -> Prob [[Bool]]
matrix Double
alpha Int
1 [] 



{--
Another possible implementation of the indian buffet process 
which uses a truncated stickbreaking construction. 
It is only an approximation to the true IBP, but doesn't need IO.   

See also 
Stick-breaking Construction for the Indian Buffet Process
Teh, Gorur, Ghahramani. AISTATS 2007.

A stochastic programming perspective on nonparametric Bayes
Daniel M. Roy, Vikash Mansinghka, Noah Goodman, and Joshua Tenenbaum
ICML Workshop on Nonparametric Bayesian, 2008. 
--}
data RestaurantS = RS [Double] 

data DishS = DS Int deriving (DishS -> DishS -> Bool
(DishS -> DishS -> Bool) -> (DishS -> DishS -> Bool) -> Eq DishS
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: DishS -> DishS -> Bool
== :: DishS -> DishS -> Bool
$c/= :: DishS -> DishS -> Bool
/= :: DishS -> DishS -> Bool
Eq,Eq DishS
Eq DishS =>
(DishS -> DishS -> Ordering)
-> (DishS -> DishS -> Bool)
-> (DishS -> DishS -> Bool)
-> (DishS -> DishS -> Bool)
-> (DishS -> DishS -> Bool)
-> (DishS -> DishS -> DishS)
-> (DishS -> DishS -> DishS)
-> Ord DishS
DishS -> DishS -> Bool
DishS -> DishS -> Ordering
DishS -> DishS -> DishS
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: DishS -> DishS -> Ordering
compare :: DishS -> DishS -> Ordering
$c< :: DishS -> DishS -> Bool
< :: DishS -> DishS -> Bool
$c<= :: DishS -> DishS -> Bool
<= :: DishS -> DishS -> Bool
$c> :: DishS -> DishS -> Bool
> :: DishS -> DishS -> Bool
$c>= :: DishS -> DishS -> Bool
>= :: DishS -> DishS -> Bool
$cmax :: DishS -> DishS -> DishS
max :: DishS -> DishS -> DishS
$cmin :: DishS -> DishS -> DishS
min :: DishS -> DishS -> DishS
Ord,Int -> DishS -> ShowS
[DishS] -> ShowS
DishS -> String
(Int -> DishS -> ShowS)
-> (DishS -> String) -> ([DishS] -> ShowS) -> Show DishS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> DishS -> ShowS
showsPrec :: Int -> DishS -> ShowS
$cshow :: DishS -> String
show :: DishS -> String
$cshowList :: [DishS] -> ShowS
showList :: [DishS] -> ShowS
Show)

newCustomerS :: RestaurantS -> Prob [DishS]
newCustomerS :: RestaurantS -> Prob [DishS]
newCustomerS (RS [Double]
rs) = 
    do [Bool]
fs <- (Double -> Prob Bool) -> [Double] -> Prob [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Double -> Prob Bool
bernoulli [Double]
rs
       [DishS] -> Prob [DishS]
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return ([DishS] -> Prob [DishS]) -> [DishS] -> Prob [DishS]
forall a b. (a -> b) -> a -> b
$ (Int -> DishS) -> [Int] -> [DishS]
forall a b. (a -> b) -> [a] -> [b]
map Int -> DishS
DS ([Int] -> [DishS]) -> [Int] -> [DishS]
forall a b. (a -> b) -> a -> b
$ (Bool -> Bool) -> [Bool] -> [Int]
forall a. (a -> Bool) -> [a] -> [Int]
findIndices Bool -> Bool
forall a. a -> a
id [Bool]
fs

newRestaurantS :: Double -> Prob RestaurantS 
newRestaurantS :: Double -> Prob RestaurantS
newRestaurantS Double
a = ([Double] -> RestaurantS) -> Prob [Double] -> Prob RestaurantS
forall a b. (a -> b) -> Prob a -> Prob b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Double] -> RestaurantS
RS (Prob [Double] -> Prob RestaurantS)
-> Prob [Double] -> Prob RestaurantS
forall a b. (a -> b) -> a -> b
$ Double -> Prob [Double]
stickScale Double
1
  where stickScale :: Double -> Prob [Double]
stickScale Double
p = do Double
r' <- Double -> Double -> Prob Double
beta Double
a Double
1
                          let r :: Double
r = Double
p Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
r'
                          -- Truncate when the probabilities are getting small
                          [Double]
rs <- if Double
r Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0.01 then [Double] -> Prob [Double]
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return [] else Double -> Prob [Double]
stickScale Double
r
                          [Double] -> Prob [Double]
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Double] -> Prob [Double]) -> [Double] -> Prob [Double]
forall a b. (a -> b) -> a -> b
$ Double
r Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: [Double]
rs