{-# LANGUAGE ConstraintKinds #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE TypeFamilies #-} module Futhark.Pass.ExtractKernels.ToKernels ( getSize, segThread, soacsLambdaToKernels, soacsStmToKernels, scopeForKernels, scopeForSOACs, injectSOACS, ) where import Control.Monad.Identity import Data.List () import Futhark.Analysis.Rephrase import Futhark.IR import Futhark.IR.Kernels import Futhark.IR.SOACS (SOACS) import qualified Futhark.IR.SOACS.SOAC as SOAC import Futhark.Tools getSize :: (MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) => String -> SizeClass -> m SubExp getSize :: String -> SizeClass -> m SubExp getSize String desc SizeClass size_class = do Name size_key <- String -> Name nameFromString (String -> Name) -> (VName -> String) -> VName -> Name forall b c a. (b -> c) -> (a -> b) -> a -> c . VName -> String forall a. Pretty a => a -> String pretty (VName -> Name) -> m VName -> m Name forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> String -> m VName forall (m :: * -> *). MonadFreshNames m => String -> m VName newVName String desc String -> Exp (Lore m) -> m SubExp forall (m :: * -> *). MonadBinder m => String -> Exp (Lore m) -> m SubExp letSubExp String desc (Exp (Lore m) -> m SubExp) -> Exp (Lore m) -> m SubExp forall a b. (a -> b) -> a -> b $ Op (Lore m) -> Exp (Lore m) forall lore. Op lore -> ExpT lore Op (Op (Lore m) -> Exp (Lore m)) -> Op (Lore m) -> Exp (Lore m) forall a b. (a -> b) -> a -> b $ SizeOp -> HostOp (Lore m) inner forall lore op. SizeOp -> HostOp lore op SizeOp (SizeOp -> HostOp (Lore m) inner) -> SizeOp -> HostOp (Lore m) inner forall a b. (a -> b) -> a -> b $ Name -> SizeClass -> SizeOp GetSize Name size_key SizeClass size_class segThread :: (MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) => String -> m SegLevel segThread :: String -> m SegLevel segThread String desc = Count NumGroups SubExp -> Count GroupSize SubExp -> SegVirt -> SegLevel SegThread (Count NumGroups SubExp -> Count GroupSize SubExp -> SegVirt -> SegLevel) -> m (Count NumGroups SubExp) -> m (Count GroupSize SubExp -> SegVirt -> SegLevel) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> (SubExp -> Count NumGroups SubExp forall u e. e -> Count u e Count (SubExp -> Count NumGroups SubExp) -> m SubExp -> m (Count NumGroups SubExp) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> String -> SizeClass -> m SubExp forall (m :: * -> *) inner. (MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) => String -> SizeClass -> m SubExp getSize (String desc String -> String -> String forall a. [a] -> [a] -> [a] ++ String "_num_groups") SizeClass SizeNumGroups) m (Count GroupSize SubExp -> SegVirt -> SegLevel) -> m (Count GroupSize SubExp) -> m (SegVirt -> SegLevel) forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b <*> (SubExp -> Count GroupSize SubExp forall u e. e -> Count u e Count (SubExp -> Count GroupSize SubExp) -> m SubExp -> m (Count GroupSize SubExp) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b <$> String -> SizeClass -> m SubExp forall (m :: * -> *) inner. (MonadBinder m, Op (Lore m) ~ HostOp (Lore m) inner) => String -> SizeClass -> m SubExp getSize (String desc String -> String -> String forall a. [a] -> [a] -> [a] ++ String "_group_size") SizeClass SizeGroup) m (SegVirt -> SegLevel) -> m SegVirt -> m SegLevel forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b <*> SegVirt -> m SegVirt forall (f :: * -> *) a. Applicative f => a -> f a pure SegVirt SegVirt injectSOACS :: ( Monad m, SameScope from to, ExpDec from ~ ExpDec to, BodyDec from ~ BodyDec to, RetType from ~ RetType to, BranchType from ~ BranchType to, Op from ~ SOAC from ) => (SOAC to -> Op to) -> Rephraser m from to injectSOACS :: (SOAC to -> Op to) -> Rephraser m from to injectSOACS SOAC to -> Op to f = Rephraser :: forall (m :: * -> *) from to. (ExpDec from -> m (ExpDec to)) -> (LetDec from -> m (LetDec to)) -> (FParamInfo from -> m (FParamInfo to)) -> (LParamInfo from -> m (LParamInfo to)) -> (BodyDec from -> m (BodyDec to)) -> (RetType from -> m (RetType to)) -> (BranchType from -> m (BranchType to)) -> (Op from -> m (Op to)) -> Rephraser m from to Rephraser { rephraseExpLore :: ExpDec from -> m (ExpDec to) rephraseExpLore = ExpDec from -> m (ExpDec to) forall (m :: * -> *) a. Monad m => a -> m a return, rephraseBodyLore :: BodyDec from -> m (BodyDec to) rephraseBodyLore = BodyDec from -> m (BodyDec to) forall (m :: * -> *) a. Monad m => a -> m a return, rephraseLetBoundLore :: LetDec from -> m (LetDec to) rephraseLetBoundLore = LetDec from -> m (LetDec to) forall (m :: * -> *) a. Monad m => a -> m a return, rephraseFParamLore :: FParamInfo from -> m (FParamInfo to) rephraseFParamLore = FParamInfo from -> m (FParamInfo to) forall (m :: * -> *) a. Monad m => a -> m a return, rephraseLParamLore :: LParamInfo from -> m (LParamInfo to) rephraseLParamLore = LParamInfo from -> m (LParamInfo to) forall (m :: * -> *) a. Monad m => a -> m a return, rephraseOp :: Op from -> m (Op to) rephraseOp = (SOAC to -> Op to) -> m (SOAC to) -> m (Op to) forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b fmap SOAC to -> Op to f (m (SOAC to) -> m (Op to)) -> (SOAC from -> m (SOAC to)) -> SOAC from -> m (Op to) forall b c a. (b -> c) -> (a -> b) -> a -> c . SOAC from -> m (SOAC to) onSOAC, rephraseRetType :: RetType from -> m (RetType to) rephraseRetType = RetType from -> m (RetType to) forall (m :: * -> *) a. Monad m => a -> m a return, rephraseBranchType :: BranchType from -> m (BranchType to) rephraseBranchType = BranchType from -> m (BranchType to) forall (m :: * -> *) a. Monad m => a -> m a return } where onSOAC :: SOAC from -> m (SOAC to) onSOAC = SOACMapper from to m -> SOAC from -> m (SOAC to) forall (m :: * -> *) flore tlore. (Applicative m, Monad m) => SOACMapper flore tlore m -> SOAC flore -> m (SOAC tlore) SOAC.mapSOACM SOACMapper from to m mapper mapper :: SOACMapper from to m mapper = SOACMapper :: forall flore tlore (m :: * -> *). (SubExp -> m SubExp) -> (Lambda flore -> m (Lambda tlore)) -> (VName -> m VName) -> SOACMapper flore tlore m SOAC.SOACMapper { mapOnSOACSubExp :: SubExp -> m SubExp SOAC.mapOnSOACSubExp = SubExp -> m SubExp forall (m :: * -> *) a. Monad m => a -> m a return, mapOnSOACVName :: VName -> m VName SOAC.mapOnSOACVName = VName -> m VName forall (m :: * -> *) a. Monad m => a -> m a return, mapOnSOACLambda :: Lambda from -> m (Lambda to) SOAC.mapOnSOACLambda = Rephraser m from to -> Lambda from -> m (Lambda to) forall (m :: * -> *) from to. Monad m => Rephraser m from to -> Lambda from -> m (Lambda to) rephraseLambda (Rephraser m from to -> Lambda from -> m (Lambda to)) -> Rephraser m from to -> Lambda from -> m (Lambda to) forall a b. (a -> b) -> a -> b $ (SOAC to -> Op to) -> Rephraser m from to forall (m :: * -> *) from to. (Monad m, SameScope from to, ExpDec from ~ ExpDec to, BodyDec from ~ BodyDec to, RetType from ~ RetType to, BranchType from ~ BranchType to, Op from ~ SOAC from) => (SOAC to -> Op to) -> Rephraser m from to injectSOACS SOAC to -> Op to f } soacsStmToKernels :: Stm SOACS -> Stm Kernels soacsStmToKernels :: Stm SOACS -> Stm Kernels soacsStmToKernels = Identity (Stm Kernels) -> Stm Kernels forall a. Identity a -> a runIdentity (Identity (Stm Kernels) -> Stm Kernels) -> (Stm SOACS -> Identity (Stm Kernels)) -> Stm SOACS -> Stm Kernels forall b c a. (b -> c) -> (a -> b) -> a -> c . Rephraser Identity SOACS Kernels -> Stm SOACS -> Identity (Stm Kernels) forall (m :: * -> *) from to. Monad m => Rephraser m from to -> Stm from -> m (Stm to) rephraseStm ((SOAC Kernels -> Op Kernels) -> Rephraser Identity SOACS Kernels forall (m :: * -> *) from to. (Monad m, SameScope from to, ExpDec from ~ ExpDec to, BodyDec from ~ BodyDec to, RetType from ~ RetType to, BranchType from ~ BranchType to, Op from ~ SOAC from) => (SOAC to -> Op to) -> Rephraser m from to injectSOACS SOAC Kernels -> Op Kernels forall lore op. op -> HostOp lore op OtherOp) soacsLambdaToKernels :: Lambda SOACS -> Lambda Kernels soacsLambdaToKernels :: Lambda SOACS -> Lambda Kernels soacsLambdaToKernels = Identity (Lambda Kernels) -> Lambda Kernels forall a. Identity a -> a runIdentity (Identity (Lambda Kernels) -> Lambda Kernels) -> (Lambda SOACS -> Identity (Lambda Kernels)) -> Lambda SOACS -> Lambda Kernels forall b c a. (b -> c) -> (a -> b) -> a -> c . Rephraser Identity SOACS Kernels -> Lambda SOACS -> Identity (Lambda Kernels) forall (m :: * -> *) from to. Monad m => Rephraser m from to -> Lambda from -> m (Lambda to) rephraseLambda ((SOAC Kernels -> Op Kernels) -> Rephraser Identity SOACS Kernels forall (m :: * -> *) from to. (Monad m, SameScope from to, ExpDec from ~ ExpDec to, BodyDec from ~ BodyDec to, RetType from ~ RetType to, BranchType from ~ BranchType to, Op from ~ SOAC from) => (SOAC to -> Op to) -> Rephraser m from to injectSOACS SOAC Kernels -> Op Kernels forall lore op. op -> HostOp lore op OtherOp) scopeForSOACs :: Scope Kernels -> Scope SOACS scopeForSOACs :: Scope Kernels -> Scope SOACS scopeForSOACs = Scope Kernels -> Scope SOACS forall fromlore tolore. SameScope fromlore tolore => Scope fromlore -> Scope tolore castScope scopeForKernels :: Scope SOACS -> Scope Kernels scopeForKernels :: Scope SOACS -> Scope Kernels scopeForKernels = Scope SOACS -> Scope Kernels forall fromlore tolore. SameScope fromlore tolore => Scope fromlore -> Scope tolore castScope