module Edges.NodeCounts ( NodeCounts, node, nodeTargets, targets, toList, toUnboxedVector, ) where import Edges.Prelude hiding (index, toList) import Edges.Types import Edges.Cereal.Instances () import qualified PrimitiveExtras.UnfoldM as A import qualified PrimitiveExtras.Pure as C import qualified PrimitiveExtras.IO as D import qualified PrimitiveExtras.Fold as E import qualified DeferredFolds.UnfoldM as B import qualified Data.Vector.Unboxed as F import qualified Control.Monad.Par.IO as Par import qualified Control.Monad.Par as Par hiding (runParIO) instance Show (NodeCounts a) where show = show . toList node :: Edges entity anyEntity -> Node entity -> NodeCounts entity node (Edges _ edgesPma) = let size = C.primMultiArrayOuterLength edgesPma in nodeWithSize size nodeWithSize :: Int -> Node entity -> NodeCounts entity nodeWithSize size (Node index) = NodeCounts (C.oneHotPrimArray size index 1) nodeTargets :: Edges surce target -> Node source -> NodeCounts target nodeTargets (Edges targetAmount edgesPma) (Node sourceIndex) = let indexUnfold = fmap fromIntegral (A.primMultiArrayAt edgesPma sourceIndex) indexFold = E.indexCounts targetAmount countPa = B.fold indexFold indexUnfold in NodeCounts countPa {-| Count the occurrences of targets based on the occurrences of sources. Utilizes concurrency. -} targets :: Edges source target -> NodeCounts source -> NodeCounts target targets (Edges targetAmount edgesPma) (NodeCounts sourceCountsPa) = unsafePerformIO $ Par.runParIO $ do targetCountVarTable <- liftIO (D.newTVarArray 0 targetAmount) Par.parFor (Par.InclusiveRange 0 (pred (sizeofPrimArray sourceCountsPa))) $ \ sourceIndex -> case indexPrimArray sourceCountsPa sourceIndex of 0 -> return () sourceCount -> liftIO $ B.forM_ (A.primMultiArrayAt edgesPma sourceIndex) $ \ targetIndex -> D.modifyTVarArrayAt targetCountVarTable (fromIntegral targetIndex) (+ sourceCount) targetCountsPa <- liftIO (D.freezeTVarArrayAsPrimArray targetCountVarTable) return (NodeCounts targetCountsPa) toList :: NodeCounts entity -> [Word32] toList (NodeCounts pa) = foldrPrimArray' (:) [] pa toUnboxedVector :: NodeCounts entity -> F.Vector Word32 toUnboxedVector (NodeCounts pa) = C.primArrayUnboxedVector pa