module Language.Hakaru.Sample where
import Numeric.SpecFunctions (logGamma, logBeta, logFactorial)
import qualified Data.Number.LogFloat as LF
import qualified System.Random.MWC as MWC
import qualified System.Random.MWC.Distributions as MWCD
import qualified Data.Vector as V
import Data.Sequence (Seq)
import qualified Data.Foldable as F
import qualified Data.List.NonEmpty as L
import Data.List.NonEmpty (NonEmpty(..))
import Data.Maybe (fromMaybe)
#if __GLASGOW_HASKELL__ < 710
import Control.Applicative (Applicative(..), (<$>))
#endif
import Control.Monad.Identity
import Control.Monad.Trans.Maybe
import Control.Monad.State.Strict
import qualified Data.IntMap as IM
import Data.Number.Nat (fromNat, unsafeNat)
import Data.Number.Natural (fromNatural, fromNonNegativeRational)
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Coercion
import Language.Hakaru.Types.Sing
import Language.Hakaru.Types.HClasses
import Language.Hakaru.Syntax.IClasses
import Language.Hakaru.Syntax.TypeOf
import Language.Hakaru.Syntax.Value
import Language.Hakaru.Syntax.Datum
import Language.Hakaru.Syntax.DatumCase
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.ABT
data EAssoc =
forall a. EAssoc !(Variable a) !(Value a)
newtype Env = Env (IM.IntMap EAssoc)
emptyEnv :: Env
emptyEnv = Env IM.empty
updateEnv :: EAssoc -> Env -> Env
updateEnv v@(EAssoc x _) (Env xs) =
Env $ IM.insert (fromNat $ varID x) v xs
lookupVar :: Variable a -> Env -> Maybe (Value a)
lookupVar x (Env env) = do
EAssoc x' e' <- IM.lookup (fromNat $ varID x) env
Refl <- varEq x x'
return e'
poisson_rng :: Double -> MWC.GenIO -> IO Int
poisson_rng lambda g' = make_poisson g'
where
smu = sqrt lambda
b = 0.931 + 2.53*smu
a = 0.059 + 0.02483*b
vr = 0.9277 3.6224/(b 2)
arep = 1.1239 + 1.1368/(b 3.4)
lnlam = log lambda
make_poisson :: MWC.GenIO -> IO Int
make_poisson g = do
u <- MWC.uniformR (0.5,0.5) g
v <- MWC.uniformR (0,1) g
let us = 0.5 abs u
k = floor $ (2*a / us + b)*u + lambda + 0.43
case () of
() | us >= 0.07 && v <= vr -> return k
() | k < 0 -> make_poisson g
() | us <= 0.013 && v > us -> make_poisson g
() | accept_region us v k -> return k
_ -> make_poisson g
accept_region :: Double -> Double -> Int -> Bool
accept_region us v k =
log (v * arep / (a/(us*us)+b))
<=
lambda + fromIntegral k * lnlam logFactorial k
normalize :: [Value 'HProb] -> (LF.LogFloat, Double, [Double])
normalize [] = (0, 0, [])
normalize [(VProb x)] = (x, 1, [1])
normalize xs = (m, y, ys)
where
xs' = map (\(VProb x) -> x) xs
m = maximum xs'
ys = [ LF.fromLogFloat (x/m) | x <- xs' ]
y = sum ys
normalizeVector
:: Value ('HArray 'HProb) -> (LF.LogFloat, Double, V.Vector Double)
normalizeVector (VArray xs) =
let xs' = V.map (\(VProb x) -> x) xs in
case V.length xs of
0 -> (0, 0, V.empty)
1 -> (V.unsafeHead xs', 1, V.singleton 1)
_ ->
let m = V.maximum xs'
ys = V.map (\x -> LF.fromLogFloat (x/m)) xs'
y = V.sum ys
in (m, y, ys)
runEvaluate
:: (ABT Term abt)
=> abt '[] a
-> Value a
runEvaluate prog = evaluate prog emptyEnv
evaluate
:: (ABT Term abt)
=> abt '[] a
-> Env
-> Value a
evaluate e env = caseVarSyn e (evaluateVar env) (flip evaluateTerm env)
evaluateVar :: Env -> Variable a -> Value a
evaluateVar env v =
case lookupVar v env of
Nothing -> error "variable not found!"
Just a -> a
evaluateTerm
:: (ABT Term abt)
=> Term abt a
-> Env
-> Value a
evaluateTerm t env =
case t of
o :$ es -> evaluateSCon o es env
NaryOp_ o es -> evaluateNaryOp o es env
Literal_ v -> evaluateLiteral v
Empty_ _ -> evaluateEmpty
Array_ n es -> evaluateArray n es env
Datum_ d -> evaluateDatum d env
Case_ o es -> evaluateCase o es env
Superpose_ es -> evaluateSuperpose es env
Reject_ _ -> VMeasure $ \_ _ -> return Nothing
evaluateSCon
:: (ABT Term abt)
=> SCon args a
-> SArgs abt args
-> Env
-> Value a
evaluateSCon Lam_ (e1 :* End) env =
caseBind e1 $ \x e1' ->
VLam $ \v -> evaluate e1' (updateEnv (EAssoc x v) env)
evaluateSCon App_ (e1 :* e2 :* End) env =
case evaluate e1 env of
VLam f -> f (evaluate e2 env)
v -> case v of {}
evaluateSCon Let_ (e1 :* e2 :* End) env =
let v = evaluate e1 env
in caseBind e2 $ \x e2' ->
evaluate e2' (updateEnv (EAssoc x v) env)
evaluateSCon (CoerceTo_ c) (e1 :* End) env =
coerceTo c $ evaluate e1 env
evaluateSCon (UnsafeFrom_ c) (e1 :* End) env =
coerceFrom c $ evaluate e1 env
evaluateSCon (PrimOp_ o) es env = evaluatePrimOp o es env
evaluateSCon (ArrayOp_ o) es env = evaluateArrayOp o es env
evaluateSCon (MeasureOp_ m) es env = evaluateMeasureOp m es env
evaluateSCon Dirac (e1 :* End) env =
VMeasure $ \p _ -> return $ Just (evaluate e1 env, p)
evaluateSCon MBind (e1 :* e2 :* End) env =
case evaluate e1 env of
VMeasure m1 -> VMeasure $ \ p g -> do
x <- m1 p g
case x of
Nothing -> return Nothing
Just (a, p') ->
caseBind e2 $ \x' e2' ->
case evaluate e2' (updateEnv (EAssoc x' a) env) of
VMeasure y -> y p' g
v -> case v of {}
v -> case v of {}
evaluateSCon Plate (n :* e2 :* End) env =
case evaluate n env of
VNat n' -> caseBind e2 $ \x e' ->
VMeasure $ \(VProb p) g -> runMaybeT $ do
(v', ps) <- fmap V.unzip . V.mapM (performMaybe g) $
V.generate (fromNat n') $ \v ->
evaluate e' $
updateEnv (EAssoc x . VNat $ unsafeNat v) env
return
( VArray v'
, VProb $ p * V.product (V.map (\(VProb x) -> x) ps)
)
v -> case v of {}
where
performMaybe
:: MWC.GenIO
-> Value ('HMeasure a)
-> MaybeT IO (Value a, Value 'HProb)
performMaybe g (VMeasure m) = MaybeT $ m (VProb 1) g
evaluateSCon Chain (n :* s :* e :* End) env =
case (evaluate n env, evaluate s env) of
(VNat n', start) ->
caseBind e $ \x e' ->
let s' = VLam $ \v -> evaluate e' (updateEnv (EAssoc x v) env) in
VMeasure (\(VProb p) g -> runMaybeT $ do
(evaluates, sout) <- runStateT (replicateM (fromNat n') $ convert g s') start
let (v', ps) = unzip evaluates
bodyType :: Sing ('HMeasure (HPair a b)) -> Sing ('HArray a)
bodyType = SArray . fst . sUnPair . sUnMeasure
return
( VDatum $ dPair_ (bodyType $ caseBind e (const typeOf)) (typeOf s)
(VArray . V.fromList $ v') sout
, VProb $ p * product (map (\(VProb x) -> x) ps)
))
v -> case v of {}
where
convert
:: MWC.GenIO
-> Value (s ':-> 'HMeasure (HPair a s))
-> StateT (Value s) (MaybeT IO) (Value a, Value 'HProb)
convert g (VLam f) = StateT $ \s' ->
case f s' of
VMeasure f' -> do
(as'', p') <- MaybeT (f' (VProb 1) g)
let (a, s'') = unPair as''
return ((a, p'), s'')
v -> case v of {}
unPair :: Value (HPair a b) -> (Value a, Value b)
unPair (VDatum (Datum "pair" _typ
(Inl (Et (Konst a)
(Et (Konst b) Done))))) = (a, b)
unPair x = case x of {}
evaluateSCon (Summate hd hs) (e1 :* e2 :* e3 :* End) env =
case (evaluate e1 env, evaluate e2 env) of
(lo, hi) ->
caseBind e3 $ \x e3' ->
foldl (\t i ->
evalOp (Sum hs) t $
evaluate e3' (updateEnv (EAssoc x i) env))
(identityElement $ Sum hs)
(enumFromUntilValue hd lo hi)
v -> case v of {}
evaluateSCon (Product hd hs) (e1 :* e2 :* e3 :* End) env =
case (evaluate e1 env, evaluate e2 env) of
(lo, hi) ->
caseBind e3 $ \x e3' ->
foldl (\t i ->
evalOp (Prod hs) t $
evaluate e3' (updateEnv (EAssoc x i) env))
(identityElement $ Prod hs)
(enumFromUntilValue hd lo hi)
v -> case v of {}
evaluateSCon s _ _ = error $ "TODO: evaluateSCon{" ++ show s ++ "}"
evaluatePrimOp
:: ( ABT Term abt, typs ~ UnLCs args, args ~ LCs typs)
=> PrimOp typs a
-> SArgs abt args
-> Env
-> Value a
evaluatePrimOp Not (e1 :* End) env =
case evaluate e1 env of
VDatum a -> if a == dTrue
then VDatum dFalse
else VDatum dTrue
v -> case v of {}
evaluatePrimOp Pi End _ = VProb . LF.logFloat $ pi
evaluatePrimOp Cos (e1 :* End) env =
case evaluate e1 env of
VReal v1 -> VReal . cos $ v1
v -> case v of {}
evaluatePrimOp RealPow (e1 :* e2 :* End) env =
case (evaluate e1 env, evaluate e2 env) of
(VProb v1, VReal v2) -> VProb $ LF.pow v1 v2
v -> case v of {}
evaluatePrimOp Exp (e1 :* End) env =
case evaluate e1 env of
VReal v1 -> VProb . LF.logToLogFloat $ v1
v -> case v of {}
evaluatePrimOp (Infinity h) End _ =
case h of
HIntegrable_Nat -> error "Can not evaluate infinity for natural numbers"
HIntegrable_Prob -> VProb $ LF.logFloat LF.infinity
evaluatePrimOp (Equal _) (e1 :* e2 :* End) env =
case (evaluate e1 env, evaluate e2 env) of
(VNat v1, VNat v2) -> VDatum $ if v1 == v2 then dTrue else dFalse
(VInt v1, VInt v2) -> VDatum $ if v1 == v2 then dTrue else dFalse
(VProb v1, VProb v2) -> VDatum $ if v1 == v2 then dTrue else dFalse
(VReal v1, VReal v2) -> VDatum $ if v1 == v2 then dTrue else dFalse
v -> error "TODO: evaluatePrimOp{Equal}"
evaluatePrimOp (Less _) (e1 :* e2 :* End) env =
case (evaluate e1 env, evaluate e2 env) of
(VNat v1, VNat v2) -> VDatum $ if v1 < v2 then dTrue else dFalse
(VProb v1, VProb v2) -> VDatum $ if v1 < v2 then dTrue else dFalse
(VReal v1, VReal v2) -> VDatum $ if v1 < v2 then dTrue else dFalse
v -> error "TODO: evaluatePrimOp{Less}"
evaluatePrimOp (NatPow _) (e1 :* e2 :* End) env =
case evaluate e2 env of
VNat v2 ->
let v2' = fromNat v2 in
case evaluate e1 env of
VNat v1 -> VNat (v1 ^ v2')
VInt v1 -> VInt (v1 ^ v2')
VProb v1 -> VProb (v1 ^ v2')
VReal v1 -> VReal (v1 ^ v2')
v2 -> case v2 of {}
evaluatePrimOp (Negate _) (e1 :* End) env =
case evaluate e1 env of
VInt v -> VInt (negate v)
VReal v -> VReal (negate v)
v -> case v of {}
evaluatePrimOp (Abs _) (e1 :* End) env =
case evaluate e1 env of
VInt v -> VNat . unsafeNat $ abs v
VReal v -> VProb . LF.logFloat $ abs v
v -> case v of {}
evaluatePrimOp (Recip _) (e1 :* End) env =
case evaluate e1 env of
VProb v -> VProb (recip v)
VReal v -> VReal (recip v)
v -> case v of {}
evaluatePrimOp (NatRoot _) (e1 :* e2 :* End) env =
case (evaluate e1 env, evaluate e2 env) of
(VProb v1, VNat v2) -> VProb $ LF.pow v1 (recip . fromIntegral $ v2)
v -> case v of {}
evaluatePrimOp prim _ _ =
error ("TODO: evaluatePrimOp{" ++ show prim ++ "}")
evaluateArrayOp
:: ( ABT Term abt
, typs ~ UnLCs args
, args ~ LCs typs)
=> ArrayOp typs a
-> SArgs abt args
-> Env
-> Value a
evaluateArrayOp (Index _) = \(e1 :* e2 :* End) env ->
case (evaluate e1 env, evaluate e2 env) of
(VArray v, VNat n) -> v V.! fromNat n
_ -> error "evaluateArrayOp: the impossible happened"
evaluateArrayOp (Size _) = \(e1 :* End) env ->
case evaluate e1 env of
VArray v -> VNat . unsafeNat $ V.length v
_ -> error "evaluateArrayOp: the impossible happened"
evaluateArrayOp (Reduce _) = \(e1 :* e2 :* e3 :* End) env ->
case ( evaluate e1 env
, evaluate e2 env
, evaluate e3 env) of
(f, a, VArray v) -> V.foldl' (lam2 f) a v
_ -> error "evaluateArrayOp: the impossible happened"
evaluateMeasureOp
:: ( ABT Term abt
, typs ~ UnLCs args
, args ~ LCs typs)
=> MeasureOp typs a
-> SArgs abt args
-> Env
-> Value ('HMeasure a)
evaluateMeasureOp Lebesgue = \End _ ->
VMeasure $ \(VProb p) g -> do
(u,b) <- MWC.uniform g
let l = log u
let n = l
return $ Just
( VReal $ if b then n else l
, VProb $ p * 2 * LF.logToLogFloat n
)
evaluateMeasureOp Counting = \End _ ->
VMeasure $ \(VProb p) g -> do
let success = LF.logToLogFloat (3 :: Double)
let pow x y = LF.logToLogFloat (LF.logFromLogFloat x *
(fromIntegral y :: Double))
u <- MWCD.geometric0 (LF.fromLogFloat success) g
b <- MWC.uniform g
return $ Just
( VInt $ if b then 1u else u
, VProb $ p * 2 / pow (1success) u / success)
evaluateMeasureOp Categorical = \(e1 :* End) env ->
VMeasure $ \p g -> do
let (_,y,ys) = normalizeVector (evaluate e1 env)
if not (y > (0::Double))
then error "Categorical needs positive weights"
else do
u <- MWC.uniformR (0, y) g
return $ Just
( VNat
. unsafeNat
. fromMaybe 0
. V.findIndex (u <=)
. V.scanl1' (+)
$ ys
, p)
evaluateMeasureOp Uniform = \(e1 :* e2 :* End) env ->
case (evaluate e1 env, evaluate e2 env) of
(VReal v1, VReal v2) -> VMeasure $ \p g -> do
x <- MWC.uniformR (v1, v2) g
return $ Just (VReal x, p)
_ -> error "evaluateMeasureOp: the impossible happened"
evaluateMeasureOp Normal = \(e1 :* e2 :* End) env ->
case (evaluate e1 env, evaluate e2 env) of
(VReal v1, VProb v2) -> VMeasure $ \ p g -> do
x <- MWCD.normal v1 (LF.fromLogFloat v2) g
return $ Just (VReal x, p)
_ -> error "evaluateMeasureOp: the impossible happened"
evaluateMeasureOp Poisson = \(e1 :* End) env ->
case evaluate e1 env of
VProb v1 -> VMeasure $ \ p g -> do
x <- poisson_rng (LF.fromLogFloat v1) g
return $ Just (VNat $ unsafeNat x, p)
_ -> error "evaluateMeasureOp: the impossible happened"
evaluateMeasureOp Gamma = \(e1 :* e2 :* End) env ->
case (evaluate e1 env, evaluate e2 env) of
(VProb v1, VProb v2) -> VMeasure $ \ p g -> do
x <- MWCD.gamma (LF.fromLogFloat v1) (LF.fromLogFloat v2) g
return $ Just (VProb $ LF.logFloat x, p)
_ -> error "evaluateMeasureOp: the impossible happened"
evaluateMeasureOp Beta = \(e1 :* e2 :* End) env ->
case (evaluate e1 env, evaluate e2 env) of
(VProb v1, VProb v2) -> VMeasure $ \ p g -> do
x <- MWCD.beta (LF.fromLogFloat v1) (LF.fromLogFloat v2) g
return $ Just (VProb $ LF.logFloat x, p)
_ -> error "evaluateMeasureOp: the impossible happened"
evaluateNaryOp
:: (ABT Term abt)
=> NaryOp a -> Seq (abt '[] a) -> Env -> Value a
evaluateNaryOp s es =
F.foldr (evalOp s) (identityElement s) . mapEvaluate es
identityElement :: NaryOp a -> Value a
identityElement And = VDatum dTrue
identityElement (Sum HSemiring_Nat) = VNat 0
identityElement (Sum HSemiring_Int) = VInt 0
identityElement (Sum HSemiring_Prob) = VProb 0
identityElement (Sum HSemiring_Real) = VReal 0
identityElement (Prod HSemiring_Nat) = VNat 1
identityElement (Prod HSemiring_Int) = VInt 1
identityElement (Prod HSemiring_Prob) = VProb 1
identityElement (Prod HSemiring_Real) = VReal 1
identityElement (Max HOrd_Prob) = VProb 0
identityElement (Max HOrd_Real) = VReal LF.negativeInfinity
identityElement (Min HOrd_Prob) = VProb (LF.logFloat LF.infinity)
identityElement (Min HOrd_Real) = VReal LF.infinity
evalOp
:: NaryOp a -> Value a -> Value a -> Value a
evalOp And (VDatum a) (VDatum b)
| a == dTrue && b == dTrue = VDatum dTrue
| otherwise = VDatum dFalse
evalOp (Sum HSemiring_Nat) (VNat a) (VNat b) = VNat (a + b)
evalOp (Sum HSemiring_Int) (VInt a) (VInt b) = VInt (a + b)
evalOp (Sum HSemiring_Prob) (VProb a) (VProb b) = VProb (a + b)
evalOp (Sum HSemiring_Real) (VReal a) (VReal b) = VReal (a + b)
evalOp (Prod HSemiring_Nat) (VNat a) (VNat b) = VNat (a * b)
evalOp (Prod HSemiring_Int) (VInt a) (VInt b) = VInt (a * b)
evalOp (Prod HSemiring_Prob) (VProb a) (VProb b) = VProb (a * b)
evalOp (Prod HSemiring_Real) (VReal a) (VReal b) = VReal (a * b)
evalOp (Max HOrd_Prob) (VProb a) (VProb b) = VProb (max a b)
evalOp (Max HOrd_Real) (VReal a) (VReal b) = VReal (max a b)
evalOp (Min HOrd_Prob) (VProb a) (VProb b) = VProb (min a b)
evalOp (Min HOrd_Real) (VReal a) (VReal b) = VReal (min a b)
evalOp op _ _ =
error ("TODO: evalOp{" ++ show op ++ "}")
mapEvaluate
:: (ABT Term abt)
=> Seq (abt '[] a) -> Env -> Seq (Value a)
mapEvaluate es env = fmap (flip evaluate env) es
evaluateLiteral :: Literal a -> Value a
evaluateLiteral (LNat n) = VNat . fromInteger $ fromNatural n
evaluateLiteral (LInt n) = VInt $ fromInteger n
evaluateLiteral (LProb n) = VProb . fromRational $ fromNonNegativeRational n
evaluateLiteral (LReal n) = VReal $ fromRational n
evaluateEmpty :: Value ('HArray a)
evaluateEmpty = VArray V.empty
evaluateArray
:: (ABT Term abt)
=> (abt '[] 'HNat)
-> (abt '[ 'HNat ] a)
-> Env
-> Value ('HArray a)
evaluateArray n e env =
case evaluate n env of
VNat n' -> caseBind e $ \x e' ->
VArray $ V.generate (fromNat n') $ \v ->
let v' = VNat $ unsafeNat v in
evaluate e' (updateEnv (EAssoc x v') env)
evaluateDatum
:: (ABT Term abt)
=> Datum (abt '[]) (HData' a)
-> Env
-> Value (HData' a)
evaluateDatum d env = VDatum (fmap11 (flip evaluate env) d)
evaluateCase
:: forall abt a b
. (ABT Term abt)
=> abt '[] a
-> [Branch a abt b]
-> Env
-> Value b
evaluateCase o es env =
case runIdentity $ matchBranches evaluateDatum' (evaluate o env) es of
Just (Matched rho b) ->
evaluate b (extendFromMatch (fromAssocs rho) env)
_ -> error "Missing cases in match expression"
where
extendFromMatch :: [Assoc Value] -> Env -> Env
extendFromMatch [] env' = env'
extendFromMatch (Assoc x v : xvs) env' =
extendFromMatch xvs (updateEnv (EAssoc x v) env')
evaluateDatum' :: DatumEvaluator Value Identity
evaluateDatum' = return . Just . getVDatum
getVDatum :: Value (HData' a) -> Datum Value (HData' a)
getVDatum (VDatum a) = a
evaluateSuperpose
:: (ABT Term abt)
=> NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a))
-> Env
-> Value ('HMeasure a)
evaluateSuperpose ((q, m) :| []) env =
case evaluate m env of
VMeasure m' ->
let VProb q' = evaluate q env
in VMeasure (\(VProb p) g -> m' (VProb $ p * q') g)
evaluateSuperpose pms@((_, m) :| _) env =
case evaluate m env of
VMeasure m' ->
let pms' = L.toList pms
weights = map ((flip evaluate env) . fst) pms'
(x,y,ys) = normalize weights
in VMeasure $ \(VProb p) g ->
if not (y > (0::Double)) then return Nothing else do
u <- MWC.uniformR (0, y) g
case [ m1 | (v,(_,m1)) <- zip (scanl1 (+) ys) pms', u <= v ] of
m2 : _ ->
case evaluate m2 env of
VMeasure m2' -> m2' (VProb $ p * x * LF.logFloat y) g
[] -> m' (VProb $ p * x * LF.logFloat y) g