{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ViewPatterns #-}
module Jikka.Core.Convert.SegmentTree
( run,
rule,
reduceCumulativeSum,
reduceMin,
)
where
import Control.Arrow
import Control.Monad.Trans.Maybe
import Data.List
import qualified Data.Map as M
import Data.Maybe
import Jikka.Common.Alpha
import Jikka.Common.Error
import qualified Jikka.Core.Convert.Alpha as Alpha
import Jikka.Core.Language.Beta
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.FreeVars
import Jikka.Core.Language.Lint
import Jikka.Core.Language.RewriteRules
import Jikka.Core.Language.Util
pattern $bCumulativeSum :: Type -> Expr -> Expr -> Expr -> Expr
$mCumulativeSum :: forall r.
Expr -> (Type -> Expr -> Expr -> Expr -> r) -> (Void# -> r) -> r
CumulativeSum t f e es <-
( \case
Scanl' t t' (Lam2 x1 t'' x2 t''' (App (App f (Var x1')) (Var x2'))) e es
| t == t' && t' == t'' && t'' == t''' && x1 == x1' && x1 `isUnusedVar` f && x2 == x2' && x2 `isUnusedVar` f -> Just (t, f, e, es)
_ -> Nothing ->
Just (t, f, e, es)
)
where
CumulativeSum Type
t Expr
f Expr
e Expr
es =
let x1 :: VarName
x1 = VarName -> Expr -> VarName
findUnusedVarName (String -> VarName
VarName String
"y") Expr
f
x2 :: VarName
x2 = VarName -> Expr -> VarName
findUnusedVarName (String -> VarName
VarName String
"x") Expr
f
in Type -> Type -> Expr -> Expr -> Expr -> Expr
Scanl' Type
t Type
t (VarName -> Type -> VarName -> Type -> Expr -> Expr
Lam2 VarName
x1 Type
t VarName
x2 Type
t (Expr -> Expr -> Expr
App (Expr -> Expr -> Expr
App Expr
f (VarName -> Expr
Var VarName
x1)) (VarName -> Expr
Var VarName
x2))) Expr
e Expr
es
pattern $bCumulativeSumFlip :: Type -> Expr -> Expr -> Expr -> Expr
$mCumulativeSumFlip :: forall r.
Expr -> (Type -> Expr -> Expr -> Expr -> r) -> (Void# -> r) -> r
CumulativeSumFlip t f e es <-
( \case
Scanl' t t' (Lam2 x1 t'' x2 t''' (App (App f (Var x2')) (Var x1'))) e es
| t == t' && t' == t'' && t'' == t''' && x2 == x2' && x2 `isUnusedVar` f && x1 == x1' && x1 `isUnusedVar` f -> Just (t, f, e, es)
_ -> Nothing ->
Just (t, f, e, es)
)
where
CumulativeSumFlip Type
t Expr
f Expr
e Expr
es =
let x1 :: VarName
x1 = VarName -> Expr -> VarName
findUnusedVarName (String -> VarName
VarName String
"y") Expr
f
x2 :: VarName
x2 = VarName -> Expr -> VarName
findUnusedVarName (String -> VarName
VarName String
"x") Expr
f
in Type -> Type -> Expr -> Expr -> Expr -> Expr
Scanl' Type
t Type
t (VarName -> Type -> VarName -> Type -> Expr -> Expr
Lam2 VarName
x1 Type
t VarName
x2 Type
t (Expr -> Expr -> Expr
App (Expr -> Expr -> Expr
App Expr
f (VarName -> Expr
Var VarName
x2)) (VarName -> Expr
Var VarName
x1))) Expr
e Expr
es
builtinToSemigroup :: Builtin -> Maybe Semigroup'
builtinToSemigroup :: Builtin -> Maybe Semigroup'
builtinToSemigroup = \case
Builtin
Plus -> Semigroup' -> Maybe Semigroup'
forall a. a -> Maybe a
Just Semigroup'
SemigroupIntPlus
Min2 Type
IntTy -> Semigroup' -> Maybe Semigroup'
forall a. a -> Maybe a
Just Semigroup'
SemigroupIntMin
Max2 Type
IntTy -> Semigroup' -> Maybe Semigroup'
forall a. a -> Maybe a
Just Semigroup'
SemigroupIntMax
Builtin
_ -> Maybe Semigroup'
forall a. Maybe a
Nothing
semigroupToBuiltin :: Semigroup' -> Builtin
semigroupToBuiltin :: Semigroup' -> Builtin
semigroupToBuiltin = \case
Semigroup'
SemigroupIntPlus -> Builtin
Plus
Semigroup'
SemigroupIntMin -> Type -> Builtin
Min2 Type
IntTy
Semigroup'
SemigroupIntMax -> Type -> Builtin
Max2 Type
IntTy
unCumulativeSum :: Expr -> Expr -> Maybe (Semigroup', Expr)
unCumulativeSum :: Expr -> Expr -> Maybe (Semigroup', Expr)
unCumulativeSum Expr
a = \case
CumulativeSum Type
_ (Lit (LitBuiltin Builtin
op)) Expr
b Expr
a' | Expr
a' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
a -> case Builtin -> Maybe Semigroup'
builtinToSemigroup Builtin
op of
Just Semigroup'
semigrp -> (Semigroup', Expr) -> Maybe (Semigroup', Expr)
forall a. a -> Maybe a
Just (Semigroup'
semigrp, Expr
b)
Maybe Semigroup'
Nothing -> Maybe (Semigroup', Expr)
forall a. Maybe a
Nothing
CumulativeSumFlip Type
_ (Lit (LitBuiltin Builtin
op)) Expr
b Expr
a' | Expr
a' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
== Expr
a -> case Builtin -> Maybe Semigroup'
builtinToSemigroup Builtin
op of
Just Semigroup'
semigrp -> (Semigroup', Expr) -> Maybe (Semigroup', Expr)
forall a. a -> Maybe a
Just (Semigroup'
semigrp, Expr
b)
Maybe Semigroup'
Nothing -> Maybe (Semigroup', Expr)
forall a. Maybe a
Nothing
Expr
_ -> Maybe (Semigroup', Expr)
forall a. Maybe a
Nothing
listCumulativeSum :: Expr -> Expr -> [(Semigroup', Expr)]
listCumulativeSum :: Expr -> Expr -> [(Semigroup', Expr)]
listCumulativeSum Expr
a = (Expr -> Maybe (Semigroup', Expr))
-> [Expr] -> [(Semigroup', Expr)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Expr -> Expr -> Maybe (Semigroup', Expr)
unCumulativeSum Expr
a) ([Expr] -> [(Semigroup', Expr)])
-> (Expr -> [Expr]) -> Expr -> [(Semigroup', Expr)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> [Expr]
listSubExprs
replaceWithSegtrees :: VarName -> [(Semigroup', Expr)] -> Expr -> Expr
replaceWithSegtrees :: VarName -> [(Semigroup', Expr)] -> Expr -> Expr
replaceWithSegtrees VarName
a [(Semigroup', Expr)]
segtrees = Map VarName (Expr, Expr, Semigroup') -> Expr -> Expr
go Map VarName (Expr, Expr, Semigroup')
forall k a. Map k a
M.empty
where
go :: M.Map VarName (Expr, Expr, Semigroup') -> Expr -> Expr
go :: Map VarName (Expr, Expr, Semigroup') -> Expr -> Expr
go Map VarName (Expr, Expr, Semigroup')
env = \case
At' Type
_ (Map VarName (Expr, Expr, Semigroup')
-> Expr -> Maybe (Expr, Expr, Semigroup')
check Map VarName (Expr, Expr, Semigroup')
env -> Just (Expr
e, Expr
b, Semigroup'
semigrp)) Expr
i ->
let e' :: Expr
e' = Semigroup' -> Expr -> Expr -> Expr -> Expr
SegmentTreeGetRange' Semigroup'
semigrp Expr
e (Integer -> Expr
LitInt' Integer
0) Expr
i
in Builtin -> Expr -> Expr -> Expr
AppBuiltin2 (Semigroup' -> Builtin
semigroupToBuiltin Semigroup'
semigrp) Expr
b Expr
e'
Var VarName
x -> VarName -> Expr
Var VarName
x
Lit Literal
lit -> Literal -> Expr
Lit Literal
lit
App Expr
e1 Expr
e2 -> Expr -> Expr -> Expr
App (Map VarName (Expr, Expr, Semigroup') -> Expr -> Expr
go Map VarName (Expr, Expr, Semigroup')
env Expr
e1) (Map VarName (Expr, Expr, Semigroup') -> Expr -> Expr
go Map VarName (Expr, Expr, Semigroup')
env Expr
e2)
Lam VarName
x Type
t Expr
e -> VarName -> Type -> Expr -> Expr
Lam VarName
x Type
t (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$ Map VarName (Expr, Expr, Semigroup') -> Expr -> Expr
go (VarName
-> Map VarName (Expr, Expr, Semigroup')
-> Map VarName (Expr, Expr, Semigroup')
forall k a. Ord k => k -> Map k a -> Map k a
M.delete VarName
x Map VarName (Expr, Expr, Semigroup')
env) Expr
e
Let VarName
x Type
t Expr
e1 Expr
e2 ->
let e1' :: Expr
e1' = Map VarName (Expr, Expr, Semigroup') -> Expr -> Expr
go Map VarName (Expr, Expr, Semigroup')
env Expr
e1
in case Map VarName (Expr, Expr, Semigroup')
-> Expr -> Maybe (Expr, Expr, Semigroup')
check Map VarName (Expr, Expr, Semigroup')
env Expr
e1' of
Just (Expr
e1', Expr
b, Semigroup'
semigrp) -> Map VarName (Expr, Expr, Semigroup') -> Expr -> Expr
go (VarName
-> (Expr, Expr, Semigroup')
-> Map VarName (Expr, Expr, Semigroup')
-> Map VarName (Expr, Expr, Semigroup')
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert VarName
x (Expr
e1', Expr
b, Semigroup'
semigrp) Map VarName (Expr, Expr, Semigroup')
env) Expr
e2
Maybe (Expr, Expr, Semigroup')
Nothing -> VarName -> Type -> Expr -> Expr -> Expr
Let VarName
x Type
t (Map VarName (Expr, Expr, Semigroup') -> Expr -> Expr
go Map VarName (Expr, Expr, Semigroup')
env Expr
e1) (Map VarName (Expr, Expr, Semigroup') -> Expr -> Expr
go Map VarName (Expr, Expr, Semigroup')
env Expr
e2)
check :: Map VarName (Expr, Expr, Semigroup')
-> Expr -> Maybe (Expr, Expr, Semigroup')
check Map VarName (Expr, Expr, Semigroup')
env = \case
Var VarName
x -> VarName
-> Map VarName (Expr, Expr, Semigroup')
-> Maybe (Expr, Expr, Semigroup')
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VarName
x Map VarName (Expr, Expr, Semigroup')
env
CumulativeSum Type
_ (Lit (LitBuiltin Builtin
op)) Expr
b (Var VarName
a') | VarName
a' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
a -> case Builtin -> [(Builtin, Expr)] -> Maybe Expr
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Builtin
op (((Semigroup', Expr) -> (Builtin, Expr))
-> [(Semigroup', Expr)] -> [(Builtin, Expr)]
forall a b. (a -> b) -> [a] -> [b]
map ((Semigroup' -> Builtin) -> (Semigroup', Expr) -> (Builtin, Expr)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Semigroup' -> Builtin
semigroupToBuiltin) [(Semigroup', Expr)]
segtrees) of
Just Expr
e -> (Expr, Expr, Semigroup') -> Maybe (Expr, Expr, Semigroup')
forall a. a -> Maybe a
Just (Expr
e, Expr
b, Maybe Semigroup' -> Semigroup'
forall a. HasCallStack => Maybe a -> a
fromJust (Builtin -> Maybe Semigroup'
builtinToSemigroup Builtin
op))
Maybe Expr
Nothing -> Maybe (Expr, Expr, Semigroup')
forall a. Maybe a
Nothing
Expr
_ -> Maybe (Expr, Expr, Semigroup')
forall a. Maybe a
Nothing
reduceCumulativeSum :: (MonadAlpha m, MonadError Error m) => RewriteRule m
reduceCumulativeSum :: RewriteRule m
reduceCumulativeSum = ([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall (m :: * -> *).
([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
RewriteRule (([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m)
-> ([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \[(VarName, Type)]
_ -> \case
Foldl' Type
t1 Type
t2 (Lam2 VarName
a Type
_ VarName
i Type
_ (SetAt' Type
t (Var VarName
a') Expr
index Expr
e)) Expr
base Expr
indices | VarName
a' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
a Bool -> Bool -> Bool
&& VarName
a VarName -> Expr -> Bool
`isUnusedVar` Expr
index -> MaybeT m Expr -> m (Maybe Expr)
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT m Expr -> m (Maybe Expr))
-> MaybeT m Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ do
let sums :: [(Semigroup', Expr)]
sums = Expr -> Expr -> [(Semigroup', Expr)]
listCumulativeSum (VarName -> Expr
Var VarName
a) Expr
e
Bool -> MaybeT m ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> MaybeT m ()) -> Bool -> MaybeT m ()
forall a b. (a -> b) -> a -> b
$ Bool -> Bool
not ([(Semigroup', Expr)] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(Semigroup', Expr)]
sums)
let semigrps :: [Semigroup']
semigrps = [Semigroup'] -> [Semigroup']
forall a. Eq a => [a] -> [a]
nub ([Semigroup'] -> [Semigroup']
forall a. Ord a => [a] -> [a]
sort (((Semigroup', Expr) -> Semigroup')
-> [(Semigroup', Expr)] -> [Semigroup']
forall a b. (a -> b) -> [a] -> [b]
map (Semigroup', Expr) -> Semigroup'
forall a b. (a, b) -> a
fst [(Semigroup', Expr)]
sums))
let ts :: [Type]
ts = Type
t2 Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: (Semigroup' -> Type) -> [Semigroup'] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Semigroup' -> Type
SegmentTreeTy [Semigroup']
semigrps
VarName
c <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VarName -> MaybeT m VarName) -> m VarName -> MaybeT m VarName
forall a b. (a -> b) -> a -> b
$ VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
a
let proj :: Int -> Expr
proj Int
i = [Type] -> Int -> Expr -> Expr
Proj' [Type]
ts Int
i (VarName -> Expr
Var VarName
c)
let e' :: Expr
e' = VarName -> [(Semigroup', Expr)] -> Expr -> Expr
replaceWithSegtrees VarName
a ([Semigroup'] -> [Expr] -> [(Semigroup', Expr)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Semigroup']
semigrps ((Int -> Expr) -> [Int] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Expr
proj [Int
1 ..])) Expr
e
Bool -> MaybeT m ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> MaybeT m ()) -> Bool -> MaybeT m ()
forall a b. (a -> b) -> a -> b
$ Expr
e' Expr -> Expr -> Bool
forall a. Eq a => a -> a -> Bool
/= Expr
e
Expr
e' <- m Expr -> MaybeT m Expr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m Expr -> MaybeT m Expr) -> m Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr -> Expr -> m Expr
forall (m :: * -> *).
MonadAlpha m =>
VarName -> Expr -> Expr -> m Expr
substitute VarName
a (Int -> Expr
proj Int
0) Expr
e'
VarName
b' <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VarName -> MaybeT m VarName) -> m VarName -> MaybeT m VarName
forall a b. (a -> b) -> a -> b
$ VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
a
let updateSegtrees :: Int -> Semigroup' -> Expr
updateSegtrees Int
i Semigroup'
semigrp = Semigroup' -> Expr -> Expr -> Expr -> Expr
SegmentTreeSetPoint' Semigroup'
semigrp (Int -> Expr
proj Int
i) Expr
index (Type -> Expr -> Expr -> Expr
At' Type
t (VarName -> Expr
Var VarName
b') Expr
index)
let step :: Expr
step = VarName -> Type -> VarName -> Type -> Expr -> Expr
Lam2 VarName
c ([Type] -> Type
TupleTy [Type]
ts) VarName
i Type
t1 (VarName -> Type -> Expr -> Expr -> Expr
Let VarName
b' Type
t2 (Type -> Expr -> Expr -> Expr -> Expr
SetAt' Type
t (Int -> Expr
proj Int
0) Expr
index Expr
e') (Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' [Type]
ts) (VarName -> Expr
Var VarName
b' Expr -> [Expr] -> [Expr]
forall a. a -> [a] -> [a]
: (Int -> Semigroup' -> Expr) -> [Int] -> [Semigroup'] -> [Expr]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Int -> Semigroup' -> Expr
updateSegtrees [Int
1 ..] [Semigroup']
semigrps)))
VarName
b <- m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m VarName -> MaybeT m VarName) -> m VarName -> MaybeT m VarName
forall a b. (a -> b) -> a -> b
$ VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
a
let base' :: [Expr]
base' = VarName -> Expr
Var VarName
b Expr -> [Expr] -> [Expr]
forall a. a -> [a] -> [a]
: (Semigroup' -> Expr) -> [Semigroup'] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map (\Semigroup'
semigrp -> Semigroup' -> Expr -> Expr
SegmentTreeInitList' Semigroup'
semigrp (VarName -> Expr
Var VarName
b)) [Semigroup']
semigrps
Expr -> MaybeT m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> MaybeT m Expr) -> Expr -> MaybeT m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Type -> Expr -> Expr -> Expr
Let VarName
b Type
t2 Expr
base ([Type] -> Int -> Expr -> Expr
Proj' [Type]
ts Int
0 (Type -> Type -> Expr -> Expr -> Expr -> Expr
Foldl' Type
t1 ([Type] -> Type
TupleTy [Type]
ts) Expr
step (Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' [Type]
ts) [Expr]
base') Expr
indices))
Expr
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
reduceMin :: MonadAlpha m => RewriteRule m
reduceMin :: RewriteRule m
reduceMin = ([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall (m :: * -> *).
([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
RewriteRule (([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m)
-> ([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \[(VarName, Type)]
_ -> \case
Expr
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
rule :: (MonadAlpha m, MonadError Error m) => RewriteRule m
rule :: RewriteRule m
rule =
[RewriteRule m] -> RewriteRule m
forall a. Monoid a => [a] -> a
mconcat
[ RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
reduceCumulativeSum,
RewriteRule m
forall (m :: * -> *). MonadAlpha m => RewriteRule m
reduceMin
]
runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
runProgram :: Program -> m Program
runProgram = RewriteRule m -> Program -> m Program
forall (m :: * -> *).
MonadError Error m =>
RewriteRule m -> Program -> m Program
applyRewriteRuleProgram' RewriteRule m
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
RewriteRule m
rule
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = String -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => String -> m a -> m a
wrapError' String
"Jikka.Core.Convert.SegmentTree" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
ensureWellTyped Program
prog
Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
runProgram Program
prog
Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
Alpha.run Program
prog
m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
ensureWellTyped Program
prog
Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return Program
prog