{-# LANGUAGE TypeSynonymInstances, FlexibleInstances, LambdaCase #-}
module Futhark.Analysis.AlgSimplify
( ScalExp
, Error
, simplify
, mkSuffConds
, RangesRep
, ppRangesRep
, linFormScalE
, pickSymToElim
)
where
import qualified Data.Set as S
import qualified Data.Map.Strict as M
import Data.List
import Control.Monad
import Control.Monad.Reader
import Control.Monad.State
import Futhark.Representation.AST hiding (SDiv, SMod, SQuot, SRem, SSignum)
import Futhark.Analysis.ScalExp
import qualified Futhark.Representation.Primitive as P
type RangesRep = M.Map VName (Int, Maybe ScalExp, Maybe ScalExp)
ppRangesRep :: RangesRep -> String
ppRangesRep = unlines . sort . map ppRange . M.toList
where ppRange (name, (_, lower, upper)) =
pretty name ++ ": " ++
if lower == upper
then "== " ++ ppBound lower
else "[" ++ ppBound lower ++ ", " ++
ppBound upper ++ "]"
ppBound Nothing = "?"
ppBound (Just se) = pretty se
data AlgSimplifyEnv = AlgSimplifyEnv { inSolveLTH0 :: Bool
, ranges :: RangesRep
, maxSteps :: Int
}
data Error = StepsExceeded | Error String
type AlgSimplifyM = StateT Int (ReaderT AlgSimplifyEnv (Either Error))
runAlgSimplifier :: Bool -> AlgSimplifyM a -> RangesRep -> Either Error a
runAlgSimplifier s x r = runReaderT (evalStateT x 0) env
where env = AlgSimplifyEnv { inSolveLTH0 = s
, ranges = r
, maxSteps = 100
}
step :: AlgSimplifyM ()
step = do modify (1+)
exceeded <- pure (>) <*> get <*> asks maxSteps
when exceeded stepsExceeded
stepsExceeded :: AlgSimplifyM a
stepsExceeded = lift $ lift $ Left StepsExceeded
badAlgSimplifyM :: String -> AlgSimplifyM a
badAlgSimplifyM = lift . lift . Left . Error
markInSolve :: AlgSimplifyEnv -> AlgSimplifyEnv
markInSolve env =
env { inSolveLTH0 = True }
markGaussLTH0 :: AlgSimplifyM a -> AlgSimplifyM a
markGaussLTH0 = local markInSolve
data NNumExp = NSum [NNumExp] PrimType
| NProd [ScalExp] PrimType
deriving (Eq, Ord, Show)
data BTerm = NRelExp RelOp0 NNumExp
| LogCt !Bool
| PosId VName
| NegId VName
deriving (Eq, Ord, Show)
type NAnd = [BTerm]
type DNF = [NAnd ]
simplify :: ScalExp -> RangesRep -> ScalExp
simplify e rangesrep = case runAlgSimplifier False (simplifyScal e) rangesrep of
Left (Error err) ->
error $ "Error during algebraic simplification of: " ++ pretty e ++
"\n" ++ err
Left StepsExceeded -> e
Right e' -> e'
linFormScalE :: VName -> ScalExp -> RangesRep -> Either Error (Maybe (ScalExp,ScalExp))
linFormScalE i e = runAlgSimplifier False (linearFormScalExp i e)
mkSuffConds :: ScalExp -> RangesRep -> Either Error [[ScalExp]]
mkSuffConds e = runAlgSimplifier True (gaussElimRel e)
simplifyNRel :: BTerm -> AlgSimplifyM BTerm
simplifyNRel inp_term@(NRelExp LTH0 inp_sofp) = do
term <- cheapSimplifyNRel inp_term
in_gauss <- asks inSolveLTH0
let tp = typeOfNAlg inp_sofp
if in_gauss || isTrivialNRel term || tp `notElem` map IntType allIntTypes
then return term
else do ednf <- markGaussLTH0 $ gaussAllLTH0 True S.empty inp_sofp
return $ case ednf of
Val (BoolValue c) -> LogCt c
_ -> term
where
isTrivialNRel (NRelExp _ (NProd [Val _] _)) = True
isTrivialNRel NRelExp{} = False
isTrivialNRel _ = False
cheapSimplifyNRel :: BTerm -> AlgSimplifyM BTerm
cheapSimplifyNRel (NRelExp rel (NProd [Val v] _)) =
LogCt <$> valLTHEQ0 rel v
cheapSimplifyNRel e = return e
simplifyNRel inp_term =
return inp_term
gaussElimRel :: ScalExp -> AlgSimplifyM [[ScalExp]]
gaussElimRel (RelExp LTH0 e) = do
e_sofp <- toNumSofP =<< simplifyScal e
e_scal<- simplifyScal =<< gaussAllLTH0 False S.empty e_sofp
e_dnf <- toDNF e_scal
mapM (mapM (\case
LogCt c -> return $ Val (BoolValue c)
PosId i -> return $ Id i $ scalExpType e
NegId i -> return $ Id i $ scalExpType e
NRelExp rel ee -> RelExp rel <$> fromNumSofP ee
)) e_dnf
gaussElimRel _ =
badAlgSimplifyM "gaussElimRel: only LTH0 Int relations please!"
primScalExpLTH0 :: ScalExp -> Bool
primScalExpLTH0 (Val (IntValue v)) = P.intToInt64 v < 0
primScalExpLTH0 _ = False
type Prod = [ScalExp]
gaussAllLTH0 :: Bool -> S.Set VName -> NNumExp -> AlgSimplifyM ScalExp
gaussAllLTH0 static_only el_syms sofp = do
step
let tp = typeOfNAlg sofp
rangesrep <- asks ranges
e_scal <- fromNumSofP sofp
let mi = pickSymToElim rangesrep el_syms e_scal
case mi of
Nothing -> return $ if primScalExpLTH0 e_scal
then Val (BoolValue True)
else RelExp LTH0 e_scal
Just i -> do
(jmm, fs0, terms) <- findMinMaxTerm i sofp
fs <- if not (null fs0) then return fs0
else do one <- getPos1 tp; return [Val one]
case jmm of
Just (MaxMin _ [] ) ->
badAlgSimplifyM "gaussAllLTH0: Empty MinMax Node!"
Just (MaxMin ismin mmts) -> do
mone <- getNeg1 tp
fsm1 <- toNumSofP =<< simplifyScal =<< fromNumSofP
( NSum [NProd fs tp, NProd [Val mone] tp] tp )
fs_leq0 <- gaussAllLTH0 static_only el_syms fsm1
mfsm1 <- toNumSofP =<< simplifyScal =<< fromNumSofP
( NSum [NProd (Val mone:fs) tp,NProd [Val mone] tp] tp )
fs_geq0 <- gaussAllLTH0 static_only el_syms mfsm1
mm_terms<- mapM (\t -> toNumSofP =<< simplifyScal =<< fromNumSofP
(NSum ( NProd (t:fs) tp:terms ) tp) ) mmts
mms <- mapM (gaussAllLTH0 static_only el_syms) mm_terms
if static_only
then if ( fs_geq0 == Val (BoolValue True) && ismin) ||
( fs_leq0 == Val (BoolValue True) && not ismin)
then do let is_one_true = Val (BoolValue True ) `elem` mms
let are_all_false = all (== Val (BoolValue False)) mms
return $ if is_one_true then Val (BoolValue True)
else if are_all_false then Val (BoolValue False)
else RelExp LTH0 e_scal
else do let are_all_true = all (== Val (BoolValue True )) mms
let is_one_false = Val (BoolValue False) `elem` mms
return $ if are_all_true then Val (BoolValue True )
else if is_one_false then Val (BoolValue False)
else RelExp LTH0 e_scal
else do
let mm_fsgeq0 = foldl (if ismin then SLogOr else SLogAnd)
(Val (BoolValue (not ismin))) mms
let mm_fslth0 = foldl (if ismin then SLogAnd else SLogOr)
(Val (BoolValue ismin )) mms
return $ SLogOr (SLogAnd fs_geq0 mm_fsgeq0) (SLogAnd fs_leq0 mm_fslth0)
Just _ -> badAlgSimplifyM "gaussOneLTH0: (Just MinMax) invariant violated!"
Nothing-> do
m_sofp <- gaussOneDefaultLTH0 static_only i el_syms sofp
case m_sofp of
Nothing -> gaussAllLTH0 static_only (S.insert i el_syms) sofp
Just res_eofp -> return res_eofp
where
findMinMaxTerm :: VName -> NNumExp -> AlgSimplifyM (Maybe ScalExp, Prod, [NNumExp])
findMinMaxTerm _ (NSum [] _) = return (Nothing, [], [])
findMinMaxTerm _ (NSum [NProd [MaxMin ismin e] _] _) =
return (Just (MaxMin ismin e), [], [])
findMinMaxTerm _ (NProd [MaxMin ismin e] _) =
return (Just (MaxMin ismin e), [], [])
findMinMaxTerm ii t@NProd{} = do (mm, fs) <- findMinMaxFact ii t
return (mm, fs, [])
findMinMaxTerm ii (NSum (t:ts) tp)= do
rangesrep <- asks ranges
case M.lookup ii rangesrep of
Just (_, Just _, Just _) -> do
f <- findMinMaxFact ii t
case f of
(Just mm, fs) -> return (Just mm, fs, ts)
(Nothing, _ ) -> do (mm, fs', ts') <- findMinMaxTerm ii (NSum ts tp)
return (mm, fs', t:ts')
_ -> return (Nothing, [], t:ts)
findMinMaxFact :: VName -> NNumExp -> AlgSimplifyM (Maybe ScalExp, Prod)
findMinMaxFact _ (NProd [] _ ) = return (Nothing, [])
findMinMaxFact ii (NProd (f:fs) tp) =
case f of
MaxMin ismin ts -> do
let id_set = mconcat $ map freeIn ts
if S.member ii id_set
then return (Just (MaxMin ismin ts), fs)
else do (mm, fs') <- findMinMaxFact ii (NProd fs tp)
return (mm, f:fs')
_ -> do (mm, fs') <- findMinMaxFact ii (NProd fs tp)
return (mm, f:fs')
findMinMaxFact ii (NSum [f] _) = findMinMaxFact ii f
findMinMaxFact _ (NSum _ _) =
badAlgSimplifyM "findMinMaxFact: NSum argument illegal!"
gaussOneDefaultLTH0 :: Bool -> VName -> S.Set VName -> NNumExp -> AlgSimplifyM (Maybe ScalExp)
gaussOneDefaultLTH0 static_only i elsyms e = do
aipb <- linearForm i e
case aipb of
Nothing -> return Nothing
Just (a, b) -> do
rangesrep <- asks ranges
one <- getPos1 (typeOfNAlg e)
ascal <- fromNumSofP a
mam1 <- toNumSofP =<< simplifyScal (SNeg (SPlus ascal (Val one)))
am1 <- toNumSofP =<< simplifyScal (SMinus ascal (Val one))
ma <- toNumSofP =<< simplifyScal (SNeg ascal)
b_scal<- fromNumSofP b
mbm1 <- toNumSofP =<< simplifyScal (SNeg (SPlus b_scal (Val one)))
aleq0 <- simplifyScal =<< gaussAllLTH0 static_only elsyms am1
ageq0 <- simplifyScal =<< gaussAllLTH0 static_only elsyms mam1
case M.lookup i rangesrep of
Nothing ->
badAlgSimplifyM "gaussOneDefaultLTH0: sym not in ranges!"
Just (_, Nothing, Nothing) ->
badAlgSimplifyM "gaussOneDefaultLTH0: both bounds are undefined!"
Just (_, Just lb, Nothing) -> do
alpblth0 <- gaussElimHalf static_only elsyms lb a b
and_half <- simplifyScal alpblth0
case (and_half, aleq0) of
(Val (BoolValue True), Val (BoolValue True)) ->
return $ Just and_half
_ -> do malmbm1lth0 <- gaussElimHalf static_only elsyms lb ma mbm1
other_half <- simplifyScal malmbm1lth0
case (other_half, ageq0) of
(Val (BoolValue True), Val (BoolValue True)) ->
return $ Just (Val (BoolValue False))
_ -> return Nothing
Just (_, Nothing, Just ub) -> do
aupblth0 <- gaussElimHalf static_only elsyms ub a b
and_half <- simplifyScal aupblth0
case (and_half, ageq0) of
(Val (BoolValue True), Val (BoolValue True)) ->
return $ Just and_half
_ -> do
maumbm1 <- gaussElimHalf static_only elsyms ub ma mbm1
other_half <- simplifyScal maumbm1
case (other_half, aleq0) of
(Val (BoolValue True), Val (BoolValue True)) ->
return $ Just (Val (BoolValue False))
_ -> return Nothing
Just (_, Just lb, Just ub) ->
if static_only
then if aleq0 == Val (BoolValue True)
then do alpblth0 <- simplifyScal =<< gaussElimHalf static_only elsyms lb a b
if alpblth0 == Val (BoolValue True)
then return $ Just (Val (BoolValue True))
else do maubmbm1 <- simplifyScal =<< gaussElimHalf static_only elsyms ub ma mbm1
return $ if maubmbm1 == Val (BoolValue True)
then Just (Val (BoolValue False))
else Nothing
else if ageq0 == Val (BoolValue True)
then do aupblth0 <- simplifyScal =<< gaussElimHalf static_only elsyms ub a b
if aupblth0 == Val (BoolValue True)
then return $ Just (Val (BoolValue True))
else do malbmbm1 <- simplifyScal =<< gaussElimHalf static_only elsyms lb ma mbm1
return $ if malbmbm1 == Val (BoolValue True)
then Just (Val (BoolValue False))
else Nothing
else return Nothing
else do
alpblth0 <- gaussElimHalf static_only elsyms lb a b
aupblth0 <- gaussElimHalf static_only elsyms ub a b
res <- simplifyScal $ SLogOr (SLogAnd aleq0 alpblth0) (SLogAnd ageq0 aupblth0)
return $ Just res
where
gaussElimHalf :: Bool -> S.Set VName -> ScalExp -> NNumExp -> NNumExp -> AlgSimplifyM ScalExp
gaussElimHalf only_static elsyms0 q a b = do
a_scal <- fromNumSofP a
b_scal <- fromNumSofP b
e_num_scal <- simplifyScal (SPlus (STimes a_scal q) b_scal)
e_num <- toNumSofP e_num_scal
gaussAllLTH0 only_static elsyms0 e_num
pickSymToElim :: RangesRep -> S.Set VName -> ScalExp -> Maybe VName
pickSymToElim rangesrep elsyms0 e_scal =
let ids0= S.toList $ freeIn e_scal
ids1= filter (\s -> not (S.member s elsyms0)) ids0
ids2= filter (\s -> case M.lookup s rangesrep of
Nothing -> False
Just _ -> True
) ids1
ids = sortBy (\n1 n2 -> let n1p = M.lookup n1 rangesrep
n2p = M.lookup n2 rangesrep
in case (n1p, n2p) of
(Just (p1,_,_), Just (p2,_,_)) -> compare (-p1) (-p2)
(_ , _ ) -> compare (1::Int) (1::Int)
) ids2
in case ids of
[] -> Nothing
v:_ -> Just v
linearFormScalExp :: VName -> ScalExp -> AlgSimplifyM (Maybe (ScalExp, ScalExp))
linearFormScalExp sym scl_exp = do
sofp <- toNumSofP =<< simplifyScal scl_exp
ab <- linearForm sym sofp
case ab of
Just (a_sofp, b_sofp) -> do
a <- fromNumSofP a_sofp
b <- fromNumSofP b_sofp
a'<- simplifyScal a
b'<- simplifyScal b
return $ Just (a', b')
Nothing ->
return Nothing
linearForm :: VName -> NNumExp -> AlgSimplifyM (Maybe (NNumExp, NNumExp))
linearForm _ (NProd [] _) =
badAlgSimplifyM "linearForm: empty Prod!"
linearForm idd ee@NProd{} = linearForm idd (NSum [ee] (typeOfNAlg ee))
linearForm _ (NSum [] _) =
badAlgSimplifyM "linearForm: empty Sum!"
linearForm idd (NSum terms tp) = do
terms_d_idd <- mapM (\t -> do t0 <- case t of
NProd (_:_) _ -> return t
_ -> badAlgSimplifyM "linearForm: ILLEGAL111!!!!"
t_scal <- fromNumSofP t0
simplifyScal $ SDiv t_scal (Id idd (scalExpType t_scal))
) terms
let myiota = [1..(length terms)]
let ia_terms= filter (\(_,t)-> case t of
SDiv _ _ -> False
_ -> True
) (zip myiota terms_d_idd)
let (a_inds, a_terms) = unzip ia_terms
let (_, b_terms) = unzip $ filter (\(iii,_) -> iii `notElem` a_inds)
(zip myiota terms)
b_succ <- foldM (\acc x ->
case x of
NProd fs _ -> do let fs_scal = case fs of
[] -> Val $ IntValue $ Int32Value 1
f:fs' -> foldl STimes f fs'
let b_ids = freeIn fs_scal
return $ acc && not (idd `S.member` b_ids)
_ -> badAlgSimplifyM "linearForm: ILLEGAL222!!!!"
) True b_terms
case a_terms of
t:ts | b_succ -> do
let a_scal = foldl SPlus t ts
a_terms_sofp <- toNumSofP =<< simplifyScal a_scal
b_terms_sofp <- if null b_terms
then do zero <- getZero tp; return $ NProd [Val zero] tp
else return $ NSum b_terms tp
return $ Just (a_terms_sofp, b_terms_sofp)
_ -> return Nothing
simplifyScal :: ScalExp -> AlgSimplifyM ScalExp
simplifyScal (Val v) = return $ Val v
simplifyScal (Id x t) = return $ Id x t
simplifyScal e@SNot{} = fromDNF =<< simplifyDNF =<< toDNF e
simplifyScal e@SLogAnd{} = fromDNF =<< simplifyDNF =<< toDNF e
simplifyScal e@SLogOr{} = fromDNF =<< simplifyDNF =<< toDNF e
simplifyScal e@RelExp{} = fromDNF =<< simplifyDNF =<< toDNF e
simplifyScal (MaxMin _ []) =
badAlgSimplifyM "Scalar MaxMin expression with empty arglist."
simplifyScal (MaxMin _ [e]) = simplifyScal e
simplifyScal (MaxMin ismin es) = do
es0 <- mapM simplifyScal es
let evals = filter isValue es0
es' = filter (not . isValue) es0
mvv = case evals of
[] -> Nothing
v:vs -> let myop = if ismin then min else max
myval= getValue v
oneval = (foldl myop myval . map getValue) vs
in Just $ Val oneval
case (es', mvv) of
([], Just vv) -> return vv
(_, Just vv) -> return $ MaxMin ismin $ remDups $ foldl flatop [] $ vv:es'
(_, Nothing) -> return $ MaxMin ismin $ remDups $ foldl flatop [] es'
where
isValue :: ScalExp -> Bool
isValue e = case e of
Val _ -> True
_ -> False
getValue :: ScalExp -> PrimValue
getValue se = case se of
Val v -> v
_ -> value (0::Int32)
flatop :: [ScalExp] -> ScalExp -> [ScalExp]
flatop a e@(MaxMin ismin' ses) =
a ++ if ismin == ismin' then ses else [e]
flatop a e = a++[e]
remDups :: [ScalExp] -> [ScalExp]
remDups l = S.toList (S.fromList l)
simplifyScal (SPlus e1o e2o) = do
e1' <- simplifyScal e1o
e2' <- simplifyScal e2o
if isMaxMin e1' || isMaxMin e2'
then helperPlusMinMax $ SPlus e1' e2'
else normalPlus e1' e2'
where
normalPlus :: ScalExp -> ScalExp -> AlgSimplifyM ScalExp
normalPlus e1 e2 = do
e1' <- toNumSofP e1
e2' <- toNumSofP e2
let tp = scalExpType e1
let terms = getTerms e1' ++ getTerms e2'
splittedTerms <- mapM splitTerm terms
let sortedTerms = sortBy (\(n1,_) (n2,_) -> compare n1 n2) splittedTerms
merged <- reverse <$> foldM discriminate [] sortedTerms
let filtered = filter (\(_,v) -> not $ zeroIsh v ) merged
if null filtered
then do
zero <- getZero tp
fromNumSofP $ NProd [Val zero] tp
else do
terms' <- mapM joinTerm filtered
fromNumSofP $ NSum terms' tp
simplifyScal (SMinus e1 e2) = do
let tp = scalExpType e1
if e1 == e2
then Val <$> getZero tp
else do min_1 <- getNeg1 $ scalExpType e1
simplifyScal $ SPlus e1 $ STimes (Val min_1) e2
simplifyScal (SNeg e) = do
negOne <- getNeg1 $ scalExpType e
simplifyScal $ STimes (Val negOne) e
simplifyScal (SAbs e) = return $ SAbs e
simplifyScal (SSignum e) = return $ SSignum e
simplifyScal (STimes e1o e2o) = do
e1'' <- simplifyScal e1o
e2'' <- simplifyScal e2o
if isMaxMin e1'' || isMaxMin e2''
then helperMultMinMax $ STimes e1'' e2''
else normalTimes e1'' e2''
where
normalTimes :: ScalExp -> ScalExp -> AlgSimplifyM ScalExp
normalTimes e1 e2 = do
let tp = scalExpType e1
e1' <- toNumSofP e1
e2' <- toNumSofP e2
case (e1', e2') of
(NProd xs _, y@NProd{}) -> fromNumSofP =<< makeProds xs y
(NProd xs _, y) -> do
prods <- mapM (makeProds xs) $ getTerms y
fromNumSofP $ NSum (sort prods) tp
(x, NProd ys _) -> do
prods <- mapM (makeProds ys) $ getTerms x
fromNumSofP $ NSum (sort prods) tp
(NSum xs _, NSum ys _) -> do
xsMultChildren <- mapM getMultChildren xs
prods <- mapM (\x -> mapM (makeProds x) ys) xsMultChildren
fromNumSofP $ NSum (sort $ concat prods) tp
makeProds :: [ScalExp] -> NNumExp -> AlgSimplifyM NNumExp
makeProds [] _ =
badAlgSimplifyM " In simplifyAlgN, makeProds: 1st arg is the empty list! "
makeProds _ (NProd [] _) =
badAlgSimplifyM " In simplifyAlgN, makeProds: 2nd arg is the empty list! "
makeProds _ NSum{} =
badAlgSimplifyM " In simplifyAlgN, makeProds: e1 * e2: e2 is a sum of sums! "
makeProds (Val v1:exs) (NProd (Val v2:ys) tp1) = do
v <- mulVals v1 v2
return $ NProd (Val v : sort (ys++exs) ) tp1
makeProds (Val v:exs) (NProd ys tp1) =
return $ NProd (Val v : sort (ys++exs) ) tp1
makeProds exs (NProd (Val v : ys) tp1) =
return $ NProd (Val v : sort (ys++exs) ) tp1
makeProds exs (NProd ys tp1) =
return $ NProd (sort (ys++exs)) tp1
simplifyScal (SDiv e1o e2o) = do
e1' <- simplifyScal e1o
e2' <- simplifyScal e2o
if isMaxMin e1' || isMaxMin e2'
then helperMultMinMax $ SDiv e1' e2'
else normalFloatDiv e1' e2'
where
normalFloatDiv :: ScalExp -> ScalExp -> AlgSimplifyM ScalExp
normalFloatDiv e1 e2
| e1 == e2 = do one <- getPos1 $ scalExpType e1
return $ Val one
| otherwise = do
e1' <- toNumSofP e1
e2' <- toNumSofP e2
case e2' of
NProd fs tp -> do
e1Split <- mapM splitTerm (getTerms e1')
case e1Split of
[] -> Val <$> getZero tp
_ -> do (fs', e1Split') <- trySimplifyDivRec fs [] e1Split
if length fs' == length fs
then turnBackAndDiv e1' e2'
else do terms_e1' <- mapM joinTerm e1Split'
e1'' <- fromNumSofP $ NSum terms_e1' tp
case fs' of
[] -> return e1''
_ -> do e2'' <- fromNumSofP $ NProd fs' tp
return $ SDiv e1'' e2''
_ -> turnBackAndDiv e1' e2'
turnBackAndDiv :: NNumExp -> NNumExp -> AlgSimplifyM ScalExp
turnBackAndDiv ee1 ee2 = do
ee1' <- fromNumSofP ee1
ee2' <- fromNumSofP ee2
return $ SDiv ee1' ee2'
simplifyScal (SMod e1o e2o) =
SMod <$> simplifyScal e1o <*> simplifyScal e2o
simplifyScal (SQuot e1o e2o) =
SQuot <$> simplifyScal e1o <*> simplifyScal e2o
simplifyScal (SRem e1o e2o) =
SRem <$> simplifyScal e1o <*> simplifyScal e2o
simplifyScal (SPow e1 e2) = do
let tp = scalExpType e1
e1' <- simplifyScal e1
e2' <- simplifyScal e2
if isCt1 e1' || isCt0 e2'
then Val <$> getPos1 tp
else if isCt1 e2'
then return e1'
else case (e1', e2') of
(Val v1, Val v2)
| Just v <- powVals v1 v2 -> return $ Val v
(_, Val (IntValue n)) ->
if P.intToInt64 n >= 1
then
do new_e <- fromNumSofP $ NProd (genericReplicate (P.intToInt64 n) e1') tp
simplifyScal new_e
else return $ SPow e1' e2'
(_, _) -> return $ SPow e1' e2'
where
powVals :: PrimValue -> PrimValue -> Maybe PrimValue
powVals (IntValue v1) (IntValue v2) = IntValue <$> P.doPow v1 v2
powVals _ _ = Nothing
isMaxMin :: ScalExp -> Bool
isMaxMin MaxMin{} = True
isMaxMin _ = False
helperPlusMinMax :: ScalExp -> AlgSimplifyM ScalExp
helperPlusMinMax (SPlus (MaxMin ismin es) e) =
simplifyScal $ MaxMin ismin $ map (`SPlus` e) es
helperPlusMinMax (SPlus e (MaxMin ismin es)) =
simplifyScal $ MaxMin ismin $ map (SPlus e) es
helperPlusMinMax _ = badAlgSimplifyM "helperPlusMinMax: Reached unreachable case!"
helperMultMinMax :: ScalExp -> AlgSimplifyM ScalExp
helperMultMinMax (STimes e em@MaxMin{}) = helperTimesDivMinMax True True em e
helperMultMinMax (STimes em@MaxMin{} e) = helperTimesDivMinMax True False em e
helperMultMinMax (SDiv e em@MaxMin{}) = helperTimesDivMinMax False True em e
helperMultMinMax (SDiv em@MaxMin{} e) = helperTimesDivMinMax False False em e
helperMultMinMax _ = badAlgSimplifyM "helperMultMinMax: Reached unreachable case!"
helperTimesDivMinMax :: Bool -> Bool -> ScalExp -> ScalExp -> AlgSimplifyM ScalExp
helperTimesDivMinMax isTimes isRev emo@MaxMin{} e = do
em <- simplifyScal emo
case em of
MaxMin ismin es -> do
e' <- simplifyScal e
e'_sop <- toNumSofP e'
p' <- simplifyNRel $ NRelExp LTH0 e'_sop
case p' of
LogCt ctbool -> do
let cond = ( isTimes && not ctbool ) ||
( not isTimes && not isRev && not ctbool ) ||
( not isTimes && isRev && ctbool )
let ismin' = if cond then ismin else not ismin
simplifyScal $ MaxMin ismin' $ map (`mkTimesDiv` e') es
_ -> if not isTimes then return $ mkTimesDiv em e'
else
case e'_sop of
NProd _ _ -> return $ mkTimesDiv em e'
NSum ts tp -> do
new_ts <-
mapM (\case
NProd fs _ -> return $ NProd (em:fs) tp
_ -> badAlgSimplifyM
"helperTimesDivMinMax: SofP invariant violated!"
) ts
simplifyScal =<< fromNumSofP ( NSum new_ts tp )
_ -> simplifyScal $ mkTimesDiv em e
where
mkTimesDiv :: ScalExp -> ScalExp -> ScalExp
mkTimesDiv e1 e2
| not isTimes = if isRev then SDiv e2 e1 else SDiv e1 e2
| isRev = STimes e2 e1
| otherwise = STimes e1 e2
helperTimesDivMinMax _ _ _ _ =
badAlgSimplifyM "helperTimesDivMinMax: Reached unreachable case!"
negateBTerm :: BTerm -> AlgSimplifyM BTerm
negateBTerm (LogCt v) = return $ LogCt (not v)
negateBTerm (PosId i) = return $ NegId i
negateBTerm (NegId i) = return $ PosId i
negateBTerm (NRelExp rel e) = do
let tp = typeOfNAlg e
case (tp, rel) of
(IntType it, LTH0) -> do
se <- fromNumSofP e
ne <- toNumSofP =<< simplifyScal (SNeg $ SPlus se (Val (value (P.intValue it (1::Int)))))
return $ NRelExp LTH0 ne
_ -> NRelExp (if rel == LEQ0 then LTH0 else LEQ0) <$>
(toNumSofP =<< negateSimplified =<< fromNumSofP e)
bterm2ScalExp :: BTerm -> AlgSimplifyM ScalExp
bterm2ScalExp (LogCt v) = return $ Val $ BoolValue v
bterm2ScalExp (PosId i) = return $ Id i int32
bterm2ScalExp (NegId i) = return $ SNot $ Id i int32
bterm2ScalExp (NRelExp rel e) = RelExp rel <$> fromNumSofP e
fromDNF :: DNF -> AlgSimplifyM ScalExp
fromDNF [] = badAlgSimplifyM "fromDNF: empty DNF!"
fromDNF (t:ts) = do
t' <- translFact t
foldM (\acc x -> do x' <- translFact x; return $ SLogOr x' acc) t' ts
where
translFact [] = badAlgSimplifyM "fromDNF, translFact empty DNF factor!"
translFact (f:fs) = do
f' <- bterm2ScalExp f
foldM (\acc x -> do x' <- bterm2ScalExp x; return $ SLogAnd x' acc) f' fs
toDNF :: ScalExp -> AlgSimplifyM DNF
toDNF (Val (BoolValue v)) = return [[LogCt v]]
toDNF (Id idd _ ) = return [[PosId idd]]
toDNF (RelExp rel e ) = do
let t = scalExpType e
case t of
IntType it -> do
e' <- if rel == LEQ0
then do m1 <- getNeg1 $ IntType it
return $ SPlus e $ Val m1
else return e
ne <- toNumSofP =<< simplifyScal e'
nrel <- simplifyNRel $ NRelExp LTH0 ne
return [[nrel]]
_ -> do ne <- toNumSofP =<< simplifyScal e
nrel <- markGaussLTH0 $ simplifyNRel $ NRelExp rel ne
return [[nrel]]
toDNF (SNot (SNot e)) = toDNF e
toDNF (SNot (Val (BoolValue v))) = return [[LogCt $ not v]]
toDNF (SNot (Id idd _)) = return [[NegId idd]]
toDNF (SNot (RelExp rel e)) = do
let not_rel = if rel == LEQ0 then LTH0 else LEQ0
neg_e <- simplifyScal (SNeg e)
toDNF $ RelExp not_rel neg_e
toDNF (SLogOr e1 e2 ) = do
e1s <- toDNF e1
e2s <- toDNF e2
return $ sort $ e1s ++ e2s
toDNF (SLogAnd e1 e2 ) = do
e1s <- toDNF e1
e2s <- toDNF e2
let lll = map (\t2-> map (++t2) e1s) e2s
return $ sort $ concat lll
toDNF (SNot (SLogAnd e1 e2)) = do
e1s <- toDNF (SNot e1)
e2s <- toDNF (SNot e2)
return $ sort $ e1s ++ e2s
toDNF (SNot (SLogOr e1 e2)) = do
e1s <- toDNF $ SNot e1
e2s <- toDNF $ SNot e2
let lll = map (\t2-> map (++t2) e1s) e2s
return $ sort $ concat lll
toDNF _ = badAlgSimplifyM "toDNF: not a boolean expression!"
simplifyDNF :: DNF -> AlgSimplifyM DNF
simplifyDNF terms0 = do
terms1 <- mapM (simplifyAndOr True) terms0
let terms' = if [LogCt True] `elem` terms1 then [[LogCt True]]
else S.toList $ S.fromList $
filter (/= [LogCt False]) terms1
if null terms' then return [[LogCt False]]
else do
let len1terms = all ((1==) . length) terms'
if not len1terms then return terms'
else do let terms_flat = concat terms'
terms'' <- simplifyAndOr False terms_flat
return $ map (:[]) terms''
simplifyAndOr :: Bool -> [BTerm] -> AlgSimplifyM [BTerm]
simplifyAndOr _ [] = badAlgSimplifyM "simplifyAndOr: not a boolean expression!"
simplifyAndOr is_and fs =
if LogCt (not is_and) `elem` fs
then return [LogCt $ not is_and]
else do let fs' = S.toList . S.fromList . filter (/=LogCt is_and) $ fs
if null fs'
then return [LogCt is_and]
else do
fs''<- foldM (\l x-> do (addx, l') <- trimImplies is_and x l
return $ if addx then x:l' else l'
) [] fs'
isF <- foldM (\b x -> if b then return b
else do notx <- negateBTerm x
impliesAny is_and x notx fs''
) False fs''
return $ if not isF then fs''
else if is_and then [LogCt False]
else [LogCt True ]
where
impliesRel :: BTerm -> BTerm -> AlgSimplifyM Bool
impliesRel (LogCt False) _ = return True
impliesRel _ (LogCt True) = return True
impliesRel (LogCt True) e = do
let e' = e
return $ e' == LogCt True
impliesRel e (LogCt False) = do
e' <- negateBTerm e
return $ e' == LogCt True
impliesRel (NRelExp rel1 e1) (NRelExp rel2 e2) = do
let btp = typeOfNAlg e1
if btp /= typeOfNAlg e2
then return False
else do
one <- getPos1 btp
e1' <- fromNumSofP e1
e2' <- fromNumSofP e2
case (rel1, rel2, btp) of
(LTH0, LTH0, IntType _) -> do
e2me1m1 <- toNumSofP =<< simplifyScal (SMinus e2' $ SPlus e1' $ Val one)
diffrel <- simplifyNRel $ NRelExp LTH0 e2me1m1
return $ diffrel == LogCt True
(_, _, IntType _) -> badAlgSimplifyM "impliesRel: LEQ0 for Int!"
(_, _, _) -> badAlgSimplifyM "impliesRel: exp of illegal type!"
impliesRel p1 p2
| p1 == p2 = return True
| otherwise = return False
trimImplies :: Bool -> BTerm -> [BTerm] -> AlgSimplifyM (Bool, [BTerm])
trimImplies _ _ [] = return (True, [])
trimImplies and_case x (p:ps) = do
succc <- impliesRel x p
if succc
then if and_case then trimImplies and_case x ps else return (False, p:ps)
else do suc <- impliesRel p x
if suc then if and_case then return (False, p:ps) else trimImplies and_case x ps
else do (addx, newps) <- trimImplies and_case x ps
return (addx, p:newps)
impliesAny :: Bool -> BTerm -> BTerm -> [BTerm] -> AlgSimplifyM Bool
impliesAny _ _ _ [] = return False
impliesAny and_case x notx (p:ps)
| x == p = impliesAny and_case x notx ps
| otherwise = do
succ' <- if and_case then impliesRel p notx else impliesRel notx p
if succ' then return True
else impliesAny and_case x notx ps
negateSimplified :: ScalExp -> AlgSimplifyM ScalExp
negateSimplified (SNeg e) = return e
negateSimplified (SNot e) = return e
negateSimplified (SAbs e) = return $ SAbs e
negateSimplified (SSignum e) =
SSignum <$> negateSimplified e
negateSimplified e@(Val v) = do
m1 <- getNeg1 $ scalExpType e
v' <- mulVals m1 v; return $ Val v'
negateSimplified e@Id{} = do
m1 <- getNeg1 $ scalExpType e
return $ STimes (Val m1) e
negateSimplified (SMinus e1 e2) = do
e1' <- negateSimplified e1
return $ SPlus e1' e2
negateSimplified (SPlus e1 e2) = do
e1' <- negateSimplified e1
e2' <- negateSimplified e2
return $ SPlus e1' e2'
negateSimplified e@(SPow _ _) = do
m1 <- getNeg1 $ scalExpType e
return $ STimes (Val m1) e
negateSimplified (STimes e1 e2) = do
(e1', e2') <- helperNegateMult e1 e2; return $ STimes e1' e2'
negateSimplified (SDiv e1 e2) = do
(e1', e2') <- helperNegateMult e1 e2; return $ SDiv e1' e2'
negateSimplified (SMod e1 e2) =
return $ SMod e1 e2
negateSimplified (SQuot e1 e2) = do
(e1', e2') <- helperNegateMult e1 e2; return $ SQuot e1' e2'
negateSimplified (SRem e1 e2) =
return $ SRem e1 e2
negateSimplified (MaxMin ismin ts) =
MaxMin (not ismin) <$> mapM negateSimplified ts
negateSimplified (RelExp LEQ0 e) =
RelExp LTH0 <$> negateSimplified e
negateSimplified (RelExp LTH0 e) =
RelExp LEQ0 <$> negateSimplified e
negateSimplified SLogAnd{} = badAlgSimplifyM "negateSimplified: SLogAnd unimplemented!"
negateSimplified SLogOr{} = badAlgSimplifyM "negateSimplified: SLogOr unimplemented!"
helperNegateMult :: ScalExp -> ScalExp -> AlgSimplifyM (ScalExp, ScalExp)
helperNegateMult e1 e2 =
case (e1, e2) of
(Val _, _) -> do e1'<- negateSimplified e1; return (e1', e2)
(STimes (Val v) e1r, _) -> do ev <- negateSimplified (Val v); return (STimes ev e1r, e2)
(_, Val _) -> do e2'<- negateSimplified e2; return (e1, e2')
(_, STimes (Val v) e2r) -> do ev <- negateSimplified (Val v); return (e1, STimes ev e2r)
(_, _) -> do e1'<- negateSimplified e1; return (e1', e2)
toNumSofP :: ScalExp -> AlgSimplifyM NNumExp
toNumSofP e@(Val _) = return $ NProd [e] $ scalExpType e
toNumSofP e@(Id _ _) = return $ NProd [e] $ scalExpType e
toNumSofP e@SDiv{} = return $ NProd [e] $ scalExpType e
toNumSofP e@SPow{} = return $ NProd [e] $ scalExpType e
toNumSofP (SMinus _ _) = badAlgSimplifyM "toNumSofP: SMinus is not in SofP form!"
toNumSofP (SNeg _) = badAlgSimplifyM "toNumSofP: SNeg is not in SofP form!"
toNumSofP (STimes e1 e2) = do
e2' <- toNumSofP e2
case e2' of
NProd es2 t -> return $ NProd (e1:es2) t
_ -> badAlgSimplifyM "toNumSofP: STimes nor in SofP form!"
toNumSofP (SPlus e1 e2) = do
let t = scalExpType e1
e1' <- toNumSofP e1
e2' <- toNumSofP e2
case (e1', e2') of
(NSum es1 _, NSum es2 _) -> return $ NSum (es1++es2) t
(NSum es1 _, NProd{}) -> return $ NSum (es1++[e2']) t
(NProd{}, NSum es2 _) -> return $ NSum (e1':es2) t
(NProd{}, NProd{} ) -> return $ NSum [e1', e2'] t
toNumSofP me@MaxMin{} =
return $ NProd [me] $ scalExpType me
toNumSofP s_e = return $ NProd [s_e] $ scalExpType s_e
fromNumSofP :: NNumExp -> AlgSimplifyM ScalExp
fromNumSofP (NSum [ ] t) =
Val <$> getZero t
fromNumSofP (NSum [f] _) = fromNumSofP f
fromNumSofP (NSum (f:fs) t) = do
fs_e <- fromNumSofP $ NSum fs t
f_e <- fromNumSofP f
return $ SPlus f_e fs_e
fromNumSofP (NProd [] _) =
badAlgSimplifyM " In fromNumSofP, empty NProd expression! "
fromNumSofP (NProd [f] _) = return f
fromNumSofP (NProd (f:fs) t) = do
fs_e <- fromNumSofP $ NProd fs t
return $ STimes f fs_e
getTerms :: NNumExp -> [NNumExp]
getTerms (NSum es _) = es
getTerms e@NProd{} = [e]
getMultChildren :: NNumExp -> AlgSimplifyM [ScalExp]
getMultChildren (NSum _ _) = badAlgSimplifyM "getMultChildren, NaryPlus should not be nested 2 levels deep "
getMultChildren (NProd xs _) = return xs
splitTerm :: NNumExp -> AlgSimplifyM (NNumExp, PrimValue)
splitTerm (NProd [ ] _) = badAlgSimplifyM "splitTerm: Empty n-ary list of factors."
splitTerm (NProd [f] tp) = do
one <- getPos1 tp
case f of
(Val v) -> return (NProd [Val one] tp, v )
e -> return (NProd [e] tp, one)
splitTerm ne@(NProd (f:fs) tp) =
case f of
(Val v) -> return (NProd fs tp, v)
_ -> do one <- getPos1 tp
return (ne, one)
splitTerm e = do
one <- getPos1 (typeOfNAlg e)
return (e, one)
joinTerm :: (NNumExp, PrimValue) -> AlgSimplifyM NNumExp
joinTerm ( NSum _ _, _) = badAlgSimplifyM "joinTerm: NaryPlus two levels deep."
joinTerm ( NProd [] _, _) = badAlgSimplifyM "joinTerm: Empty NaryProd."
joinTerm ( NProd (Val l:fs) tp, v) = do
v' <- mulVals v l
let v'Lit = Val v'
return $ NProd (v'Lit:sort fs) tp
joinTerm ( e@(NProd fs tp), v)
| P.oneIsh v = return e
| otherwise = let vExp = Val v
in return $ NProd (vExp:sort fs) tp
discriminate :: [(NNumExp, PrimValue)] -> (NNumExp, PrimValue) -> AlgSimplifyM [(NNumExp, PrimValue)]
discriminate [] e = return [e]
discriminate e@((k,v):t) (k', v') =
if k == k'
then do v'' <- addVals v v'
return ( (k, v'') : t )
else return ( (k', v') : e )
getZero :: PrimType -> AlgSimplifyM PrimValue
getZero (IntType t) = return $ value $ intValue t (0::Int)
getZero tp = badAlgSimplifyM ("getZero for type: "++pretty tp)
getPos1 :: PrimType -> AlgSimplifyM PrimValue
getPos1 (IntType t) = return $ value $ intValue t (1::Int)
getPos1 tp = badAlgSimplifyM ("getOne for type: "++pretty tp)
getNeg1 :: PrimType -> AlgSimplifyM PrimValue
getNeg1 (IntType t) = return $ value $ intValue t (-1::Int)
getNeg1 tp = badAlgSimplifyM ("getOne for type: "++pretty tp)
valLTHEQ0 :: RelOp0 -> PrimValue -> AlgSimplifyM Bool
valLTHEQ0 LEQ0 (IntValue iv) = return $ P.intToInt64 iv <= 0
valLTHEQ0 LTH0 (IntValue iv) = return $ P.intToInt64 iv < 0
valLTHEQ0 _ _ = badAlgSimplifyM "valLTHEQ0 for non-numeric type!"
isCt1 :: ScalExp -> Bool
isCt1 (Val v) = P.oneIsh v
isCt1 _ = False
isCt0 :: ScalExp -> Bool
isCt0 (Val v) = P.zeroIsh v
isCt0 _ = False
addVals :: PrimValue -> PrimValue -> AlgSimplifyM PrimValue
addVals (IntValue v1) (IntValue v2) =
return $ IntValue $ P.doAdd v1 v2
addVals _ _ =
badAlgSimplifyM "addVals: operands not of (the same) numeral type! "
mulVals :: PrimValue -> PrimValue -> AlgSimplifyM PrimValue
mulVals (IntValue v1) (IntValue v2) =
return $ IntValue $ P.doMul v1 v2
mulVals v1 v2 =
badAlgSimplifyM $ "mulVals: operands not of (the same) numeral type! "++
pretty v1++" "++pretty v2
divVals :: PrimValue -> PrimValue -> AlgSimplifyM PrimValue
divVals (IntValue v1) (IntValue v2) =
case P.doSDiv v1 v2 of
Just v -> return $ IntValue v
Nothing -> badAlgSimplifyM "Division by zero"
divVals _ _ =
badAlgSimplifyM "divVals: operands not of (the same) numeral type! "
canDivValsEvenly :: PrimValue -> PrimValue -> AlgSimplifyM Bool
canDivValsEvenly (IntValue v1) (IntValue v2) =
case P.doSMod v1 v2 of
Just v -> return $ P.zeroIsh $ IntValue v
Nothing -> return False
canDivValsEvenly _ _ =
badAlgSimplifyM "canDivValsEvenly: operands not of (the same) numeral type!"
typeOfNAlg :: NNumExp -> PrimType
typeOfNAlg (NSum _ t) = t
typeOfNAlg (NProd _ t) = t
trySimplifyDivRec :: [ScalExp] -> [ScalExp] -> [(NNumExp, PrimValue)] ->
AlgSimplifyM ([ScalExp], [(NNumExp, PrimValue)])
trySimplifyDivRec [] fs' spl_terms =
return (fs', spl_terms)
trySimplifyDivRec (f:fs) fs' spl_terms = do
res_tmp <- mapM (tryDivProdByOneFact f) spl_terms
let (succs, spl_terms') = unzip res_tmp
if all (==True) succs
then trySimplifyDivRec fs fs' spl_terms'
else trySimplifyDivRec fs (fs'++[f]) spl_terms
tryDivProdByOneFact :: ScalExp -> (NNumExp, PrimValue) -> AlgSimplifyM (Bool, (NNumExp, PrimValue))
tryDivProdByOneFact (Val f) (e, v) = do
succc <- canDivValsEvenly v f
if succc then do vres <- divVals v f
return (True, (e, vres))
else return (False,(e, v) )
tryDivProdByOneFact _ pev@(NProd [] _, _) = return (False, pev)
tryDivProdByOneFact f pev@(NProd (t:tfs) tp, v) = do
(succc, newt) <- tryDivTriv t f
one <- getPos1 tp
if not succc
then do (succ', (tfs', v')) <- tryDivProdByOneFact f (NProd tfs tp, v)
case (succ', tfs') of
(True, NProd (Val vv:tfs'') _) -> do
vres <- mulVals v' vv
return (True, (NProd (t:tfs'') tp, vres))
(True, NProd tfs'' _) -> return (True, (NProd (t:tfs'') tp, v'))
(_, _) -> return (False, pev)
else case (newt, tfs) of
(Val vv, _) -> do vres <- mulVals vv v
return $ if null tfs
then (True, (NProd [Val one] tp, vres))
else (True, (NProd tfs tp, vres))
(_, _) -> return (True, (NProd (newt:tfs) tp, v))
tryDivProdByOneFact _ (NSum _ _, _) =
badAlgSimplifyM "tryDivProdByOneFact: unreachable case NSum reached!"
tryDivTriv :: ScalExp -> ScalExp -> AlgSimplifyM (Bool, ScalExp)
tryDivTriv (SPow a e1) (SPow d e2)
| a == d && e1 == e2 = do one <- getPos1 $ scalExpType a
return (True, Val one)
| a == d = do
let tp = scalExpType a
one <- getPos1 tp
e1me2 <- simplifyScal $ SMinus e1 e2
case (tp, e1me2) of
(IntType _, Val v) | P.zeroIsh v ->
return (True, Val one)
(IntType _, Val v) | P.oneIsh v ->
return (True, a)
(IntType _, _) -> do
e2me1 <- negateSimplified e1me2
e2me1_sop <- toNumSofP e2me1
p' <- simplifyNRel $ NRelExp LTH0 e2me1_sop
return $ if p' == LogCt True
then (True, SPow a e1me2)
else (False, SDiv (SPow a e1) (SPow d e2))
(_, _) -> return (False, SDiv (SPow a e1) (SPow d e2))
| otherwise = return (False, SDiv (SPow a e1) (SPow d e2))
tryDivTriv (SPow a e1) b
| a == b = do one <- getPos1 $ scalExpType a
tryDivTriv (SPow a e1) (SPow a (Val one))
| otherwise = return (False, SDiv (SPow a e1) b)
tryDivTriv b (SPow a e1)
| a == b = do one <- getPos1 $ scalExpType a
tryDivTriv (SPow a (Val one)) (SPow a e1)
| otherwise = return (False, SDiv b (SPow a e1))
tryDivTriv t f
| t == f = do one <- getPos1 $ scalExpType t
return (True, Val one)
| otherwise = return (False, SDiv t f)