{-# LANGUAGE ForeignFunctionInterface #-}
module Numerical.Integration
  (integration, IntegralResult(..))
  where
import           Foreign.Marshal.Alloc (free, mallocBytes)
import           Foreign.Ptr           (FunPtr, Ptr, freeHaskellFunPtr)
import           Foreign.Storable      (peek, sizeOf)
import           Foreign.C             (CDouble(..), CInt(..))               

data IntegralResult = IntegralResult {
  IntegralResult -> CDouble
_value :: CDouble,
  IntegralResult -> CDouble
_error :: CDouble,
  IntegralResult -> Int
_code  :: Int
} deriving Int -> IntegralResult -> ShowS
[IntegralResult] -> ShowS
IntegralResult -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IntegralResult] -> ShowS
$cshowList :: [IntegralResult] -> ShowS
show :: IntegralResult -> String
$cshow :: IntegralResult -> String
showsPrec :: Int -> IntegralResult -> ShowS
$cshowsPrec :: Int -> IntegralResult -> ShowS
Show

foreign import ccall safe "wrapper" funPtr
    :: (CDouble -> CDouble) -> IO (FunPtr (CDouble -> CDouble))

foreign import ccall safe "integration" c_integration
    :: FunPtr (CDouble -> CDouble) -> CDouble -> CDouble -> CDouble -> CInt
    -> Ptr CDouble -> Ptr CInt -> IO CDouble

-- | Numerical integration.
integration :: (CDouble -> CDouble)       -- ^ integrand
            -> CDouble                   -- ^ lower bound
            -> CDouble                   -- ^ upper bound
            -> CDouble                   -- ^ desired relative error
            -> CInt                      -- ^ number of subdivisions
            -> IO IntegralResult        -- ^ value, error estimate, error code
integration :: (CDouble -> CDouble)
-> CDouble -> CDouble -> CDouble -> CInt -> IO IntegralResult
integration CDouble -> CDouble
f CDouble
lower CDouble
upper CDouble
relError CInt
subdiv = do
  Ptr CDouble
errorEstimatePtr <- forall a. Int -> IO (Ptr a)
mallocBytes (forall a. Storable a => a -> Int
sizeOf (CDouble
0 :: CDouble))
  Ptr CInt
errorCodePtr <- forall a. Int -> IO (Ptr a)
mallocBytes (forall a. Storable a => a -> Int
sizeOf (CInt
0 :: CInt))
  FunPtr (CDouble -> CDouble)
fPtr <- (CDouble -> CDouble) -> IO (FunPtr (CDouble -> CDouble))
funPtr CDouble -> CDouble
f
  CDouble
result <-
    FunPtr (CDouble -> CDouble)
-> CDouble
-> CDouble
-> CDouble
-> CInt
-> Ptr CDouble
-> Ptr CInt
-> IO CDouble
c_integration FunPtr (CDouble -> CDouble)
fPtr CDouble
lower CDouble
upper CDouble
relError CInt
subdiv Ptr CDouble
errorEstimatePtr Ptr CInt
errorCodePtr
  CDouble
errorEstimate <- forall a. Storable a => Ptr a -> IO a
peek Ptr CDouble
errorEstimatePtr
  Int
errorCode <- forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Storable a => Ptr a -> IO a
peek Ptr CInt
errorCodePtr
  let out :: IntegralResult
out = IntegralResult {_value :: CDouble
_value = CDouble
result, _error :: CDouble
_error = CDouble
errorEstimate, _code :: Int
_code = Int
errorCode}
  forall a. Ptr a -> IO ()
free Ptr CDouble
errorEstimatePtr
  forall a. Ptr a -> IO ()
free Ptr CInt
errorCodePtr
  forall a. FunPtr a -> IO ()
freeHaskellFunPtr FunPtr (CDouble -> CDouble)
fPtr
  forall (m :: * -> *) a. Monad m => a -> m a
return IntegralResult
out