{-# LANGUAGE CPP
, FlexibleInstances
, GADTs
, DataKinds
, TypeOperators
, KindSignatures
, LambdaCase
, ViewPatterns
, DeriveDataTypeable
, StandaloneDeriving
, OverlappingInstances
, UndecidableInstances
, RankNTypes
#-}
{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
module Language.Hakaru.Syntax.Transform
(
TransformImpl(..)
, Transform(..)
, transformName, allTransforms
, typeOfTransform
, TransformCtx(..), HasTransformCtx(..), unionCtx, minimalCtx
, TransformTable(..), lookupTransform', simpleTable
, unionTable, someTransformations
)
where
import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Syntax.SArgs
import Language.Hakaru.Syntax.IClasses
import Language.Hakaru.Syntax.Variable
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing
import Control.Applicative (Alternative(..), Applicative(..))
import Data.Number.Nat
import Data.Data (Data, Typeable)
import Data.List (stripPrefix)
import Data.Monoid (Monoid(..))
#if !(MIN_VERSION_base(4,11,0))
import Data.Semigroup
#endif
data TransformImpl = InMaple | InHaskell
deriving (Eq, Ord, Show, Read, Data, Typeable)
data Transform :: [([Hakaru], Hakaru)] -> Hakaru -> * where
Expect ::
Transform
'[ LC ('HMeasure a), '( '[ a ], 'HProb) ] 'HProb
Observe ::
Transform
'[ LC ('HMeasure a), LC a ] ('HMeasure a)
MH ::
Transform
'[ LC (a ':-> 'HMeasure a), LC ('HMeasure a) ]
(a ':-> 'HMeasure (HPair a 'HProb))
MCMC ::
Transform
'[ LC (a ':-> 'HMeasure a), LC ('HMeasure a) ]
(a ':-> 'HMeasure a)
Disint :: TransformImpl ->
Transform
'[ LC ('HMeasure (HPair a b)) ]
(a :-> 'HMeasure b)
Summarize ::
Transform '[ LC a ] a
Simplify ::
Transform '[ LC a ] a
Reparam ::
Transform '[ LC a ] a
deriving instance Eq (Transform args a)
deriving instance Show (Transform args a)
instance Eq (Some2 Transform) where
Some2 t0 == Some2 t1 =
case (t0, t1) of
(Expect , Expect ) -> True
(Observe , Observe ) -> True
(MH , MH ) -> True
(MCMC , MCMC ) -> True
(Disint k0 , Disint k1) -> k0==k1
(Summarize , Summarize) -> True
(Simplify , Simplify ) -> True
(Reparam , Reparam ) -> True
_ -> False
instance Read (Some2 Transform) where
readsPrec _ s =
let trs = map (\t'@(Some2 t) -> (show t, t')) allTransforms
readMay (s', t)
| Just rs <- stripPrefix s' s = [(t, rs)]
| otherwise = []
in concatMap readMay trs
transformName :: Transform args a -> String
transformName =
\case
Expect -> "expect"
Observe -> "observe"
MH -> "mh"
MCMC -> "mcmc"
Disint k -> "disint" ++
(case k of
InHaskell -> ""
InMaple -> "_m")
Summarize -> "summarize"
Simplify -> "simplify"
Reparam -> "reparam"
allTransforms :: [Some2 Transform]
allTransforms =
[ Some2 Expect, Some2 Observe, Some2 MH, Some2 MCMC
, Some2 (Disint InHaskell), Some2 (Disint InMaple)
, Some2 Summarize, Some2 Simplify, Some2 Reparam ]
typeOfTransform
:: Transform as x
-> SArgsSing as
-> Sing x
typeOfTransform t as =
case (t,as) of
(Expect , _)
-> SProb
(Observe , Pw _ e :* _ :* End)
-> e
(MH , Pw _ (fst.sUnFun -> a) :* _ :* End)
-> SFun a (SMeasure (sPair a SProb))
(MCMC , Pw _ a :* _)
-> a
(Disint _ , Pw _ (sUnPair.sUnMeasure -> (a,b)) :* End)
-> SFun a (SMeasure b)
(Summarize, Pw _ e :* End)
-> e
(Simplify , Pw _ e :* End)
-> e
(Reparam , Pw _ e :* End)
-> e
newtype TransformCtx = TransformCtx
{ nextFreeVar :: Nat }
deriving (Eq, Ord, Show)
minimalCtx :: TransformCtx
minimalCtx = TransformCtx { nextFreeVar = 0 }
unionCtx :: TransformCtx -> TransformCtx -> TransformCtx
unionCtx ctx0 ctx1 =
TransformCtx { nextFreeVar = max (nextFreeVar ctx0) (nextFreeVar ctx1) }
instance Semigroup TransformCtx where
(<>) = unionCtx
instance Monoid TransformCtx where
mempty = minimalCtx
#if !(MIN_VERSION_base(4,11,0))
mappend = (<>)
#endif
class HasTransformCtx x where
ctxOf :: x -> TransformCtx
instance HasTransformCtx (Variable (a :: Hakaru)) where
ctxOf v = TransformCtx { nextFreeVar = varID v + 1 }
instance ABT syn abt => HasTransformCtx (abt (xs :: [Hakaru]) (a :: Hakaru)) where
ctxOf t = TransformCtx { nextFreeVar = nextFree t }
newtype TransformTable abt m
= TransformTable
{ lookupTransform
:: forall as b
. Transform as b
-> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b))) }
lookupTransform'
:: (Applicative m)
=> TransformTable abt m
-> Transform as b
-> TransformCtx
-> SArgs abt as -> m (Maybe (abt '[] b))
lookupTransform' tbl tr ctx args=
case lookupTransform tbl tr of
Just f -> f ctx args
Nothing -> pure Nothing
simpleTable
:: (Applicative m)
=> (forall as b . Transform as b
-> Maybe (TransformCtx -> SArgs abt as -> Maybe (abt '[] b)))
-> TransformTable abt m
simpleTable k = TransformTable $ \tr -> fmap (fmap (fmap pure)) $ k tr
unionTable :: TransformTable abt m
-> TransformTable abt m
-> TransformTable abt m
unionTable tbl0 tbl1 = TransformTable $ \tr ->
lookupTransform tbl0 tr <|>
lookupTransform tbl1 tr
someTransformations :: [Some2 Transform]
-> TransformTable abt m
-> TransformTable abt m
someTransformations toExpand tbl = TransformTable $
\tr -> if Some2 tr `elem` toExpand then lookupTransform tbl tr else Nothing