module Control.Foldl.Statistics (
range
, sum'
, mean
, welfordMean
, meanWeighted
, harmonicMean
, geometricMean
, centralMoment
, centralMoments
, centralMoments'
, skewness
, kurtosis
, variance
, varianceUnbiased
, stdDev
, varianceWeighted
, fastVariance
, fastVarianceUnbiased
, fastStdDev
, fastLMVSK
, Stats4(..)
, fastLinearReg
, LinRegResult(..)
, correlation
, module Control.Foldl
) where
import Control.Foldl as F
import qualified Control.Foldl
import Data.Profunctor
import Numeric.Sum (KBNSum, kbn, add, zero)
data T = T {-# UNPACK #-}!Double {-# UNPACK #-}!Int
data TS = TS {-# UNPACK #-}!KBNSum {-# UNPACK #-}!Int
data T1 = T1 {-# UNPACK #-}!Int {-# UNPACK #-}!Double {-# UNPACK #-}!Double
data V = V {-# UNPACK #-}!Double {-# UNPACK #-}!Double
data V1 = V1 {-# UNPACK #-}!Double {-# UNPACK #-}!Double {-# UNPACK #-}!Int
data V1S = V1S {-# UNPACK #-}!KBNSum {-# UNPACK #-}!KBNSum {-# UNPACK #-}!Int
{-# INLINE sum' #-}
sum' :: Fold Double Double
sum' = Fold (add :: KBNSum -> Double -> KBNSum)
(zero :: KBNSum)
kbn
{-# INLINE range #-}
range :: Fold Double Double
range = (\(Just lo) (Just hi) -> hi - lo)
<$> F.minimum
<*> F.maximum
{-# INLINE mean #-}
mean :: Fold Double Double
mean = Fold step (TS zero 0) final where
step (TS s n) x = TS (add s x) (n+1)
final (TS s n) = kbn s / fromIntegral n
{-# INLINE welfordMean #-}
welfordMean :: Fold Double Double
welfordMean = Fold step (T 0 0) final where
final (T m _) = m
step (T m n) x = T m' n' where
m' = m + (x - m) / fromIntegral n'
n' = n + 1
{-# INLINE meanWeighted #-}
meanWeighted :: Fold (Double,Double) Double
meanWeighted = Fold step (V 0 0) final
where
final (V a _) = a
step (V m w) (x,xw) = V m' w'
where m' | w' == 0 = 0
| otherwise = m + xw * (x - m) / w'
w' = w + xw
{-# INLINE harmonicMean #-}
harmonicMean :: Fold Double Double
harmonicMean = Fold step (T 0 0) final
where
final (T b a) = fromIntegral a / b
step (T x y) n = T (x + (1/n)) (y+1)
{-# INLINE geometricMean #-}
geometricMean :: Fold Double Double
geometricMean = dimap log exp mean
{-# INLINE centralMoment #-}
centralMoment :: Int -> Double -> Fold Double Double
centralMoment a m
| a < 0 = error "Statistics.Sample.centralMoment: negative input"
| a == 0 = 1
| a == 1 = 0
| otherwise = Fold step (TS zero 0) final where
step (TS s n) x = TS (add s $ go x) (n+1)
final (TS s n) = kbn s / fromIntegral n
go x = (x-m) ^^^ a
{-# INLINE centralMoments #-}
centralMoments :: Int -> Int -> Double -> Fold Double (Double, Double)
centralMoments a b m
| a < 2 || b < 2 = (,) <$> centralMoment a m <*> centralMoment b m
| otherwise = Fold step (V1 0 0 0) final
where final (V1 i j n) = (i / fromIntegral n , j / fromIntegral n)
step (V1 i j n) x = V1 (i + d^^^a) (j + d^^^b) (n+1)
where d = x - m
{-# INLINE centralMoments' #-}
centralMoments' :: Int -> Int -> Double -> Fold Double (Double, Double)
centralMoments' a b m
| a < 2 || b < 2 = (,) <$> centralMoment a m <*> centralMoment b m
| otherwise = Fold step (V1S zero zero 0) final
where final (V1S i j n) = (kbn i / fromIntegral n , kbn j / fromIntegral n)
step (V1S i j n) x = V1S (add i $ d^^^a) (add j $ d^^^b) (n+1)
where d = x - m
{-# INLINE skewness #-}
skewness :: Double -> Fold Double Double
skewness m = (\(c3, c2) -> c3 * c2 ** (-1.5)) <$> centralMoments 3 2 m
{-# INLINE kurtosis #-}
kurtosis :: Double -> Fold Double Double
kurtosis m = (\(c4,c2) -> c4 / (c2 * c2) - 3) <$> centralMoments 4 2 m
{-# INLINE square #-}
square :: Double -> Double
square x = x * x
{-# INLINE robustSumVar #-}
robustSumVar :: Double -> Fold Double TS
robustSumVar m = Fold step (TS zero 0) id where
step (TS s n) x = TS (add s . square . subtract m $ x) (n+1)
{-# INLINE variance #-}
variance :: Double -> Fold Double Double
variance m =
(\(TS sv n) -> if n > 1 then kbn sv / fromIntegral n else 0)
<$> robustSumVar m
{-# INLINE varianceUnbiased #-}
varianceUnbiased :: Double -> Fold Double Double
varianceUnbiased m =
(\(TS sv n) -> if n > 1 then kbn sv / fromIntegral (n-1) else 0)
<$> robustSumVar m
{-# INLINE stdDev #-}
stdDev :: Double -> Fold Double Double
stdDev m = sqrt (varianceUnbiased m)
{-# INLINE robustSumVarWeighted #-}
robustSumVarWeighted :: Double -> Fold (Double,Double) V1
robustSumVarWeighted m = Fold step (V1 0 0 0) id
where
step (V1 s w n) (x,xw) = V1 (s + xw*d*d) (w + xw) (n+1)
where d = x - m
{-# INLINE varianceWeighted #-}
varianceWeighted :: Double -> Fold (Double,Double) Double
varianceWeighted m =
(\(V1 s w n) -> if n > 1 then s / w else 0)
<$> robustSumVarWeighted m
{-# INLINE fastVar #-}
fastVar :: Fold Double T1
fastVar = Fold step (T1 0 0 0) id
where
step (T1 n m s) x = T1 n' m' s'
where n' = n + 1
m' = m + d / fromIntegral n'
s' = s + d * (x - m')
d = x - m
{-# INLINE fastVariance #-}
fastVariance :: Fold Double Double
fastVariance = final <$> fastVar
where final (T1 n _m s)
| n > 1 = s / fromIntegral n
| otherwise = 0
{-# INLINE fastVarianceUnbiased #-}
fastVarianceUnbiased :: Fold Double Double
fastVarianceUnbiased = final <$> fastVar
where final (T1 n _m s)
| n > 1 = s / fromIntegral (n-1)
| otherwise = 0
{-# INLINE fastStdDev #-}
fastStdDev :: Fold Double Double
fastStdDev = sqrt fastVariance
data Stats4 = Stats4
{ stats4Count :: {-# UNPACK #-}!Int
, stats4Mean :: {-# UNPACK #-}!Double
, stats4Variance :: {-# UNPACK #-}!Double
, stats4Skewness :: {-# UNPACK #-}!Double
, stats4Kurtosis :: {-# UNPACK #-}!Double
} deriving (Show, Eq)
{-# INLINE fastLMVSK #-}
fastLMVSK :: Fold Double Stats4
fastLMVSK = finalStats4 <$> foldStats4
{-# INLINE stats40 #-}
stats40 = Stats4 0 0 0 0 0
{-# INLINE foldStats4 #-}
foldStats4 :: Fold Double Stats4
foldStats4 = Fold stepStats4 stats40 id
{-# INLINE stepStats4 #-}
stepStats4 :: Stats4 -> Double -> Stats4
stepStats4 (Stats4 n1 m1 m2 m3 m4) x = Stats4 n' m1' m2' m3' m4' where
n' = n1+1
delta = x - m1
delta_n = delta / fromIntegral n'
delta_n2 = delta_n * delta_n
term1 = delta * delta_n * fromIntegral n1
m1' = m1 + delta_n
m4' = m4 + term1 * delta_n2 * fromIntegral (n'*n' - 3*n' + 3) + 6 * delta_n2 * m2 - 4 * delta_n * m3
m3' = m3 + term1 * delta_n * fromIntegral (n' - 2) - 3 * delta_n * m2
m2' = m2 + term1
finalStats4 :: Stats4 -> Stats4
finalStats4 (Stats4 n m1 m2 m3 m4) = Stats4 n m1 m2' m3' m4' where
nd = fromIntegral n
m2' = m2 / (nd-1)
m3' = sqrt nd * m3 * (m2 ** (-1.5))
m4' = nd*m4 / (m2*m2) - 3.0
data LinRegResult = LinRegResult
{lrrCount :: {-# UNPACK #-}!Int
,lrrSlope :: {-# UNPACK #-}!Double
,lrrIntercept :: {-# UNPACK #-}!Double
,lrrCorrelation :: {-# UNPACK #-}!Double
} deriving (Show, Eq)
{-# INLINE fastLinearReg #-}
fastLinearReg :: Fold (Double,Double) LinRegResult
fastLinearReg = Fold step (V2 0 (V 0 0) (V 0 0) 0) final where
step (V2 n v1@(V xMean xVar) v2@(V yMean _) s_xy) (x,y) = V2 (n+1) v1' v2' s_xy' where
nd = fromIntegral n
nd1 = fromIntegral (n+1)
s_xy' = s_xy + (xMean - x)*(yMean - y)*nd/nd1
v1' = stepV v1 n x
v2' = stepV v2 n y
final (V2 n v1@(V xMean xVar) v2@(V yMean yVar) s_xy) = LinRegResult n slope intercept correlation where
ndm1 = fromIntegral (n-1)
slope = s_xy / xVar
intercept = yMean - slope*xMean
t = sqrt (xVar/ndm1) * sqrt (yVar/ndm1);
correlation = s_xy / (ndm1 * t)
data V2 = V2 {-# UNPACK #-}!Int {-# UNPACK #-}!V {-# UNPACK #-}!V {-# UNPACK #-}!Double
{-# INLINE stepV #-}
stepV :: V -> Int -> Double -> V
stepV (V m1 m2) n1 x = V m1' m2' where
delta = x - m1
delta_n = delta / fromIntegral (n1+1)
term1 = delta * delta_n * fromIntegral n1
m1' = m1 + delta_n
m2' = m2 + term1
correlation :: (Double, Double) -> (Double, Double) -> Fold (Double,Double) Double
correlation (m1,m2) (s1,s2) = Fold step (TS zero 0) final where
step (TS s n) (x1,x2) = TS (add s $ ((x1-m1)/s1) * ((x2-m2)/s2)) (n+1)
final (TS s n) = kbn s / fromIntegral (n-1)
(^^^) :: Double -> Int -> Double
x ^^^ 1 = x
x ^^^ n = x * (x ^^^ (n-1))
{-# INLINE[2] (^^^) #-}
{-# RULES
"pow 2" forall x. x ^^^ 2 = x * x
"pow 3" forall x. x ^^^ 3 = x * x * x
"pow 4" forall x. x ^^^ 4 = x * x * x * x
"pow 5" forall x. x ^^^ 5 = x * x * x * x * x
"pow 6" forall x. x ^^^ 6 = x * x * x * x * x * x
"pow 7" forall x. x ^^^ 7 = x * x * x * x * x * x * x
"pow 8" forall x. x ^^^ 8 = x * x * x * x * x * x * x * x
"pow 9" forall x. x ^^^ 9 = x * x * x * x * x * x * x * x * x
"pow 10" forall x. x ^^^ 10 = x * x * x * x * x * x * x * x * x * x
#-}