-- |
-- Module    : G500.Generate
-- Copyright : (C) 2013 Parallel Scientific Labs, LLC.
-- License   : GPLv2
--
-- A generator for Graph500 benchmark. Translated from Graph500 specification in GNU Octave.

{-# LANGUAGE BangPatterns #-}

module G500.Generate
    ( generate
      -- $spec
    ) where

import Control.Concurrent
import Control.Monad
import Control.Monad.State
import Data.Array.IO
import Data.Bits
import System.Random.Mersenne.Pure64

import G500.Index

-------------------------------------------------------------------------------
-- Documentation.

{- $spec

@
%% Original specification in GNU Octave commented by sergueyz (wherever he feels original is too terse).
%%
%% GNU Octave expressions part of manual: <http://sunsite.univie.ac.at/textbooks/octave/octave_9.html>

function ij = kronecker_generator (SCALE, edgefactor)
%% Generate an edgelist according to the Graph500
%% parameters.  In this sample, the edge list is
%% returned in an array with two rows, where StartVertex
%% is first row and EndVertex is the second.  The vertex
%% labels start at zero.
%%
%% Example, creating a sparse matrix for viewing:
%%   ij = kronecker_generator (10, 16);
%%   G = sparse (ij(1,:)+1, ij(2,:)+1, ones (1, size (ij, 2)));
%%   spy (G);
%% The spy plot should appear fairly dense. Any locality
%% is removed by the final permutations.

  %% Set number of vertices.
  N = 2^SCALE;

  %% Set number of edges.
  M = edgefactor * N;

  %% Set initiator probabilities.
  [A, B, C] = deal (0.57, 0.19, 0.19); %% Just a tuple assignment.

  %% Create index arrays.
  ij = ones (2, M); %% 2xM of ones.

  %% Probabilities.
  ab = A + B;
  c_norm = C/(1 - (A + B));
  a_norm = A/(A + B);

  %% Loop over each order of bit.
  for ib = 1:SCALE,
    %% Compare with probabilities and set bits of indices.
    ii_bit = rand (1, M) > ab; %% either 0 or 1
    jj_bit = rand (1, M) > ( c_norm * ii_bit + a_norm * not (ii_bit) ); %% either 0 or 1.

    %% please see that ij is one-based. We add current power of two to sums in ij.
    %% each ij(:,:) lies in range 1..2^SCALE.
    ij = ij + 2^(ib-1) * [ii_bit; jj_bit];
  end

  %% Permute vertex labels
  p = randperm (N); %% a column with numbers 1..N.
  ij = p(ij); %% the most appropriate meaning here is ij(a,b) = p(ij(a,b)).
              %% please correct me if I am wrong.

  %% Permute the edge list
  p = randperm (M);
  ij = ij(:, p); %% the most appropriate meaning here is ij(a,b) = ij(a,p(b)).
                 %% please correct me if I am wrong.

  %% Adjust to zero-based labels.
  ij = ij - 1;
@
-}

-------------------------------------------------------------------------------
-- Monad definition and helpers.

type GenM a = StateT PureMT IO a

a, b, c, ab, c_norm, a_norm:: Float
(a, b, c) = (0.57, 0.19, 0.19 :: Float)
ab = a + b
c_norm = c / (1-ab)
a_norm = a / ab

-- |Generates index that is in range 0..maxIndex. indexMask should be power of two minus 1.
-- See @randomIndex@.
genRandomIndexMask :: Index -> GenM Index
genRandomIndexMask indexMask = do
	g <- get
	let (!r,!g') = randomIndex indexMask g
	put g'
	return r

genRandomIndex :: Index -> GenM Index
genRandomIndex maxIndex = do
	v <- genRandomIndexMask (mask 1)
	if v < maxIndex then return v else genRandomIndex maxIndex
	where
		mask n
			| n > maxIndex = n-1
			| otherwise = mask (2*n)

-- |Just an usual random float with randomness in full 24 bits.
getRandomFloat :: GenM Float
getRandomFloat = do
	g <- get
	let (!i,!g') = randomInt64 g
	put g'
	return $ fromIntegral (i .&. max') / fromIntegral (max' + 1)
	where
		max' = 0x3fffffff

-- |This is heart of data set generation code.
-- This generates either 0 or 1 for current (power of two) weights for
-- edge start and end indices.
ii_jj_bits :: GenM (Index, Index)
ii_jj_bits = do
	iiR <- getRandomFloat
	jjR <- getRandomFloat
	let iiBit = iiR > ab
	let jjThresh = if iiBit then c_norm else a_norm
		-- c_norm * iiBitFloat + a_norm * (1-iiBitFloat)
	let jjBit = jjR > jjThresh
	return (fromBool iiBit, fromBool jjBit)
	where
		fromBool = fromIntegral . fromEnum

-------------------------------------------------------------------------------
-- Main driver.

type GraphArr = IOUArray Index Index

generate :: Int -- ^ Scale
         -> Int -- ^ Edge Factor
         -> IO (GraphArr, GraphArr)
generate scale edgeFactor = do
	start <- newIndexArr
	end <- newIndexArr
	std <- newPureMT
	runGenM std $ go start end
	return (start,end)
	where
		runGenM g genM = runStateT genM g >> return ()
		newIndexArr = liftIO $ newArray (0,maxEdgeIndex) 0
		n = shiftL 1 scale
		m = n*fromIntegral edgeFactor
		maxEdgeIndex = m-1
		maxIndex = n-1


		incrIndex arr i incr = do
			v <- readArray arr i
			writeArray arr i (v+incr)
		go start end = do
			gen (shiftL 1 (scale - 1)) start end
			p <- permutation maxIndex
			permute maxEdgeIndex p start
			permute maxEdgeIndex p end
			p1 <- permutation maxEdgeIndex
			permuteIndices maxEdgeIndex p1 end
			return ()

		permutation :: Index -> GenM GraphArr
		permutation maxIndex' = do
			p <- liftIO $ newArray (0,maxIndex') 0
			liftIO $ forM_ [0..maxIndex'] $ \i -> writeArray p i i
			forM_ (concat $ replicate 1 [0..maxIndex']) $ \i -> do
				j <- genRandomIndex maxIndex'
				liftIO $ do
					a1 <- readArray p i
					b1 <- readArray p j
					writeArray p i b1
					writeArray p j a1
			return p

		permute :: Index -> GraphArr -> GraphArr -> GenM ()
		permute n1 p arr = liftIO $ forM_ [0..n1] $ \i -> do
			a1 <- readArray arr i
			pa <- readArray p a1
			writeArray arr i pa

		permuteIndices :: Index -> GraphArr -> GraphArr -> GenM ()
		permuteIndices n1 p arr = liftIO $ forM_ [0..n1] $ \i -> do
			j <- readArray p i
			a1 <- readArray arr i
			b1 <- readArray arr j
			writeArray arr i b1
			writeArray arr j a1

		gen pow2 start end
			| pow2 < 1 = return ()
			| otherwise = do
				-- forM_ [0..maxEdgeIndex] $ genBit pow2 start end
				parallelGeneration pow2 start end
				gen (shiftR pow2 1) start end
		parallelPortion :: Index
		parallelPortion = fromIntegral edgeFactor*1024
		parallelGeneration pow2 start end = do
			let number = fromIntegral $ div (maxEdgeIndex+1) parallelPortion
			runningCount <- liftIO $ newMVar (number :: Int)
			forM_ [0..fromIntegral number - 1] $ \nt -> do
				threadG <- lift newPureMT
				let startI = parallelPortion * nt
				let endI = parallelPortion * (nt+1) - 1
				liftIO $ forkIO $ runGenM threadG $ do
					forM_ [startI..endI] $
						genBit pow2 start end
					liftIO $ modifyMVar_ runningCount $ return . (+(-1))
			let wait = do
				n1 <- takeMVar runningCount
				if n1 > 0 then do
						putMVar runningCount n1
						yield
						wait
					else return ()
			liftIO wait
		genBit pow2 start end i = do
			(startBit, endBit) <- ii_jj_bits
			liftIO $ do
				incrIndex start i (startBit * pow2)
				incrIndex end   i (endBit * pow2)