module Edges.NodeCounts
(
NodeCounts,
node,
nodeTargets,
targets,
toList,
toUnboxedVector,
)
where
import Edges.Prelude hiding (index, toList)
import Edges.Types
import Edges.Cereal.Instances ()
import qualified PrimitiveExtras.UnfoldM as A
import qualified PrimitiveExtras.Pure as C
import qualified PrimitiveExtras.IO as D
import qualified PrimitiveExtras.Fold as E
import qualified DeferredFolds.UnfoldM as B
import qualified Data.Vector.Unboxed as F
import qualified Control.Monad.Par.IO as Par
import qualified Control.Monad.Par as Par hiding (runParIO)
instance Show (NodeCounts a) where
show = show . toList
node :: Edges entity anyEntity -> Node entity -> NodeCounts entity
node (Edges _ edgesPma) =
let size = C.primMultiArrayOuterLength edgesPma
in nodeWithSize size
nodeWithSize :: Int -> Node entity -> NodeCounts entity
nodeWithSize size (Node index) =
NodeCounts (C.oneHotPrimArray size index 1)
nodeTargets :: Edges surce target -> Node source -> NodeCounts target
nodeTargets (Edges targetAmount edgesPma) (Node sourceIndex) =
let indexUnfold = fmap fromIntegral (A.primMultiArrayAt edgesPma sourceIndex)
indexFold = E.indexCounts targetAmount
countPa = B.fold indexFold indexUnfold
in NodeCounts countPa
targets :: Edges source target -> NodeCounts source -> NodeCounts target
targets (Edges targetAmount edgesPma) (NodeCounts sourceCountsPa) =
unsafePerformIO $ Par.runParIO $ do
targetCountVarTable <- liftIO (D.newTVarArray 0 targetAmount)
Par.parFor (Par.InclusiveRange 0 (pred (sizeofPrimArray sourceCountsPa))) $ \ sourceIndex ->
case indexPrimArray sourceCountsPa sourceIndex of
0 -> return ()
sourceCount ->
liftIO $
B.forM_ (A.primMultiArrayAt edgesPma sourceIndex) $ \ targetIndex ->
D.modifyTVarArrayAt targetCountVarTable (fromIntegral targetIndex) (+ sourceCount)
targetCountsPa <- liftIO (D.freezeTVarArrayAsPrimArray targetCountVarTable)
return (NodeCounts targetCountsPa)
toList :: NodeCounts entity -> [Word32]
toList (NodeCounts pa) = foldrPrimArray' (:) [] pa
toUnboxedVector :: NodeCounts entity -> F.Vector Word32
toUnboxedVector (NodeCounts pa) = C.primArrayUnboxedVector pa