{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies        #-}
{-# LANGUAGE TypeOperators       #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.Native.CodeGen.Base
-- Copyright   : [2015..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.Native.CodeGen.Base
  where

import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Module
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.Compile.Cache
import Data.Array.Accelerate.LLVM.Native.Target                     ( Native )
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Type

import LLVM.AST.Type.Downcast
import LLVM.AST.Type.Name
import qualified LLVM.AST.Global                                    as LLVM
import qualified LLVM.AST.Type                                      as LLVM

import Control.Monad
import Data.Monoid
import Data.String
import Text.Printf
import Prelude                                                      as P


-- | Generate function parameters that will specify the first and last (linear)
-- index of the array this thread should evaluate.
--
gangParam :: ShapeR sh -> (Operands sh, Operands sh, [LLVM.Parameter])
gangParam :: ShapeR sh -> (Operands sh, Operands sh, [Parameter])
gangParam ShapeR sh
shr =
  let start :: Name sh
start = Name sh
"ix.start"
      end :: Name sh
end   = Name sh
"ix.end"
      tp :: TypeR sh
tp    = ShapeR sh -> TypeR sh
forall sh. ShapeR sh -> TypeR sh
shapeType ShapeR sh
shr
  in
  (TypeR sh -> Name sh -> Operands sh
forall a. TypeR a -> Name a -> Operands a
local TypeR sh
tp Name sh
start, TypeR sh -> Name sh -> Operands sh
forall a. TypeR a -> Name a -> Operands a
local TypeR sh
tp Name sh
end, TypeR sh -> Name sh -> [Parameter]
forall t. TypeR t -> Name t -> [Parameter]
parameter TypeR sh
tp Name sh
start [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ TypeR sh -> Name sh -> [Parameter]
forall t. TypeR t -> Name t -> [Parameter]
parameter TypeR sh
tp Name sh
end)


-- | The worker ID of the calling thread
--
gangId :: (Operands Int, [LLVM.Parameter])
gangId :: (Operands Int, [Parameter])
gangId =
  let tid :: Name Int
tid = Name Int
"ix.tid"
  in (TypeR Int -> Name Int -> Operands Int
forall a. TypeR a -> Name a -> Operands a
local (ScalarType Int -> TypeR Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Name Int
tid, [ ScalarType Int -> Name Int -> Parameter
forall t. ScalarType t -> Name t -> Parameter
scalarParameter ScalarType Int
forall a. IsScalar a => ScalarType a
scalarType Name Int
tid ] )


-- Global function definitions
-- ---------------------------

data instance KernelMetadata Native = KM_Native ()

-- | Combine kernels into a single program
--
(+++) :: IROpenAcc Native aenv a -> IROpenAcc Native aenv a -> IROpenAcc Native aenv a
IROpenAcc [Kernel Native aenv a]
k1 +++ :: IROpenAcc Native aenv a
-> IROpenAcc Native aenv a -> IROpenAcc Native aenv a
+++ IROpenAcc [Kernel Native aenv a]
k2 = [Kernel Native aenv a] -> IROpenAcc Native aenv a
forall arch aenv arrs.
[Kernel arch aenv arrs] -> IROpenAcc arch aenv arrs
IROpenAcc ([Kernel Native aenv a]
k1 [Kernel Native aenv a]
-> [Kernel Native aenv a] -> [Kernel Native aenv a]
forall a. [a] -> [a] -> [a]
++ [Kernel Native aenv a]
k2)

-- | Create a single kernel program
--
makeOpenAcc :: UID -> Label -> [LLVM.Parameter] -> CodeGen Native () -> CodeGen Native (IROpenAcc Native aenv a)
makeOpenAcc :: UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv a)
makeOpenAcc UID
uid Label
name [Parameter]
param CodeGen Native ()
kernel = do
  Kernel Native aenv a
body  <- Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (Kernel Native aenv a)
forall aenv a.
Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (Kernel Native aenv a)
makeKernel (Label
name Label -> Label -> Label
forall a. Semigroup a => a -> a -> a
<> String -> Label
forall a. IsString a => String -> a
fromString (String -> String -> String
forall r. PrintfType r => String -> r
printf String
"_%s" (UID -> String
forall a. Show a => a -> String
show UID
uid))) [Parameter]
param CodeGen Native ()
kernel
  IROpenAcc Native aenv a -> CodeGen Native (IROpenAcc Native aenv a)
forall (m :: * -> *) a. Monad m => a -> m a
return (IROpenAcc Native aenv a
 -> CodeGen Native (IROpenAcc Native aenv a))
-> IROpenAcc Native aenv a
-> CodeGen Native (IROpenAcc Native aenv a)
forall a b. (a -> b) -> a -> b
$ [Kernel Native aenv a] -> IROpenAcc Native aenv a
forall arch aenv arrs.
[Kernel arch aenv arrs] -> IROpenAcc arch aenv arrs
IROpenAcc [Kernel Native aenv a
body]

-- | Create a complete kernel function by running the code generation process
-- specified in the final parameter.
--
makeKernel :: Label -> [LLVM.Parameter] -> CodeGen Native () -> CodeGen Native (Kernel Native aenv a)
makeKernel :: Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (Kernel Native aenv a)
makeKernel Label
name [Parameter]
param CodeGen Native ()
kernel = do
  ()
_    <- CodeGen Native ()
kernel
  [BasicBlock]
code <- CodeGen Native [BasicBlock]
forall arch. HasCallStack => CodeGen arch [BasicBlock]
createBlocks
  Kernel Native aenv a -> CodeGen Native (Kernel Native aenv a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Kernel Native aenv a -> CodeGen Native (Kernel Native aenv a))
-> Kernel Native aenv a -> CodeGen Native (Kernel Native aenv a)
forall a b. (a -> b) -> a -> b
$ Kernel :: forall arch aenv a.
Global -> KernelMetadata arch -> Kernel arch aenv a
Kernel
    { kernelMetadata :: KernelMetadata Native
kernelMetadata = () -> KernelMetadata Native
KM_Native ()
    , unKernel :: Global
unKernel       = Global
LLVM.functionDefaults
                     { returnType :: Type
LLVM.returnType  = Type
LLVM.VoidType
                     , name :: Name
LLVM.name        = Label -> Name
forall typed untyped.
(Downcast typed untyped, HasCallStack) =>
typed -> untyped
downcast Label
name
                     , parameters :: ([Parameter], Bool)
LLVM.parameters  = ([Parameter]
param, Bool
False)
                     , basicBlocks :: [BasicBlock]
LLVM.basicBlocks = [BasicBlock]
code
                     }
    }