{-# LANGUAGE ExistentialQuantification #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE RecordWildCards #-} module Dataflow.Primitives ( Dataflow(..), DataflowState, Vertex(..), initDataflowState, duplicateDataflowState, StateRef, newState, readState, writeState, modifyState, Edge, Timestamp(..), registerVertex, registerFinalizer, incrementEpoch, input, send, finalize ) where import Control.Arrow ((>>>)) import Control.Monad (forM, (>=>)) import Control.Monad.IO.Class (liftIO) import Control.Monad.State.Strict (StateT, get, gets, modify) import Control.Monad.Trans (lift) import Data.Hashable (Hashable (..)) import Data.IORef (IORef, atomicModifyIORef', atomicWriteIORef, newIORef, readIORef) import Data.Vector (Vector, empty, snoc, (!)) import Numeric.Natural (Natural) import Prelude import Unsafe.Coerce (unsafeCoerce) newtype VertexID = VertexID Int deriving (Eq, Ord, Show) newtype StateID = StateID Int deriving (Eq, Ord, Show) newtype Epoch = Epoch Natural deriving (Eq, Ord, Hashable, Show) -- | 'Timestamp's represent instants in the causal timeline. -- -- @since 0.1.0.0 newtype Timestamp = Timestamp Epoch deriving (Eq, Ord, Hashable, Show) -- | An 'Edge' is a typed reference to a computational vertex that -- takes 'a's as its input. -- -- @since 0.1.0.0 newtype Edge a = Edge VertexID -- | Class of entities that can be incremented by one. class Incrementable a where inc :: a -> a instance Incrementable VertexID where inc (VertexID n) = VertexID (n + 1) instance Incrementable StateID where inc (StateID n) = StateID (n + 1) instance Incrementable Epoch where inc (Epoch n) = Epoch (n + 1) -- | 'ErasedType' erases the type it wraps. data ErasedType = forall i. EraseType i unEraseType :: ErasedType -> a unEraseType (EraseType x) = unsafeCoerce x data DataflowState = DataflowState { dfsVertices :: Vector ErasedType, dfsStates :: Vector (IORef ErasedType), dfsFinalizers :: [Timestamp -> Dataflow ()], dfsLastVertexID :: VertexID, dfsLastStateID :: StateID, dfsLastInputEpoch :: Epoch } -- | `Dataflow` is the type of all dataflow operations. -- -- @since 0.1.0.0 newtype Dataflow a = Dataflow { runDataflow :: StateT DataflowState IO a } deriving (Functor, Applicative, Monad) initDataflowState :: DataflowState initDataflowState = DataflowState { dfsVertices = empty, dfsStates = empty, dfsFinalizers = [], dfsLastVertexID = VertexID (-1), dfsLastStateID = StateID (-1), dfsLastInputEpoch = Epoch 0 } duplicateDataflowState :: Dataflow (DataflowState) duplicateDataflowState = Dataflow $ do DataflowState{..} <- get newStates <- liftIO $ forM dfsStates dupIORef return $ DataflowState { dfsStates = newStates, .. } where dupIORef = readIORef >=> newIORef -- | Get the next input Epoch. incrementEpoch :: Dataflow Epoch incrementEpoch = Dataflow $ do epoch <- gets (dfsLastInputEpoch >>> inc) modify $ \s -> s { dfsLastInputEpoch = epoch } return epoch data Vertex i = forall s. StatefulVertex (StateRef s) (StateRef s -> Timestamp -> i -> Dataflow ()) | StatelessVertex (Timestamp -> i -> Dataflow ()) -- | Retrieve the vertex for a given edge. lookupVertex :: Edge i -> Dataflow (Vertex i) lookupVertex (Edge (VertexID vindex)) = Dataflow $ do vertices <- gets dfsVertices return $ unEraseType (vertices ! vindex) -- | Store a provided vertex and obtain an 'Edge' that refers to it. registerVertex :: Vertex i -> Dataflow (Edge i) registerVertex vertex = Dataflow $ do vid <- gets (dfsLastVertexID >>> inc) modify $ addVertex vertex vid return (Edge vid) where addVertex vtx vid s = s { dfsVertices = dfsVertices s `snoc` EraseType vtx, dfsLastVertexID = vid } -- | Store a provided finalizer. registerFinalizer :: (Timestamp -> Dataflow ()) -> Dataflow () registerFinalizer finalizer = Dataflow $ modify $ \s -> s { dfsFinalizers = finalizer : dfsFinalizers s } -- | Mutable state that holds an `a`. -- -- @since 0.1.0.0 newtype StateRef a = StateRef StateID -- | Create a `StateRef` initialized to the provided `a`. -- -- @since 0.1.0.0 newState :: a -> Dataflow (StateRef a) newState a = Dataflow $ do sid <- gets (dfsLastStateID >>> inc) ioref <- lift $ newIORef (EraseType a) modify $ addState ioref sid return (StateRef sid) where addState ref sid s = s { dfsStates = dfsStates s `snoc` ref, dfsLastStateID = sid } lookupStateRef :: StateRef s -> Dataflow (IORef ErasedType) lookupStateRef (StateRef (StateID sindex)) = Dataflow $ do states <- gets dfsStates return (states ! sindex) -- | Read the value stored in the `StateRef`. -- -- @since 0.1.0.0 readState :: StateRef a -> Dataflow a readState sref = do ioref <- lookupStateRef sref Dataflow $ lift $ (unEraseType <$> readIORef ioref) -- | Overwrite the value stored in the `StateRef`. -- -- @since 0.1.0.0 writeState :: StateRef a -> a -> Dataflow () writeState sref x = do ioref <- lookupStateRef sref Dataflow $ lift $ atomicWriteIORef ioref (EraseType x) -- | Update the value stored in `StateRef`. -- -- @since 0.1.0.0 modifyState :: StateRef a -> (a -> a) -> Dataflow () modifyState sref op = do ioref <- lookupStateRef sref Dataflow $ lift $ atomicModifyIORef' ioref (\x -> (EraseType $ op (unEraseType x), ())) {-# INLINEABLE input #-} input :: Traversable t => t i -> Edge i -> Dataflow () input inputs next = do timestamp <- Timestamp <$> incrementEpoch mapM_ (send next timestamp) inputs finalize timestamp {-# INLINE send #-} -- | Send an `input` item to be worked on to the indicated vertex. -- -- @since 0.1.0.0 send :: Edge input -> Timestamp -> input -> Dataflow () send e t i = lookupVertex e >>= invoke t i where invoke timestamp datum (StatefulVertex sref callback) = callback sref timestamp datum invoke timestamp datum (StatelessVertex callback) = callback timestamp datum -- Notify all relevant vertices that no more input is coming for `Timestamp`. -- -- @since 0.1.0.0 finalize :: Timestamp -> Dataflow () finalize t = do finalizers <- Dataflow $ gets dfsFinalizers mapM_ (\p -> p t) finalizers