{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

-- |
-- Module      : Jikka.Core.Convert.CumulativeSum
-- Description : processes queries like range sum query using cumulative sums. / 累積和を用いて range sum query のようなクエリを処理します。
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
module Jikka.Core.Convert.CumulativeSum
  ( run,

    -- * internal rules
    rule,
  )
where

import Data.Maybe
import Jikka.Common.Alpha
import Jikka.Common.Error
import qualified Jikka.Core.Convert.Alpha as Alpha
import Jikka.Core.Language.ArithmeticalExpr
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.FreeVars
import Jikka.Core.Language.Lint
import Jikka.Core.Language.RewriteRules
import Jikka.Core.Language.Util

cumulativeMax :: MonadAlpha m => (Expr -> Expr -> Expr) -> Type -> Maybe Expr -> Expr -> Expr -> m Expr
cumulativeMax :: (Expr -> Expr -> Expr)
-> Type -> Maybe Expr -> Expr -> Expr -> m Expr
cumulativeMax Expr -> Expr -> Expr
max2 Type
t Maybe Expr
a0 Expr
a Expr
n = do
  VarName
b <- m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
  let e :: Expr
e = Type -> Expr -> Expr -> Expr
At' Type
t (VarName -> Expr
Var VarName
b) Expr
n
  VarName
x1 <- m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
  VarName
x2 <- m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
  let a0' :: Expr
a0' = Expr -> Maybe Expr -> Expr
forall a. a -> Maybe a -> a
fromMaybe (Type -> Expr -> Expr -> Expr
At' Type
t Expr
a (Integer -> Expr
LitInt' Integer
0)) Maybe Expr
a0
  Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Type -> Expr -> Expr -> Expr
Let VarName
b (Type -> Type
ListTy Type
t) (Type -> Type -> Expr -> Expr -> Expr -> Expr
Scanl' Type
t Type
t (VarName -> Type -> VarName -> Type -> Expr -> Expr
Lam2 VarName
x1 Type
t VarName
x2 Type
t (Expr -> Expr -> Expr
max2 (VarName -> Expr
Var VarName
x1) (VarName -> Expr
Var VarName
x2))) Expr
a0' Expr
a) Expr
e

rule :: MonadAlpha m => RewriteRule m
rule :: RewriteRule m
rule = ([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall (m :: * -> *).
([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
RewriteRule (([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m)
-> ([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \[(VarName, Type)]
_ -> \case
  Sum' (Map' Type
_ Type
_ (Lam VarName
x Type
_ (At' Type
_ Expr
a Expr
index)) (Range1' Expr
n)) | VarName
x VarName -> Expr -> Bool
`isUnusedVar` Expr
a -> do
    case VarName
-> ArithmeticalExpr -> Maybe (ArithmeticalExpr, ArithmeticalExpr)
makeAffineFunctionFromArithmeticalExpr VarName
x (Expr -> ArithmeticalExpr
parseArithmeticalExpr Expr
index) of
      Just (ArithmeticalExpr
coeff, ArithmeticalExpr
shift) | ArithmeticalExpr -> Bool
isOneArithmeticalExpr ArithmeticalExpr
coeff -> do
        VarName
b <- m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
        let e :: Expr
e =
              if ArithmeticalExpr -> Bool
isZeroArithmeticalExpr ArithmeticalExpr
shift
                then Type -> Expr -> Expr -> Expr
At' Type
IntTy (VarName -> Expr
Var VarName
b) Expr
n
                else Expr -> Expr -> Expr
Minus' (Type -> Expr -> Expr -> Expr
At' Type
IntTy (VarName -> Expr
Var VarName
b) (Expr -> Expr -> Expr
Plus' Expr
n (ArithmeticalExpr -> Expr
formatArithmeticalExpr ArithmeticalExpr
shift))) (Type -> Expr -> Expr -> Expr
At' Type
IntTy (VarName -> Expr
Var VarName
b) (ArithmeticalExpr -> Expr
formatArithmeticalExpr ArithmeticalExpr
shift))
        Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Expr -> m (Maybe Expr))
-> (Expr -> Maybe Expr) -> Expr -> m (Maybe Expr)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$
          VarName -> Type -> Expr -> Expr -> Expr
Let VarName
b (Type -> Type
ListTy Type
IntTy) (Type -> Type -> Expr -> Expr -> Expr -> Expr
Scanl' Type
IntTy Type
IntTy (Literal -> Expr
Lit (Builtin -> Literal
LitBuiltin Builtin
Plus)) Expr
Lit0 Expr
a) Expr
e
      Maybe (ArithmeticalExpr, ArithmeticalExpr)
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
  Max1' Type
t (Cons' Type
_ Expr
a0 (Map' Type
_ Type
_ (Lam VarName
x Type
_ (At' Type
_ Expr
a (Var VarName
x'))) (Range1' Expr
n))) | VarName
x' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x Bool -> Bool -> Bool
&& VarName
x VarName -> Expr -> Bool
`isUnusedVar` Expr
a -> do
    Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> m Expr -> m (Maybe Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Expr -> Expr -> Expr)
-> Type -> Maybe Expr -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
(Expr -> Expr -> Expr)
-> Type -> Maybe Expr -> Expr -> Expr -> m Expr
cumulativeMax (Type -> Expr -> Expr -> Expr
Max2' Type
t) Type
t (Expr -> Maybe Expr
forall a. a -> Maybe a
Just Expr
a0) Expr
a Expr
n
  Max1' Type
t (Map' Type
_ Type
_ (Lam VarName
x Type
_ (At' Type
_ Expr
a (Var VarName
x'))) (Range1' Expr
n)) | VarName
x' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x Bool -> Bool -> Bool
&& VarName
x VarName -> Expr -> Bool
`isUnusedVar` Expr
a -> do
    Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> m Expr -> m (Maybe Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Expr -> Expr -> Expr)
-> Type -> Maybe Expr -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
(Expr -> Expr -> Expr)
-> Type -> Maybe Expr -> Expr -> Expr -> m Expr
cumulativeMax (Type -> Expr -> Expr -> Expr
Max2' Type
t) Type
t Maybe Expr
forall a. Maybe a
Nothing Expr
a Expr
n
  Min1' Type
t (Cons' Type
_ Expr
a0 (Map' Type
_ Type
_ (Lam VarName
x Type
_ (At' Type
_ Expr
a (Var VarName
x'))) (Range1' Expr
n))) | VarName
x' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x Bool -> Bool -> Bool
&& VarName
x VarName -> Expr -> Bool
`isUnusedVar` Expr
a -> do
    Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> m Expr -> m (Maybe Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Expr -> Expr -> Expr)
-> Type -> Maybe Expr -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
(Expr -> Expr -> Expr)
-> Type -> Maybe Expr -> Expr -> Expr -> m Expr
cumulativeMax (Type -> Expr -> Expr -> Expr
Min2' Type
t) Type
t (Expr -> Maybe Expr
forall a. a -> Maybe a
Just Expr
a0) Expr
a Expr
n
  Min1' Type
t (Map' Type
_ Type
_ (Lam VarName
x Type
_ (At' Type
_ Expr
a (Var VarName
x'))) (Range1' Expr
n)) | VarName
x' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x Bool -> Bool -> Bool
&& VarName
x VarName -> Expr -> Bool
`isUnusedVar` Expr
a -> do
    Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> m Expr -> m (Maybe Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Expr -> Expr -> Expr)
-> Type -> Maybe Expr -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
(Expr -> Expr -> Expr)
-> Type -> Maybe Expr -> Expr -> Expr -> m Expr
cumulativeMax (Type -> Expr -> Expr -> Expr
Min2' Type
t) Type
t Maybe Expr
forall a. Maybe a
Nothing Expr
a Expr
n
  Expr
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing

runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
runProgram :: Program -> m Program
runProgram = RewriteRule m -> Program -> m Program
forall (m :: * -> *).
MonadError Error m =>
RewriteRule m -> Program -> m Program
applyRewriteRuleProgram' RewriteRule m
forall (m :: * -> *). MonadAlpha m => RewriteRule m
rule

-- | `run` introduces cumulative sums.
--
-- == Examples
--
-- Before:
--
-- > sum (fun i -> a[i]) (range n)
--
-- After:
--
-- > let b = scanl (+) 0 a in b[n]
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = String -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Convert.CumulativeSum" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
ensureWellTyped Program
prog
  Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
Alpha.run Program
prog
  Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
runProgram Program
prog
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
ensureWellTyped Program
prog
  Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return Program
prog