{-# LANGUAGE ExistentialQuantification #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RecordWildCards #-}
module Dataflow.Primitives (
Dataflow(..),
DataflowState,
Vertex(..),
initDataflowState,
duplicateDataflowState,
StateRef,
newState,
readState,
writeState,
modifyState,
Edge,
Timestamp(..),
registerVertex,
registerFinalizer,
incrementEpoch,
input,
send,
finalize
) where
import Control.Arrow ((>>>))
import Control.Monad (forM, (>=>))
import Control.Monad.IO.Class (liftIO)
import Control.Monad.State.Strict (StateT, get, gets, modify)
import Control.Monad.Trans (lift)
import Data.Hashable (Hashable (..))
import Data.IORef (IORef, modifyIORef', newIORef,
readIORef, writeIORef)
import Data.Vector (Vector, empty, snoc, unsafeIndex)
import GHC.Exts (Any)
import Numeric.Natural (Natural)
import Prelude
import Unsafe.Coerce (unsafeCoerce)
newtype VertexID = VertexID Int deriving (Eq, Ord, Show)
newtype StateID = StateID Int deriving (Eq, Ord, Show)
newtype Epoch = Epoch Natural deriving (Eq, Ord, Hashable, Show)
newtype Timestamp = Timestamp Epoch deriving (Eq, Ord, Hashable, Show)
newtype Edge a = Edge VertexID
class Incrementable a where
inc :: a -> a
instance Incrementable VertexID where
inc (VertexID n) = VertexID (n + 1)
instance Incrementable StateID where
inc (StateID n) = StateID (n + 1)
instance Incrementable Epoch where
inc (Epoch n) = Epoch (n + 1)
data DataflowState = DataflowState {
dfsVertices :: Vector Any,
dfsStates :: Vector (IORef Any),
dfsFinalizers :: [Timestamp -> Dataflow ()],
dfsLastVertexID :: VertexID,
dfsLastStateID :: StateID,
dfsLastInputEpoch :: Epoch
}
newtype Dataflow a = Dataflow { runDataflow :: StateT DataflowState IO a }
deriving (Functor, Applicative, Monad)
initDataflowState :: DataflowState
initDataflowState = DataflowState {
dfsVertices = empty,
dfsStates = empty,
dfsFinalizers = [],
dfsLastVertexID = VertexID (-1),
dfsLastStateID = StateID (-1),
dfsLastInputEpoch = Epoch 0
}
duplicateDataflowState :: Dataflow (DataflowState)
duplicateDataflowState = Dataflow $ do
DataflowState{..} <- get
newStates <- liftIO $ forM dfsStates dupIORef
return $ DataflowState { dfsStates = newStates, .. }
where
dupIORef = readIORef >=> newIORef
incrementEpoch :: Dataflow Epoch
incrementEpoch =
Dataflow $ do
epoch <- gets (dfsLastInputEpoch >>> inc)
modify $ \s -> s { dfsLastInputEpoch = epoch }
return epoch
data Vertex i = forall s.
StatefulVertex
(StateRef s)
(StateRef s -> Timestamp -> i -> Dataflow ())
| StatelessVertex
(Timestamp -> i -> Dataflow ())
lookupVertex :: Edge i -> Dataflow (Vertex i)
lookupVertex (Edge (VertexID vindex)) =
Dataflow $ do
vertices <- gets dfsVertices
return $ unsafeCoerce (vertices `unsafeIndex` vindex)
registerVertex :: Vertex i -> Dataflow (Edge i)
registerVertex vertex =
Dataflow $ do
vid <- gets (dfsLastVertexID >>> inc)
modify $ addVertex vertex vid
return (Edge vid)
where
addVertex vtx vid s = s {
dfsVertices = dfsVertices s `snoc` unsafeCoerce vtx,
dfsLastVertexID = vid
}
registerFinalizer :: (Timestamp -> Dataflow ()) -> Dataflow ()
registerFinalizer finalizer =
Dataflow $ modify $ \s -> s { dfsFinalizers = finalizer : dfsFinalizers s }
newtype StateRef a = StateRef StateID
newState :: a -> Dataflow (StateRef a)
newState a =
Dataflow $ do
sid <- gets (dfsLastStateID >>> inc)
ioref <- lift $ newIORef (unsafeCoerce a)
modify $ addState ioref sid
return (StateRef sid)
where
addState ref sid s = s {
dfsStates = dfsStates s `snoc` ref,
dfsLastStateID = sid
}
lookupStateRef :: StateRef s -> Dataflow (IORef Any)
lookupStateRef (StateRef (StateID sindex)) =
Dataflow $ do
states <- gets dfsStates
return (states `unsafeIndex` sindex)
readState :: StateRef a -> Dataflow a
readState sref = do
ioref <- lookupStateRef sref
Dataflow $ lift (unsafeCoerce <$> readIORef ioref)
writeState :: StateRef a -> a -> Dataflow ()
writeState sref x = do
ioref <- lookupStateRef sref
Dataflow $ lift $ writeIORef ioref (unsafeCoerce x)
modifyState :: StateRef a -> (a -> a) -> Dataflow ()
modifyState sref op = do
ioref <- lookupStateRef sref
Dataflow $ lift $ modifyIORef' ioref (unsafeCoerce . op . unsafeCoerce)
{-# INLINEABLE input #-}
input :: Traversable t => t i -> Edge i -> Dataflow ()
input inputs next = do
timestamp <- Timestamp <$> incrementEpoch
mapM_ (send next timestamp) inputs
finalize timestamp
{-# INLINE send #-}
send :: Edge input -> Timestamp -> input -> Dataflow ()
send e t i = lookupVertex e >>= invoke t i
where
invoke timestamp datum (StatefulVertex sref callback) = callback sref timestamp datum
invoke timestamp datum (StatelessVertex callback) = callback timestamp datum
finalize :: Timestamp -> Dataflow ()
finalize t = do
finalizers <- Dataflow $ gets dfsFinalizers
mapM_ (\p -> p t) finalizers