{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}

-- |
-- Module      : Jikka.Core.Convert.MatrixExponentiation
-- Description : replaces repeated applications of linear (or, affine) functions with powers of matrices. / 線形な (あるいは affine な) 関数の繰り返しの適用を行列累乗で置き換えます。
-- Copyright   : (c) Kimiyuki Onaka, 2021
-- License     : Apache License 2.0
-- Maintainer  : kimiyuki95@gmail.com
-- Stability   : experimental
-- Portability : portable
module Jikka.Core.Convert.MatrixExponentiation
  ( run,
  )
where

import Control.Monad.Trans
import Control.Monad.Trans.Maybe
import qualified Data.Vector as V
import Jikka.Common.Alpha
import Jikka.Common.Error
import Jikka.Common.Matrix
import Jikka.Core.Language.ArithmeticalExpr
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.FreeVars
import Jikka.Core.Language.Lint
import Jikka.Core.Language.RewriteRules
import Jikka.Core.Language.Util

toLinearExpression :: VarName -> Expr -> Maybe (Maybe Expr, Maybe Expr)
toLinearExpression :: VarName -> Expr -> Maybe (Maybe Expr, Maybe Expr)
toLinearExpression VarName
x Expr
e = do
  (Vector ArithmeticalExpr
a, ArithmeticalExpr
b) <- Vector VarName
-> ArithmeticalExpr
-> Maybe (Vector ArithmeticalExpr, ArithmeticalExpr)
makeVectorFromArithmeticalExpr (VarName -> Vector VarName
forall a. a -> Vector a
V.singleton VarName
x) (Expr -> ArithmeticalExpr
parseArithmeticalExpr Expr
e)
  case Vector ArithmeticalExpr -> [ArithmeticalExpr]
forall a. Vector a -> [a]
V.toList Vector ArithmeticalExpr
a of
    [ArithmeticalExpr
a] ->
      let a' :: Maybe Expr
a' = if ArithmeticalExpr -> Bool
isOneArithmeticalExpr ArithmeticalExpr
a then Maybe Expr
forall a. Maybe a
Nothing else Expr -> Maybe Expr
forall a. a -> Maybe a
Just (ArithmeticalExpr -> Expr
formatArithmeticalExpr ArithmeticalExpr
a)
          b' :: Maybe Expr
b' = if ArithmeticalExpr -> Bool
isZeroArithmeticalExpr ArithmeticalExpr
b then Maybe Expr
forall a. Maybe a
Nothing else Expr -> Maybe Expr
forall a. a -> Maybe a
Just (ArithmeticalExpr -> Expr
formatArithmeticalExpr ArithmeticalExpr
b)
       in (Maybe Expr, Maybe Expr) -> Maybe (Maybe Expr, Maybe Expr)
forall a. a -> Maybe a
Just (Maybe Expr
a', Maybe Expr
b')
    [ArithmeticalExpr]
_ -> [Char] -> Maybe (Maybe Expr, Maybe Expr)
forall a. HasCallStack => [Char] -> a
error ([Char] -> Maybe (Maybe Expr, Maybe Expr))
-> [Char] -> Maybe (Maybe Expr, Maybe Expr)
forall a b. (a -> b) -> a -> b
$ [Char]
"Jikka.Core.Convert.MatrixExponentiation.toLinearExpression: size mismtach: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show (Vector ArithmeticalExpr -> Int
forall a. Vector a -> Int
V.length Vector ArithmeticalExpr
a)

fromMatrix :: Matrix ArithmeticalExpr -> Expr
fromMatrix :: Matrix ArithmeticalExpr -> Expr
fromMatrix Matrix ArithmeticalExpr
f =
  let (Int
h, Int
w) = Matrix ArithmeticalExpr -> (Int, Int)
forall a. Matrix a -> (Int, Int)
matsize Matrix ArithmeticalExpr
f
      go :: Vector ArithmeticalExpr -> Expr
go Vector ArithmeticalExpr
row = Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
w Type
IntTy)) ((ArithmeticalExpr -> Expr) -> [ArithmeticalExpr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map ArithmeticalExpr -> Expr
formatArithmeticalExpr (Vector ArithmeticalExpr -> [ArithmeticalExpr]
forall a. Vector a -> [a]
V.toList Vector ArithmeticalExpr
row))
   in Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
h ([Type] -> Type
TupleTy (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
w Type
IntTy)))) ((Vector ArithmeticalExpr -> Expr)
-> [Vector ArithmeticalExpr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map Vector ArithmeticalExpr -> Expr
go (Vector (Vector ArithmeticalExpr) -> [Vector ArithmeticalExpr]
forall a. Vector a -> [a]
V.toList (Matrix ArithmeticalExpr -> Vector (Vector ArithmeticalExpr)
forall a. Matrix a -> Vector (Vector a)
unMatrix Matrix ArithmeticalExpr
f)))

fromAffineMatrix :: Matrix ArithmeticalExpr -> V.Vector ArithmeticalExpr -> Expr
fromAffineMatrix :: Matrix ArithmeticalExpr -> Vector ArithmeticalExpr -> Expr
fromAffineMatrix Matrix ArithmeticalExpr
a Vector ArithmeticalExpr
b | (Int, Int) -> Int
forall a b. (a, b) -> a
fst (Matrix ArithmeticalExpr -> (Int, Int)
forall a. Matrix a -> (Int, Int)
matsize Matrix ArithmeticalExpr
a) Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Vector ArithmeticalExpr -> Int
forall a. Vector a -> Int
V.length Vector ArithmeticalExpr
b = [Char] -> Expr
forall a. HasCallStack => [Char] -> a
error ([Char] -> Expr) -> [Char] -> Expr
forall a b. (a -> b) -> a -> b
$ [Char]
"Jikka.Core.Convert.MatrixExponentiation.fromAffineMatrix: size mismtach: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ (Int, Int) -> [Char]
forall a. Show a => a -> [Char]
show (Matrix ArithmeticalExpr -> (Int, Int)
forall a. Matrix a -> (Int, Int)
matsize Matrix ArithmeticalExpr
a) [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" and " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Int -> [Char]
forall a. Show a => a -> [Char]
show (Vector ArithmeticalExpr -> Int
forall a. Vector a -> Int
V.length Vector ArithmeticalExpr
b)
fromAffineMatrix Matrix ArithmeticalExpr
a Vector ArithmeticalExpr
b =
  let (Int
h, Int
w) = Matrix ArithmeticalExpr -> (Int, Int)
forall a. Matrix a -> (Int, Int)
matsize Matrix ArithmeticalExpr
a
      go :: Vector ArithmeticalExpr -> ArithmeticalExpr -> Expr
go Vector ArithmeticalExpr
row ArithmeticalExpr
c = Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Type
IntTy)) ((ArithmeticalExpr -> Expr) -> [ArithmeticalExpr] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map ArithmeticalExpr -> Expr
formatArithmeticalExpr (Vector ArithmeticalExpr -> [ArithmeticalExpr]
forall a. Vector a -> [a]
V.toList Vector ArithmeticalExpr
row [ArithmeticalExpr] -> [ArithmeticalExpr] -> [ArithmeticalExpr]
forall a. [a] -> [a] -> [a]
++ [ArithmeticalExpr
c]))
      bottom :: Expr
bottom = Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Type
IntTy)) (Int -> Expr -> [Expr]
forall a. Int -> a -> [a]
replicate Int
w (Integer -> Expr
LitInt' Integer
0) [Expr] -> [Expr] -> [Expr]
forall a. [a] -> [a] -> [a]
++ [Integer -> Expr
LitInt' Integer
1])
   in Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (Int
h Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ([Type] -> Type
TupleTy (Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate (Int
w Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Type
IntTy)))) (Vector Expr -> [Expr]
forall a. Vector a -> [a]
V.toList ((Vector ArithmeticalExpr -> ArithmeticalExpr -> Expr)
-> Vector (Vector ArithmeticalExpr)
-> Vector ArithmeticalExpr
-> Vector Expr
forall a b c. (a -> b -> c) -> Vector a -> Vector b -> Vector c
V.zipWith Vector ArithmeticalExpr -> ArithmeticalExpr -> Expr
go (Matrix ArithmeticalExpr -> Vector (Vector ArithmeticalExpr)
forall a. Matrix a -> Vector (Vector a)
unMatrix Matrix ArithmeticalExpr
a) Vector ArithmeticalExpr
b) [Expr] -> [Expr] -> [Expr]
forall a. [a] -> [a] -> [a]
++ [Expr
bottom])

toMatrix :: MonadAlpha m => [(VarName, Type)] -> VarName -> Int -> Expr -> m (Maybe (Matrix ArithmeticalExpr, Maybe (V.Vector ArithmeticalExpr)))
toMatrix :: [(VarName, Type)]
-> VarName
-> Int
-> Expr
-> m (Maybe
        (Matrix ArithmeticalExpr, Maybe (Vector ArithmeticalExpr)))
toMatrix [(VarName, Type)]
env VarName
x Int
n Expr
step =
  case Expr -> (Expr, [Expr])
curryApp Expr
step of
    (Tuple' [Type]
_, [Expr]
es) -> MaybeT m (Matrix ArithmeticalExpr, Maybe (Vector ArithmeticalExpr))
-> m (Maybe
        (Matrix ArithmeticalExpr, Maybe (Vector ArithmeticalExpr)))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT
   m (Matrix ArithmeticalExpr, Maybe (Vector ArithmeticalExpr))
 -> m (Maybe
         (Matrix ArithmeticalExpr, Maybe (Vector ArithmeticalExpr))))
-> MaybeT
     m (Matrix ArithmeticalExpr, Maybe (Vector ArithmeticalExpr))
-> m (Maybe
        (Matrix ArithmeticalExpr, Maybe (Vector ArithmeticalExpr)))
forall a b. (a -> b) -> a -> b
$ do
      Vector VarName
xs <- [VarName] -> Vector VarName
forall a. [a] -> Vector a
V.fromList ([VarName] -> Vector VarName)
-> MaybeT m [VarName] -> MaybeT m (Vector VarName)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> MaybeT m VarName -> MaybeT m [VarName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (m VarName -> MaybeT m VarName
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
x))
      let unpackTuple :: p -> Expr -> Expr
unpackTuple p
_ Expr
e = case Expr
e of
            Proj' [Type]
_ Int
i (Var VarName
x') | VarName
x' VarName -> VarName -> Bool
forall a. Eq a => a -> a -> Bool
== VarName
x -> VarName -> Expr
Var (Vector VarName
xs Vector VarName -> Int -> VarName
forall a. Vector a -> Int -> a
V.! Int
i)
            Expr
_ -> Expr
e
      [(Vector ArithmeticalExpr, ArithmeticalExpr)]
rows <- m (Maybe [(Vector ArithmeticalExpr, ArithmeticalExpr)])
-> MaybeT m [(Vector ArithmeticalExpr, ArithmeticalExpr)]
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe [(Vector ArithmeticalExpr, ArithmeticalExpr)])
 -> MaybeT m [(Vector ArithmeticalExpr, ArithmeticalExpr)])
-> ((Expr -> Maybe (Vector ArithmeticalExpr, ArithmeticalExpr))
    -> m (Maybe [(Vector ArithmeticalExpr, ArithmeticalExpr)]))
-> (Expr -> Maybe (Vector ArithmeticalExpr, ArithmeticalExpr))
-> MaybeT m [(Vector ArithmeticalExpr, ArithmeticalExpr)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe [(Vector ArithmeticalExpr, ArithmeticalExpr)]
-> m (Maybe [(Vector ArithmeticalExpr, ArithmeticalExpr)])
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe [(Vector ArithmeticalExpr, ArithmeticalExpr)]
 -> m (Maybe [(Vector ArithmeticalExpr, ArithmeticalExpr)]))
-> ((Expr -> Maybe (Vector ArithmeticalExpr, ArithmeticalExpr))
    -> Maybe [(Vector ArithmeticalExpr, ArithmeticalExpr)])
-> (Expr -> Maybe (Vector ArithmeticalExpr, ArithmeticalExpr))
-> m (Maybe [(Vector ArithmeticalExpr, ArithmeticalExpr)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Expr]
-> (Expr -> Maybe (Vector ArithmeticalExpr, ArithmeticalExpr))
-> Maybe [(Vector ArithmeticalExpr, ArithmeticalExpr)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM [Expr]
es ((Expr -> Maybe (Vector ArithmeticalExpr, ArithmeticalExpr))
 -> MaybeT m [(Vector ArithmeticalExpr, ArithmeticalExpr)])
-> (Expr -> Maybe (Vector ArithmeticalExpr, ArithmeticalExpr))
-> MaybeT m [(Vector ArithmeticalExpr, ArithmeticalExpr)]
forall a b. (a -> b) -> a -> b
$ \Expr
e -> do
        let e' :: Expr
e' = ([(VarName, Type)] -> Expr -> Expr)
-> [(VarName, Type)] -> Expr -> Expr
mapExpr [(VarName, Type)] -> Expr -> Expr
forall p. p -> Expr -> Expr
unpackTuple [(VarName, Type)]
env Expr
e
        Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ VarName
x VarName -> Expr -> Bool
`isUnusedVar` Expr
e'
        Vector VarName
-> ArithmeticalExpr
-> Maybe (Vector ArithmeticalExpr, ArithmeticalExpr)
makeVectorFromArithmeticalExpr Vector VarName
xs (Expr -> ArithmeticalExpr
parseArithmeticalExpr Expr
e')
      Matrix ArithmeticalExpr
a <- m (Maybe (Matrix ArithmeticalExpr))
-> MaybeT m (Matrix ArithmeticalExpr)
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (m (Maybe (Matrix ArithmeticalExpr))
 -> MaybeT m (Matrix ArithmeticalExpr))
-> (Maybe (Matrix ArithmeticalExpr)
    -> m (Maybe (Matrix ArithmeticalExpr)))
-> Maybe (Matrix ArithmeticalExpr)
-> MaybeT m (Matrix ArithmeticalExpr)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (Matrix ArithmeticalExpr)
-> m (Maybe (Matrix ArithmeticalExpr))
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe (Matrix ArithmeticalExpr)
 -> MaybeT m (Matrix ArithmeticalExpr))
-> Maybe (Matrix ArithmeticalExpr)
-> MaybeT m (Matrix ArithmeticalExpr)
forall a b. (a -> b) -> a -> b
$ Vector (Vector ArithmeticalExpr) -> Maybe (Matrix ArithmeticalExpr)
forall a. Vector (Vector a) -> Maybe (Matrix a)
makeMatrix ([Vector ArithmeticalExpr] -> Vector (Vector ArithmeticalExpr)
forall a. [a] -> Vector a
V.fromList (((Vector ArithmeticalExpr, ArithmeticalExpr)
 -> Vector ArithmeticalExpr)
-> [(Vector ArithmeticalExpr, ArithmeticalExpr)]
-> [Vector ArithmeticalExpr]
forall a b. (a -> b) -> [a] -> [b]
map (Vector ArithmeticalExpr, ArithmeticalExpr)
-> Vector ArithmeticalExpr
forall a b. (a, b) -> a
fst [(Vector ArithmeticalExpr, ArithmeticalExpr)]
rows))
      let b :: Maybe (Vector ArithmeticalExpr)
b = if ((Vector ArithmeticalExpr, ArithmeticalExpr) -> Bool)
-> [(Vector ArithmeticalExpr, ArithmeticalExpr)] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (ArithmeticalExpr -> Bool
isZeroArithmeticalExpr (ArithmeticalExpr -> Bool)
-> ((Vector ArithmeticalExpr, ArithmeticalExpr)
    -> ArithmeticalExpr)
-> (Vector ArithmeticalExpr, ArithmeticalExpr)
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vector ArithmeticalExpr, ArithmeticalExpr) -> ArithmeticalExpr
forall a b. (a, b) -> b
snd) [(Vector ArithmeticalExpr, ArithmeticalExpr)]
rows then Maybe (Vector ArithmeticalExpr)
forall a. Maybe a
Nothing else Vector ArithmeticalExpr -> Maybe (Vector ArithmeticalExpr)
forall a. a -> Maybe a
Just ([ArithmeticalExpr] -> Vector ArithmeticalExpr
forall a. [a] -> Vector a
V.fromList (((Vector ArithmeticalExpr, ArithmeticalExpr) -> ArithmeticalExpr)
-> [(Vector ArithmeticalExpr, ArithmeticalExpr)]
-> [ArithmeticalExpr]
forall a b. (a -> b) -> [a] -> [b]
map (Vector ArithmeticalExpr, ArithmeticalExpr) -> ArithmeticalExpr
forall a b. (a, b) -> b
snd [(Vector ArithmeticalExpr, ArithmeticalExpr)]
rows))
      (Matrix ArithmeticalExpr, Maybe (Vector ArithmeticalExpr))
-> MaybeT
     m (Matrix ArithmeticalExpr, Maybe (Vector ArithmeticalExpr))
forall (m :: * -> *) a. Monad m => a -> m a
return (Matrix ArithmeticalExpr
a, Maybe (Vector ArithmeticalExpr)
b)
    (Expr, [Expr])
_ -> Maybe (Matrix ArithmeticalExpr, Maybe (Vector ArithmeticalExpr))
-> m (Maybe
        (Matrix ArithmeticalExpr, Maybe (Vector ArithmeticalExpr)))
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe (Matrix ArithmeticalExpr, Maybe (Vector ArithmeticalExpr))
forall a. Maybe a
Nothing

addOneToVector :: Int -> VarName -> Expr
addOneToVector :: Int -> VarName -> Expr
addOneToVector Int
n VarName
x =
  let ts :: [Type]
ts = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
n Type
IntTy
   in Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' (Type
IntTy Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
ts)) ((Int -> Expr) -> [Int] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
i -> [Type] -> Int -> Expr -> Expr
Proj' [Type]
ts Int
i (VarName -> Expr
Var VarName
x)) [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [Expr] -> [Expr] -> [Expr]
forall a. [a] -> [a] -> [a]
++ [Integer -> Expr
LitInt' Integer
1])

removeOneFromVector :: Int -> VarName -> Expr
removeOneFromVector :: Int -> VarName -> Expr
removeOneFromVector Int
n VarName
x =
  let ts :: [Type]
ts = Int -> Type -> [Type]
forall a. Int -> a -> [a]
replicate Int
n Type
IntTy
   in Expr -> [Expr] -> Expr
uncurryApp ([Type] -> Expr
Tuple' [Type]
ts) ((Int -> Expr) -> [Int] -> [Expr]
forall a b. (a -> b) -> [a] -> [b]
map (\Int
i -> [Type] -> Int -> Expr -> Expr
Proj' (Type
IntTy Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
ts) Int
i (VarName -> Expr
Var VarName
x)) [Int
0 .. Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1])

rule :: MonadAlpha m => RewriteRule m
rule :: RewriteRule m
rule = ([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall (m :: * -> *).
([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
RewriteRule (([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m)
-> ([(VarName, Type)] -> Expr -> m (Maybe Expr)) -> RewriteRule m
forall a b. (a -> b) -> a -> b
$ \[(VarName, Type)]
env -> \case
  Iterate' Type
IntTy Expr
k (Lam VarName
x Type
_ Expr
step) Expr
base -> do
    let step' :: Maybe (Maybe Expr, Maybe Expr)
step' = VarName -> Expr -> Maybe (Maybe Expr, Maybe Expr)
toLinearExpression VarName
x Expr
step
    Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Expr -> m (Maybe Expr)) -> Maybe Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ case Maybe (Maybe Expr, Maybe Expr)
step' of
      Maybe (Maybe Expr, Maybe Expr)
Nothing -> Maybe Expr
forall a. Maybe a
Nothing
      Just (Maybe Expr
Nothing, Maybe Expr
Nothing) -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just Expr
base
      Just (Maybe Expr
Nothing, Just Expr
b) -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr
Plus' Expr
base (Expr -> Expr -> Expr
Mult' Expr
k Expr
b)
      Just (Just Expr
a, Maybe Expr
Nothing) -> Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr
Mult' (Expr -> Expr -> Expr
Pow' Expr
a Expr
k) Expr
base
      Just (Just Expr
a, Just Expr
b) ->
        let a' :: Expr
a' = Expr -> Expr -> Expr
Pow' Expr
a Expr
k
            b' :: Expr
b' = Expr -> Expr -> Expr
Mult' (Expr -> Expr -> Expr
FloorDiv' (Expr -> Expr -> Expr
Minus' (Expr -> Expr -> Expr
Pow' Expr
a Expr
k) (Integer -> Expr
LitInt' Integer
1)) (Expr -> Expr -> Expr
Minus' Expr
a (Integer -> Expr
LitInt' Integer
1))) Expr
b -- This division has no remainder.
         in Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> Maybe Expr) -> Expr -> Maybe Expr
forall a b. (a -> b) -> a -> b
$ Expr -> Expr -> Expr
Plus' (Expr -> Expr -> Expr
Mult' Expr
a' Expr
base) Expr
b'
  Iterate' (TupleTy [Type]
ts) Expr
k (Lam VarName
x Type
_ Expr
step) Expr
base | [Type] -> Bool
isVectorTy' [Type]
ts -> do
    let n :: Int
n = [Type] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
ts
    let go :: Int -> Expr -> Expr -> Expr
go Int
n Expr
step Expr
base = Int -> Int -> Expr -> Expr -> Expr
MatAp' Int
n Int
n (Int -> Expr -> Expr -> Expr
MatPow' Int
n Expr
step Expr
k) Expr
base
    Maybe (Matrix ArithmeticalExpr, Maybe (Vector ArithmeticalExpr))
step <- [(VarName, Type)]
-> VarName
-> Int
-> Expr
-> m (Maybe
        (Matrix ArithmeticalExpr, Maybe (Vector ArithmeticalExpr)))
forall (m :: * -> *).
MonadAlpha m =>
[(VarName, Type)]
-> VarName
-> Int
-> Expr
-> m (Maybe
        (Matrix ArithmeticalExpr, Maybe (Vector ArithmeticalExpr)))
toMatrix [(VarName, Type)]
env VarName
x Int
n Expr
step
    case Maybe (Matrix ArithmeticalExpr, Maybe (Vector ArithmeticalExpr))
step of
      Maybe (Matrix ArithmeticalExpr, Maybe (Vector ArithmeticalExpr))
Nothing -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing
      Just (Matrix ArithmeticalExpr
a, Maybe (Vector ArithmeticalExpr)
Nothing) -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Expr -> m (Maybe Expr))
-> (Expr -> Maybe Expr) -> Expr -> m (Maybe Expr)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$ Int -> Expr -> Expr -> Expr
go Int
n (Matrix ArithmeticalExpr -> Expr
fromMatrix Matrix ArithmeticalExpr
a) Expr
base
      Just (Matrix ArithmeticalExpr
a, Just Vector ArithmeticalExpr
b) -> do
        VarName
y <- VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
x
        VarName
z <- VarName -> m VarName
forall (m :: * -> *). MonadAlpha m => VarName -> m VarName
genVarName VarName
x
        Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe Expr -> m (Maybe Expr))
-> (Expr -> Maybe Expr) -> Expr -> m (Maybe Expr)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> Maybe Expr
forall a. a -> Maybe a
Just (Expr -> m (Maybe Expr)) -> Expr -> m (Maybe Expr)
forall a b. (a -> b) -> a -> b
$
          VarName -> Type -> Expr -> Expr -> Expr
Let VarName
y ([Type] -> Type
TupleTy [Type]
ts) Expr
base (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
            VarName -> Type -> Expr -> Expr -> Expr
Let VarName
z ([Type] -> Type
TupleTy (Type
IntTy Type -> [Type] -> [Type]
forall a. a -> [a] -> [a]
: [Type]
ts)) (Int -> Expr -> Expr -> Expr
go (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Matrix ArithmeticalExpr -> Vector ArithmeticalExpr -> Expr
fromAffineMatrix Matrix ArithmeticalExpr
a Vector ArithmeticalExpr
b) (Int -> VarName -> Expr
addOneToVector Int
n VarName
y)) (Expr -> Expr) -> Expr -> Expr
forall a b. (a -> b) -> a -> b
$
              Int -> VarName -> Expr
removeOneFromVector Int
n VarName
z
  Expr
_ -> Maybe Expr -> m (Maybe Expr)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe Expr
forall a. Maybe a
Nothing

runProgram :: (MonadAlpha m, MonadError Error m) => Program -> m Program
runProgram :: Program -> m Program
runProgram = RewriteRule m -> Program -> m Program
forall (m :: * -> *).
MonadError Error m =>
RewriteRule m -> Program -> m Program
applyRewriteRuleProgram' RewriteRule m
forall (m :: * -> *). MonadAlpha m => RewriteRule m
rule

-- | `run` simplifies an affine functions from vectors to vectors in @iterate@ (`Iterate`) functions.
--
-- == Examples
--
-- This makes matrix multiplication. Before:
--
-- > iterate n (fun xs -> (xs[0] + 2 * xs[1], xs[1])) xs
--
-- After:
--
-- > matap (matpow ((1, 2), (0, 1)) n) xs
--
-- Also this works on integers. Before:
--
-- > iterate n (fun x -> (2 x + 1)) x
--
-- After:
--
-- > (2 ** n) * x + (2 ** n - 1) / (n - 1)
run :: (MonadAlpha m, MonadError Error m) => Program -> m Program
run :: Program -> m Program
run Program
prog = [Char] -> m Program -> m Program
forall (m :: * -> *) a. MonadError Error m => [Char] -> m a -> m a
wrapError' [Char]
"Jikka.Core.Convert.MatrixExponentiation" (m Program -> m Program) -> m Program -> m Program
forall a b. (a -> b) -> a -> b
$ do
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
precondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
ensureWellTyped Program
prog
  Program
prog <- Program -> m Program
forall (m :: * -> *).
(MonadAlpha m, MonadError Error m) =>
Program -> m Program
runProgram Program
prog
  m () -> m ()
forall (m :: * -> *) a. MonadError Error m => m a -> m a
postcondition (m () -> m ()) -> m () -> m ()
forall a b. (a -> b) -> a -> b
$ do
    Program -> m ()
forall (m :: * -> *). MonadError Error m => Program -> m ()
ensureWellTyped Program
prog
  Program -> m Program
forall (m :: * -> *) a. Monad m => a -> m a
return Program
prog