{-# LANGUAGE FlexibleInstances      #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE GADTs                  #-}
module Data.TDigest.Postprocess.Internal (
    
    HasHistogram (..),
    HistBin (..),
    histogramFromCentroids,
    
    quantile,
    
    
    
    
    mean,
    variance,
    
    cdf,
    
    validateHistogram,
    
    Affine (..),
    ) where
import Data.Foldable           (toList)
import Data.Functor.Compose    (Compose (..))
import Data.Functor.Identity   (Identity (..))
import Data.List.NonEmpty      (NonEmpty (..), nonEmpty)
import Data.Proxy              (Proxy (..))
import Data.Semigroup          (Semigroup (..))
import Data.Semigroup.Foldable (foldMap1)
import Prelude ()
import Prelude.Compat
import qualified Data.List.NonEmpty  as NE
import Data.TDigest.Internal
data HistBin = HistBin
    { hbMin       :: !Mean    
    , hbMax       :: !Mean    
    , hbValue     :: !Mean    
    , hbWeight    :: !Weight  
    , hbCumWeight :: !Weight  
    }
  deriving (Show)
class Affine f => HasHistogram a f | a -> f where
    histogram   :: a -> f (NonEmpty HistBin)
    totalWeight :: a -> Weight
instance (HistBin ~ e) => HasHistogram (NonEmpty HistBin) Identity where
    histogram = Identity
    totalWeight = tw . NE.last where
        tw hb =  hbWeight hb + hbCumWeight hb
instance (HistBin ~ e) => HasHistogram [HistBin] Maybe where
    histogram = nonEmpty
    totalWeight = affine 0 totalWeight . histogram
histogramFromCentroids :: NonEmpty Centroid -> NonEmpty HistBin
histogramFromCentroids = make
  where
    make :: NonEmpty Centroid -> NonEmpty HistBin
    
    make ((x, w) :| []) = HistBin x x x w 0 :| []
    
    make (c1@(x1, w1) :| rest@((x2, _) : _))
        = HistBin x1 (mid x1 x2) x1 w1 0 :| iter c1 w1 rest
    
    iter :: (Mean, Weight) -> Weight -> [(Mean, Weight)] -> [HistBin]
    iter _ _ [] = []
    
    iter (x0, _) t (c1@(x1, w1) : rest@((x2, _) : _))
        = HistBin (mid x0 x1) (mid x1 x2) x1 w1 t: iter c1 (t + w1) rest
    
    iter (x0, _) t [(x1, w1)]
        = [HistBin (mid x0 x1) x1 x1 w1 t]
    mid a b = (a + b) / 2
quantile :: Double -> Weight -> NonEmpty HistBin -> Double
quantile q tw = iter . toList
  where
    q' = q * tw
    iter []                          = error "quantile: empty NonEmpty"
    iter [HistBin a b _ w t]           = a + (b - a) * (q' - t) / w
    iter (HistBin a b _ w t : rest)
        |  q' < t + w = a + (b - a) * (q' - t) / w
        | otherwise                  = iter rest
mean :: NonEmpty HistBin -> Double
mean = getMean . foldMap1 toMean
  where
    toMean (HistBin _ _ x w _) = Mean w x
data Mean' = Mean !Double !Double
getMean :: Mean' -> Double
getMean (Mean _ x) = x
instance Semigroup Mean' where
    Mean w1 x1 <> Mean w2 x2 = Mean w x
      where
        w = w1 + w2
        x = (x1 * w1 + x2 * w2) / w
variance :: NonEmpty HistBin -> Double
variance = getVariance . foldMap1 toVariance
  where
    toVariance (HistBin _ _ x w _) = Variance w x 0
data Variance = Variance !Double !Double !Double
getVariance :: Variance -> Double
getVariance (Variance w _ d) = d / (w - 1)
instance Semigroup Variance where
    Variance w1 x1 d1 <> Variance w2 x2 d2 = Variance w x d
      where
        w = w1 + w2
        x = (x1 * w1 + x2 * w2) / w
        d = d1 + d2 + w1 * (x1 * x1) + w2 * (x2 * x2) - w * x * x
cdf :: Double
    -> Double  
    -> [HistBin] -> Double
cdf x n = iter
  where
    iter [] = 1
    iter (HistBin a b _ w t : rest)
        | x < a     = 0
        | x < b     = (t + w * (x - a) / (b - a)) / n
        | otherwise = iter rest
validateHistogram :: Foldable f => f HistBin -> Either String (f HistBin)
validateHistogram bs = traverse validPair (pairs $ toList bs) >> pure bs
  where
    validPair (lb@(HistBin _ lmax _ lwt lcw), rb@(HistBin rmin _ _ _ rcw)) = do
        check (lmax == rmin)     "gap between bins"
        check (lcw + lwt == rcw) "mismatch in weight cumulation"
      where
        check False err = Left $ err ++ " " ++ show (lb, rb)
        check True  _   = Right ()
    pairs xs = zip xs $ tail xs
class Traversable t => Affine t where
    
    affine :: b -> (a -> b) -> t a -> b
    affine x f = fromAffine x . fmap f
    fromAffine :: a -> t a -> a
    fromAffine x = affine x id
    {-# MINIMAL fromAffine | affine #-}
instance Affine Identity    where fromAffine _ = runIdentity
instance Affine Maybe       where affine = maybe
instance Affine Proxy       where affine x _ _ = x
instance (Affine f, Affine g) => Affine (Compose f g) where
    affine x f (Compose c) = affine x (affine x f) c