{-# LANGUAGE CPP
           , DataKinds
           , FlexibleContexts
           , GADTs
           , GeneralizedNewtypeDeriving
           , MultiParamTypeClasses
           , RankNTypes
           , ScopedTypeVariables
           , TypeOperators
           #-}

{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
----------------------------------------------------------------
--                                                    2017.02.01
-- |
-- Module      :  Language.Hakaru.Syntax.Unroll
-- Copyright   :  Copyright (c) 2016 the Hakaru team
-- License     :  BSD3
-- Maintainer  :
-- Stability   :  experimental
-- Portability :  GHC-only
--
-- Performs renaming of Hakaru expressions to ensure globally unique variable
-- identifiers.
--
----------------------------------------------------------------
module Language.Hakaru.Syntax.Unroll (renameInEnv, unroll) where

import           Control.Monad.Reader
import           Data.Maybe                     (fromMaybe)
import           Language.Hakaru.Syntax.ABT
import           Language.Hakaru.Syntax.AST
import           Language.Hakaru.Syntax.AST.Eq  (Varmap)
import           Language.Hakaru.Syntax.Prelude hiding ((>>=))
import           Language.Hakaru.Types.HClasses
import           Prelude                        hiding (product, (*), (+), (-),
                                                 (==), (>=))

#if __GLASGOW_HASKELL__ < 710
import           Control.Applicative
#endif

newtype Unroll a = Unroll { Unroll a -> Reader Varmap a
runUnroll :: Reader Varmap a }
  deriving (a -> Unroll b -> Unroll a
(a -> b) -> Unroll a -> Unroll b
(forall a b. (a -> b) -> Unroll a -> Unroll b)
-> (forall a b. a -> Unroll b -> Unroll a) -> Functor Unroll
forall a b. a -> Unroll b -> Unroll a
forall a b. (a -> b) -> Unroll a -> Unroll b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> Unroll b -> Unroll a
$c<$ :: forall a b. a -> Unroll b -> Unroll a
fmap :: (a -> b) -> Unroll a -> Unroll b
$cfmap :: forall a b. (a -> b) -> Unroll a -> Unroll b
Functor, Functor Unroll
a -> Unroll a
Functor Unroll
-> (forall a. a -> Unroll a)
-> (forall a b. Unroll (a -> b) -> Unroll a -> Unroll b)
-> (forall a b c.
    (a -> b -> c) -> Unroll a -> Unroll b -> Unroll c)
-> (forall a b. Unroll a -> Unroll b -> Unroll b)
-> (forall a b. Unroll a -> Unroll b -> Unroll a)
-> Applicative Unroll
Unroll a -> Unroll b -> Unroll b
Unroll a -> Unroll b -> Unroll a
Unroll (a -> b) -> Unroll a -> Unroll b
(a -> b -> c) -> Unroll a -> Unroll b -> Unroll c
forall a. a -> Unroll a
forall a b. Unroll a -> Unroll b -> Unroll a
forall a b. Unroll a -> Unroll b -> Unroll b
forall a b. Unroll (a -> b) -> Unroll a -> Unroll b
forall a b c. (a -> b -> c) -> Unroll a -> Unroll b -> Unroll c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: Unroll a -> Unroll b -> Unroll a
$c<* :: forall a b. Unroll a -> Unroll b -> Unroll a
*> :: Unroll a -> Unroll b -> Unroll b
$c*> :: forall a b. Unroll a -> Unroll b -> Unroll b
liftA2 :: (a -> b -> c) -> Unroll a -> Unroll b -> Unroll c
$cliftA2 :: forall a b c. (a -> b -> c) -> Unroll a -> Unroll b -> Unroll c
<*> :: Unroll (a -> b) -> Unroll a -> Unroll b
$c<*> :: forall a b. Unroll (a -> b) -> Unroll a -> Unroll b
pure :: a -> Unroll a
$cpure :: forall a. a -> Unroll a
$cp1Applicative :: Functor Unroll
Applicative, Applicative Unroll
a -> Unroll a
Applicative Unroll
-> (forall a b. Unroll a -> (a -> Unroll b) -> Unroll b)
-> (forall a b. Unroll a -> Unroll b -> Unroll b)
-> (forall a. a -> Unroll a)
-> Monad Unroll
Unroll a -> (a -> Unroll b) -> Unroll b
Unroll a -> Unroll b -> Unroll b
forall a. a -> Unroll a
forall a b. Unroll a -> Unroll b -> Unroll b
forall a b. Unroll a -> (a -> Unroll b) -> Unroll b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> Unroll a
$creturn :: forall a. a -> Unroll a
>> :: Unroll a -> Unroll b -> Unroll b
$c>> :: forall a b. Unroll a -> Unroll b -> Unroll b
>>= :: Unroll a -> (a -> Unroll b) -> Unroll b
$c>>= :: forall a b. Unroll a -> (a -> Unroll b) -> Unroll b
$cp1Monad :: Applicative Unroll
Monad, MonadReader Varmap, Monad Unroll
Monad Unroll
-> (forall a. (a -> Unroll a) -> Unroll a) -> MonadFix Unroll
(a -> Unroll a) -> Unroll a
forall a. (a -> Unroll a) -> Unroll a
forall (m :: * -> *).
Monad m -> (forall a. (a -> m a) -> m a) -> MonadFix m
mfix :: (a -> Unroll a) -> Unroll a
$cmfix :: forall a. (a -> Unroll a) -> Unroll a
$cp1MonadFix :: Monad Unroll
MonadFix)

rebind
  :: (ABT Term abt, MonadFix m)
  => Variable a
  -> (Variable a -> m (abt xs b))
  -> m (abt (a ': xs) b)
rebind :: Variable a -> (Variable a -> m (abt xs b)) -> m (abt (a : xs) b)
rebind Variable a
source Variable a -> m (abt xs b)
f = Text -> Sing a -> (abt '[] a -> m (abt xs b)) -> m (abt (a : xs) b)
forall a1 (m :: * -> *) (syn :: ([a1] -> a1 -> *) -> a1 -> *)
       (abt :: [a1] -> a1 -> *) (a2 :: a1) (xs :: [a1]) (b :: a1).
(MonadFix m, ABT syn abt) =>
Text
-> Sing a2 -> (abt '[] a2 -> m (abt xs b)) -> m (abt (a2 : xs) b)
binderM (Variable a -> Text
forall k (a :: k). Variable a -> Text
varHint Variable a
source) (Variable a -> Sing a
forall k (a :: k). Variable a -> Sing a
varType Variable a
source) ((abt '[] a -> m (abt xs b)) -> m (abt (a : xs) b))
-> (abt '[] a -> m (abt xs b)) -> m (abt (a : xs) b)
forall a b. (a -> b) -> a -> b
$ \abt '[] a
var' ->
  let v :: Variable a
v = abt '[] a
-> (Variable a -> Variable a)
-> (Term abt a -> Variable a)
-> Variable a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k) r.
ABT syn abt =>
abt '[] a -> (Variable a -> r) -> (syn abt a -> r) -> r
caseVarSyn abt '[] a
var' Variable a -> Variable a
forall a. a -> a
id (Variable a -> Term abt a -> Variable a
forall a b. a -> b -> a
const (Variable a -> Term abt a -> Variable a)
-> Variable a -> Term abt a -> Variable a
forall a b. (a -> b) -> a -> b
$ [Char] -> Variable a
forall a. HasCallStack => [Char] -> a
error [Char]
"oops")
  in Variable a -> m (abt xs b)
f Variable a
v

renameInEnv
  :: (ABT Term abt, MonadReader Varmap m, MonadFix m)
  => Variable a
  -> m (abt xs b)
  -> m (abt (a ': xs) b)
renameInEnv :: Variable a -> m (abt xs b) -> m (abt (a : xs) b)
renameInEnv Variable a
source m (abt xs b)
action = Variable a -> (Variable a -> m (abt xs b)) -> m (abt (a : xs) b)
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *) (a :: Hakaru)
       (xs :: [Hakaru]) (b :: Hakaru).
(ABT Term abt, MonadFix m) =>
Variable a -> (Variable a -> m (abt xs b)) -> m (abt (a : xs) b)
rebind Variable a
source ((Variable a -> m (abt xs b)) -> m (abt (a : xs) b))
-> (Variable a -> m (abt xs b)) -> m (abt (a : xs) b)
forall a b. (a -> b) -> a -> b
$ \Variable a
v ->
  (Varmap -> Varmap) -> m (abt xs b) -> m (abt xs b)
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (Assoc Variable -> Varmap -> Varmap
forall k (ast :: k -> *). Assoc ast -> Assocs ast -> Assocs ast
insertAssoc (Assoc Variable -> Varmap -> Varmap)
-> Assoc Variable -> Varmap -> Varmap
forall a b. (a -> b) -> a -> b
$ Variable a -> Variable a -> Assoc Variable
forall k (ast :: k -> *) (a :: k). Variable a -> ast a -> Assoc ast
Assoc Variable a
source Variable a
v) m (abt xs b)
action

unroll :: forall abt xs a . (ABT Term abt) => abt xs a -> abt xs a
unroll :: abt xs a -> abt xs a
unroll abt xs a
abt = Reader Varmap (abt xs a) -> Varmap -> abt xs a
forall r a. Reader r a -> r -> a
runReader (Unroll (abt xs a) -> Reader Varmap (abt xs a)
forall a. Unroll a -> Reader Varmap a
runUnroll (Unroll (abt xs a) -> Reader Varmap (abt xs a))
-> Unroll (abt xs a) -> Reader Varmap (abt xs a)
forall a b. (a -> b) -> a -> b
$ abt xs a -> Unroll (abt xs a)
forall (abt :: [Hakaru] -> Hakaru -> *) (xs :: [Hakaru])
       (a :: Hakaru).
ABT Term abt =>
abt xs a -> Unroll (abt xs a)
unroll' abt xs a
abt) Varmap
forall k (abt :: k -> *). Assocs abt
emptyAssocs

unroll' :: forall abt xs a . (ABT Term abt) => abt xs a -> Unroll (abt xs a)
unroll' :: abt xs a -> Unroll (abt xs a)
unroll' = (forall (a :: Hakaru). Variable a -> Unroll (abt '[] a))
-> (forall (x :: Hakaru) (xs :: [Hakaru]) (a :: Hakaru).
    Variable x -> Unroll (abt xs a) -> Unroll (abt (x : xs) a))
-> (forall (a :: Hakaru).
    Unroll (Term abt a) -> Unroll (abt '[] a))
-> forall (xs :: [Hakaru]) (a :: Hakaru).
   abt xs a -> Unroll (abt xs a)
forall k (abt :: [k] -> k -> *) (syn :: ([k] -> k -> *) -> k -> *)
       (r :: [k] -> k -> *) (f :: * -> *).
(ABT syn abt, Traversable21 syn, Applicative f) =>
(forall (a :: k). Variable a -> f (r '[] a))
-> (forall (x :: k) (xs :: [k]) (a :: k).
    Variable x -> f (r xs a) -> f (r (x : xs) a))
-> (forall (a :: k). f (syn r a) -> f (r '[] a))
-> forall (xs :: [k]) (a :: k). abt xs a -> f (r xs a)
cataABTM forall (a :: Hakaru). Variable a -> Unroll (abt '[] a)
var_ forall (x :: Hakaru) (xs :: [Hakaru]) (a :: Hakaru).
Variable x -> Unroll (abt xs a) -> Unroll (abt (x : xs) a)
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *) (a :: Hakaru)
       (xs :: [Hakaru]) (b :: Hakaru).
(ABT Term abt, MonadReader Varmap m, MonadFix m) =>
Variable a -> m (abt xs b) -> m (abt (a : xs) b)
renameInEnv (Unroll (Term abt a)
-> (Term abt a -> Unroll (abt '[] a)) -> Unroll (abt '[] a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= Term abt a -> Unroll (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Term abt a -> Unroll (abt '[] a)
unrollTerm)
  where
    var_ :: Variable b -> Unroll (abt '[] b)
    var_ :: Variable b -> Unroll (abt '[] b)
var_ Variable b
v = (Varmap -> abt '[] b) -> Unroll Varmap -> Unroll (abt '[] b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Variable b -> abt '[] b
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var (Variable b -> abt '[] b)
-> (Varmap -> Variable b) -> Varmap -> abt '[] b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Variable b -> Maybe (Variable b) -> Variable b
forall a. a -> Maybe a -> a
fromMaybe Variable b
v (Maybe (Variable b) -> Variable b)
-> (Varmap -> Maybe (Variable b)) -> Varmap -> Variable b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Variable b -> Varmap -> Maybe (Variable b)
forall k (a :: k) (ast :: k -> *).
(Show1 Sing, JmEq1 Sing) =>
Variable a -> Assocs ast -> Maybe (ast a)
lookupAssoc Variable b
v) Unroll Varmap
forall r (m :: * -> *). MonadReader r m => m r
ask

mklet :: ABT Term abt => abt '[] b -> abt '[b] a -> abt '[] a
mklet :: abt '[] b -> abt '[b] a -> abt '[] a
mklet abt '[] b
rhs abt '[b] a
body = Term abt a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn (SCon '[LC b, '( '[b], a)] a
forall (a :: Hakaru) (b :: Hakaru). SCon '[LC a, '( '[a], b)] b
Let_ SCon '[LC b, '( '[b], a)] a
-> SArgs abt '[LC b, '( '[b], a)] -> Term abt a
forall (args :: [([Hakaru], Hakaru)]) (a :: Hakaru)
       (abt :: [Hakaru] -> Hakaru -> *).
SCon args a -> SArgs abt args -> Term abt a
:$ abt '[] b
rhs abt '[] b
-> SArgs abt '[ '( '[b], a)] -> SArgs abt '[LC b, '( '[b], a)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* abt '[b] a
body abt '[b] a -> SArgs abt '[] -> SArgs abt '[ '( '[b], a)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* SArgs abt '[]
forall (abt :: [Hakaru] -> Hakaru -> *). SArgs abt '[]
End)

mksummate, mkproduct
  :: (ABT Term abt)
  => HDiscrete a
  -> HSemiring b
  -> abt '[] a
  -> abt '[] a
  -> abt '[a] b
  -> abt '[] b
mksummate :: HDiscrete a
-> HSemiring b -> abt '[] a -> abt '[] a -> abt '[a] b -> abt '[] b
mksummate HDiscrete a
a HSemiring b
b abt '[] a
lo abt '[] a
hi abt '[a] b
body = Term abt b -> abt '[] b
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn (HDiscrete a -> HSemiring b -> SCon '[LC a, LC a, '( '[a], b)] b
forall (a :: Hakaru) (b :: Hakaru).
HDiscrete a -> HSemiring b -> SCon '[LC a, LC a, '( '[a], b)] b
Summate HDiscrete a
a HSemiring b
b SCon '[LC a, LC a, '( '[a], b)] b
-> SArgs abt '[LC a, LC a, '( '[a], b)] -> Term abt b
forall (args :: [([Hakaru], Hakaru)]) (a :: Hakaru)
       (abt :: [Hakaru] -> Hakaru -> *).
SCon args a -> SArgs abt args -> Term abt a
:$ abt '[] a
lo abt '[] a
-> SArgs abt '[LC a, '( '[a], b)]
-> SArgs abt '[LC a, LC a, '( '[a], b)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* abt '[] a
hi abt '[] a
-> SArgs abt '[ '( '[a], b)] -> SArgs abt '[LC a, '( '[a], b)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* abt '[a] b
body abt '[a] b -> SArgs abt '[] -> SArgs abt '[ '( '[a], b)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* SArgs abt '[]
forall (abt :: [Hakaru] -> Hakaru -> *). SArgs abt '[]
End)
mkproduct :: HDiscrete a
-> HSemiring b -> abt '[] a -> abt '[] a -> abt '[a] b -> abt '[] b
mkproduct HDiscrete a
a HSemiring b
b abt '[] a
lo abt '[] a
hi abt '[a] b
body = Term abt b -> abt '[] b
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn (HDiscrete a -> HSemiring b -> SCon '[LC a, LC a, '( '[a], b)] b
forall (a :: Hakaru) (b :: Hakaru).
HDiscrete a -> HSemiring b -> SCon '[LC a, LC a, '( '[a], b)] b
Product HDiscrete a
a HSemiring b
b SCon '[LC a, LC a, '( '[a], b)] b
-> SArgs abt '[LC a, LC a, '( '[a], b)] -> Term abt b
forall (args :: [([Hakaru], Hakaru)]) (a :: Hakaru)
       (abt :: [Hakaru] -> Hakaru -> *).
SCon args a -> SArgs abt args -> Term abt a
:$ abt '[] a
lo abt '[] a
-> SArgs abt '[LC a, '( '[a], b)]
-> SArgs abt '[LC a, LC a, '( '[a], b)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* abt '[] a
hi abt '[] a
-> SArgs abt '[ '( '[a], b)] -> SArgs abt '[LC a, '( '[a], b)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* abt '[a] b
body abt '[a] b -> SArgs abt '[] -> SArgs abt '[ '( '[a], b)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* SArgs abt '[]
forall (abt :: [Hakaru] -> Hakaru -> *). SArgs abt '[]
End)

unrollTerm
  :: (ABT Term abt)
  => Term abt a
  -> Unroll (abt '[] a)
unrollTerm :: Term abt a -> Unroll (abt '[] a)
unrollTerm (Summate HDiscrete a
disc HSemiring a
semi :$ abt vars a
lo :* abt vars a
hi :* abt vars a
body :* SArgs abt args
End) =
  case (HDiscrete a
disc, HSemiring a
semi) of
    (HDiscrete a
HDiscrete_Nat, HSemiring a
HSemiring_Nat)  -> HDiscrete a
-> HSemiring a
-> abt '[] a
-> abt '[] a
-> abt '[a] a
-> Unroll (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a) =>
HDiscrete a
-> HSemiring b
-> abt '[] a
-> abt '[] a
-> abt '[a] b
-> Unroll (abt '[] b)
unrollSummate HDiscrete a
disc HSemiring a
semi abt vars a
abt '[] a
lo abt vars a
abt '[] a
hi abt vars a
abt '[a] a
body
    (HDiscrete a
HDiscrete_Nat, HSemiring a
HSemiring_Int)  -> HDiscrete a
-> HSemiring a
-> abt '[] a
-> abt '[] a
-> abt '[a] a
-> Unroll (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a) =>
HDiscrete a
-> HSemiring b
-> abt '[] a
-> abt '[] a
-> abt '[a] b
-> Unroll (abt '[] b)
unrollSummate HDiscrete a
disc HSemiring a
semi abt vars a
abt '[] a
lo abt vars a
abt '[] a
hi abt vars a
abt '[a] a
body
    (HDiscrete a
HDiscrete_Nat, HSemiring a
HSemiring_Prob) -> HDiscrete a
-> HSemiring a
-> abt '[] a
-> abt '[] a
-> abt '[a] a
-> Unroll (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a) =>
HDiscrete a
-> HSemiring b
-> abt '[] a
-> abt '[] a
-> abt '[a] b
-> Unroll (abt '[] b)
unrollSummate HDiscrete a
disc HSemiring a
semi abt vars a
abt '[] a
lo abt vars a
abt '[] a
hi abt vars a
abt '[a] a
body
    (HDiscrete a
HDiscrete_Nat, HSemiring a
HSemiring_Real) -> HDiscrete a
-> HSemiring a
-> abt '[] a
-> abt '[] a
-> abt '[a] a
-> Unroll (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a) =>
HDiscrete a
-> HSemiring b
-> abt '[] a
-> abt '[] a
-> abt '[a] b
-> Unroll (abt '[] b)
unrollSummate HDiscrete a
disc HSemiring a
semi abt vars a
abt '[] a
lo abt vars a
abt '[] a
hi abt vars a
abt '[a] a
body

    (HDiscrete a
HDiscrete_Int, HSemiring a
HSemiring_Nat)  -> HDiscrete a
-> HSemiring a
-> abt '[] a
-> abt '[] a
-> abt '[a] a
-> Unroll (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a) =>
HDiscrete a
-> HSemiring b
-> abt '[] a
-> abt '[] a
-> abt '[a] b
-> Unroll (abt '[] b)
unrollSummate HDiscrete a
disc HSemiring a
semi abt vars a
abt '[] a
lo abt vars a
abt '[] a
hi abt vars a
abt '[a] a
body
    (HDiscrete a
HDiscrete_Int, HSemiring a
HSemiring_Int)  -> HDiscrete a
-> HSemiring a
-> abt '[] a
-> abt '[] a
-> abt '[a] a
-> Unroll (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a) =>
HDiscrete a
-> HSemiring b
-> abt '[] a
-> abt '[] a
-> abt '[a] b
-> Unroll (abt '[] b)
unrollSummate HDiscrete a
disc HSemiring a
semi abt vars a
abt '[] a
lo abt vars a
abt '[] a
hi abt vars a
abt '[a] a
body
    (HDiscrete a
HDiscrete_Int, HSemiring a
HSemiring_Prob) -> HDiscrete a
-> HSemiring a
-> abt '[] a
-> abt '[] a
-> abt '[a] a
-> Unroll (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a) =>
HDiscrete a
-> HSemiring b
-> abt '[] a
-> abt '[] a
-> abt '[a] b
-> Unroll (abt '[] b)
unrollSummate HDiscrete a
disc HSemiring a
semi abt vars a
abt '[] a
lo abt vars a
abt '[] a
hi abt vars a
abt '[a] a
body
    (HDiscrete a
HDiscrete_Int, HSemiring a
HSemiring_Real) -> HDiscrete a
-> HSemiring a
-> abt '[] a
-> abt '[] a
-> abt '[a] a
-> Unroll (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a) =>
HDiscrete a
-> HSemiring b
-> abt '[] a
-> abt '[] a
-> abt '[a] b
-> Unroll (abt '[] b)
unrollSummate HDiscrete a
disc HSemiring a
semi abt vars a
abt '[] a
lo abt vars a
abt '[] a
hi abt vars a
abt '[a] a
body

unrollTerm (Product HDiscrete a
disc HSemiring a
semi :$ abt vars a
lo :* abt vars a
hi :* abt vars a
body :* SArgs abt args
End) =
  case (HDiscrete a
disc, HSemiring a
semi) of
    (HDiscrete a
HDiscrete_Nat, HSemiring a
HSemiring_Nat)  -> HDiscrete a
-> HSemiring a
-> abt '[] a
-> abt '[] a
-> abt '[a] a
-> Unroll (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a) =>
HDiscrete a
-> HSemiring b
-> abt '[] a
-> abt '[] a
-> abt '[a] b
-> Unroll (abt '[] b)
unrollProduct HDiscrete a
disc HSemiring a
semi abt vars a
abt '[] a
lo abt vars a
abt '[] a
hi abt vars a
abt '[a] a
body
    (HDiscrete a
HDiscrete_Nat, HSemiring a
HSemiring_Int)  -> HDiscrete a
-> HSemiring a
-> abt '[] a
-> abt '[] a
-> abt '[a] a
-> Unroll (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a) =>
HDiscrete a
-> HSemiring b
-> abt '[] a
-> abt '[] a
-> abt '[a] b
-> Unroll (abt '[] b)
unrollProduct HDiscrete a
disc HSemiring a
semi abt vars a
abt '[] a
lo abt vars a
abt '[] a
hi abt vars a
abt '[a] a
body
    (HDiscrete a
HDiscrete_Nat, HSemiring a
HSemiring_Prob) -> HDiscrete a
-> HSemiring a
-> abt '[] a
-> abt '[] a
-> abt '[a] a
-> Unroll (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a) =>
HDiscrete a
-> HSemiring b
-> abt '[] a
-> abt '[] a
-> abt '[a] b
-> Unroll (abt '[] b)
unrollProduct HDiscrete a
disc HSemiring a
semi abt vars a
abt '[] a
lo abt vars a
abt '[] a
hi abt vars a
abt '[a] a
body
    (HDiscrete a
HDiscrete_Nat, HSemiring a
HSemiring_Real) -> HDiscrete a
-> HSemiring a
-> abt '[] a
-> abt '[] a
-> abt '[a] a
-> Unroll (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a) =>
HDiscrete a
-> HSemiring b
-> abt '[] a
-> abt '[] a
-> abt '[a] b
-> Unroll (abt '[] b)
unrollProduct HDiscrete a
disc HSemiring a
semi abt vars a
abt '[] a
lo abt vars a
abt '[] a
hi abt vars a
abt '[a] a
body

    (HDiscrete a
HDiscrete_Int, HSemiring a
HSemiring_Nat)  -> HDiscrete a
-> HSemiring a
-> abt '[] a
-> abt '[] a
-> abt '[a] a
-> Unroll (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a) =>
HDiscrete a
-> HSemiring b
-> abt '[] a
-> abt '[] a
-> abt '[a] b
-> Unroll (abt '[] b)
unrollProduct HDiscrete a
disc HSemiring a
semi abt vars a
abt '[] a
lo abt vars a
abt '[] a
hi abt vars a
abt '[a] a
body
    (HDiscrete a
HDiscrete_Int, HSemiring a
HSemiring_Int)  -> HDiscrete a
-> HSemiring a
-> abt '[] a
-> abt '[] a
-> abt '[a] a
-> Unroll (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a) =>
HDiscrete a
-> HSemiring b
-> abt '[] a
-> abt '[] a
-> abt '[a] b
-> Unroll (abt '[] b)
unrollProduct HDiscrete a
disc HSemiring a
semi abt vars a
abt '[] a
lo abt vars a
abt '[] a
hi abt vars a
abt '[a] a
body
    (HDiscrete a
HDiscrete_Int, HSemiring a
HSemiring_Prob) -> HDiscrete a
-> HSemiring a
-> abt '[] a
-> abt '[] a
-> abt '[a] a
-> Unroll (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a) =>
HDiscrete a
-> HSemiring b
-> abt '[] a
-> abt '[] a
-> abt '[a] b
-> Unroll (abt '[] b)
unrollProduct HDiscrete a
disc HSemiring a
semi abt vars a
abt '[] a
lo abt vars a
abt '[] a
hi abt vars a
abt '[a] a
body
    (HDiscrete a
HDiscrete_Int, HSemiring a
HSemiring_Real) -> HDiscrete a
-> HSemiring a
-> abt '[] a
-> abt '[] a
-> abt '[a] a
-> Unroll (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a) =>
HDiscrete a
-> HSemiring b
-> abt '[] a
-> abt '[] a
-> abt '[a] b
-> Unroll (abt '[] b)
unrollProduct HDiscrete a
disc HSemiring a
semi abt vars a
abt '[] a
lo abt vars a
abt '[] a
hi abt vars a
abt '[a] a
body

unrollTerm Term abt a
term = abt '[] a -> Unroll (abt '[] a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Term abt a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn Term abt a
term)

-- Conditionally introduce a variable for the rhs if the rhs is not currently a
-- variable already. Be careful that the provided variable has been remaped to
-- its equivalent in the target term if altering the binding structure of the
-- program.
letM' :: (Functor m, MonadFix m, ABT Term abt)
      => abt '[] a
      -> (abt '[] a -> m (abt '[] b))
      -> m (abt '[] b)
letM' :: abt '[] a -> (abt '[] a -> m (abt '[] b)) -> m (abt '[] b)
letM' abt '[] a
e abt '[] a -> m (abt '[] b)
f =
  case abt '[] a -> View (Term abt) '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (a :: k).
ABT syn abt =>
abt xs a -> View (syn abt) xs a
viewABT abt '[] a
e of
    Var Variable a
_            -> abt '[] a -> m (abt '[] b)
f abt '[] a
e
    Syn (Literal_ _) -> abt '[] a -> m (abt '[] b)
f abt '[] a
e
    View (Term abt) '[] a
_                -> abt '[] a -> (abt '[] a -> m (abt '[] b)) -> m (abt '[] b)
forall (m :: * -> *) (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(Functor m, MonadFix m, ABT Term abt) =>
abt '[] a -> (abt '[] a -> m (abt '[] b)) -> m (abt '[] b)
letM abt '[] a
e abt '[] a -> m (abt '[] b)
f

unrollSummate
  :: (ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a)
  => HDiscrete a
  -> HSemiring b
  -> abt '[] a
  -> abt '[] a
  -> abt '[a] b
  -> Unroll (abt '[] b)
unrollSummate :: HDiscrete a
-> HSemiring b
-> abt '[] a
-> abt '[] a
-> abt '[a] b
-> Unroll (abt '[] b)
unrollSummate HDiscrete a
disc HSemiring b
semi abt '[] a
lo abt '[] a
hi abt '[a] b
body =
   abt '[] a
-> (abt '[] a -> Unroll (abt '[] b)) -> Unroll (abt '[] b)
forall (m :: * -> *) (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(Functor m, MonadFix m, ABT Term abt) =>
abt '[] a -> (abt '[] a -> m (abt '[] b)) -> m (abt '[] b)
letM' abt '[] a
lo ((abt '[] a -> Unroll (abt '[] b)) -> Unroll (abt '[] b))
-> (abt '[] a -> Unroll (abt '[] b)) -> Unroll (abt '[] b)
forall a b. (a -> b) -> a -> b
$ \abt '[] a
loVar ->
     abt '[] a
-> (abt '[] a -> Unroll (abt '[] b)) -> Unroll (abt '[] b)
forall (m :: * -> *) (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(Functor m, MonadFix m, ABT Term abt) =>
abt '[] a -> (abt '[] a -> m (abt '[] b)) -> m (abt '[] b)
letM' abt '[] a
hi ((abt '[] a -> Unroll (abt '[] b)) -> Unroll (abt '[] b))
-> (abt '[] a -> Unroll (abt '[] b)) -> Unroll (abt '[] b)
forall a b. (a -> b) -> a -> b
$ \abt '[] a
hiVar ->
       let preamble :: abt '[] b
preamble = abt '[] a -> abt '[a] b -> abt '[] b
forall (abt :: [Hakaru] -> Hakaru -> *) (b :: Hakaru)
       (a :: Hakaru).
ABT Term abt =>
abt '[] b -> abt '[b] a -> abt '[] a
mklet abt '[] a
loVar abt '[a] b
body
           loop :: abt '[] b
loop     = HDiscrete a
-> HSemiring b -> abt '[] a -> abt '[] a -> abt '[a] b -> abt '[] b
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
ABT Term abt =>
HDiscrete a
-> HSemiring b -> abt '[] a -> abt '[] a -> abt '[a] b -> abt '[] b
mksummate HDiscrete a
disc HSemiring b
semi (abt '[] a
loVar abt '[] a -> abt '[] a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
+ abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a
one) abt '[] a
hiVar abt '[a] b
body
       in abt '[] b -> Unroll (abt '[] b)
forall (m :: * -> *) a. Monad m => a -> m a
return (abt '[] b -> Unroll (abt '[] b))
-> abt '[] b -> Unroll (abt '[] b)
forall a b. (a -> b) -> a -> b
$ abt '[] HBool -> abt '[] b -> abt '[] b -> abt '[] b
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] HBool -> abt '[] a -> abt '[] a -> abt '[] a
if_ (abt '[] a
loVar abt '[] a -> abt '[] a -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HEq_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
== abt '[] a
hiVar) abt '[] b
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a
zero (abt '[] b
preamble abt '[] b -> abt '[] b -> abt '[] b
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
+ abt '[] b
loop)

unrollProduct
  :: (ABT Term abt, HSemiring_ a, HSemiring_ b, HEq_ a)
  => HDiscrete a
  -> HSemiring b
  -> abt '[] a
  -> abt '[] a
  -> abt '[a] b
  -> Unroll (abt '[] b)
unrollProduct :: HDiscrete a
-> HSemiring b
-> abt '[] a
-> abt '[] a
-> abt '[a] b
-> Unroll (abt '[] b)
unrollProduct HDiscrete a
disc HSemiring b
semi abt '[] a
lo abt '[] a
hi abt '[a] b
body =
   abt '[] a
-> (abt '[] a -> Unroll (abt '[] b)) -> Unroll (abt '[] b)
forall (m :: * -> *) (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(Functor m, MonadFix m, ABT Term abt) =>
abt '[] a -> (abt '[] a -> m (abt '[] b)) -> m (abt '[] b)
letM' abt '[] a
lo ((abt '[] a -> Unroll (abt '[] b)) -> Unroll (abt '[] b))
-> (abt '[] a -> Unroll (abt '[] b)) -> Unroll (abt '[] b)
forall a b. (a -> b) -> a -> b
$ \abt '[] a
loVar ->
     abt '[] a
-> (abt '[] a -> Unroll (abt '[] b)) -> Unroll (abt '[] b)
forall (m :: * -> *) (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
(Functor m, MonadFix m, ABT Term abt) =>
abt '[] a -> (abt '[] a -> m (abt '[] b)) -> m (abt '[] b)
letM' abt '[] a
hi ((abt '[] a -> Unroll (abt '[] b)) -> Unroll (abt '[] b))
-> (abt '[] a -> Unroll (abt '[] b)) -> Unroll (abt '[] b)
forall a b. (a -> b) -> a -> b
$ \abt '[] a
hiVar ->
       let preamble :: abt '[] b
preamble = abt '[] a -> abt '[a] b -> abt '[] b
forall (abt :: [Hakaru] -> Hakaru -> *) (b :: Hakaru)
       (a :: Hakaru).
ABT Term abt =>
abt '[] b -> abt '[b] a -> abt '[] a
mklet abt '[] a
loVar abt '[a] b
body
           loop :: abt '[] b
loop     = HDiscrete a
-> HSemiring b -> abt '[] a -> abt '[] a -> abt '[a] b -> abt '[] b
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
ABT Term abt =>
HDiscrete a
-> HSemiring b -> abt '[] a -> abt '[] a -> abt '[a] b -> abt '[] b
mkproduct HDiscrete a
disc HSemiring b
semi (abt '[] a
loVar abt '[] a -> abt '[] a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
+ abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a
one) abt '[] a
hiVar abt '[a] b
body
       in abt '[] b -> Unroll (abt '[] b)
forall (m :: * -> *) a. Monad m => a -> m a
return (abt '[] b -> Unroll (abt '[] b))
-> abt '[] b -> Unroll (abt '[] b)
forall a b. (a -> b) -> a -> b
$ abt '[] HBool -> abt '[] b -> abt '[] b -> abt '[] b
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] HBool -> abt '[] a -> abt '[] a -> abt '[] a
if_ (abt '[] a
loVar abt '[] a -> abt '[] a -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HEq_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
== abt '[] a
hiVar) abt '[] b
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a
one (abt '[] b
preamble abt '[] b -> abt '[] b -> abt '[] b
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
* abt '[] b
loop)