{-# LANGUAGE
MultiParamTypeClasses,
FlexibleInstances, FlexibleContexts,
CPP
#-}
module Data.Random.Distribution.Categorical
( Categorical
, categorical, categoricalT
, weightedCategorical, weightedCategoricalT
, fromList, toList, totalWeight, numEvents
, fromWeightedList, fromObservations
, mapCategoricalPs, normalizeCategoricalPs
, collectEvents, collectEventsBy
) where
import Data.Random.RVar
import Data.Random.Distribution
import Data.Random.Distribution.Uniform
import Control.Arrow
import Control.Monad
import Control.Monad.ST
import Control.Applicative
import Data.Foldable (Foldable(foldMap))
import Data.STRef
import Data.Traversable (Traversable(traverse, sequenceA))
import Data.List
import Data.Function
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
categorical :: (Num p, Distribution (Categorical p) a) => [(p,a)] -> RVar a
categorical = rvar . fromList
categoricalT :: (Num p, Distribution (Categorical p) a) => [(p,a)] -> RVarT m a
categoricalT = rvarT . fromList
weightedCategorical :: (Fractional p, Eq p, Distribution (Categorical p) a) => [(p,a)] -> RVar a
weightedCategorical = rvar . fromWeightedList
weightedCategoricalT :: (Fractional p, Eq p, Distribution (Categorical p) a) => [(p,a)] -> RVarT m a
weightedCategoricalT = rvarT . fromWeightedList
{-# INLINE fromList #-}
fromList :: (Num p) => [(p,a)] -> Categorical p a
fromList xs = Categorical (V.fromList (scanl1 f xs))
where f (p0, _) (p1, y) = (p0 + p1, y)
{-# INLINE toList #-}
toList :: (Num p) => Categorical p a -> [(p,a)]
toList (Categorical ds) = V.foldr' g [] ds
where
g x [] = [x]
g x@(p0,_) ((p1, y):xs) = x : (p1-p0,y) : xs
totalWeight :: Num p => Categorical p a -> p
totalWeight (Categorical ds)
| V.null ds = 0
| otherwise = fst (V.last ds)
numEvents :: Categorical p a -> Int
numEvents (Categorical ds) = V.length ds
fromWeightedList :: (Fractional p, Eq p) => [(p,a)] -> Categorical p a
fromWeightedList = normalizeCategoricalPs . fromList
fromObservations :: (Fractional p, Eq p, Ord a) => [a] -> Categorical p a
fromObservations = fromWeightedList . map (genericLength &&& head) . group . sort
newtype Categorical p a = Categorical (V.Vector (p, a))
deriving Eq
instance (Num p, Show p, Show a) => Show (Categorical p a) where
showsPrec p cat = showParen (p>10)
( showString "fromList "
. showsPrec 11 (toList cat)
)
instance (Num p, Read p, Read a) => Read (Categorical p a) where
readsPrec p = readParen (p > 10) $ \str -> do
("fromList", valStr) <- lex str
(vals, rest) <- readsPrec 11 valStr
return (fromList vals, rest)
instance (Fractional p, Ord p, Distribution Uniform p) => Distribution (Categorical p) a where
rvarT (Categorical ds)
| V.null ds = fail "categorical distribution over empty set cannot be sampled"
| n == 1 = return (snd (V.head ds))
| otherwise = do
u <- uniformT 0 (fst (V.last ds))
let
p i = fst (ds V.! i)
x i = snd (ds V.! i)
findEvent i j
| j <= i = x j
| u <= p m = findEvent i m
| otherwise = findEvent (max m (i+1)) j
where
m = (i + j) `div` 2
return $! if u <= 0 then x 0 else findEvent 0 (n-1)
where n = V.length ds
instance Functor (Categorical p) where
fmap f (Categorical ds) = Categorical (V.map (second f) ds)
instance Foldable (Categorical p) where
foldMap f (Categorical ds) = foldMap (f . snd) (V.toList ds)
instance Traversable (Categorical p) where
traverse f (Categorical ds) = Categorical . V.fromList <$> traverse (\(p,e) -> (\e' -> (p,e')) <$> f e) (V.toList ds)
sequenceA (Categorical ds) = Categorical . V.fromList <$> traverse (\(p,e) -> (\e' -> (p,e')) <$> e) (V.toList ds)
instance Fractional p => Monad (Categorical p) where
return x = Categorical (V.singleton (1, x))
fail _ = Categorical V.empty
xs >>= f = fromList $ do
(p, x) <- toList xs
(q, y) <- toList (f x)
return (p * q, y)
instance Fractional p => Applicative (Categorical p) where
pure = return
(<*>) = ap
mapCategoricalPs :: (Num p, Num q) => (p -> q) -> Categorical p e -> Categorical q e
mapCategoricalPs f = fromList . map (first f) . toList
normalizeCategoricalPs :: (Fractional p, Eq p) => Categorical p e -> Categorical p e
normalizeCategoricalPs orig@(Categorical ds)
| ps == 0 = Categorical V.empty
| otherwise = runST $ do
lastP <- newSTRef 0
nDups <- newSTRef 0
normalized <- V.thaw ds
let n = V.length ds
skip = modifySTRef' nDups (1+)
save i p x = do
d <- readSTRef nDups
MV.write normalized (i-d) (p, x)
sequence_
[ do
let (p,x) = ds V.! i
p0 <- readSTRef lastP
if p == p0
then skip
else do
save i (p * scale) x
writeSTRef lastP $! p
| i <- [0..n-1]
]
d <- readSTRef nDups
let n' = n-d
(_,lastX) <- MV.read normalized (n'-1)
MV.write normalized (n'-1) (1,lastX)
Categorical <$> V.unsafeFreeze (MV.unsafeSlice 0 n' normalized)
where
ps = totalWeight orig
scale = recip ps
#if __GLASGOW_HASKELL__ < 706
modifySTRef' :: STRef s a -> (a -> a) -> ST s ()
modifySTRef' x f = do
v <- readSTRef x
let fv = f v
fv `seq` writeSTRef x fv
#endif
collectEvents :: (Ord e, Num p, Ord p) => Categorical p e -> Categorical p e
collectEvents = collectEventsBy compare ((sum *** head) . unzip)
collectEventsBy :: Num p => (e -> e -> Ordering) -> ([(p,e)] -> (p,e))-> Categorical p e -> Categorical p e
collectEventsBy compareE combine =
fromList . map combine . groupEvents . sortEvents . toList
where
groupEvents = groupBy (\x y -> snd x `compareE` snd y == EQ)
sortEvents = sortBy (compareE `on` snd)