{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE TupleSections #-}

module Jikka.Core.Language.ArithmeticalExpr
  ( -- * Basic functions
    ArithmeticalExpr,
    parseArithmeticalExpr,
    formatArithmeticalExpr,
    integerArithmeticalExpr,
    negateArithmeticalExpr,
    plusArithmeticalExpr,
    minusArithmeticalExpr,
    multArithmeticalExpr,
    isZeroArithmeticalExpr,
    isOneArithmeticalExpr,

    -- * Advanced functions
    unNPlusKPattern,
    makeVectorFromArithmeticalExpr,
    makeAffineFunctionFromArithmeticalExpr,
    splitConstantFactorArithmeticalExpr,
  )
where

import Control.Arrow
import Control.Monad
import Control.Monad.ST
import Control.Monad.Trans
import Control.Monad.Trans.Maybe
import Data.List (findIndices, groupBy, sort, sortBy)
import Data.STRef
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
import Jikka.Core.Language.BuiltinPatterns
import Jikka.Core.Language.Expr
import Jikka.Core.Language.FreeVars

data ProductExpr = ProductExpr
  { ProductExpr -> Integer
productExprConst :: Integer,
    ProductExpr -> [Expr]
productExprList :: [Expr]
  }
  deriving (ProductExpr -> ProductExpr -> Bool
(ProductExpr -> ProductExpr -> Bool)
-> (ProductExpr -> ProductExpr -> Bool) -> Eq ProductExpr
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: ProductExpr -> ProductExpr -> Bool
$c/= :: ProductExpr -> ProductExpr -> Bool
== :: ProductExpr -> ProductExpr -> Bool
$c== :: ProductExpr -> ProductExpr -> Bool
Eq, Eq ProductExpr
Eq ProductExpr
-> (ProductExpr -> ProductExpr -> Ordering)
-> (ProductExpr -> ProductExpr -> Bool)
-> (ProductExpr -> ProductExpr -> Bool)
-> (ProductExpr -> ProductExpr -> Bool)
-> (ProductExpr -> ProductExpr -> Bool)
-> (ProductExpr -> ProductExpr -> ProductExpr)
-> (ProductExpr -> ProductExpr -> ProductExpr)
-> Ord ProductExpr
ProductExpr -> ProductExpr -> Bool
ProductExpr -> ProductExpr -> Ordering
ProductExpr -> ProductExpr -> ProductExpr
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: ProductExpr -> ProductExpr -> ProductExpr
$cmin :: ProductExpr -> ProductExpr -> ProductExpr
max :: ProductExpr -> ProductExpr -> ProductExpr
$cmax :: ProductExpr -> ProductExpr -> ProductExpr
>= :: ProductExpr -> ProductExpr -> Bool
$c>= :: ProductExpr -> ProductExpr -> Bool
> :: ProductExpr -> ProductExpr -> Bool
$c> :: ProductExpr -> ProductExpr -> Bool
<= :: ProductExpr -> ProductExpr -> Bool
$c<= :: ProductExpr -> ProductExpr -> Bool
< :: ProductExpr -> ProductExpr -> Bool
$c< :: ProductExpr -> ProductExpr -> Bool
compare :: ProductExpr -> ProductExpr -> Ordering
$ccompare :: ProductExpr -> ProductExpr -> Ordering
$cp1Ord :: Eq ProductExpr
Ord, Int -> ProductExpr -> ShowS
[ProductExpr] -> ShowS
ProductExpr -> String
(Int -> ProductExpr -> ShowS)
-> (ProductExpr -> String)
-> ([ProductExpr] -> ShowS)
-> Show ProductExpr
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ProductExpr] -> ShowS
$cshowList :: [ProductExpr] -> ShowS
show :: ProductExpr -> String
$cshow :: ProductExpr -> String
showsPrec :: Int -> ProductExpr -> ShowS
$cshowsPrec :: Int -> ProductExpr -> ShowS
Show, ReadPrec [ProductExpr]
ReadPrec ProductExpr
Int -> ReadS ProductExpr
ReadS [ProductExpr]
(Int -> ReadS ProductExpr)
-> ReadS [ProductExpr]
-> ReadPrec ProductExpr
-> ReadPrec [ProductExpr]
-> Read ProductExpr
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [ProductExpr]
$creadListPrec :: ReadPrec [ProductExpr]
readPrec :: ReadPrec ProductExpr
$creadPrec :: ReadPrec ProductExpr
readList :: ReadS [ProductExpr]
$creadList :: ReadS [ProductExpr]
readsPrec :: Int -> ReadS ProductExpr
$creadsPrec :: Int -> ReadS ProductExpr
Read)

data SumExpr = SumExpr
  { SumExpr -> [ProductExpr]
sumExprList :: [ProductExpr],
    SumExpr -> Integer
sumExprConst :: Integer
  }
  deriving (SumExpr -> SumExpr -> Bool
(SumExpr -> SumExpr -> Bool)
-> (SumExpr -> SumExpr -> Bool) -> Eq SumExpr
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: SumExpr -> SumExpr -> Bool
$c/= :: SumExpr -> SumExpr -> Bool
== :: SumExpr -> SumExpr -> Bool
$c== :: SumExpr -> SumExpr -> Bool
Eq, Eq SumExpr
Eq SumExpr
-> (SumExpr -> SumExpr -> Ordering)
-> (SumExpr -> SumExpr -> Bool)
-> (SumExpr -> SumExpr -> Bool)
-> (SumExpr -> SumExpr -> Bool)
-> (SumExpr -> SumExpr -> Bool)
-> (SumExpr -> SumExpr -> SumExpr)
-> (SumExpr -> SumExpr -> SumExpr)
-> Ord SumExpr
SumExpr -> SumExpr -> Bool
SumExpr -> SumExpr -> Ordering
SumExpr -> SumExpr -> SumExpr
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: SumExpr -> SumExpr -> SumExpr
$cmin :: SumExpr -> SumExpr -> SumExpr
max :: SumExpr -> SumExpr -> SumExpr
$cmax :: SumExpr -> SumExpr -> SumExpr
>= :: SumExpr -> SumExpr -> Bool
$c>= :: SumExpr -> SumExpr -> Bool
> :: SumExpr -> SumExpr -> Bool
$c> :: SumExpr -> SumExpr -> Bool
<= :: SumExpr -> SumExpr -> Bool
$c<= :: SumExpr -> SumExpr -> Bool
< :: SumExpr -> SumExpr -> Bool
$c< :: SumExpr -> SumExpr -> Bool
compare :: SumExpr -> SumExpr -> Ordering
$ccompare :: SumExpr -> SumExpr -> Ordering
$cp1Ord :: Eq SumExpr
Ord, Int -> SumExpr -> ShowS
[SumExpr] -> ShowS
SumExpr -> String
(Int -> SumExpr -> ShowS)
-> (SumExpr -> String) -> ([SumExpr] -> ShowS) -> Show SumExpr
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [SumExpr] -> ShowS
$cshowList :: [SumExpr] -> ShowS
show :: SumExpr -> String
$cshow :: SumExpr -> String
showsPrec :: Int -> SumExpr -> ShowS
$cshowsPrec :: Int -> SumExpr -> ShowS
Show, ReadPrec [SumExpr]
ReadPrec SumExpr
Int -> ReadS SumExpr
ReadS [SumExpr]
(Int -> ReadS SumExpr)
-> ReadS [SumExpr]
-> ReadPrec SumExpr
-> ReadPrec [SumExpr]
-> Read SumExpr
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [SumExpr]
$creadListPrec :: ReadPrec [SumExpr]
readPrec :: ReadPrec SumExpr
$creadPrec :: ReadPrec SumExpr
readList :: ReadS [SumExpr]
$creadList :: ReadS [SumExpr]
readsPrec :: Int -> ReadS SumExpr
$creadsPrec :: Int -> ReadS SumExpr
Read)

newtype ArithmeticalExpr = ArithmeticalExpr {ArithmeticalExpr -> SumExpr
unArithmeticalExpr :: SumExpr}
  deriving (Int -> ArithmeticalExpr -> ShowS
[ArithmeticalExpr] -> ShowS
ArithmeticalExpr -> String
(Int -> ArithmeticalExpr -> ShowS)
-> (ArithmeticalExpr -> String)
-> ([ArithmeticalExpr] -> ShowS)
-> Show ArithmeticalExpr
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [ArithmeticalExpr] -> ShowS
$cshowList :: [ArithmeticalExpr] -> ShowS
show :: ArithmeticalExpr -> String
$cshow :: ArithmeticalExpr -> String
showsPrec :: Int -> ArithmeticalExpr -> ShowS
$cshowsPrec :: Int -> ArithmeticalExpr -> ShowS
Show)

instance Eq ArithmeticalExpr where
  ArithmeticalExpr
e1 == :: ArithmeticalExpr -> ArithmeticalExpr -> Bool
== ArithmeticalExpr
e2 = ArithmeticalExpr -> SumExpr
unArithmeticalExpr (ArithmeticalExpr -> ArithmeticalExpr
normalizeArithmeticalExpr ArithmeticalExpr
e1) SumExpr -> SumExpr -> Bool
forall a. Eq a => a -> a -> Bool
== ArithmeticalExpr -> SumExpr
unArithmeticalExpr (ArithmeticalExpr -> ArithmeticalExpr
normalizeArithmeticalExpr ArithmeticalExpr
e2)

instance Ord ArithmeticalExpr where
  ArithmeticalExpr
e1 compare :: ArithmeticalExpr -> ArithmeticalExpr -> Ordering
`compare` ArithmeticalExpr
e2 = ArithmeticalExpr -> SumExpr
unArithmeticalExpr (ArithmeticalExpr -> ArithmeticalExpr
normalizeArithmeticalExpr ArithmeticalExpr
e1) SumExpr -> SumExpr -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` ArithmeticalExpr -> SumExpr
unArithmeticalExpr (ArithmeticalExpr -> ArithmeticalExpr
normalizeArithmeticalExpr ArithmeticalExpr
e2)

integerProductExpr :: Integer -> ProductExpr
integerProductExpr :: Integer -> ProductExpr
integerProductExpr Integer
n =
  ProductExpr :: Integer -> [Expr] -> ProductExpr
ProductExpr
    { productExprConst :: Integer
productExprConst = Integer
n,
      productExprList :: [Expr]
productExprList = []
    }

negateProductExpr :: ProductExpr -> ProductExpr
negateProductExpr :: ProductExpr -> ProductExpr
negateProductExpr ProductExpr
e = ProductExpr
e {productExprConst :: Integer
productExprConst = Integer -> Integer
forall a. Num a => a -> a
negate (ProductExpr -> Integer
productExprConst ProductExpr
e)}

multProductExpr :: ProductExpr -> ProductExpr -> ProductExpr
multProductExpr :: ProductExpr -> ProductExpr -> ProductExpr
multProductExpr ProductExpr
e1 ProductExpr
e2 =
  ProductExpr :: Integer -> [Expr] -> ProductExpr
ProductExpr
    { productExprConst :: Integer
productExprConst = ProductExpr -> Integer
productExprConst ProductExpr
e1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* ProductExpr -> Integer
productExprConst ProductExpr
e2,
      productExprList :: [Expr]
productExprList = ProductExpr -> [Expr]
productExprList ProductExpr
e1 [Expr] -> [Expr] -> [Expr]
forall a. [a] -> [a] -> [a]
++ ProductExpr -> [Expr]
productExprList ProductExpr
e2
    }

parseProductExpr :: Expr -> ProductExpr
parseProductExpr :: Expr -> ProductExpr
parseProductExpr = \case
  LitInt' Integer
n -> ProductExpr :: Integer -> [Expr] -> ProductExpr
ProductExpr {productExprConst :: Integer
productExprConst = Integer
n, productExprList :: [Expr]
productExprList = []}
  Negate' Expr
e -> ProductExpr -> ProductExpr
negateProductExpr (Expr -> ProductExpr
parseProductExpr Expr
e)
  Mult' Expr
e1 Expr
e2 -> ProductExpr -> ProductExpr -> ProductExpr
multProductExpr (Expr -> ProductExpr
parseProductExpr Expr
e1) (Expr -> ProductExpr
parseProductExpr Expr
e2)
  Pow' Expr
e1 (LitInt' Integer
k) | Integer
0 Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
<= Integer
k Bool -> Bool -> Bool
&& Integer
k Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
10 -> (ProductExpr -> ProductExpr) -> ProductExpr -> [ProductExpr]
forall a. (a -> a) -> a -> [a]
iterate (ProductExpr -> ProductExpr -> ProductExpr
multProductExpr (Expr -> ProductExpr
parseProductExpr Expr
e1)) (Integer -> ProductExpr
integerProductExpr Integer
1) [ProductExpr] -> Int -> ProductExpr
forall a. [a] -> Int -> a
!! Integer -> Int
forall a. Num a => Integer -> a
fromInteger Integer
k
  Expr
e -> ProductExpr :: Integer -> [Expr] -> ProductExpr
ProductExpr {productExprConst :: Integer
productExprConst = Integer
1, productExprList :: [Expr]
productExprList = [Expr
e]}

sumExprFromProductExpr :: ProductExpr -> SumExpr
sumExprFromProductExpr :: ProductExpr -> SumExpr
sumExprFromProductExpr ProductExpr
e =
  SumExpr :: [ProductExpr] -> Integer -> SumExpr
SumExpr
    { sumExprList :: [ProductExpr]
sumExprList = [ProductExpr
e],
      sumExprConst :: Integer
sumExprConst = Integer
0
    }

arithmeticalExprFromProductExpr :: ProductExpr -> ArithmeticalExpr
arithmeticalExprFromProductExpr :: ProductExpr -> ArithmeticalExpr
arithmeticalExprFromProductExpr = SumExpr -> ArithmeticalExpr
ArithmeticalExpr (SumExpr -> ArithmeticalExpr)
-> (ProductExpr -> SumExpr) -> ProductExpr -> ArithmeticalExpr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ProductExpr -> SumExpr
sumExprFromProductExpr

integerSumExpr :: Integer -> SumExpr
integerSumExpr :: Integer -> SumExpr
integerSumExpr Integer
n =
  SumExpr :: [ProductExpr] -> Integer -> SumExpr
SumExpr
    { sumExprConst :: Integer
sumExprConst = Integer
n,
      sumExprList :: [ProductExpr]
sumExprList = []
    }

integerArithmeticalExpr :: Integer -> ArithmeticalExpr
integerArithmeticalExpr :: Integer -> ArithmeticalExpr
integerArithmeticalExpr = SumExpr -> ArithmeticalExpr
ArithmeticalExpr (SumExpr -> ArithmeticalExpr)
-> (Integer -> SumExpr) -> Integer -> ArithmeticalExpr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> SumExpr
integerSumExpr

negateSumExpr :: SumExpr -> SumExpr
negateSumExpr :: SumExpr -> SumExpr
negateSumExpr SumExpr
e =
  SumExpr :: [ProductExpr] -> Integer -> SumExpr
SumExpr
    { sumExprList :: [ProductExpr]
sumExprList = (ProductExpr -> ProductExpr) -> [ProductExpr] -> [ProductExpr]
forall a b. (a -> b) -> [a] -> [b]
map ProductExpr -> ProductExpr
negateProductExpr (SumExpr -> [ProductExpr]
sumExprList SumExpr
e),
      sumExprConst :: Integer
sumExprConst = Integer -> Integer
forall a. Num a => a -> a
negate (SumExpr -> Integer
sumExprConst SumExpr
e)
    }

plusSumExpr :: SumExpr -> SumExpr -> SumExpr
plusSumExpr :: SumExpr -> SumExpr -> SumExpr
plusSumExpr SumExpr
e1 SumExpr
e2 =
  SumExpr :: [ProductExpr] -> Integer -> SumExpr
SumExpr
    { sumExprList :: [ProductExpr]
sumExprList = SumExpr -> [ProductExpr]
sumExprList SumExpr
e1 [ProductExpr] -> [ProductExpr] -> [ProductExpr]
forall a. [a] -> [a] -> [a]
++ SumExpr -> [ProductExpr]
sumExprList SumExpr
e2,
      sumExprConst :: Integer
sumExprConst = SumExpr -> Integer
sumExprConst SumExpr
e1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ SumExpr -> Integer
sumExprConst SumExpr
e2
    }

multSumExpr :: SumExpr -> SumExpr -> SumExpr
multSumExpr :: SumExpr -> SumExpr -> SumExpr
multSumExpr SumExpr
e1 SumExpr
e2 =
  SumExpr :: [ProductExpr] -> Integer -> SumExpr
SumExpr
    { sumExprList :: [ProductExpr]
sumExprList =
        let es1 :: [ProductExpr]
es1 = Expr -> ProductExpr
parseProductExpr (Integer -> Expr
LitInt' (SumExpr -> Integer
sumExprConst SumExpr
e1)) ProductExpr -> [ProductExpr] -> [ProductExpr]
forall a. a -> [a] -> [a]
: SumExpr -> [ProductExpr]
sumExprList SumExpr
e1
            es2 :: [ProductExpr]
es2 = Expr -> ProductExpr
parseProductExpr (Integer -> Expr
LitInt' (SumExpr -> Integer
sumExprConst SumExpr
e2)) ProductExpr -> [ProductExpr] -> [ProductExpr]
forall a. a -> [a] -> [a]
: SumExpr -> [ProductExpr]
sumExprList SumExpr
e2
         in ((ProductExpr, ProductExpr) -> ProductExpr)
-> [(ProductExpr, ProductExpr)] -> [ProductExpr]
forall a b. (a -> b) -> [a] -> [b]
map ((ProductExpr -> ProductExpr -> ProductExpr)
-> (ProductExpr, ProductExpr) -> ProductExpr
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ProductExpr -> ProductExpr -> ProductExpr
multProductExpr) ((,) (ProductExpr -> ProductExpr -> (ProductExpr, ProductExpr))
-> [ProductExpr] -> [ProductExpr -> (ProductExpr, ProductExpr)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [ProductExpr]
es1 [ProductExpr -> (ProductExpr, ProductExpr)]
-> [ProductExpr] -> [(ProductExpr, ProductExpr)]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [ProductExpr]
es2),
      sumExprConst :: Integer
sumExprConst = SumExpr -> Integer
sumExprConst SumExpr
e1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* SumExpr -> Integer
sumExprConst SumExpr
e2
    }

negateArithmeticalExpr :: ArithmeticalExpr -> ArithmeticalExpr
negateArithmeticalExpr :: ArithmeticalExpr -> ArithmeticalExpr
negateArithmeticalExpr (ArithmeticalExpr SumExpr
e) = SumExpr -> ArithmeticalExpr
ArithmeticalExpr (SumExpr -> ArithmeticalExpr) -> SumExpr -> ArithmeticalExpr
forall a b. (a -> b) -> a -> b
$ SumExpr -> SumExpr
negateSumExpr SumExpr
e

plusArithmeticalExpr :: ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
plusArithmeticalExpr :: ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
plusArithmeticalExpr (ArithmeticalExpr SumExpr
e1) (ArithmeticalExpr SumExpr
e2) = SumExpr -> ArithmeticalExpr
ArithmeticalExpr (SumExpr -> ArithmeticalExpr) -> SumExpr -> ArithmeticalExpr
forall a b. (a -> b) -> a -> b
$ SumExpr -> SumExpr -> SumExpr
plusSumExpr SumExpr
e1 SumExpr
e2

minusArithmeticalExpr :: ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
minusArithmeticalExpr :: ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
minusArithmeticalExpr (ArithmeticalExpr SumExpr
e1) (ArithmeticalExpr SumExpr
e2) = SumExpr -> ArithmeticalExpr
ArithmeticalExpr (SumExpr -> ArithmeticalExpr) -> SumExpr -> ArithmeticalExpr
forall a b. (a -> b) -> a -> b
$ SumExpr -> SumExpr -> SumExpr
plusSumExpr SumExpr
e1 (SumExpr -> SumExpr
negateSumExpr SumExpr
e2)

multArithmeticalExpr :: ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr :: ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
multArithmeticalExpr (ArithmeticalExpr SumExpr
e1) (ArithmeticalExpr SumExpr
e2) = SumExpr -> ArithmeticalExpr
ArithmeticalExpr (SumExpr -> ArithmeticalExpr) -> SumExpr -> ArithmeticalExpr
forall a b. (a -> b) -> a -> b
$ SumExpr -> SumExpr -> SumExpr
multSumExpr SumExpr
e1 SumExpr
e2

parseSumExpr :: Expr -> SumExpr
parseSumExpr :: Expr -> SumExpr
parseSumExpr = \case
  LitInt' Integer
n -> SumExpr :: [ProductExpr] -> Integer -> SumExpr
SumExpr {sumExprList :: [ProductExpr]
sumExprList = [], sumExprConst :: Integer
sumExprConst = Integer
n}
  Negate' Expr
e -> SumExpr -> SumExpr
negateSumExpr (Expr -> SumExpr
parseSumExpr Expr
e)
  Plus' Expr
e1 Expr
e2 -> SumExpr -> SumExpr -> SumExpr
plusSumExpr (Expr -> SumExpr
parseSumExpr Expr
e1) (Expr -> SumExpr
parseSumExpr Expr
e2)
  Minus' Expr
e1 Expr
e2 -> SumExpr -> SumExpr -> SumExpr
plusSumExpr (Expr -> SumExpr
parseSumExpr Expr
e1) (SumExpr -> SumExpr
negateSumExpr (Expr -> SumExpr
parseSumExpr Expr
e2))
  Mult' Expr
e1 Expr
e2 -> SumExpr -> SumExpr -> SumExpr
multSumExpr (Expr -> SumExpr
parseSumExpr Expr
e1) (Expr -> SumExpr
parseSumExpr Expr
e2)
  Expr
e -> ProductExpr -> SumExpr
sumExprFromProductExpr (Expr -> ProductExpr
parseProductExpr Expr
e)

-- | `parseArithmeticalExpr` converts a given expr to a normal form \(\sum_i \prod_j e _ {i,j})\).
-- This assumes given exprs have the type \(\mathbf{int}\).
parseArithmeticalExpr :: Expr -> ArithmeticalExpr
parseArithmeticalExpr :: Expr -> ArithmeticalExpr
parseArithmeticalExpr = SumExpr -> ArithmeticalExpr
ArithmeticalExpr (SumExpr -> ArithmeticalExpr)
-> (Expr -> SumExpr) -> Expr -> ArithmeticalExpr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Expr -> SumExpr
parseSumExpr

formatProductExpr :: ProductExpr -> Expr
formatProductExpr :: ProductExpr -> Expr
formatProductExpr ProductExpr
e =
  let k :: Expr
k = Integer -> Expr
LitInt' (ProductExpr -> Integer
productExprConst ProductExpr
e)
      k' :: Expr -> Expr
k' Expr
e' = case ProductExpr -> Integer
productExprConst ProductExpr
e of
        Integer
0 -> Integer -> Expr
LitInt' Integer
0
        Integer
1 -> Expr
e'
        -1 -> Expr -> Expr
Negate' Expr
e'
        Integer
_ -> Expr -> Expr -> Expr
Mult' Expr
e' Expr
k
   in case ProductExpr -> [Expr]
productExprList ProductExpr
e of
        [] -> Expr
k
        Expr
eHead : [Expr]
esTail -> Expr -> Expr
k' ((Expr -> Expr -> Expr) -> Expr -> [Expr] -> Expr
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Expr -> Expr -> Expr
Mult' Expr
eHead [Expr]
esTail)

formatSumExpr :: SumExpr -> Expr
formatSumExpr :: SumExpr -> Expr
formatSumExpr SumExpr
e = case SumExpr -> [ProductExpr]
sumExprList SumExpr
e of
  [] -> Integer -> Expr
LitInt' (SumExpr -> Integer
sumExprConst SumExpr
e)
  ProductExpr
eHead : [ProductExpr]
esTail ->
    let op :: ProductExpr -> Expr -> Expr -> Expr
op ProductExpr
e'
          | ProductExpr -> Integer
productExprConst ProductExpr
e' Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
0 = Expr -> Expr -> Expr
Plus'
          | ProductExpr -> Integer
productExprConst ProductExpr
e' Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0 = Expr -> Expr -> Expr
Minus'
          | Bool
otherwise = Expr -> Expr -> Expr
forall a b. a -> b -> a
const
        go :: Expr -> ProductExpr -> Expr
go Expr
e1 ProductExpr
e2 = ProductExpr -> Expr -> Expr -> Expr
op ProductExpr
e2 Expr
e1 (ProductExpr -> Expr
formatProductExpr (ProductExpr
e2 {productExprConst :: Integer
productExprConst = Integer -> Integer
forall a. Num a => a -> a
abs (ProductExpr -> Integer
productExprConst ProductExpr
e2)}))
        k' :: Expr -> Expr
k' Expr
e'
          | SumExpr -> Integer
sumExprConst SumExpr
e Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
0 = Expr -> Expr -> Expr
Plus' Expr
e' (Integer -> Expr
LitInt' (SumExpr -> Integer
sumExprConst SumExpr
e))
          | SumExpr -> Integer
sumExprConst SumExpr
e Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
< Integer
0 = Expr -> Expr -> Expr
Minus' Expr
e' (Integer -> Expr
LitInt' (Integer -> Integer
forall a. Num a => a -> a
abs (SumExpr -> Integer
sumExprConst SumExpr
e)))
          | Bool
otherwise = Expr
e'
     in Expr -> Expr
k' ((Expr -> ProductExpr -> Expr) -> Expr -> [ProductExpr] -> Expr
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Expr -> ProductExpr -> Expr
go (ProductExpr -> Expr
formatProductExpr ProductExpr
eHead) [ProductExpr]
esTail)

formatArithmeticalExpr :: ArithmeticalExpr -> Expr
formatArithmeticalExpr :: ArithmeticalExpr -> Expr
formatArithmeticalExpr = SumExpr -> Expr
formatSumExpr (SumExpr -> Expr)
-> (ArithmeticalExpr -> SumExpr) -> ArithmeticalExpr -> Expr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticalExpr -> SumExpr
unArithmeticalExpr (ArithmeticalExpr -> SumExpr)
-> (ArithmeticalExpr -> ArithmeticalExpr)
-> ArithmeticalExpr
-> SumExpr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticalExpr -> ArithmeticalExpr
normalizeArithmeticalExpr

normalizeProductExpr :: ProductExpr -> ProductExpr
normalizeProductExpr :: ProductExpr -> ProductExpr
normalizeProductExpr ProductExpr
e =
  let es :: [Expr]
es =
        if ProductExpr -> Integer
productExprConst ProductExpr
e Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0
          then []
          else [Expr] -> [Expr]
forall a. Ord a => [a] -> [a]
sort (ProductExpr -> [Expr]
productExprList ProductExpr
e)
   in ProductExpr
e {productExprList :: [Expr]
productExprList = [Expr]
es}

normalizeSumExpr :: SumExpr -> SumExpr
normalizeSumExpr :: SumExpr -> SumExpr
normalizeSumExpr SumExpr
e =
  let cmp :: ProductExpr -> ProductExpr -> Ordering
cmp ProductExpr
e1 ProductExpr
e2 = ProductExpr -> [Expr]
productExprList ProductExpr
e1 [Expr] -> [Expr] -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` ProductExpr -> [Expr]
productExprList ProductExpr
e2
      cmp' :: ProductExpr -> ProductExpr -> Bool
cmp' ProductExpr
e1 ProductExpr
e2 = ProductExpr -> ProductExpr -> Ordering
cmp ProductExpr
e1 ProductExpr
e2 Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
EQ
      es :: [ProductExpr]
es = (ProductExpr -> ProductExpr -> Ordering)
-> [ProductExpr] -> [ProductExpr]
forall a. (a -> a -> Ordering) -> [a] -> [a]
sortBy ProductExpr -> ProductExpr -> Ordering
cmp ((ProductExpr -> ProductExpr) -> [ProductExpr] -> [ProductExpr]
forall a b. (a -> b) -> [a] -> [b]
map ProductExpr -> ProductExpr
normalizeProductExpr (SumExpr -> [ProductExpr]
sumExprList SumExpr
e))
      es' :: [[ProductExpr]]
es' = (ProductExpr -> ProductExpr -> Bool)
-> [ProductExpr] -> [[ProductExpr]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
groupBy ProductExpr -> ProductExpr -> Bool
cmp' [ProductExpr]
es
      es'' :: [ProductExpr]
es'' = ([ProductExpr] -> ProductExpr) -> [[ProductExpr]] -> [ProductExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\[ProductExpr]
group -> ProductExpr :: Integer -> [Expr] -> ProductExpr
ProductExpr {productExprConst :: Integer
productExprConst = [Integer] -> Integer
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((ProductExpr -> Integer) -> [ProductExpr] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map ProductExpr -> Integer
productExprConst [ProductExpr]
group), productExprList :: [Expr]
productExprList = ProductExpr -> [Expr]
productExprList ([ProductExpr] -> ProductExpr
forall a. [a] -> a
head [ProductExpr]
group)}) [[ProductExpr]]
es'
      es''' :: [ProductExpr]
es''' = (ProductExpr -> Bool) -> [ProductExpr] -> [ProductExpr]
forall a. (a -> Bool) -> [a] -> [a]
filter (\ProductExpr
e -> ProductExpr -> Integer
productExprConst ProductExpr
e Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer
0 Bool -> Bool -> Bool
&& Bool -> Bool
not ([Expr] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (ProductExpr -> [Expr]
productExprList ProductExpr
e))) [ProductExpr]
es''
      k :: Integer
k = [Integer] -> Integer
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((ProductExpr -> Integer) -> [ProductExpr] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map (\ProductExpr
e -> if [Expr] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null (ProductExpr -> [Expr]
productExprList ProductExpr
e) then ProductExpr -> Integer
productExprConst ProductExpr
e else Integer
0) [ProductExpr]
es'')
   in SumExpr :: [ProductExpr] -> Integer -> SumExpr
SumExpr
        { sumExprList :: [ProductExpr]
sumExprList = [ProductExpr]
es''',
          sumExprConst :: Integer
sumExprConst = SumExpr -> Integer
sumExprConst SumExpr
e Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
k
        }

normalizeArithmeticalExpr :: ArithmeticalExpr -> ArithmeticalExpr
normalizeArithmeticalExpr :: ArithmeticalExpr -> ArithmeticalExpr
normalizeArithmeticalExpr = SumExpr -> ArithmeticalExpr
ArithmeticalExpr (SumExpr -> ArithmeticalExpr)
-> (ArithmeticalExpr -> SumExpr)
-> ArithmeticalExpr
-> ArithmeticalExpr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SumExpr -> SumExpr
normalizeSumExpr (SumExpr -> SumExpr)
-> (ArithmeticalExpr -> SumExpr) -> ArithmeticalExpr -> SumExpr
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ArithmeticalExpr -> SumExpr
unArithmeticalExpr

-- | `makeVectorFromArithmeticalExpr` makes a vector \(f\) and a expr \(c\) from a given vector of variables \(x_0, x_1, \dots, x _ {n - 1}\) and a given expr \(e\) s.t. \(f\) and \(c\) don't have \(x_0, x_1, \dots, x _ {n - 1}\) as free variables and \(e = c + f \cdot (x_0, x_1, \dots, x _ {n - 1})\) holds.
-- This assumes given variables and exprs have the type \(\mathbf{int}\).
--
-- * The returned exprs are normalized with `normalizeArithmeticalExpr`.
makeVectorFromArithmeticalExpr :: V.Vector VarName -> ArithmeticalExpr -> Maybe (V.Vector ArithmeticalExpr, ArithmeticalExpr)
makeVectorFromArithmeticalExpr :: Vector VarName
-> ArithmeticalExpr
-> Maybe (Vector ArithmeticalExpr, ArithmeticalExpr)
makeVectorFromArithmeticalExpr Vector VarName
xs ArithmeticalExpr
es = (forall s.
 ST s (Maybe (Vector ArithmeticalExpr, ArithmeticalExpr)))
-> Maybe (Vector ArithmeticalExpr, ArithmeticalExpr)
forall a. (forall s. ST s a) -> a
runST ((forall s.
  ST s (Maybe (Vector ArithmeticalExpr, ArithmeticalExpr)))
 -> Maybe (Vector ArithmeticalExpr, ArithmeticalExpr))
-> (forall s.
    ST s (Maybe (Vector ArithmeticalExpr, ArithmeticalExpr)))
-> Maybe (Vector ArithmeticalExpr, ArithmeticalExpr)
forall a b. (a -> b) -> a -> b
$ do
  MaybeT (ST s) (Vector ArithmeticalExpr, ArithmeticalExpr)
-> ST s (Maybe (Vector ArithmeticalExpr, ArithmeticalExpr))
forall (m :: * -> *) a. MaybeT m a -> m (Maybe a)
runMaybeT (MaybeT (ST s) (Vector ArithmeticalExpr, ArithmeticalExpr)
 -> ST s (Maybe (Vector ArithmeticalExpr, ArithmeticalExpr)))
-> MaybeT (ST s) (Vector ArithmeticalExpr, ArithmeticalExpr)
-> ST s (Maybe (Vector ArithmeticalExpr, ArithmeticalExpr))
forall a b. (a -> b) -> a -> b
$ do
    MVector s ArithmeticalExpr
f <- ST s (MVector s ArithmeticalExpr)
-> MaybeT (ST s) (MVector s ArithmeticalExpr)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (MVector s ArithmeticalExpr)
 -> MaybeT (ST s) (MVector s ArithmeticalExpr))
-> ST s (MVector s ArithmeticalExpr)
-> MaybeT (ST s) (MVector s ArithmeticalExpr)
forall a b. (a -> b) -> a -> b
$ Int
-> ArithmeticalExpr
-> ST s (MVector (PrimState (ST s)) ArithmeticalExpr)
forall (m :: * -> *) a.
PrimMonad m =>
Int -> a -> m (MVector (PrimState m) a)
MV.replicate (Vector VarName -> Int
forall a. Vector a -> Int
V.length Vector VarName
xs) (Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0)
    STRef s ArithmeticalExpr
c <- ST s (STRef s ArithmeticalExpr)
-> MaybeT (ST s) (STRef s ArithmeticalExpr)
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s (STRef s ArithmeticalExpr)
 -> MaybeT (ST s) (STRef s ArithmeticalExpr))
-> ST s (STRef s ArithmeticalExpr)
-> MaybeT (ST s) (STRef s ArithmeticalExpr)
forall a b. (a -> b) -> a -> b
$ ArithmeticalExpr -> ST s (STRef s ArithmeticalExpr)
forall a s. a -> ST s (STRef s a)
newSTRef (Integer -> ArithmeticalExpr
integerArithmeticalExpr (SumExpr -> Integer
sumExprConst (ArithmeticalExpr -> SumExpr
unArithmeticalExpr ArithmeticalExpr
es)))
    [ProductExpr]
-> (ProductExpr -> MaybeT (ST s) ()) -> MaybeT (ST s) ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ (SumExpr -> [ProductExpr]
sumExprList (ArithmeticalExpr -> SumExpr
unArithmeticalExpr ArithmeticalExpr
es)) ((ProductExpr -> MaybeT (ST s) ()) -> MaybeT (ST s) ())
-> (ProductExpr -> MaybeT (ST s) ()) -> MaybeT (ST s) ()
forall a b. (a -> b) -> a -> b
$ \ProductExpr
e -> do
      let indices :: Vector [(Int, Int)]
indices = (Int -> VarName -> [(Int, Int)])
-> Vector VarName -> Vector [(Int, Int)]
forall a b. (Int -> a -> b) -> Vector a -> Vector b
V.imap (\Int
i VarName
x -> (Int -> (Int, Int)) -> [Int] -> [(Int, Int)]
forall a b. (a -> b) -> [a] -> [b]
map (Int
i,) ((Expr -> Bool) -> [Expr] -> [Int]
forall a. (a -> Bool) -> [a] -> [Int]
findIndices (VarName
x VarName -> Expr -> Bool
`isFreeVar`) (ProductExpr -> [Expr]
productExprList ProductExpr
e))) Vector VarName
xs
      case [[(Int, Int)]] -> [(Int, Int)]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (Vector [(Int, Int)] -> [[(Int, Int)]]
forall a. Vector a -> [a]
V.toList Vector [(Int, Int)]
indices) of
        [] -> ST s () -> MaybeT (ST s) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> MaybeT (ST s) ()) -> ST s () -> MaybeT (ST s) ()
forall a b. (a -> b) -> a -> b
$ STRef s ArithmeticalExpr
-> (ArithmeticalExpr -> ArithmeticalExpr) -> ST s ()
forall s a. STRef s a -> (a -> a) -> ST s ()
modifySTRef STRef s ArithmeticalExpr
c (ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
plusArithmeticalExpr (ProductExpr -> ArithmeticalExpr
arithmeticalExprFromProductExpr ProductExpr
e))
        [(Int
i, Int
j)] -> do
          let e' :: ProductExpr
e' = ProductExpr
e {productExprList :: [Expr]
productExprList = Int -> [Expr] -> [Expr]
forall a. Int -> [a] -> [a]
take Int
j (ProductExpr -> [Expr]
productExprList ProductExpr
e) [Expr] -> [Expr] -> [Expr]
forall a. [a] -> [a] -> [a]
++ Int -> [Expr] -> [Expr]
forall a. Int -> [a] -> [a]
drop (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (ProductExpr -> [Expr]
productExprList ProductExpr
e)}
          ST s () -> MaybeT (ST s) ()
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s () -> MaybeT (ST s) ()) -> ST s () -> MaybeT (ST s) ()
forall a b. (a -> b) -> a -> b
$ MVector (PrimState (ST s)) ArithmeticalExpr
-> (ArithmeticalExpr -> ArithmeticalExpr) -> Int -> ST s ()
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> (a -> a) -> Int -> m ()
MV.modify MVector s ArithmeticalExpr
MVector (PrimState (ST s)) ArithmeticalExpr
f (ArithmeticalExpr -> ArithmeticalExpr -> ArithmeticalExpr
plusArithmeticalExpr (ProductExpr -> ArithmeticalExpr
arithmeticalExprFromProductExpr ProductExpr
e')) Int
i
        [(Int, Int)]
_ -> ST s (Maybe ()) -> MaybeT (ST s) ()
forall (m :: * -> *) a. m (Maybe a) -> MaybeT m a
MaybeT (ST s (Maybe ()) -> MaybeT (ST s) ())
-> ST s (Maybe ()) -> MaybeT (ST s) ()
forall a b. (a -> b) -> a -> b
$ Maybe () -> ST s (Maybe ())
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe ()
forall a. Maybe a
Nothing
    Vector ArithmeticalExpr
f <- MVector (PrimState (MaybeT (ST s))) ArithmeticalExpr
-> MaybeT (ST s) (Vector ArithmeticalExpr)
forall (m :: * -> *) a.
PrimMonad m =>
MVector (PrimState m) a -> m (Vector a)
V.freeze MVector s ArithmeticalExpr
MVector (PrimState (MaybeT (ST s))) ArithmeticalExpr
f
    ArithmeticalExpr
c <- ST s ArithmeticalExpr -> MaybeT (ST s) ArithmeticalExpr
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ST s ArithmeticalExpr -> MaybeT (ST s) ArithmeticalExpr)
-> ST s ArithmeticalExpr -> MaybeT (ST s) ArithmeticalExpr
forall a b. (a -> b) -> a -> b
$ STRef s ArithmeticalExpr -> ST s ArithmeticalExpr
forall s a. STRef s a -> ST s a
readSTRef STRef s ArithmeticalExpr
c
    (Vector ArithmeticalExpr, ArithmeticalExpr)
-> MaybeT (ST s) (Vector ArithmeticalExpr, ArithmeticalExpr)
forall (m :: * -> *) a. Monad m => a -> m a
return ((ArithmeticalExpr -> ArithmeticalExpr)
-> Vector ArithmeticalExpr -> Vector ArithmeticalExpr
forall a b. (a -> b) -> Vector a -> Vector b
V.map ArithmeticalExpr -> ArithmeticalExpr
normalizeArithmeticalExpr Vector ArithmeticalExpr
f, ArithmeticalExpr -> ArithmeticalExpr
normalizeArithmeticalExpr ArithmeticalExpr
c)

isZeroArithmeticalExpr :: ArithmeticalExpr -> Bool
isZeroArithmeticalExpr :: ArithmeticalExpr -> Bool
isZeroArithmeticalExpr ArithmeticalExpr
e = ArithmeticalExpr -> ArithmeticalExpr
normalizeArithmeticalExpr ArithmeticalExpr
e ArithmeticalExpr -> ArithmeticalExpr -> Bool
forall a. Eq a => a -> a -> Bool
== Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0

isOneArithmeticalExpr :: ArithmeticalExpr -> Bool
isOneArithmeticalExpr :: ArithmeticalExpr -> Bool
isOneArithmeticalExpr ArithmeticalExpr
e = ArithmeticalExpr -> ArithmeticalExpr
normalizeArithmeticalExpr ArithmeticalExpr
e ArithmeticalExpr -> ArithmeticalExpr -> Bool
forall a. Eq a => a -> a -> Bool
== Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
1

-- | `unNPlusKPattern` recognizes a pattern of \(x + k\) for a variable \(x\) and an integer constant \(k \in \mathbb{Z}\).
unNPlusKPattern :: ArithmeticalExpr -> Maybe (VarName, Integer)
unNPlusKPattern :: ArithmeticalExpr -> Maybe (VarName, Integer)
unNPlusKPattern ArithmeticalExpr
e = case ArithmeticalExpr -> ArithmeticalExpr
normalizeArithmeticalExpr ArithmeticalExpr
e of
  ArithmeticalExpr
    SumExpr
      { sumExprList :: SumExpr -> [ProductExpr]
sumExprList =
          [ ProductExpr
              { productExprConst :: ProductExpr -> Integer
productExprConst = Integer
1,
                productExprList :: ProductExpr -> [Expr]
productExprList = [Var VarName
x]
              }
            ],
        sumExprConst :: SumExpr -> Integer
sumExprConst = Integer
k
      } -> (VarName, Integer) -> Maybe (VarName, Integer)
forall a. a -> Maybe a
Just (VarName
x, Integer
k)
  ArithmeticalExpr
_ -> Maybe (VarName, Integer)
forall a. Maybe a
Nothing

-- | `makeAffineFunctionFromArithmeticalExpr` is a specialized version of `makeVectorFromArithmeticalExpr`.
-- This function returns \(a, b\) for a given variable \(x\) and a given expr \(e = a x + b\) where \(a, b\) which doesn't use \(x\) free.
makeAffineFunctionFromArithmeticalExpr :: VarName -> ArithmeticalExpr -> Maybe (ArithmeticalExpr, ArithmeticalExpr)
makeAffineFunctionFromArithmeticalExpr :: VarName
-> ArithmeticalExpr -> Maybe (ArithmeticalExpr, ArithmeticalExpr)
makeAffineFunctionFromArithmeticalExpr VarName
x ArithmeticalExpr
es = (Vector ArithmeticalExpr -> ArithmeticalExpr)
-> (Vector ArithmeticalExpr, ArithmeticalExpr)
-> (ArithmeticalExpr, ArithmeticalExpr)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (b, d) (c, d)
first Vector ArithmeticalExpr -> ArithmeticalExpr
forall a. Vector a -> a
V.head ((Vector ArithmeticalExpr, ArithmeticalExpr)
 -> (ArithmeticalExpr, ArithmeticalExpr))
-> Maybe (Vector ArithmeticalExpr, ArithmeticalExpr)
-> Maybe (ArithmeticalExpr, ArithmeticalExpr)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector VarName
-> ArithmeticalExpr
-> Maybe (Vector ArithmeticalExpr, ArithmeticalExpr)
makeVectorFromArithmeticalExpr (VarName -> Vector VarName
forall a. a -> Vector a
V.singleton VarName
x) ArithmeticalExpr
es

-- | `splitConstantFactorArithmeticalExpr` finds \(k\) and \(e'\) for given \(e\) s.t. \(e = k e'\).
splitConstantFactorArithmeticalExpr :: ArithmeticalExpr -> (Integer, ArithmeticalExpr)
splitConstantFactorArithmeticalExpr :: ArithmeticalExpr -> (Integer, ArithmeticalExpr)
splitConstantFactorArithmeticalExpr ArithmeticalExpr
e =
  let e' :: SumExpr
e' = ArithmeticalExpr -> SumExpr
unArithmeticalExpr (ArithmeticalExpr -> SumExpr) -> ArithmeticalExpr -> SumExpr
forall a b. (a -> b) -> a -> b
$ ArithmeticalExpr -> ArithmeticalExpr
normalizeArithmeticalExpr ArithmeticalExpr
e
   in case (SumExpr -> Integer
sumExprConst SumExpr
e', SumExpr -> [ProductExpr]
sumExprList SumExpr
e') of
        (Integer
0, []) -> (Integer
0, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
0)
        (Integer
k, []) -> (Integer
k, Integer -> ArithmeticalExpr
integerArithmeticalExpr Integer
1)
        (Integer
0, [ProductExpr
e]) -> (ProductExpr -> ArithmeticalExpr)
-> (Integer, ProductExpr) -> (Integer, ArithmeticalExpr)
forall (a :: * -> * -> *) b c d.
Arrow a =>
a b c -> a (d, b) (d, c)
second ProductExpr -> ArithmeticalExpr
arithmeticalExprFromProductExpr ((Integer, ProductExpr) -> (Integer, ArithmeticalExpr))
-> (Integer, ProductExpr) -> (Integer, ArithmeticalExpr)
forall a b. (a -> b) -> a -> b
$ ProductExpr -> (Integer, ProductExpr)
splitConstantFactorProductExpr ProductExpr
e
        (Integer
k, [ProductExpr]
es) ->
          let kes :: [(Integer, ProductExpr)]
kes = (ProductExpr -> (Integer, ProductExpr))
-> [ProductExpr] -> [(Integer, ProductExpr)]
forall a b. (a -> b) -> [a] -> [b]
map ProductExpr -> (Integer, ProductExpr)
splitConstantFactorProductExpr [ProductExpr]
es
              d :: Integer
d = (Integer -> Integer -> Integer) -> Integer -> [Integer] -> Integer
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
gcd Integer
k (((Integer, ProductExpr) -> Integer)
-> [(Integer, ProductExpr)] -> [Integer]
forall a b. (a -> b) -> [a] -> [b]
map (Integer, ProductExpr) -> Integer
forall a b. (a, b) -> a
fst [(Integer, ProductExpr)]
kes)
           in ( Integer
d,
                SumExpr -> ArithmeticalExpr
ArithmeticalExpr
                  SumExpr :: [ProductExpr] -> Integer -> SumExpr
SumExpr
                    { sumExprConst :: Integer
sumExprConst = Integer
k Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
d,
                      sumExprList :: [ProductExpr]
sumExprList = ((Integer, ProductExpr) -> ProductExpr)
-> [(Integer, ProductExpr)] -> [ProductExpr]
forall a b. (a -> b) -> [a] -> [b]
map (\(Integer
k, ProductExpr
e) -> ProductExpr
e {productExprConst :: Integer
productExprConst = (Integer
k Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* ProductExpr -> Integer
productExprConst ProductExpr
e) Integer -> Integer -> Integer
forall a. Integral a => a -> a -> a
`div` Integer
d}) [(Integer, ProductExpr)]
kes
                    }
              )

splitConstantFactorProductExpr :: ProductExpr -> (Integer, ProductExpr)
splitConstantFactorProductExpr :: ProductExpr -> (Integer, ProductExpr)
splitConstantFactorProductExpr ProductExpr
e = (ProductExpr -> Integer
productExprConst ProductExpr
e, ProductExpr
e {productExprConst :: Integer
productExprConst = Integer
1})