{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

-- |
-- Module      : Jikka.Core.Convert.KubaruToMorau
-- Description : converts Kubaru DP to Morau DP. / 配る DP を貰う DP に変換します。
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
module Jikka.Core.Convert.KubaruToMorau
  ( run,

    -- * internal rules
    rule,
    runFunctionBody,
  )
where

import Control.Monad.Trans.Maybe
import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Core.Language.ArithmeticExpr
import Jikka.Core.Language.Beta
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.FreeVars
import Jikka.Core.Language.Lint
import Jikka.Core.Language.RewriteRules
import Jikka.Core.Language.Util

-- | @runFunctionBody c i j step y x k@ returns @step'(y, x, i, k)@ s.t. @step(c, i, j) = step'(c[i + j + 1], c[i], i, i + j + 1)@
runFunctionBody :: (MonadAlpha m, MonadError Error m) => VarName -> VarName -> VarName -> Expr -> VarName -> VarName -> VarName -> MaybeT m Expr
runFunctionBody :: VarName
-> VarName
-> VarName
-> Expr
-> VarName
-> VarName
-> VarName
-> MaybeT m Expr
runFunctionBody VarName
c VarName
i VarName
j Expr
step VarName
y VarName
x VarName
k = do
  Expr
step <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
j (Expr -> Expr -> Expr
Minus' (Expr -> Expr -> Expr
Minus' (VarName -> Expr
Var VarName
k) (VarName -> Expr
Var VarName
i)) (Integer -> Expr
LitInt' Integer
1)) Expr
step
  let go :: Expr -> MaybeT m Expr
go = \case
        Var VarName
x
          | VarName
x VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
c -> Maybe Expr -> MaybeT m Expr
forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a
hoistMaybe Maybe Expr
forall a. Maybe a
Nothing
          | Bool
otherwise -> Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr
Var VarName
x
        Lit Literal
lit -> Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Literal -> Expr
Lit Literal
lit
        At' Type
_ (Var VarName
c') Expr
index | VarName
c' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
c -> case () of
          () | Expr -> ArithmeticExpr
parseArithmeticExpr Expr
index ArithmeticExpr -> ArithmeticExpr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr -> ArithmeticExpr
parseArithmeticExpr (VarName -> Expr
Var VarName
i) -> Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr
Var VarName
x
          () | Expr -> ArithmeticExpr
parseArithmeticExpr Expr
index ArithmeticExpr -> ArithmeticExpr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr -> ArithmeticExpr
parseArithmeticExpr (VarName -> Expr
Var VarName
k) -> Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr
Var VarName
y
          () | Bool
otherwise -> Maybe Expr -> MaybeT m Expr
forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a
hoistMaybe Maybe Expr
forall a. Maybe a
Nothing
        App Expr
e1 Expr
e2 -> Expr -> Expr -> Expr
App (Expr -> Expr -> Expr) -> MaybeT m Expr -> MaybeT m (Expr -> Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> MaybeT m Expr
go Expr
e1 MaybeT m (Expr -> Expr) -> MaybeT m Expr -> MaybeT m Expr
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Expr -> MaybeT m Expr
go Expr
e2
        Let VarName
x Type
t Expr
e1 Expr
e2
          | VarName
x VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
c Bool -> Bool -> Bool
|| VarName
x VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
i Bool -> Bool -> Bool
|| VarName
x VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
j -> String -> MaybeT m Expr
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwRuntimeError String
"name confliction found"
          | Bool
otherwise -> VarName -> Type -> Expr -> Expr -> Expr
Let VarName
x Type
t (Expr -> Expr -> Expr) -> MaybeT m Expr -> MaybeT m (Expr -> Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> MaybeT m Expr
go Expr
e1 MaybeT m (Expr -> Expr) -> MaybeT m Expr -> MaybeT m Expr
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Expr -> MaybeT m Expr
go Expr
e2
        Lam VarName
x Type
t Expr
e
          | VarName
x VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
c Bool -> Bool -> Bool
|| VarName
x VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
i Bool -> Bool -> Bool
|| VarName
x VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
j -> String -> MaybeT m Expr
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwRuntimeError String
"name confliction found"
          | Bool
otherwise -> VarName -> Type -> Expr -> Expr
Lam VarName
x Type
t (Expr -> Expr) -> MaybeT m Expr -> MaybeT m Expr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> MaybeT m Expr
go Expr
e
        Assert Expr
e1 Expr
e2 -> Expr -> Expr -> Expr
Assert (Expr -> Expr -> Expr) -> MaybeT m Expr -> MaybeT m (Expr -> Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> MaybeT m Expr
go Expr
e1 MaybeT m (Expr -> Expr) -> MaybeT m Expr -> MaybeT m Expr
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Expr -> MaybeT m Expr
go Expr
e2
  Expr -> MaybeT m Expr
go Expr
step

-- | TODO: remove the assumption that the length of @a@ is equals to @n@
rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m
rule :: RewriteRule m
rule = String
-> (RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall (m :: * -> *).
Monad m =>
String
-> (RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m
makeRewriteRule String
"Jikka.Core.Convert.KubaruToMorau" ((RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m)
-> (RewriteEnvironment -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \RewriteEnvironment
_ -> \case
  -- foldl (fun b i -> foldl (fun c j -> setAt c index(i, j) step(c, i, j)) b (range m(i))) a (range n)
  Foldl' Type
IntTy (ListTy Type
t2) (Lam2 VarName
b Type
_ VarName
i Type
_ (Foldl' Type
IntTy (ListTy Type
t2') (Lam2 VarName
c Type
_ VarName
j Type
_ (SetAt' Type
_ (Var VarName
c') Expr
index Expr
step)) (Var VarName
b') (Range1' Expr
m))) Expr
a (Range1' Expr
n)
    | Type
t2' Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
t2 Bool -> Bool -> Bool
&& VarName
b' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
b Bool -> Bool -> Bool
&& VarName
c VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
c' Bool -> Bool -> Bool
&& VarName
b VarName -> Expr -> Bool
`isUnusedVar` Expr
m Bool -> Bool -> Bool
&& VarName
b VarName -> Expr -> Bool
`isUnusedVar` Expr
index Bool -> Bool -> Bool
&& VarName
b VarName -> Expr -> Bool
`isUnusedVar` Expr
step Bool -> Bool -> Bool
&& VarName
c VarName -> Expr -> Bool
`isUnusedVar` Expr
index -> MaybeT m Expr -> m (Maybe Expr)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m Expr -> m (Maybe Expr))
-> MaybeT m Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ do
      -- m(i) = n - i - 1
      Bool -> MaybeT m ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> MaybeT m ()) -> Bool -> MaybeT m ()
forall a b. (a -> b) -> a -> b
$ Expr -> ArithmeticExpr
parseArithmeticExpr Expr
m ArithmeticExpr -> ArithmeticExpr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr -> ArithmeticExpr
parseArithmeticExpr (Expr -> Expr -> Expr
Minus' (Expr -> Expr -> Expr
Minus' Expr
n (VarName -> Expr
Var VarName
i)) (Integer -> Expr
LitInt' Integer
1))
      -- index(i, j) = i + j + 1
      Bool -> MaybeT m ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> MaybeT m ()) -> Bool -> MaybeT m ()
forall a b. (a -> b) -> a -> b
$ Expr -> ArithmeticExpr
parseArithmeticExpr Expr
index ArithmeticExpr -> ArithmeticExpr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr -> ArithmeticExpr
parseArithmeticExpr (Expr -> Expr -> Expr
Plus' (VarName -> Expr
Var VarName
i) (Expr -> Expr -> Expr
Plus' (VarName -> Expr
Var VarName
j) (Integer -> Expr
LitInt' Integer
1)))
      VarName
x <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
      VarName
y <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
      VarName
k <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
      -- get step'(y, x, i, k) s.t. step(c, i, j) = step'(c[i + j + 1], c[i], i, i + j + 1)
      Expr
step <- VarName
-> VarName
-> VarName
-> Expr
-> VarName
-> VarName
-> VarName
-> MaybeT m Expr
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
VarName
-> VarName
-> VarName
-> Expr
-> VarName
-> VarName
-> VarName
-> MaybeT m Expr
runFunctionBody VarName
c VarName
i VarName
j Expr
step VarName
y VarName
x VarName
k
      Expr
step <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
x (Type -> Expr -> Expr -> Expr
At' Type
t2 (VarName -> Expr
Var VarName
c) (VarName -> Expr
Var VarName
i)) Expr
step
      Expr
step <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
k (Type -> Expr -> Expr
Len' Type
t2 (VarName -> Expr
Var VarName
c)) Expr
step
      let base :: Expr
base = Type -> Expr -> Expr -> Expr
At' Type
t2 Expr
a (Type -> Expr -> Expr
Len' Type
t2 (VarName -> Expr
Var VarName
c))
      Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ Type -> Expr -> Expr -> Expr -> Expr
Build' Type
t2 (VarName -> Type -> Expr -> Expr
Lam VarName
c (Type -> Type
ListTy Type
t2) (Type -> Type -> Expr -> Expr -> Expr -> Expr
Foldl' Type
IntTy Type
t2 (VarName -> Type -> VarName -> Type -> Expr -> Expr
Lam2 VarName
y Type
t2 VarName
i Type
IntTy Expr
step) Expr
base (Expr -> Expr
Range1' (Type -> Expr -> Expr
Len' Type
t2 (VarName -> Expr
Var VarName
c))))) (Type -> Expr
Nil' Type
t2) Expr
n
  Expr
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing

runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
runProgram :: Program -> m Program
runProgram = RewriteRule m -> Program -> m Program
forall (m :: * -> *).
MonadError Error m =>
RewriteRule m -> Program -> m Program
applyRewriteRuleProgram' RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
rule

-- | `run` converts Kubaru DP
-- (for each \(i\), updates \(
--     \mathrm{dp}(j) \gets f(\mathrm{dp}(j), \mathrm{dp}(i))
-- \) for each \(j \gt i\))
-- to Morau DP
-- (for each \(i\), computes \(
--     \mathrm{dp}(i) = F(\lbrace \mathrm{dp}(j) \mid j \lt i \rbrace)
-- \)).
--
-- == Examples
--
-- Before:
--
-- > foldl (fun dp i ->
-- >     foldl (fun dp j ->
-- >         setAt dp j (
-- >             f dp[j] dp[i])
-- >         ) dp (range (i + 1) n)
-- >     ) dp (range n)
--
-- After:
--
-- > build (fun dp' ->
-- >     foldl (fun dp_i j ->
-- >         f dp_i dp'[j]
-- >         ) dp[i] (range i)
-- >     ) [] n
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = String -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Convert.KubaruToMorau" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
  Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
runProgram Program
prog
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
  Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return Program
prog