-- |
-- This module provides a safe bindings to libsvm functions and structures with implicit memory handling.
module Data.SVM
  ( Vector,
    Problem,
    KernelType (..),
    Algorithm (..),
    ExtraParam (..),
    Model,
    train,
    train',
    crossValidate,
    crossValidate',
    loadModel,
    saveModel,
    predict,
    withPrintFn,
    CSvmPrintFn,
  )
where

import Control.Exception
import Control.Monad (when)
import Data.IntMap (IntMap, toList)
import qualified Data.IntMap as M
import Data.SVM.Raw
  ( CSvmModel,
    CSvmNode (..),
    CSvmParameter,
    CSvmPrintFn,
    CSvmProblem (..),
    c_clone_model_support_vectors,
    c_svm_check_parameter,
    c_svm_cross_validation,
    c_svm_destroy_model,
    c_svm_load_model,
    c_svm_predict,
    c_svm_save_model,
    c_svm_set_print_string_function,
    c_svm_train,
    createSvmPrintFnPtr,
    defaultCParam,
  )
import qualified Data.SVM.Raw as R
import Foreign.C.String (newCString, peekCString)
import Foreign.ForeignPtr
  ( ForeignPtr,
    newForeignPtr,
    withForeignPtr,
  )
import Foreign.Marshal.Alloc (alloca, free, malloc)
import Foreign.Marshal.Array
  ( allocaArray,
    newArray,
    newArray0,
    peekArray,
    withArray0,
  )
import Foreign.Ptr (Ptr, freeHaskellFunPtr, nullPtr)
import Foreign.Storable (peek, poke)

-- | Vector type provides a sparse implementation of vector. It uses IntMap as underlying implementation.
type Vector = IntMap Double

-- | SVM problem is a list of maps from training vectors to 1.0 or -1.0
type Problem = [(Double, Vector)]

-- | 'Model' is a wrapper over foreign pointer to 'CSvmModel'
newtype Model = Model (ForeignPtr CSvmModel)

-- | Kernel function for SVM algorithm.
data KernelType
  = -- | Linear kernel function, i.e. dot product
    Linear
  | -- | Gaussian radial basis function with parameter 'gamma'
    RBF {KernelType -> Double
gamma :: Double}
  | -- | Sigmoid kernel function
    Sigmoid {gamma :: Double, KernelType -> Double
coef0 :: Double}
  | -- | Inhomogeneous polynomial function
    Poly {gamma :: Double, coef0 :: Double, KernelType -> Int
degree :: Int}

-- | SVM Algorithm with parameters
data Algorithm
  = -- | c-SVC algorithm
    CSvc {Algorithm -> Double
c :: Double}
  | -- | nu-SVC algorithm
    NuSvc {Algorithm -> Double
nu :: Double}
  | -- | nu-SVR algorithm
    NuSvr {nu :: Double, c :: Double}
  | -- | eps-SVR algorithm
    EpsilonSvr {Algorithm -> Double
epsilon :: Double, c :: Double}
  | -- | One class SVM
    OneClassSvm {nu :: Double}

-- | Extra parameters of SVM implementation
data ExtraParam = ExtraParam
  { ExtraParam -> Double
cacheSize :: Double,
    ExtraParam -> Int
shrinking :: Int,
    ExtraParam -> Int
probability :: Int
  }

-- | Default extra parameters of SVM implamentation
defaultExtra :: ExtraParam
defaultExtra :: ExtraParam
defaultExtra = ExtraParam {cacheSize :: Double
cacheSize = Double
1000, shrinking :: Int
shrinking = Int
1, probability :: Int
probability = Int
0}

mergeKernel :: KernelType -> CSvmParameter -> CSvmParameter
mergeKernel :: KernelType -> CSvmParameter -> CSvmParameter
mergeKernel KernelType
Linear CSvmParameter
p = CSvmParameter
p {R.kernel_type = R.linear}
mergeKernel (RBF Double
g) CSvmParameter
p =
  CSvmParameter
p
    { R.kernel_type = R.rbf,
      R.gamma = realToFrac g
    }
mergeKernel (Sigmoid Double
g Double
cf) CSvmParameter
p =
  CSvmParameter
p
    { R.kernel_type = R.sigmoid,
      R.gamma = realToFrac g,
      R.coef0 = realToFrac cf
    }
mergeKernel (Poly Double
g Double
cf Int
d) CSvmParameter
p =
  CSvmParameter
p
    { R.kernel_type = R.poly,
      R.gamma = realToFrac g,
      R.coef0 = realToFrac cf,
      R.degree = fromIntegral d
    }

mergeAlgo :: Algorithm -> CSvmParameter -> CSvmParameter
mergeAlgo :: Algorithm -> CSvmParameter -> CSvmParameter
mergeAlgo (CSvc Double
cf) CSvmParameter
p =
  CSvmParameter
p
    { R.svm_type = R.cSvc,
      R.c = realToFrac cf
    }
mergeAlgo (NuSvc Double
n) CSvmParameter
p =
  CSvmParameter
p
    { R.svm_type = R.nuSvc,
      R.nu = realToFrac n
    }
mergeAlgo (NuSvr Double
n Double
cf) CSvmParameter
p =
  CSvmParameter
p
    { R.svm_type = R.nuSvr,
      R.nu = realToFrac n,
      R.c = realToFrac cf
    }
mergeAlgo (EpsilonSvr Double
e Double
cf) CSvmParameter
p =
  CSvmParameter
p
    { R.svm_type = R.epsilonSvr,
      R.eps = realToFrac e,
      R.c = realToFrac cf
    }
mergeAlgo (OneClassSvm Double
n) CSvmParameter
p =
  CSvmParameter
p
    { R.svm_type = R.oneClass,
      R.nu = realToFrac n
    }

mergeExtra :: ExtraParam -> CSvmParameter -> CSvmParameter
mergeExtra :: ExtraParam -> CSvmParameter -> CSvmParameter
mergeExtra (ExtraParam Double
cf Int
s Int
pr) CSvmParameter
p =
  CSvmParameter
p
    { R.cache_size = realToFrac cf,
      R.shrinking = fromIntegral s,
      R.probability = fromIntegral pr
    }

-------------------------------------------------------------------------------

convertToNodeArray :: Vector -> [CSvmNode]
convertToNodeArray :: Vector -> [CSvmNode]
convertToNodeArray = ((Int, Double) -> CSvmNode) -> [(Int, Double)] -> [CSvmNode]
forall a b. (a -> b) -> [a] -> [b]
map (Int, Double) -> CSvmNode
forall {a} {a}. (Integral a, Real a) => (a, a) -> CSvmNode
convertNode ([(Int, Double)] -> [CSvmNode])
-> (Vector -> [(Int, Double)]) -> Vector -> [CSvmNode]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector -> [(Int, Double)]
forall a. IntMap a -> [(Int, a)]
toList (Vector -> [(Int, Double)])
-> (Vector -> Vector) -> Vector -> [(Int, Double)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double -> Bool) -> Vector -> Vector
forall a. (a -> Bool) -> IntMap a -> IntMap a
M.filter (Double -> Double -> Bool
forall a. Eq a => a -> a -> Bool
/= Double
0)
  where
    convertNode :: (a, a) -> CSvmNode
convertNode (a
key, a
val) = CInt -> CDouble -> CSvmNode
CSvmNode (a -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
key) (a -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac a
val)

endMarker :: CSvmNode
endMarker :: CSvmNode
endMarker = CInt -> CDouble -> CSvmNode
CSvmNode (-CInt
1) CDouble
0.0

newCSvmNodeArray :: Vector -> IO (Ptr CSvmNode)
newCSvmNodeArray :: Vector -> IO (Ptr CSvmNode)
newCSvmNodeArray Vector
v = CSvmNode -> [CSvmNode] -> IO (Ptr CSvmNode)
forall a. Storable a => a -> [a] -> IO (Ptr a)
newArray0 CSvmNode
endMarker (Vector -> [CSvmNode]
convertToNodeArray Vector
v)

withCSvmNodeArray :: Vector -> (Ptr CSvmNode -> IO a) -> IO a
withCSvmNodeArray :: forall a. Vector -> (Ptr CSvmNode -> IO a) -> IO a
withCSvmNodeArray Vector
v = CSvmNode -> [CSvmNode] -> (Ptr CSvmNode -> IO a) -> IO a
forall a b. Storable a => a -> [a] -> (Ptr a -> IO b) -> IO b
withArray0 CSvmNode
endMarker (Vector -> [CSvmNode]
convertToNodeArray Vector
v)

newCSvmProblem :: Problem -> IO (Ptr CSvmProblem)
newCSvmProblem :: Problem -> IO (Ptr CSvmProblem)
newCSvmProblem Problem
lvs = do
  [Ptr CSvmNode]
nodePtrList <- ((Double, Vector) -> IO (Ptr CSvmNode))
-> Problem -> IO [Ptr CSvmNode]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Vector -> IO (Ptr CSvmNode)
newCSvmNodeArray (Vector -> IO (Ptr CSvmNode))
-> ((Double, Vector) -> Vector)
-> (Double, Vector)
-> IO (Ptr CSvmNode)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double, Vector) -> Vector
forall a b. (a, b) -> b
snd) Problem
lvs
  Ptr (Ptr CSvmNode)
nodePtrPtr <- [Ptr CSvmNode] -> IO (Ptr (Ptr CSvmNode))
forall a. Storable a => [a] -> IO (Ptr a)
newArray [Ptr CSvmNode]
nodePtrList
  Ptr CDouble
labelPtr <- [CDouble] -> IO (Ptr CDouble)
forall a. Storable a => [a] -> IO (Ptr a)
newArray (((Double, Vector) -> CDouble) -> Problem -> [CDouble]
forall a b. (a -> b) -> [a] -> [b]
map (Double -> CDouble
forall a b. (Real a, Fractional b) => a -> b
realToFrac (Double -> CDouble)
-> ((Double, Vector) -> Double) -> (Double, Vector) -> CDouble
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Double, Vector) -> Double
forall a b. (a, b) -> a
fst) Problem
lvs)
  let z :: CInt
z = Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> CInt) -> (Problem -> Int) -> Problem -> CInt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Problem -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Problem -> CInt) -> Problem -> CInt
forall a b. (a -> b) -> a -> b
$ Problem
lvs
  Ptr CSvmProblem
ptr <- IO (Ptr CSvmProblem)
forall a. Storable a => IO (Ptr a)
malloc
  Ptr CSvmProblem -> CSvmProblem -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr CSvmProblem
ptr (CSvmProblem -> IO ()) -> CSvmProblem -> IO ()
forall a b. (a -> b) -> a -> b
$ CInt -> Ptr CDouble -> Ptr (Ptr CSvmNode) -> CSvmProblem
CSvmProblem CInt
z Ptr CDouble
labelPtr Ptr (Ptr CSvmNode)
nodePtrPtr
  Ptr CSvmProblem -> IO (Ptr CSvmProblem)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return Ptr CSvmProblem
ptr

freeCSVmProblem :: Ptr CSvmProblem -> IO ()
freeCSVmProblem :: Ptr CSvmProblem -> IO ()
freeCSVmProblem Ptr CSvmProblem
ptr = do
  CSvmProblem
prob <- Ptr CSvmProblem -> IO CSvmProblem
forall a. Storable a => Ptr a -> IO a
peek Ptr CSvmProblem
ptr
  Ptr CDouble -> IO ()
forall a. Ptr a -> IO ()
free (Ptr CDouble -> IO ()) -> Ptr CDouble -> IO ()
forall a b. (a -> b) -> a -> b
$ CSvmProblem -> Ptr CDouble
y CSvmProblem
prob
  [Ptr CSvmNode]
vecList <- Int -> Ptr (Ptr CSvmNode) -> IO [Ptr CSvmNode]
forall a. Storable a => Int -> Ptr a -> IO [a]
peekArray (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> CInt -> Int
forall a b. (a -> b) -> a -> b
$ CSvmProblem -> CInt
l CSvmProblem
prob) (CSvmProblem -> Ptr (Ptr CSvmNode)
x CSvmProblem
prob)
  (Ptr CSvmNode -> IO ()) -> [Ptr CSvmNode] -> IO ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ Ptr CSvmNode -> IO ()
forall a. Ptr a -> IO ()
free [Ptr CSvmNode]
vecList
  Ptr (Ptr CSvmNode) -> IO ()
forall a. Ptr a -> IO ()
free (Ptr (Ptr CSvmNode) -> IO ()) -> Ptr (Ptr CSvmNode) -> IO ()
forall a b. (a -> b) -> a -> b
$ CSvmProblem -> Ptr (Ptr CSvmNode)
x CSvmProblem
prob
  Ptr CSvmProblem -> IO ()
forall a. Ptr a -> IO ()
free Ptr CSvmProblem
ptr

withProblem :: Problem -> (Ptr CSvmProblem -> IO a) -> IO a
withProblem :: forall a. Problem -> (Ptr CSvmProblem -> IO a) -> IO a
withProblem Problem
prob = IO (Ptr CSvmProblem)
-> (Ptr CSvmProblem -> IO ()) -> (Ptr CSvmProblem -> IO a) -> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket (Problem -> IO (Ptr CSvmProblem)
newCSvmProblem Problem
prob) Ptr CSvmProblem -> IO ()
freeCSVmProblem

---

withParam ::
  ExtraParam ->
  Algorithm ->
  KernelType ->
  (Ptr CSvmParameter -> IO a) ->
  IO a
withParam :: forall a.
ExtraParam
-> Algorithm -> KernelType -> (Ptr CSvmParameter -> IO a) -> IO a
withParam ExtraParam
extra Algorithm
algo KernelType
kern Ptr CSvmParameter -> IO a
f =
  let merge :: CSvmParameter -> CSvmParameter
merge = Algorithm -> CSvmParameter -> CSvmParameter
mergeAlgo Algorithm
algo (CSvmParameter -> CSvmParameter)
-> (CSvmParameter -> CSvmParameter)
-> CSvmParameter
-> CSvmParameter
forall b c a. (b -> c) -> (a -> b) -> a -> c
. KernelType -> CSvmParameter -> CSvmParameter
mergeKernel KernelType
kern (CSvmParameter -> CSvmParameter)
-> (CSvmParameter -> CSvmParameter)
-> CSvmParameter
-> CSvmParameter
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExtraParam -> CSvmParameter -> CSvmParameter
mergeExtra ExtraParam
extra
      param :: CSvmParameter
param = CSvmParameter -> CSvmParameter
merge CSvmParameter
defaultCParam
   in (Ptr CSvmParameter -> IO a) -> IO a
forall a b. Storable a => (Ptr a -> IO b) -> IO b
alloca ((Ptr CSvmParameter -> IO a) -> IO a)
-> (Ptr CSvmParameter -> IO a) -> IO a
forall a b. (a -> b) -> a -> b
$ \Ptr CSvmParameter
paramPtr -> Ptr CSvmParameter -> CSvmParameter -> IO ()
forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr CSvmParameter
paramPtr CSvmParameter
param IO () -> IO a -> IO a
forall a b. IO a -> IO b -> IO b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr CSvmParameter -> IO a
f Ptr CSvmParameter
paramPtr

checkParam :: Ptr CSvmProblem -> Ptr CSvmParameter -> IO ()
checkParam :: Ptr CSvmProblem -> Ptr CSvmParameter -> IO ()
checkParam Ptr CSvmProblem
probPtr Ptr CSvmParameter
paramPtr = do
  let errStr :: CString
errStr = Ptr CSvmProblem -> Ptr CSvmParameter -> CString
c_svm_check_parameter Ptr CSvmProblem
probPtr Ptr CSvmParameter
paramPtr
  Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (CString
errStr CString -> CString -> Bool
forall a. Eq a => a -> a -> Bool
/= CString
forall a. Ptr a
nullPtr) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ CString -> IO String
peekCString CString
errStr IO String -> (String -> IO ()) -> IO ()
forall a b. IO a -> (a -> IO b) -> IO b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= String -> IO ()
forall a. HasCallStack => String -> a
error (String -> IO ()) -> (String -> String) -> String -> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String
"svm: " String -> String -> String
forall a. [a] -> [a] -> [a]
++)

--

-- | Like 'train' but with extra parameters
train' :: ExtraParam -> Algorithm -> KernelType -> Problem -> IO Model
train' :: ExtraParam -> Algorithm -> KernelType -> Problem -> IO Model
train' ExtraParam
extra Algorithm
algo KernelType
kern Problem
prob =
  Problem -> (Ptr CSvmProblem -> IO Model) -> IO Model
forall a. Problem -> (Ptr CSvmProblem -> IO a) -> IO a
withProblem Problem
prob ((Ptr CSvmProblem -> IO Model) -> IO Model)
-> (Ptr CSvmProblem -> IO Model) -> IO Model
forall a b. (a -> b) -> a -> b
$ \Ptr CSvmProblem
probPtr ->
    ExtraParam
-> Algorithm
-> KernelType
-> (Ptr CSvmParameter -> IO Model)
-> IO Model
forall a.
ExtraParam
-> Algorithm -> KernelType -> (Ptr CSvmParameter -> IO a) -> IO a
withParam ExtraParam
extra Algorithm
algo KernelType
kern ((Ptr CSvmParameter -> IO Model) -> IO Model)
-> (Ptr CSvmParameter -> IO Model) -> IO Model
forall a b. (a -> b) -> a -> b
$ \Ptr CSvmParameter
paramPtr -> do
      Ptr CSvmProblem -> Ptr CSvmParameter -> IO ()
checkParam Ptr CSvmProblem
probPtr Ptr CSvmParameter
paramPtr
      Ptr CSvmModel
modelPtr <- Ptr CSvmProblem -> Ptr CSvmParameter -> IO (Ptr CSvmModel)
c_svm_train Ptr CSvmProblem
probPtr Ptr CSvmParameter
paramPtr
      CInt
_ <- Ptr CSvmModel -> IO CInt
c_clone_model_support_vectors Ptr CSvmModel
modelPtr
      ForeignPtr CSvmModel
modelForeignPtr <- FinalizerPtr CSvmModel
-> Ptr CSvmModel -> IO (ForeignPtr CSvmModel)
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FinalizerPtr CSvmModel
c_svm_destroy_model Ptr CSvmModel
modelPtr
      Model -> IO Model
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (Model -> IO Model) -> Model -> IO Model
forall a b. (a -> b) -> a -> b
$ ForeignPtr CSvmModel -> Model
Model ForeignPtr CSvmModel
modelForeignPtr

-- | The 'train' function allows training a 'Model' starting from a 'Problem'
-- by specifying an 'Algorithm' and a 'KernelType'
train :: Algorithm -> KernelType -> Problem -> IO Model
train :: Algorithm -> KernelType -> Problem -> IO Model
train = ExtraParam -> Algorithm -> KernelType -> Problem -> IO Model
train' ExtraParam
defaultExtra

-- | Like 'crossvalidate' but with extra parameters
crossValidate' ::
  ExtraParam ->
  Algorithm ->
  KernelType ->
  Problem ->
  Int ->
  IO [Double]
crossValidate' :: ExtraParam
-> Algorithm -> KernelType -> Problem -> Int -> IO [Double]
crossValidate' ExtraParam
extra Algorithm
algo KernelType
kern Problem
prob Int
nFold =
  Problem -> (Ptr CSvmProblem -> IO [Double]) -> IO [Double]
forall a. Problem -> (Ptr CSvmProblem -> IO a) -> IO a
withProblem Problem
prob ((Ptr CSvmProblem -> IO [Double]) -> IO [Double])
-> (Ptr CSvmProblem -> IO [Double]) -> IO [Double]
forall a b. (a -> b) -> a -> b
$ \Ptr CSvmProblem
probPtr ->
    ExtraParam
-> Algorithm
-> KernelType
-> (Ptr CSvmParameter -> IO [Double])
-> IO [Double]
forall a.
ExtraParam
-> Algorithm -> KernelType -> (Ptr CSvmParameter -> IO a) -> IO a
withParam ExtraParam
extra Algorithm
algo KernelType
kern ((Ptr CSvmParameter -> IO [Double]) -> IO [Double])
-> (Ptr CSvmParameter -> IO [Double]) -> IO [Double]
forall a b. (a -> b) -> a -> b
$ \Ptr CSvmParameter
paramPtr -> do
      Int
probLen <- (CInt -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (CInt -> Int) -> (CSvmProblem -> CInt) -> CSvmProblem -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CSvmProblem -> CInt
R.l) (CSvmProblem -> Int) -> IO CSvmProblem -> IO Int
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Ptr CSvmProblem -> IO CSvmProblem
forall a. Storable a => Ptr a -> IO a
peek Ptr CSvmProblem
probPtr
      Int -> (Ptr CDouble -> IO [Double]) -> IO [Double]
forall a b. Storable a => Int -> (Ptr a -> IO b) -> IO b
allocaArray Int
probLen ((Ptr CDouble -> IO [Double]) -> IO [Double])
-> (Ptr CDouble -> IO [Double]) -> IO [Double]
forall a b. (a -> b) -> a -> b
$ \Ptr CDouble
targetPtr -> do
        -- (length prob is inefficient)
        Ptr CSvmProblem -> Ptr CSvmParameter -> IO ()
checkParam Ptr CSvmProblem
probPtr Ptr CSvmParameter
paramPtr
        let c_nFold :: CInt
c_nFold = Int -> CInt
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
nFold
        Ptr CSvmProblem
-> Ptr CSvmParameter -> CInt -> Ptr CDouble -> IO ()
c_svm_cross_validation Ptr CSvmProblem
probPtr Ptr CSvmParameter
paramPtr CInt
c_nFold Ptr CDouble
targetPtr
        (CDouble -> Double) -> [CDouble] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map CDouble -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac ([CDouble] -> [Double]) -> IO [CDouble] -> IO [Double]
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` Int -> Ptr CDouble -> IO [CDouble]
forall a. Storable a => Int -> Ptr a -> IO [a]
peekArray Int
probLen Ptr CDouble
targetPtr

-- | Stratified cross validation
crossValidate :: Algorithm -> KernelType -> Problem -> Int -> IO [Double]
crossValidate :: Algorithm -> KernelType -> Problem -> Int -> IO [Double]
crossValidate = ExtraParam
-> Algorithm -> KernelType -> Problem -> Int -> IO [Double]
crossValidate' ExtraParam
defaultExtra

-----------------------------------------------------------------------

-- | Save model to the file
saveModel :: Model -> FilePath -> IO ()
saveModel :: Model -> String -> IO ()
saveModel (Model ForeignPtr CSvmModel
modelForeignPtr) String
path =
  ForeignPtr CSvmModel -> (Ptr CSvmModel -> IO ()) -> IO ()
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CSvmModel
modelForeignPtr ((Ptr CSvmModel -> IO ()) -> IO ())
-> (Ptr CSvmModel -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr CSvmModel
modelPtr -> do
    CString
pathString <- String -> IO CString
newCString String
path
    CInt
ret <- CString -> Ptr CSvmModel -> IO CInt
c_svm_save_model CString
pathString Ptr CSvmModel
modelPtr
    Bool -> IO () -> IO ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (CInt
ret CInt -> CInt -> Bool
forall a. Eq a => a -> a -> Bool
/= CInt
0) (IO () -> IO ()) -> IO () -> IO ()
forall a b. (a -> b) -> a -> b
$ String -> IO ()
forall a. HasCallStack => String -> a
error (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
"svm: error saving the model:" String -> String -> String
forall a. [a] -> [a] -> [a]
++ CInt -> String
forall a. Show a => a -> String
show CInt
ret

-- | Load model from the file
loadModel :: FilePath -> IO Model
loadModel :: String -> IO Model
loadModel String
path = do
  Ptr CSvmModel
modelPtr <- CString -> IO (Ptr CSvmModel)
c_svm_load_model (CString -> IO (Ptr CSvmModel)) -> IO CString -> IO (Ptr CSvmModel)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< String -> IO CString
newCString String
path
  ForeignPtr CSvmModel -> Model
Model (ForeignPtr CSvmModel -> Model)
-> IO (ForeignPtr CSvmModel) -> IO Model
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
`fmap` FinalizerPtr CSvmModel
-> Ptr CSvmModel -> IO (ForeignPtr CSvmModel)
forall a. FinalizerPtr a -> Ptr a -> IO (ForeignPtr a)
newForeignPtr FinalizerPtr CSvmModel
c_svm_destroy_model Ptr CSvmModel
modelPtr

-- | Predict a value for 'Vector' by using 'Model'
predict :: Model -> Vector -> IO Double
predict :: Model -> Vector -> IO Double
predict (Model ForeignPtr CSvmModel
modelForeignPtr) Vector
vector = IO Double
action
  where
    action :: IO Double
    action :: IO Double
action = ForeignPtr CSvmModel -> (Ptr CSvmModel -> IO Double) -> IO Double
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr CSvmModel
modelForeignPtr ((Ptr CSvmModel -> IO Double) -> IO Double)
-> (Ptr CSvmModel -> IO Double) -> IO Double
forall a b. (a -> b) -> a -> b
$ \Ptr CSvmModel
modelPtr ->
      Vector -> (Ptr CSvmNode -> IO Double) -> IO Double
forall a. Vector -> (Ptr CSvmNode -> IO a) -> IO a
withCSvmNodeArray Vector
vector ((CDouble -> Double) -> IO CDouble -> IO Double
forall a b. (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap CDouble -> Double
forall a b. (Real a, Fractional b) => a -> b
realToFrac (IO CDouble -> IO Double)
-> (Ptr CSvmNode -> IO CDouble) -> Ptr CSvmNode -> IO Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr CSvmModel -> Ptr CSvmNode -> IO CDouble
c_svm_predict Ptr CSvmModel
modelPtr)

-- | Wrapper to change the libsvm output reporting function.
--
--  libsvm by default writes some statistics to stdout. If you don't
--  want any output from libsvm, you can do e.g.:
--
--  >>> withPrintFn (\_ -> return ()) $ train (NuSvc 0.25) (RBF 1) feats
withPrintFn :: CSvmPrintFn -> IO a -> IO a
withPrintFn :: forall a. CSvmPrintFn -> IO a -> IO a
withPrintFn CSvmPrintFn
printfn IO a
body =
  IO (FunPtr CSvmPrintFn)
-> (FunPtr CSvmPrintFn -> IO ())
-> (FunPtr CSvmPrintFn -> IO a)
-> IO a
forall a b c. IO a -> (a -> IO b) -> (a -> IO c) -> IO c
bracket
    ( do
        FunPtr CSvmPrintFn
c_printfn <- CSvmPrintFn -> IO (FunPtr CSvmPrintFn)
createSvmPrintFnPtr CSvmPrintFn
printfn
        FunPtr CSvmPrintFn -> IO ()
c_svm_set_print_string_function FunPtr CSvmPrintFn
c_printfn
        FunPtr CSvmPrintFn -> IO (FunPtr CSvmPrintFn)
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return FunPtr CSvmPrintFn
c_printfn
    )
    FunPtr CSvmPrintFn -> IO ()
forall a. FunPtr a -> IO ()
freeHaskellFunPtr
    (IO a -> FunPtr CSvmPrintFn -> IO a
forall a b. a -> b -> a
const IO a
body)