{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

-- |
-- Module      : Jikka.Core.Convert.ConvexHullTrick
-- Description : uses convex hull trick. / convex hull trick を使います。
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
--
-- \[
--     \newcommand\int{\mathbf{int}}
--     \newcommand\bool{\mathbf{bool}}
--     \newcommand\list{\mathbf{list}}
-- \]
module Jikka.Core.Convert.ConvexHullTrick
  ( run,

    -- * internal rules
    rule,
    parseLinearFunctionBody,
    parseLinearFunctionBody',
  )
where

import Control.Monad.Trans.Maybe
import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Core.Language.ArithmeticalExpr
import Jikka.Core.Language.Beta
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.FreeVars
import Jikka.Core.Language.Lint
import Jikka.Core.Language.RewriteRules
import Jikka.Core.Language.Util

hoistMaybe :: Applicative m => Maybe a -> MaybeT m a
hoistMaybe :: Maybe a -> MaybeT m a
hoistMaybe = m (Maybe a) -> MaybeT m a
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe a) -> MaybeT m a)
-> (Maybe a -> m (Maybe a)) -> Maybe a -> MaybeT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe a -> m (Maybe a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure

-- | This is something commutative because only one kind of @c@ is allowed.
plusPair :: (ArithmeticalExpr, ArithmeticalExpr) -> (ArithmeticalExpr, ArithmeticalExpr) -> Maybe (ArithmeticalExpr, ArithmeticalExpr)
plusPair :: (ArithmeticalExpr, ArithmeticalExpr)
-> (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
plusPair (ArithmeticalExpr
a1, ArithmeticalExpr
c1) (ArithmeticalExpr
a2, ArithmeticalExpr
_) | ArithmeticalExpr -> Bool
isZeroArithmeticalExpr ArithmeticalExpr
a2 = (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. a -> Maybe a
Just (ArithmeticalExpr
a1, ArithmeticalExpr
c1)
plusPair (ArithmeticalExpr
a1, ArithmeticalExpr
c1) (ArithmeticalExpr
_, ArithmeticalExpr
c2) | ArithmeticalExpr -> Bool
isZeroArithmeticalExpr ArithmeticalExpr
c2 = (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. a -> Maybe a
Just (ArithmeticalExpr
a1, ArithmeticalExpr
c1)
plusPair (ArithmeticalExpr
a1, ArithmeticalExpr
_) (ArithmeticalExpr
a2, ArithmeticalExpr
c2) | ArithmeticalExpr -> Bool
isZeroArithmeticalExpr ArithmeticalExpr
a1 = (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. a -> Maybe a
Just (ArithmeticalExpr
a2, ArithmeticalExpr
c2)
plusPair (ArithmeticalExpr
_, ArithmeticalExpr
c1) (ArithmeticalExpr
a2, ArithmeticalExpr
c2) | ArithmeticalExpr -> Bool
isZeroArithmeticalExpr ArithmeticalExpr
c1 = (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. a -> Maybe a
Just (ArithmeticalExpr
a2, ArithmeticalExpr
c2)
plusPair (ArithmeticalExpr
a1, ArithmeticalExpr
c1) (ArithmeticalExpr
a2, ArithmeticalExpr
c2) =
  let (Integer
k1, ArithmeticalExpr
c1') = ArithmeticalExpr -> (Integer, ArithmeticalExpr)
splitConstantFactorArithmeticalExpr ArithmeticalExpr
c1
      (Integer
k2, ArithmeticalExpr
c2') = ArithmeticalExpr -> (Integer, ArithmeticalExpr)
splitConstantFactorArithmeticalExpr ArithmeticalExpr
c2
      a1' :: ArithmeticalExpr
a1' = ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr (Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
k1) ArithmeticalExpr
a1
      a2' :: ArithmeticalExpr
a2' = ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr (Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
k2) ArithmeticalExpr
a2
   in if ArithmeticalExpr
c1' ArithmeticalExpr -> ArithmeticalExpr -> Bool
forall a. Eq a => a -> a -> Bool
== ArithmeticalExpr
c2'
        then (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. a -> Maybe a
Just (ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
plusArithmeticalExpr ArithmeticalExpr
a1' ArithmeticalExpr
a2', ArithmeticalExpr
c1')
        else Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. Maybe a
Nothing

sumPairs :: [(ArithmeticalExpr, ArithmeticalExpr)] -> Maybe (ArithmeticalExpr, ArithmeticalExpr)
sumPairs :: [(ArithmeticalExpr, ArithmeticalExpr)]
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
sumPairs = ((ArithmeticalExpr, ArithmeticalExpr)
 -> Maybe (ArithmeticalExpr, ArithmeticalExpr)
 -> Maybe (ArithmeticalExpr, ArithmeticalExpr))
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
-> [(ArithmeticalExpr, ArithmeticalExpr)]
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(ArithmeticalExpr, ArithmeticalExpr)
e1 Maybe (ArithmeticalExpr, ArithmeticalExpr)
e2 -> (ArithmeticalExpr, ArithmeticalExpr)
-> (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
plusPair (ArithmeticalExpr, ArithmeticalExpr)
e1 ((ArithmeticalExpr, ArithmeticalExpr)
 -> Maybe (ArithmeticalExpr, ArithmeticalExpr))
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Maybe (ArithmeticalExpr, ArithmeticalExpr)
e2) ((ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. a -> Maybe a
Just (Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
1, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0))

-- | `parseLinearFunctionBody'` parses the body of a linear function which can be decomposed to convex hull trick.
-- @parseLinearFunctionBody' f i j e@ finds a 4-tuple @a, b, c, d@ where @e = a(f[j], j) c(f[< i], i) + b(f[j], j) + d(f[< i], i)@.
--
-- TODO: What is the relation between @j@ and @k@?
parseLinearFunctionBody' :: VarName -> VarName -> VarName -> Expr -> Maybe (Expr, Expr, Expr, Expr)
parseLinearFunctionBody' :: VarName
-> VarName -> VarName -> Expr -> Maybe (Expr, Expr, Expr, Expr)
parseLinearFunctionBody' VarName
f VarName
i VarName
j Expr
e = (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
 ArithmeticalExpr)
-> (Expr, Expr, Expr, Expr)
result ((ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
  ArithmeticalExpr)
 -> (Expr, Expr, Expr, Expr))
-> Maybe
     (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
      ArithmeticalExpr)
-> Maybe (Expr, Expr, Expr, Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr
-> Maybe
     (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
      ArithmeticalExpr)
go Expr
e
  where
    result :: (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
 ArithmeticalExpr)
-> (Expr, Expr, Expr, Expr)
result (ArithmeticalExpr
a, ArithmeticalExpr
c, ArithmeticalExpr
b, ArithmeticalExpr
d) =
      let (Integer
k, ArithmeticalExpr
a') = ArithmeticalExpr -> (Integer, ArithmeticalExpr)
splitConstantFactorArithmeticalExpr ArithmeticalExpr
a
          c' :: ArithmeticalExpr
c' = ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr (Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
k) ArithmeticalExpr
c
       in (ArithmeticalExpr -> Expr
formatArithmeticalExpr ArithmeticalExpr
a', ArithmeticalExpr -> Expr
formatArithmeticalExpr ArithmeticalExpr
c', ArithmeticalExpr -> Expr
formatArithmeticalExpr ArithmeticalExpr
b, ArithmeticalExpr -> Expr
formatArithmeticalExpr ArithmeticalExpr
d)
    go :: Expr
-> Maybe
     (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
      ArithmeticalExpr)
go = \case
      Negate' Expr
e -> do
        (ArithmeticalExpr
a, ArithmeticalExpr
c, ArithmeticalExpr
b, ArithmeticalExpr
d) <- Expr
-> Maybe
     (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
      ArithmeticalExpr)
go Expr
e
        (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
 ArithmeticalExpr)
-> Maybe
     (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
      ArithmeticalExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (ArithmeticalExpr
a, ArithmeticalExpr -> ArithmeticalExpr
negateArithmeticalExpr ArithmeticalExpr
c, ArithmeticalExpr -> ArithmeticalExpr
negateArithmeticalExpr ArithmeticalExpr
b, ArithmeticalExpr -> ArithmeticalExpr
negateArithmeticalExpr ArithmeticalExpr
d)
      Plus' Expr
e1 Expr
e2 -> do
        (ArithmeticalExpr
a1, ArithmeticalExpr
c1, ArithmeticalExpr
b1, ArithmeticalExpr
d1) <- Expr
-> Maybe
     (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
      ArithmeticalExpr)
go Expr
e1
        (ArithmeticalExpr
a2, ArithmeticalExpr
c2, ArithmeticalExpr
b2, ArithmeticalExpr
d2) <- Expr
-> Maybe
     (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
      ArithmeticalExpr)
go Expr
e2
        (ArithmeticalExpr
a, ArithmeticalExpr
c) <- (ArithmeticalExpr, ArithmeticalExpr)
-> (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
plusPair (ArithmeticalExpr
a1, ArithmeticalExpr
c1) (ArithmeticalExpr
a2, ArithmeticalExpr
c2)
        (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
 ArithmeticalExpr)
-> Maybe
     (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
      ArithmeticalExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (ArithmeticalExpr
a, ArithmeticalExpr
c, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
plusArithmeticalExpr ArithmeticalExpr
b1 ArithmeticalExpr
b2, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
plusArithmeticalExpr ArithmeticalExpr
d1 ArithmeticalExpr
d2)
      Minus' Expr
e1 Expr
e2 -> do
        (ArithmeticalExpr
a1, ArithmeticalExpr
c1, ArithmeticalExpr
b1, ArithmeticalExpr
d1) <- Expr
-> Maybe
     (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
      ArithmeticalExpr)
go Expr
e1
        (ArithmeticalExpr
a2, ArithmeticalExpr
c2, ArithmeticalExpr
b2, ArithmeticalExpr
d2) <- Expr
-> Maybe
     (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
      ArithmeticalExpr)
go Expr
e2
        (ArithmeticalExpr
a, ArithmeticalExpr
c) <- (ArithmeticalExpr, ArithmeticalExpr)
-> (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
plusPair (ArithmeticalExpr
a1, ArithmeticalExpr
c1) (ArithmeticalExpr -> ArithmeticalExpr
negateArithmeticalExpr ArithmeticalExpr
a2, ArithmeticalExpr
c2)
        (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
 ArithmeticalExpr)
-> Maybe
     (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
      ArithmeticalExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (ArithmeticalExpr
a, ArithmeticalExpr
c, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
minusArithmeticalExpr ArithmeticalExpr
b1 ArithmeticalExpr
b2, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
minusArithmeticalExpr ArithmeticalExpr
d1 ArithmeticalExpr
d2)
      Mult' Expr
e1 Expr
e2 -> do
        (ArithmeticalExpr
a1, ArithmeticalExpr
c1, ArithmeticalExpr
b1, ArithmeticalExpr
d1) <- Expr
-> Maybe
     (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
      ArithmeticalExpr)
go Expr
e1
        (ArithmeticalExpr
a2, ArithmeticalExpr
c2, ArithmeticalExpr
b2, ArithmeticalExpr
d2) <- Expr
-> Maybe
     (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
      ArithmeticalExpr)
go Expr
e2
        (ArithmeticalExpr
a, ArithmeticalExpr
c) <-
          [(ArithmeticalExpr, ArithmeticalExpr)]
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
sumPairs
            [ (ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr ArithmeticalExpr
a1 ArithmeticalExpr
a2, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr ArithmeticalExpr
c1 ArithmeticalExpr
c2),
              (ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr ArithmeticalExpr
b2 ArithmeticalExpr
a1, ArithmeticalExpr
c1),
              (ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr ArithmeticalExpr
b1 ArithmeticalExpr
a2, ArithmeticalExpr
c2),
              (ArithmeticalExpr
a1, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr ArithmeticalExpr
c1 ArithmeticalExpr
d2),
              (ArithmeticalExpr
a2, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr ArithmeticalExpr
c2 ArithmeticalExpr
d1),
              (ArithmeticalExpr
b2, ArithmeticalExpr
d1),
              (ArithmeticalExpr
b1, ArithmeticalExpr
d2)
            ]
        (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
 ArithmeticalExpr)
-> Maybe
     (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
      ArithmeticalExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (ArithmeticalExpr
a, ArithmeticalExpr
c, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr ArithmeticalExpr
b1 ArithmeticalExpr
b2, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr ArithmeticalExpr
d1 ArithmeticalExpr
d2)
      Expr
e
        | VarName
f VarName -> Expr -> Bool
`isUnusedVar` Expr
e Bool -> Bool -> Bool
&& VarName
j VarName -> Expr -> Bool
`isUnusedVar` Expr
e ->
          -- NOTE: Put constants to @d@ and simplify @a, b@
          (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
 ArithmeticalExpr)
-> Maybe
     (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
      ArithmeticalExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
1, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0, Expr -> ArithmeticalExpr
parseArithmeticalExpr Expr
e)
      Expr
e
        | VarName
f VarName -> Expr -> Bool
`isUnusedVar` Expr
e Bool -> Bool -> Bool
&& VarName
i VarName -> Expr -> Bool
`isUnusedVar` Expr
e ->
          (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
 ArithmeticalExpr)
-> Maybe
     (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
      ArithmeticalExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
1, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0, Expr -> ArithmeticalExpr
parseArithmeticalExpr Expr
e, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0)
      e :: Expr
e@(At' Type
_ (Var VarName
f') Expr
index) | VarName
f' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
f -> case ArithmeticalExpr -> Maybe (VarName, Integer)
unNPlusKPattern (Expr -> ArithmeticalExpr
parseArithmeticalExpr Expr
index) of
        Just (VarName
i', Integer
k) | VarName
i' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
i Bool -> Bool -> Bool
&& Integer
k Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0 -> do
          (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
 ArithmeticalExpr)
-> Maybe
     (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
      ArithmeticalExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
1, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0, Expr -> ArithmeticalExpr
parseArithmeticalExpr Expr
e)
        Just (VarName
j', Integer
0) | VarName
j' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
j -> do
          (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
 ArithmeticalExpr)
-> Maybe
     (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
      ArithmeticalExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
1, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0, Expr -> ArithmeticalExpr
parseArithmeticalExpr Expr
e, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0)
        Maybe (VarName, Integer)
_ -> Maybe
  (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
   ArithmeticalExpr)
forall a. Maybe a
Nothing
      Expr
_ -> Maybe
  (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
   ArithmeticalExpr)
forall a. Maybe a
Nothing

parseLinearFunctionBody :: MonadAlpha m => VarName -> VarName -> Integer -> Expr -> m (Maybe (Expr, Expr, Expr, Expr, Expr))
parseLinearFunctionBody :: VarName
-> VarName
-> Integer
-> Expr
-> m (Maybe (Expr, Expr, Expr, Expr, Expr))
parseLinearFunctionBody VarName
f VarName
i Integer
k = MaybeT m (Expr, Expr, Expr, Expr, Expr)
-> m (Maybe (Expr, Expr, Expr, Expr, Expr))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m (Expr, Expr, Expr, Expr, Expr)
 -> m (Maybe (Expr, Expr, Expr, Expr, Expr)))
-> (Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr))
-> Expr
-> m (Maybe (Expr, Expr, Expr, Expr, Expr))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr)
go
  where
    go :: Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr)
go = \case
      Min1' Type
_ (Map' Type
_ Type
_ (Lam VarName
j Type
_ Expr
step) (Range1' Expr
size)) -> case ArithmeticalExpr -> Maybe (VarName, Integer)
unNPlusKPattern (Expr -> ArithmeticalExpr
parseArithmeticalExpr Expr
size) of
        Just (VarName
i', Integer
k') | VarName
i' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
i Bool -> Bool -> Bool
&& Integer
k' Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
k -> do
          (Expr
a, Expr
b, Expr
c, Expr
d) <- Maybe (Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a
hoistMaybe (Maybe (Expr, Expr, Expr, Expr)
 -> MaybeT m (Expr, Expr, Expr, Expr))
-> Maybe (Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr)
forall a b. (a -> b) -> a -> b
$ VarName
-> VarName -> VarName -> Expr -> Maybe (Expr, Expr, Expr, Expr)
parseLinearFunctionBody' VarName
f VarName
i VarName
j Expr
step
          -- raname @j@ to @i@
          Expr
a <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
j (VarName -> Expr
Var VarName
i) Expr
a
          Expr
c <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
j (VarName -> Expr
Var VarName
i) Expr
c
          (Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> Expr
LitInt' Integer
1, Expr
a, Expr
b, Expr
c, Expr
d)
        Maybe (VarName, Integer)
_ -> Maybe (Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a
hoistMaybe Maybe (Expr, Expr, Expr, Expr, Expr)
forall a. Maybe a
Nothing
      Max1' Type
_ (Map' Type
_ Type
_ (Lam VarName
j Type
_ Expr
step) (Range1' Expr
size)) -> case ArithmeticalExpr -> Maybe (VarName, Integer)
unNPlusKPattern (Expr -> ArithmeticalExpr
parseArithmeticalExpr Expr
size) of
        Just (VarName
i', Integer
k') | VarName
i' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
i Bool -> Bool -> Bool
&& Integer
k' Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
k -> do
          (Expr
a, Expr
b, Expr
c, Expr
d) <- Maybe (Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a
hoistMaybe (Maybe (Expr, Expr, Expr, Expr)
 -> MaybeT m (Expr, Expr, Expr, Expr))
-> Maybe (Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr)
forall a b. (a -> b) -> a -> b
$ VarName
-> VarName -> VarName -> Expr -> Maybe (Expr, Expr, Expr, Expr)
parseLinearFunctionBody' VarName
f VarName
i VarName
j Expr
step
          -- raname @j@ to @i@
          Expr
a <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
j (VarName -> Expr
Var VarName
i) Expr
a
          Expr
c <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
j (VarName -> Expr
Var VarName
i) Expr
c
          (Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> Expr
LitInt' (-Integer
1), Expr
a, Expr -> Expr
Negate' Expr
b, Expr -> Expr
Negate' Expr
c, Expr
d)
        Maybe (VarName, Integer)
_ -> Maybe (Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a
hoistMaybe Maybe (Expr, Expr, Expr, Expr, Expr)
forall a. Maybe a
Nothing
      Negate' Expr
e -> do
        (Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr)
go Expr
e
        (Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> Expr
Negate' Expr
sign, Expr
a, Expr
b, Expr
c, Expr -> Expr
Negate' Expr
d)
      Plus' Expr
e1 Expr
e2 | Expr -> Bool
isConstantTimeExpr Expr
e2 -> do
        (Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr)
go Expr
e1
        (Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr
sign, Expr
a, Expr
b, Expr
c, Expr -> Expr -> Expr
Plus' Expr
d Expr
e2)
      Plus' Expr
e1 Expr
e2 | Expr -> Bool
isConstantTimeExpr Expr
e1 -> do
        (Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr)
go Expr
e2
        (Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr
sign, Expr
a, Expr
b, Expr
c, Expr -> Expr -> Expr
Plus' Expr
e1 Expr
d)
      Minus' Expr
e1 Expr
e2 | Expr -> Bool
isConstantTimeExpr Expr
e2 -> do
        (Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr)
go Expr
e1
        (Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr
sign, Expr
a, Expr
b, Expr
c, Expr -> Expr -> Expr
Minus' Expr
d Expr
e2)
      Minus' Expr
e1 Expr
e2 | Expr -> Bool
isConstantTimeExpr Expr
e1 -> do
        (Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr)
go Expr
e2
        (Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> Expr
Negate' Expr
sign, Expr
a, Expr
b, Expr
c, Expr -> Expr -> Expr
Minus' Expr
e1 Expr
d)
      Mult' Expr
e1 Expr
e2 | Expr -> Bool
isConstantTimeExpr Expr
e2 -> do
        (Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr)
go Expr
e1
        (Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> Expr -> Expr
Mult' Expr
sign Expr
e2, Expr
a, Expr
b, Expr
c, Expr -> Expr -> Expr
Mult' Expr
d Expr
e2)
      Mult' Expr
e1 Expr
e2 | Expr -> Bool
isConstantTimeExpr Expr
e1 -> do
        (Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr)
go Expr
e2
        (Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> Expr -> Expr
Mult' Expr
e1 Expr
sign, Expr
a, Expr
b, Expr
c, Expr -> Expr -> Expr
Mult' Expr
e1 Expr
d)
      Expr
_ -> Maybe (Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a
hoistMaybe Maybe (Expr, Expr, Expr, Expr, Expr)
forall a. Maybe a
Nothing

getLength :: Expr -> Maybe Integer
getLength :: Expr -> Maybe Integer
getLength = \case
  Nil' Type
_ -> Integer -> Maybe Integer
forall a. a -> Maybe a
Just Integer
0
  Cons' Type
_ Expr
_ Expr
xs -> Integer -> Integer
forall a. Enum a => a -> a
succ (Integer -> Integer) -> Maybe Integer -> Maybe Integer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> Maybe Integer
getLength Expr
xs
  Snoc' Type
_ Expr
xs Expr
_ -> Integer -> Integer
forall a. Enum a => a -> a
succ (Integer -> Integer) -> Maybe Integer -> Maybe Integer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> Maybe Integer
getLength Expr
xs
  Expr
_ -> Maybe Integer
forall a. Maybe a
Nothing

rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m
rule :: RewriteRule m
rule = ([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall (m :: * -> *).
([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
RewriteRule (([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m)
-> ([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \[(VarName, Type)]
_ -> \case
  -- build (fun f -> step(f)) base n
  Build' Type
IntTy (Lam VarName
f Type
_ Expr
step) Expr
base Expr
n -> MaybeT m Expr -> m (Maybe Expr)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m Expr -> m (Maybe Expr))
-> MaybeT m Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ do
    VarName
i <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
    Integer
k <- Maybe Integer -> MaybeT m Integer
forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a
hoistMaybe (Maybe Integer -> MaybeT m Integer)
-> Maybe Integer -> MaybeT m Integer
forall a b. (a -> b) -> a -> b
$ Expr -> Maybe Integer
getLength Expr
base
    Expr
step <- VarName -> VarName -> Integer -> Expr -> MaybeT m Expr
forall (m :: * -> *).
MonadError Error m =>
VarName -> VarName -> Integer -> Expr -> m Expr
replaceLenF VarName
f VarName
i Integer
k Expr
step
    -- step(f) = sign(f) * min (map (fun j -> a(f, j) c(f) + b(f, j)) (range (i + k))) + d(f)
    (Expr
sign, Expr
a, Expr
c, Expr
b, Expr
d) <- m (Maybe (Expr, Expr, Expr, Expr, Expr))
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe (Expr, Expr, Expr, Expr, Expr))
 -> MaybeT m (Expr, Expr, Expr, Expr, Expr))
-> m (Maybe (Expr, Expr, Expr, Expr, Expr))
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall a b. (a -> b) -> a -> b
$ VarName
-> VarName
-> Integer
-> Expr
-> m (Maybe (Expr, Expr, Expr, Expr, Expr))
forall (m :: * -> *).
MonadAlpha m =>
VarName
-> VarName
-> Integer
-> Expr
-> m (Maybe (Expr, Expr, Expr, Expr, Expr))
parseLinearFunctionBody VarName
f VarName
i Integer
k Expr
step
    VarName
x <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
    VarName
y <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
    VarName
f' <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VarName -> MaybeT m VarName) -> m VarName -> MaybeT m VarName
forall a b. (a -> b) -> a -> b
$ VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
f
    let ts :: [Type]
ts = [Type
ConvexHullTrickTy, Type -> Type
ListTy Type
IntTy]
    -- base' = (empty, base)
    let base' :: Expr
base' = Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' [Type]
ts) [Expr
ConvexHullTrickInit', Expr
base]
    -- step' = fun (cht, f) i ->
    --     let f' = setat f index(i) (min cht f[i + k] + c(i))
    --     in let cht' = update cht a(i) b(i)
    --     in (cht', f')
    let step' :: Expr
step' =
          VarName -> Type -> VarName -> Type -> Expr -> Expr
Lam2 VarName
x ([Type] -> Type
TupleTy [Type]
ts) VarName
i Type
IntTy (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
            VarName -> Type -> Expr -> Expr -> Expr
Let VarName
f (Type -> Type
ListTy Type
IntTy) ([Type] -> Int -> Expr -> Expr
Proj' [Type]
ts Int
1 (VarName -> Expr
Var VarName
x)) (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
              VarName -> Type -> Expr -> Expr -> Expr
Let VarName
y Type
ConvexHullTrickTy (Expr -> Expr -> Expr -> Expr
ConvexHullTrickInsert' ([Type] -> Int -> Expr -> Expr
Proj' [Type]
ts Int
0 (VarName -> Expr
Var VarName
x)) Expr
a Expr
b) (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
                VarName -> Type -> Expr -> Expr -> Expr
Let VarName
f' (Type -> Type
ListTy Type
IntTy) (Type -> Expr -> Expr -> Expr
Snoc' Type
IntTy (VarName -> Expr
Var VarName
f) (Expr -> Expr -> Expr
Plus' (Expr -> Expr -> Expr
Mult' Expr
sign (Expr -> Expr -> Expr
ConvexHullTrickGetMin' (VarName -> Expr
Var VarName
y) Expr
c)) Expr
d)) (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
                  Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' [Type]
ts) [VarName -> Expr
Var VarName
y, VarName -> Expr
Var VarName
f']
    -- proj 1 (foldl step' base' (range (n - 1)))
    Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ [Type] -> Int -> Expr -> Expr
Proj' [Type]
ts Int
1 (Type -> Type -> Expr -> Expr -> Expr -> Expr
Foldl' Type
IntTy ([Type] -> Type
TupleTy [Type]
ts) Expr
step' Expr
base' (Expr -> Expr
Range1' Expr
n))
  Expr
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing

runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
runProgram :: Program -> m Program
runProgram = RewriteRule m -> Program -> m Program
forall (m :: * -> *).
MonadError Error m =>
RewriteRule m -> Program -> m Program
applyRewriteRuleProgram' RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
rule

-- | `run` optimizes a DP which has the recurrence relation
-- \[
--     \mathrm{dp}(i) = \min a(j) x(i) + b(j) \lbrace \mid j \lt i \rbrace + c(i)
-- \] where only appropriate elements of \(\mathrm{dp}\) are used in \(a, x, b, c\).
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = String -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Convert.ConvexHullTrick" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
ensureWellTyped Program
prog
  Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
runProgram Program
prog
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
ensureWellTyped Program
prog
  Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return Program
prog