{-# LANGUAGE CPP
           , FlexibleInstances
           , GADTs
           , DataKinds
           , TypeOperators
           , KindSignatures
           , LambdaCase
           , ViewPatterns
           , DeriveDataTypeable
           , StandaloneDeriving
           , OverlappingInstances
           , UndecidableInstances
           , RankNTypes
           #-}

{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
-- |
-- Module      :  Language.Hakaru.Syntax.Transform
-- Copyright   :  Copyright (c) 2016 the Hakaru team
-- License     :  BSD3
-- Stability   :  experimental
-- Portability :  GHC-only
--
-- The internal syntax of Hakaru transformations, which are functions on Hakaru
-- terms which are neither primitive, nor expressible in terms of Hakaru
-- primitives.
----------------------------------------------------------------
module Language.Hakaru.Syntax.Transform
  (
  -- * Transformation internal syntax
    TransformImpl(..)
  , Transform(..)
  -- * Some utilities
  , transformName, allTransforms
  -- * Mapping of input type to output type for transforms
  , typeOfTransform
  -- * Transformation contexts
  , TransformCtx(..), HasTransformCtx(..), unionCtx, minimalCtx
  -- * Transformation tables
  , TransformTable(..), lookupTransform', simpleTable
  , unionTable, someTransformations
  )
  where


import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Syntax.SArgs
import Language.Hakaru.Syntax.IClasses
import Language.Hakaru.Syntax.Variable
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing

import Control.Applicative (Alternative(..), Applicative(..))
import Data.Number.Nat
import Data.Data (Data, Typeable)
import Data.List (stripPrefix)
import Data.Monoid (Monoid(..))

#if !(MIN_VERSION_base(4,11,0))
import Data.Semigroup
#endif

----------------------------------------------------------------

-- | Some transformations have the same type and 'same' semantics, but are
--   implemented in multiple different ways. Such transformations are
--   distinguished in concrete syntax by differing keywords.
data TransformImpl = InMaple | InHaskell
  deriving (TransformImpl -> TransformImpl -> Bool
(TransformImpl -> TransformImpl -> Bool)
-> (TransformImpl -> TransformImpl -> Bool) -> Eq TransformImpl
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TransformImpl -> TransformImpl -> Bool
$c/= :: TransformImpl -> TransformImpl -> Bool
== :: TransformImpl -> TransformImpl -> Bool
$c== :: TransformImpl -> TransformImpl -> Bool
Eq, Eq TransformImpl
Eq TransformImpl
-> (TransformImpl -> TransformImpl -> Ordering)
-> (TransformImpl -> TransformImpl -> Bool)
-> (TransformImpl -> TransformImpl -> Bool)
-> (TransformImpl -> TransformImpl -> Bool)
-> (TransformImpl -> TransformImpl -> Bool)
-> (TransformImpl -> TransformImpl -> TransformImpl)
-> (TransformImpl -> TransformImpl -> TransformImpl)
-> Ord TransformImpl
TransformImpl -> TransformImpl -> Bool
TransformImpl -> TransformImpl -> Ordering
TransformImpl -> TransformImpl -> TransformImpl
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: TransformImpl -> TransformImpl -> TransformImpl
$cmin :: TransformImpl -> TransformImpl -> TransformImpl
max :: TransformImpl -> TransformImpl -> TransformImpl
$cmax :: TransformImpl -> TransformImpl -> TransformImpl
>= :: TransformImpl -> TransformImpl -> Bool
$c>= :: TransformImpl -> TransformImpl -> Bool
> :: TransformImpl -> TransformImpl -> Bool
$c> :: TransformImpl -> TransformImpl -> Bool
<= :: TransformImpl -> TransformImpl -> Bool
$c<= :: TransformImpl -> TransformImpl -> Bool
< :: TransformImpl -> TransformImpl -> Bool
$c< :: TransformImpl -> TransformImpl -> Bool
compare :: TransformImpl -> TransformImpl -> Ordering
$ccompare :: TransformImpl -> TransformImpl -> Ordering
$cp1Ord :: Eq TransformImpl
Ord, Int -> TransformImpl -> ShowS
[TransformImpl] -> ShowS
TransformImpl -> String
(Int -> TransformImpl -> ShowS)
-> (TransformImpl -> String)
-> ([TransformImpl] -> ShowS)
-> Show TransformImpl
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransformImpl] -> ShowS
$cshowList :: [TransformImpl] -> ShowS
show :: TransformImpl -> String
$cshow :: TransformImpl -> String
showsPrec :: Int -> TransformImpl -> ShowS
$cshowsPrec :: Int -> TransformImpl -> ShowS
Show, ReadPrec [TransformImpl]
ReadPrec TransformImpl
Int -> ReadS TransformImpl
ReadS [TransformImpl]
(Int -> ReadS TransformImpl)
-> ReadS [TransformImpl]
-> ReadPrec TransformImpl
-> ReadPrec [TransformImpl]
-> Read TransformImpl
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [TransformImpl]
$creadListPrec :: ReadPrec [TransformImpl]
readPrec :: ReadPrec TransformImpl
$creadPrec :: ReadPrec TransformImpl
readList :: ReadS [TransformImpl]
$creadList :: ReadS [TransformImpl]
readsPrec :: Int -> ReadS TransformImpl
$creadsPrec :: Int -> ReadS TransformImpl
Read, Typeable TransformImpl
DataType
Constr
Typeable TransformImpl
-> (forall (c :: * -> *).
    (forall d b. Data d => c (d -> b) -> d -> c b)
    -> (forall g. g -> c g) -> TransformImpl -> c TransformImpl)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c TransformImpl)
-> (TransformImpl -> Constr)
-> (TransformImpl -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c TransformImpl))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e))
    -> Maybe (c TransformImpl))
-> ((forall b. Data b => b -> b) -> TransformImpl -> TransformImpl)
-> (forall r r'.
    (r -> r' -> r)
    -> r -> (forall d. Data d => d -> r') -> TransformImpl -> r)
-> (forall r r'.
    (r' -> r -> r)
    -> r -> (forall d. Data d => d -> r') -> TransformImpl -> r)
-> (forall u. (forall d. Data d => d -> u) -> TransformImpl -> [u])
-> (forall u.
    Int -> (forall d. Data d => d -> u) -> TransformImpl -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> TransformImpl -> m TransformImpl)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> TransformImpl -> m TransformImpl)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> TransformImpl -> m TransformImpl)
-> Data TransformImpl
TransformImpl -> DataType
TransformImpl -> Constr
(forall b. Data b => b -> b) -> TransformImpl -> TransformImpl
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> TransformImpl -> c TransformImpl
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c TransformImpl
forall a.
Typeable a
-> (forall (c :: * -> *).
    (forall d b. Data d => c (d -> b) -> d -> c b)
    -> (forall g. g -> c g) -> a -> c a)
-> (forall (c :: * -> *).
    (forall b r. Data b => c (b -> r) -> c r)
    -> (forall r. r -> c r) -> Constr -> c a)
-> (a -> Constr)
-> (a -> DataType)
-> (forall (t :: * -> *) (c :: * -> *).
    Typeable t =>
    (forall d. Data d => c (t d)) -> Maybe (c a))
-> (forall (t :: * -> * -> *) (c :: * -> *).
    Typeable t =>
    (forall d e. (Data d, Data e) => c (t d e)) -> Maybe (c a))
-> ((forall b. Data b => b -> b) -> a -> a)
-> (forall r r'.
    (r -> r' -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall r r'.
    (r' -> r -> r) -> r -> (forall d. Data d => d -> r') -> a -> r)
-> (forall u. (forall d. Data d => d -> u) -> a -> [u])
-> (forall u. Int -> (forall d. Data d => d -> u) -> a -> u)
-> (forall (m :: * -> *).
    Monad m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> (forall (m :: * -> *).
    MonadPlus m =>
    (forall d. Data d => d -> m d) -> a -> m a)
-> Data a
forall u. Int -> (forall d. Data d => d -> u) -> TransformImpl -> u
forall u. (forall d. Data d => d -> u) -> TransformImpl -> [u]
forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> TransformImpl -> r
forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> TransformImpl -> r
forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> TransformImpl -> m TransformImpl
forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> TransformImpl -> m TransformImpl
forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c TransformImpl
forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> TransformImpl -> c TransformImpl
forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c TransformImpl)
forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c TransformImpl)
$cInHaskell :: Constr
$cInMaple :: Constr
$tTransformImpl :: DataType
gmapMo :: (forall d. Data d => d -> m d) -> TransformImpl -> m TransformImpl
$cgmapMo :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> TransformImpl -> m TransformImpl
gmapMp :: (forall d. Data d => d -> m d) -> TransformImpl -> m TransformImpl
$cgmapMp :: forall (m :: * -> *).
MonadPlus m =>
(forall d. Data d => d -> m d) -> TransformImpl -> m TransformImpl
gmapM :: (forall d. Data d => d -> m d) -> TransformImpl -> m TransformImpl
$cgmapM :: forall (m :: * -> *).
Monad m =>
(forall d. Data d => d -> m d) -> TransformImpl -> m TransformImpl
gmapQi :: Int -> (forall d. Data d => d -> u) -> TransformImpl -> u
$cgmapQi :: forall u. Int -> (forall d. Data d => d -> u) -> TransformImpl -> u
gmapQ :: (forall d. Data d => d -> u) -> TransformImpl -> [u]
$cgmapQ :: forall u. (forall d. Data d => d -> u) -> TransformImpl -> [u]
gmapQr :: (r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> TransformImpl -> r
$cgmapQr :: forall r r'.
(r' -> r -> r)
-> r -> (forall d. Data d => d -> r') -> TransformImpl -> r
gmapQl :: (r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> TransformImpl -> r
$cgmapQl :: forall r r'.
(r -> r' -> r)
-> r -> (forall d. Data d => d -> r') -> TransformImpl -> r
gmapT :: (forall b. Data b => b -> b) -> TransformImpl -> TransformImpl
$cgmapT :: (forall b. Data b => b -> b) -> TransformImpl -> TransformImpl
dataCast2 :: (forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c TransformImpl)
$cdataCast2 :: forall (t :: * -> * -> *) (c :: * -> *).
Typeable t =>
(forall d e. (Data d, Data e) => c (t d e))
-> Maybe (c TransformImpl)
dataCast1 :: (forall d. Data d => c (t d)) -> Maybe (c TransformImpl)
$cdataCast1 :: forall (t :: * -> *) (c :: * -> *).
Typeable t =>
(forall d. Data d => c (t d)) -> Maybe (c TransformImpl)
dataTypeOf :: TransformImpl -> DataType
$cdataTypeOf :: TransformImpl -> DataType
toConstr :: TransformImpl -> Constr
$ctoConstr :: TransformImpl -> Constr
gunfold :: (forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c TransformImpl
$cgunfold :: forall (c :: * -> *).
(forall b r. Data b => c (b -> r) -> c r)
-> (forall r. r -> c r) -> Constr -> c TransformImpl
gfoldl :: (forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> TransformImpl -> c TransformImpl
$cgfoldl :: forall (c :: * -> *).
(forall d b. Data d => c (d -> b) -> d -> c b)
-> (forall g. g -> c g) -> TransformImpl -> c TransformImpl
$cp1Data :: Typeable TransformImpl
Data, Typeable)

-- | Transformations and their types. Like 'Language.Hakaru.Syntax.AST.SCon'.
data Transform :: [([Hakaru], Hakaru)] -> Hakaru -> * where
  Expect  ::
    Transform
      '[ LC ('HMeasure a), '( '[ a ], 'HProb) ] 'HProb

  Observe ::
    Transform
      '[ LC ('HMeasure a), LC a ] ('HMeasure a)

  MH ::
    Transform
      '[ LC (a ':-> 'HMeasure a), LC ('HMeasure a) ]
      (a ':-> 'HMeasure (HPair a 'HProb))

  MCMC ::
    Transform
      '[ LC (a ':-> 'HMeasure a), LC ('HMeasure a) ]
      (a ':-> 'HMeasure a)

  Disint :: TransformImpl ->
    Transform
      '[ LC ('HMeasure (HPair a b)) ]
      (a :-> 'HMeasure b)

  Summarize ::
    Transform '[ LC a ] a

  Simplify ::
    Transform '[ LC a ] a

  Reparam ::
    Transform '[ LC a ] a

deriving instance Eq   (Transform args a)
deriving instance Show (Transform args a)

instance Eq (Some2 Transform) where
  Some2 Transform i j
t0 == :: Some2 Transform -> Some2 Transform -> Bool
== Some2 Transform i j
t1 =
    case (Transform i j
t0, Transform i j
t1) of
      (Transform i j
Expect    , Transform i j
Expect   ) -> Bool
True
      (Transform i j
Observe   , Transform i j
Observe  ) -> Bool
True
      (Transform i j
MH        , Transform i j
MH       ) -> Bool
True
      (Transform i j
MCMC      , Transform i j
MCMC     ) -> Bool
True
      (Disint TransformImpl
k0 , Disint TransformImpl
k1) -> TransformImpl
k0TransformImpl -> TransformImpl -> Bool
forall a. Eq a => a -> a -> Bool
==TransformImpl
k1
      (Transform i j
Summarize , Transform i j
Summarize) -> Bool
True
      (Transform i j
Simplify  , Transform i j
Simplify ) -> Bool
True
      (Transform i j
Reparam   , Transform i j
Reparam  ) -> Bool
True
      (Transform i j, Transform i j)
_ -> Bool
False

instance Read (Some2 Transform) where
  readsPrec :: Int -> ReadS (Some2 Transform)
readsPrec Int
_ String
s =
    let trs :: [(String, Some2 Transform)]
trs = (Some2 Transform -> (String, Some2 Transform))
-> [Some2 Transform] -> [(String, Some2 Transform)]
forall a b. (a -> b) -> [a] -> [b]
map (\t' :: Some2 Transform
t'@(Some2 Transform i j
t) -> (Transform i j -> String
forall a. Show a => a -> String
show Transform i j
t, Some2 Transform
t')) [Some2 Transform]
allTransforms
        readMay :: (String, Some2 Transform) -> [(Some2 Transform, String)]
readMay (String
s', Some2 Transform
t)
          | Just String
rs <- String -> String -> Maybe String
forall a. Eq a => [a] -> [a] -> Maybe [a]
stripPrefix String
s' String
s = [(Some2 Transform
t, String
rs)]
          | Bool
otherwise                   = []
    in ((String, Some2 Transform) -> [(Some2 Transform, String)])
-> [(String, Some2 Transform)] -> [(Some2 Transform, String)]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (String, Some2 Transform) -> [(Some2 Transform, String)]
readMay [(String, Some2 Transform)]
trs

-- | The concrete syntax names of transformations.
transformName :: Transform args a -> String
transformName :: Transform args a -> String
transformName =
  \case
    Transform args a
Expect    -> String
"expect"
    Transform args a
Observe   -> String
"observe"
    Transform args a
MH        -> String
"mh"
    Transform args a
MCMC      -> String
"mcmc"
    Disint TransformImpl
k  -> String
"disint" String -> ShowS
forall a. [a] -> [a] -> [a]
++
      (case TransformImpl
k of
         TransformImpl
InHaskell -> String
""
         TransformImpl
InMaple   -> String
"_m")
    Transform args a
Summarize -> String
"summarize"
    Transform args a
Simplify  -> String
"simplify"
    Transform args a
Reparam   -> String
"reparam"

-- | All transformations.
allTransforms :: [Some2 Transform]
allTransforms :: [Some2 Transform]
allTransforms =
  [ Transform '[LC ('HMeasure Any), '( '[Any], 'HProb)] 'HProb
-> Some2 Transform
forall k1 k2 (a :: k1 -> k2 -> *) (i :: k1) (j :: k2).
a i j -> Some2 a
Some2 Transform '[LC ('HMeasure Any), '( '[Any], 'HProb)] 'HProb
forall (a :: Hakaru).
Transform '[LC ('HMeasure a), '( '[a], 'HProb)] 'HProb
Expect, Transform '[LC ('HMeasure Any), LC Any] ('HMeasure Any)
-> Some2 Transform
forall k1 k2 (a :: k1 -> k2 -> *) (i :: k1) (j :: k2).
a i j -> Some2 a
Some2 Transform '[LC ('HMeasure Any), LC Any] ('HMeasure Any)
forall (a :: Hakaru).
Transform '[LC ('HMeasure a), LC a] ('HMeasure a)
Observe, Transform
  '[LC (Any ':-> 'HMeasure Any), LC ('HMeasure Any)]
  (Any ':-> 'HMeasure (HPair Any 'HProb))
-> Some2 Transform
forall k1 k2 (a :: k1 -> k2 -> *) (i :: k1) (j :: k2).
a i j -> Some2 a
Some2 Transform
  '[LC (Any ':-> 'HMeasure Any), LC ('HMeasure Any)]
  (Any ':-> 'HMeasure (HPair Any 'HProb))
forall (a :: Hakaru).
Transform
  '[LC (a ':-> 'HMeasure a), LC ('HMeasure a)]
  (a ':-> 'HMeasure (HPair a 'HProb))
MH, Transform
  '[LC (Any ':-> 'HMeasure Any), LC ('HMeasure Any)]
  (Any ':-> 'HMeasure Any)
-> Some2 Transform
forall k1 k2 (a :: k1 -> k2 -> *) (i :: k1) (j :: k2).
a i j -> Some2 a
Some2 Transform
  '[LC (Any ':-> 'HMeasure Any), LC ('HMeasure Any)]
  (Any ':-> 'HMeasure Any)
forall (a :: Hakaru).
Transform
  '[LC (a ':-> 'HMeasure a), LC ('HMeasure a)] (a ':-> 'HMeasure a)
MCMC
  , Transform
  '[LC ('HMeasure (HPair Any Any))] (Any ':-> 'HMeasure Any)
-> Some2 Transform
forall k1 k2 (a :: k1 -> k2 -> *) (i :: k1) (j :: k2).
a i j -> Some2 a
Some2 (TransformImpl
-> Transform
     '[LC ('HMeasure (HPair Any Any))] (Any ':-> 'HMeasure Any)
forall (a :: Hakaru) (b :: Hakaru).
TransformImpl
-> Transform '[LC ('HMeasure (HPair a b))] (a ':-> 'HMeasure b)
Disint TransformImpl
InHaskell), Transform
  '[LC ('HMeasure (HPair Any Any))] (Any ':-> 'HMeasure Any)
-> Some2 Transform
forall k1 k2 (a :: k1 -> k2 -> *) (i :: k1) (j :: k2).
a i j -> Some2 a
Some2 (TransformImpl
-> Transform
     '[LC ('HMeasure (HPair Any Any))] (Any ':-> 'HMeasure Any)
forall (a :: Hakaru) (b :: Hakaru).
TransformImpl
-> Transform '[LC ('HMeasure (HPair a b))] (a ':-> 'HMeasure b)
Disint TransformImpl
InMaple)
  , Transform '[LC Any] Any -> Some2 Transform
forall k1 k2 (a :: k1 -> k2 -> *) (i :: k1) (j :: k2).
a i j -> Some2 a
Some2 Transform '[LC Any] Any
forall (a :: Hakaru). Transform '[LC a] a
Summarize, Transform '[LC Any] Any -> Some2 Transform
forall k1 k2 (a :: k1 -> k2 -> *) (i :: k1) (j :: k2).
a i j -> Some2 a
Some2 Transform '[LC Any] Any
forall (a :: Hakaru). Transform '[LC a] a
Simplify, Transform '[LC Any] Any -> Some2 Transform
forall k1 k2 (a :: k1 -> k2 -> *) (i :: k1) (j :: k2).
a i j -> Some2 a
Some2 Transform '[LC Any] Any
forall (a :: Hakaru). Transform '[LC a] a
Reparam ]

typeOfTransform
    :: Transform as x
    -> SArgsSing as
    -> Sing x
typeOfTransform :: Transform as x -> SArgsSing as -> Sing x
typeOfTransform Transform as x
t SArgsSing as
as =
  case (Transform as x
t,SArgsSing as
as) of
    (Transform as x
Expect   , SArgsSing as
_)
      -> Sing x
Sing 'HProb
SProb
    (Transform as x
Observe  , Pw Lift1 () vars
_ Sing a
e :* Pointwise (Lift1 ()) Sing vars a
_ :* SArgs (Pointwise (Lift1 ()) Sing) args
End)
      -> Sing x
Sing a
e
    (Transform as x
MH       , Pw Lift1 () vars
_ ((Sing a, Sing ('HMeasure a)) -> Sing a
forall a b. (a, b) -> a
fst((Sing a, Sing ('HMeasure a)) -> Sing a)
-> (Sing (a ':-> 'HMeasure a) -> (Sing a, Sing ('HMeasure a)))
-> Sing (a ':-> 'HMeasure a)
-> Sing a
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Sing (a ':-> 'HMeasure a) -> (Sing a, Sing ('HMeasure a))
forall (a :: Hakaru) (b :: Hakaru).
Sing (a ':-> b) -> (Sing a, Sing b)
sUnFun -> Sing a
a) :* Pointwise (Lift1 ()) Sing vars a
_ :* SArgs (Pointwise (Lift1 ()) Sing) args
End)
      -> Sing a
-> Sing ('HMeasure (HPair a 'HProb))
-> Sing (a ':-> 'HMeasure (HPair a 'HProb))
forall (a :: Hakaru) (b :: Hakaru).
Sing a -> Sing b -> Sing (a ':-> b)
SFun Sing a
a (Sing (HPair a 'HProb) -> Sing ('HMeasure (HPair a 'HProb))
forall (a :: Hakaru). Sing a -> Sing ('HMeasure a)
SMeasure (Sing a -> Sing 'HProb -> Sing (HPair a 'HProb)
forall (a :: Hakaru) (b :: Hakaru).
Sing a -> Sing b -> Sing (HPair a b)
sPair Sing a
a Sing 'HProb
SProb))
    (Transform as x
MCMC     , Pw Lift1 () vars
_ Sing a
a :* SArgs (Pointwise (Lift1 ()) Sing) args
_)
      -> Sing x
Sing a
a
    (Disint TransformImpl
_ , Pw Lift1 () vars
_ (Sing (HPair a b) -> (Sing a, Sing b)
forall (a :: Hakaru) (b :: Hakaru).
Sing (HPair a b) -> (Sing a, Sing b)
sUnPair(Sing (HPair a b) -> (Sing a, Sing b))
-> (Sing ('HMeasure (HPair a b)) -> Sing (HPair a b))
-> Sing ('HMeasure (HPair a b))
-> (Sing a, Sing b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
.Sing ('HMeasure (HPair a b)) -> Sing (HPair a b)
forall (a :: Hakaru). Sing ('HMeasure a) -> Sing a
sUnMeasure -> (Sing a
a,Sing b
b)) :* SArgs (Pointwise (Lift1 ()) Sing) args
End)
      -> Sing a -> Sing ('HMeasure b) -> Sing (a ':-> 'HMeasure b)
forall (a :: Hakaru) (b :: Hakaru).
Sing a -> Sing b -> Sing (a ':-> b)
SFun Sing a
a (Sing b -> Sing ('HMeasure b)
forall (a :: Hakaru). Sing a -> Sing ('HMeasure a)
SMeasure Sing b
b)
    (Transform as x
Summarize, Pw Lift1 () vars
_ Sing a
e :* SArgs (Pointwise (Lift1 ()) Sing) args
End)
      -> Sing x
Sing a
e
    (Transform as x
Simplify , Pw Lift1 () vars
_ Sing a
e :* SArgs (Pointwise (Lift1 ()) Sing) args
End)
      -> Sing x
Sing a
e
    (Transform as x
Reparam  , Pw Lift1 () vars
_ Sing a
e :* SArgs (Pointwise (Lift1 ()) Sing) args
End)
      -> Sing x
Sing a
e

-- | The context in which a transformation is called.  Currently this is simply
--   the next free variable in the enclosing program, but it could one day be
--   expanded to include more information, e.g., an association of variables to
--   terms in the enclosing program.
newtype TransformCtx = TransformCtx
  { TransformCtx -> Nat
nextFreeVar :: Nat }
    deriving (TransformCtx -> TransformCtx -> Bool
(TransformCtx -> TransformCtx -> Bool)
-> (TransformCtx -> TransformCtx -> Bool) -> Eq TransformCtx
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: TransformCtx -> TransformCtx -> Bool
$c/= :: TransformCtx -> TransformCtx -> Bool
== :: TransformCtx -> TransformCtx -> Bool
$c== :: TransformCtx -> TransformCtx -> Bool
Eq, Eq TransformCtx
Eq TransformCtx
-> (TransformCtx -> TransformCtx -> Ordering)
-> (TransformCtx -> TransformCtx -> Bool)
-> (TransformCtx -> TransformCtx -> Bool)
-> (TransformCtx -> TransformCtx -> Bool)
-> (TransformCtx -> TransformCtx -> Bool)
-> (TransformCtx -> TransformCtx -> TransformCtx)
-> (TransformCtx -> TransformCtx -> TransformCtx)
-> Ord TransformCtx
TransformCtx -> TransformCtx -> Bool
TransformCtx -> TransformCtx -> Ordering
TransformCtx -> TransformCtx -> TransformCtx
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: TransformCtx -> TransformCtx -> TransformCtx
$cmin :: TransformCtx -> TransformCtx -> TransformCtx
max :: TransformCtx -> TransformCtx -> TransformCtx
$cmax :: TransformCtx -> TransformCtx -> TransformCtx
>= :: TransformCtx -> TransformCtx -> Bool
$c>= :: TransformCtx -> TransformCtx -> Bool
> :: TransformCtx -> TransformCtx -> Bool
$c> :: TransformCtx -> TransformCtx -> Bool
<= :: TransformCtx -> TransformCtx -> Bool
$c<= :: TransformCtx -> TransformCtx -> Bool
< :: TransformCtx -> TransformCtx -> Bool
$c< :: TransformCtx -> TransformCtx -> Bool
compare :: TransformCtx -> TransformCtx -> Ordering
$ccompare :: TransformCtx -> TransformCtx -> Ordering
$cp1Ord :: Eq TransformCtx
Ord, Int -> TransformCtx -> ShowS
[TransformCtx] -> ShowS
TransformCtx -> String
(Int -> TransformCtx -> ShowS)
-> (TransformCtx -> String)
-> ([TransformCtx] -> ShowS)
-> Show TransformCtx
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [TransformCtx] -> ShowS
$cshowList :: [TransformCtx] -> ShowS
show :: TransformCtx -> String
$cshow :: TransformCtx -> String
showsPrec :: Int -> TransformCtx -> ShowS
$cshowsPrec :: Int -> TransformCtx -> ShowS
Show)

-- | The smallest possible context, i.e. a default context suitable for use when
-- performing induction on terms which may contain transformations as subterms.
minimalCtx :: TransformCtx
minimalCtx :: TransformCtx
minimalCtx = TransformCtx :: Nat -> TransformCtx
TransformCtx { nextFreeVar :: Nat
nextFreeVar = Nat
0 }

-- | The union of two contexts
unionCtx :: TransformCtx -> TransformCtx -> TransformCtx
unionCtx :: TransformCtx -> TransformCtx -> TransformCtx
unionCtx TransformCtx
ctx0 TransformCtx
ctx1 =
  TransformCtx :: Nat -> TransformCtx
TransformCtx { nextFreeVar :: Nat
nextFreeVar = Nat -> Nat -> Nat
forall a. Ord a => a -> a -> a
max (TransformCtx -> Nat
nextFreeVar TransformCtx
ctx0) (TransformCtx -> Nat
nextFreeVar TransformCtx
ctx1) }

instance Semigroup TransformCtx where
  <> :: TransformCtx -> TransformCtx -> TransformCtx
(<>) = TransformCtx -> TransformCtx -> TransformCtx
unionCtx

instance Monoid TransformCtx where
  mempty :: TransformCtx
mempty  = TransformCtx
minimalCtx
#if !(MIN_VERSION_base(4,11,0))
  mappend = (<>)
#endif

-- | The class of types which have an associated context
class HasTransformCtx x where
  ctxOf :: x -> TransformCtx

instance HasTransformCtx (Variable (a :: Hakaru)) where
  ctxOf :: Variable a -> TransformCtx
ctxOf Variable a
v = TransformCtx :: Nat -> TransformCtx
TransformCtx { nextFreeVar :: Nat
nextFreeVar = Variable a -> Nat
forall k (a :: k). Variable a -> Nat
varID Variable a
v Nat -> Nat -> Nat
forall a. Num a => a -> a -> a
+ Nat
1 }

instance ABT syn abt => HasTransformCtx (abt (xs :: [Hakaru]) (a :: Hakaru)) where
  ctxOf :: abt xs a -> TransformCtx
ctxOf abt xs a
t = TransformCtx :: Nat -> TransformCtx
TransformCtx { nextFreeVar :: Nat
nextFreeVar = abt xs a -> Nat
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (a :: k).
ABT syn abt =>
abt xs a -> Nat
nextFree abt xs a
t }

-- | A functional lookup table which indicates how to expand
--   transformations. The function returns @Nothing@ when the transformation
--   shouldn't be expanded. When it returns @Just k@, @k@ is passed an @SArgs@
--   and a @TransformCtx@.
newtype TransformTable abt m
  =  TransformTable
  {  TransformTable abt m
-> forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
   Transform as b
   -> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b)))
lookupTransform
  :: forall as b
  .  Transform as b
  -> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b))) }

-- | A variant of @lookupTransform@ which joins the two layers of @Maybe@.
lookupTransform'
  :: (Applicative m)
  => TransformTable abt m
  -> Transform as b
  -> TransformCtx
  -> SArgs abt as -> m (Maybe (abt '[] b))
lookupTransform' :: TransformTable abt m
-> Transform as b
-> TransformCtx
-> SArgs abt as
-> m (Maybe (abt '[] b))
lookupTransform' TransformTable abt m
tbl Transform as b
tr TransformCtx
ctx SArgs abt as
args=
  case TransformTable abt m
-> Transform as b
-> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b)))
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *).
TransformTable abt m
-> forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
   Transform as b
   -> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b)))
lookupTransform TransformTable abt m
tbl Transform as b
tr of
    Just TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b))
f  -> TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b))
f TransformCtx
ctx SArgs abt as
args
    Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b)))
Nothing -> Maybe (abt '[] b) -> m (Maybe (abt '[] b))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (abt '[] b)
forall a. Maybe a
Nothing

-- | Builds a 'simple' transformation table, i.e. one which doesn't make use of
--  the monadic context. Such a table is valid in every @Applicative@ context.
simpleTable
  :: (Applicative m)
  => (forall as b . Transform as b
                 -> Maybe (TransformCtx -> SArgs abt as -> Maybe (abt '[] b)))
  -> TransformTable abt m
simpleTable :: (forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
 Transform as b
 -> Maybe (TransformCtx -> SArgs abt as -> Maybe (abt '[] b)))
-> TransformTable abt m
simpleTable forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
Transform as b
-> Maybe (TransformCtx -> SArgs abt as -> Maybe (abt '[] b))
k = (forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
 Transform as b
 -> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b))))
-> TransformTable abt m
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *).
(forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
 Transform as b
 -> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b))))
-> TransformTable abt m
TransformTable ((forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
  Transform as b
  -> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b))))
 -> TransformTable abt m)
-> (forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
    Transform as b
    -> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b))))
-> TransformTable abt m
forall a b. (a -> b) -> a -> b
$ \Transform as b
tr -> ((TransformCtx -> SArgs abt as -> Maybe (abt '[] b))
 -> TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b)))
-> Maybe (TransformCtx -> SArgs abt as -> Maybe (abt '[] b))
-> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (((SArgs abt as -> Maybe (abt '[] b))
 -> SArgs abt as -> m (Maybe (abt '[] b)))
-> (TransformCtx -> SArgs abt as -> Maybe (abt '[] b))
-> TransformCtx
-> SArgs abt as
-> m (Maybe (abt '[] b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((Maybe (abt '[] b) -> m (Maybe (abt '[] b)))
-> (SArgs abt as -> Maybe (abt '[] b))
-> SArgs abt as
-> m (Maybe (abt '[] b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Maybe (abt '[] b) -> m (Maybe (abt '[] b))
forall (f :: * -> *) a. Applicative f => a -> f a
pure)) (Maybe (TransformCtx -> SArgs abt as -> Maybe (abt '[] b))
 -> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b))))
-> Maybe (TransformCtx -> SArgs abt as -> Maybe (abt '[] b))
-> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b)))
forall a b. (a -> b) -> a -> b
$ Transform as b
-> Maybe (TransformCtx -> SArgs abt as -> Maybe (abt '[] b))
forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
Transform as b
-> Maybe (TransformCtx -> SArgs abt as -> Maybe (abt '[] b))
k Transform as b
tr

-- | Take the left-biased union of two transformation tables
unionTable :: TransformTable abt m
           -> TransformTable abt m
           -> TransformTable abt m
unionTable :: TransformTable abt m
-> TransformTable abt m -> TransformTable abt m
unionTable TransformTable abt m
tbl0 TransformTable abt m
tbl1 = (forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
 Transform as b
 -> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b))))
-> TransformTable abt m
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *).
(forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
 Transform as b
 -> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b))))
-> TransformTable abt m
TransformTable ((forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
  Transform as b
  -> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b))))
 -> TransformTable abt m)
-> (forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
    Transform as b
    -> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b))))
-> TransformTable abt m
forall a b. (a -> b) -> a -> b
$ \Transform as b
tr ->
  TransformTable abt m
-> Transform as b
-> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b)))
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *).
TransformTable abt m
-> forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
   Transform as b
   -> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b)))
lookupTransform TransformTable abt m
tbl0 Transform as b
tr Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b)))
-> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b)))
-> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b)))
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|>
  TransformTable abt m
-> Transform as b
-> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b)))
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *).
TransformTable abt m
-> forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
   Transform as b
   -> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b)))
lookupTransform TransformTable abt m
tbl1 Transform as b
tr

-- | Intersect a transformation table with a list of transformations
someTransformations :: [Some2 Transform]
                    -> TransformTable abt m
                    -> TransformTable abt m
someTransformations :: [Some2 Transform] -> TransformTable abt m -> TransformTable abt m
someTransformations [Some2 Transform]
toExpand TransformTable abt m
tbl = (forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
 Transform as b
 -> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b))))
-> TransformTable abt m
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *).
(forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
 Transform as b
 -> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b))))
-> TransformTable abt m
TransformTable ((forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
  Transform as b
  -> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b))))
 -> TransformTable abt m)
-> (forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
    Transform as b
    -> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b))))
-> TransformTable abt m
forall a b. (a -> b) -> a -> b
$
  \Transform as b
tr -> if Transform as b -> Some2 Transform
forall k1 k2 (a :: k1 -> k2 -> *) (i :: k1) (j :: k2).
a i j -> Some2 a
Some2 Transform as b
tr Some2 Transform -> [Some2 Transform] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Some2 Transform]
toExpand then TransformTable abt m
-> Transform as b
-> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b)))
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *).
TransformTable abt m
-> forall (as :: [([Hakaru], Hakaru)]) (b :: Hakaru).
   Transform as b
   -> Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b)))
lookupTransform TransformTable abt m
tbl Transform as b
tr else Maybe (TransformCtx -> SArgs abt as -> m (Maybe (abt '[] b)))
forall a. Maybe a
Nothing