{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

-- |
-- Module      : Jikka.Core.Convert.ConvexHullTrick
-- Description : uses convex hull trick. / convex hull trick を使います。
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
--
-- $-- \newcommand\int{\mathbf{int}} -- \newcommand\bool{\mathbf{bool}} -- \newcommand\list{\mathbf{list}} --$
module Jikka.Core.Convert.ConvexHullTrick
( run,

-- * internal rules
rule,
parseLinearFunctionBody,
parseLinearFunctionBody',
)
where

import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Core.Language.ArithmeticalExpr
import Jikka.Core.Language.Beta
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.FreeVars
import Jikka.Core.Language.Lint
import Jikka.Core.Language.RewriteRules
import Jikka.Core.Language.Util

hoistMaybe :: Applicative m => Maybe a -> MaybeT m a
hoistMaybe :: Maybe a -> MaybeT m a
hoistMaybe = m (Maybe a) -> MaybeT m a
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe a) -> MaybeT m a)
-> (Maybe a -> m (Maybe a)) -> Maybe a -> MaybeT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe a -> m (Maybe a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure

-- | This is something commutative because only one kind of @c@ is allowed.
plusPair :: (ArithmeticalExpr, ArithmeticalExpr) -> (ArithmeticalExpr, ArithmeticalExpr) -> Maybe (ArithmeticalExpr, ArithmeticalExpr)
plusPair :: (ArithmeticalExpr, ArithmeticalExpr)
-> (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
plusPair (ArithmeticalExpr
a1, ArithmeticalExpr
c1) (ArithmeticalExpr
a2, ArithmeticalExpr
_) | ArithmeticalExpr -> Bool
isZeroArithmeticalExpr ArithmeticalExpr
a2 = (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. a -> Maybe a
Just (ArithmeticalExpr
a1, ArithmeticalExpr
c1)
plusPair (ArithmeticalExpr
a1, ArithmeticalExpr
c1) (ArithmeticalExpr
_, ArithmeticalExpr
c2) | ArithmeticalExpr -> Bool
isZeroArithmeticalExpr ArithmeticalExpr
c2 = (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. a -> Maybe a
Just (ArithmeticalExpr
a1, ArithmeticalExpr
c1)
plusPair (ArithmeticalExpr
a1, ArithmeticalExpr
_) (ArithmeticalExpr
a2, ArithmeticalExpr
c2) | ArithmeticalExpr -> Bool
isZeroArithmeticalExpr ArithmeticalExpr
a1 = (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. a -> Maybe a
Just (ArithmeticalExpr
a2, ArithmeticalExpr
c2)
plusPair (ArithmeticalExpr
_, ArithmeticalExpr
c1) (ArithmeticalExpr
a2, ArithmeticalExpr
c2) | ArithmeticalExpr -> Bool
isZeroArithmeticalExpr ArithmeticalExpr
c1 = (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. a -> Maybe a
Just (ArithmeticalExpr
a2, ArithmeticalExpr
c2)
plusPair (ArithmeticalExpr
a1, ArithmeticalExpr
c1) (ArithmeticalExpr
a2, ArithmeticalExpr
c2) =
let (Integer
k1, ArithmeticalExpr
c1') = ArithmeticalExpr -> (Integer, ArithmeticalExpr)
splitConstantFactorArithmeticalExpr ArithmeticalExpr
c1
(Integer
k2, ArithmeticalExpr
c2') = ArithmeticalExpr -> (Integer, ArithmeticalExpr)
splitConstantFactorArithmeticalExpr ArithmeticalExpr
c2
a1' :: ArithmeticalExpr
a1' = ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr (Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
k1) ArithmeticalExpr
a1
a2' :: ArithmeticalExpr
a2' = ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr (Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
k2) ArithmeticalExpr
a2
in if ArithmeticalExpr
c1' ArithmeticalExpr -> ArithmeticalExpr -> Bool
forall a. Eq a => a -> a -> Bool
== ArithmeticalExpr
c2'
then (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. a -> Maybe a
Just (ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
plusArithmeticalExpr ArithmeticalExpr
a1' ArithmeticalExpr
a2', ArithmeticalExpr
c1')
else Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. Maybe a
Nothing

sumPairs :: [(ArithmeticalExpr, ArithmeticalExpr)] -> Maybe (ArithmeticalExpr, ArithmeticalExpr)
sumPairs :: [(ArithmeticalExpr, ArithmeticalExpr)]
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
sumPairs = ((ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr))
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
-> [(ArithmeticalExpr, ArithmeticalExpr)]
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ($$ArithmeticalExpr, ArithmeticalExpr) e1 Maybe (ArithmeticalExpr, ArithmeticalExpr) e2 -> (ArithmeticalExpr, ArithmeticalExpr) -> (ArithmeticalExpr, ArithmeticalExpr) -> Maybe (ArithmeticalExpr, ArithmeticalExpr) plusPair (ArithmeticalExpr, ArithmeticalExpr) e1 ((ArithmeticalExpr, ArithmeticalExpr) -> Maybe (ArithmeticalExpr, ArithmeticalExpr)) -> Maybe (ArithmeticalExpr, ArithmeticalExpr) -> Maybe (ArithmeticalExpr, ArithmeticalExpr) forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b =<< Maybe (ArithmeticalExpr, ArithmeticalExpr) e2) ((ArithmeticalExpr, ArithmeticalExpr) -> Maybe (ArithmeticalExpr, ArithmeticalExpr) forall a. a -> Maybe a Just (Integer -> ArithmeticalExpr integerArithmeticalExpr Integer 1, Integer -> ArithmeticalExpr integerArithmeticalExpr Integer 0)) -- | parseLinearFunctionBody' parses the body of a linear function which can be decomposed to convex hull trick. -- @parseLinearFunctionBody' f i j e@ finds a 4-tuple @a, b, c, d@ where @e = a(f[j], j) c(f[< i], i) + b(f[j], j) + d(f[< i], i)@. -- -- TODO: What is the relation between @j@ and @k@? parseLinearFunctionBody' :: VarName -> VarName -> VarName -> Expr -> Maybe (Expr, Expr, Expr, Expr) parseLinearFunctionBody' :: VarName -> VarName -> VarName -> Expr -> Maybe (Expr, Expr, Expr, Expr) parseLinearFunctionBody' VarName f VarName i VarName j Expr e = (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) -> (Expr, Expr, Expr, Expr) result ((ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) -> (Expr, Expr, Expr, Expr)) -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) -> Maybe (Expr, Expr, Expr, Expr) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <> Expr -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) go Expr e where result :: (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) -> (Expr, Expr, Expr, Expr) result (ArithmeticalExpr a, ArithmeticalExpr c, ArithmeticalExpr b, ArithmeticalExpr d) = let (Integer k, ArithmeticalExpr a') = ArithmeticalExpr -> (Integer, ArithmeticalExpr) splitConstantFactorArithmeticalExpr ArithmeticalExpr a c' :: ArithmeticalExpr c' = ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr multArithmeticalExpr (Integer -> ArithmeticalExpr integerArithmeticalExpr Integer k) ArithmeticalExpr c in (ArithmeticalExpr -> Expr formatArithmeticalExpr ArithmeticalExpr a', ArithmeticalExpr -> Expr formatArithmeticalExpr ArithmeticalExpr c', ArithmeticalExpr -> Expr formatArithmeticalExpr ArithmeticalExpr b, ArithmeticalExpr -> Expr formatArithmeticalExpr ArithmeticalExpr d) go :: Expr -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) go = \case Negate' Expr e -> do (ArithmeticalExpr a, ArithmeticalExpr c, ArithmeticalExpr b, ArithmeticalExpr d) <- Expr -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) go Expr e (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) forall (m :: * -> *) a. Monad m => a -> m a return (ArithmeticalExpr a, ArithmeticalExpr -> ArithmeticalExpr negateArithmeticalExpr ArithmeticalExpr c, ArithmeticalExpr -> ArithmeticalExpr negateArithmeticalExpr ArithmeticalExpr b, ArithmeticalExpr -> ArithmeticalExpr negateArithmeticalExpr ArithmeticalExpr d) Plus' Expr e1 Expr e2 -> do (ArithmeticalExpr a1, ArithmeticalExpr c1, ArithmeticalExpr b1, ArithmeticalExpr d1) <- Expr -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) go Expr e1 (ArithmeticalExpr a2, ArithmeticalExpr c2, ArithmeticalExpr b2, ArithmeticalExpr d2) <- Expr -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) go Expr e2 (ArithmeticalExpr a, ArithmeticalExpr c) <- (ArithmeticalExpr, ArithmeticalExpr) -> (ArithmeticalExpr, ArithmeticalExpr) -> Maybe (ArithmeticalExpr, ArithmeticalExpr) plusPair (ArithmeticalExpr a1, ArithmeticalExpr c1) (ArithmeticalExpr a2, ArithmeticalExpr c2) (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) forall (m :: * -> *) a. Monad m => a -> m a return (ArithmeticalExpr a, ArithmeticalExpr c, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr plusArithmeticalExpr ArithmeticalExpr b1 ArithmeticalExpr b2, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr plusArithmeticalExpr ArithmeticalExpr d1 ArithmeticalExpr d2) Minus' Expr e1 Expr e2 -> do (ArithmeticalExpr a1, ArithmeticalExpr c1, ArithmeticalExpr b1, ArithmeticalExpr d1) <- Expr -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) go Expr e1 (ArithmeticalExpr a2, ArithmeticalExpr c2, ArithmeticalExpr b2, ArithmeticalExpr d2) <- Expr -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) go Expr e2 (ArithmeticalExpr a, ArithmeticalExpr c) <- (ArithmeticalExpr, ArithmeticalExpr) -> (ArithmeticalExpr, ArithmeticalExpr) -> Maybe (ArithmeticalExpr, ArithmeticalExpr) plusPair (ArithmeticalExpr a1, ArithmeticalExpr c1) (ArithmeticalExpr -> ArithmeticalExpr negateArithmeticalExpr ArithmeticalExpr a2, ArithmeticalExpr c2) (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) forall (m :: * -> *) a. Monad m => a -> m a return (ArithmeticalExpr a, ArithmeticalExpr c, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr minusArithmeticalExpr ArithmeticalExpr b1 ArithmeticalExpr b2, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr minusArithmeticalExpr ArithmeticalExpr d1 ArithmeticalExpr d2) Mult' Expr e1 Expr e2 -> do (ArithmeticalExpr a1, ArithmeticalExpr c1, ArithmeticalExpr b1, ArithmeticalExpr d1) <- Expr -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) go Expr e1 (ArithmeticalExpr a2, ArithmeticalExpr c2, ArithmeticalExpr b2, ArithmeticalExpr d2) <- Expr -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) go Expr e2 (ArithmeticalExpr a, ArithmeticalExpr c) <- [(ArithmeticalExpr, ArithmeticalExpr)] -> Maybe (ArithmeticalExpr, ArithmeticalExpr) sumPairs [ (ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr multArithmeticalExpr ArithmeticalExpr a1 ArithmeticalExpr a2, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr multArithmeticalExpr ArithmeticalExpr c1 ArithmeticalExpr c2), (ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr multArithmeticalExpr ArithmeticalExpr b2 ArithmeticalExpr a1, ArithmeticalExpr c1), (ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr multArithmeticalExpr ArithmeticalExpr b1 ArithmeticalExpr a2, ArithmeticalExpr c2), (ArithmeticalExpr a1, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr multArithmeticalExpr ArithmeticalExpr c1 ArithmeticalExpr d2), (ArithmeticalExpr a2, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr multArithmeticalExpr ArithmeticalExpr c2 ArithmeticalExpr d1), (ArithmeticalExpr b2, ArithmeticalExpr d1), (ArithmeticalExpr b1, ArithmeticalExpr d2) ] (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) forall (m :: * -> *) a. Monad m => a -> m a return (ArithmeticalExpr a, ArithmeticalExpr c, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr multArithmeticalExpr ArithmeticalExpr b1 ArithmeticalExpr b2, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr multArithmeticalExpr ArithmeticalExpr d1 ArithmeticalExpr d2) Expr e | VarName f VarName -> Expr -> Bool isUnusedVar Expr e Bool -> Bool -> Bool && VarName j VarName -> Expr -> Bool isUnusedVar Expr e -> -- NOTE: Put constants to @d@ and simplify @a, b@ (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) forall (m :: * -> *) a. Monad m => a -> m a return (Integer -> ArithmeticalExpr integerArithmeticalExpr Integer 1, Integer -> ArithmeticalExpr integerArithmeticalExpr Integer 0, Integer -> ArithmeticalExpr integerArithmeticalExpr Integer 0, Expr -> ArithmeticalExpr parseArithmeticalExpr Expr e) Expr e | VarName f VarName -> Expr -> Bool isUnusedVar Expr e Bool -> Bool -> Bool && VarName i VarName -> Expr -> Bool isUnusedVar Expr e -> (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) forall (m :: * -> *) a. Monad m => a -> m a return (Integer -> ArithmeticalExpr integerArithmeticalExpr Integer 1, Integer -> ArithmeticalExpr integerArithmeticalExpr Integer 0, Expr -> ArithmeticalExpr parseArithmeticalExpr Expr e, Integer -> ArithmeticalExpr integerArithmeticalExpr Integer 0) e :: Expr e@(At' Type _ (Var VarName f') Expr index) | VarName f' VarName -> VarName -> Bool forall a. Eq a => a -> a -> Bool == VarName f -> case ArithmeticalExpr -> Maybe (VarName, Integer) unNPlusKPattern (Expr -> ArithmeticalExpr parseArithmeticalExpr Expr index) of Just (VarName i', Integer k) | VarName i' VarName -> VarName -> Bool forall a. Eq a => a -> a -> Bool == VarName i Bool -> Bool -> Bool && Integer k Integer -> Integer -> Bool forall a. Ord a => a -> a -> Bool < Integer 0 -> do (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) forall (m :: * -> *) a. Monad m => a -> m a return (Integer -> ArithmeticalExpr integerArithmeticalExpr Integer 1, Integer -> ArithmeticalExpr integerArithmeticalExpr Integer 0, Integer -> ArithmeticalExpr integerArithmeticalExpr Integer 0, Expr -> ArithmeticalExpr parseArithmeticalExpr Expr e) Just (VarName j', Integer 0) | VarName j' VarName -> VarName -> Bool forall a. Eq a => a -> a -> Bool == VarName j -> do (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) forall (m :: * -> *) a. Monad m => a -> m a return (Integer -> ArithmeticalExpr integerArithmeticalExpr Integer 1, Integer -> ArithmeticalExpr integerArithmeticalExpr Integer 0, Expr -> ArithmeticalExpr parseArithmeticalExpr Expr e, Integer -> ArithmeticalExpr integerArithmeticalExpr Integer 0) Maybe (VarName, Integer) _ -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) forall a. Maybe a Nothing Expr _ -> Maybe (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr) forall a. Maybe a Nothing parseLinearFunctionBody :: MonadAlpha m => VarName -> VarName -> Integer -> Expr -> m (Maybe (Expr, Expr, Expr, Expr, Expr)) parseLinearFunctionBody :: VarName -> VarName -> Integer -> Expr -> m (Maybe (Expr, Expr, Expr, Expr, Expr)) parseLinearFunctionBody VarName f VarName i Integer k = MaybeT m (Expr, Expr, Expr, Expr, Expr) -> m (Maybe (Expr, Expr, Expr, Expr, Expr)) forall (m :: * -> *) a. MaybeT m a -> m (Maybe a) runMaybeT (MaybeT m (Expr, Expr, Expr, Expr, Expr) -> m (Maybe (Expr, Expr, Expr, Expr, Expr))) -> (Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr)) -> Expr -> m (Maybe (Expr, Expr, Expr, Expr, Expr)) forall b c a. (b -> c) -> (a -> b) -> a -> c . Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr) go where go :: Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr) go = \case Min1' Type _ (Map' Type _ Type _ (Lam VarName j Type _ Expr step) (Range1' Expr size)) -> case ArithmeticalExpr -> Maybe (VarName, Integer) unNPlusKPattern (Expr -> ArithmeticalExpr parseArithmeticalExpr Expr size) of Just (VarName i', Integer k') | VarName i' VarName -> VarName -> Bool forall a. Eq a => a -> a -> Bool == VarName i Bool -> Bool -> Bool && Integer k' Integer -> Integer -> Bool forall a. Eq a => a -> a -> Bool == Integer k -> do (Expr a, Expr b, Expr c, Expr d) <- Maybe (Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr) forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a hoistMaybe (Maybe (Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr)) -> Maybe (Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr) forall a b. (a -> b) -> a -> b  VarName -> VarName -> VarName -> Expr -> Maybe (Expr, Expr, Expr, Expr) parseLinearFunctionBody' VarName f VarName i VarName j Expr step -- raname @j@ to @i@ Expr a <- m Expr -> MaybeT m Expr forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr forall a b. (a -> b) -> a -> b  VarName -> Expr -> Expr -> m Expr forall (m :: * -> *). MonadAlpha m => VarName -> Expr -> Expr -> m Expr substitute VarName j (VarName -> Expr Var VarName i) Expr a Expr c <- m Expr -> MaybeT m Expr forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr forall a b. (a -> b) -> a -> b  VarName -> Expr -> Expr -> m Expr forall (m :: * -> *). MonadAlpha m => VarName -> Expr -> Expr -> m Expr substitute VarName j (VarName -> Expr Var VarName i) Expr c (Expr, Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr, Expr) forall (m :: * -> *) a. Monad m => a -> m a return (Integer -> Expr LitInt' Integer 1, Expr a, Expr b, Expr c, Expr d) Maybe (VarName, Integer) _ -> Maybe (Expr, Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr, Expr) forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a hoistMaybe Maybe (Expr, Expr, Expr, Expr, Expr) forall a. Maybe a Nothing Max1' Type _ (Map' Type _ Type _ (Lam VarName j Type _ Expr step) (Range1' Expr size)) -> case ArithmeticalExpr -> Maybe (VarName, Integer) unNPlusKPattern (Expr -> ArithmeticalExpr parseArithmeticalExpr Expr size) of Just (VarName i', Integer k') | VarName i' VarName -> VarName -> Bool forall a. Eq a => a -> a -> Bool == VarName i Bool -> Bool -> Bool && Integer k' Integer -> Integer -> Bool forall a. Eq a => a -> a -> Bool == Integer k -> do (Expr a, Expr b, Expr c, Expr d) <- Maybe (Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr) forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a hoistMaybe (Maybe (Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr)) -> Maybe (Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr) forall a b. (a -> b) -> a -> b  VarName -> VarName -> VarName -> Expr -> Maybe (Expr, Expr, Expr, Expr) parseLinearFunctionBody' VarName f VarName i VarName j Expr step -- raname @j@ to @i@ Expr a <- m Expr -> MaybeT m Expr forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr forall a b. (a -> b) -> a -> b  VarName -> Expr -> Expr -> m Expr forall (m :: * -> *). MonadAlpha m => VarName -> Expr -> Expr -> m Expr substitute VarName j (VarName -> Expr Var VarName i) Expr a Expr c <- m Expr -> MaybeT m Expr forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr forall a b. (a -> b) -> a -> b  VarName -> Expr -> Expr -> m Expr forall (m :: * -> *). MonadAlpha m => VarName -> Expr -> Expr -> m Expr substitute VarName j (VarName -> Expr Var VarName i) Expr c (Expr, Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr, Expr) forall (m :: * -> *) a. Monad m => a -> m a return (Integer -> Expr LitInt' (-Integer 1), Expr a, Expr -> Expr Negate' Expr b, Expr -> Expr Negate' Expr c, Expr d) Maybe (VarName, Integer) _ -> Maybe (Expr, Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr, Expr) forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a hoistMaybe Maybe (Expr, Expr, Expr, Expr, Expr) forall a. Maybe a Nothing Negate' Expr e -> do (Expr sign, Expr a, Expr b, Expr c, Expr d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr) go Expr e (Expr, Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr, Expr) forall (m :: * -> *) a. Monad m => a -> m a return (Expr -> Expr Negate' Expr sign, Expr a, Expr b, Expr c, Expr -> Expr Negate' Expr d) Plus' Expr e1 Expr e2 | Expr -> Bool isConstantTimeExpr Expr e2 -> do (Expr sign, Expr a, Expr b, Expr c, Expr d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr) go Expr e1 (Expr, Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr, Expr) forall (m :: * -> *) a. Monad m => a -> m a return (Expr sign, Expr a, Expr b, Expr c, Expr -> Expr -> Expr Plus' Expr d Expr e2) Plus' Expr e1 Expr e2 | Expr -> Bool isConstantTimeExpr Expr e1 -> do (Expr sign, Expr a, Expr b, Expr c, Expr d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr) go Expr e2 (Expr, Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr, Expr) forall (m :: * -> *) a. Monad m => a -> m a return (Expr sign, Expr a, Expr b, Expr c, Expr -> Expr -> Expr Plus' Expr e1 Expr d) Minus' Expr e1 Expr e2 | Expr -> Bool isConstantTimeExpr Expr e2 -> do (Expr sign, Expr a, Expr b, Expr c, Expr d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr) go Expr e1 (Expr, Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr, Expr) forall (m :: * -> *) a. Monad m => a -> m a return (Expr sign, Expr a, Expr b, Expr c, Expr -> Expr -> Expr Minus' Expr d Expr e2) Minus' Expr e1 Expr e2 | Expr -> Bool isConstantTimeExpr Expr e1 -> do (Expr sign, Expr a, Expr b, Expr c, Expr d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr) go Expr e2 (Expr, Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr, Expr) forall (m :: * -> *) a. Monad m => a -> m a return (Expr -> Expr Negate' Expr sign, Expr a, Expr b, Expr c, Expr -> Expr -> Expr Minus' Expr e1 Expr d) Mult' Expr e1 Expr e2 | Expr -> Bool isConstantTimeExpr Expr e2 -> do (Expr sign, Expr a, Expr b, Expr c, Expr d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr) go Expr e1 (Expr, Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr, Expr) forall (m :: * -> *) a. Monad m => a -> m a return (Expr -> Expr -> Expr Mult' Expr sign Expr e2, Expr a, Expr b, Expr c, Expr -> Expr -> Expr Mult' Expr d Expr e2) Mult' Expr e1 Expr e2 | Expr -> Bool isConstantTimeExpr Expr e1 -> do (Expr sign, Expr a, Expr b, Expr c, Expr d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr) go Expr e2 (Expr, Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr, Expr) forall (m :: * -> *) a. Monad m => a -> m a return (Expr -> Expr -> Expr Mult' Expr e1 Expr sign, Expr a, Expr b, Expr c, Expr -> Expr -> Expr Mult' Expr e1 Expr d) Expr _ -> Maybe (Expr, Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr, Expr) forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a hoistMaybe Maybe (Expr, Expr, Expr, Expr, Expr) forall a. Maybe a Nothing getLength :: Expr -> Maybe Integer getLength :: Expr -> Maybe Integer getLength = \case Nil' Type _ -> Integer -> Maybe Integer forall a. a -> Maybe a Just Integer 0 Cons' Type _ Expr _ Expr xs -> Integer -> Integer forall a. Enum a => a -> a succ (Integer -> Integer) -> Maybe Integer -> Maybe Integer forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <> Expr -> Maybe Integer getLength Expr xs Snoc' Type _ Expr xs Expr _ -> Integer -> Integer forall a. Enum a => a -> a succ (Integer -> Integer) -> Maybe Integer -> Maybe Integer forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <> Expr -> Maybe Integer getLength Expr xs Expr _ -> Maybe Integer forall a. Maybe a Nothing rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m rule :: RewriteRule m rule = ([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m forall (m :: * -> *). ([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m RewriteRule (([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m) -> ([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m forall a b. (a -> b) -> a -> b  $(VarName, Type)] _ -> \case -- build (fun f -> step(f)) base n Build' Type IntTy (Lam VarName f Type _ Expr step) Expr base Expr n -> MaybeT m Expr -> m (Maybe Expr) forall (m :: * -> *) a. MaybeT m a -> m (Maybe a) runMaybeT (MaybeT m Expr -> m (Maybe Expr)) -> MaybeT m Expr -> m (Maybe Expr) forall a b. (a -> b) -> a -> b do VarName i <- m VarName -> MaybeT m VarName forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift m VarName forall (m :: * -> *). MonadAlpha m => m VarName genVarName' Integer k <- Maybe Integer -> MaybeT m Integer forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a hoistMaybe (Maybe Integer -> MaybeT m Integer) -> Maybe Integer -> MaybeT m Integer forall a b. (a -> b) -> a -> b Expr -> Maybe Integer getLength Expr base Expr step <- VarName -> VarName -> Integer -> Expr -> MaybeT m Expr forall (m :: * -> *). MonadError Error m => VarName -> VarName -> Integer -> Expr -> m Expr replaceLenF VarName f VarName i Integer k Expr step -- step(f) = sign(f) * min (map (fun j -> a(f, j) c(f) + b(f, j)) (range (i + k))) + d(f) (Expr sign, Expr a, Expr c, Expr b, Expr d) <- m (Maybe (Expr, Expr, Expr, Expr, Expr)) -> MaybeT m (Expr, Expr, Expr, Expr, Expr) forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a MaybeT (m (Maybe (Expr, Expr, Expr, Expr, Expr)) -> MaybeT m (Expr, Expr, Expr, Expr, Expr)) -> m (Maybe (Expr, Expr, Expr, Expr, Expr)) -> MaybeT m (Expr, Expr, Expr, Expr, Expr) forall a b. (a -> b) -> a -> b VarName -> VarName -> Integer -> Expr -> m (Maybe (Expr, Expr, Expr, Expr, Expr)) forall (m :: * -> *). MonadAlpha m => VarName -> VarName -> Integer -> Expr -> m (Maybe (Expr, Expr, Expr, Expr, Expr)) parseLinearFunctionBody VarName f VarName i Integer k Expr step VarName x <- m VarName -> MaybeT m VarName forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift m VarName forall (m :: * -> *). MonadAlpha m => m VarName genVarName' VarName y <- m VarName -> MaybeT m VarName forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift m VarName forall (m :: * -> *). MonadAlpha m => m VarName genVarName' VarName f' <- m VarName -> MaybeT m VarName forall (t :: (* -> *) -> * -> *) (m :: * -> *) a. (MonadTrans t, Monad m) => m a -> t m a lift (m VarName -> MaybeT m VarName) -> m VarName -> MaybeT m VarName forall a b. (a -> b) -> a -> b VarName -> m VarName forall (m :: * -> *). MonadAlpha m => VarName -> m VarName genVarName VarName f let ts :: [Type] ts = [Type ConvexHullTrickTy, Type -> Type ListTy Type IntTy] -- base' = (empty, base) let base' :: Expr base' = Expr -> [Expr] -> Expr uncurryApp ([Type] -> Expr Tuple' [Type] ts) [Expr ConvexHullTrickInit', Expr base] -- step' = fun (cht, f) i -> -- let f' = setat f index(i) (min cht f[i + k] + c(i)) -- in let cht' = update cht a(i) b(i) -- in (cht', f') let step' :: Expr step' = VarName -> Type -> VarName -> Type -> Expr -> Expr Lam2 VarName x ([Type] -> Type TupleTy [Type] ts) VarName i Type IntTy (Expr -> Expr) -> Expr -> Expr forall a b. (a -> b) -> a -> b VarName -> Type -> Expr -> Expr -> Expr Let VarName f (Type -> Type ListTy Type IntTy) ([Type] -> Int -> Expr -> Expr Proj' [Type] ts Int 1 (VarName -> Expr Var VarName x)) (Expr -> Expr) -> Expr -> Expr forall a b. (a -> b) -> a -> b VarName -> Type -> Expr -> Expr -> Expr Let VarName y Type ConvexHullTrickTy (Expr -> Expr -> Expr -> Expr ConvexHullTrickInsert' ([Type] -> Int -> Expr -> Expr Proj' [Type] ts Int 0 (VarName -> Expr Var VarName x)) Expr a Expr b) (Expr -> Expr) -> Expr -> Expr forall a b. (a -> b) -> a -> b VarName -> Type -> Expr -> Expr -> Expr Let VarName f' (Type -> Type ListTy Type IntTy) (Type -> Expr -> Expr -> Expr Snoc' Type IntTy (VarName -> Expr Var VarName f) (Expr -> Expr -> Expr Plus' (Expr -> Expr -> Expr Mult' Expr sign (Expr -> Expr -> Expr ConvexHullTrickGetMin' (VarName -> Expr Var VarName y) Expr c)) Expr d)) (Expr -> Expr) -> Expr -> Expr forall a b. (a -> b) -> a -> b Expr -> [Expr] -> Expr uncurryApp ([Type] -> Expr Tuple' [Type] ts) [VarName -> Expr Var VarName y, VarName -> Expr Var VarName f'] -- proj 1 (foldl step' base' (range (n - 1))) Expr -> MaybeT m Expr forall (m :: * -> *) a. Monad m => a -> m a return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr forall a b. (a -> b) -> a -> b [Type] -> Int -> Expr -> Expr Proj' [Type] ts Int 1 (Type -> Type -> Expr -> Expr -> Expr -> Expr Foldl' Type IntTy ([Type] -> Type TupleTy [Type] ts) Expr step' Expr base' (Expr -> Expr Range1' Expr n)) Expr _ -> Maybe Expr -> m (Maybe Expr) forall (m :: * -> *) a. Monad m => a -> m a return Maybe Expr forall a. Maybe a Nothing runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program runProgram :: Program -> m Program runProgram = RewriteRule m -> Program -> m Program forall (m :: * -> *). MonadError Error m => RewriteRule m -> Program -> m Program applyRewriteRuleProgram' RewriteRule m forall (m :: * -> *). (MonadAlpha m, MonadError Error m) => RewriteRule m rule -- | run optimizes a DP which has the recurrence relation -- \[ -- \mathrm{dp}(i) = \min a(j) x(i) + b(j) \lbrace \mid j \lt i \rbrace + c(i) --$ where only appropriate elements of \(\mathrm{dp}$$ are used in $$a, x, b, c$$.
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = String -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Convert.ConvexHullTrick" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$do m () -> m () forall (m :: * -> *) a. MonadError Error m => m a -> m a precondition (m () -> m ()) -> m () -> m () forall a b. (a -> b) -> a -> b$ do
Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
ensureWellTyped Program
prog
Program
prog <- Program -> m Program
forall (m :: * -> *).
Program -> m Program
runProgram Program
prog
m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
\$ do
Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
ensureWellTyped Program
prog
Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return Program
prog