{-----------------------------------------------------------------------------
    reactive-banana
------------------------------------------------------------------------------}
{-# LANGUAGE ExistentialQuantification, NamedFieldPuns #-}
{-# LANGUAGE TypeSynonymInstances, FlexibleInstances #-}
module Reactive.Banana.Prim.Types where

import           Control.Monad.Trans.RWSIO
import           Control.Monad.Trans.Reader
import           Control.Monad.Trans.ReaderWriterIO
import           Data.Functor
import           Data.Hashable
import           Data.Monoid (Monoid, mempty, mappend)
import           Data.Semigroup
import qualified Data.Vault.Lazy                    as Lazy
import           System.IO.Unsafe
import           System.Mem.Weak

import Reactive.Banana.Prim.Graph            (Graph)
import Reactive.Banana.Prim.OrderedBag as OB (OrderedBag, empty)
import Reactive.Banana.Prim.Util

{-----------------------------------------------------------------------------
    Network
------------------------------------------------------------------------------}
-- | A 'Network' represents the state of a pulse/latch network,
data Network = Network
    { Network -> Time
nTime           :: !Time                 -- Current time.
    , Network -> OrderedBag Output
nOutputs        :: !(OrderedBag Output)  -- Remember outputs to prevent garbage collection.
    , Network -> Maybe (Pulse ())
nAlwaysP        :: !(Maybe (Pulse ()))   -- Pulse that always fires.
    }

type Inputs        = ([SomeNode], Lazy.Vault)
type EvalNetwork a = Network -> IO (a, Network)
type Step          = EvalNetwork (IO ())

emptyNetwork :: Network
emptyNetwork :: Network
emptyNetwork = Network :: Time -> OrderedBag Output -> Maybe (Pulse ()) -> Network
Network
    { nTime :: Time
nTime    = Time -> Time
next Time
beginning
    , nOutputs :: OrderedBag Output
nOutputs = OrderedBag Output
forall a. OrderedBag a
OB.empty
    , nAlwaysP :: Maybe (Pulse ())
nAlwaysP = Maybe (Pulse ())
forall a. Maybe a
Nothing
    }

type Build  = ReaderWriterIOT BuildR BuildW IO
type BuildR = (Time, Pulse ())
    -- ( current time
    -- , pulse that always fires)
newtype BuildW = BuildW (DependencyBuilder, [Output], Action, Maybe (Build ()))
    -- reader : current timestamp
    -- writer : ( actions that change the network topology
    --          , outputs to be added to the network
    --          , late IO actions
    --          , late build actions
    --          )

instance Semigroup BuildW where
    BuildW (DependencyBuilder, [Output], Action, Maybe (Build ()))
x <> :: BuildW -> BuildW -> BuildW
<> BuildW (DependencyBuilder, [Output], Action, Maybe (Build ()))
y = (DependencyBuilder, [Output], Action, Maybe (Build ())) -> BuildW
BuildW ((DependencyBuilder, [Output], Action, Maybe (Build ()))
x (DependencyBuilder, [Output], Action, Maybe (Build ()))
-> (DependencyBuilder, [Output], Action, Maybe (Build ()))
-> (DependencyBuilder, [Output], Action, Maybe (Build ()))
forall a. Semigroup a => a -> a -> a
<> (DependencyBuilder, [Output], Action, Maybe (Build ()))
y)

instance Monoid BuildW where
    mempty :: BuildW
mempty  = (DependencyBuilder, [Output], Action, Maybe (Build ())) -> BuildW
BuildW (DependencyBuilder, [Output], Action, Maybe (Build ()))
forall a. Monoid a => a
mempty
    mappend :: BuildW -> BuildW -> BuildW
mappend = BuildW -> BuildW -> BuildW
forall a. Semigroup a => a -> a -> a
(<>)

type BuildIO = Build

type DependencyBuilder = (Endo (Graph SomeNode), [(SomeNode, SomeNode)])

{-----------------------------------------------------------------------------
    Synonyms
------------------------------------------------------------------------------}
-- | Priority used to determine evaluation order for pulses.
type Level = Int

ground :: Level
ground :: Level
ground = Level
0

-- | 'IO' actions as a monoid with respect to sequencing.
newtype Action = Action { Action -> IO ()
doit :: IO () }
instance Semigroup Action where
    Action IO ()
x <> :: Action -> Action -> Action
<> Action IO ()
y = IO () -> Action
Action (IO ()
x IO () -> IO () -> IO ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> IO ()
y)
instance Monoid Action where
    mempty :: Action
mempty = IO () -> Action
Action (IO () -> Action) -> IO () -> Action
forall a b. (a -> b) -> a -> b
$ () -> IO ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
    mappend :: Action -> Action -> Action
mappend = Action -> Action -> Action
forall a. Semigroup a => a -> a -> a
(<>)

-- | Lens-like functionality.
data Lens s a = Lens (s -> a) (a -> s -> s)

set :: Lens s a -> a -> s -> s
set :: Lens s a -> a -> s -> s
set (Lens s -> a
_   a -> s -> s
set)   = a -> s -> s
set

update :: Lens s a -> (a -> a) -> s -> s
update :: Lens s a -> (a -> a) -> s -> s
update (Lens s -> a
get a -> s -> s
set) a -> a
f = \s
s -> a -> s -> s
set (a -> a
f (a -> a) -> a -> a
forall a b. (a -> b) -> a -> b
$ s -> a
get s
s) s
s

{-----------------------------------------------------------------------------
    Pulse and Latch
------------------------------------------------------------------------------}
type Pulse  a = Ref (Pulse' a)
data Pulse' a = Pulse
    { Pulse' a -> Key (Maybe a)
_keyP      :: Lazy.Key (Maybe a) -- Key to retrieve pulse from cache.
    , Pulse' a -> Time
_seenP     :: !Time              -- See note [Timestamp].
    , Pulse' a -> EvalP (Maybe a)
_evalP     :: EvalP (Maybe a)    -- Calculate current value.
    , Pulse' a -> [Weak SomeNode]
_childrenP :: [Weak SomeNode]    -- Weak references to child nodes.
    , Pulse' a -> [Weak SomeNode]
_parentsP  :: [Weak SomeNode]    -- Weak reference to parent nodes.
    , Pulse' a -> Level
_levelP    :: !Level             -- Priority in evaluation order.
    , Pulse' a -> String
_nameP     :: String             -- Name for debugging.
    }

instance Show (Pulse a) where
    show :: Pulse a -> String
show Pulse a
p = Pulse' a -> String
forall a. Pulse' a -> String
_nameP (IO (Pulse' a) -> Pulse' a
forall a. IO a -> a
unsafePerformIO (IO (Pulse' a) -> Pulse' a) -> IO (Pulse' a) -> Pulse' a
forall a b. (a -> b) -> a -> b
$ Pulse a -> IO (Pulse' a)
forall (m :: * -> *) a. MonadIO m => Ref a -> m a
readRef Pulse a
p) String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
" " String -> ShowS
forall a. [a] -> [a] -> [a]
++ Level -> String
forall a. Show a => a -> String
show (Level -> Pulse a -> Level
forall a. Hashable a => Level -> a -> Level
hashWithSalt Level
0 Pulse a
p)

type Latch  a = Ref (Latch' a)
data Latch' a = Latch
    { Latch' a -> Time
_seenL  :: !Time               -- Timestamp for the current value.
    , Latch' a -> a
_valueL :: a                   -- Current value.
    , Latch' a -> EvalL a
_evalL  :: EvalL a             -- Recalculate current latch value.
    }
type LatchWrite = Ref LatchWrite'
data LatchWrite' = forall a. LatchWrite
    { ()
_evalLW  :: EvalP a            -- Calculate value to write.
    , ()
_latchLW :: Weak (Latch a)     -- Destination 'Latch' to write to.
    }

type Output  = Ref Output'
data Output' = Output
    { Output' -> EvalP EvalO
_evalO     :: EvalP EvalO
    }
instance Eq Output where == :: Output -> Output -> Bool
(==) = Output -> Output -> Bool
forall a b. Ref a -> Ref b -> Bool
equalRef

data SomeNode
    = forall a. P (Pulse a)
    | L LatchWrite
    | O Output

instance Hashable SomeNode where
    hashWithSalt :: Level -> SomeNode -> Level
hashWithSalt Level
s (P Pulse a
x) = Level -> Pulse a -> Level
forall a. Hashable a => Level -> a -> Level
hashWithSalt Level
s Pulse a
x
    hashWithSalt Level
s (L LatchWrite
x) = Level -> LatchWrite -> Level
forall a. Hashable a => Level -> a -> Level
hashWithSalt Level
s LatchWrite
x
    hashWithSalt Level
s (O Output
x) = Level -> Output -> Level
forall a. Hashable a => Level -> a -> Level
hashWithSalt Level
s Output
x

instance Eq SomeNode where
    (P Pulse a
x) == :: SomeNode -> SomeNode -> Bool
== (P Pulse a
y) = Pulse a -> Pulse a -> Bool
forall a b. Ref a -> Ref b -> Bool
equalRef Pulse a
x Pulse a
y
    (L LatchWrite
x) == (L LatchWrite
y) = LatchWrite -> LatchWrite -> Bool
forall a b. Ref a -> Ref b -> Bool
equalRef LatchWrite
x LatchWrite
y
    (O Output
x) == (O Output
y) = Output -> Output -> Bool
forall a b. Ref a -> Ref b -> Bool
equalRef Output
x Output
y

{-# INLINE mkWeakNodeValue #-}
mkWeakNodeValue :: SomeNode -> v -> IO (Weak v)
mkWeakNodeValue :: SomeNode -> v -> IO (Weak v)
mkWeakNodeValue (P Pulse a
x) = Pulse a -> v -> IO (Weak v)
forall (m :: * -> *) a value.
MonadIO m =>
Ref a -> value -> m (Weak value)
mkWeakRefValue Pulse a
x
mkWeakNodeValue (L LatchWrite
x) = LatchWrite -> v -> IO (Weak v)
forall (m :: * -> *) a value.
MonadIO m =>
Ref a -> value -> m (Weak value)
mkWeakRefValue LatchWrite
x
mkWeakNodeValue (O Output
x) = Output -> v -> IO (Weak v)
forall (m :: * -> *) a value.
MonadIO m =>
Ref a -> value -> m (Weak value)
mkWeakRefValue Output
x

-- Lenses for various parameters
seenP :: Lens (Pulse' a) Time
seenP :: Lens (Pulse' a) Time
seenP = (Pulse' a -> Time)
-> (Time -> Pulse' a -> Pulse' a) -> Lens (Pulse' a) Time
forall s a. (s -> a) -> (a -> s -> s) -> Lens s a
Lens Pulse' a -> Time
forall a. Pulse' a -> Time
_seenP  (\Time
a Pulse' a
s -> Pulse' a
s { _seenP :: Time
_seenP = Time
a })

seenL :: Lens (Latch' a) Time
seenL :: Lens (Latch' a) Time
seenL = (Latch' a -> Time)
-> (Time -> Latch' a -> Latch' a) -> Lens (Latch' a) Time
forall s a. (s -> a) -> (a -> s -> s) -> Lens s a
Lens Latch' a -> Time
forall a. Latch' a -> Time
_seenL  (\Time
a Latch' a
s -> Latch' a
s { _seenL :: Time
_seenL = Time
a })

valueL :: Lens (Latch' a) a
valueL :: Lens (Latch' a) a
valueL = (Latch' a -> a) -> (a -> Latch' a -> Latch' a) -> Lens (Latch' a) a
forall s a. (s -> a) -> (a -> s -> s) -> Lens s a
Lens Latch' a -> a
forall a. Latch' a -> a
_valueL (\a
a Latch' a
s -> Latch' a
s { _valueL :: a
_valueL = a
a })

parentsP :: Lens (Pulse' a) [Weak SomeNode]
parentsP :: Lens (Pulse' a) [Weak SomeNode]
parentsP = (Pulse' a -> [Weak SomeNode])
-> ([Weak SomeNode] -> Pulse' a -> Pulse' a)
-> Lens (Pulse' a) [Weak SomeNode]
forall s a. (s -> a) -> (a -> s -> s) -> Lens s a
Lens Pulse' a -> [Weak SomeNode]
forall a. Pulse' a -> [Weak SomeNode]
_parentsP (\[Weak SomeNode]
a Pulse' a
s -> Pulse' a
s { _parentsP :: [Weak SomeNode]
_parentsP = [Weak SomeNode]
a })

childrenP :: Lens (Pulse' a) [Weak SomeNode]
childrenP :: Lens (Pulse' a) [Weak SomeNode]
childrenP = (Pulse' a -> [Weak SomeNode])
-> ([Weak SomeNode] -> Pulse' a -> Pulse' a)
-> Lens (Pulse' a) [Weak SomeNode]
forall s a. (s -> a) -> (a -> s -> s) -> Lens s a
Lens Pulse' a -> [Weak SomeNode]
forall a. Pulse' a -> [Weak SomeNode]
_childrenP (\[Weak SomeNode]
a Pulse' a
s -> Pulse' a
s { _childrenP :: [Weak SomeNode]
_childrenP = [Weak SomeNode]
a })

levelP :: Lens (Pulse' a) Int
levelP :: Lens (Pulse' a) Level
levelP = (Pulse' a -> Level)
-> (Level -> Pulse' a -> Pulse' a) -> Lens (Pulse' a) Level
forall s a. (s -> a) -> (a -> s -> s) -> Lens s a
Lens Pulse' a -> Level
forall a. Pulse' a -> Level
_levelP (\Level
a Pulse' a
s -> Pulse' a
s { _levelP :: Level
_levelP = Level
a })

-- | Evaluation monads.
type EvalPW   = (EvalLW, [(Output, EvalO)])
type EvalLW   = Action

type EvalO    = Future (IO ())
type Future   = IO

-- Note: For efficiency reasons, we unroll the monad transformer stack.
-- type EvalP = RWST () Lazy.Vault EvalPW Build
type EvalP    = RWSIOT BuildR (EvalPW,BuildW) Lazy.Vault IO
    -- writer : (latch updates, IO action)
    -- state  : current pulse values

-- Computation with a timestamp that indicates the last time it was performed.
type EvalL    = ReaderWriterIOT () Time IO

{-----------------------------------------------------------------------------
    Show functions for debugging
------------------------------------------------------------------------------}
printNode :: SomeNode -> IO String
printNode :: SomeNode -> IO String
printNode (P Pulse a
p) = Pulse' a -> String
forall a. Pulse' a -> String
_nameP (Pulse' a -> String) -> IO (Pulse' a) -> IO String
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Pulse a -> IO (Pulse' a)
forall (m :: * -> *) a. MonadIO m => Ref a -> m a
readRef Pulse a
p
printNode (L LatchWrite
l) = String -> IO String
forall (m :: * -> *) a. Monad m => a -> m a
return String
"L"
printNode (O Output
o) = String -> IO String
forall (m :: * -> *) a. Monad m => a -> m a
return String
"O"

{-----------------------------------------------------------------------------
    Time monoid
------------------------------------------------------------------------------}
-- | A timestamp local to this program run.
--
-- Useful e.g. for controlling cache validity.
newtype Time = T Integer deriving (Time -> Time -> Bool
(Time -> Time -> Bool) -> (Time -> Time -> Bool) -> Eq Time
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
/= :: Time -> Time -> Bool
$c/= :: Time -> Time -> Bool
== :: Time -> Time -> Bool
$c== :: Time -> Time -> Bool
Eq, Eq Time
Eq Time
-> (Time -> Time -> Ordering)
-> (Time -> Time -> Bool)
-> (Time -> Time -> Bool)
-> (Time -> Time -> Bool)
-> (Time -> Time -> Bool)
-> (Time -> Time -> Time)
-> (Time -> Time -> Time)
-> Ord Time
Time -> Time -> Bool
Time -> Time -> Ordering
Time -> Time -> Time
forall a.
Eq a
-> (a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
min :: Time -> Time -> Time
$cmin :: Time -> Time -> Time
max :: Time -> Time -> Time
$cmax :: Time -> Time -> Time
>= :: Time -> Time -> Bool
$c>= :: Time -> Time -> Bool
> :: Time -> Time -> Bool
$c> :: Time -> Time -> Bool
<= :: Time -> Time -> Bool
$c<= :: Time -> Time -> Bool
< :: Time -> Time -> Bool
$c< :: Time -> Time -> Bool
compare :: Time -> Time -> Ordering
$ccompare :: Time -> Time -> Ordering
$cp1Ord :: Eq Time
Ord, Level -> Time -> ShowS
[Time] -> ShowS
Time -> String
(Level -> Time -> ShowS)
-> (Time -> String) -> ([Time] -> ShowS) -> Show Time
forall a.
(Level -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
showList :: [Time] -> ShowS
$cshowList :: [Time] -> ShowS
show :: Time -> String
$cshow :: Time -> String
showsPrec :: Level -> Time -> ShowS
$cshowsPrec :: Level -> Time -> ShowS
Show, ReadPrec [Time]
ReadPrec Time
Level -> ReadS Time
ReadS [Time]
(Level -> ReadS Time)
-> ReadS [Time] -> ReadPrec Time -> ReadPrec [Time] -> Read Time
forall a.
(Level -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
readListPrec :: ReadPrec [Time]
$creadListPrec :: ReadPrec [Time]
readPrec :: ReadPrec Time
$creadPrec :: ReadPrec Time
readList :: ReadS [Time]
$creadList :: ReadS [Time]
readsPrec :: Level -> ReadS Time
$creadsPrec :: Level -> ReadS Time
Read)

-- | Before the beginning of time. See Note [TimeStamp]
agesAgo :: Time
agesAgo :: Time
agesAgo = Integer -> Time
T (-Integer
1)

beginning :: Time
beginning :: Time
beginning = Integer -> Time
T Integer
0

next :: Time -> Time
next :: Time -> Time
next (T Integer
n) = Integer -> Time
T (Integer
nInteger -> Integer -> Integer
forall a. Num a => a -> a -> a
+Integer
1)

instance Semigroup Time where
    T Integer
x <> :: Time -> Time -> Time
<> T Integer
y = Integer -> Time
T (Integer -> Integer -> Integer
forall a. Ord a => a -> a -> a
max Integer
x Integer
y)

instance Monoid Time where
    mappend :: Time -> Time -> Time
mappend = Time -> Time -> Time
forall a. Semigroup a => a -> a -> a
(<>)
    mempty :: Time
mempty  = Time
beginning

{-----------------------------------------------------------------------------
    Notes
------------------------------------------------------------------------------}
{- Note [Timestamp]

The time stamp indicates how recent the current value is.

For Pulse:
During pulse evaluation, a time stamp equal to the current
time indicates that the pulse has already been evaluated in this phase.

For Latch:
The timestamp indicates the last time at which the latch has been written to.

    agesAgo   = The latch has never been written to.
    beginning = The latch has been written to before everything starts.

The second description is ensured by the fact that the network
writes timestamps that begin at time `next beginning`.

-}