-- |
-- Module:     Control.Wire.Prefab.Analyze
-- Copyright:  (c) 2011 Ertugrul Soeylemez
-- License:    BSD3
-- Maintainer: Ertugrul Soeylemez <es@ertes.de>
--
-- Various signal analysis tools

module Control.Wire.Prefab.Analyze
    ( -- * Statistics
      -- ** Average
      avg,
      avgAll,
      avgFps,
      avgFpsInt,
      -- ** Peak
      highPeak,
      lowPeak,
      peakBy,

      -- * Monitoring
      collect,
      diff,
      firstSeen,
      lastSeen
    )
    where

import qualified Data.Map as M
import qualified Data.Set as S
import qualified Data.Vector.Unboxed as Vu
import qualified Data.Vector.Unboxed.Mutable as Vum
import Control.Arrow
import Control.Monad.Fix
import Control.Monad.ST
import Control.Wire.Trans.Clock
import Control.Wire.Trans.Sample
import Control.Wire.Types
import Data.Map (Map)
import Data.Monoid
import Data.Set (Set)


-- | Calculate the average of the signal over the given number of last
-- samples.  If you need an average over all samples ever produced,
-- consider using 'avgAll' instead.
--
-- * Complexity: O(n) space, O(1) time wrt number of samples.
--
-- * Depends: current instant.

avg ::
    forall e v (>~).
    (Fractional v, Vu.Unbox v, WirePure (>~))
    => Int
    -> Wire e (>~) v v
avg n = mkPure $ \x -> (Right x, avg' (Vu.replicate n (x/d)) x 0)
    where
    avg' :: Vu.Vector v -> v -> Int -> Wire e (>~) v v
    avg' samples' s' cur' =
        mkPure $ \((/d) -> x) ->
            let cur = let ncur = succ cur' in
                      if ncur >= n then 0 else ncur
                x' = samples' Vu.! cur
                samples =
                    x' `seq` runST $ do
                        sam <- Vu.unsafeThaw samples'
                        Vum.write sam cur x
                        Vu.unsafeFreeze sam
                s = s' - x' + x
            in cur `seq` s' `seq` (Right s, avg' samples s cur)

    d :: v
    d = realToFrac n


-- | Calculate the average of the signal over all samples.
--
-- Please note that somewhat surprisingly this wire runs in constant
-- space and is generally faster than 'avg', but most applications will
-- benefit from averages over only the last few samples.
--
-- * Depends: current instant.

avgAll :: forall e v (>~). (Fractional v, WirePure (>~)) => Wire e (>~) v v
avgAll = mkPure $ \x -> (Right x, avgAll' 1 x)
    where
    avgAll' :: v -> v -> Wire e (>~) v v
    avgAll' n' a' =
        mkPure $ \x ->
            let n = n' + 1
                a = a' - a'/n + x/n
            in a' `seq` (Right a, avgAll' n a)


-- | Calculate the average number of frames per virtual second for the
-- last given number of frames.
--
-- Please note that this wire uses the clock from the 'WWithDT' instance
-- for the underlying arrow.  If this clock doesn't represent real time,
-- then the output of this wire won't either.

avgFps ::
    (Arrow (Wire e (>~)), Fractional t, Vu.Unbox t, WirePure (>~), WWithDT t (>~))
    => Int
    -> Wire e (>~) a t
avgFps n = recip ^<< passDT (avg n)


-- | Same as 'avgFps', but samples only at regular intervals.  This can
-- improve performance, if querying the clock is an expensive operation.

avgFpsInt ::
    (Arrow (Wire e (>~)), Fractional t, Vu.Unbox t, WirePure (>~), WSampleInt (>~), WWithDT t (>~))
    => Int  -- ^ Interval size.
    -> Int  -- ^ Number of Samples.
    -> Wire e (>~) a t
avgFpsInt int n =
    proc x' ->
        (| sampleInt ((* fromIntegral int) ^<< avgFps n -< x') |) int


-- | Collects all distinct inputs ever received.
--
-- * Complexity: O(n) space, O(log n) time wrt collected inputs so far.
--
-- * Depends: current instant.

collect :: forall b e (>~). (Ord b, WirePure (>~)) => Wire e (>~) b (Set b)
collect = collect' S.empty
    where
    collect' :: Set b -> Wire e (>~) b (Set b)
    collect' ins' =
        mkPure $ \x ->
            let ins = S.insert x ins'
            in (Right ins, collect' ins)


-- | Outputs the last input value on every change of the input signal.
-- Acts like the identity wire at the first instant.
--
-- * Depends: current instant.
--
-- * Inhibits: on no change after the first instant.

diff :: forall b e (>~). (Eq b, Monoid e, WirePure (>~)) => Wire e (>~) b b
diff = mkPure $ \x -> (Right x, diff' x)
    where
    diff' :: b -> Wire e (>~) b b
    diff' x' =
        mkPure $ \x ->
            if x' == x
              then (Left mempty, diff' x')
              else (Right x', diff' x)


-- | Reports the first global time the given input was seen.
--
-- * Complexity: O(n) space, O(log n) time wrt collected inputs so far.
--
-- * Depends: Current instant.

firstSeen ::
    forall a e t (>~). (Ord a, WirePure (>~), WWithSysTime t (>~))
    => Wire e (>~) a t
firstSeen = withSysTime (firstSeen' M.empty)
    where
    firstSeen' :: Map a t -> Wire e (>~) (a, t) t
    firstSeen' xs' =
        fix $ \again ->
        mkPure $ \(x, t) ->
            case M.lookup x xs' of
              Just xt -> (Right xt, again)
              Nothing -> (Right t, firstSeen' (M.insert x t xs'))


-- | Outputs the high peak of the input signal.
--
-- * Depends: Current instant.

highPeak :: (Ord b, WirePure (>~)) => Wire e (>~) b b
highPeak = peakBy compare


-- | Reports the last time the given input was seen.  Inhibits when
-- seeing a signal for the first time.
--
-- * Complexity: O(n) space, O(log n) time wrt collected inputs so far.
--
-- * Depends: Current instant.
--
-- * Inhibits: On first sight of a signal.

lastSeen ::
    forall a e t (>~). (Monoid e, Ord a, WirePure (>~), WWithSysTime t (>~))
    => Wire e (>~) a t
lastSeen = withSysTime (lastSeen' M.empty)
    where
    lastSeen' :: Map a t -> Wire e (>~) (a, t) t
    lastSeen' xs' =
        mkPure $ \(x, t) ->
            let xs = M.insert x t xs'
            in (maybe (Left mempty) Right $ M.lookup x xs',
                lastSeen' xs)


-- | Outputs the low peak of the input signal.
--
-- * Depends: Current instant.

lowPeak :: (Ord b, WirePure (>~)) => Wire e (>~) b b
lowPeak = peakBy (flip compare)


-- | Outputs the high peak of the input signal with respect to the given
-- comparison function.
--
-- * Depends: Current instant.

peakBy ::
    forall b e (>~). WirePure (>~)
    => (b -> b -> Ordering)
    -> Wire e (>~) b b
peakBy comp = mkPure (Right &&& peakBy')
    where
    peakBy' :: b -> Wire e (>~) b b
    peakBy' x'' =
        mkPure $ \x' ->
            Right &&& peakBy' $ if comp x' x'' == GT then x' else x''