{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.LLVM.PTX.Compile.Libdevice (
withLibdeviceNVVM,
withLibdeviceNVPTX,
) where
import LLVM.Context
import qualified LLVM.Module as LLVM
import LLVM.AST as AST
import LLVM.AST.Global as G
import LLVM.AST.Linkage
import Data.Array.Accelerate.LLVM.PTX.Compile.Libdevice.Load
import qualified Data.Array.Accelerate.LLVM.PTX.Debug as Debug
import Foreign.CUDA.Analysis
import Control.Monad
import Data.ByteString ( ByteString )
import Data.ByteString.Short.Char8 ( ShortByteString )
import Data.HashSet ( HashSet )
import Data.List
import Data.Maybe
import Text.Printf
import qualified Data.ByteString.Short.Char8 as S8
import qualified Data.ByteString.Short.Extra as BS
import qualified Data.HashSet as Set
withLibdeviceNVPTX
:: DeviceProperties
-> Context
-> Module
-> (LLVM.Module -> IO a)
-> IO a
withLibdeviceNVPTX dev ctx ast next =
case Set.null externs of
True -> LLVM.withModuleFromAST ctx ast next
False ->
LLVM.withModuleFromAST ctx ast $ \mdl ->
LLVM.withModuleFromAST ctx nvvmReflect $ \refl ->
LLVM.withModuleFromAST ctx (internalise externs libdev) $ \libd -> do
LLVM.linkModules mdl refl
LLVM.linkModules mdl libd
Debug.traceIO Debug.dump_cc msg
next mdl
where
libdev = (libdevice arch) { moduleTargetTriple = moduleTargetTriple ast
, moduleDataLayout = moduleDataLayout ast
}
externs = analyse ast
arch = computeCapability dev
msg = printf "cc: linking with libdevice: %s"
$ intercalate ", "
$ map S8.unpack
$ Set.toList externs
withLibdeviceNVVM
:: DeviceProperties
-> Context
-> Module
-> ([(String, ByteString)] -> LLVM.Module -> IO a)
-> IO a
withLibdeviceNVVM dev ctx ast next =
LLVM.withModuleFromAST ctx ast $ \mdl -> do
when withlib $ Debug.traceIO Debug.dump_cc msg
next lib mdl
where
externs = analyse ast
withlib = not (Set.null externs)
lib | withlib = [ nvvmReflect, libdevice arch ]
| otherwise = []
arch = computeCapability dev
msg = printf "cc: linking with libdevice: %s"
$ intercalate ", "
$ map S8.unpack
$ Set.toList externs
analyse :: Module -> HashSet ShortByteString
analyse Module{..} =
let intrinsic (GlobalDefinition Function{..})
| null basicBlocks
, Name n <- name
, "__nv_" <- BS.take 5 n
= Just n
intrinsic _
= Nothing
in
Set.fromList (mapMaybe intrinsic moduleDefinitions)
internalise :: HashSet ShortByteString -> Module -> Module
internalise externals Module{..} =
let internal (GlobalDefinition Function{..})
| Name n <- name
, not (Set.member n externals)
, not (null basicBlocks)
= GlobalDefinition Function { linkage=Internal, .. }
internal x
= x
in
Module { moduleDefinitions = map internal moduleDefinitions, .. }