{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

-- |
-- Module      : Jikka.Core.Convert.TrivialLetElimination
-- Description : removes let-exprs whose variables are referenced at most only once. / その変数が高々 1 回しか参照されないような let 式を消去します。
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
module Jikka.Core.Convert.TrivialLetElimination
  ( run,
    run',
  )
where

import Data.Functor
import qualified Data.Map as M
import Data.Maybe (fromMaybe)
import Jikka.Common.Error
import Jikka.Core.Language.Expr
import Jikka.Core.Language.Lint

plus :: Maybe Bool -> Maybe Bool -> Maybe Bool
plus :: Maybe Bool -> Maybe Bool -> Maybe Bool
plus (Just Bool
_) (Just Bool
_) = Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
False
plus (Just Bool
p) Maybe Bool
Nothing = Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
p
plus Maybe Bool
Nothing (Just Bool
p) = Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
p
plus Maybe Bool
Nothing Maybe Bool
Nothing = Maybe Bool
forall a. Maybe a
Nothing

isEliminatable :: VarName -> Expr -> Maybe Bool
isEliminatable :: VarName -> Expr -> Maybe Bool
isEliminatable VarName
x = \case
  Var VarName
y -> if VarName
x VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
y then Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
True else Maybe Bool
forall a. Maybe a
Nothing
  Lit Literal
_ -> Maybe Bool
forall a. Maybe a
Nothing
  App Expr
f Expr
e -> VarName -> Expr -> Maybe Bool
isEliminatable VarName
x Expr
f Maybe Bool -> Maybe Bool -> Maybe Bool
`plus` VarName -> Expr -> Maybe Bool
isEliminatable VarName
x Expr
e
  Lam VarName
y Type
_ Expr
e -> if VarName
x VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
y then Maybe Bool
forall a. Maybe a
Nothing else VarName -> Expr -> Maybe Bool
isEliminatable VarName
x Expr
e Maybe Bool -> Bool -> Maybe Bool
forall (f :: * -> *) a b. Functor f => f a -> b -> f b
$> Bool
False -- moving an expr into a lambda may increase the time complexity
  Let VarName
y Type
_ Expr
e1 Expr
e2 -> VarName -> Expr -> Maybe Bool
isEliminatable VarName
x Expr
e1 Maybe Bool -> Maybe Bool -> Maybe Bool
`plus` (if VarName
x VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
y then Maybe Bool
forall a. Maybe a
Nothing else VarName -> Expr -> Maybe Bool
isEliminatable VarName
x Expr
e2)

isEliminatableToplevelExpr :: VarName -> ToplevelExpr -> Maybe Bool
isEliminatableToplevelExpr :: VarName -> ToplevelExpr -> Maybe Bool
isEliminatableToplevelExpr VarName
x = \case
  ResultExpr Expr
e -> VarName -> Expr -> Maybe Bool
isEliminatable VarName
x Expr
e
  ToplevelLet VarName
y Type
_ Expr
e ToplevelExpr
cont -> VarName -> Expr -> Maybe Bool
isEliminatable VarName
x Expr
e Maybe Bool -> Maybe Bool -> Maybe Bool
`plus` (if VarName
x VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
y then Maybe Bool
forall a. Maybe a
Nothing else VarName -> ToplevelExpr -> Maybe Bool
isEliminatableToplevelExpr VarName
x ToplevelExpr
cont)
  ToplevelLetRec VarName
f [(VarName, Type)]
args Type
_ Expr
body ToplevelExpr
cont -> if VarName
x VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
f then Maybe Bool
forall a. Maybe a
Nothing else VarName -> ToplevelExpr -> Maybe Bool
isEliminatableToplevelExpr VarName
x ToplevelExpr
cont Maybe Bool -> Maybe Bool -> Maybe Bool
`plus` (if VarName
x VarName -> [VarName] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` ((VarName, Type) -> VarName) -> [(VarName, Type)] -> [VarName]
forall a b. (a -> b) -> [a] -> [b]
map (VarName, Type) -> VarName
forall a b. (a, b) -> a
fst [(VarName, Type)]
args then Maybe Bool
forall a. Maybe a
Nothing else VarName -> Expr -> Maybe Bool
isEliminatable VarName
x Expr
body)

runExpr :: M.Map VarName Expr -> Expr -> Expr
runExpr :: Map VarName Expr -> Expr -> Expr
runExpr Map VarName Expr
env = \case
  Var VarName
x -> Expr -> Maybe Expr -> Expr
forall a. a -> Maybe a -> a
fromMaybe (VarName -> Expr
Var VarName
x) (VarName -> Map VarName Expr -> Maybe Expr
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VarName
x Map VarName Expr
env)
  Lit Literal
lit -> Literal -> Expr
Lit Literal
lit
  App Expr
f Expr
e -> Expr -> Expr -> Expr
App (Map VarName Expr -> Expr -> Expr
runExpr Map VarName Expr
env Expr
f) (Map VarName Expr -> Expr -> Expr
runExpr Map VarName Expr
env Expr
e)
  Lam VarName
x Type
t Expr
body -> VarName -> Type -> Expr -> Expr
Lam VarName
x Type
t (Map VarName Expr -> Expr -> Expr
runExpr Map VarName Expr
env Expr
body)
  Let VarName
x Type
t Expr
e1 Expr
e2 ->
    let e1' :: Expr
e1' = Map VarName Expr -> Expr -> Expr
runExpr Map VarName Expr
env Expr
e1
     in if VarName -> Expr -> Maybe Bool
isEliminatable VarName
x Expr
e2 Maybe Bool -> Maybe Bool -> Bool
forall a. Eq a => a -> a -> Bool
/= Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
False
          then Map VarName Expr -> Expr -> Expr
runExpr (VarName -> Expr -> Map VarName Expr -> Map VarName Expr
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VarName
x Expr
e1' Map VarName Expr
env) Expr
e2
          else VarName -> Type -> Expr -> Expr -> Expr
Let VarName
x Type
t Expr
e1' (Map VarName Expr -> Expr -> Expr
runExpr Map VarName Expr
env Expr
e2)

runToplevelExpr :: M.Map VarName Expr -> ToplevelExpr -> ToplevelExpr
runToplevelExpr :: Map VarName Expr -> ToplevelExpr -> ToplevelExpr
runToplevelExpr Map VarName Expr
env = \case
  ResultExpr Expr
e -> Expr -> ToplevelExpr
ResultExpr (Map VarName Expr -> Expr -> Expr
runExpr Map VarName Expr
env Expr
e)
  ToplevelLet VarName
x Type
t Expr
e ToplevelExpr
cont ->
    let e' :: Expr
e' = Map VarName Expr -> Expr -> Expr
runExpr Map VarName Expr
env Expr
e
     in if VarName -> ToplevelExpr -> Maybe Bool
isEliminatableToplevelExpr VarName
x ToplevelExpr
cont Maybe Bool -> Maybe Bool -> Bool
forall a. Eq a => a -> a -> Bool
/= Bool -> Maybe Bool
forall a. a -> Maybe a
Just Bool
False
          then Map VarName Expr -> ToplevelExpr -> ToplevelExpr
runToplevelExpr (VarName -> Expr -> Map VarName Expr -> Map VarName Expr
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VarName
x Expr
e' Map VarName Expr
env) ToplevelExpr
cont
          else VarName -> Type -> Expr -> ToplevelExpr -> ToplevelExpr
ToplevelLet VarName
x Type
t Expr
e' (Map VarName Expr -> ToplevelExpr -> ToplevelExpr
runToplevelExpr Map VarName Expr
env ToplevelExpr
cont)
  ToplevelLetRec VarName
f [(VarName, Type)]
args Type
ret Expr
body ToplevelExpr
cont ->
    VarName
-> [(VarName, Type)]
-> Type
-> Expr
-> ToplevelExpr
-> ToplevelExpr
ToplevelLetRec VarName
f [(VarName, Type)]
args Type
ret (Map VarName Expr -> Expr -> Expr
runExpr Map VarName Expr
env Expr
body) (Map VarName Expr -> ToplevelExpr -> ToplevelExpr
runToplevelExpr Map VarName Expr
env ToplevelExpr
cont)

run' :: Program -> Program
run' :: ToplevelExpr -> ToplevelExpr
run' = Map VarName Expr -> ToplevelExpr -> ToplevelExpr
runToplevelExpr Map VarName Expr
forall k a. Map k a
M.empty

-- | `run` remove let-exprs whose assigned variables are used only at most once.
-- This assumes that the program is alpha-converted.
--
-- For example, this converts the following:
--
-- > let f = fun y -> y
-- > in let x = 1
-- > in f(x + x)
--
-- to:
--
-- > let x = 1
-- > in (fun y -> y) (x + x)
--
-- NOTE: this doesn't constant folding.
run :: MonadError Error m => Program -> m Program
run :: ToplevelExpr -> m ToplevelExpr
run ToplevelExpr
prog = String -> m ToplevelExpr -> m ToplevelExpr
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Convert.ConstantPropagation" (m ToplevelExpr -> m ToplevelExpr)
-> m ToplevelExpr -> m ToplevelExpr
forall a b. (a -> b) -> a -> b
$ do
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    ToplevelExpr -> m ()
forall (m :: * -> *). MonadError Error m => ToplevelExpr -> m ()
ensureWellTyped ToplevelExpr
prog
  ToplevelExpr
prog <- ToplevelExpr -> m ToplevelExpr
forall (m :: * -> *) a. Monad m => a -> m a
return (ToplevelExpr -> m ToplevelExpr) -> ToplevelExpr -> m ToplevelExpr
forall a b. (a -> b) -> a -> b
$ ToplevelExpr -> ToplevelExpr
run' ToplevelExpr
prog
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    ToplevelExpr -> m ()
forall (m :: * -> *). MonadError Error m => ToplevelExpr -> m ()
ensureWellTyped ToplevelExpr
prog
  ToplevelExpr -> m ToplevelExpr
forall (m :: * -> *) a. Monad m => a -> m a
return ToplevelExpr
prog