{-# LANGUAGE Trustworthy, DeriveDataTypeable, ScopedTypeVariables, DeriveFunctor, GeneralizedNewtypeDeriving #-}
-- | The elements of FRP.
module FRP.Reactivity.Measurement (Measurement(Empty), wait, stmAction, assertMeasurement, measure, await, blindAwait, fromList, assertChan, chan, first, leMeas, mergeStreams, getValue, copoint, time) where

import GHC.Conc hiding (Chan, newChan)
import Control.Concurrent.MVar
import Control.CUtils.Conc
import Control.CUtils.FChan
import System.IO.Unsafe
import Data.Time.Clock.POSIX
import Data.Typeable
import Control.Monad
import Control.Applicative
import Data.Maybe
import Data.Monoid

eitherOr :: IO t -> IO u -> IO ()
eitherOr m m2 = oneOfF 2 (\n -> if n == 0 then void m else void m2)

-- | Measurements are a basic building block for pull-based FRP. They are like futures in that: when you have
--   something running on a separate thread, you can use a Measurement to wait on it. They also establish a
--   measurement of an an event occurrence time. Primitives for Measurements, make this measurement inside an
--   STM (software transactional memory) block. The STM system induces a global time ordering of transactions.
--   I piggyback on top of this mechanism to get a global time ordering of measurement as well. This is an
--   attempt to answer the tricky question of how to measure.
data Measurement t = Measurement
	!(IO ())
	!(STM (Maybe (t, POSIXTime))) | Empty deriving (Typeable, Functor)

instance (Show t) => Show (Measurement t) where
	showsPrec prec meas = showsPrec prec (copoint meas, time meas)

instance (Eq t) => Eq (Measurement t) where
	meas == meas2 = copoint meas == copoint meas2 && time meas == time meas2

instance Monad Measurement where
	return x = Measurement (return ()) (return (Just (x, 0)))
	meas >>= f = Measurement
		(getValue meas >>= wait . f . fst)
		(stmAction meas >>= maybe
			(return Nothing)
			(\(x, t) -> liftM (fmap (\(y, t') -> (y, max t t'))) (stmAction (f x))))
	fail _ = mzero

instance Applicative Measurement where
	pure = return
	(<*>) = ap

wait ~(Measurement co _) = co

stmAction ~(Measurement _ stm) = stm

_assertMeasurement :: IO (t, Maybe POSIXTime) -> IO (Measurement t)
_assertMeasurement m = do
	mv <- newEmptyMVar
	tv <- newTVarIO Nothing
	let writeMeas = do
		(x, my) <- m
		atomically (maybe (unsafeIOToSTM getPOSIXTime) return my >>= \t -> readTVar tv >>= maybe (writeTVar tv (Just (x, t))) (\_ -> return ()))
		tryPutMVar mv ()
		return ()
	let meas = Measurement
		(readMVar mv)
		(readTVar tv)
	forkIO writeMeas
	return meas

assertMeasurement :: IO (t, POSIXTime) -> IO (Measurement t)
assertMeasurement m = _assertMeasurement (liftM (\(x, t) -> (x, Just t)) m)

measure :: IO t -> IO (Measurement t)
measure m = _assertMeasurement (liftM (\x -> (x, Nothing)) m)

delayUntil :: POSIXTime -> IO ()
delayUntil t = getPOSIXTime >>= \time -> threadDelay (round (1000000 * fromRational (toRational (t - time))))

-- | Wait for a time, then measure that time.
await :: POSIXTime -> IO (Measurement ())
await t = measure (delayUntil t)
 
-- | Give the parameter time as the time of the measurement. This is "blind" to system lags that may disrupt
-- the timing of a control signal.
blindAwait :: POSIXTime -> Measurement ()
blindAwait t = unsafePerformIO (do
	Measurement wait stm <- measure (delayUntil t)
	return (Measurement wait (return (Just ((), t)))))

{-# INLINE fromList #-}
fromList :: [(t, POSIXTime)] -> [Measurement t]
fromList ((x, t):xs) = fmap fst meas : snd (copoint meas) where
	meas = fmap (const (x, fromList xs)) (blindAwait t)
fromList [] = []

{-# INLINE assertChan #-}
assertChan :: forall t. IO (t -> POSIXTime -> IO (), [Measurement t])
assertChan = do
	(f, chn) <- newChan
	let loop (chn :: Chan (t, POSIXTime)) = do
		meas <- assertMeasurement (do
			((x, t), chn') <- takeChan chn
			ls <- unsafeInterleaveIO (loop chn')
			return ((x, ls), t))
		return (fmap fst meas : snd (copoint meas))
	ls <- loop chn
	return (curry f, ls)

{-# INLINE chan #-}
chan :: IO (t -> IO (), [Measurement t])
chan = liftM (\(f, ls) -> (\x -> getPOSIXTime >>= f x, ls)) assertChan

-- | Decide which of the 'Measurement's comes first. I rely on the 'STM' subsystem to find
--   a consistent ordering.
first :: Measurement t -> Measurement t -> Measurement t
first Empty meas2 = meas2
first meas Empty = meas
first meas meas2 = Measurement wt stm
	where
	stm = do
		pr <- liftM2 (,) (stmAction meas) (stmAction meas2)
		return $ case pr of
			(Just (x, t), Just (x2, t2)) -> Just (if t <= t2 then (x, t) else (x2, t2))
			(Just pr, Nothing) -> Just pr
			(Nothing, Just pr) -> Just pr
			_ -> Nothing
	wt = wait meas `eitherOr` wait meas2

instance MonadPlus Measurement where
	mzero = Empty
	mplus = first

instance Monoid (Measurement t) where
	mempty = mzero
	mappend = mplus

instance Alternative Measurement where
	empty = mzero
	(<|>) = mplus

leMeas :: Measurement t -> Measurement u -> Bool
leMeas x y = copoint (fmap (const True) x `first` fmap (const False) y)

mergeStreams (x:xs) (y:ys) = fmap fst x' : snd (copoint x') where
	x' = fmap (\x' -> (x', mergeStreams xs (y:ys))) x
		`first` fmap (\x' -> (x', mergeStreams (x:xs) ys)) y
mergeStreams [] xs = xs
mergeStreams xs [] = xs

-- | Extract value and time of the measurement.
getValue :: Measurement t -> IO (t, POSIXTime)
getValue meas = do
	wait meas
	liftM fromJust (atomically (stmAction meas))

{-# INLINE copoint #-}
copoint :: Measurement t -> t
copoint = fst . unsafePerformIO . getValue

{-# INLINE time #-}
time :: Measurement t -> POSIXTime
time = snd . unsafePerformIO . getValue