{-# LANGUAGE
MultiParamTypeClasses,
FlexibleInstances, FlexibleContexts,
CPP
#-}
{-# OPTIONS_GHC -fno-warn-simplifiable-class-constraints #-}
module Data.Random.Distribution.Categorical
( Categorical
, categorical, categoricalT
, weightedCategorical, weightedCategoricalT
, fromList, toList, totalWeight, numEvents
, fromWeightedList, fromObservations
, mapCategoricalPs, normalizeCategoricalPs
, collectEvents, collectEventsBy
) where
import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Uniform
import Control.Arrow
import Control.Monad
import Control.Monad.ST
import Data.Foldable (Foldable(foldMap))
import Data.STRef
import Data.Traversable (Traversable(traverse, sequenceA))
import Data.List
import Data.Function
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
categorical :: (Num p, Distribution (Categorical p) a) => [(p,a)] -> RVar a
categorical = rvar . fromList
categoricalT :: (Num p, Distribution (Categorical p) a) => [(p,a)] -> RVarT m a
categoricalT = rvarT . fromList
weightedCategorical :: (Fractional p, Eq p, Distribution (Categorical p) a) => [(p,a)] -> RVar a
weightedCategorical = rvar . fromWeightedList
weightedCategoricalT :: (Fractional p, Eq p, Distribution (Categorical p) a) => [(p,a)] -> RVarT m a
weightedCategoricalT = rvarT . fromWeightedList
{-# INLINE fromList #-}
fromList :: (Num p) => [(p,a)] -> Categorical p a
fromList xs = Categorical (V.fromList (scanl1 f xs))
where f (p0, _) (p1, y) = (p0 + p1, y)
{-# INLINE toList #-}
toList :: (Num p) => Categorical p a -> [(p,a)]
toList (Categorical ds) = V.foldr' g [] ds
where
g x [] = [x]
g x@(p0,_) ((p1, y):xs) = x : (p1-p0,y) : xs
totalWeight :: Num p => Categorical p a -> p
totalWeight (Categorical ds)
| V.null ds = 0
| otherwise = fst (V.last ds)
numEvents :: Categorical p a -> Int
numEvents (Categorical ds) = V.length ds
fromWeightedList :: (Fractional p, Eq p) => [(p,a)] -> Categorical p a
fromWeightedList = normalizeCategoricalPs . fromList
fromObservations :: (Fractional p, Eq p, Ord a) => [a] -> Categorical p a
fromObservations = fromWeightedList . map (genericLength &&& head) . group . sort
newtype Categorical p a = Categorical (V.Vector (p, a))
deriving Eq
instance (Num p, Show p, Show a) => Show (Categorical p a) where
showsPrec p cat = showParen (p>10)
( showString "fromList "
. showsPrec 11 (toList cat)
)
instance (Num p, Read p, Read a) => Read (Categorical p a) where
readsPrec p = readParen (p > 10) $ \str -> do
("fromList", valStr) <- lex str
(vals, rest) <- readsPrec 11 valStr
return (fromList vals, rest)
instance (Fractional p, Ord p, Distribution Uniform p) => Distribution (Categorical p) a where
rvarT (Categorical ds)
| V.null ds = error "categorical distribution over empty set cannot be sampled"
| n == 1 = return (snd (V.head ds))
| otherwise = do
u <- uniformT 0 (fst (V.last ds))
let
p i = fst (ds V.! i)
x i = snd (ds V.! i)
findEvent i j
| j <= i = x j
| u <= p m = findEvent i m
| otherwise = findEvent (max m (i+1)) j
where
m = (i + j) `div` 2
return $! if u <= 0 then x 0 else findEvent 0 (n-1)
where n = V.length ds
instance Functor (Categorical p) where
fmap f (Categorical ds) = Categorical (V.map (second f) ds)
instance Foldable (Categorical p) where
foldMap f (Categorical ds) = foldMap (f . snd) (V.toList ds)
instance Traversable (Categorical p) where
traverse f (Categorical ds) = Categorical . V.fromList <$> traverse (\(p,e) -> (\e' -> (p,e')) <$> f e) (V.toList ds)
sequenceA (Categorical ds) = Categorical . V.fromList <$> traverse (\(p,e) -> (\e' -> (p,e')) <$> e) (V.toList ds)
instance Fractional p => Monad (Categorical p) where
return x = Categorical (V.singleton (1, x))
#if __GLASGOW_HASKELL__ < 808
fail _ = Categorical V.empty
#endif
xs >>= f = fromList $ do
(p, x) <- toList xs
(q, y) <- toList (f x)
return (p * q, y)
instance Fractional p => Applicative (Categorical p) where
pure = return
(<*>) = ap
mapCategoricalPs :: (Num p, Num q) => (p -> q) -> Categorical p e -> Categorical q e
mapCategoricalPs f = fromList . map (first f) . toList
normalizeCategoricalPs :: (Fractional p, Eq p) => Categorical p e -> Categorical p e
normalizeCategoricalPs orig@(Categorical ds)
| ps == 0 = Categorical V.empty
| otherwise = runST $ do
lastP <- newSTRef 0
nDups <- newSTRef 0
normalized <- V.thaw ds
let n = V.length ds
skip = modifySTRef' nDups (1+)
save i p x = do
d <- readSTRef nDups
MV.write normalized (i-d) (p, x)
sequence_
[ do
let (p,x) = ds V.! i
p0 <- readSTRef lastP
if p == p0
then skip
else do
save i (p * scale) x
writeSTRef lastP $! p
| i <- [0..n-1]
]
d <- readSTRef nDups
let n' = n-d
(_,lastX) <- MV.read normalized (n'-1)
MV.write normalized (n'-1) (1,lastX)
Categorical <$> V.unsafeFreeze (MV.unsafeSlice 0 n' normalized)
where
ps = totalWeight orig
scale = recip ps
#if __GLASGOW_HASKELL__ < 706
modifySTRef' :: STRef s a -> (a -> a) -> ST s ()
modifySTRef' x f = do
v <- readSTRef x
let fv = f v
fv `seq` writeSTRef x fv
#endif
collectEvents :: (Ord e, Num p, Ord p) => Categorical p e -> Categorical p e
collectEvents = collectEventsBy compare ((sum *** head) . unzip)
collectEventsBy :: Num p => (e -> e -> Ordering) -> ([(p,e)] -> (p,e))-> Categorical p e -> Categorical p e
collectEventsBy compareE combine =
fromList . map combine . groupEvents . sortEvents . toList
where
groupEvents = groupBy (\x y -> snd x `compareE` snd y == EQ)
sortEvents = sortBy (compareE `on` snd)