{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE TypeFamilies #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

module Futhark.IR.Kernels.Simplify
  ( simplifyKernels,
    simplifyLambda,
    Kernels,

    -- * Building blocks
    simplifyKernelOp,
  )
where

import qualified Futhark.Analysis.SymbolTable as ST
import Futhark.IR.Kernels
import qualified Futhark.IR.SOACS.Simplify as SOAC
import Futhark.MonadFreshNames
import qualified Futhark.Optimise.Simplify as Simplify
import qualified Futhark.Optimise.Simplify.Engine as Engine
import Futhark.Optimise.Simplify.Lore
import Futhark.Optimise.Simplify.Rule
import Futhark.Optimise.Simplify.Rules
import Futhark.Pass
import Futhark.Tools
import qualified Futhark.Transform.FirstOrderTransform as FOT

simpleKernels :: Simplify.SimpleOps Kernels
simpleKernels :: SimpleOps Kernels
simpleKernels = SimplifyOp Kernels (Op Kernels) -> SimpleOps Kernels
forall lore.
(SimplifiableLore lore, Bindable lore) =>
SimplifyOp lore (Op lore) -> SimpleOps lore
Simplify.bindableSimpleOps (SimplifyOp Kernels (Op Kernels) -> SimpleOps Kernels)
-> SimplifyOp Kernels (Op Kernels) -> SimpleOps Kernels
forall a b. (a -> b) -> a -> b
$ SimplifyOp Kernels (SOAC Kernels)
-> HostOp Kernels (SOAC Kernels)
-> SimpleM
     Kernels
     (HostOp (Wise Kernels) (OpWithWisdom (SOAC Kernels)),
      Stms (Wise Kernels))
forall lore op.
(SimplifiableLore lore, BodyDec lore ~ ()) =>
SimplifyOp lore op
-> HostOp lore op
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
simplifyKernelOp SimplifyOp Kernels (SOAC Kernels)
forall lore. SimplifiableLore lore => SimplifyOp lore (SOAC lore)
SOAC.simplifySOAC

simplifyKernels :: Prog Kernels -> PassM (Prog Kernels)
simplifyKernels :: Prog Kernels -> PassM (Prog Kernels)
simplifyKernels =
  SimpleOps Kernels
-> RuleBook (Wise Kernels)
-> HoistBlockers Kernels
-> Prog Kernels
-> PassM (Prog Kernels)
forall lore.
SimplifiableLore lore =>
SimpleOps lore
-> RuleBook (Wise lore)
-> HoistBlockers lore
-> Prog lore
-> PassM (Prog lore)
Simplify.simplifyProg SimpleOps Kernels
simpleKernels RuleBook (Wise Kernels)
kernelRules HoistBlockers Kernels
forall lore. HoistBlockers lore
Simplify.noExtraHoistBlockers

simplifyLambda ::
  (HasScope Kernels m, MonadFreshNames m) =>
  Lambda Kernels ->
  m (Lambda Kernels)
simplifyLambda :: Lambda Kernels -> m (Lambda Kernels)
simplifyLambda =
  SimpleOps Kernels
-> RuleBook (Wise Kernels)
-> HoistBlockers Kernels
-> Lambda Kernels
-> m (Lambda Kernels)
forall (m :: * -> *) lore.
(MonadFreshNames m, HasScope lore m, SimplifiableLore lore) =>
SimpleOps lore
-> RuleBook (Wise lore)
-> HoistBlockers lore
-> Lambda lore
-> m (Lambda lore)
Simplify.simplifyLambda SimpleOps Kernels
simpleKernels RuleBook (Wise Kernels)
kernelRules HoistBlockers Kernels
forall lore. HoistBlockers lore
Engine.noExtraHoistBlockers

simplifyKernelOp ::
  ( Engine.SimplifiableLore lore,
    BodyDec lore ~ ()
  ) =>
  Simplify.SimplifyOp lore op ->
  HostOp lore op ->
  Engine.SimpleM lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
simplifyKernelOp :: SimplifyOp lore op
-> HostOp lore op
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
simplifyKernelOp SimplifyOp lore op
f (OtherOp op
op) = do
  (OpWithWisdom op
op', Stms (Wise lore)
stms) <- SimplifyOp lore op
f op
op
  (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (OpWithWisdom op -> HostOp (Wise lore) (OpWithWisdom op)
forall lore op. op -> HostOp lore op
OtherOp OpWithWisdom op
op', Stms (Wise lore)
stms)
simplifyKernelOp SimplifyOp lore op
_ (SegOp SegOp SegLevel lore
op) = do
  (SegOp SegLevel (Wise lore)
op', Stms (Wise lore)
hoisted) <- SegOp SegLevel lore
-> SimpleM lore (SegOp SegLevel (Wise lore), Stms (Wise lore))
forall lore lvl.
(SimplifiableLore lore, BodyDec lore ~ (), Simplifiable lvl) =>
SegOp lvl lore
-> SimpleM lore (SegOp lvl (Wise lore), Stms (Wise lore))
simplifySegOp SegOp SegLevel lore
op
  (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SegOp SegLevel (Wise lore) -> HostOp (Wise lore) (OpWithWisdom op)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp SegOp SegLevel (Wise lore)
op', Stms (Wise lore)
hoisted)
simplifyKernelOp SimplifyOp lore op
_ (SizeOp (SplitSpace SplitOrdering
o SubExp
w SubExp
i SubExp
elems_per_thread)) =
  (,)
    (HostOp (Wise lore) (OpWithWisdom op)
 -> Stms (Wise lore)
 -> (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore)))
-> SimpleM lore (HostOp (Wise lore) (OpWithWisdom op))
-> SimpleM
     lore
     (Stms (Wise lore)
      -> (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( SizeOp -> HostOp (Wise lore) (OpWithWisdom op)
forall lore op. SizeOp -> HostOp lore op
SizeOp
            (SizeOp -> HostOp (Wise lore) (OpWithWisdom op))
-> SimpleM lore SizeOp
-> SimpleM lore (HostOp (Wise lore) (OpWithWisdom op))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp
SplitSpace (SplitOrdering -> SubExp -> SubExp -> SubExp -> SizeOp)
-> SimpleM lore SplitOrdering
-> SimpleM lore (SubExp -> SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SplitOrdering -> SimpleM lore SplitOrdering
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SplitOrdering
o SimpleM lore (SubExp -> SubExp -> SubExp -> SizeOp)
-> SimpleM lore SubExp -> SimpleM lore (SubExp -> SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
w
                    SimpleM lore (SubExp -> SubExp -> SizeOp)
-> SimpleM lore SubExp -> SimpleM lore (SubExp -> SizeOp)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
i
                    SimpleM lore (SubExp -> SizeOp)
-> SimpleM lore SubExp -> SimpleM lore SizeOp
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
elems_per_thread
                )
        )
    SimpleM
  lore
  (Stms (Wise lore)
   -> (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore)))
-> SimpleM lore (Stms (Wise lore))
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Stms (Wise lore) -> SimpleM lore (Stms (Wise lore))
forall (f :: * -> *) a. Applicative f => a -> f a
pure Stms (Wise lore)
forall a. Monoid a => a
mempty
simplifyKernelOp SimplifyOp lore op
_ (SizeOp (GetSize Name
key SizeClass
size_class)) =
  (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SizeOp -> HostOp (Wise lore) (OpWithWisdom op)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp (Wise lore) (OpWithWisdom op))
-> SizeOp -> HostOp (Wise lore) (OpWithWisdom op)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
key SizeClass
size_class, Stms (Wise lore)
forall a. Monoid a => a
mempty)
simplifyKernelOp SimplifyOp lore op
_ (SizeOp (GetSizeMax SizeClass
size_class)) =
  (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SizeOp -> HostOp (Wise lore) (OpWithWisdom op)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp (Wise lore) (OpWithWisdom op))
-> SizeOp -> HostOp (Wise lore) (OpWithWisdom op)
forall a b. (a -> b) -> a -> b
$ SizeClass -> SizeOp
GetSizeMax SizeClass
size_class, Stms (Wise lore)
forall a. Monoid a => a
mempty)
simplifyKernelOp SimplifyOp lore op
_ (SizeOp (CmpSizeLe Name
key SizeClass
size_class SubExp
x)) = do
  SubExp
x' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
x
  (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SizeOp -> HostOp (Wise lore) (OpWithWisdom op)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp (Wise lore) (OpWithWisdom op))
-> SizeOp -> HostOp (Wise lore) (OpWithWisdom op)
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SubExp -> SizeOp
CmpSizeLe Name
key SizeClass
size_class SubExp
x', Stms (Wise lore)
forall a. Monoid a => a
mempty)
simplifyKernelOp SimplifyOp lore op
_ (SizeOp (CalcNumGroups SubExp
w Name
max_num_groups SubExp
group_size)) = do
  SubExp
w' <- SubExp -> SimpleM lore SubExp
forall e lore.
(Simplifiable e, SimplifiableLore lore) =>
e -> SimpleM lore e
Engine.simplify SubExp
w
  (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
-> SimpleM
     lore (HostOp (Wise lore) (OpWithWisdom op), Stms (Wise lore))
forall (m :: * -> *) a. Monad m => a -> m a
return (SizeOp -> HostOp (Wise lore) (OpWithWisdom op)
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp (Wise lore) (OpWithWisdom op))
-> SizeOp -> HostOp (Wise lore) (OpWithWisdom op)
forall a b. (a -> b) -> a -> b
$ SubExp -> Name -> SubExp -> SizeOp
CalcNumGroups SubExp
w' Name
max_num_groups SubExp
group_size, Stms (Wise lore)
forall a. Monoid a => a
mempty)

instance BinderOps (Wise Kernels)

instance HasSegOp (Wise Kernels) where
  type SegOpLevel (Wise Kernels) = SegLevel
  asSegOp :: Op (Wise Kernels)
-> Maybe (SegOp (SegOpLevel (Wise Kernels)) (Wise Kernels))
asSegOp (SegOp op) = SegOp SegLevel (Wise Kernels)
-> Maybe (SegOp SegLevel (Wise Kernels))
forall a. a -> Maybe a
Just SegOp SegLevel (Wise Kernels)
op
  asSegOp Op (Wise Kernels)
_ = Maybe (SegOp (SegOpLevel (Wise Kernels)) (Wise Kernels))
forall a. Maybe a
Nothing
  segOp :: SegOp (SegOpLevel (Wise Kernels)) (Wise Kernels)
-> Op (Wise Kernels)
segOp = SegOp (SegOpLevel (Wise Kernels)) (Wise Kernels)
-> Op (Wise Kernels)
forall lore op. SegOp SegLevel lore -> HostOp lore op
SegOp

instance SOAC.HasSOAC (Wise Kernels) where
  asSOAC :: Op (Wise Kernels) -> Maybe (SOAC (Wise Kernels))
asSOAC (OtherOp soac) = SOAC (Wise Kernels) -> Maybe (SOAC (Wise Kernels))
forall a. a -> Maybe a
Just SOAC (Wise Kernels)
soac
  asSOAC Op (Wise Kernels)
_ = Maybe (SOAC (Wise Kernels))
forall a. Maybe a
Nothing
  soacOp :: SOAC (Wise Kernels) -> Op (Wise Kernels)
soacOp = SOAC (Wise Kernels) -> Op (Wise Kernels)
forall lore op. op -> HostOp lore op
OtherOp

kernelRules :: RuleBook (Wise Kernels)
kernelRules :: RuleBook (Wise Kernels)
kernelRules =
  RuleBook (Wise Kernels)
forall lore. (BinderOps lore, Aliased lore) => RuleBook lore
standardRules RuleBook (Wise Kernels)
-> RuleBook (Wise Kernels) -> RuleBook (Wise Kernels)
forall a. Semigroup a => a -> a -> a
<> RuleBook (Wise Kernels)
forall lore.
(HasSegOp lore, BinderOps lore, Bindable lore) =>
RuleBook lore
segOpRules
    RuleBook (Wise Kernels)
-> RuleBook (Wise Kernels) -> RuleBook (Wise Kernels)
forall a. Semigroup a => a -> a -> a
<> [TopDownRule (Wise Kernels)]
-> [BottomUpRule (Wise Kernels)] -> RuleBook (Wise Kernels)
forall m. [TopDownRule m] -> [BottomUpRule m] -> RuleBook m
ruleBook
      [ RuleOp (Wise Kernels) (TopDown (Wise Kernels))
-> TopDownRule (Wise Kernels)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise Kernels) (TopDown (Wise Kernels))
redomapIotaToLoop,
        RuleOp (Wise Kernels) (TopDown (Wise Kernels))
-> TopDownRule (Wise Kernels)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise Kernels) (TopDown (Wise Kernels))
forall lore.
(Bindable lore, SimplifiableLore lore, HasSOAC (Wise lore)) =>
TopDownRuleOp (Wise lore)
SOAC.simplifyKnownIterationSOAC,
        RuleOp (Wise Kernels) (TopDown (Wise Kernels))
-> TopDownRule (Wise Kernels)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise Kernels) (TopDown (Wise Kernels))
forall lore.
(Bindable lore, SimplifiableLore lore, HasSOAC (Wise lore)) =>
TopDownRuleOp (Wise lore)
SOAC.removeReplicateMapping,
        RuleOp (Wise Kernels) (TopDown (Wise Kernels))
-> TopDownRule (Wise Kernels)
forall lore a. RuleOp lore a -> SimplificationRule lore a
RuleOp RuleOp (Wise Kernels) (TopDown (Wise Kernels))
forall lore.
(Bindable lore, SimplifiableLore lore, HasSOAC (Wise lore)) =>
TopDownRuleOp (Wise lore)
SOAC.liftIdentityMapping
      ]
      [ RuleBasicOp (Wise Kernels) (BottomUp (Wise Kernels))
-> BottomUpRule (Wise Kernels)
forall lore a. RuleBasicOp lore a -> SimplificationRule lore a
RuleBasicOp RuleBasicOp (Wise Kernels) (BottomUp (Wise Kernels))
forall lore. BinderOps lore => BottomUpRuleBasicOp lore
removeUnnecessaryCopy
      ]

-- We turn reductions over (solely) iotas into do-loops, because there
-- is no useful structure here anyway.  This is mostly a hack to work
-- around the fact that loop tiling would otherwise pointlessly tile
-- them.
redomapIotaToLoop :: TopDownRuleOp (Wise Kernels)
redomapIotaToLoop :: RuleOp (Wise Kernels) (TopDown (Wise Kernels))
redomapIotaToLoop TopDown (Wise Kernels)
vtable Pattern (Wise Kernels)
pat StmAux (ExpDec (Wise Kernels))
aux (OtherOp soac@(Screma _ [arr] form))
  | Just ([Reduce (Wise Kernels)], Lambda (Wise Kernels))
_ <- ScremaForm (Wise Kernels)
-> Maybe ([Reduce (Wise Kernels)], Lambda (Wise Kernels))
forall lore. ScremaForm lore -> Maybe ([Reduce lore], Lambda lore)
isRedomapSOAC ScremaForm (Wise Kernels)
form,
    Just (Iota {}, Certificates
_) <- VName -> TopDown (Wise Kernels) -> Maybe (BasicOp, Certificates)
forall lore.
VName -> SymbolTable lore -> Maybe (BasicOp, Certificates)
ST.lookupBasicOp VName
arr TopDown (Wise Kernels)
vtable =
    RuleM (Wise Kernels) () -> Rule (Wise Kernels)
forall lore. RuleM lore () -> Rule lore
Simplify (RuleM (Wise Kernels) () -> Rule (Wise Kernels))
-> RuleM (Wise Kernels) () -> Rule (Wise Kernels)
forall a b. (a -> b) -> a -> b
$ Certificates -> RuleM (Wise Kernels) () -> RuleM (Wise Kernels) ()
forall (m :: * -> *) a. MonadBinder m => Certificates -> m a -> m a
certifying (StmAux (ExpWisdom, ()) -> Certificates
forall dec. StmAux dec -> Certificates
stmAuxCerts StmAux (ExpWisdom, ())
StmAux (ExpDec (Wise Kernels))
aux) (RuleM (Wise Kernels) () -> RuleM (Wise Kernels) ())
-> RuleM (Wise Kernels) () -> RuleM (Wise Kernels) ()
forall a b. (a -> b) -> a -> b
$ Pattern (Lore (RuleM (Wise Kernels)))
-> SOAC (Lore (RuleM (Wise Kernels))) -> RuleM (Wise Kernels) ()
forall (m :: * -> *).
Transformer m =>
Pattern (Lore m) -> SOAC (Lore m) -> m ()
FOT.transformSOAC Pattern (Lore (RuleM (Wise Kernels)))
Pattern (Wise Kernels)
pat SOAC (Lore (RuleM (Wise Kernels)))
SOAC (Wise Kernels)
soac
redomapIotaToLoop TopDown (Wise Kernels)
_ Pattern (Wise Kernels)
_ StmAux (ExpDec (Wise Kernels))
_ Op (Wise Kernels)
_ =
  Rule (Wise Kernels)
forall lore. Rule lore
Skip