{-# LANGUAGE Safe, DeriveDataTypeable, DeriveFunctor #-}

module FRP.Reactivity.MeasurementWrapper (MeasurementWrapper(..), wrapOne, extractMW, measUntil, wrapperToList) where

import FRP.Reactivity.AlternateEvent
import FRP.Reactivity.Measurement
import Data.Time.Clock.POSIX
import Data.Typeable
import Data.Monoid
import Control.Monad
import Control.Monad.Fix
import Control.Applicative
import Control.Arrow (second)

newtype MeasurementWrapper t = MeasurementWrapper { unMeasurementWrapper :: [Measurement t] } deriving (Typeable, Functor, Show)

wrapOne m = MeasurementWrapper [m]

instance Monad MeasurementWrapper where
	return x = MeasurementWrapper [return x]
	MeasurementWrapper (x:xs) >>= f = f (copoint x) `mplus` (MeasurementWrapper xs >>= f)
	MeasurementWrapper [] >>= _ = mzero
	fail _ = mzero

instance MonadPlus MeasurementWrapper where
	mzero = MeasurementWrapper []
	MeasurementWrapper ls `mplus` MeasurementWrapper ls2 = MeasurementWrapper (ls `mergeStreams` ls2)

instance Monoid (MeasurementWrapper t) where
	mempty = mzero
	mappend = mplus

instance Applicative MeasurementWrapper where
	pure = return
	(<*>) = ap

instance Alternative MeasurementWrapper where
	empty = mzero
	(<|>) = mplus

extractMW :: MeasurementWrapper t -> t
extractMW (MeasurementWrapper (x:_)) = copoint x
extractMW (MeasurementWrapper []) = error "Comonad.extract: empty MeasurementWrapper"

measUntil :: MeasurementWrapper t -> MeasurementWrapper u -> Measurement (MeasurementWrapper t)
measUntil (MeasurementWrapper (x:xs)) (MeasurementWrapper (y:ys)) = continueInX `mplus` halt where
	continueInX = liftM2 (\_ (MeasurementWrapper xs) -> MeasurementWrapper (x:xs))
		x
		(measUntil (MeasurementWrapper xs) (MeasurementWrapper (y:ys)))
	halt = fmap (const mzero) y

instance EventStream MeasurementWrapper where
	eventFromList ls = MeasurementWrapper (fromList ls)

	scan f x (MeasurementWrapper (y:ys)) = fmap snd $ MeasurementWrapper $ scanl (\y z -> fmap (f (fst (copoint y))) z) (fmap (const (x, undefined)) y) (y:ys)
	scan _ _ (MeasurementWrapper []) = MeasurementWrapper []

	switch (MeasurementWrapper (x1:x2:xs)) = copoint (measUntil (copoint x1) (copoint x2)) `mplus` switch (MeasurementWrapper (x2:xs))
	switch (MeasurementWrapper [x1]) = copoint x1
	switch (MeasurementWrapper []) = mzero

	withRemainder wrapper = scan (\(MeasurementWrapper rest) y -> (MeasurementWrapper (tail rest), (y, MeasurementWrapper (tail rest)))) wrapper wrapper

	channel = liftM (second MeasurementWrapper) chan

	adjoinTime (MeasurementWrapper ls) = MeasurementWrapper (map (\meas -> fmap (\x -> (x, time meas)) meas) ls)

{-# INLINE wrapperToList #-}
wrapperToList :: MeasurementWrapper t -> [(t, POSIXTime)]
wrapperToList (MeasurementWrapper meass) = map (\meas -> (copoint meas, time meas)) meass