{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE TypeFamilies #-}

-- | Code generation for ImpCode with multicore operations.
module Futhark.CodeGen.ImpGen.Multicore
  ( Futhark.CodeGen.ImpGen.Multicore.compileProg,
    Warnings,
  )
where

import qualified Futhark.CodeGen.ImpCode.Multicore as Imp
import Futhark.CodeGen.ImpGen
import Futhark.CodeGen.ImpGen.Multicore.Base
import Futhark.CodeGen.ImpGen.Multicore.SegHist
import Futhark.CodeGen.ImpGen.Multicore.SegMap
import Futhark.CodeGen.ImpGen.Multicore.SegRed
import Futhark.CodeGen.ImpGen.Multicore.SegScan
import Futhark.IR.MCMem
import Futhark.MonadFreshNames
import Prelude hiding (quot, rem)

-- GCC supported primitve atomic Operations
-- TODO: Add support for 1, 2, and 16 bytes too
gccAtomics :: AtomicBinOp
gccAtomics :: AtomicBinOp
gccAtomics = (BinOp
 -> [(BinOp,
      VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp)]
 -> Maybe
      (VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp))
-> [(BinOp,
     VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp)]
-> AtomicBinOp
forall a b c. (a -> b -> c) -> b -> a -> c
flip BinOp
-> [(BinOp,
     VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp)]
-> Maybe
     (VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp)
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup [(BinOp,
  VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp)]
cpu
  where
    cpu :: [(BinOp,
  VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp)]
cpu =
      [ (IntType -> Overflow -> BinOp
Add IntType
Int32 Overflow
OverflowUndef, IntType
-> VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp
Imp.AtomicAdd IntType
Int32),
        (IntType -> Overflow -> BinOp
Sub IntType
Int32 Overflow
OverflowUndef, IntType
-> VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp
Imp.AtomicSub IntType
Int32),
        (IntType -> BinOp
And IntType
Int32, IntType
-> VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp
Imp.AtomicAnd IntType
Int32),
        (IntType -> BinOp
Xor IntType
Int32, IntType
-> VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp
Imp.AtomicXor IntType
Int32),
        (IntType -> BinOp
Or IntType
Int32, IntType
-> VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp
Imp.AtomicOr IntType
Int32),
        (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef, IntType
-> VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp
Imp.AtomicAdd IntType
Int64),
        (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef, IntType
-> VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp
Imp.AtomicSub IntType
Int64),
        (IntType -> BinOp
And IntType
Int64, IntType
-> VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp
Imp.AtomicAnd IntType
Int64),
        (IntType -> BinOp
Xor IntType
Int64, IntType
-> VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp
Imp.AtomicXor IntType
Int64),
        (IntType -> BinOp
Or IntType
Int64, IntType
-> VName -> VName -> Count Elements (TExp Int32) -> Exp -> AtomicOp
Imp.AtomicOr IntType
Int64)
      ]

compileProg ::
  MonadFreshNames m =>
  Prog MCMem ->
  m (Warnings, Imp.Definitions Imp.Multicore)
compileProg :: forall (m :: * -> *).
MonadFreshNames m =>
Prog MCMem -> m (Warnings, Definitions Multicore)
compileProg = HostEnv
-> Operations MCMem HostEnv Multicore
-> Space
-> Prog MCMem
-> m (Warnings, Definitions Multicore)
forall lore op (m :: * -> *) r.
(Mem lore, FreeIn op, MonadFreshNames m) =>
r
-> Operations lore r op
-> Space
-> Prog lore
-> m (Warnings, Definitions op)
Futhark.CodeGen.ImpGen.compileProg (AtomicBinOp -> HostEnv
HostEnv AtomicBinOp
gccAtomics) Operations MCMem HostEnv Multicore
ops Space
Imp.DefaultSpace
  where
    ops :: Operations MCMem HostEnv Multicore
ops = OpCompiler MCMem HostEnv Multicore
-> Operations MCMem HostEnv Multicore
forall lore op r.
(Mem lore, FreeIn op) =>
OpCompiler lore r op -> Operations lore r op
defaultOperations OpCompiler MCMem HostEnv Multicore
PatternT LParamMem
-> MemOp (MCOp MCMem ()) -> ImpM MCMem HostEnv Multicore ()
opCompiler
    opCompiler :: PatternT LParamMem
-> MemOp (MCOp MCMem ()) -> ImpM MCMem HostEnv Multicore ()
opCompiler PatternT LParamMem
dest (Alloc SubExp
e Space
space) = Pattern MCMem -> SubExp -> Space -> ImpM MCMem HostEnv Multicore ()
forall lore r op.
Mem lore =>
Pattern lore -> SubExp -> Space -> ImpM lore r op ()
compileAlloc Pattern MCMem
PatternT LParamMem
dest SubExp
e Space
space
    opCompiler PatternT LParamMem
dest (Inner MCOp MCMem ()
op) = Pattern MCMem -> MCOp MCMem () -> ImpM MCMem HostEnv Multicore ()
compileMCOp Pattern MCMem
PatternT LParamMem
dest MCOp MCMem ()
op

compileMCOp ::
  Pattern MCMem ->
  MCOp MCMem () ->
  ImpM MCMem HostEnv Imp.Multicore ()
compileMCOp :: Pattern MCMem -> MCOp MCMem () -> ImpM MCMem HostEnv Multicore ()
compileMCOp Pattern MCMem
_ (OtherOp ()) = () -> ImpM MCMem HostEnv Multicore ()
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
compileMCOp Pattern MCMem
pat (ParOp Maybe (SegOp () MCMem)
par_op SegOp () MCMem
op) = do
  let space :: SegSpace
space = SegOp () MCMem -> SegSpace
getSpace SegOp () MCMem
op
  VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ (SegSpace -> VName
segFlat SegSpace
space) (TExp Int64
0 :: Imp.TExp Int64)
  TExp Int64
iterations <- SegOp () MCMem -> SegSpace -> MulticoreGen (TExp Int64)
getIterationDomain SegOp () MCMem
op SegSpace
space
  TV Int32
nsubtasks <- String -> PrimType -> ImpM MCMem HostEnv Multicore (TV Int32)
forall lore r op t. String -> PrimType -> ImpM lore r op (TV t)
dPrim String
"num_tasks" (PrimType -> ImpM MCMem HostEnv Multicore (TV Int32))
-> PrimType -> ImpM MCMem HostEnv Multicore (TV Int32)
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int32
  Code
seq_code <- Pattern MCMem
-> SegOp () MCMem -> TV Int32 -> ImpM MCMem HostEnv Multicore Code
compileSegOp Pattern MCMem
pat SegOp () MCMem
op TV Int32
nsubtasks
  [Param]
retvals <- Pattern MCMem -> SegOp () MCMem -> MulticoreGen [Param]
getReturnParams Pattern MCMem
pat SegOp () MCMem
op

  let scheduling_info :: Scheduling -> SchedulerInfo
scheduling_info = VName -> Exp -> Scheduling -> SchedulerInfo
Imp.SchedulerInfo (TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
nsubtasks) (TExp Int64 -> Exp
forall t v. TPrimExp t v -> PrimExp v
untyped TExp Int64
iterations)

  Code
par_code <- case Maybe (SegOp () MCMem)
par_op of
    Just SegOp () MCMem
nested_op -> do
      let space' :: SegSpace
space' = SegOp () MCMem -> SegSpace
getSpace SegOp () MCMem
nested_op
      VName -> TExp Int64 -> ImpM MCMem HostEnv Multicore ()
forall t lore r op. VName -> TExp t -> ImpM lore r op ()
dPrimV_ (SegSpace -> VName
segFlat SegSpace
space') (TExp Int64
0 :: Imp.TExp Int64)
      Pattern MCMem
-> SegOp () MCMem -> TV Int32 -> ImpM MCMem HostEnv Multicore Code
compileSegOp Pattern MCMem
pat SegOp () MCMem
nested_op TV Int32
nsubtasks
    Maybe (SegOp () MCMem)
Nothing -> Code -> ImpM MCMem HostEnv Multicore Code
forall (m :: * -> *) a. Monad m => a -> m a
return Code
forall a. Monoid a => a
mempty

  let par_task :: Maybe ParallelTask
par_task = case Maybe (SegOp () MCMem)
par_op of
        Just SegOp () MCMem
nested_op -> ParallelTask -> Maybe ParallelTask
forall a. a -> Maybe a
Just (ParallelTask -> Maybe ParallelTask)
-> ParallelTask -> Maybe ParallelTask
forall a b. (a -> b) -> a -> b
$ Code -> VName -> ParallelTask
Imp.ParallelTask Code
par_code (VName -> ParallelTask) -> VName -> ParallelTask
forall a b. (a -> b) -> a -> b
$ SegSpace -> VName
segFlat (SegSpace -> VName) -> SegSpace -> VName
forall a b. (a -> b) -> a -> b
$ SegOp () MCMem -> SegSpace
getSpace SegOp () MCMem
nested_op
        Maybe (SegOp () MCMem)
Nothing -> Maybe ParallelTask
forall a. Maybe a
Nothing

  let non_free :: [VName]
non_free =
        ( [SegSpace -> VName
segFlat SegSpace
space, TV Int32 -> VName
forall t. TV t -> VName
tvVar TV Int32
nsubtasks]
            [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ (Param -> VName) -> [Param] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param -> VName
Imp.paramName [Param]
retvals
        )
          [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ case Maybe (SegOp () MCMem)
par_op of
            Just SegOp () MCMem
nested_op ->
              [SegSpace -> VName
segFlat (SegSpace -> VName) -> SegSpace -> VName
forall a b. (a -> b) -> a -> b
$ SegOp () MCMem -> SegSpace
getSpace SegOp () MCMem
nested_op]
            Maybe (SegOp () MCMem)
Nothing -> []

  String
s <- SegOp () MCMem -> MulticoreGen String
segOpString SegOp () MCMem
op
  [Param]
free_params <- Code -> [VName] -> MulticoreGen [Param]
freeParams (Code
par_code Code -> Code -> Code
forall a. Semigroup a => a -> a -> a
<> Code
seq_code) [VName]
non_free
  let seq_task :: ParallelTask
seq_task = Code -> VName -> ParallelTask
Imp.ParallelTask Code
seq_code (SegSpace -> VName
segFlat SegSpace
space)
  Code -> ImpM MCMem HostEnv Multicore ()
forall op lore r. Code op -> ImpM lore r op ()
emit (Code -> ImpM MCMem HostEnv Multicore ())
-> Code -> ImpM MCMem HostEnv Multicore ()
forall a b. (a -> b) -> a -> b
$ Multicore -> Code
forall a. a -> Code a
Imp.Op (Multicore -> Code) -> Multicore -> Code
forall a b. (a -> b) -> a -> b
$ String
-> [Param]
-> ParallelTask
-> Maybe ParallelTask
-> [Param]
-> SchedulerInfo
-> Multicore
Imp.Segop String
s [Param]
free_params ParallelTask
seq_task Maybe ParallelTask
par_task [Param]
retvals (SchedulerInfo -> Multicore) -> SchedulerInfo -> Multicore
forall a b. (a -> b) -> a -> b
$ Scheduling -> SchedulerInfo
scheduling_info (SegOp () MCMem -> Code -> Scheduling
forall lore. SegOp () lore -> Code -> Scheduling
decideScheduling' SegOp () MCMem
op Code
seq_code)

compileSegOp ::
  Pattern MCMem ->
  SegOp () MCMem ->
  TV Int32 ->
  ImpM MCMem HostEnv Imp.Multicore Imp.Code
compileSegOp :: Pattern MCMem
-> SegOp () MCMem -> TV Int32 -> ImpM MCMem HostEnv Multicore Code
compileSegOp Pattern MCMem
pat (SegHist ()
_ SegSpace
space [HistOp MCMem]
histops [Type]
_ KernelBody MCMem
kbody) TV Int32
ntasks =
  Pattern MCMem
-> SegSpace
-> [HistOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> ImpM MCMem HostEnv Multicore Code
compileSegHist Pattern MCMem
pat SegSpace
space [HistOp MCMem]
histops KernelBody MCMem
kbody TV Int32
ntasks
compileSegOp Pattern MCMem
pat (SegScan ()
_ SegSpace
space [SegBinOp MCMem]
scans [Type]
_ KernelBody MCMem
kbody) TV Int32
ntasks =
  Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> ImpM MCMem HostEnv Multicore Code
compileSegScan Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
scans KernelBody MCMem
kbody TV Int32
ntasks
compileSegOp Pattern MCMem
pat (SegRed ()
_ SegSpace
space [SegBinOp MCMem]
reds [Type]
_ KernelBody MCMem
kbody) TV Int32
ntasks =
  Pattern MCMem
-> SegSpace
-> [SegBinOp MCMem]
-> KernelBody MCMem
-> TV Int32
-> ImpM MCMem HostEnv Multicore Code
compileSegRed Pattern MCMem
pat SegSpace
space [SegBinOp MCMem]
reds KernelBody MCMem
kbody TV Int32
ntasks
compileSegOp Pattern MCMem
pat (SegMap ()
_ SegSpace
space [Type]
_ KernelBody MCMem
kbody) TV Int32
_ =
  Pattern MCMem
-> SegSpace
-> KernelBody MCMem
-> ImpM MCMem HostEnv Multicore Code
compileSegMap Pattern MCMem
pat SegSpace
space KernelBody MCMem
kbody