{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE PatternGuards       #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.Native.Execute.Divide
-- Copyright   : [2018..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.Native.Execute.Divide (

  divideWork, divideWork1

) where

import Data.Array.Accelerate.Representation.Shape

import Data.Bits
import Data.Sequence                                                ( Seq )
import qualified Data.Sequence                                      as Seq
import qualified Data.Vector.Unboxed                                as U
import qualified Data.Vector.Unboxed.Mutable                        as M


-- Divide the given multidimensional index range into a sequence of work pieces.
-- Splits will be made on the outermost (left-most) index preferentially, so
-- that spans are longest on the innermost dimension (because caches).
--
-- No dimension will be made smaller than the given minimum.
--
-- The number of subdivisions a hint (at most, it should generate a number of
-- pieces rounded up to the next power-of-two).
--
-- Full pieces will occur first in the resulting sequence, with smaller pieces
-- at the end (suitable for work-stealing). Note that the pieces are not sorted
-- according by size, and are ordered in the resulting sequence depending only
-- on whether all dimensions are above the minimum threshold or not. The integer
-- parameter to the apply action can be used to access the chunks linearly (for
-- example, this is useful when evaluating non-commutative operations).
--
-- {-# INLINABLE divideWork #-}
divideWork
    :: ShapeR sh
    -> Int                        -- #subdivisions (hint)
    -> Int                        -- minimum size of a dimension (must be a power of two)
    -> sh                         -- start index (e.g. top-left)
    -> sh                         -- end index   (e.g. bottom-right)
    -> (Int -> sh -> sh -> a)     -- action given start/end index range, and split number in the range [0..]
    -> Seq a
divideWork :: ShapeR sh
-> Int -> Int -> sh -> sh -> (Int -> sh -> sh -> a) -> Seq a
divideWork ShapeR sh
ShapeRz              = Int -> Int -> sh -> sh -> (Int -> sh -> sh -> a) -> Seq a
forall a.
Int -> Int -> DIM0 -> DIM0 -> (Int -> DIM0 -> DIM0 -> a) -> Seq a
divideWork0
divideWork (ShapeRsnoc ShapeR sh1
ShapeRz) = Int -> Int -> sh -> sh -> (Int -> sh -> sh -> a) -> Seq a
forall a.
Int -> Int -> DIM1 -> DIM1 -> (Int -> DIM1 -> DIM1 -> a) -> Seq a
divideWork1
divideWork ShapeR sh
shr                  = ShapeR sh
-> Int -> Int -> sh -> sh -> (Int -> sh -> sh -> a) -> Seq a
forall sh a.
ShapeR sh
-> Int -> Int -> sh -> sh -> (Int -> sh -> sh -> a) -> Seq a
divideWorkN ShapeR sh
shr
  --
  -- It is slightly faster to use lists instead of a Sequence here (though the
  -- difference is <1us on 'divideWork empty (Z:.2000) nop 8 32'). However,
  -- later operations will benefit from more efficient append, etc.

divideWork0 :: Int -> Int -> DIM0 -> DIM0 -> (Int -> DIM0 -> DIM0 -> a) -> Seq a
divideWork0 :: Int -> Int -> DIM0 -> DIM0 -> (Int -> DIM0 -> DIM0 -> a) -> Seq a
divideWork0 Int
_ Int
_ () () Int -> DIM0 -> DIM0 -> a
action = a -> Seq a
forall a. a -> Seq a
Seq.singleton (Int -> DIM0 -> DIM0 -> a
action Int
0 () ())

divideWork1 :: Int -> Int -> DIM1 -> DIM1 -> (Int -> DIM1 -> DIM1 -> a) -> Seq a
divideWork1 :: Int -> Int -> DIM1 -> DIM1 -> (Int -> DIM1 -> DIM1 -> a) -> Seq a
divideWork1 !Int
pieces !Int
minsize ((), (!Int
from)) ((), (!Int
to)) Int -> DIM1 -> DIM1 -> a
action =
  let
      split :: Int -> Int -> Int -> Int -> Seq a -> Seq a -> (Int, Seq a, Seq a)
split Int
0 !Int
u !Int
v !Int
i !Seq a
f !Seq a
s
        | Int
v Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
u Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
minsize = (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1, Seq a
f, Seq a
s Seq a -> a -> Seq a
forall a. Seq a -> a -> Seq a
Seq.|> Int -> Int -> Int -> a
apply Int
i Int
u Int
v)
        | Bool
otherwise       = (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1, Seq a
f Seq a -> a -> Seq a
forall a. Seq a -> a -> Seq a
Seq.|> Int -> Int -> Int -> a
apply Int
i Int
u Int
v, Seq a
s)
      --
      split !Int
s !Int
u !Int
v !Int
i0 !Seq a
f0 !Seq a
s0 =
        case Int -> Int -> Int -> Maybe (Int, Int)
findSplitPoint1 Int
u Int
v Int
minsize of
          Maybe (Int, Int)
Nothing       -> (Int
i0Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1, Seq a
f0, Seq a
s0 Seq a -> a -> Seq a
forall a. Seq a -> a -> Seq a
Seq.|> Int -> Int -> Int -> a
apply Int
i0 Int
u Int
v)
          Just (Int
u', Int
v') ->
            let s' :: Int
s'         = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR Int
s Int
1
                (Int
i1,Seq a
f1,Seq a
s1) = Int -> Int -> Int -> Int -> Seq a -> Seq a -> (Int, Seq a, Seq a)
split Int
s' Int
u  Int
v' Int
i0 Seq a
f0 Seq a
s0
                (Int
i2,Seq a
f2,Seq a
s2) = Int -> Int -> Int -> Int -> Seq a -> Seq a -> (Int, Seq a, Seq a)
split Int
s' Int
u' Int
v  Int
i1 Seq a
f1 Seq a
s1
            in
            (Int
i2, Seq a
f2, Seq a
s2)

      apply :: Int -> Int -> Int -> a
apply Int
i Int
u Int
v = Int -> DIM1 -> DIM1 -> a
action Int
i ((), Int
u) ((), Int
v)
      (Int
_, Seq a
fs, Seq a
ss) = Int -> Int -> Int -> Int -> Seq a -> Seq a -> (Int, Seq a, Seq a)
split Int
pieces Int
from Int
to Int
0 Seq a
forall a. Seq a
Seq.empty Seq a
forall a. Seq a
Seq.empty
  in
  Seq a
fs Seq a -> Seq a -> Seq a
forall a. Seq a -> Seq a -> Seq a
Seq.>< Seq a
ss

{-# INLINE findSplitPoint1 #-}
findSplitPoint1
    :: Int
    -> Int
    -> Int
    -> Maybe (Int, Int)
findSplitPoint1 :: Int -> Int -> Int -> Maybe (Int, Int)
findSplitPoint1 !Int
u !Int
v !Int
minsize =
  let a :: Int
a = Int
v Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
u in
  if Int
a Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
minsize
    then Maybe (Int, Int)
forall a. Maybe a
Nothing
    else
      let b :: Int
b = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR (Int
aInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
1
          c :: Int
c = Int
minsize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
          d :: Int
d = (Int
bInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
c) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int -> Int
forall a. Bits a => a -> a
complement Int
c
      in
      (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
dInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
u, Int
vInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
aInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
d)


divideWorkN :: ShapeR sh -> Int -> Int -> sh -> sh -> (Int -> sh -> sh -> a) -> Seq a
divideWorkN :: ShapeR sh
-> Int -> Int -> sh -> sh -> (Int -> sh -> sh -> a) -> Seq a
divideWorkN !ShapeR sh
shr !Int
pieces !Int
minsize !sh
from !sh
to Int -> sh -> sh -> a
action =
  let
      -- Is it worth checking whether the piece is full? Doing so ensures that
      -- full pieces are assigned to threads first, with the non-full blocks
      -- being the ones at the end of the work queue to be stolen.
      --
      split :: Int
-> Vector Int
-> Vector Int
-> Int
-> Seq a
-> Seq a
-> (Int, Seq a, Seq a)
split Int
0 !Vector Int
u !Vector Int
v !Int
i !Seq a
f !Seq a
s
        | (Int -> Bool) -> Vector Int -> Bool
forall a. Unbox a => (a -> Bool) -> Vector a -> Bool
U.any (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
minsize) ((Int -> Int -> Int) -> Vector Int -> Vector Int -> Vector Int
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
U.zipWith (-) Vector Int
v Vector Int
u) = (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1, Seq a
f, Seq a
s Seq a -> a -> Seq a
forall a. Seq a -> a -> Seq a
Seq.|> Int -> Vector Int -> Vector Int -> a
apply Int
i Vector Int
u Vector Int
v)
        | Bool
otherwise                             = (Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1, Seq a
f Seq a -> a -> Seq a
forall a. Seq a -> a -> Seq a
Seq.|> Int -> Vector Int -> Vector Int -> a
apply Int
i Vector Int
u Vector Int
v, Seq a
s)
      --
      split !Int
s !Vector Int
u !Vector Int
v !Int
i0 !Seq a
f0 !Seq a
s0 =
        case Vector Int -> Vector Int -> Int -> Maybe (Vector Int, Vector Int)
findSplitPointN Vector Int
u Vector Int
v Int
minsize of
          Maybe (Vector Int, Vector Int)
Nothing       -> (Int
i0Int -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1, Seq a
f0, Seq a
s0 Seq a -> a -> Seq a
forall a. Seq a -> a -> Seq a
Seq.|> Int -> Vector Int -> Vector Int -> a
apply Int
i0 Vector Int
u Vector Int
v)
          Just (Vector Int
u', Vector Int
v') ->
            let s' :: Int
s'      = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR Int
s Int
1
                (Int
i1,Seq a
f1,Seq a
s1) = Int
-> Vector Int
-> Vector Int
-> Int
-> Seq a
-> Seq a
-> (Int, Seq a, Seq a)
split Int
s' Vector Int
u  Vector Int
v' Int
i0 Seq a
f0 Seq a
s0
                (Int
i2,Seq a
f2,Seq a
s2) = Int
-> Vector Int
-> Vector Int
-> Int
-> Seq a
-> Seq a
-> (Int, Seq a, Seq a)
split Int
s' Vector Int
u' Vector Int
v  Int
i1 Seq a
f1 Seq a
s1
            in
            (Int
i2, Seq a
f2, Seq a
s2)

      apply :: Int -> Vector Int -> Vector Int -> a
apply Int
i Vector Int
u Vector Int
v = Int -> sh -> sh -> a
action Int
i (ShapeR sh -> Vector Int -> sh
forall sh. ShapeR sh -> Vector Int -> sh
vecToShape ShapeR sh
shr Vector Int
u) (ShapeR sh -> Vector Int -> sh
forall sh. ShapeR sh -> Vector Int -> sh
vecToShape ShapeR sh
shr Vector Int
v)
      (Int
_, Seq a
fs, Seq a
ss) = Int
-> Vector Int
-> Vector Int
-> Int
-> Seq a
-> Seq a
-> (Int, Seq a, Seq a)
split Int
pieces (ShapeR sh -> sh -> Vector Int
forall sh. ShapeR sh -> sh -> Vector Int
shapeToVec ShapeR sh
shr sh
from) (ShapeR sh -> sh -> Vector Int
forall sh. ShapeR sh -> sh -> Vector Int
shapeToVec ShapeR sh
shr sh
to) Int
0 Seq a
forall a. Seq a
Seq.empty Seq a
forall a. Seq a
Seq.empty
  in
  Seq a
fs Seq a -> Seq a -> Seq a
forall a. Seq a -> Seq a -> Seq a
Seq.>< Seq a
ss


-- Determine if and where to split the given index range. Returns new start and
-- end indices if found.
--
{-# INLINE findSplitPointN #-}
findSplitPointN
    :: U.Vector Int           -- start
    -> U.Vector Int           -- end
    -> Int                    -- minimum size of a dimension (must be power of 2)
    -> Maybe (U.Vector Int, U.Vector Int)
findSplitPointN :: Vector Int -> Vector Int -> Int -> Maybe (Vector Int, Vector Int)
findSplitPointN !Vector Int
from !Vector Int
to !Int
minsize =
  let
      mix :: Maybe (Int, Int)
mix = (Int -> Int -> Maybe (Int, Int) -> Maybe (Int, Int))
-> Maybe (Int, Int) -> Vector Int -> Maybe (Int, Int)
forall a b. Unbox a => (Int -> a -> b -> b) -> b -> Vector a -> b
U.ifoldr' Int -> Int -> Maybe (Int, Int) -> Maybe (Int, Int)
combine Maybe (Int, Int)
forall a. Maybe a
Nothing
          (Vector Int -> Maybe (Int, Int)) -> Vector Int -> Maybe (Int, Int)
forall a b. (a -> b) -> a -> b
$ (Int -> Int -> Int) -> Vector Int -> Vector Int -> Vector Int
forall a b c.
(Unbox a, Unbox b, Unbox c) =>
(a -> b -> c) -> Vector a -> Vector b -> Vector c
U.zipWith (-) Vector Int
to Vector Int
from

      combine :: Int -> Int -> Maybe (Int, Int) -> Maybe (Int, Int)
combine Int
i Int
v Maybe (Int, Int)
old =
        if Int
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
minsize
          then Maybe (Int, Int)
old
          else case Maybe (Int, Int)
old of
                 Maybe (Int, Int)
Nothing    -> (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
i,Int
v)
                 Just (Int
_,Int
u) -> if Int
v Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
u
                                 then (Int, Int) -> Maybe (Int, Int)
forall a. a -> Maybe a
Just (Int
i,Int
v)
                                 else Maybe (Int, Int)
old
  in
  case Maybe (Int, Int)
mix of
    Maybe (Int, Int)
Nothing     -> Maybe (Vector Int, Vector Int)
forall a. Maybe a
Nothing
    Just (Int
i,Int
a)  ->
      let b :: Int
b     = Int -> Int -> Int
forall a. Bits a => a -> Int -> a
unsafeShiftR (Int
aInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) Int
1    -- divide by 2 (rounded up)
          c :: Int
c     = Int
minsize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
          d :: Int
d     = (Int
bInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
c) Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int -> Int
forall a. Bits a => a -> a
complement Int
c  -- round up to next multiple of chunk size
          e :: Int
e     = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
from Int
i
          f :: Int
f     = Vector Int -> Int -> Int
forall a. Unbox a => Vector a -> Int -> a
U.unsafeIndex Vector Int
to   Int
i
          --
          from' :: Vector Int
from' = (forall s. MVector s Int -> ST s DIM0) -> Vector Int -> Vector Int
forall a.
Unbox a =>
(forall s. MVector s a -> ST s DIM0) -> Vector a -> Vector a
U.modify (\MVector s Int
mv -> MVector (PrimState (ST s)) Int -> Int -> Int -> ST s DIM0
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m DIM0
M.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
mv Int
i (Int
dInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
e))   Vector Int
from
          to' :: Vector Int
to'   = (forall s. MVector s Int -> ST s DIM0) -> Vector Int -> Vector Int
forall a.
Unbox a =>
(forall s. MVector s a -> ST s DIM0) -> Vector a -> Vector a
U.modify (\MVector s Int
mv -> MVector (PrimState (ST s)) Int -> Int -> Int -> ST s DIM0
forall (m :: * -> *) a.
(PrimMonad m, Unbox a) =>
MVector (PrimState m) a -> Int -> a -> m DIM0
M.unsafeWrite MVector s Int
MVector (PrimState (ST s)) Int
mv Int
i (Int
fInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
aInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
d)) Vector Int
to
      in
      (Vector Int, Vector Int) -> Maybe (Vector Int, Vector Int)
forall a. a -> Maybe a
Just (Vector Int
from', Vector Int
to')

{-# INLINE vecToShape #-}
vecToShape :: ShapeR sh -> U.Vector Int -> sh
vecToShape :: ShapeR sh -> Vector Int -> sh
vecToShape ShapeR sh
shr = ShapeR sh -> [Int] -> sh
forall sh. HasCallStack => ShapeR sh -> [Int] -> sh
listToShape ShapeR sh
shr ([Int] -> sh) -> (Vector Int -> [Int]) -> Vector Int -> sh
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector Int -> [Int]
forall a. Unbox a => Vector a -> [a]
U.toList

{-# INLINE shapeToVec #-}
shapeToVec :: ShapeR sh -> sh -> U.Vector Int
shapeToVec :: ShapeR sh -> sh -> Vector Int
shapeToVec ShapeR sh
shr sh
sh = Int -> [Int] -> Vector Int
forall a. Unbox a => Int -> [a] -> Vector a
U.fromListN (ShapeR sh -> Int
forall sh. ShapeR sh -> Int
rank ShapeR sh
shr) (ShapeR sh -> sh -> [Int]
forall sh. ShapeR sh -> sh -> [Int]
shapeToList ShapeR sh
shr sh
sh)