{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Torch.Core.LogAdd
( logAdd
, logSub
, expMinusApprox
) where
import Torch.Core.Exceptions
import qualified Torch.FFI.TH.LogAdd as TH
logAdd :: Double -> Double -> IO Double
logAdd a b = realToFrac <$> TH.c_THLogAdd (realToFrac a) (realToFrac b)
logSub :: Double -> Double -> IO Double
logSub log_a log_b
| log_a < log_b = throw $ MathException "log_a must be greater than log_b"
| otherwise = realToFrac <$> TH.c_THLogSub (realToFrac log_a) (realToFrac log_b)
expMinusApprox :: Double -> IO Double
expMinusApprox a = realToFrac <$> TH.c_THExpMinusApprox (realToFrac a)
expMinusApprox' :: forall f . RealFrac f => f -> Maybe f
expMinusApprox' x
| x < 0 = Nothing
| x < 13 = Just $ 1 / (y*y*y*y)
| otherwise = Just 0
where
a0, a1, a2, a3, a4 :: f
a0 = 1
a1 = 0.125
a2 = 0.0078125
a3 = 0.00032552083
a4 = 1.0172526e-5
y = a0 + x * (a1 + x * (a2 + x * (a3 + x * a4)))