{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

module Futhark.Pass.ExtractKernels.ToKernels
  ( getSize,
    segThread,
    soacsLambdaToKernels,
    soacsStmToKernels,
    scopeForKernels,
    scopeForSOACs,
    injectSOACS,
  )
where

import Control.Monad.Identity
import Data.List ()
import Futhark.Analysis.Rephrase
import Futhark.IR
import Futhark.IR.Kernels
import Futhark.IR.SOACS (SOACS)
import qualified Futhark.IR.SOACS.SOAC as SOAC
import Futhark.Tools

getSize ::
  (MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
  String ->
  SizeClass ->
  m SubExp
getSize :: String -> SizeClass -> m SubExp
getSize String
desc SizeClass
size_class = do
  Name
size_key <- String -> Name
nameFromString (String -> Name) -> (VName -> String) -> VName -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> String
forall a. Pretty a => a -> String
pretty (VName -> Name) -> m VName -> m Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
  String -> Exp (Lore m) -> m SubExp
forall (m :: * -> *).
MonadBinder m =>
String -> Exp (Lore m) -> m SubExp
letSubExp String
desc (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ Op (Lore m) -> Exp (Lore m)
forall lore. Op lore -> ExpT lore
Op (Op (Lore m) -> Exp (Lore m)) -> Op (Lore m) -> Exp (Lore m)
forall a b. (a -> b) -> a -> b
$ SizeOp -> HostOp (Lore m) inner
forall lore op. SizeOp -> HostOp lore op
SizeOp (SizeOp -> HostOp (Lore m) inner)
-> SizeOp -> HostOp (Lore m) inner
forall a b. (a -> b) -> a -> b
$ Name -> SizeClass -> SizeOp
GetSize Name
size_key SizeClass
size_class

segThread ::
  (MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
  String ->
  m SegLevel
segThread :: String -> m SegLevel
segThread String
desc =
  Count NumGroups SubExp
-> Count GroupSize SubExp -> SegVirt -> SegLevel
SegThread
    (Count NumGroups SubExp
 -> Count GroupSize SubExp -> SegVirt -> SegLevel)
-> m (Count NumGroups SubExp)
-> m (Count GroupSize SubExp -> SegVirt -> SegLevel)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> Count NumGroups SubExp
forall u e. e -> Count u e
Count (SubExp -> Count NumGroups SubExp)
-> m SubExp -> m (Count NumGroups SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> SizeClass -> m SubExp
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_num_groups") SizeClass
SizeNumGroups)
    m (Count GroupSize SubExp -> SegVirt -> SegLevel)
-> m (Count GroupSize SubExp) -> m (SegVirt -> SegLevel)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (SubExp -> Count GroupSize SubExp
forall u e. e -> Count u e
Count (SubExp -> Count GroupSize SubExp)
-> m SubExp -> m (Count GroupSize SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> SizeClass -> m SubExp
forall (m :: * -> *) inner.
(MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) =>
String -> SizeClass -> m SubExp
getSize (String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_group_size") SizeClass
SizeGroup)
    m (SegVirt -> SegLevel) -> m SegVirt -> m SegLevel
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> SegVirt -> m SegVirt
forall (f :: * -> *) a. Applicative f => a -> f a
pure SegVirt
SegVirt

injectSOACS ::
  ( Monad m,
    SameScope from to,
    ExpDec from ~ ExpDec to,
    BodyDec from ~ BodyDec to,
    RetType from ~ RetType to,
    BranchType from ~ BranchType to,
    Op from ~ SOAC from
  ) =>
  (SOAC to -> Op to) ->
  Rephraser m from to
injectSOACS :: (SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC to -> Op to
f =
  Rephraser :: forall (m :: * -> *) from to.
(ExpDec from -> m (ExpDec to))
-> (LetDec from -> m (LetDec to))
-> (FParamInfo from -> m (FParamInfo to))
-> (LParamInfo from -> m (LParamInfo to))
-> (BodyDec from -> m (BodyDec to))
-> (RetType from -> m (RetType to))
-> (BranchType from -> m (BranchType to))
-> (Op from -> m (Op to))
-> Rephraser m from to
Rephraser
    { rephraseExpLore :: ExpDec from -> m (ExpDec to)
rephraseExpLore = ExpDec from -> m (ExpDec to)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseBodyLore :: BodyDec from -> m (BodyDec to)
rephraseBodyLore = BodyDec from -> m (BodyDec to)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseLetBoundLore :: LetDec from -> m (LetDec to)
rephraseLetBoundLore = LetDec from -> m (LetDec to)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseFParamLore :: FParamInfo from -> m (FParamInfo to)
rephraseFParamLore = FParamInfo from -> m (FParamInfo to)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseLParamLore :: LParamInfo from -> m (LParamInfo to)
rephraseLParamLore = LParamInfo from -> m (LParamInfo to)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseOp :: Op from -> m (Op to)
rephraseOp = (SOAC to -> Op to) -> m (SOAC to) -> m (Op to)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SOAC to -> Op to
f (m (SOAC to) -> m (Op to))
-> (SOAC from -> m (SOAC to)) -> SOAC from -> m (Op to)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SOAC from -> m (SOAC to)
onSOAC,
      rephraseRetType :: RetType from -> m (RetType to)
rephraseRetType = RetType from -> m (RetType to)
forall (m :: * -> *) a. Monad m => a -> m a
return,
      rephraseBranchType :: BranchType from -> m (BranchType to)
rephraseBranchType = BranchType from -> m (BranchType to)
forall (m :: * -> *) a. Monad m => a -> m a
return
    }
  where
    onSOAC :: SOAC from -> m (SOAC to)
onSOAC = SOACMapper from to m -> SOAC from -> m (SOAC to)
forall (m :: * -> *) flore tlore.
(Applicative m, Monad m) =>
SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore)
SOAC.mapSOACM SOACMapper from to m
mapper
    mapper :: SOACMapper from to m
mapper =
      SOACMapper :: forall flore tlore (m :: * -> *).
(SubExp -> m SubExp)
-> (Lambda flore -> m (Lambda tlore))
-> (VName -> m VName)
-> SOACMapper flore tlore m
SOAC.SOACMapper
        { mapOnSOACSubExp :: SubExp -> m SubExp
SOAC.mapOnSOACSubExp = SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return,
          mapOnSOACVName :: VName -> m VName
SOAC.mapOnSOACVName = VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return,
          mapOnSOACLambda :: Lambda from -> m (Lambda to)
SOAC.mapOnSOACLambda = Rephraser m from to -> Lambda from -> m (Lambda to)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda (Rephraser m from to -> Lambda from -> m (Lambda to))
-> Rephraser m from to -> Lambda from -> m (Lambda to)
forall a b. (a -> b) -> a -> b
$ (SOAC to -> Op to) -> Rephraser m from to
forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpDec from ~ ExpDec to,
 BodyDec from ~ BodyDec to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC to -> Op to
f
        }

soacsStmToKernels :: Stm SOACS -> Stm Kernels
soacsStmToKernels :: Stm SOACS -> Stm Kernels
soacsStmToKernels = Identity (Stm Kernels) -> Stm Kernels
forall a. Identity a -> a
runIdentity (Identity (Stm Kernels) -> Stm Kernels)
-> (Stm SOACS -> Identity (Stm Kernels))
-> Stm SOACS
-> Stm Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity SOACS Kernels
-> Stm SOACS -> Identity (Stm Kernels)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Stm from -> m (Stm to)
rephraseStm ((SOAC Kernels -> Op Kernels) -> Rephraser Identity SOACS Kernels
forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpDec from ~ ExpDec to,
 BodyDec from ~ BodyDec to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC Kernels -> Op Kernels
forall lore op. op -> HostOp lore op
OtherOp)

soacsLambdaToKernels :: Lambda SOACS -> Lambda Kernels
soacsLambdaToKernels :: Lambda SOACS -> Lambda Kernels
soacsLambdaToKernels = Identity (Lambda Kernels) -> Lambda Kernels
forall a. Identity a -> a
runIdentity (Identity (Lambda Kernels) -> Lambda Kernels)
-> (Lambda SOACS -> Identity (Lambda Kernels))
-> Lambda SOACS
-> Lambda Kernels
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Rephraser Identity SOACS Kernels
-> Lambda SOACS -> Identity (Lambda Kernels)
forall (m :: * -> *) from to.
Monad m =>
Rephraser m from to -> Lambda from -> m (Lambda to)
rephraseLambda ((SOAC Kernels -> Op Kernels) -> Rephraser Identity SOACS Kernels
forall (m :: * -> *) from to.
(Monad m, SameScope from to, ExpDec from ~ ExpDec to,
 BodyDec from ~ BodyDec to, RetType from ~ RetType to,
 BranchType from ~ BranchType to, Op from ~ SOAC from) =>
(SOAC to -> Op to) -> Rephraser m from to
injectSOACS SOAC Kernels -> Op Kernels
forall lore op. op -> HostOp lore op
OtherOp)

scopeForSOACs :: Scope Kernels -> Scope SOACS
scopeForSOACs :: Scope Kernels -> Scope SOACS
scopeForSOACs = Scope Kernels -> Scope SOACS
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope

scopeForKernels :: Scope SOACS -> Scope Kernels
scopeForKernels :: Scope SOACS -> Scope Kernels
scopeForKernels = Scope SOACS -> Scope Kernels
forall fromlore tolore.
SameScope fromlore tolore =>
Scope fromlore -> Scope tolore
castScope