{-# LANGUAGE ForeignFunctionInterface #-}
module Numerical.Integration
  (integration, IntegralResult(..))
  where
import           Foreign.Marshal.Alloc (free, mallocBytes)
import           Foreign.Ptr           (FunPtr, Ptr, freeHaskellFunPtr)
import           Foreign.Storable      (peek, sizeOf)
import           Foreign.C             (CDouble(..), CInt(..))               

zeroDouble :: Double
zeroDouble :: Double
zeroDouble = Double
0.0

-- zeroCDouble :: CDouble

-- zeroCDouble = 0.0


nanDouble :: Double
nanDouble :: Double
nanDouble = Double
zeroDouble forall a. Fractional a => a -> a -> a
/ Double
zeroDouble

-- nanCDouble :: CDouble

-- nanCDouble = zeroCDouble / zeroCDouble


double2Cdouble :: Double -> CDouble
double2Cdouble :: Double -> CDouble
double2Cdouble = forall a b. (Real a, Fractional b) => a -> b
realToFrac

-- cdouble2double :: CDouble -> Double

-- cdouble2double = realToFrac



data IntegralResult = IntegralResult {
  IntegralResult -> Double
_value :: Double,
  IntegralResult -> Double
_error :: Double,
  IntegralResult -> Int
_code  :: Int
} deriving Int -> IntegralResult -> ShowS
[IntegralResult] -> ShowS
IntegralResult -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [IntegralResult] -> ShowS
$cshowList :: [IntegralResult] -> ShowS
show :: IntegralResult -> String
$cshow :: IntegralResult -> String
showsPrec :: Int -> IntegralResult -> ShowS
$cshowsPrec :: Int -> IntegralResult -> ShowS
Show

foreign import ccall safe "wrapper" funPtr
    :: (CDouble -> CDouble) -> IO (FunPtr (CDouble -> CDouble))

foreign import ccall safe "integration" c_integration
    :: FunPtr (CDouble -> CDouble) -> CDouble -> CDouble -> CDouble -> CInt
    -> Ptr CDouble -> Ptr CInt -> IO CDouble

-- | Numerical integration.

integration :: (CDouble -> CDouble)   -- ^ integrand

            -> Double                 -- ^ lower bound

            -> Double                 -- ^ upper bound

            -> Double                 -- ^ desired relative error

            -> Int                    -- ^ number of subdivisions

            -> IO IntegralResult      -- ^ value, error estimate, error code

integration :: (CDouble -> CDouble)
-> Double -> Double -> Double -> Int -> IO IntegralResult
integration CDouble -> CDouble
f Double
lower Double
upper Double
relError Int
subdiv = do
  Ptr CDouble
errorEstimatePtr <- forall a. Int -> IO (Ptr a)
mallocBytes (forall a. Storable a => a -> Int
sizeOf (CDouble
0 :: CDouble))
  Ptr CInt
errorCodePtr <- forall a. Int -> IO (Ptr a)
mallocBytes (forall a. Storable a => a -> Int
sizeOf (CInt
0 :: CInt))
  FunPtr (CDouble -> CDouble)
fPtr <- (CDouble -> CDouble) -> IO (FunPtr (CDouble -> CDouble))
funPtr CDouble -> CDouble
f
  let lower' :: CDouble
lower' = Double -> CDouble
double2Cdouble Double
lower
      upper' :: CDouble
upper' = Double -> CDouble
double2Cdouble Double
upper 
      relError' :: CDouble
relError' = Double -> CDouble
double2Cdouble Double
relError
      subdiv' :: CInt
subdiv' = forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
subdiv
  CDouble
result <-
    FunPtr (CDouble -> CDouble)
-> CDouble
-> CDouble
-> CDouble
-> CInt
-> Ptr CDouble
-> Ptr CInt
-> IO CDouble
c_integration FunPtr (CDouble -> CDouble)
fPtr CDouble
lower' CDouble
upper' CDouble
relError' CInt
subdiv' Ptr CDouble
errorEstimatePtr Ptr CInt
errorCodePtr
  let result' :: Double
result' = if forall a. RealFloat a => a -> Bool
isNaN CDouble
result 
      then Double
nanDouble 
      else forall a b. (Real a, Fractional b) => a -> b
realToFrac CDouble
result
  CDouble
errorEstimate <- forall a. Storable a => Ptr a -> IO a
peek Ptr CDouble
errorEstimatePtr
  let errorEstimate' :: Double
errorEstimate' = if forall a. RealFloat a => a -> Bool
isNaN CDouble
errorEstimate 
      then Double
nanDouble 
      else forall a b. (Real a, Fractional b) => a -> b
realToFrac CDouble
errorEstimate
  Int
errorCode <- forall a b. (Integral a, Num b) => a -> b
fromIntegral forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. Storable a => Ptr a -> IO a
peek Ptr CInt
errorCodePtr
  let out :: IntegralResult
out = IntegralResult {_value :: Double
_value = Double
result', _error :: Double
_error = Double
errorEstimate', _code :: Int
_code = Int
errorCode}
  forall a. Ptr a -> IO ()
free Ptr CDouble
errorEstimatePtr
  forall a. Ptr a -> IO ()
free Ptr CInt
errorCodePtr
  forall a. FunPtr a -> IO ()
freeHaskellFunPtr FunPtr (CDouble -> CDouble)
fPtr
  forall (m :: * -> *) a. Monad m => a -> m a
return IntegralResult
out