{-# LANGUAGE GeneralizedNewtypeDeriving #-}
module LazyPPL.Distributions.IBP where
import LazyPPL
import LazyPPL.Distributions
import LazyPPL.Distributions.Counter
import LazyPPL.Distributions.Memoization
import Data.List
newtype Restaurant = R ([[Bool]], Counter)
newtype Dish = D Int deriving (Dish -> Dish -> Bool
(Dish -> Dish -> Bool) -> (Dish -> Dish -> Bool) -> Eq Dish
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Dish -> Dish -> Bool
== :: Dish -> Dish -> Bool
$c/= :: Dish -> Dish -> Bool
/= :: Dish -> Dish -> Bool
Eq,Eq Dish
Eq Dish =>
(Dish -> Dish -> Ordering)
-> (Dish -> Dish -> Bool)
-> (Dish -> Dish -> Bool)
-> (Dish -> Dish -> Bool)
-> (Dish -> Dish -> Bool)
-> (Dish -> Dish -> Dish)
-> (Dish -> Dish -> Dish)
-> Ord Dish
Dish -> Dish -> Bool
Dish -> Dish -> Ordering
Dish -> Dish -> Dish
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: Dish -> Dish -> Ordering
compare :: Dish -> Dish -> Ordering
$c< :: Dish -> Dish -> Bool
< :: Dish -> Dish -> Bool
$c<= :: Dish -> Dish -> Bool
<= :: Dish -> Dish -> Bool
$c> :: Dish -> Dish -> Bool
> :: Dish -> Dish -> Bool
$c>= :: Dish -> Dish -> Bool
>= :: Dish -> Dish -> Bool
$cmax :: Dish -> Dish -> Dish
max :: Dish -> Dish -> Dish
$cmin :: Dish -> Dish -> Dish
min :: Dish -> Dish -> Dish
Ord,Int -> Dish -> ShowS
[Dish] -> ShowS
Dish -> String
(Int -> Dish -> ShowS)
-> (Dish -> String) -> ([Dish] -> ShowS) -> Show Dish
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Dish -> ShowS
showsPrec :: Int -> Dish -> ShowS
$cshow :: Dish -> String
show :: Dish -> String
$cshowList :: [Dish] -> ShowS
showList :: [Dish] -> ShowS
Show,MonadMemo Prob)
newCustomer :: Restaurant -> Prob [Dish]
newCustomer :: Restaurant -> Prob [Dish]
newCustomer (R ([[Bool]]
matrix, Counter
ref)) = do
Int
i <- Counter -> Prob Int
readAndIncrement Counter
ref
[Dish] -> Prob [Dish]
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return [ Int -> Dish
D Int
k | Int
k <- [Int
0..([Bool] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([[Bool]]
matrix[[Bool]] -> Int -> [Bool]
forall a. HasCallStack => [a] -> Int -> a
!!Int
i) Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)], [[Bool]]
matrix[[Bool]] -> Int -> [Bool]
forall a. HasCallStack => [a] -> Int -> a
!!Int
i[Bool] -> Int -> Bool
forall a. HasCallStack => [a] -> Int -> a
!!Int
k ]
newRestaurant :: Double -> Prob Restaurant
newRestaurant :: Double -> Prob Restaurant
newRestaurant Double
alpha = do
Double
r <- Prob Double
uniform
Counter
ref <- Prob Counter
newCounter
[[Bool]]
matrix <- Double -> Prob [[Bool]]
ibp Double
alpha
Restaurant -> Prob Restaurant
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return (Restaurant -> Prob Restaurant) -> Restaurant -> Prob Restaurant
forall a b. (a -> b) -> a -> b
$ ([[Bool]], Counter) -> Restaurant
R ([[Bool]]
matrix, Counter
ref)
matrix :: Double -> Int -> [Int] -> Prob [[Bool]]
matrix :: Double -> Int -> [Int] -> Prob [[Bool]]
matrix Double
alpha Int
index [Int]
features =
do
let i :: Double
i = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
index
[Bool]
existingDishes <- (Int -> Prob Bool) -> [Int] -> Prob [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\Int
m -> Double -> Prob Bool
bernoulli ((Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
m) Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
i)) [Int]
features
let newFeatures :: [Int]
newFeatures = (Int -> Bool -> Int) -> [Int] -> [Bool] -> [Int]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith (\Int
a -> \Bool
b -> if Bool
b then Int
a Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 else Int
a) [Int]
features [Bool]
existingDishes
Int
nNewDishes <- (Integer -> Int) -> Prob Integer -> Prob Int
forall a b. (a -> b) -> Prob a -> Prob b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Integer -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Prob Integer -> Prob Int) -> Prob Integer -> Prob Int
forall a b. (a -> b) -> a -> b
$ Double -> Prob Integer
poisson (Double
alpha Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
i)
let fixZero :: Int
fixZero = if [Int]
features [Int] -> [Int] -> Bool
forall a. Eq a => a -> a -> Bool
== [] Bool -> Bool -> Bool
&& Int
nNewDishes Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 then Int
1 else Int
nNewDishes
let newRow :: [Bool]
newRow = [Bool]
existingDishes [Bool] -> [Bool] -> [Bool]
forall a. [a] -> [a] -> [a]
++ (Int -> [Bool] -> [Bool]
forall a. Int -> [a] -> [a]
take Int
fixZero ([Bool] -> [Bool]) -> [Bool] -> [Bool]
forall a b. (a -> b) -> a -> b
$ Bool -> [Bool]
forall a. a -> [a]
repeat Bool
True)
[[Bool]]
rest <- Double -> Int -> [Int] -> Prob [[Bool]]
matrix Double
alpha (Int
index Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ([Int]
newFeatures [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int -> [Int] -> [Int]
forall a. Int -> [a] -> [a]
take Int
fixZero ([Int] -> [Int]) -> [Int] -> [Int]
forall a b. (a -> b) -> a -> b
$ Int -> [Int]
forall a. a -> [a]
repeat Int
1))
[[Bool]] -> Prob [[Bool]]
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return ([[Bool]] -> Prob [[Bool]]) -> [[Bool]] -> Prob [[Bool]]
forall a b. (a -> b) -> a -> b
$ [Bool]
newRow [Bool] -> [[Bool]] -> [[Bool]]
forall a. a -> [a] -> [a]
: [[Bool]]
rest
ibp :: Double -> Prob [[Bool]]
ibp :: Double -> Prob [[Bool]]
ibp Double
alpha = Double -> Int -> [Int] -> Prob [[Bool]]
matrix Double
alpha Int
1 []
data RestaurantS = RS [Double]
data DishS = DS Int deriving (DishS -> DishS -> Bool
(DishS -> DishS -> Bool) -> (DishS -> DishS -> Bool) -> Eq DishS
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: DishS -> DishS -> Bool
== :: DishS -> DishS -> Bool
$c/= :: DishS -> DishS -> Bool
/= :: DishS -> DishS -> Bool
Eq,Eq DishS
Eq DishS =>
(DishS -> DishS -> Ordering)
-> (DishS -> DishS -> Bool)
-> (DishS -> DishS -> Bool)
-> (DishS -> DishS -> Bool)
-> (DishS -> DishS -> Bool)
-> (DishS -> DishS -> DishS)
-> (DishS -> DishS -> DishS)
-> Ord DishS
DishS -> DishS -> Bool
DishS -> DishS -> Ordering
DishS -> DishS -> DishS
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: DishS -> DishS -> Ordering
compare :: DishS -> DishS -> Ordering
$c< :: DishS -> DishS -> Bool
< :: DishS -> DishS -> Bool
$c<= :: DishS -> DishS -> Bool
<= :: DishS -> DishS -> Bool
$c> :: DishS -> DishS -> Bool
> :: DishS -> DishS -> Bool
$c>= :: DishS -> DishS -> Bool
>= :: DishS -> DishS -> Bool
$cmax :: DishS -> DishS -> DishS
max :: DishS -> DishS -> DishS
$cmin :: DishS -> DishS -> DishS
min :: DishS -> DishS -> DishS
Ord,Int -> DishS -> ShowS
[DishS] -> ShowS
DishS -> String
(Int -> DishS -> ShowS)
-> (DishS -> String) -> ([DishS] -> ShowS) -> Show DishS
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> DishS -> ShowS
showsPrec :: Int -> DishS -> ShowS
$cshow :: DishS -> String
show :: DishS -> String
$cshowList :: [DishS] -> ShowS
showList :: [DishS] -> ShowS
Show)
newCustomerS :: RestaurantS -> Prob [DishS]
newCustomerS :: RestaurantS -> Prob [DishS]
newCustomerS (RS [Double]
rs) =
do [Bool]
fs <- (Double -> Prob Bool) -> [Double] -> Prob [Bool]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Double -> Prob Bool
bernoulli [Double]
rs
[DishS] -> Prob [DishS]
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return ([DishS] -> Prob [DishS]) -> [DishS] -> Prob [DishS]
forall a b. (a -> b) -> a -> b
$ (Int -> DishS) -> [Int] -> [DishS]
forall a b. (a -> b) -> [a] -> [b]
map Int -> DishS
DS ([Int] -> [DishS]) -> [Int] -> [DishS]
forall a b. (a -> b) -> a -> b
$ (Bool -> Bool) -> [Bool] -> [Int]
forall a. (a -> Bool) -> [a] -> [Int]
findIndices Bool -> Bool
forall a. a -> a
id [Bool]
fs
newRestaurantS :: Double -> Prob RestaurantS
newRestaurantS :: Double -> Prob RestaurantS
newRestaurantS Double
a = ([Double] -> RestaurantS) -> Prob [Double] -> Prob RestaurantS
forall a b. (a -> b) -> Prob a -> Prob b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Double] -> RestaurantS
RS (Prob [Double] -> Prob RestaurantS)
-> Prob [Double] -> Prob RestaurantS
forall a b. (a -> b) -> a -> b
$ Double -> Prob [Double]
stickScale Double
1
where stickScale :: Double -> Prob [Double]
stickScale Double
p = do Double
r' <- Double -> Double -> Prob Double
beta Double
a Double
1
let r :: Double
r = Double
p Double -> Double -> Double
forall a. Num a => a -> a -> a
* Double
r'
[Double]
rs <- if Double
r Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
0.01 then [Double] -> Prob [Double]
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return [] else Double -> Prob [Double]
stickScale Double
r
[Double] -> Prob [Double]
forall a. a -> Prob a
forall (m :: * -> *) a. Monad m => a -> m a
return ([Double] -> Prob [Double]) -> [Double] -> Prob [Double]
forall a b. (a -> b) -> a -> b
$ Double
r Double -> [Double] -> [Double]
forall a. a -> [a] -> [a]
: [Double]
rs