{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
module Jikka.Core.Convert.ConvexHullTrick
( run,
rule,
parseLinearFunctionBody,
parseLinearFunctionBody',
)
where
import Control.Monad.Trans.Maybe
import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Core.Language.ArithmeticalExpr
import Jikka.Core.Language.Beta
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.FreeVars
import Jikka.Core.Language.Lint
import Jikka.Core.Language.RewriteRules
import Jikka.Core.Language.Util
hoistMaybe :: Applicative m => Maybe a -> MaybeT m a
hoistMaybe :: Maybe a -> MaybeT m a
hoistMaybe = m (Maybe a) -> MaybeT m a
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe a) -> MaybeT m a)
-> (Maybe a -> m (Maybe a)) -> Maybe a -> MaybeT m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe a -> m (Maybe a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure
plusPair :: (ArithmeticalExpr, ArithmeticalExpr) -> (ArithmeticalExpr, ArithmeticalExpr) -> Maybe (ArithmeticalExpr, ArithmeticalExpr)
plusPair :: (ArithmeticalExpr, ArithmeticalExpr)
-> (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
plusPair (ArithmeticalExpr
a1, ArithmeticalExpr
c1) (ArithmeticalExpr
a2, ArithmeticalExpr
_) | ArithmeticalExpr -> Bool
isZeroArithmeticalExpr ArithmeticalExpr
a2 = (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. a -> Maybe a
Just (ArithmeticalExpr
a1, ArithmeticalExpr
c1)
plusPair (ArithmeticalExpr
a1, ArithmeticalExpr
c1) (ArithmeticalExpr
_, ArithmeticalExpr
c2) | ArithmeticalExpr -> Bool
isZeroArithmeticalExpr ArithmeticalExpr
c2 = (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. a -> Maybe a
Just (ArithmeticalExpr
a1, ArithmeticalExpr
c1)
plusPair (ArithmeticalExpr
a1, ArithmeticalExpr
_) (ArithmeticalExpr
a2, ArithmeticalExpr
c2) | ArithmeticalExpr -> Bool
isZeroArithmeticalExpr ArithmeticalExpr
a1 = (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. a -> Maybe a
Just (ArithmeticalExpr
a2, ArithmeticalExpr
c2)
plusPair (ArithmeticalExpr
_, ArithmeticalExpr
c1) (ArithmeticalExpr
a2, ArithmeticalExpr
c2) | ArithmeticalExpr -> Bool
isZeroArithmeticalExpr ArithmeticalExpr
c1 = (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. a -> Maybe a
Just (ArithmeticalExpr
a2, ArithmeticalExpr
c2)
plusPair (ArithmeticalExpr
a1, ArithmeticalExpr
c1) (ArithmeticalExpr
a2, ArithmeticalExpr
c2) =
let (Integer
k1, ArithmeticalExpr
c1') = ArithmeticalExpr -> (Integer, ArithmeticalExpr)
splitConstantFactorArithmeticalExpr ArithmeticalExpr
c1
(Integer
k2, ArithmeticalExpr
c2') = ArithmeticalExpr -> (Integer, ArithmeticalExpr)
splitConstantFactorArithmeticalExpr ArithmeticalExpr
c2
a1' :: ArithmeticalExpr
a1' = ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr (Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
k1) ArithmeticalExpr
a1
a2' :: ArithmeticalExpr
a2' = ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr (Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
k2) ArithmeticalExpr
a2
in if ArithmeticalExpr
c1' ArithmeticalExpr -> ArithmeticalExpr -> Bool
forall a. Eq a => a -> a -> Bool
== ArithmeticalExpr
c2'
then (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. a -> Maybe a
Just (ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
plusArithmeticalExpr ArithmeticalExpr
a1' ArithmeticalExpr
a2', ArithmeticalExpr
c1')
else Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. Maybe a
Nothing
sumPairs :: [(ArithmeticalExpr, ArithmeticalExpr)] -> Maybe (ArithmeticalExpr, ArithmeticalExpr)
sumPairs :: [(ArithmeticalExpr, ArithmeticalExpr)]
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
sumPairs = ((ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr))
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
-> [(ArithmeticalExpr, ArithmeticalExpr)]
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\(ArithmeticalExpr, ArithmeticalExpr)
e1 Maybe (ArithmeticalExpr, ArithmeticalExpr)
e2 -> (ArithmeticalExpr, ArithmeticalExpr)
-> (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
plusPair (ArithmeticalExpr, ArithmeticalExpr)
e1 ((ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr))
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Maybe (ArithmeticalExpr, ArithmeticalExpr)
e2) ((ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall a. a -> Maybe a
Just (Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
1, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0))
parseLinearFunctionBody' :: VarName -> VarName -> VarName -> Expr -> Maybe (Expr, Expr, Expr, Expr)
parseLinearFunctionBody' :: VarName
-> VarName -> VarName -> Expr -> Maybe (Expr, Expr, Expr, Expr)
parseLinearFunctionBody' VarName
f VarName
i VarName
j Expr
e = (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
-> (Expr, Expr, Expr, Expr)
result ((ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
-> (Expr, Expr, Expr, Expr))
-> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
-> Maybe (Expr, Expr, Expr, Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr
-> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
go Expr
e
where
result :: (ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
-> (Expr, Expr, Expr, Expr)
result (ArithmeticalExpr
a, ArithmeticalExpr
c, ArithmeticalExpr
b, ArithmeticalExpr
d) =
let (Integer
k, ArithmeticalExpr
a') = ArithmeticalExpr -> (Integer, ArithmeticalExpr)
splitConstantFactorArithmeticalExpr ArithmeticalExpr
a
c' :: ArithmeticalExpr
c' = ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr (Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
k) ArithmeticalExpr
c
in (ArithmeticalExpr -> Expr
formatArithmeticalExpr ArithmeticalExpr
a', ArithmeticalExpr -> Expr
formatArithmeticalExpr ArithmeticalExpr
c', ArithmeticalExpr -> Expr
formatArithmeticalExpr ArithmeticalExpr
b, ArithmeticalExpr -> Expr
formatArithmeticalExpr ArithmeticalExpr
d)
go :: Expr
-> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
go = \case
Negate' Expr
e -> do
(ArithmeticalExpr
a, ArithmeticalExpr
c, ArithmeticalExpr
b, ArithmeticalExpr
d) <- Expr
-> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
go Expr
e
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
-> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (ArithmeticalExpr
a, ArithmeticalExpr -> ArithmeticalExpr
negateArithmeticalExpr ArithmeticalExpr
c, ArithmeticalExpr -> ArithmeticalExpr
negateArithmeticalExpr ArithmeticalExpr
b, ArithmeticalExpr -> ArithmeticalExpr
negateArithmeticalExpr ArithmeticalExpr
d)
Plus' Expr
e1 Expr
e2 -> do
(ArithmeticalExpr
a1, ArithmeticalExpr
c1, ArithmeticalExpr
b1, ArithmeticalExpr
d1) <- Expr
-> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
go Expr
e1
(ArithmeticalExpr
a2, ArithmeticalExpr
c2, ArithmeticalExpr
b2, ArithmeticalExpr
d2) <- Expr
-> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
go Expr
e2
(ArithmeticalExpr
a, ArithmeticalExpr
c) <- (ArithmeticalExpr, ArithmeticalExpr)
-> (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
plusPair (ArithmeticalExpr
a1, ArithmeticalExpr
c1) (ArithmeticalExpr
a2, ArithmeticalExpr
c2)
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
-> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (ArithmeticalExpr
a, ArithmeticalExpr
c, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
plusArithmeticalExpr ArithmeticalExpr
b1 ArithmeticalExpr
b2, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
plusArithmeticalExpr ArithmeticalExpr
d1 ArithmeticalExpr
d2)
Minus' Expr
e1 Expr
e2 -> do
(ArithmeticalExpr
a1, ArithmeticalExpr
c1, ArithmeticalExpr
b1, ArithmeticalExpr
d1) <- Expr
-> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
go Expr
e1
(ArithmeticalExpr
a2, ArithmeticalExpr
c2, ArithmeticalExpr
b2, ArithmeticalExpr
d2) <- Expr
-> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
go Expr
e2
(ArithmeticalExpr
a, ArithmeticalExpr
c) <- (ArithmeticalExpr, ArithmeticalExpr)
-> (ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
plusPair (ArithmeticalExpr
a1, ArithmeticalExpr
c1) (ArithmeticalExpr -> ArithmeticalExpr
negateArithmeticalExpr ArithmeticalExpr
a2, ArithmeticalExpr
c2)
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
-> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (ArithmeticalExpr
a, ArithmeticalExpr
c, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
minusArithmeticalExpr ArithmeticalExpr
b1 ArithmeticalExpr
b2, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
minusArithmeticalExpr ArithmeticalExpr
d1 ArithmeticalExpr
d2)
Mult' Expr
e1 Expr
e2 -> do
(ArithmeticalExpr
a1, ArithmeticalExpr
c1, ArithmeticalExpr
b1, ArithmeticalExpr
d1) <- Expr
-> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
go Expr
e1
(ArithmeticalExpr
a2, ArithmeticalExpr
c2, ArithmeticalExpr
b2, ArithmeticalExpr
d2) <- Expr
-> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
go Expr
e2
(ArithmeticalExpr
a, ArithmeticalExpr
c) <-
[(ArithmeticalExpr, ArithmeticalExpr)]
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
sumPairs
[ (ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr ArithmeticalExpr
a1 ArithmeticalExpr
a2, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr ArithmeticalExpr
c1 ArithmeticalExpr
c2),
(ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr ArithmeticalExpr
b2 ArithmeticalExpr
a1, ArithmeticalExpr
c1),
(ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr ArithmeticalExpr
b1 ArithmeticalExpr
a2, ArithmeticalExpr
c2),
(ArithmeticalExpr
a1, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr ArithmeticalExpr
c1 ArithmeticalExpr
d2),
(ArithmeticalExpr
a2, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr ArithmeticalExpr
c2 ArithmeticalExpr
d1),
(ArithmeticalExpr
b2, ArithmeticalExpr
d1),
(ArithmeticalExpr
b1, ArithmeticalExpr
d2)
]
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
-> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (ArithmeticalExpr
a, ArithmeticalExpr
c, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr ArithmeticalExpr
b1 ArithmeticalExpr
b2, ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr ArithmeticalExpr
d1 ArithmeticalExpr
d2)
Expr
e
| VarName
f VarName -> Expr -> Bool
`isUnusedVar` Expr
e Bool -> Bool -> Bool
&& VarName
j VarName -> Expr -> Bool
`isUnusedVar` Expr
e ->
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
-> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
1, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0, Expr -> ArithmeticalExpr
parseArithmeticalExpr Expr
e)
Expr
e
| VarName
f VarName -> Expr -> Bool
`isUnusedVar` Expr
e Bool -> Bool -> Bool
&& VarName
i VarName -> Expr -> Bool
`isUnusedVar` Expr
e ->
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
-> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
1, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0, Expr -> ArithmeticalExpr
parseArithmeticalExpr Expr
e, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0)
e :: Expr
e@(At' Type
_ (Var VarName
f') Expr
index) | VarName
f' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
f -> case ArithmeticalExpr -> Maybe (VarName, Integer)
unNPlusKPattern (Expr -> ArithmeticalExpr
parseArithmeticalExpr Expr
index) of
Just (VarName
i', Integer
k) | VarName
i' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
i Bool -> Bool -> Bool
&& Integer
k Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0 -> do
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
-> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
1, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0, Expr -> ArithmeticalExpr
parseArithmeticalExpr Expr
e)
Just (VarName
j', Integer
0) | VarName
j' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
j -> do
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
-> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
1, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0, Expr -> ArithmeticalExpr
parseArithmeticalExpr Expr
e, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0)
Maybe (VarName, Integer)
_ -> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
forall a. Maybe a
Nothing
Expr
_ -> Maybe
(ArithmeticalExpr, ArithmeticalExpr, ArithmeticalExpr,
ArithmeticalExpr)
forall a. Maybe a
Nothing
parseLinearFunctionBody :: MonadAlpha m => VarName -> VarName -> Integer -> Expr -> m (Maybe (Expr, Expr, Expr, Expr, Expr))
parseLinearFunctionBody :: VarName
-> VarName
-> Integer
-> Expr
-> m (Maybe (Expr, Expr, Expr, Expr, Expr))
parseLinearFunctionBody VarName
f VarName
i Integer
k = MaybeT m (Expr, Expr, Expr, Expr, Expr)
-> m (Maybe (Expr, Expr, Expr, Expr, Expr))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m (Expr, Expr, Expr, Expr, Expr)
-> m (Maybe (Expr, Expr, Expr, Expr, Expr)))
-> (Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr))
-> Expr
-> m (Maybe (Expr, Expr, Expr, Expr, Expr))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr)
go
where
go :: Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr)
go = \case
Min1' Type
_ (Map' Type
_ Type
_ (Lam VarName
j Type
_ Expr
step) (Range1' Expr
size)) -> case ArithmeticalExpr -> Maybe (VarName, Integer)
unNPlusKPattern (Expr -> ArithmeticalExpr
parseArithmeticalExpr Expr
size) of
Just (VarName
i', Integer
k') | VarName
i' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
i Bool -> Bool -> Bool
&& Integer
k' Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
k -> do
(Expr
a, Expr
b, Expr
c, Expr
d) <- Maybe (Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a
hoistMaybe (Maybe (Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr))
-> Maybe (Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr)
forall a b. (a -> b) -> a -> b
$ VarName
-> VarName -> VarName -> Expr -> Maybe (Expr, Expr, Expr, Expr)
parseLinearFunctionBody' VarName
f VarName
i VarName
j Expr
step
Expr
a <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
j (VarName -> Expr
Var VarName
i) Expr
a
Expr
c <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
j (VarName -> Expr
Var VarName
i) Expr
c
(Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> Expr
LitInt' Integer
1, Expr
a, Expr
b, Expr
c, Expr
d)
Maybe (VarName, Integer)
_ -> Maybe (Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a
hoistMaybe Maybe (Expr, Expr, Expr, Expr, Expr)
forall a. Maybe a
Nothing
Max1' Type
_ (Map' Type
_ Type
_ (Lam VarName
j Type
_ Expr
step) (Range1' Expr
size)) -> case ArithmeticalExpr -> Maybe (VarName, Integer)
unNPlusKPattern (Expr -> ArithmeticalExpr
parseArithmeticalExpr Expr
size) of
Just (VarName
i', Integer
k') | VarName
i' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
i Bool -> Bool -> Bool
&& Integer
k' Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
k -> do
(Expr
a, Expr
b, Expr
c, Expr
d) <- Maybe (Expr, Expr, Expr, Expr) -> MaybeT m (Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a
hoistMaybe (Maybe (Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr))
-> Maybe (Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr)
forall a b. (a -> b) -> a -> b
$ VarName
-> VarName -> VarName -> Expr -> Maybe (Expr, Expr, Expr, Expr)
parseLinearFunctionBody' VarName
f VarName
i VarName
j Expr
step
Expr
a <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
j (VarName -> Expr
Var VarName
i) Expr
a
Expr
c <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
j (VarName -> Expr
Var VarName
i) Expr
c
(Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Integer -> Expr
LitInt' (-Integer
1), Expr
a, Expr -> Expr
Negate' Expr
b, Expr -> Expr
Negate' Expr
c, Expr
d)
Maybe (VarName, Integer)
_ -> Maybe (Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a
hoistMaybe Maybe (Expr, Expr, Expr, Expr, Expr)
forall a. Maybe a
Nothing
Negate' Expr
e -> do
(Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr)
go Expr
e
(Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> Expr
Negate' Expr
sign, Expr
a, Expr
b, Expr
c, Expr -> Expr
Negate' Expr
d)
Plus' Expr
e1 Expr
e2 | Expr -> Bool
isConstantTimeExpr Expr
e2 -> do
(Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr)
go Expr
e1
(Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr
sign, Expr
a, Expr
b, Expr
c, Expr -> Expr -> Expr
Plus' Expr
d Expr
e2)
Plus' Expr
e1 Expr
e2 | Expr -> Bool
isConstantTimeExpr Expr
e1 -> do
(Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr)
go Expr
e2
(Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr
sign, Expr
a, Expr
b, Expr
c, Expr -> Expr -> Expr
Plus' Expr
e1 Expr
d)
Minus' Expr
e1 Expr
e2 | Expr -> Bool
isConstantTimeExpr Expr
e2 -> do
(Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr)
go Expr
e1
(Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr
sign, Expr
a, Expr
b, Expr
c, Expr -> Expr -> Expr
Minus' Expr
d Expr
e2)
Minus' Expr
e1 Expr
e2 | Expr -> Bool
isConstantTimeExpr Expr
e1 -> do
(Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr)
go Expr
e2
(Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> Expr
Negate' Expr
sign, Expr
a, Expr
b, Expr
c, Expr -> Expr -> Expr
Minus' Expr
e1 Expr
d)
Mult' Expr
e1 Expr
e2 | Expr -> Bool
isConstantTimeExpr Expr
e2 -> do
(Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr)
go Expr
e1
(Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> Expr -> Expr
Mult' Expr
sign Expr
e2, Expr
a, Expr
b, Expr
c, Expr -> Expr -> Expr
Mult' Expr
d Expr
e2)
Mult' Expr
e1 Expr
e2 | Expr -> Bool
isConstantTimeExpr Expr
e1 -> do
(Expr
sign, Expr
a, Expr
b, Expr
c, Expr
d) <- Expr -> MaybeT m (Expr, Expr, Expr, Expr, Expr)
go Expr
e2
(Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> Expr -> Expr
Mult' Expr
e1 Expr
sign, Expr
a, Expr
b, Expr
c, Expr -> Expr -> Expr
Mult' Expr
e1 Expr
d)
Expr
_ -> Maybe (Expr, Expr, Expr, Expr, Expr)
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a
hoistMaybe Maybe (Expr, Expr, Expr, Expr, Expr)
forall a. Maybe a
Nothing
getLength :: Expr -> Maybe Integer
getLength :: Expr -> Maybe Integer
getLength = \case
Nil' Type
_ -> Integer -> Maybe Integer
forall a. a -> Maybe a
Just Integer
0
Cons' Type
_ Expr
_ Expr
xs -> Integer -> Integer
forall a. Enum a => a -> a
succ (Integer -> Integer) -> Maybe Integer -> Maybe Integer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> Maybe Integer
getLength Expr
xs
Snoc' Type
_ Expr
xs Expr
_ -> Integer -> Integer
forall a. Enum a => a -> a
succ (Integer -> Integer) -> Maybe Integer -> Maybe Integer
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> Maybe Integer
getLength Expr
xs
Expr
_ -> Maybe Integer
forall a. Maybe a
Nothing
rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m
rule :: RewriteRule m
rule = ([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall (m :: * -> *).
([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
RewriteRule (([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m)
-> ([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \[(VarName, Type)]
_ -> \case
Build' Type
IntTy (Lam VarName
f Type
_ Expr
step) Expr
base Expr
n -> MaybeT m Expr -> m (Maybe Expr)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m Expr -> m (Maybe Expr))
-> MaybeT m Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ do
VarName
i <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
Integer
k <- Maybe Integer -> MaybeT m Integer
forall (m :: * -> *) a. Applicative m => Maybe a -> MaybeT m a
hoistMaybe (Maybe Integer -> MaybeT m Integer)
-> Maybe Integer -> MaybeT m Integer
forall a b. (a -> b) -> a -> b
$ Expr -> Maybe Integer
getLength Expr
base
Expr
step <- VarName -> VarName -> Integer -> Expr -> MaybeT m Expr
forall (m :: * -> *).
MonadError Error m =>
VarName -> VarName -> Integer -> Expr -> m Expr
replaceLenF VarName
f VarName
i Integer
k Expr
step
(Expr
sign, Expr
a, Expr
c, Expr
b, Expr
d) <- m (Maybe (Expr, Expr, Expr, Expr, Expr))
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe (Expr, Expr, Expr, Expr, Expr))
-> MaybeT m (Expr, Expr, Expr, Expr, Expr))
-> m (Maybe (Expr, Expr, Expr, Expr, Expr))
-> MaybeT m (Expr, Expr, Expr, Expr, Expr)
forall a b. (a -> b) -> a -> b
$ VarName
-> VarName
-> Integer
-> Expr
-> m (Maybe (Expr, Expr, Expr, Expr, Expr))
forall (m :: * -> *).
MonadAlpha m =>
VarName
-> VarName
-> Integer
-> Expr
-> m (Maybe (Expr, Expr, Expr, Expr, Expr))
parseLinearFunctionBody VarName
f VarName
i Integer
k Expr
step
VarName
x <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
VarName
y <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m VarName
forall (m :: * -> *). MonadAlpha m => m VarName
genVarName'
VarName
f' <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VarName -> MaybeT m VarName) -> m VarName -> MaybeT m VarName
forall a b. (a -> b) -> a -> b
$ VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
f
let ts :: [Type]
ts = [Type
ConvexHullTrickTy, Type -> Type
ListTy Type
IntTy]
let base' :: Expr
base' = Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' [Type]
ts) [Expr
ConvexHullTrickInit', Expr
base]
let step' :: Expr
step' =
VarName -> Type -> VarName -> Type -> Expr -> Expr
Lam2 VarName
x ([Type] -> Type
TupleTy [Type]
ts) VarName
i Type
IntTy (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
VarName -> Type -> Expr -> Expr -> Expr
Let VarName
f (Type -> Type
ListTy Type
IntTy) ([Type] -> Int -> Expr -> Expr
Proj' [Type]
ts Int
1 (VarName -> Expr
Var VarName
x)) (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
VarName -> Type -> Expr -> Expr -> Expr
Let VarName
y Type
ConvexHullTrickTy (Expr -> Expr -> Expr -> Expr
ConvexHullTrickInsert' ([Type] -> Int -> Expr -> Expr
Proj' [Type]
ts Int
0 (VarName -> Expr
Var VarName
x)) Expr
a Expr
b) (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
VarName -> Type -> Expr -> Expr -> Expr
Let VarName
f' (Type -> Type
ListTy Type
IntTy) (Type -> Expr -> Expr -> Expr
Snoc' Type
IntTy (VarName -> Expr
Var VarName
f) (Expr -> Expr -> Expr
Plus' (Expr -> Expr -> Expr
Mult' Expr
sign (Expr -> Expr -> Expr
ConvexHullTrickGetMin' (VarName -> Expr
Var VarName
y) Expr
c)) Expr
d)) (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' [Type]
ts) [VarName -> Expr
Var VarName
y, VarName -> Expr
Var VarName
f']
Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ [Type] -> Int -> Expr -> Expr
Proj' [Type]
ts Int
1 (Type -> Type -> Expr -> Expr -> Expr -> Expr
Foldl' Type
IntTy ([Type] -> Type
TupleTy [Type]
ts) Expr
step' Expr
base' (Expr -> Expr
Range1' Expr
n))
Expr
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
runProgram :: Program -> m Program
runProgram = RewriteRule m -> Program -> m Program
forall (m :: * -> *).
MonadError Error m =>
RewriteRule m -> Program -> m Program
applyRewriteRuleProgram' RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
rule
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = String -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Convert.ConvexHullTrick" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
ensureWellTyped Program
prog
Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
runProgram Program
prog
m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
ensureWellTyped Program
prog
Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return Program
prog