module Statistics.Math.RootFinding
(
Root(..)
, fromRoot
, ridders
) where
import Data.Aeson (FromJSON, ToJSON)
import Control.Applicative (Alternative(..), Applicative(..))
import Control.Monad (MonadPlus(..), ap)
import Data.Binary (Binary)
import Data.Binary (put, get)
import Data.Binary.Get (getWord8)
import Data.Binary.Put (putWord8)
import Data.Data (Data, Typeable)
import GHC.Generics (Generic)
import Numeric.MathFunctions.Comparison (within)
data Root a = NotBracketed
| SearchFailed
| Root a
deriving (Eq, Read, Show, Typeable, Data, Generic)
instance (FromJSON a) => FromJSON (Root a)
instance (ToJSON a) => ToJSON (Root a)
instance (Binary a) => Binary (Root a) where
put NotBracketed = putWord8 0
put SearchFailed = putWord8 1
put (Root a) = putWord8 2 >> put a
get = do
i <- getWord8
case i of
0 -> return NotBracketed
1 -> return SearchFailed
2 -> fmap Root get
_ -> fail $ "Root.get: Invalid value: " ++ show i
instance Functor Root where
fmap _ NotBracketed = NotBracketed
fmap _ SearchFailed = SearchFailed
fmap f (Root a) = Root (f a)
instance Monad Root where
NotBracketed >>= _ = NotBracketed
SearchFailed >>= _ = SearchFailed
Root a >>= m = m a
return = Root
instance MonadPlus Root where
mzero = SearchFailed
r@(Root _) `mplus` _ = r
_ `mplus` p = p
instance Applicative Root where
pure = Root
(<*>) = ap
instance Alternative Root where
empty = SearchFailed
r@(Root _) <|> _ = r
_ <|> p = p
fromRoot :: a
-> Root a
-> a
fromRoot _ (Root a) = a
fromRoot a _ = a
ridders :: Double
-> (Double,Double)
-> (Double -> Double)
-> Root Double
ridders tol (lo,hi) f
| flo == 0 = Root lo
| fhi == 0 = Root hi
| flo*fhi > 0 = NotBracketed
| otherwise = go lo flo hi fhi 0
where
go !a !fa !b !fb !i
| within 1 a b = Root a
| fm == 0 = Root m
| fn == 0 = Root n
| d < tol = Root n
| i >= (100 :: Int) = SearchFailed
| n == a || n == b = case () of
_| fm*fa < 0 -> go a fa m fm (i+1)
| otherwise -> go m fm b fb (i+1)
| fn*fm < 0 = go n fn m fm (i+1)
| fn*fa < 0 = go a fa n fn (i+1)
| otherwise = go n fn b fb (i+1)
where
d = abs (b a)
dm = (b a) * 0.5
!m = a + dm
!fm = f m
!dn = signum (fb fa) * dm * fm / sqrt(fm*fm fa*fb)
!n = m signum dn * min (abs dn) (abs dm 0.5 * tol)
!fn = f n
!flo = f lo
!fhi = f hi