{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

module Jikka.Core.Language.Util where

import Control.Monad.Identity
import Control.Monad.Writer (execWriter, tell)
import Data.Maybe (isJust)
import Data.Monoid (Dual (..))
import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr

genType :: MonadAlpha m => m Type
genType :: m Type
genType = do
  Int
i <- m Int
forall (m :: * -> *). MonadAlpha m => m Int
nextCounter
  Type -> m Type
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> m Type) -> Type -> m Type
forall a b. (a -> b) -> a -> b
$ TypeName -> Type
VarTy (String -> TypeName
TypeName (Char
'$' Char -> String -> String
forall a. a -> [a] -> [a]
: Int -> String
forall a. Show a => a -> String
show Int
i))

genVarName :: MonadAlpha m => VarName -> m VarName
genVarName :: VarName -> m VarName
genVarName VarName
x = do
  Int
i <- m Int
forall (m :: * -> *). MonadAlpha m => m Int
nextCounter
  let base :: String
base = if VarName -> String
unVarName VarName
x String -> String -> Bool
forall a. Eq a => a -> a -> Bool
== String
"_" then String
"" else (Char -> Bool) -> String -> String
forall a. (a -> Bool) -> [a] -> [a]
takeWhile (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
'$') (VarName -> String
unVarName VarName
x)
  VarName -> m VarName
forall (m :: * -> *) a. Monad m => a -> m a
return (VarName -> m VarName) -> VarName -> m VarName
forall a b. (a -> b) -> a -> b
$ String -> VarName
VarName (String
base String -> String -> String
forall a. [a] -> [a] -> [a]
++ Char
'$' Char -> String -> String
forall a. a -> [a] -> [a]
: Int -> String
forall a. Show a => a -> String
show Int
i)

genVarName' :: MonadAlpha m => m VarName
genVarName' :: m VarName
genVarName' = VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName (String -> VarName
VarName String
"_")

mapTypeInBuiltin :: (Type -> Type) -> Builtin -> Builtin
mapTypeInBuiltin :: (Type -> Type) -> Builtin -> Builtin
mapTypeInBuiltin Type -> Type
f = \case
  -- arithmetical functions
  Builtin
Negate -> Builtin
Negate
  Builtin
Plus -> Builtin
Plus
  Builtin
Minus -> Builtin
Minus
  Builtin
Mult -> Builtin
Mult
  Builtin
FloorDiv -> Builtin
FloorDiv
  Builtin
FloorMod -> Builtin
FloorMod
  Builtin
CeilDiv -> Builtin
CeilDiv
  Builtin
CeilMod -> Builtin
CeilMod
  Builtin
Pow -> Builtin
Pow
  -- advanced arithmetical functions
  Builtin
Abs -> Builtin
Abs
  Builtin
Gcd -> Builtin
Gcd
  Builtin
Lcm -> Builtin
Lcm
  Min2 Type
t -> Type -> Builtin
Min2 (Type -> Type
f Type
t)
  Max2 Type
t -> Type -> Builtin
Max2 (Type -> Type
f Type
t)
  Iterate Type
t -> Type -> Builtin
Iterate (Type -> Type
f Type
t)
  -- logical functionslogical
  Builtin
Not -> Builtin
Not
  Builtin
And -> Builtin
And
  Builtin
Or -> Builtin
Or
  Builtin
Implies -> Builtin
Implies
  If Type
t -> Type -> Builtin
If (Type -> Type
f Type
t)
  -- bitwise functionsbitwise
  Builtin
BitNot -> Builtin
BitNot
  Builtin
BitAnd -> Builtin
BitAnd
  Builtin
BitOr -> Builtin
BitOr
  Builtin
BitXor -> Builtin
BitXor
  Builtin
BitLeftShift -> Builtin
BitLeftShift
  Builtin
BitRightShift -> Builtin
BitRightShift
  -- matrix functions
  MatAp Int
h Int
w -> Int -> Int -> Builtin
MatAp Int
h Int
w
  MatZero Int
n -> Int -> Builtin
MatZero Int
n
  MatOne Int
n -> Int -> Builtin
MatOne Int
n
  MatAdd Int
h Int
w -> Int -> Int -> Builtin
MatAdd Int
h Int
w
  MatMul Int
h Int
n Int
w -> Int -> Int -> Int -> Builtin
MatMul Int
h Int
n Int
w
  MatPow Int
n -> Int -> Builtin
MatPow Int
n
  VecFloorMod Int
n -> Int -> Builtin
VecFloorMod Int
n
  MatFloorMod Int
h Int
w -> Int -> Int -> Builtin
MatFloorMod Int
h Int
w
  -- modular functionsmodular
  Builtin
ModNegate -> Builtin
ModNegate
  Builtin
ModPlus -> Builtin
ModPlus
  Builtin
ModMinus -> Builtin
ModMinus
  Builtin
ModMult -> Builtin
ModMult
  Builtin
ModInv -> Builtin
ModInv
  Builtin
ModPow -> Builtin
ModPow
  ModMatAp Int
h Int
w -> Int -> Int -> Builtin
ModMatAp Int
h Int
w
  ModMatAdd Int
h Int
w -> Int -> Int -> Builtin
ModMatAdd Int
h Int
w
  ModMatMul Int
h Int
n Int
w -> Int -> Int -> Int -> Builtin
ModMatMul Int
h Int
n Int
w
  ModMatPow Int
n -> Int -> Builtin
ModMatPow Int
n
  -- list functionslist
  Cons Type
t -> Type -> Builtin
Cons (Type -> Type
f Type
t)
  Snoc Type
t -> Type -> Builtin
Snoc (Type -> Type
f Type
t)
  Foldl Type
t1 Type
t2 -> Type -> Type -> Builtin
Foldl (Type -> Type
f Type
t1) (Type -> Type
f Type
t2)
  Scanl Type
t1 Type
t2 -> Type -> Type -> Builtin
Scanl (Type -> Type
f Type
t1) (Type -> Type
f Type
t2)
  Build Type
t -> Type -> Builtin
Build (Type -> Type
f Type
t)
  Len Type
t -> Type -> Builtin
Len (Type -> Type
f Type
t)
  Map Type
t1 Type
t2 -> Type -> Type -> Builtin
Map (Type -> Type
f Type
t1) (Type -> Type
f Type
t2)
  Filter Type
t -> Type -> Builtin
Filter (Type -> Type
f Type
t)
  At Type
t -> Type -> Builtin
At (Type -> Type
f Type
t)
  SetAt Type
t -> Type -> Builtin
SetAt (Type -> Type
f Type
t)
  Elem Type
t -> Type -> Builtin
Elem (Type -> Type
f Type
t)
  Builtin
Sum -> Builtin
Sum
  Builtin
Product -> Builtin
Product
  Builtin
ModSum -> Builtin
ModSum
  Builtin
ModProduct -> Builtin
ModProduct
  Min1 Type
t -> Type -> Builtin
Min1 (Type -> Type
f Type
t)
  Max1 Type
t -> Type -> Builtin
Max1 (Type -> Type
f Type
t)
  ArgMin Type
t -> Type -> Builtin
ArgMin (Type -> Type
f Type
t)
  ArgMax Type
t -> Type -> Builtin
ArgMax (Type -> Type
f Type
t)
  Builtin
All -> Builtin
All
  Builtin
Any -> Builtin
Any
  Sorted Type
t -> Type -> Builtin
Sorted (Type -> Type
f Type
t)
  Reversed Type
t -> Type -> Builtin
Reversed (Type -> Type
f Type
t)
  Builtin
Range1 -> Builtin
Range1
  Builtin
Range2 -> Builtin
Range2
  Builtin
Range3 -> Builtin
Range3
  -- tuple functions
  Tuple [Type]
ts -> [Type] -> Builtin
Tuple ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
f [Type]
ts)
  Proj [Type]
ts Int
n -> [Type] -> Int -> Builtin
Proj ((Type -> Type) -> [Type] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map Type -> Type
f [Type]
ts) Int
n
  -- comparison
  LessThan Type
t -> Type -> Builtin
LessThan (Type -> Type
f Type
t)
  LessEqual Type
t -> Type -> Builtin
LessEqual (Type -> Type
f Type
t)
  GreaterThan Type
t -> Type -> Builtin
GreaterThan (Type -> Type
f Type
t)
  GreaterEqual Type
t -> Type -> Builtin
GreaterEqual (Type -> Type
f Type
t)
  Equal Type
t -> Type -> Builtin
Equal (Type -> Type
f Type
t)
  NotEqual Type
t -> Type -> Builtin
NotEqual (Type -> Type
f Type
t)
  -- combinational functions
  Builtin
Fact -> Builtin
Fact
  Builtin
Choose -> Builtin
Choose
  Builtin
Permute -> Builtin
Permute
  Builtin
MultiChoose -> Builtin
MultiChoose
  -- data structures
  Builtin
ConvexHullTrickInit -> Builtin
ConvexHullTrickInit
  Builtin
ConvexHullTrickInsert -> Builtin
ConvexHullTrickInsert
  Builtin
ConvexHullTrickGetMin -> Builtin
ConvexHullTrickGetMin
  SegmentTreeInitList Semigroup'
semigrp -> Semigroup' -> Builtin
SegmentTreeInitList Semigroup'
semigrp
  SegmentTreeGetRange Semigroup'
semigrp -> Semigroup' -> Builtin
SegmentTreeGetRange Semigroup'
semigrp
  SegmentTreeSetPoint Semigroup'
semigrp -> Semigroup' -> Builtin
SegmentTreeSetPoint Semigroup'
semigrp

-- | `mapExprM'` substitutes exprs using given two functions, which are called in pre-order and post-order.
mapExprM' :: Monad m => ([(VarName, Type)] -> Expr -> m Expr) -> ([(VarName, Type)] -> Expr -> m Expr) -> [(VarName, Type)] -> Expr -> m Expr
mapExprM' :: ([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> Expr
-> m Expr
mapExprM' [(VarName, Type)] -> Expr -> m Expr
pre [(VarName, Type)] -> Expr -> m Expr
post [(VarName, Type)]
env Expr
e = do
  Expr
e <- [(VarName, Type)] -> Expr -> m Expr
pre [(VarName, Type)]
env Expr
e
  let go :: [(VarName, Type)] -> Expr -> m Expr
go = ([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> Expr
-> m Expr
forall (m :: * -> *).
Monad m =>
([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> Expr
-> m Expr
mapExprM' [(VarName, Type)] -> Expr -> m Expr
pre [(VarName, Type)] -> Expr -> m Expr
post
  Expr
e <- case Expr
e of
    Var VarName
y -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr
Var VarName
y
    Lit Literal
lit -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ Literal -> Expr
Lit Literal
lit
    App Expr
g Expr
e -> Expr -> Expr -> Expr
App (Expr -> Expr -> Expr) -> m Expr -> m (Expr -> Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(VarName, Type)] -> Expr -> m Expr
go [(VarName, Type)]
env Expr
g m (Expr -> Expr) -> m Expr -> m Expr
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(VarName, Type)] -> Expr -> m Expr
go [(VarName, Type)]
env Expr
e
    Lam VarName
x Type
t Expr
body -> VarName -> Type -> Expr -> Expr
Lam VarName
x Type
t (Expr -> Expr) -> m Expr -> m Expr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(VarName, Type)] -> Expr -> m Expr
go ((VarName
x, Type
t) (VarName, Type) -> [(VarName, Type)] -> [(VarName, Type)]
forall a. a -> [a] -> [a]
: [(VarName, Type)]
env) Expr
body
    Let VarName
y Type
t Expr
e1 Expr
e2 -> VarName -> Type -> Expr -> Expr -> Expr
Let VarName
y Type
t (Expr -> Expr -> Expr) -> m Expr -> m (Expr -> Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(VarName, Type)] -> Expr -> m Expr
go [(VarName, Type)]
env Expr
e1 m (Expr -> Expr) -> m Expr -> m Expr
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [(VarName, Type)] -> Expr -> m Expr
go ((VarName
y, Type
t) (VarName, Type) -> [(VarName, Type)] -> [(VarName, Type)]
forall a. a -> [a] -> [a]
: [(VarName, Type)]
env) Expr
e2
  [(VarName, Type)] -> Expr -> m Expr
post [(VarName, Type)]
env Expr
e

mapExprToplevelExprM' :: Monad m => ([(VarName, Type)] -> Expr -> m Expr) -> ([(VarName, Type)] -> Expr -> m Expr) -> [(VarName, Type)] -> ToplevelExpr -> m ToplevelExpr
mapExprToplevelExprM' :: ([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> ToplevelExpr
-> m ToplevelExpr
mapExprToplevelExprM' [(VarName, Type)] -> Expr -> m Expr
pre [(VarName, Type)] -> Expr -> m Expr
post [(VarName, Type)]
env = \case
  ResultExpr Expr
e -> Expr -> ToplevelExpr
ResultExpr (Expr -> ToplevelExpr) -> m Expr -> m ToplevelExpr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> Expr
-> m Expr
forall (m :: * -> *).
Monad m =>
([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> Expr
-> m Expr
mapExprM' [(VarName, Type)] -> Expr -> m Expr
pre [(VarName, Type)] -> Expr -> m Expr
post [(VarName, Type)]
env Expr
e
  ToplevelLet VarName
y Type
t Expr
e ToplevelExpr
cont ->
    VarName -> Type -> Expr -> ToplevelExpr -> ToplevelExpr
ToplevelLet VarName
y Type
t (Expr -> ToplevelExpr -> ToplevelExpr)
-> m Expr -> m (ToplevelExpr -> ToplevelExpr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> Expr
-> m Expr
forall (m :: * -> *).
Monad m =>
([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> Expr
-> m Expr
mapExprM' [(VarName, Type)] -> Expr -> m Expr
pre [(VarName, Type)] -> Expr -> m Expr
post [(VarName, Type)]
env Expr
e m (ToplevelExpr -> ToplevelExpr)
-> m ToplevelExpr -> m ToplevelExpr
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> ToplevelExpr
-> m ToplevelExpr
forall (m :: * -> *).
Monad m =>
([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> ToplevelExpr
-> m ToplevelExpr
mapExprToplevelExprM' [(VarName, Type)] -> Expr -> m Expr
pre [(VarName, Type)] -> Expr -> m Expr
post ((VarName
y, Type
t) (VarName, Type) -> [(VarName, Type)] -> [(VarName, Type)]
forall a. a -> [a] -> [a]
: [(VarName, Type)]
env) ToplevelExpr
cont
  ToplevelLetRec VarName
g [(VarName, Type)]
args Type
ret Expr
body ToplevelExpr
cont ->
    let env' :: [(VarName, Type)]
env' = (VarName
g, ((VarName, Type) -> Type -> Type)
-> Type -> [(VarName, Type)] -> Type
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (Type -> Type -> Type
FunTy (Type -> Type -> Type)
-> ((VarName, Type) -> Type) -> (VarName, Type) -> Type -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VarName, Type) -> Type
forall a b. (a, b) -> b
snd) Type
ret [(VarName, Type)]
args) (VarName, Type) -> [(VarName, Type)] -> [(VarName, Type)]
forall a. a -> [a] -> [a]
: [(VarName, Type)]
env
     in VarName
-> [(VarName, Type)]
-> Type
-> Expr
-> ToplevelExpr
-> ToplevelExpr
ToplevelLetRec VarName
g [(VarName, Type)]
args Type
ret (Expr -> ToplevelExpr -> ToplevelExpr)
-> m Expr -> m (ToplevelExpr -> ToplevelExpr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> Expr
-> m Expr
forall (m :: * -> *).
Monad m =>
([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> Expr
-> m Expr
mapExprM' [(VarName, Type)] -> Expr -> m Expr
pre [(VarName, Type)] -> Expr -> m Expr
post ([(VarName, Type)] -> [(VarName, Type)]
forall a. [a] -> [a]
reverse [(VarName, Type)]
args [(VarName, Type)] -> [(VarName, Type)] -> [(VarName, Type)]
forall a. [a] -> [a] -> [a]
++ [(VarName, Type)]
env') Expr
body m (ToplevelExpr -> ToplevelExpr)
-> m ToplevelExpr -> m ToplevelExpr
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> ToplevelExpr
-> m ToplevelExpr
forall (m :: * -> *).
Monad m =>
([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> ToplevelExpr
-> m ToplevelExpr
mapExprToplevelExprM' [(VarName, Type)] -> Expr -> m Expr
pre [(VarName, Type)] -> Expr -> m Expr
post [(VarName, Type)]
env' ToplevelExpr
cont

mapExprProgramM' :: Monad m => ([(VarName, Type)] -> Expr -> m Expr) -> ([(VarName, Type)] -> Expr -> m Expr) -> Program -> m Program
mapExprProgramM' :: ([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> ToplevelExpr
-> m ToplevelExpr
mapExprProgramM' [(VarName, Type)] -> Expr -> m Expr
pre [(VarName, Type)] -> Expr -> m Expr
post = ([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> ToplevelExpr
-> m ToplevelExpr
forall (m :: * -> *).
Monad m =>
([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> ToplevelExpr
-> m ToplevelExpr
mapExprToplevelExprM' [(VarName, Type)] -> Expr -> m Expr
pre [(VarName, Type)] -> Expr -> m Expr
post []

-- | `mapExprM` is a wrapper of `mapExprM'`. This function works in post-order.
mapExprM :: Monad m => ([(VarName, Type)] -> Expr -> m Expr) -> [(VarName, Type)] -> Expr -> m Expr
mapExprM :: ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)] -> Expr -> m Expr
mapExprM [(VarName, Type)] -> Expr -> m Expr
f = ([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> Expr
-> m Expr
forall (m :: * -> *).
Monad m =>
([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> Expr
-> m Expr
mapExprM' (\[(VarName, Type)]
_ Expr
e -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e) [(VarName, Type)] -> Expr -> m Expr
f

mapExprToplevelExprM :: Monad m => ([(VarName, Type)] -> Expr -> m Expr) -> [(VarName, Type)] -> ToplevelExpr -> m ToplevelExpr
mapExprToplevelExprM :: ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)] -> ToplevelExpr -> m ToplevelExpr
mapExprToplevelExprM [(VarName, Type)] -> Expr -> m Expr
f = ([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> ToplevelExpr
-> m ToplevelExpr
forall (m :: * -> *).
Monad m =>
([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)]
-> ToplevelExpr
-> m ToplevelExpr
mapExprToplevelExprM' (\[(VarName, Type)]
_ Expr
e -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e) [(VarName, Type)] -> Expr -> m Expr
f

mapExprProgramM :: Monad m => ([(VarName, Type)] -> Expr -> m Expr) -> Program -> m Program
mapExprProgramM :: ([(VarName, Type)] -> Expr -> m Expr)
-> ToplevelExpr -> m ToplevelExpr
mapExprProgramM [(VarName, Type)] -> Expr -> m Expr
f = ([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> ToplevelExpr
-> m ToplevelExpr
forall (m :: * -> *).
Monad m =>
([(VarName, Type)] -> Expr -> m Expr)
-> ([(VarName, Type)] -> Expr -> m Expr)
-> ToplevelExpr
-> m ToplevelExpr
mapExprProgramM' (\[(VarName, Type)]
_ Expr
e -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e) [(VarName, Type)] -> Expr -> m Expr
f

mapExpr :: ([(VarName, Type)] -> Expr -> Expr) -> [(VarName, Type)] -> Expr -> Expr
mapExpr :: ([(VarName, Type)] -> Expr -> Expr)
-> [(VarName, Type)] -> Expr -> Expr
mapExpr [(VarName, Type)] -> Expr -> Expr
f [(VarName, Type)]
env Expr
e = Identity Expr -> Expr
forall a. Identity a -> a
runIdentity (Identity Expr -> Expr) -> Identity Expr -> Expr
forall a b. (a -> b) -> a -> b
$ ([(VarName, Type)] -> Expr -> Identity Expr)
-> [(VarName, Type)] -> Expr -> Identity Expr
forall (m :: * -> *).
Monad m =>
([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)] -> Expr -> m Expr
mapExprM (\[(VarName, Type)]
env Expr
e -> Expr -> Identity Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> Identity Expr) -> Expr -> Identity Expr
forall a b. (a -> b) -> a -> b
$ [(VarName, Type)] -> Expr -> Expr
f [(VarName, Type)]
env Expr
e) [(VarName, Type)]
env Expr
e

mapExprToplevelExpr :: ([(VarName, Type)] -> Expr -> Expr) -> [(VarName, Type)] -> ToplevelExpr -> ToplevelExpr
mapExprToplevelExpr :: ([(VarName, Type)] -> Expr -> Expr)
-> [(VarName, Type)] -> ToplevelExpr -> ToplevelExpr
mapExprToplevelExpr [(VarName, Type)] -> Expr -> Expr
f [(VarName, Type)]
env ToplevelExpr
e = Identity ToplevelExpr -> ToplevelExpr
forall a. Identity a -> a
runIdentity (Identity ToplevelExpr -> ToplevelExpr)
-> Identity ToplevelExpr -> ToplevelExpr
forall a b. (a -> b) -> a -> b
$ ([(VarName, Type)] -> Expr -> Identity Expr)
-> [(VarName, Type)] -> ToplevelExpr -> Identity ToplevelExpr
forall (m :: * -> *).
Monad m =>
([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)] -> ToplevelExpr -> m ToplevelExpr
mapExprToplevelExprM (\[(VarName, Type)]
env Expr
e -> Expr -> Identity Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> Identity Expr) -> Expr -> Identity Expr
forall a b. (a -> b) -> a -> b
$ [(VarName, Type)] -> Expr -> Expr
f [(VarName, Type)]
env Expr
e) [(VarName, Type)]
env ToplevelExpr
e

mapExprProgram :: ([(VarName, Type)] -> Expr -> Expr) -> Program -> Program
mapExprProgram :: ([(VarName, Type)] -> Expr -> Expr) -> ToplevelExpr -> ToplevelExpr
mapExprProgram [(VarName, Type)] -> Expr -> Expr
f ToplevelExpr
prog = Identity ToplevelExpr -> ToplevelExpr
forall a. Identity a -> a
runIdentity (Identity ToplevelExpr -> ToplevelExpr)
-> Identity ToplevelExpr -> ToplevelExpr
forall a b. (a -> b) -> a -> b
$ ([(VarName, Type)] -> Expr -> Identity Expr)
-> ToplevelExpr -> Identity ToplevelExpr
forall (m :: * -> *).
Monad m =>
([(VarName, Type)] -> Expr -> m Expr)
-> ToplevelExpr -> m ToplevelExpr
mapExprProgramM (\[(VarName, Type)]
env Expr
e -> Expr -> Identity Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> Identity Expr) -> Expr -> Identity Expr
forall a b. (a -> b) -> a -> b
$ [(VarName, Type)] -> Expr -> Expr
f [(VarName, Type)]
env Expr
e) ToplevelExpr
prog

listSubExprs :: Expr -> [Expr]
listSubExprs :: Expr -> [Expr]
listSubExprs Expr
e = Dual [Expr] -> [Expr]
forall a. Dual a -> a
getDual (Dual [Expr] -> [Expr])
-> (Writer (Dual [Expr]) Expr -> Dual [Expr])
-> Writer (Dual [Expr]) Expr
-> [Expr]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Writer (Dual [Expr]) Expr -> Dual [Expr]
forall w a. Writer w a -> w
execWriter (Writer (Dual [Expr]) Expr -> [Expr])
-> Writer (Dual [Expr]) Expr -> [Expr]
forall a b. (a -> b) -> a -> b
$ ([(VarName, Type)] -> Expr -> Writer (Dual [Expr]) Expr)
-> [(VarName, Type)] -> Expr -> Writer (Dual [Expr]) Expr
forall (m :: * -> *).
Monad m =>
([(VarName, Type)] -> Expr -> m Expr)
-> [(VarName, Type)] -> Expr -> m Expr
mapExprM [(VarName, Type)] -> Expr -> Writer (Dual [Expr]) Expr
forall (m :: * -> *) b p. MonadWriter (Dual [b]) m => p -> b -> m b
go [] Expr
e
  where
    go :: p -> b -> m b
go p
_ b
e = do
      Dual [b] -> m ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (Dual [b] -> m ()) -> Dual [b] -> m ()
forall a b. (a -> b) -> a -> b
$ [b] -> Dual [b]
forall a. a -> Dual a
Dual [b
e]
      b -> m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
e

uncurryFunTy :: Type -> ([Type], Type)
uncurryFunTy :: Type -> ([Type], Type)
uncurryFunTy = \case
  (FunTy Type
t Type
t') -> let ([Type]
ts, Type
ret) = Type -> ([Type], Type)
uncurryFunTy Type
t' in (Type
t Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
ts, Type
ret)
  Type
ret -> ([], Type
ret)

uncurryLam :: Expr -> ([(VarName, Type)], Expr)
uncurryLam :: Expr -> ([(VarName, Type)], Expr)
uncurryLam = \case
  Lam VarName
x Type
t Expr
body -> let ([(VarName, Type)]
args, Expr
body') = Expr -> ([(VarName, Type)], Expr)
uncurryLam Expr
body in ((VarName
x, Type
t) (VarName, Type) -> [(VarName, Type)] -> [(VarName, Type)]
forall a. a -> [a] -> [a]
: [(VarName, Type)]
args, Expr
body')
  Expr
body -> ([], Expr
body)

curryApp :: Expr -> (Expr, [Expr])
curryApp :: Expr -> (Expr, [Expr])
curryApp = \case
  App Expr
f Expr
e -> let (Expr
f', [Expr]
e') = Expr -> (Expr, [Expr])
curryApp Expr
f in (Expr
f', [Expr]
e' [Expr] -> [Expr] -> [Expr]
forall a. [a] -> [a] -> [a]
++ [Expr
e])
  Expr
f -> (Expr
f, [])

curryFunTy :: [Type] -> Type -> Type
curryFunTy :: [Type] -> Type -> Type
curryFunTy [Type]
ts Type
ret = (Type -> Type -> Type) -> Type -> [Type] -> Type
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Type -> Type -> Type
FunTy Type
ret [Type]
ts

curryLam :: [(VarName, Type)] -> Expr -> Expr
curryLam :: [(VarName, Type)] -> Expr -> Expr
curryLam [(VarName, Type)]
args Expr
body = ((VarName, Type) -> Expr -> Expr)
-> Expr -> [(VarName, Type)] -> Expr
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr ((VarName -> Type -> Expr -> Expr)
-> (VarName, Type) -> Expr -> Expr
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry VarName -> Type -> Expr -> Expr
Lam) Expr
body [(VarName, Type)]
args

uncurryApp :: Expr -> [Expr] -> Expr
uncurryApp :: Expr -> [Expr] -> Expr
uncurryApp = (Expr -> Expr -> Expr) -> Expr -> [Expr] -> Expr
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Expr -> Expr -> Expr
App

isVectorTy :: Type -> Bool
isVectorTy :: Type -> Bool
isVectorTy = Maybe Int -> Bool
forall a. Maybe a -> Bool
isJust (Maybe Int -> Bool) -> (Type -> Maybe Int) -> Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Maybe Int
sizeOfVectorTy

isVectorTy' :: [Type] -> Bool
isVectorTy' :: [Type] -> Bool
isVectorTy' = Type -> Bool
isVectorTy (Type -> Bool) -> ([Type] -> Type) -> [Type] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Type] -> Type
TupleTy

sizeOfVectorTy :: Type -> Maybe Int
sizeOfVectorTy :: Type -> Maybe Int
sizeOfVectorTy = \case
  TupleTy [Type]
ts | (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
IntTy) [Type]
ts -> Int -> Maybe Int
forall a. a -> Maybe a
Just ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts)
  Type
_ -> Maybe Int
forall a. Maybe a
Nothing

isMatrixTy :: Type -> Bool
isMatrixTy :: Type -> Bool
isMatrixTy = Maybe (Int, Int) -> Bool
forall a. Maybe a -> Bool
isJust (Maybe (Int, Int) -> Bool)
-> (Type -> Maybe (Int, Int)) -> Type -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Type -> Maybe (Int, Int)
sizeOfMatrixTy

isMatrixTy' :: [Type] -> Bool
isMatrixTy' :: [Type] -> Bool
isMatrixTy' = Type -> Bool
isMatrixTy (Type -> Bool) -> ([Type] -> Type) -> [Type] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Type] -> Type
TupleTy

sizeOfMatrixTy :: Type -> Maybe (Int, Int)
sizeOfMatrixTy :: Type -> Maybe (Int, Int)
sizeOfMatrixTy = \case
  TupleTy ts :: [Type]
ts@(TupleTy [Type]
ts' : [Type]
_) | (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Type
IntTy) [Type]
ts' Bool -> Bool -> Bool
&& (Type -> Bool) -> [Type] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== [Type] -> Type
TupleTy [Type]
ts') [Type]
ts -> (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just ([Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts, [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts')
  Type
_ -> Maybe (Int, Int)
forall a. Maybe a
Nothing

isConstantTimeBuiltin :: Builtin -> Bool
isConstantTimeBuiltin :: Builtin -> Bool
isConstantTimeBuiltin = \case
  -- arithmetical functions
  Builtin
Negate -> Bool
True
  Builtin
Plus -> Bool
True
  Builtin
Minus -> Bool
True
  Builtin
Mult -> Bool
True
  Builtin
FloorDiv -> Bool
True
  Builtin
FloorMod -> Bool
True
  Builtin
CeilDiv -> Bool
True
  Builtin
CeilMod -> Bool
True
  Builtin
Pow -> Bool
True
  -- advanced arithmetical functions
  Builtin
Abs -> Bool
True
  Builtin
Gcd -> Bool
True
  Builtin
Lcm -> Bool
True
  Min2 Type
_ -> Bool
True
  Max2 Type
_ -> Bool
True
  Iterate Type
_ -> Bool
False
  -- logical functions
  Builtin
Not -> Bool
True
  Builtin
And -> Bool
True
  Builtin
Or -> Bool
True
  Builtin
Implies -> Bool
True
  If Type
_ -> Bool
True
  -- bitwise functions
  Builtin
BitNot -> Bool
True
  Builtin
BitAnd -> Bool
True
  Builtin
BitOr -> Bool
True
  Builtin
BitXor -> Bool
True
  Builtin
BitLeftShift -> Bool
True
  Builtin
BitRightShift -> Bool
True
  -- matrix functions
  MatAp Int
_ Int
_ -> Bool
True
  MatZero Int
_ -> Bool
True
  MatOne Int
_ -> Bool
True
  MatAdd Int
_ Int
_ -> Bool
True
  MatMul Int
_ Int
_ Int
_ -> Bool
True
  MatPow Int
_ -> Bool
True
  VecFloorMod Int
_ -> Bool
True
  MatFloorMod Int
_ Int
_ -> Bool
True
  -- modular functions
  Builtin
ModNegate -> Bool
True
  Builtin
ModPlus -> Bool
True
  Builtin
ModMinus -> Bool
True
  Builtin
ModMult -> Bool
True
  Builtin
ModInv -> Bool
True
  Builtin
ModPow -> Bool
True
  ModMatAp Int
_ Int
_ -> Bool
True
  ModMatAdd Int
_ Int
_ -> Bool
True
  ModMatMul Int
_ Int
_ Int
_ -> Bool
True
  ModMatPow Int
_ -> Bool
True
  -- list functions
  Cons Type
_ -> Bool
False
  Snoc Type
_ -> Bool
False
  Foldl Type
_ Type
_ -> Bool
False
  Scanl Type
_ Type
_ -> Bool
False
  Build Type
_ -> Bool
False
  Len Type
_ -> Bool
True
  Map Type
_ Type
_ -> Bool
False
  Filter Type
_ -> Bool
False
  At Type
_ -> Bool
True
  SetAt Type
_ -> Bool
False
  Elem Type
_ -> Bool
False
  Builtin
Sum -> Bool
False
  Builtin
Product -> Bool
False
  Builtin
ModSum -> Bool
False
  Builtin
ModProduct -> Bool
False
  Min1 Type
_ -> Bool
False
  Max1 Type
_ -> Bool
False
  ArgMin Type
_ -> Bool
False
  ArgMax Type
_ -> Bool
False
  Builtin
All -> Bool
False
  Builtin
Any -> Bool
False
  Sorted Type
_ -> Bool
False
  Reversed Type
_ -> Bool
False
  Builtin
Range1 -> Bool
False
  Builtin
Range2 -> Bool
False
  Builtin
Range3 -> Bool
False
  -- tuple functions
  Tuple [Type]
_ -> Bool
True
  Proj [Type]
_ Int
_ -> Bool
True
  -- comparison
  LessThan Type
_ -> Bool
True
  LessEqual Type
_ -> Bool
True
  GreaterThan Type
_ -> Bool
True
  GreaterEqual Type
_ -> Bool
True
  Equal Type
_ -> Bool
True
  NotEqual Type
_ -> Bool
True
  -- combinational functions
  Builtin
Fact -> Bool
True
  Builtin
Choose -> Bool
True
  Builtin
Permute -> Bool
True
  Builtin
MultiChoose -> Bool
True
  -- data structures
  Builtin
ConvexHullTrickInit -> Bool
False
  Builtin
ConvexHullTrickInsert -> Bool
False
  Builtin
ConvexHullTrickGetMin -> Bool
False
  SegmentTreeInitList Semigroup'
_ -> Bool
False
  SegmentTreeGetRange Semigroup'
_ -> Bool
False
  SegmentTreeSetPoint Semigroup'
_ -> Bool
False

-- | `isConstantTimeExpr` checks whether given exprs are suitable to propagate.
isConstantTimeExpr :: Expr -> Bool
isConstantTimeExpr :: Expr -> Bool
isConstantTimeExpr = \case
  Var VarName
_ -> Bool
True
  Lit Literal
_ -> Bool
True
  e :: Expr
e@(App Expr
_ Expr
_) -> case Expr -> (Expr, [Expr])
curryApp Expr
e of
    (Lit (LitBuiltin Builtin
f), [Expr]
args) -> Builtin -> Bool
isConstantTimeBuiltin Builtin
f Bool -> Bool -> Bool
&& (Expr -> Bool) -> [Expr] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all Expr -> Bool
isConstantTimeExpr [Expr]
args
    (Expr, [Expr])
_ -> Bool
False
  Lam VarName
_ Type
_ Expr
_ -> Bool
True
  Let VarName
_ Type
_ Expr
e1 Expr
e2 -> Expr -> Bool
isConstantTimeExpr Expr
e1 Bool -> Bool -> Bool
&& Expr -> Bool
isConstantTimeExpr Expr
e2

-- | `replaceLenF` replaces @len(f)@ in an expr with @i + k@.
-- * This assumes that there are no name conflicts.
replaceLenF :: MonadError Error m => VarName -> VarName -> Integer -> Expr -> m Expr
replaceLenF :: VarName -> VarName -> Integer -> Expr -> m Expr
replaceLenF VarName
f VarName
i Integer
k = Expr -> m Expr
go
  where
    go :: Expr -> m Expr
go = \case
      Len' Type
_ (Var VarName
f') | VarName
f' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
f -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr
Plus' (VarName -> Expr
Var VarName
i) (Integer -> Expr
LitInt' Integer
k)
      Var VarName
y -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ VarName -> Expr
Var VarName
y
      Lit Literal
lit -> Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return (Expr -> m Expr) -> Expr -> m Expr
forall a b. (a -> b) -> a -> b
$ Literal -> Expr
Lit Literal
lit
      App Expr
g Expr
e -> Expr -> Expr -> Expr
App (Expr -> Expr -> Expr) -> m Expr -> m (Expr -> Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> m Expr
go Expr
g m (Expr -> Expr) -> m Expr -> m Expr
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Expr -> m Expr
go Expr
e
      Lam VarName
x Type
_ Expr
_ | VarName
x VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
i -> String -> m Expr
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwInternalError String
"Jikka.Core.Language.Util.replaceLenF: name conflict"
      Lam VarName
x Type
t Expr
body -> VarName -> Type -> Expr -> Expr
Lam VarName
x Type
t (Expr -> Expr) -> m Expr -> m Expr
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (if VarName
x VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
f then Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
body else Expr -> m Expr
go Expr
body)
      Let VarName
y Type
_ Expr
_ Expr
_ | VarName
y VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
i -> String -> m Expr
forall (m :: * -> *) a. MonadError Error m => String -> m a
throwInternalError String
"Jikka.Core.Language.Util.replaceLenF: name conflict"
      Let VarName
y Type
t Expr
e1 Expr
e2 -> VarName -> Type -> Expr -> Expr -> Expr
Let VarName
y Type
t (Expr -> Expr -> Expr) -> m Expr -> m (Expr -> Expr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Expr -> m Expr
go Expr
e1 m (Expr -> Expr) -> m Expr -> m Expr
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (if VarName
y VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
f then Expr -> m Expr
forall (m :: * -> *) a. Monad m => a -> m a
return Expr
e2 else Expr -> m Expr
go Expr
e2)