{-# LANGUAGE FlexibleContexts #-}

-- | Defines simplification functions for 'PrimExp's.
module Futhark.Analysis.PrimExp.Simplify (simplifyPrimExp, simplifyExtPrimExp) where

import Futhark.Analysis.PrimExp
import Futhark.IR
import Futhark.Optimise.Simplify.Engine as Engine

-- | Simplify a 'PrimExp', including copy propagation.  If a 'LeafExp'
-- refers to a name that is a 'Constant', the node turns into a
-- 'ValueExp'.
simplifyPrimExp ::
  SimplifiableLore lore =>
  PrimExp VName ->
  SimpleM lore (PrimExp VName)
simplifyPrimExp :: PrimExp VName -> SimpleM lore (PrimExp VName)
simplifyPrimExp = (VName -> PrimType -> SimpleM lore (PrimExp VName))
-> PrimExp VName -> SimpleM lore (PrimExp VName)
forall lore a.
SimplifiableLore lore =>
(a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp VName -> PrimType -> SimpleM lore (PrimExp VName)
forall lore.
(ASTLore lore, Simplifiable (LetDec lore),
 Simplifiable (FParamInfo lore), Simplifiable (LParamInfo lore),
 Simplifiable (RetType lore), Simplifiable (BranchType lore),
 CanBeWise (Op lore), IndexOp (OpWithWisdom (Op lore)),
 BinderOps (Wise lore)) =>
VName -> PrimType -> SimpleM lore (PrimExp VName)
onLeaf
  where
    onLeaf :: VName -> PrimType -> SimpleM lore (PrimExp VName)
onLeaf VName
v PrimType
pt = do
      SubExp
se <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
      case SubExp
se of
        Var VName
v' -> PrimExp VName -> SimpleM lore (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp VName -> SimpleM lore (PrimExp VName))
-> PrimExp VName -> SimpleM lore (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v' PrimType
pt
        Constant PrimValue
pv -> PrimExp VName -> SimpleM lore (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp VName -> SimpleM lore (PrimExp VName))
-> PrimExp VName -> SimpleM lore (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimExp VName
forall v. PrimValue -> PrimExp v
ValueExp PrimValue
pv

-- | Like 'simplifyPrimExp', but where leaves may be 'Ext's.
simplifyExtPrimExp ::
  SimplifiableLore lore =>
  PrimExp (Ext VName) ->
  SimpleM lore (PrimExp (Ext VName))
simplifyExtPrimExp :: PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName))
simplifyExtPrimExp = (Ext VName -> PrimType -> SimpleM lore (PrimExp (Ext VName)))
-> PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName))
forall lore a.
SimplifiableLore lore =>
(a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp Ext VName -> PrimType -> SimpleM lore (PrimExp (Ext VName))
forall lore.
(ASTLore lore, Simplifiable (LetDec lore),
 Simplifiable (FParamInfo lore), Simplifiable (LParamInfo lore),
 Simplifiable (RetType lore), Simplifiable (BranchType lore),
 CanBeWise (Op lore), IndexOp (OpWithWisdom (Op lore)),
 BinderOps (Wise lore)) =>
Ext VName -> PrimType -> SimpleM lore (PrimExp (Ext VName))
onLeaf
  where
    onLeaf :: Ext VName -> PrimType -> SimpleM lore (PrimExp (Ext VName))
onLeaf (Free VName
v) PrimType
pt = do
      SubExp
se <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
      case SubExp
se of
        Var VName
v' -> PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName)))
-> PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName))
forall a b. (a -> b) -> a -> b
$ Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (VName -> Ext VName
forall a. a -> Ext a
Free VName
v') PrimType
pt
        Constant PrimValue
pv -> PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName)))
-> PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName))
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimExp (Ext VName)
forall v. PrimValue -> PrimExp v
ValueExp PrimValue
pv
    onLeaf (Ext Int
i) PrimType
pt = PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName)))
-> PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName))
forall a b. (a -> b) -> a -> b
$ Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i) PrimType
pt

simplifyAnyPrimExp ::
  SimplifiableLore lore =>
  (a -> PrimType -> SimpleM lore (PrimExp a)) ->
  PrimExp a ->
  SimpleM lore (PrimExp a)
simplifyAnyPrimExp :: (a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f (LeafExp a
v PrimType
pt) = a -> PrimType -> SimpleM lore (PrimExp a)
f a
v PrimType
pt
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
_ (ValueExp PrimValue
pv) =
  PrimExp a -> SimpleM lore (PrimExp a)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp a -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimExp a
forall v. PrimValue -> PrimExp v
ValueExp PrimValue
pv
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f (BinOpExp BinOp
bop PrimExp a
e1 PrimExp a
e2) =
  BinOp -> PrimExp a -> PrimExp a -> PrimExp a
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp BinOp
bop (PrimExp a -> PrimExp a -> PrimExp a)
-> SimpleM lore (PrimExp a)
-> SimpleM lore (PrimExp a -> PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
forall lore a.
SimplifiableLore lore =>
(a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f PrimExp a
e1 SimpleM lore (PrimExp a -> PrimExp a)
-> SimpleM lore (PrimExp a) -> SimpleM lore (PrimExp a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
forall lore a.
SimplifiableLore lore =>
(a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f PrimExp a
e2
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f (CmpOpExp CmpOp
cmp PrimExp a
e1 PrimExp a
e2) =
  CmpOp -> PrimExp a -> PrimExp a -> PrimExp a
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp CmpOp
cmp (PrimExp a -> PrimExp a -> PrimExp a)
-> SimpleM lore (PrimExp a)
-> SimpleM lore (PrimExp a -> PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
forall lore a.
SimplifiableLore lore =>
(a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f PrimExp a
e1 SimpleM lore (PrimExp a -> PrimExp a)
-> SimpleM lore (PrimExp a) -> SimpleM lore (PrimExp a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
forall lore a.
SimplifiableLore lore =>
(a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f PrimExp a
e2
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f (UnOpExp UnOp
op PrimExp a
e) =
  UnOp -> PrimExp a -> PrimExp a
forall v. UnOp -> PrimExp v -> PrimExp v
UnOpExp UnOp
op (PrimExp a -> PrimExp a)
-> SimpleM lore (PrimExp a) -> SimpleM lore (PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
forall lore a.
SimplifiableLore lore =>
(a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f PrimExp a
e
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f (ConvOpExp ConvOp
conv PrimExp a
e) =
  ConvOp -> PrimExp a -> PrimExp a
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp ConvOp
conv (PrimExp a -> PrimExp a)
-> SimpleM lore (PrimExp a) -> SimpleM lore (PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
forall lore a.
SimplifiableLore lore =>
(a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f PrimExp a
e
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f (FunExp String
h [PrimExp a]
args PrimType
t) =
  String -> [PrimExp a] -> PrimType -> PrimExp a
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
FunExp String
h ([PrimExp a] -> PrimType -> PrimExp a)
-> SimpleM lore [PrimExp a] -> SimpleM lore (PrimType -> PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PrimExp a -> SimpleM lore (PrimExp a))
-> [PrimExp a] -> SimpleM lore [PrimExp a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
forall lore a.
SimplifiableLore lore =>
(a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f) [PrimExp a]
args SimpleM lore (PrimType -> PrimExp a)
-> SimpleM lore PrimType -> SimpleM lore (PrimExp a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PrimType -> SimpleM lore PrimType
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t