{-# LANGUAGE FlexibleContexts #-}

-- | Defines simplification functions for 'PrimExp's.
module Futhark.Analysis.PrimExp.Simplify (simplifyPrimExp, simplifyExtPrimExp) where

import Futhark.Analysis.PrimExp
import Futhark.IR
import Futhark.Optimise.Simplify.Engine as Engine

-- | Simplify a 'PrimExp', including copy propagation.  If a 'LeafExp'
-- refers to a name that is a 'Constant', the node turns into a
-- 'ValueExp'.
simplifyPrimExp ::
  SimplifiableLore lore =>
  PrimExp VName ->
  SimpleM lore (PrimExp VName)
simplifyPrimExp :: forall lore.
SimplifiableLore lore =>
PrimExp VName -> SimpleM lore (PrimExp VName)
simplifyPrimExp = (VName -> PrimType -> SimpleM lore (PrimExp VName))
-> PrimExp VName -> SimpleM lore (PrimExp VName)
forall lore a.
SimplifiableLore lore =>
(a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp VName -> PrimType -> SimpleM lore (PrimExp VName)
forall {lore}.
(ASTLore lore, Simplifiable (LetDec lore),
 Simplifiable (FParamInfo lore), Simplifiable (LParamInfo lore),
 Simplifiable (RetType lore), Simplifiable (BranchType lore),
 CanBeWise (Op lore), IndexOp (OpWithWisdom (Op lore)),
 BinderOps (Wise lore)) =>
VName -> PrimType -> SimpleM lore (PrimExp VName)
onLeaf
  where
    onLeaf :: VName -> PrimType -> SimpleM lore (PrimExp VName)
onLeaf VName
v PrimType
pt = do
      SubExp
se <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
      case SubExp
se of
        Var VName
v' -> PrimExp VName -> SimpleM lore (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp VName -> SimpleM lore (PrimExp VName))
-> PrimExp VName -> SimpleM lore (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
v' PrimType
pt
        Constant PrimValue
pv -> PrimExp VName -> SimpleM lore (PrimExp VName)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp VName -> SimpleM lore (PrimExp VName))
-> PrimExp VName -> SimpleM lore (PrimExp VName)
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimExp VName
forall v. PrimValue -> PrimExp v
ValueExp PrimValue
pv

-- | Like 'simplifyPrimExp', but where leaves may be 'Ext's.
simplifyExtPrimExp ::
  SimplifiableLore lore =>
  PrimExp (Ext VName) ->
  SimpleM lore (PrimExp (Ext VName))
simplifyExtPrimExp :: forall lore.
SimplifiableLore lore =>
PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName))
simplifyExtPrimExp = (Ext VName -> PrimType -> SimpleM lore (PrimExp (Ext VName)))
-> PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName))
forall lore a.
SimplifiableLore lore =>
(a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp Ext VName -> PrimType -> SimpleM lore (PrimExp (Ext VName))
forall {lore}.
(ASTLore lore, Simplifiable (LetDec lore),
 Simplifiable (FParamInfo lore), Simplifiable (LParamInfo lore),
 Simplifiable (RetType lore), Simplifiable (BranchType lore),
 CanBeWise (Op lore), IndexOp (OpWithWisdom (Op lore)),
 BinderOps (Wise lore)) =>
Ext VName -> PrimType -> SimpleM lore (PrimExp (Ext VName))
onLeaf
  where
    onLeaf :: Ext VName -> PrimType -> SimpleM lore (PrimExp (Ext VName))
onLeaf (Free VName
v) PrimType
pt = do
      SubExp
se <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
simplify (SubExp -> SimpleM lore SubExp) -> SubExp -> SimpleM lore SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
      case SubExp
se of
        Var VName
v' -> PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName)))
-> PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName))
forall a b. (a -> b) -> a -> b
$ Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (VName -> Ext VName
forall a. a -> Ext a
Free VName
v') PrimType
pt
        Constant PrimValue
pv -> PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName)))
-> PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName))
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimExp (Ext VName)
forall v. PrimValue -> PrimExp v
ValueExp PrimValue
pv
    onLeaf (Ext Int
i) PrimType
pt = PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName))
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName)))
-> PrimExp (Ext VName) -> SimpleM lore (PrimExp (Ext VName))
forall a b. (a -> b) -> a -> b
$ Ext VName -> PrimType -> PrimExp (Ext VName)
forall v. v -> PrimType -> PrimExp v
LeafExp (Int -> Ext VName
forall a. Int -> Ext a
Ext Int
i) PrimType
pt

simplifyAnyPrimExp ::
  SimplifiableLore lore =>
  (a -> PrimType -> SimpleM lore (PrimExp a)) ->
  PrimExp a ->
  SimpleM lore (PrimExp a)
simplifyAnyPrimExp :: forall lore a.
SimplifiableLore lore =>
(a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f (LeafExp a
v PrimType
pt) = a -> PrimType -> SimpleM lore (PrimExp a)
f a
v PrimType
pt
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
_ (ValueExp PrimValue
pv) =
  PrimExp a -> SimpleM lore (PrimExp a)
forall (m :: * -> *) a. Monad m => a -> m a
return (PrimExp a -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
forall a b. (a -> b) -> a -> b
$ PrimValue -> PrimExp a
forall v. PrimValue -> PrimExp v
ValueExp PrimValue
pv
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f (BinOpExp BinOp
bop PrimExp a
e1 PrimExp a
e2) =
  BinOp -> PrimExp a -> PrimExp a -> PrimExp a
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp BinOp
bop (PrimExp a -> PrimExp a -> PrimExp a)
-> SimpleM lore (PrimExp a)
-> SimpleM lore (PrimExp a -> PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
forall lore a.
SimplifiableLore lore =>
(a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f PrimExp a
e1 SimpleM lore (PrimExp a -> PrimExp a)
-> SimpleM lore (PrimExp a) -> SimpleM lore (PrimExp a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
forall lore a.
SimplifiableLore lore =>
(a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f PrimExp a
e2
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f (CmpOpExp CmpOp
cmp PrimExp a
e1 PrimExp a
e2) =
  CmpOp -> PrimExp a -> PrimExp a -> PrimExp a
forall v. CmpOp -> PrimExp v -> PrimExp v -> PrimExp v
CmpOpExp CmpOp
cmp (PrimExp a -> PrimExp a -> PrimExp a)
-> SimpleM lore (PrimExp a)
-> SimpleM lore (PrimExp a -> PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
forall lore a.
SimplifiableLore lore =>
(a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f PrimExp a
e1 SimpleM lore (PrimExp a -> PrimExp a)
-> SimpleM lore (PrimExp a) -> SimpleM lore (PrimExp a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
forall lore a.
SimplifiableLore lore =>
(a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f PrimExp a
e2
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f (UnOpExp UnOp
op PrimExp a
e) =
  UnOp -> PrimExp a -> PrimExp a
forall v. UnOp -> PrimExp v -> PrimExp v
UnOpExp UnOp
op (PrimExp a -> PrimExp a)
-> SimpleM lore (PrimExp a) -> SimpleM lore (PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
forall lore a.
SimplifiableLore lore =>
(a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f PrimExp a
e
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f (ConvOpExp ConvOp
conv PrimExp a
e) =
  ConvOp -> PrimExp a -> PrimExp a
forall v. ConvOp -> PrimExp v -> PrimExp v
ConvOpExp ConvOp
conv (PrimExp a -> PrimExp a)
-> SimpleM lore (PrimExp a) -> SimpleM lore (PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
forall lore a.
SimplifiableLore lore =>
(a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f PrimExp a
e
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f (FunExp String
h [PrimExp a]
args PrimType
t) =
  String -> [PrimExp a] -> PrimType -> PrimExp a
forall v. String -> [PrimExp v] -> PrimType -> PrimExp v
FunExp String
h ([PrimExp a] -> PrimType -> PrimExp a)
-> SimpleM lore [PrimExp a] -> SimpleM lore (PrimType -> PrimExp a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PrimExp a -> SimpleM lore (PrimExp a))
-> [PrimExp a] -> SimpleM lore [PrimExp a]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
forall lore a.
SimplifiableLore lore =>
(a -> PrimType -> SimpleM lore (PrimExp a))
-> PrimExp a -> SimpleM lore (PrimExp a)
simplifyAnyPrimExp a -> PrimType -> SimpleM lore (PrimExp a)
f) [PrimExp a]
args SimpleM lore (PrimType -> PrimExp a)
-> SimpleM lore PrimType -> SimpleM lore (PrimExp a)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PrimType -> SimpleM lore PrimType
forall (f :: * -> *) a. Applicative f => a -> f a
pure PrimType
t