{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module      : Jikka.Core.Convert.SpecializeFoldl
-- Description : specializes @foldl@ with concrete functions like @sum@ and @product@. / @sum@ や @product@ のような具体的な関数で @foldl@ を特殊化します。
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
--
-- \[
--     \newcommand\int{\mathbf{int}}
--     \newcommand\bool{\mathbf{bool}}
--     \newcommand\list{\mathbf{list}}
-- \]
module Jikka.Core.Convert.SpecializeFoldl
  ( run,
  )
where

import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Core.Language.ArithmeticExpr
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.FreeVars
import Jikka.Core.Language.Lint
import Jikka.Core.Language.ModuloExpr
import Jikka.Core.Language.RewriteRules

convertToSum :: Expr -> Maybe Expr
convertToSum :: Expr -> Maybe Expr
convertToSum = \case
  Foldl' Type
t1 Type
IntTy (Lam2 VarName
x2 Type
_ VarName
x1 Type
_ Expr
body) Expr
init Expr
xs -> do
    (ArithmeticExpr
a, ArithmeticExpr
b) <- VarName -> ArithmeticExpr -> Maybe (ArithmeticExpr, ArithmeticExpr)
makeAffineFunctionFromArithmeticExpr VarName
x2 (Expr -> ArithmeticExpr
parseArithmeticExpr Expr
body)
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ ArithmeticExpr -> Bool
isOneArithmeticExpr ArithmeticExpr
a
    Expr -> Maybe Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr
Plus' Expr
init (Expr -> Expr
Sum' (Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
IntTy (VarName -> Type -> Expr -> Expr
Lam VarName
x1 Type
t1 (ArithmeticExpr -> Expr
formatArithmeticExpr ArithmeticExpr
b)) Expr
xs))
  Expr
_ -> Maybe Expr
forall a. Maybe a
Nothing

convertToModSum :: Expr -> Maybe Expr
convertToModSum :: Expr -> Maybe Expr
convertToModSum = \case
  Foldl' Type
t1 Type
IntTy (Lam2 VarName
x2 Type
_ VarName
x1 Type
_ Expr
body) Expr
init Expr
xs -> do
    ModuloExpr
body <- Expr -> Maybe ModuloExpr
parseModuloExpr Expr
body
    (ArithmeticExpr
a, ArithmeticExpr
b) <- VarName -> ArithmeticExpr -> Maybe (ArithmeticExpr, ArithmeticExpr)
makeAffineFunctionFromArithmeticExpr VarName
x2 (ModuloExpr -> ArithmeticExpr
arithmeticExprFromModuloExpr ModuloExpr
body)
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ ArithmeticExpr -> Bool
isOneArithmeticExpr ArithmeticExpr
a

    -- `if` is required for cases like `foldl (fun y x -> y % 2) 3 xs`, which is the same to `if xs == nil then 3 else 1`.
    let wrap :: Expr -> Expr
        wrap :: Expr -> Expr
wrap =
          if Expr
init Expr -> Modulo -> Bool
`isModulo` Expr -> Modulo
Modulo (ModuloExpr -> Expr
moduloOfModuloExpr ModuloExpr
body)
            then Expr -> Expr
forall a. a -> a
id
            else Type -> Expr -> Expr -> Expr -> Expr
If' Type
IntTy (Type -> Expr -> Expr -> Expr
Equal' (Type -> Type
ListTy Type
t1) Expr
xs (Type -> Expr
Nil' Type
t1)) Expr
init

    Expr -> Maybe Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> Maybe Expr) -> (Expr -> Expr) -> Expr -> Maybe Expr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> Expr
wrap (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$
      Expr -> Expr -> Expr -> Expr
ModPlus' Expr
init (Expr -> Expr -> Expr
ModSum' (Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
IntTy (VarName -> Type -> Expr -> Expr
Lam VarName
x1 Type
t1 (ArithmeticExpr -> Expr
formatArithmeticExpr ArithmeticExpr
b)) Expr
xs) (ModuloExpr -> Expr
moduloOfModuloExpr ModuloExpr
body)) (ModuloExpr -> Expr
moduloOfModuloExpr ModuloExpr
body)
  Expr
_ -> Maybe Expr
forall a. Maybe a
Nothing

rule :: MonadAlpha m => RewriteRule m
rule :: RewriteRule m
rule = String -> (Expr -> Maybe Expr) -> RewriteRule m
forall (m :: * -> *).
Monad m =>
String -> (Expr -> Maybe Expr) -> RewriteRule m
simpleRewriteRule String
"Jikka.Core.Convert.SpecializeFoldl" ((Expr -> Maybe Expr) -> RewriteRule m)
-> (Expr -> Maybe Expr) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \case
  (Expr -> Maybe Expr
convertToSum -> Just Expr
e) -> Expr -> Maybe Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e
  (Expr -> Maybe Expr
convertToModSum -> Just Expr
e) -> Expr -> Maybe Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e
  -- TODO: Replace these operators with the better implementation like sum.
  Foldl' Type
t1 Type
t2 (Lam2 VarName
x2 Type
_ VarName
x1 Type
_ Expr
body) Expr
init Expr
xs -> case Expr
body of
    -- Product
    Mult' (Var VarName
x2') Expr
e | VarName
x2' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x2 Bool -> Bool -> Bool
&& VarName
x2 VarName -> Expr -> Bool
`isUnusedVar` Expr
e -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr
Product' (Type -> Expr -> Expr -> Expr
Cons' Type
t2 Expr
init (Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
t2 (VarName -> Type -> Expr -> Expr
Lam VarName
x1 Type
t1 Expr
e) Expr
xs))
    Mult' Expr
e (Var VarName
x2') | VarName
x2' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x2 Bool -> Bool -> Bool
&& VarName
x2 VarName -> Expr -> Bool
`isUnusedVar` Expr
e -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr
Product' (Type -> Expr -> Expr -> Expr
Cons' Type
t2 Expr
init (Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
t2 (VarName -> Type -> Expr -> Expr
Lam VarName
x1 Type
t1 Expr
e) Expr
xs))
    -- All
    And' (Var VarName
x2') Expr
e | VarName
x2' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x2 Bool -> Bool -> Bool
&& VarName
x2 VarName -> Expr -> Bool
`isUnusedVar` Expr
e -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr
All' (Type -> Expr -> Expr -> Expr
Cons' Type
t2 Expr
init (Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
t2 (VarName -> Type -> Expr -> Expr
Lam VarName
x1 Type
t1 Expr
e) Expr
xs))
    And' Expr
e (Var VarName
x2') | VarName
x2' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x2 Bool -> Bool -> Bool
&& VarName
x2 VarName -> Expr -> Bool
`isUnusedVar` Expr
e -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr
All' (Type -> Expr -> Expr -> Expr
Cons' Type
t2 Expr
init (Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
t2 (VarName -> Type -> Expr -> Expr
Lam VarName
x1 Type
t1 Expr
e) Expr
xs))
    -- Any
    Or' (Var VarName
x2') Expr
e | VarName
x2' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x2 Bool -> Bool -> Bool
&& VarName
x2 VarName -> Expr -> Bool
`isUnusedVar` Expr
e -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr
Any' (Type -> Expr -> Expr -> Expr
Cons' Type
t2 Expr
init (Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
t2 (VarName -> Type -> Expr -> Expr
Lam VarName
x1 Type
t1 Expr
e) Expr
xs))
    Or' Expr
e (Var VarName
x2') | VarName
x2' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x2 Bool -> Bool -> Bool
&& VarName
x2 VarName -> Expr -> Bool
`isUnusedVar` Expr
e -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr
Any' (Type -> Expr -> Expr -> Expr
Cons' Type
t2 Expr
init (Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
t2 (VarName -> Type -> Expr -> Expr
Lam VarName
x1 Type
t1 Expr
e) Expr
xs))
    -- Max1
    Max2' Type
_ (Var VarName
x2') Expr
e | VarName
x2' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x2 Bool -> Bool -> Bool
&& VarName
x2 VarName -> Expr -> Bool
`isUnusedVar` Expr
e -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Type -> Expr -> Expr
Max1' Type
t2 (Type -> Expr -> Expr -> Expr
Cons' Type
t2 Expr
init (Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
t2 (VarName -> Type -> Expr -> Expr
Lam VarName
x1 Type
t1 Expr
e) Expr
xs))
    Max2' Type
_ Expr
e (Var VarName
x2') | VarName
x2' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x2 Bool -> Bool -> Bool
&& VarName
x2 VarName -> Expr -> Bool
`isUnusedVar` Expr
e -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Type -> Expr -> Expr
Max1' Type
t2 (Type -> Expr -> Expr -> Expr
Cons' Type
t2 Expr
init (Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
t2 (VarName -> Type -> Expr -> Expr
Lam VarName
x1 Type
t1 Expr
e) Expr
xs))
    -- Max1
    Min2' Type
_ (Var VarName
x2') Expr
e | VarName
x2' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x2 Bool -> Bool -> Bool
&& VarName
x2 VarName -> Expr -> Bool
`isUnusedVar` Expr
e -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Type -> Expr -> Expr
Min1' Type
t2 (Type -> Expr -> Expr -> Expr
Cons' Type
t2 Expr
init (Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
t2 (VarName -> Type -> Expr -> Expr
Lam VarName
x1 Type
t1 Expr
e) Expr
xs))
    Min2' Type
_ Expr
e (Var VarName
x2') | VarName
x2' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x2 Bool -> Bool -> Bool
&& VarName
x2 VarName -> Expr -> Bool
`isUnusedVar` Expr
e -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Type -> Expr -> Expr
Min1' Type
t2 (Type -> Expr -> Expr -> Expr
Cons' Type
t2 Expr
init (Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
t2 (VarName -> Type -> Expr -> Expr
Lam VarName
x1 Type
t1 Expr
e) Expr
xs))
    -- Lcm1
    Lcm' (Var VarName
x2') Expr
e | VarName
x2' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x2 Bool -> Bool -> Bool
&& VarName
x2 VarName -> Expr -> Bool
`isUnusedVar` Expr
e -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Type -> Expr -> Expr
Lcm1' Type
t2 (Type -> Expr -> Expr -> Expr
Cons' Type
t2 Expr
init (Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
t2 (VarName -> Type -> Expr -> Expr
Lam VarName
x1 Type
t1 Expr
e) Expr
xs))
    Lcm' Expr
e (Var VarName
x2') | VarName
x2' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x2 Bool -> Bool -> Bool
&& VarName
x2 VarName -> Expr -> Bool
`isUnusedVar` Expr
e -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Type -> Expr -> Expr
Lcm1' Type
t2 (Type -> Expr -> Expr -> Expr
Cons' Type
t2 Expr
init (Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
t2 (VarName -> Type -> Expr -> Expr
Lam VarName
x1 Type
t1 Expr
e) Expr
xs))
    -- Gcd1
    Gcd' (Var VarName
x2') Expr
e | VarName
x2' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x2 Bool -> Bool -> Bool
&& VarName
x2 VarName -> Expr -> Bool
`isUnusedVar` Expr
e -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Type -> Expr -> Expr
Gcd1' Type
t2 (Type -> Expr -> Expr -> Expr
Cons' Type
t2 Expr
init (Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
t2 (VarName -> Type -> Expr -> Expr
Lam VarName
x1 Type
t1 Expr
e) Expr
xs))
    Gcd' Expr
e (Var VarName
x2') | VarName
x2' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x2 Bool -> Bool -> Bool
&& VarName
x2 VarName -> Expr -> Bool
`isUnusedVar` Expr
e -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Type -> Expr -> Expr
Gcd1' Type
t2 (Type -> Expr -> Expr -> Expr
Cons' Type
t2 Expr
init (Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
t2 (VarName -> Type -> Expr -> Expr
Lam VarName
x1 Type
t1 Expr
e) Expr
xs))
    -- others
    Expr
_ -> Maybe Expr
forall a. Maybe a
Nothing
  -- The outer floor-mod is required because foldl for empty lists returns values without modulo.
  FloorMod' (Foldl' Type
t1 Type
t2 (Lam2 VarName
x2 Type
_ VarName
x1 Type
_ Expr
body) Expr
init Expr
xs) Expr
m -> case Expr
body of
    -- ModProduct
    ModMult' (Var VarName
x2') Expr
e Expr
m' | VarName
x2' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x2 Bool -> Bool -> Bool
&& VarName
x2 VarName -> Expr -> Bool
`isUnusedVar` Expr
e Bool -> Bool -> Bool
&& Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr
ModProduct' (Type -> Expr -> Expr -> Expr
Cons' Type
t2 Expr
init (Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
t2 (VarName -> Type -> Expr -> Expr
Lam VarName
x1 Type
t1 Expr
e) Expr
xs)) Expr
m
    ModMult' Expr
e (Var VarName
x2') Expr
m' | VarName
x2' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x2 Bool -> Bool -> Bool
&& VarName
x2 VarName -> Expr -> Bool
`isUnusedVar` Expr
e Bool -> Bool -> Bool
&& Expr
m' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
m -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr
ModProduct' (Type -> Expr -> Expr -> Expr
Cons' Type
t2 Expr
init (Type -> Type -> Expr -> Expr -> Expr
Map' Type
t1 Type
t2 (VarName -> Type -> Expr -> Expr
Lam VarName
x1 Type
t1 Expr
e) Expr
xs)) Expr
m
    -- others
    Expr
_ -> Maybe Expr
forall a. Maybe a
Nothing
  -- others
  Expr
_ -> Maybe Expr
forall a. Maybe a
Nothing

runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
runProgram :: Program -> m Program
runProgram = RewriteRule m -> Program -> m Program
forall (m :: * -> *).
MonadError Error m =>
RewriteRule m -> Program -> m Program
applyRewriteRuleProgram' RewriteRule m
forall (m :: * -> *). MonadAlpha m => RewriteRule m
rule

-- | `run` reduces summations and products.
--
-- == Example
--
-- Before:
--
-- > foldl (fun x y -> x + y) 0 xs
--
-- After:
--
-- > sum xs
--
-- == List of builtin functions which are reduced
--
-- === Source functions
--
-- * `Foldl` \(: \forall \alpha \beta. (\beta \to \alpha \to \beta) \to \beta \to \list(\alpha) \to \beta\)
--
-- === Destination functions
--
-- * `Sum` \(: \list(\int) \to \int\)
-- * `Product` \(: \list(\int) \to \int\)
-- * `ModSum` \(: \list(\int) \to \int \to \int\)
-- * `ModProduct` \(: \list(\int) \to \int \to \int\)
-- * `All` \(: \list(\bool) \to \bool\)
-- * `Any` \(: \list(\bool) \to \bool\)
-- * `Max1` \(: \forall \alpha. \list(\alpha) \to \alpha\)
-- * `Min1` \(: \forall \alpha. \list(\alpha) \to \alpha\)
-- * `Iterate` \(: \forall \alpha. \int \to (\alpha \to \alpha) \to \alpha \to \alpha\)
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = String -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Convert.SpecializeFoldl" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
  Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
runProgram Program
prog
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
lint Program
prog
  Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return Program
prog