{-# LANGUAGE ForeignFunctionInterface #-}
module Numerical.Cubature
  (cubature)
  where
import           Foreign.C.Types       (CUInt(..))
import           Foreign.Marshal.Alloc (free, mallocBytes)
import           Foreign.Marshal.Array (peekArray, pokeArray)
import           Foreign.Ptr           (FunPtr, Ptr, freeHaskellFunPtr)
import           Foreign.Storable      (poke, peek, sizeOf)

type Integrand = CUInt -> Ptr Double -> Ptr () -> CUInt -> Ptr Double -> IO Int

data Result = Result
  { Result -> Double
_integral :: Double, Result -> Double
_error :: Double } 
  deriving (Int -> Result -> ShowS
[Result] -> ShowS
Result -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Result] -> ShowS
$cshowList :: [Result] -> ShowS
show :: Result -> String
$cshow :: Result -> String
showsPrec :: Int -> Result -> ShowS
$cshowsPrec :: Int -> Result -> ShowS
Show)

foreign import ccall safe "wrapper" integrandPtr
    :: Integrand -> IO (FunPtr Integrand)

foreign import ccall safe "mintegration" c_cubature
    :: Char
    -> FunPtr Integrand
    -> Int
    -> Ptr Double
    -> Ptr Double
    -> Double
    -> Ptr Double
    -> IO Double

fun2integrand :: ([Double] -> Double) -> Int -> Integrand
fun2integrand :: ([Double] -> Double) -> Int -> Integrand
fun2integrand [Double] -> Double
f Int
n CUInt
_ Ptr Double
x Ptr ()
_ CUInt
_ Ptr Double
fval = do
  [Double]
list <- forall a. Storable a => Int -> Ptr a -> IO [a]
peekArray Int
n Ptr Double
x
  forall a. Storable a => Ptr a -> a -> IO ()
poke Ptr Double
fval ([Double] -> Double
f [Double]
list)
  forall (m :: * -> *) a. Monad m => a -> m a
return Int
0

-- | Multivariate integration on an axis-aligned box.
cubature :: Char                 -- ^ cubature version, 'h' or 'p'
         -> ([Double] -> Double) -- ^ integrand
         -> Int                  -- ^ dimension (number of variables)
         -> [Double]             -- ^ lower limits of integration
         -> [Double]             -- ^ upper limits of integration
         -> Double               -- ^ desired relative error
         -> IO Result            -- ^ output: integral value and error estimate
cubature :: Char
-> ([Double] -> Double)
-> Int
-> [Double]
-> [Double]
-> Double
-> IO Result
cubature Char
version [Double] -> Double
f Int
n [Double]
xmin [Double]
xmax Double
relError = do
  FunPtr Integrand
fPtr <- Integrand -> IO (FunPtr Integrand)
integrandPtr (([Double] -> Double) -> Int -> Integrand
fun2integrand [Double] -> Double
f Int
n)
  Ptr Double
xminPtr <- forall a. Int -> IO (Ptr a)
mallocBytes (Int
n forall a. Num a => a -> a -> a
* forall a. Storable a => a -> Int
sizeOf (Double
0.0 :: Double))
  forall a. Storable a => Ptr a -> [a] -> IO ()
pokeArray Ptr Double
xminPtr [Double]
xmin
  Ptr Double
xmaxPtr <- forall a. Int -> IO (Ptr a)
mallocBytes (Int
n forall a. Num a => a -> a -> a
* forall a. Storable a => a -> Int
sizeOf (Double
0.0 :: Double))
  forall a. Storable a => Ptr a -> [a] -> IO ()
pokeArray Ptr Double
xmaxPtr [Double]
xmax
  Ptr Double
errorPtr <- forall a. Int -> IO (Ptr a)
mallocBytes (forall a. Storable a => a -> Int
sizeOf (Double
0.0 :: Double))
  Double
result <- Char
-> FunPtr Integrand
-> Int
-> Ptr Double
-> Ptr Double
-> Double
-> Ptr Double
-> IO Double
c_cubature Char
version FunPtr Integrand
fPtr Int
n Ptr Double
xminPtr Ptr Double
xmaxPtr Double
relError Ptr Double
errorPtr
  Double
errorEstimate <- forall a. Storable a => Ptr a -> IO a
peek Ptr Double
errorPtr
  forall a. Ptr a -> IO ()
free Ptr Double
errorPtr
  forall a. Ptr a -> IO ()
free Ptr Double
xmaxPtr
  forall a. Ptr a -> IO ()
free Ptr Double
xminPtr
  forall a. FunPtr a -> IO ()
freeHaskellFunPtr FunPtr Integrand
fPtr
  forall (m :: * -> *) a. Monad m => a -> m a
return Result { _integral :: Double
_integral = Double
result, _error :: Double
_error = Double
errorEstimate }

-- fExample :: [Double] -> Double
-- fExample list = exp (-0.5 * (sum $ zipWith (*) list list))

-- example :: IO Result
-- example = cubature 'h' fExample 2 [-6,-6] [6,6] 1e-10