{-# LANGUAGE BangPatterns #-} {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MagicHash #-} {-# LANGUAGE RoleAnnotations #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE UnboxedTuples #-} module ML.DMLC.XGBoost.FFI where import Foundation import Foundation.Array.Internal import Foundation.Class.Storable import Foundation.Foreign import Foundation.Primitive import qualified Prelude (Show(..)) import Data.Bits ((.|.)) import qualified Foreign.Storable import Foreign.Marshal.Alloc (alloca) import GHC.Exts import ML.DMLC.XGBoost.Exception import ML.DMLC.XGBoost.Foreign newtype DMatrix = DMatrix (Ptr ()) deriving (Eq, Storable, Foreign.Storable.Storable) {-- Instances to make `DMatirx` (`Ptr ()`) as foundation's PrimType, GeneralizedNewtypeDeriving doesn't work here. --} instance PrimType DMatrix where primSizeInBytes _ = size (Proxy :: Proxy (Ptr ())) {-# INLINE primSizeInBytes #-} primShiftToBytes _ = let (CountOf k) = size (Proxy :: Proxy (Ptr ())) in if k == 4 then 3 else 5 -- TODO may be wrong {-# INLINE primShiftToBytes #-} primBaUIndex ba (Offset (I# n)) = DMatrix (Ptr (indexAddrArray# ba n)) {-# INLINE primBaUIndex #-} primMbaURead mba (Offset (I# n)) = primitive $ \s1 -> let !(# s2, r1 #) = readAddrArray# mba n s1 in (# s2, DMatrix (Ptr r1) #) {-# INLINE primMbaURead #-} primMbaUWrite mba (Offset (I# n)) (DMatrix (Ptr w)) = primitive $ \s1 -> (# writeAddrArray# mba n w s1, () #) {-# INLINE primMbaUWrite #-} primAddrIndex addr (Offset (I# n)) = DMatrix (Ptr (indexAddrOffAddr# addr n)) {-# INLINE primAddrIndex #-} primAddrRead addr (Offset (I# n)) = primitive $ \s1 -> let !(# s2, r1 #) = readAddrOffAddr# addr n s1 in (# s2, DMatrix (Ptr r1) #) {-# INLINE primAddrRead #-} primAddrWrite addr (Offset (I# n)) (DMatrix (Ptr w)) = primitive $ \s1 -> (# writeAddrOffAddr# addr n w s1, () #) {-# INLINE primAddrWrite #-} type DMatrixArray = Ptr DMatrix newtype Booster = Booster (Ptr ()) deriving (Storable, Foreign.Storable.Storable) newtype DataIter = DataIter (Ptr ()) deriving (Storable, Foreign.Storable.Storable) newtype DataHolder = DataHolder (Ptr ()) deriving (Storable, Foreign.Storable.Storable) newtype XGBoostBatchCSR = XGBoostBatchCSR (Ptr ()) deriving (Storable, Foreign.Storable.Storable) type XGBCallbackSetData = Ptr () -- TODO -- type XGBCallbackSetData = -- FunPtr (DataHolder -- -> XGBoostBatchCSR -- -> IO Int) type XGBCallbackDataIterNext = Ptr () -- TODO -- type XGBCallbackDataIterNext = -- FunPtr (DataIter -- -> XGBCallbackSetData -- -> DataHolderHandle -- -> IO Int) {-- Foreign Imports ----------------------------------------------------------} foreign import ccall unsafe "XGBGetLastError" c_xgbGetLastError :: IO StringPtr foreign import ccall unsafe "XGDMatrixCreateFromFile" c_xgDMatrixCreateFromFile :: StringPtr -> Int32 -> Ptr DMatrix -> IO Int32 foreign import ccall unsafe "XGDMatrixCreateFromDataIter" c_xgDMatrixCreateFromDataIter :: DataIter -> XGBCallbackDataIterNext -> StringPtr -> Ptr DMatrix -> IO Int32 foreign import ccall unsafe "XGDMatrixCreateFromMat" c_xgDMatrixCreateFromMat :: FloatArray -> Word64 -> Word64 -> CFloat -> Ptr DMatrix -> IO Int32 foreign import ccall unsafe "XGDMatrixFree" c_xgDMatrixFree :: DMatrix -> IO Int32 foreign import ccall unsafe "XGDMatrixSaveBinary" c_xgDMatrixSaveBinary :: DMatrix -> StringPtr -> Int32 -> IO Int32 foreign import ccall unsafe "XGDMatrixSetFloatInfo" c_xgDMatrixSetFloatInfo :: DMatrix -> StringPtr -> FloatArray -> Word64 -> IO Int32 foreign import ccall unsafe "XGDMatrixSetUIntInfo" c_xgDMatrixSetUIntInfo :: DMatrix -> StringPtr -> UIntArray -> Word64 -> IO Int32 foreign import ccall unsafe "XGDMatrixGetFloatInfo" c_xgDMatrixGetFloatInfo :: DMatrix -> StringPtr -> Ptr Word64 -> Ptr FloatArray -> IO Int32 foreign import ccall unsafe "XGDMatrixGetUIntInfo" c_xgDMatrixGetUIntInfo :: DMatrix -> StringPtr -> Ptr Word64 -> Ptr UIntArray -> IO Int32 foreign import ccall unsafe "XGDMatrixNumRow" c_xgDMatrixNumRow :: DMatrix -> Ptr Word64 -> IO Int32 foreign import ccall unsafe "XGDMatrixNumCol" c_xgDMatrixNumCol :: DMatrix -> Ptr Word64 -> IO Int32 foreign import ccall unsafe "XGBoosterCreate" c_xgBoosterCreate :: DMatrixArray -> Word64 -> Ptr Booster -> IO Int32 foreign import ccall unsafe "XGBoosterFree" c_xgBoosterFree :: Booster -> IO Int32 foreign import ccall unsafe "XGBoosterSetParam" c_xgBoosterSetParam :: Booster -> StringPtr -> StringPtr -> IO Int32 foreign import ccall unsafe "XGBoosterUpdateOneIter" c_xgBoosterUpdateOneIter :: Booster -> Int32 -> DMatrix -> IO Int32 foreign import ccall unsafe "XGBoosterBoostOneIter" c_xgBoosterBoostOneIter :: Booster -> DMatrix -> FloatArray -> FloatArray -> Word64 -> IO Int32 foreign import ccall unsafe "XGBoosterEvalOneIter" c_xgBoosterEvalOneIter :: Booster -> Int32 -> DMatrixArray -> StringArray -> Word64 -> Ptr StringPtr -> IO Int32 foreign import ccall unsafe "XGBoosterPredict" c_xgBoosterPredict :: Booster -> DMatrix -> Int32 -> Int32 -> Ptr Word64 -> Ptr FloatArray -> IO Int32 foreign import ccall unsafe "XGBoosterLoadModel" c_xgBoosterLoadModel :: Booster -> StringPtr -> IO Int32 foreign import ccall unsafe "XGBoosterSaveModel" c_xgBoosterSaveModel :: Booster -> StringPtr -> IO Int32 foreign import ccall unsafe "XGBoosterLoadModelFromBuffer" c_xgBoosterLoadModelFromBuffer :: Booster -> ByteArray -> Word64 -> IO Int32 foreign import ccall unsafe "XGBoosterGetModelRaw" c_xgBoosterGetModelRaw :: Booster -> Ptr Word64 -> Ptr ByteArray -> IO Int32 foreign import ccall unsafe "XGBoosterDumpModel" c_xgBoosterDumpModel :: Booster -> StringArray -> Int32 -> Ptr Word64 -> Ptr StringArray -> IO Int32 foreign import ccall unsafe "XGBoosterGetAttr" c_xgBoosterGetAttr :: Booster -> StringPtr -> Ptr StringPtr -- TODO -> Ptr Int32 -> IO Int32 foreign import ccall unsafe "XGBoosterSetAttr" c_xgBoosterSetAttr :: Booster -> StringPtr -> StringPtr -> IO Int32 foreign import ccall unsafe "XGBoosterGetAttrNames" c_xgBoosterGetAttrNames :: Booster -> Ptr Word64 -> Ptr StringArray -> IO Int32 foreign import ccall unsafe "XGBoosterLoadRabitCheckpoint" c_xgBoosterLoadRabitCheckpoint :: Booster -> Ptr Int32 -> IO Int32 foreign import ccall unsafe "XGBoosterSaveRabitCheckpoint" c_xgBoosterSaveRabitCheckpoint :: Booster -> IO Int32 {-- Tag Types ----------------------------------------------------------------} -- | In XGBoost, the float info is correctly restricted to DMatrix's meta information, namely label and weight. -- -- Ref: /https://github.com/dmlc/xgboost/issues/1026#issuecomment-199873890/. data FloatInfoField = LabelInfo | WeightInfo | BaseMarginInfo deriving Eq instance Prelude.Show FloatInfoField where show LabelInfo = "label" show WeightInfo = "weight" show BaseMarginInfo = "base_margin" -- | In XGBoost, the only uint field valid is "root_index". -- -- Ref: /https://github.com/dmlc/xgboost/issues/1787#issuecomment-261653748/. data UIntInfoField = RootIndexInfo deriving Eq instance Prelude.Show UIntInfoField where show RootIndexInfo = "root_index" -- | See https://github.com/dmlc/xgboost/blob/master/include/xgboost/c_api.h#L399 data PredictMask = Normal | Margin | LeafIndex | FeatureContrib instance Enum PredictMask where toEnum 0 = Normal toEnum 1 = Margin toEnum 2 = LeafIndex toEnum 4 = FeatureContrib toEnum _ = error "No such PredictMask" fromEnum Normal = 0 fromEnum Margin = 1 fromEnum LeafIndex = 2 fromEnum FeatureContrib = 4 {-- Error Handling -----------------------------------------------------------} guard_ffi :: IO Int32 -> IO () guard_ffi action = action >>= \r -> if r == 0 then return () else xgbGetLastError >>= throw . XGBError r xgbGetLastError :: IO String xgbGetLastError = c_xgbGetLastError >>= getString {-- DMatrix FFI Bindings -----------------------------------------------------} xgbFromFile :: String -- ^ file name -> Bool -- ^ print messages during loading -> IO DMatrix xgbFromFile filename slient = alloca $ \pm -> withString filename $ \ps -> do guard_ffi $ c_xgDMatrixCreateFromFile ps (boolToInt32 slient) pm peek pm xgbFromDataIter :: DataIter -> XGBCallbackDataIterNext -> String -> IO DMatrix xgbFromDataIter iter callback cacheinfo = alloca $ \pm -> withString cacheinfo $ \pc -> do guard_ffi $ c_xgDMatrixCreateFromDataIter iter callback pc pm peek pm xgbFromMat :: UArray Float -- ^ mat -> Int -- ^ rows -> Int -- ^ columns -> Float -- ^ missing value -> IO DMatrix xgbFromMat arr r c missing = alloca $ \pm -> withPtr arr $ \parr -> do guard_ffi $ c_xgDMatrixCreateFromMat parr (fromIntegral r) (fromIntegral c) (CFloat missing) pm peek pm dmatrixFree :: DMatrix -> IO () dmatrixFree = guard_ffi . c_xgDMatrixFree xgbSetFloatInfo :: DMatrix -> FloatInfoField -- ^ label field -> UArray Float -- ^ info vector -> IO () xgbSetFloatInfo dm field value = do let (CountOf len) = length value withString (show field) $ \ps -> withPtr value $ \pv -> guard_ffi $ c_xgDMatrixSetFloatInfo dm ps pv (fromIntegral len) xgbGetFloatInfo :: DMatrix -> FloatInfoField -- ^ label field -> IO (UArray Float) -- ^ info vector xgbGetFloatInfo dm field = alloca $ \plen -> alloca $ \parr -> withString (show field) $ \ps -> do guard_ffi $ c_xgDMatrixGetFloatInfo dm ps plen parr len <- peek plen arr <- peek parr peekArray (CountOf (fromIntegral len)) arr xgbGetLabel :: DMatrix -> IO (UArray Float) xgbGetLabel mat = xgbGetFloatInfo mat LabelInfo xgbGetWeight :: DMatrix -> IO (UArray Float) xgbGetWeight mat = xgbGetFloatInfo mat WeightInfo xgbGetBaseMargin :: DMatrix -> IO (UArray Float) xgbGetBaseMargin mat = xgbGetFloatInfo mat BaseMarginInfo xgbSetLabel :: DMatrix -> UArray Float -> IO () xgbSetLabel mat = xgbSetFloatInfo mat LabelInfo xgbSetWeight :: DMatrix -> UArray Float -> IO () xgbSetWeight mat = xgbSetFloatInfo mat WeightInfo xgbSetBaseMargin :: DMatrix -> UArray Float -> IO () xgbSetBaseMargin mat = xgbSetFloatInfo mat BaseMarginInfo xgbSetUIntInfo :: DMatrix -> UIntInfoField -- ^ label field -> UArray Word32 -- ^ info vector -> IO () xgbSetUIntInfo dm field value = do let (CountOf len) = length value withString (show field) $ \ps -> withPtr value $ \pv -> guard_ffi $ c_xgDMatrixSetUIntInfo dm ps pv (fromIntegral len) xgbGetUIntInfo :: DMatrix -> UIntInfoField -- ^ label field -> IO (UArray Word32) -- ^ info vector xgbGetUIntInfo dm field = alloca $ \plen -> alloca $ \parr -> withString (show field) $ \ps -> do guard_ffi $ c_xgDMatrixGetUIntInfo dm ps plen parr len <- peek plen arr <- peek parr peekArray (CountOf (fromIntegral len)) arr xgbMatRow :: DMatrix -> IO Integer xgbMatRow dm = alloca $ \pnum -> do guard_ffi $ c_xgDMatrixNumRow dm pnum fromIntegral <$> peek pnum xgbMatCol :: DMatrix -> IO Integer xgbMatCol dm = alloca $ \pnum -> do guard_ffi $ c_xgDMatrixNumCol dm pnum fromIntegral <$> peek pnum {-- Booster FFI Bindings -----------------------------------------------------} xgbBooster :: [DMatrix] -> IO Booster xgbBooster dms = alloca $ \pb -> do let dms' = fromList dms (CountOf len) = length dms' withPtr dms' $ \pdm -> do guard_ffi $ c_xgBoosterCreate pdm (fromIntegral len) pb peek pb boosterFree :: Booster -> IO () boosterFree = guard_ffi . c_xgBoosterFree setParam :: Booster -> String -- ^ Name of parameter. -> String -- ^ Value of parameter. -> IO () setParam booster name value = withString name $ \pname -> withString value $ \pvalue -> guard_ffi $ c_xgBoosterSetParam booster pname pvalue updateOneIter :: Booster -> Int32 -- ^ Current iteration rounds -> DMatrix -- ^ Training data -> IO () updateOneIter booster iter dtrain = guard_ffi $ c_xgBoosterUpdateOneIter booster iter dtrain boostOneIter :: Booster -> DMatrix -- ^ Training data -> UArray Float -- ^ Gradient statistics -> UArray Float -- ^ Second order gradient statistics, should have the same length with gradient statistics array (but not checked) -> IO () boostOneIter booster dtrain grad hess = do let (CountOf nlen) = length grad withPtr grad $ \pgrad -> withPtr hess $ \phess -> guard_ffi $ c_xgBoosterBoostOneIter booster dtrain pgrad phess (fromIntegral nlen) evalOneIter :: Booster -> Int32 -- ^ Current iteration rounds -> [DMatrix] -- ^ Pointers to data to be evaluated -> [String] -- ^ Names of each data, should have the same length with data array (but not checked) -> IO String -- ^ The string containing evaluation statistics evalOneIter booster iter dms names = do let dms' = fromList dms (CountOf nlen) = length dms' alloca $ \pstat -> withPtr dms' $ \pdms -> withStringArray names $ \pnames -> do guard_ffi $ c_xgBoosterEvalOneIter booster iter pdms pnames (fromIntegral nlen) pstat peek pstat >>= getString boosterPredict :: Booster -> DMatrix -> [PredictMask] -> Int32 -> IO (UArray Float) boosterPredict booster dmat masks ntree = do let mask = fromIntegral $ foldl' (.|.) (fromEnum Normal) (fromEnum <$> masks) alloca $ \plen -> alloca $ \parr -> do guard_ffi $ c_xgBoosterPredict booster dmat mask ntree plen parr len <- peek plen arr <- peek parr peekArray (CountOf (fromIntegral len)) arr loadModel :: Booster -> String -- ^ File name -> IO () loadModel booster fname = withString fname $ \pfname -> guard_ffi $ c_xgBoosterLoadModel booster pfname saveModel :: Booster -> String -- ^ File name -> IO () saveModel booster fname = withString fname $ \pfname -> guard_ffi $ c_xgBoosterSaveModel booster pfname loadModelFromBuffer :: Booster -> ByteArray -- ^ Pointer to buffer -> Int32 -> IO () loadModelFromBuffer booster buffer nlen = guard_ffi $ c_xgBoosterLoadModelFromBuffer booster buffer (fromIntegral nlen) getBoosterAttr :: Booster -> String -> IO String getBoosterAttr booster name = alloca $ \pout -> alloca $ \psucc -> withString name $ \pname -> do guard_ffi $ c_xgBoosterGetAttr booster pname pout psucc succ' <- peek psucc if int32ToBool succ' then peek pout >>= getString else return "" setBoosterAttr :: Booster -> String -> String -> IO () setBoosterAttr booster name value = withString name $ \pname -> withString value $ \pvalue -> guard_ffi $ c_xgBoosterSetAttr booster pname pvalue getAttrNames :: Booster -> IO [String] getAttrNames booster = alloca $ \plen -> alloca $ \pout -> do guard_ffi $ c_xgBoosterGetAttrNames booster plen pout nlen <- peek plen peek pout >>= getStringArray' (CountOf (fromIntegral nlen)) loadRabitCheckpoint :: Booster -> IO Int32 -- ^ Return output version of the model loadRabitCheckpoint booster = alloca $ \pversion -> do guard_ffi $ c_xgBoosterLoadRabitCheckpoint booster pversion peek pversion saveRabitCheckpoint :: Booster -> IO () saveRabitCheckpoint booster = guard_ffi $ c_xgBoosterSaveRabitCheckpoint booster