{-# LANGUAGE FunctionalDependencies #-}
module Data.Vector.Algorithms.Quicksort.Fork2
(
Fork2(..)
, Sequential(..)
, Parallel
, mkParallel
, waitParallel
, ParStrategies
, defaultParStrategies
, setParStrategiesCutoff
, HasLength
, getLength
) where
import GHC.Conc (par, pseq)
import Control.Concurrent
import Control.Concurrent.STM
import Control.Monad.ST
import Data.Bits
import Data.Vector.Generic.Mutable qualified as GM
import GHC.ST (unsafeInterleaveST)
import System.IO.Unsafe
class Fork2 a x m | a -> x where
startWork :: a -> m x
endWork :: a -> x -> m ()
fork2
:: (HasLength b, HasLength d)
=> a
-> x
-> Int
-> (x -> b -> m ())
-> (x -> d -> m ())
-> b
-> d
-> m ()
data Sequential = Sequential
instance Monad m => Fork2 Sequential () m where
{-# INLINE startWork #-}
{-# INLINE endWork #-}
{-# INLINE fork2 #-}
startWork :: Sequential -> m ()
startWork Sequential
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
endWork :: Sequential -> () -> m ()
endWork Sequential
_ ()
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
fork2 :: forall b d.
(HasLength b, HasLength d) =>
Sequential
-> ()
-> Int
-> (() -> b -> m ())
-> (() -> d -> m ())
-> b
-> d
-> m ()
fork2 Sequential
_ ()
tok Int
_ () -> b -> m ()
f () -> d -> m ()
g !b
b !d
d = () -> b -> m ()
f ()
tok b
b forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> () -> d -> m ()
g ()
tok d
d
data Parallel = Parallel !Int !(TVar Int)
mkParallel :: Int -> IO Parallel
mkParallel :: Int -> IO Parallel
mkParallel Int
jobs =
Int -> TVar Int -> Parallel
Parallel Int
jobs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall a. a -> IO (TVar a)
newTVarIO Int
0
addPending :: Parallel -> IO ()
addPending :: Parallel -> IO ()
addPending (Parallel Int
_ TVar Int
pending) =
forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
pending (forall a. Num a => a -> a -> a
+ Int
1)
removePending :: Parallel -> IO ()
removePending :: Parallel -> IO ()
removePending (Parallel Int
_ TVar Int
pending) =
forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ forall a. TVar a -> (a -> a) -> STM ()
modifyTVar' TVar Int
pending forall a b. (a -> b) -> a -> b
$ \Int
x -> Int
x forall a. Num a => a -> a -> a
- Int
1
waitParallel :: Parallel -> IO ()
waitParallel :: Parallel -> IO ()
waitParallel (Parallel Int
_ TVar Int
pending) = forall a. STM a -> IO a
atomically forall a b. (a -> b) -> a -> b
$ do
Int
m <- forall a. TVar a -> STM a
readTVar TVar Int
pending
if Int
m forall a. Eq a => a -> a -> Bool
== Int
0
then forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
else forall a. STM a
retry
instance Fork2 Parallel (Bool, Bool) IO where
{-# INLINE startWork #-}
{-# INLINE endWork #-}
{-# INLINE fork2 #-}
startWork :: Parallel -> IO (Bool, Bool)
startWork !Parallel
p = do
Parallel -> IO ()
addPending Parallel
p
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool
False, Bool
True)
endWork :: Parallel -> (Bool, Bool) -> IO ()
endWork Parallel
p (Bool
_, Bool
shouldDecrement)
| Bool
shouldDecrement
= Parallel -> IO ()
removePending Parallel
p
| Bool
otherwise
= forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
fork2
:: forall b d. (HasLength b, HasLength d)
=> Parallel
-> (Bool, Bool)
-> Int
-> ((Bool, Bool) -> b -> IO ())
-> ((Bool, Bool) -> d -> IO ())
-> b
-> d
-> IO ()
fork2 :: forall b d.
(HasLength b, HasLength d) =>
Parallel
-> (Bool, Bool)
-> Int
-> ((Bool, Bool) -> b -> IO ())
-> ((Bool, Bool) -> d -> IO ())
-> b
-> d
-> IO ()
fork2 !p :: Parallel
p@(Parallel Int
jobs TVar Int
_) tok :: (Bool, Bool)
tok@(!Bool
isSeq, Bool
shouldDecrement) !Int
depth (Bool, Bool) -> b -> IO ()
f (Bool, Bool) -> d -> IO ()
g !b
b !d
d
| Bool
isSeq
= (Bool, Bool) -> b -> IO ()
f (Bool
True, Bool
False) b
b forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> (Bool, Bool) -> d -> IO ()
g (Bool, Bool)
tok d
d
| Int
2 forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
depth forall a. Ord a => a -> a -> Bool
< Int
jobs Bool -> Bool -> Bool
&& Int
mn forall a. Ord a => a -> a -> Bool
> Int
10_000
= do
Parallel -> IO ()
addPending Parallel
p
ThreadId
_ <- IO () -> IO ThreadId
forkIO forall a b. (a -> b) -> a -> b
$ (Bool, Bool) -> b -> IO ()
f (Bool
False, Bool
True) b
b
(Bool, Bool) -> d -> IO ()
g (Bool, Bool)
tok d
d
| Int
bLen forall a. Ord a => a -> a -> Bool
> Int
dLen
= (Bool, Bool) -> b -> IO ()
f (Bool
False, Bool
False) b
b forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> (Bool, Bool) -> d -> IO ()
g (Bool
True, Bool
shouldDecrement) d
d
| Bool
otherwise
= (Bool, Bool) -> d -> IO ()
g (Bool
False, Bool
False) d
d forall (f :: * -> *) a b. Applicative f => f a -> f b -> f b
*> (Bool, Bool) -> b -> IO ()
f (Bool
True, Bool
shouldDecrement) b
b
where
bLen, dLen :: Int
!bLen :: Int
bLen = forall a. HasLength a => a -> Int
getLength b
b
!dLen :: Int
dLen = forall a. HasLength a => a -> Int
getLength d
d
!mn :: Int
mn = forall a. Ord a => a -> a -> a
min Int
bLen Int
dLen
data ParStrategies = ParStrategies !Int
defaultParStrategies :: ParStrategies
defaultParStrategies :: ParStrategies
defaultParStrategies = Int -> ParStrategies
ParStrategies Int
10_000
setParStrategiesCutoff :: Int -> ParStrategies -> ParStrategies
setParStrategiesCutoff :: Int -> ParStrategies -> ParStrategies
setParStrategiesCutoff Int
n ParStrategies
_ = Int -> ParStrategies
ParStrategies Int
n
instance Fork2 ParStrategies () IO where
{-# INLINE startWork #-}
{-# INLINE endWork #-}
{-# INLINE fork2 #-}
startWork :: ParStrategies -> IO ()
startWork ParStrategies
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
endWork :: ParStrategies -> () -> IO ()
endWork ParStrategies
_ ()
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
fork2
:: forall b d. (HasLength b, HasLength d)
=> ParStrategies
-> ()
-> Int
-> (() -> b -> IO ())
-> (() -> d -> IO ())
-> b
-> d
-> IO ()
fork2 :: forall b d.
(HasLength b, HasLength d) =>
ParStrategies
-> ()
-> Int
-> (() -> b -> IO ())
-> (() -> d -> IO ())
-> b
-> d
-> IO ()
fork2 !(ParStrategies Int
cutoff) ()
_ Int
_ () -> b -> IO ()
f () -> d -> IO ()
g !b
b !d
d
| Int
mn forall a. Ord a => a -> a -> Bool
> Int
cutoff
= do
let b' :: ()
b' = forall a. IO a -> a
unsafePerformIO forall a b. (a -> b) -> a -> b
$ () -> b -> IO ()
f () b
b
()
d' <- ()
b' forall a b. a -> b -> b
`par` () -> d -> IO ()
g () d
d
forall (f :: * -> *) a. Applicative f => a -> f a
pure (()
b' forall a b. a -> b -> b
`pseq` (()
d' forall a b. a -> b -> b
`pseq` ()))
| Bool
otherwise
= do
()
b' <- () -> b -> IO ()
f () b
b
()
d' <- () -> d -> IO ()
g () d
d
forall (f :: * -> *) a. Applicative f => a -> f a
pure (()
b' forall a b. a -> b -> b
`pseq` (()
d' forall a b. a -> b -> b
`pseq` ()))
where
bLen, dLen :: Int
!bLen :: Int
bLen = forall a. HasLength a => a -> Int
getLength b
b
!dLen :: Int
dLen = forall a. HasLength a => a -> Int
getLength d
d
!mn :: Int
mn = forall a. Ord a => a -> a -> a
min Int
bLen Int
dLen
instance Fork2 ParStrategies () (ST s) where
{-# INLINE startWork #-}
{-# INLINE endWork #-}
{-# INLINE fork2 #-}
startWork :: ParStrategies -> ST s ()
startWork ParStrategies
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
endWork :: ParStrategies -> () -> ST s ()
endWork ParStrategies
_ ()
_ = forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
fork2
:: forall b d. (HasLength b, HasLength d)
=> ParStrategies
-> ()
-> Int
-> (() -> b -> ST s ())
-> (() -> d -> ST s ())
-> b
-> d
-> ST s ()
fork2 :: forall b d.
(HasLength b, HasLength d) =>
ParStrategies
-> ()
-> Int
-> (() -> b -> ST s ())
-> (() -> d -> ST s ())
-> b
-> d
-> ST s ()
fork2 !(ParStrategies Int
cutoff) ()
_ Int
_ () -> b -> ST s ()
f () -> d -> ST s ()
g !b
b !d
d
| Int
mn forall a. Ord a => a -> a -> Bool
> Int
cutoff
= do
()
b' <- forall s a. ST s a -> ST s a
unsafeInterleaveST forall a b. (a -> b) -> a -> b
$ () -> b -> ST s ()
f () b
b
()
d' <- ()
b' forall a b. a -> b -> b
`par` () -> d -> ST s ()
g () d
d
forall (f :: * -> *) a. Applicative f => a -> f a
pure (()
b' forall a b. a -> b -> b
`pseq` (()
d' forall a b. a -> b -> b
`pseq` ()))
| Bool
otherwise
= do
()
b' <- () -> b -> ST s ()
f () b
b
()
d' <- () -> d -> ST s ()
g () d
d
forall (f :: * -> *) a. Applicative f => a -> f a
pure (()
b' forall a b. a -> b -> b
`pseq` (()
d' forall a b. a -> b -> b
`pseq` ()))
where
bLen, dLen :: Int
!bLen :: Int
bLen = forall a. HasLength a => a -> Int
getLength b
b
!dLen :: Int
dLen = forall a. HasLength a => a -> Int
getLength d
d
!mn :: Int
mn = forall a. Ord a => a -> a -> a
min Int
bLen Int
dLen
class HasLength a where
getLength :: a -> Int
instance GM.MVector v a => HasLength (v s a) where
{-# INLINE getLength #-}
getLength :: v s a -> Int
getLength = forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
GM.length