{-# OPTIONS -Wall -fno-warn-orphans -fno-warn-missing-signatures #-} {-# LANGUAGE ScopedTypeVariables #-} -- | Distributed ST computations. -- -- Computations of type 'DistST' are data-parallel computations which -- are run on each thread of a gang. At the moment, they can only access the -- element of a (possibly mutable) distributed value owned by the current -- thread. -- -- /TODO:/ Add facilities for implementing parallel scans etc. module Data.Array.Parallel.Unlifted.Distributed.DistST ( DistST , stToDistST , distST_, distST , runDistST, runDistST_seq , traceDistST , myIndex , myD , readMyMD, writeMyMD) where import Data.Array.Parallel.Base (ST, runST) import Data.Array.Parallel.Unlifted.Distributed.Gang import Data.Array.Parallel.Unlifted.Distributed.Types (DT(..), Dist, MDist) import Control.Monad (liftM) -- | Data-parallel computations. -- When applied to a thread gang, the computation implicitly knows the index -- of the thread it's working on. Alternatively, if we know the thread index -- then we can make a regular ST computation. newtype DistST s a = DistST { unDistST :: Int -> ST s a } instance Monad (DistST s) where {-# INLINE return #-} return = DistST . const . return {-# INLINE (>>=) #-} DistST p >>= f = DistST $ \i -> do x <- p i unDistST (f x) i -- | Yields the index of the current thread within its gang. myIndex :: DistST s Int myIndex = DistST return {-# INLINE myIndex #-} -- | Lifts an 'ST' computation into the 'DistST' monad. -- The lifted computation should be data parallel. stToDistST :: ST s a -> DistST s a stToDistST p = DistST $ \_ -> p {-# INLINE stToDistST #-} -- | Yields the 'Dist' element owned by the current thread. myD :: DT a => Dist a -> DistST s a myD dt = liftM (indexD "myD" dt) myIndex {-# NOINLINE myD #-} -- | Yields the 'MDist' element owned by the current thread. readMyMD :: DT a => MDist a s -> DistST s a readMyMD mdt = do i <- myIndex stToDistST $ readMD mdt i {-# NOINLINE readMyMD #-} -- | Writes the 'MDist' element owned by the current thread. writeMyMD :: DT a => MDist a s -> a -> DistST s () writeMyMD mdt x = do i <- myIndex stToDistST $ writeMD mdt i x {-# NOINLINE writeMyMD #-} -- | Execute a data-parallel computation on a 'Gang'. -- The same DistST comutation runs on each thread. distST_ :: Gang -> DistST s () -> ST s () distST_ g = gangST g . unDistST {-# INLINE distST_ #-} -- | Execute a data-parallel computation, yielding the distributed result. distST :: DT a => Gang -> DistST s a -> ST s (Dist a) distST g p = do md <- newMD g distST_ g $ writeMyMD md =<< p unsafeFreezeMD md {-# INLINE distST #-} -- | Run a data-parallel computation, yielding the distributed result. runDistST :: DT a => Gang -> (forall s. DistST s a) -> Dist a runDistST g p = runST (distST g p) {-# NOINLINE runDistST #-} runDistST_seq :: forall a. DT a => Gang -> (forall s. DistST s a) -> Dist a runDistST_seq g p = runST ( do md <- newMD g go md 0 unsafeFreezeMD md) where !n = gangSize g go :: forall s. MDist a s -> Int -> ST s () go md i | i < n = do writeMD md i =<< unDistST p i go md (i+1) | otherwise = return () {-# NOINLINE runDistST_seq #-} traceDistST :: String -> DistST s () traceDistST s = DistST $ \n -> traceGangST ("Worker " ++ show n ++ ": " ++ s)