module Data.Stochastic.Types where
import Control.Monad
import Control.Monad.Writer
import Data.Stochastic.Internal
import qualified Data.Sequence as S
import System.Random
data Distribution a where
Normal :: Mean -> StDev -> Distribution Double
Bernoulli :: Double -> Distribution Bool
Discrete :: [(a, Double)] -> Distribution a
Uniform :: [a] -> Distribution a
Certain :: a -> Distribution a
class Sampleable d where
certainDist :: a -> d a
sampleFrom :: (RandomGen g) => d a -> g -> (a, g)
instance Sampleable Distribution where
sampleFrom da g
= case da of
Normal mean stdev
-> let (a, g') = decentRandom g
(a', g'') = decentRandom g'
s = (stdev * (boxMuller a a')) + mean
in (s, g')
Bernoulli prob
-> let (a, g') = decentRandom g
in (a <= prob, g')
Discrete l
-> let (a, g') = decentRandom g
in (scan a l, g')
where scan lim [] =
if lim <= 0 then error $ "not normalized discrete dist"
else error "empty discrete dist"
scan lim (x:xs) =
if lim <= snd x then fst x
else scan (lim snd x) xs
Uniform l
-> let (a, g') = decentRandom g
prob = 1 / (fromIntegral $ length l)
in (l !! (floor $ a / prob), g')
Certain val
-> (val, snd $ decentRandom g)
certainDist = Certain
instance (Show a) => Show (Distribution a) where
show da = case da of
Normal mean stdev -> "Normal " ++ show mean ++ " " ++ show stdev
Bernoulli prob -> "Bernoulli " ++ show prob
Discrete l -> "Discrete " ++ show l
Uniform l -> "Uniform " ++ show l
Certain val -> "Certain " ++ show val
newtype Sample g d a
= Sample { runSample :: (RandomGen g, Sampleable d) => g -> (d a, g) }
type StochProcess
= WriterT (S.Seq Double) (Sample StdGen Distribution) Double
instance (RandomGen g, Sampleable d) => Monad (Sample g d) where
return x = Sample $ \g -> (certainDist x, snd $ next g)
(>>=) ma f = Sample $ \g ->
let (dist, g') = runSample ma g
(a, g'') = sampleFrom dist g'
in runSample (f a) g''
instance (RandomGen g, Sampleable s) => Functor (Sample g s) where
fmap = liftM
instance (RandomGen g, Sampleable s) => Applicative (Sample g s) where
pure = return
(<*>) = ap
type Sampler a = Sample StdGen Distribution a
type Mean = Double
type StDev = Double