module ML.DMLC.XGBoost.Rabit.FFI where
import Foundation
import Foundation.Array.Internal
import Foundation.Class.Storable
import Foundation.Collection
import Foundation.Foreign
import Foreign.Ptr (nullPtr)
import Foreign.Marshal.Alloc (alloca)
import qualified Foreign.Storable (peek)
import ML.DMLC.XGBoost.Foreign
foreign import ccall unsafe "RabitInit" c_rabitInit
:: Int32
-> StringArray
-> IO ()
foreign import ccall unsafe "RabitFinalize" c_rabitFinalize
:: IO ()
foreign import ccall unsafe "RabitGetRank" c_rabitGetRank
:: IO Int32
foreign import ccall unsafe "RabitGetWorldSize" c_rabitGetWorldSize
:: IO Int32
foreign import ccall unsafe "RabitIsDistributed" c_rabitIsDistributed
:: IO Int32
foreign import ccall unsafe "RabitTrackerPrint" c_rabitTrackerPrint
:: StringPtr
-> IO ()
foreign import ccall unsafe "RabitGetProcessorName" c_rabitGetProcessorName
:: StringPtr
-> Ptr CULong
-> CULong
-> IO ()
foreign import ccall unsafe "RabitBroadcast" c_rabitBroadcast
:: Ptr a
-> CULong
-> Int32
-> IO ()
foreign import ccall unsafe "RabitAllreduce" c_rabitAllreduce
:: Ptr a
-> CSize
-> Int32
-> Int32
-> Ptr ()
-> Ptr ()
-> IO ()
foreign import ccall unsafe "RabitLoadCheckPoint" c_rabitLoadCheckPoint
:: Ptr StringPtr
-> Ptr CULong
-> Ptr StringPtr
-> Ptr CULong
-> IO Int32
foreign import ccall unsafe "RabitCheckPoint" c_rabitCheckPoint
:: StringPtr
-> CULong
-> StringPtr
-> CULong
-> IO ()
foreign import ccall unsafe "RabitVersionNumber" c_rabitVersionNumber
:: IO Int32
foreign import ccall unsafe "RabitLinkTag" c_rabitLinkTag
:: IO Int32
rabitInit :: [String] -> IO ()
rabitInit args = do
let (CountOf nlen) = length args
argv = fromIntegral nlen
withStringArray args $ \pargs ->
c_rabitInit argv pargs
rabitFinalize :: IO ()
rabitFinalize = c_rabitFinalize
rabitGetRank :: IO Int32
rabitGetRank = c_rabitGetRank
rabitGetWordSize :: IO Int32
rabitGetWordSize = c_rabitGetWorldSize
rabitIsDistributed :: IO Bool
rabitIsDistributed = int32ToBool <$> c_rabitIsDistributed
rabitTrackerPrint :: String -> IO ()
rabitTrackerPrint msg = withString msg $ \pmsg -> c_rabitTrackerPrint pmsg
rabitGetProcessorName :: IO String
rabitGetProcessorName = do
let nlimit = 64
buf <- mutNew (CountOf nlimit)
alloca $ \plen ->
withMutablePtr buf $ \pbuf -> do
c_rabitGetProcessorName pbuf plen (fromIntegral nlimit)
nlen <- Foreign.Storable.peek plen
getString' (CountOf (fromIntegral nlen)) pbuf
rabitBoradcast
:: Ptr a
-> Int32
-> Int32
-> IO ()
rabitBoradcast pdata nlen root = c_rabitBroadcast pdata (fromIntegral nlen) root
data AllreduceOpType = KMax | KMin | KSum | KBitwiseOR deriving Eq
instance Enum AllreduceOpType where
toEnum 0 = KMax
toEnum 1 = KMin
toEnum 2 = KSum
toEnum 3 = KBitwiseOR
toEnum _ = error "No such AllreduceOpType"
fromEnum KMax = 0
fromEnum KMin = 2
fromEnum KSum = 3
fromEnum KBitwiseOR = 4
data AllreduceDataType = KChar | KUChar | KInt | KUInt | KLong | KULong | KFloat | KDouble | KLongLong | KULongLong deriving Eq
instance Enum AllreduceDataType where
toEnum 0 = KChar
toEnum 1 = KUChar
toEnum 2 = KInt
toEnum 3 = KUInt
toEnum 4 = KLong
toEnum 5 = KULong
toEnum 6 = KFloat
toEnum 7 = KDouble
toEnum 8 = KLongLong
toEnum 9 = KULongLong
toEnum _ = error "No such AllreduceDataType"
fromEnum KChar = 0
fromEnum KUChar = 1
fromEnum KInt = 2
fromEnum KUInt = 3
fromEnum KLong = 4
fromEnum KULong = 5
fromEnum KFloat = 6
fromEnum KDouble = 7
fromEnum KLongLong = 8
fromEnum KULongLong = 9
rabitAllreduce
:: Ptr a
-> Int32
-> AllreduceDataType
-> AllreduceOpType
-> IO ()
rabitAllreduce pdata count dtype optype =
c_rabitAllreduce pdata
(fromIntegral count)
(fromIntegral . fromEnum $ dtype)
(fromIntegral . fromEnum $ optype)
nullPtr
nullPtr
rabitVersionNumber :: IO Int32
rabitVersionNumber = c_rabitVersionNumber
rabitLinkTag :: IO Int32
rabitLinkTag = c_rabitLinkTag