{-# LANGUAGE CPP                        #-}
{-# LANGUAGE DataKinds                  #-}
{-# LANGUAGE FlexibleContexts           #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE GADTs                      #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses      #-}
{-# LANGUAGE Rank2Types                 #-}
{-# LANGUAGE ScopedTypeVariables        #-}
{-# LANGUAGE ViewPatterns               #-}

{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
----------------------------------------------------------------
--                                                    2016.04.02
-- |
-- Module      :  Language.Hakaru.Evaluation.ConstantPropagation
-- Copyright   :  Copyright (c) 2016 the Hakaru team
-- License     :  BSD3
-- Maintainer  :  wren@community.haskell.org
-- Stability   :  experimental
-- Portability :  GHC-only
--
--
----------------------------------------------------------------
module Language.Hakaru.Evaluation.ConstantPropagation
    ( constantPropagation
    ) where

#if __GLASGOW_HASKELL__ < 710
import           Control.Applicative                  (Applicative (..))
import           Data.Functor                         ((<$>))
#endif

import           Control.Monad.Reader
import           Data.Monoid                          (All (..))
import           Language.Hakaru.Evaluation.EvalMonad (runPureEvaluate)
import           Language.Hakaru.Syntax.ABT           (ABT (..), View (..))
import           Language.Hakaru.Syntax.AST
import           Language.Hakaru.Syntax.IClasses      (Foldable21 (..),
                                                       Traversable21 (..))
import           Language.Hakaru.Syntax.Variable

type Env = Assocs Literal

-- The constant propagation monad. Simply threads through an environment mapping
-- variables to known constant values.
newtype PropM a = PropM { PropM a -> Reader Env a
runPropM :: Reader Env a }
  deriving (a -> PropM b -> PropM a
(a -> b) -> PropM a -> PropM b
(forall a b. (a -> b) -> PropM a -> PropM b)
-> (forall a b. a -> PropM b -> PropM a) -> Functor PropM
forall a b. a -> PropM b -> PropM a
forall a b. (a -> b) -> PropM a -> PropM b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
<$ :: a -> PropM b -> PropM a
$c<$ :: forall a b. a -> PropM b -> PropM a
fmap :: (a -> b) -> PropM a -> PropM b
$cfmap :: forall a b. (a -> b) -> PropM a -> PropM b
Functor, Functor PropM
a -> PropM a
Functor PropM
-> (forall a. a -> PropM a)
-> (forall a b. PropM (a -> b) -> PropM a -> PropM b)
-> (forall a b c. (a -> b -> c) -> PropM a -> PropM b -> PropM c)
-> (forall a b. PropM a -> PropM b -> PropM b)
-> (forall a b. PropM a -> PropM b -> PropM a)
-> Applicative PropM
PropM a -> PropM b -> PropM b
PropM a -> PropM b -> PropM a
PropM (a -> b) -> PropM a -> PropM b
(a -> b -> c) -> PropM a -> PropM b -> PropM c
forall a. a -> PropM a
forall a b. PropM a -> PropM b -> PropM a
forall a b. PropM a -> PropM b -> PropM b
forall a b. PropM (a -> b) -> PropM a -> PropM b
forall a b c. (a -> b -> c) -> PropM a -> PropM b -> PropM c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
<* :: PropM a -> PropM b -> PropM a
$c<* :: forall a b. PropM a -> PropM b -> PropM a
*> :: PropM a -> PropM b -> PropM b
$c*> :: forall a b. PropM a -> PropM b -> PropM b
liftA2 :: (a -> b -> c) -> PropM a -> PropM b -> PropM c
$cliftA2 :: forall a b c. (a -> b -> c) -> PropM a -> PropM b -> PropM c
<*> :: PropM (a -> b) -> PropM a -> PropM b
$c<*> :: forall a b. PropM (a -> b) -> PropM a -> PropM b
pure :: a -> PropM a
$cpure :: forall a. a -> PropM a
$cp1Applicative :: Functor PropM
Applicative, Applicative PropM
a -> PropM a
Applicative PropM
-> (forall a b. PropM a -> (a -> PropM b) -> PropM b)
-> (forall a b. PropM a -> PropM b -> PropM b)
-> (forall a. a -> PropM a)
-> Monad PropM
PropM a -> (a -> PropM b) -> PropM b
PropM a -> PropM b -> PropM b
forall a. a -> PropM a
forall a b. PropM a -> PropM b -> PropM b
forall a b. PropM a -> (a -> PropM b) -> PropM b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
return :: a -> PropM a
$creturn :: forall a. a -> PropM a
>> :: PropM a -> PropM b -> PropM b
$c>> :: forall a b. PropM a -> PropM b -> PropM b
>>= :: PropM a -> (a -> PropM b) -> PropM b
$c>>= :: forall a b. PropM a -> (a -> PropM b) -> PropM b
$cp1Monad :: Applicative PropM
Monad, MonadReader Env)

----------------------------------------------------------------
----------------------------------------------------------------
-- TODO: try evaluating certain things even if not all their immediate
-- subterms are literals. For example:
-- (1) evaluate beta-redexes where the argument is a literal
-- (2) evaluate case-of-constructor if we can
-- (3) handle identity elements for NaryOps
-- (4) Recognize trivial cases for looping constructs:
--     summate a b (const 0) == 0
--     summate a b id        == b - a
--     summate a b (const x) == x * (b - a)
--
-- | Perform basic constant propagation.
constantPropagation
  :: forall abt a . (ABT Term abt)
  => abt '[] a
  -> abt '[] a
constantPropagation :: abt '[] a -> abt '[] a
constantPropagation abt '[] a
abt = Reader Env (abt '[] a) -> Env -> abt '[] a
forall r a. Reader r a -> r -> a
runReader (PropM (abt '[] a) -> Reader Env (abt '[] a)
forall a. PropM a -> Reader Env a
runPropM (PropM (abt '[] a) -> Reader Env (abt '[] a))
-> PropM (abt '[] a) -> Reader Env (abt '[] a)
forall a b. (a -> b) -> a -> b
$ abt '[] a -> PropM (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (xs :: [Hakaru]).
ABT Term abt =>
abt xs a -> PropM (abt xs a)
constantProp' abt '[] a
abt) Env
forall k (abt :: k -> *). Assocs abt
emptyAssocs

constantProp'
  :: forall abt a xs . (ABT Term abt)
  => abt xs a
  -> PropM (abt xs a)
constantProp' :: abt xs a -> PropM (abt xs a)
constantProp' = abt xs a -> PropM (abt xs a)
forall (b :: Hakaru) (ys :: [Hakaru]). abt ys b -> PropM (abt ys b)
start
  where

    start :: forall b ys . abt ys b -> PropM (abt ys b)
    start :: abt ys b -> PropM (abt ys b)
start = View (Term abt) ys b -> PropM (abt ys b)
forall (b :: Hakaru) (ys :: [Hakaru]).
View (Term abt) ys b -> PropM (abt ys b)
loop (View (Term abt) ys b -> PropM (abt ys b))
-> (abt ys b -> View (Term abt) ys b)
-> abt ys b
-> PropM (abt ys b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. abt ys b -> View (Term abt) ys b
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (a :: k).
ABT syn abt =>
abt xs a -> View (syn abt) xs a
viewABT

    loop :: forall b ys . View (Term abt) ys b -> PropM (abt ys b)
    loop :: View (Term abt) ys b -> PropM (abt ys b)
loop (Var Variable b
v)    = abt '[] b
-> (Literal b -> abt '[] b) -> Maybe (Literal b) -> abt '[] b
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Variable b -> abt '[] b
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var Variable b
v) (Term abt b -> abt '[] b
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn (Term abt b -> abt '[] b)
-> (Literal b -> Term abt b) -> Literal b -> abt '[] b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Literal b -> Term abt b
forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *).
Literal a -> Term abt a
Literal_) (Maybe (Literal b) -> abt '[] b)
-> (Env -> Maybe (Literal b)) -> Env -> abt '[] b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Variable b -> Env -> Maybe (Literal b)
forall k (a :: k) (ast :: k -> *).
(Show1 Sing, JmEq1 Sing) =>
Variable a -> Assocs ast -> Maybe (ast a)
lookupAssoc Variable b
v (Env -> abt '[] b) -> PropM Env -> PropM (abt '[] b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> PropM Env
forall r (m :: * -> *). MonadReader r m => m r
ask
    loop (Syn Term abt b
s)    = Term abt b -> PropM (abt '[] b)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Term abt a -> PropM (abt '[] a)
constantPropTerm Term abt b
s
    loop (Bind Variable a
v View (Term abt) xs b
b) = Variable a -> abt xs b -> abt (a : xs) b
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k) (xs :: [k]) (b :: k).
ABT syn abt =>
Variable a -> abt xs b -> abt (a : xs) b
bind Variable a
v (abt xs b -> abt (a : xs) b)
-> PropM (abt xs b) -> PropM (abt (a : xs) b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> View (Term abt) xs b -> PropM (abt xs b)
forall (b :: Hakaru) (ys :: [Hakaru]).
View (Term abt) ys b -> PropM (abt ys b)
loop View (Term abt) xs b
b

isLiteral :: forall abt b ys . (ABT Term abt) => abt ys b -> Bool
isLiteral :: abt ys b -> Bool
isLiteral abt ys b
abt = case abt ys b -> View (Term abt) ys b
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (a :: k).
ABT syn abt =>
abt xs a -> View (syn abt) xs a
viewABT abt ys b
abt of
                  Syn (Literal_ _) -> Bool
True
                  View (Term abt) ys b
_                -> Bool
False

isFoldable :: forall abt b . (ABT Term abt) => Term abt b -> Bool
isFoldable :: Term abt b -> Bool
isFoldable = All -> Bool
getAll (All -> Bool) -> (Term abt b -> All) -> Term abt b -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall (h :: [Hakaru]) (i :: Hakaru). abt h i -> All)
-> Term abt b -> All
forall k1 k2 k3 (f :: (k1 -> k2 -> *) -> k3 -> *) m
       (a :: k1 -> k2 -> *) (j :: k3).
(Foldable21 f, Monoid m) =>
(forall (h :: k1) (i :: k2). a h i -> m) -> f a j -> m
foldMap21 (Bool -> All
All (Bool -> All) -> (abt h i -> Bool) -> abt h i -> All
forall b c a. (b -> c) -> (a -> b) -> a -> c
. abt h i -> Bool
forall (abt :: [Hakaru] -> Hakaru -> *) (b :: Hakaru)
       (ys :: [Hakaru]).
ABT Term abt =>
abt ys b -> Bool
isLiteral)

getLiteral :: forall abt ys b. (ABT Term abt) => abt ys b -> Maybe (Literal b)
getLiteral :: abt ys b -> Maybe (Literal b)
getLiteral abt ys b
e =
  case abt ys b -> View (Term abt) ys b
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (a :: k).
ABT syn abt =>
abt xs a -> View (syn abt) xs a
viewABT abt ys b
e of
    Syn (Literal_ l) -> Literal b -> Maybe (Literal b)
forall a. a -> Maybe a
Just Literal b
l
    View (Term abt) ys b
_                -> Maybe (Literal b)
forall a. Maybe a
Nothing

tryEval :: forall abt b . (ABT Term abt) => Term abt b -> abt '[] b
tryEval :: Term abt b -> abt '[] b
tryEval Term abt b
term
  | Term abt b -> Bool
forall (abt :: [Hakaru] -> Hakaru -> *) (b :: Hakaru).
ABT Term abt =>
Term abt b -> Bool
isFoldable Term abt b
term = abt '[] b -> abt '[] b
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a
runPureEvaluate (Term abt b -> abt '[] b
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn Term abt b
term)
  | Bool
otherwise       = Term abt b -> abt '[] b
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn Term abt b
term

constantPropTerm
  :: (ABT Term abt)
  => Term abt a
  -> PropM (abt '[] a)
constantPropTerm :: Term abt a -> PropM (abt '[] a)
constantPropTerm (SCon args a
Let_ :$ abt vars a
rhs :* abt vars a
body :* SArgs abt args
End) =
  abt '[a] a
-> (Variable a -> abt '[] a -> PropM (abt '[] a))
-> PropM (abt '[] a)
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (x :: k) (xs :: [k]) (a :: k) r.
ABT syn abt =>
abt (x : xs) a -> (Variable x -> abt xs a -> r) -> r
caseBind abt vars a
abt '[a] a
body ((Variable a -> abt '[] a -> PropM (abt '[] a))
 -> PropM (abt '[] a))
-> (Variable a -> abt '[] a -> PropM (abt '[] a))
-> PropM (abt '[] a)
forall a b. (a -> b) -> a -> b
$ \Variable a
v abt '[] a
body' -> do
    abt vars a
rhs' <- abt vars a -> PropM (abt vars a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (xs :: [Hakaru]).
ABT Term abt =>
abt xs a -> PropM (abt xs a)
constantProp' abt vars a
rhs
    case abt vars a -> Maybe (Literal a)
forall (abt :: [Hakaru] -> Hakaru -> *) (ys :: [Hakaru])
       (b :: Hakaru).
ABT Term abt =>
abt ys b -> Maybe (Literal b)
getLiteral abt vars a
rhs' of
      Just Literal a
l  -> (Env -> Env) -> PropM (abt '[] a) -> PropM (abt '[] a)
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (Assoc Literal -> Env -> Env
forall k (ast :: k -> *). Assoc ast -> Assocs ast -> Assocs ast
insertAssoc (Variable a -> Literal a -> Assoc Literal
forall k (ast :: k -> *) (a :: k). Variable a -> ast a -> Assoc ast
Assoc Variable a
v Literal a
l)) (abt '[] a -> PropM (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (xs :: [Hakaru]).
ABT Term abt =>
abt xs a -> PropM (abt xs a)
constantProp' abt '[] a
body')
      Maybe (Literal a)
Nothing -> do
        abt '[] a
body'' <- abt '[] a -> PropM (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (xs :: [Hakaru]).
ABT Term abt =>
abt xs a -> PropM (abt xs a)
constantProp' abt '[] a
body'
        abt '[] a -> PropM (abt '[] a)
forall (m :: * -> *) a. Monad m => a -> m a
return (abt '[] a -> PropM (abt '[] a)) -> abt '[] a -> PropM (abt '[] a)
forall a b. (a -> b) -> a -> b
$ Term abt a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn (SCon '[LC a, '( '[a], a)] a
forall (a :: Hakaru) (b :: Hakaru). SCon '[LC a, '( '[a], b)] b
Let_ SCon '[LC a, '( '[a], a)] a
-> SArgs abt '[LC a, '( '[a], a)] -> Term abt a
forall (args :: [([Hakaru], Hakaru)]) (a :: Hakaru)
       (abt :: [Hakaru] -> Hakaru -> *).
SCon args a -> SArgs abt args -> Term abt a
:$ abt vars a
rhs' abt vars a
-> SArgs abt '[ '( '[a], a)]
-> SArgs abt '[ '(vars, a), '( '[a], a)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* Variable a -> abt '[] a -> abt '[a] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k) (xs :: [k]) (b :: k).
ABT syn abt =>
Variable a -> abt xs b -> abt (a : xs) b
bind Variable a
v abt '[] a
body'' abt '[a] a -> SArgs abt '[] -> SArgs abt '[ '( '[a], a)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* SArgs abt '[]
forall (abt :: [Hakaru] -> Hakaru -> *). SArgs abt '[]
End)

constantPropTerm Term abt a
term = Term abt a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (b :: Hakaru).
ABT Term abt =>
Term abt b -> abt '[] b
tryEval (Term abt a -> abt '[] a)
-> PropM (Term abt a) -> PropM (abt '[] a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (h :: [Hakaru]) (i :: Hakaru). abt h i -> PropM (abt h i))
-> Term abt a -> PropM (Term abt a)
forall k1 k2 k3 (t :: (k1 -> k2 -> *) -> k3 -> *) (f :: * -> *)
       (a :: k1 -> k2 -> *) (b :: k1 -> k2 -> *) (j :: k3).
(Traversable21 t, Applicative f) =>
(forall (h :: k1) (i :: k2). a h i -> f (b h i))
-> t a j -> f (t b j)
traverse21 forall (h :: [Hakaru]) (i :: Hakaru). abt h i -> PropM (abt h i)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (xs :: [Hakaru]).
ABT Term abt =>
abt xs a -> PropM (abt xs a)
constantProp' Term abt a
term