{-# LANGUAGE ForeignFunctionInterface #-}
module Numerical.Integration
  (integration)
  where
import           Foreign.Marshal.Alloc (free, mallocBytes)
import           Foreign.Ptr           (FunPtr, Ptr, freeHaskellFunPtr)
import           Foreign.Storable      (peek, sizeOf)

foreign import ccall safe "wrapper" funPtr
    :: (Double -> Double) -> IO(FunPtr (Double -> Double))

foreign import ccall safe "integration" c_integration
    :: FunPtr (Double -> Double) -> Double -> Double -> Double -> Int
    -> Ptr Double -> Ptr Int -> IO Double

-- | Numerical integration.
integration :: (Double -> Double)       -- ^ integrand
            -> Double                   -- ^ lower bound
            -> Double                   -- ^ upper bound
            -> Double                   -- ^ desired relative error
            -> Int                      -- ^ number of subdivisions
            -> IO (Double, Double, Int) -- ^ value, error estimate, error code
integration :: (Double -> Double)
-> Double -> Double -> Double -> Int -> IO (Double, Double, Int)
integration Double -> Double
f Double
lower Double
upper Double
relError Int
subdiv = do
  Ptr Double
errorEstimatePtr <- forall a. Int -> IO (Ptr a)
mallocBytes (forall a. Storable a => a -> Int
sizeOf (Double
0 :: Double))
  Ptr Int
errorCodePtr <- forall a. Int -> IO (Ptr a)
mallocBytes (forall a. Storable a => a -> Int
sizeOf (Int
0 :: Int))
  FunPtr (Double -> Double)
fPtr <- (Double -> Double) -> IO (FunPtr (Double -> Double))
funPtr Double -> Double
f
  Double
result <-
    FunPtr (Double -> Double)
-> Double
-> Double
-> Double
-> Int
-> Ptr Double
-> Ptr Int
-> IO Double
c_integration FunPtr (Double -> Double)
fPtr Double
lower Double
upper Double
relError Int
subdiv Ptr Double
errorEstimatePtr Ptr Int
errorCodePtr
  Double
errorEstimate <- forall a. Storable a => Ptr a -> IO a
peek Ptr Double
errorEstimatePtr
  Int
errorCode <- forall a. Storable a => Ptr a -> IO a
peek Ptr Int
errorCodePtr
  let out :: (Double, Double, Int)
out = (Double
result, Double
errorEstimate, Int
errorCode)
  forall a. Ptr a -> IO ()
free Ptr Double
errorEstimatePtr
  forall a. Ptr a -> IO ()
free Ptr Int
errorCodePtr
  forall a. FunPtr a -> IO ()
freeHaskellFunPtr FunPtr (Double -> Double)
fPtr
  forall (m :: * -> *) a. Monad m => a -> m a
return (Double, Double, Int)
out