{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeOperators       #-}
{-# LANGUAGE ViewPatterns        #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.Native.CodeGen.FoldSeg
-- Copyright   : [2014..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.Native.CodeGen.FoldSeg
  where

import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Type

import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic                as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.Compile.Cache

import Data.Array.Accelerate.LLVM.Native.CodeGen.Base
import Data.Array.Accelerate.LLVM.Native.CodeGen.Fold
import Data.Array.Accelerate.LLVM.Native.CodeGen.Loop
import Data.Array.Accelerate.LLVM.Native.Target                     ( Native )

import Control.Monad
import Prelude                                                      as P

{--
-- Segmented reduction where a single processor reduces the entire array. The
-- segments array contains the length of each segment.
--
mkFoldSegS
    :: forall aenv sh i e. (Shape sh, IsIntegral i, Elt i, Elt e)
    => UID
    -> Gamma             aenv
    -> IRFun2     Native aenv (e -> e -> e)
    -> MIRExp     Native aenv e
    -> MIRDelayed Native aenv (Array (sh :. Int) e)
    -> MIRDelayed Native aenv (Segments i)
    -> CodeGen    Native      (IROpenAcc Native aenv (Array (sh :. Int) e))
mkFoldSegS uid aenv combine mseed marr mseg =
  let
      (start, end, paramGang) = gangParam @DIM1
      (arrOut, paramOut)      = mutableArray @(sh:.Int) "out"
      (arrIn,  paramIn)       = delayedArray @(sh:.Int) "in"  marr
      (arrSeg, paramSeg)      = delayedArray @DIM1      "seg" mseg
      paramEnv                = envParam aenv
  in
  makeOpenAcc uid "foldSegS" (paramGang ++ paramOut ++ paramIn ++ paramSeg ++ paramEnv) $ do

    -- Number of segments, useful only if reducing DIM2 and higher
    ss <- indexHead <$> delayedExtent arrSeg

    let test si = A.lt singleType (A.fst si) (indexHead end)
        initial = A.pair (indexHead start) (lift 0)

        body :: IR (Int,Int) -> CodeGen Native (IR (Int,Int))
        body (A.unpair -> (s,inf)) = do
          -- We can avoid an extra division if this is a DIM1 array. Higher
          -- dimensional reductions need to wrap around the segment array at
          -- each new lower-dimensional index.
          s'  <- case rank @sh of
                   0 -> return s
                   _ -> A.rem integralType s ss

          len <- A.fromIntegral integralType numType =<< app1 (delayedLinearIndex arrSeg) s'
          sup <- A.add numType inf len

          r   <- case mseed of
                   Just seed -> do z <- seed
                                   reduceFromTo  inf sup (app2 combine) z (app1 (delayedLinearIndex arrIn))
                   Nothing   ->    reduce1FromTo inf sup (app2 combine)   (app1 (delayedLinearIndex arrIn))
          writeArray arrOut s r

          t <- A.add numType s (lift 1)
          return $ A.pair t sup

    void $ while test body initial
    return_
--}


-- Segmented reduction along the innermost dimension of an array. Performs one
-- reduction per segment of the source array. When no seed is given, assumes
-- that /all/ segments are non-empty.
--
-- This implementation assumes that the segments array represents the offset
-- indices to the source array, rather than the lengths of each segment. The
-- segment-offset approach is required for parallel implementations.
--
mkFoldSeg
    :: UID
    -> Gamma             aenv
    -> ArrayR (Array (sh, Int) e)
    -> IntegralType i
    -> IRFun2     Native aenv (e -> e -> e)
    -> MIRExp     Native aenv e
    -> MIRDelayed Native aenv (Array (sh, Int) e)
    -> MIRDelayed Native aenv (Segments i)
    -> CodeGen    Native      (IROpenAcc Native aenv (Array (sh, Int) e))
mkFoldSeg :: UID
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 Native aenv (e -> e -> e)
-> MIRExp Native aenv e
-> MIRDelayed Native aenv (Array (sh, Int) e)
-> MIRDelayed Native aenv (Segments i)
-> CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))
mkFoldSeg UID
uid Gamma aenv
aenv aR :: ArrayR (Array (sh, Int) e)
aR@(ArrayR ShapeR sh
shR TypeR e
eR) IntegralType i
int IRFun2 Native aenv (e -> e -> e)
combine MIRExp Native aenv e
mseed MIRDelayed Native aenv (Array (sh, Int) e)
marr MIRDelayed Native aenv (Segments i)
mseg =
  let
      (Operands sh
start, Operands sh
end, [Parameter]
paramGang) = ShapeR sh -> (Operands sh, Operands sh, [Parameter])
forall sh. ShapeR sh -> (Operands sh, Operands sh, [Parameter])
gangParam ShapeR sh
shR
      (IRArray (Array (sh, Int) e)
arrOut, [Parameter]
paramOut)      = ArrayR (Array (sh, Int) e)
-> Name (Array (sh, Int) e)
-> (IRArray (Array (sh, Int) e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray ArrayR (Array (sh, Int) e)
aR Name (Array (sh, Int) e)
"out"
      (IRDelayed Native aenv (Array (sh, Int) e)
arrIn,  [Parameter]
paramIn)       = Name (Array (sh, Int) e)
-> MIRDelayed Native aenv (Array (sh, Int) e)
-> (IRDelayed Native aenv (Array (sh, Int) e), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray    Name (Array (sh, Int) e)
"in"  MIRDelayed Native aenv (Array (sh, Int) e)
marr
      (IRDelayed Native aenv (Segments i)
arrSeg, [Parameter]
paramSeg)      = Name (Segments i)
-> MIRDelayed Native aenv (Segments i)
-> (IRDelayed Native aenv (Segments i), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray    Name (Segments i)
"seg" MIRDelayed Native aenv (Segments i)
mseg
      paramEnv :: [Parameter]
paramEnv                = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
  in
  UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))
forall aenv a.
UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv a)
makeOpenAcc UID
uid Label
"foldSegP" ([Parameter]
paramGang [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramSeg [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen Native ()
 -> CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e)))
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))
forall a b. (a -> b) -> a -> b
$ do

    ShapeR sh
-> Operands sh
-> Operands sh
-> Operands sh
-> (Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ()
forall sh.
ShapeR sh
-> Operands sh
-> Operands sh
-> Operands sh
-> (Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ()
imapNestFromTo ShapeR sh
shR Operands sh
start Operands sh
end (IRArray (Array (sh, Int) e) -> Operands (sh, Int)
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array (sh, Int) e)
arrOut) ((Operands sh -> Operands Int -> CodeGen Native ())
 -> CodeGen Native ())
-> (Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ()
forall a b. (a -> b) -> a -> b
$ \Operands sh
ix Operands Int
ii -> do

      -- Determine the start and end indices of the innermost portion of
      -- the array to reduce. This is a segment-offset array computed by
      -- 'scanl (+) 0' of the segment length array.
      --
      let iz :: Operands sh
iz = Operands (sh, Int) -> Operands sh
forall sh sz. Operands (sh, sz) -> Operands sh
indexTail Operands sh
Operands (sh, Int)
ix
          i :: Operands Int
i  = Operands (sh, Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands sh
Operands (sh, Int)
ix
      --
      Operands Int
j <- NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
      Operands Int
u <- IntegralType i
-> NumType Int -> Operands i -> CodeGen Native (Operands Int)
forall arch a b.
IntegralType a
-> NumType b -> Operands a -> CodeGen arch (Operands b)
A.fromIntegral IntegralType i
int NumType Int
forall a. IsNum a => NumType a
numType (Operands i -> CodeGen Native (Operands Int))
-> CodeGen Native (Operands i) -> CodeGen Native (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IROpenFun1 Native () aenv (Int -> i)
-> Operands Int -> CodeGen Native (Operands i)
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed Native aenv (Segments i)
-> IROpenFun1 Native () aenv (Int -> i)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed Native aenv (Segments i)
arrSeg) Operands Int
i
      Operands Int
v <- IntegralType i
-> NumType Int -> Operands i -> CodeGen Native (Operands Int)
forall arch a b.
IntegralType a
-> NumType b -> Operands a -> CodeGen arch (Operands b)
A.fromIntegral IntegralType i
int NumType Int
forall a. IsNum a => NumType a
numType (Operands i -> CodeGen Native (Operands Int))
-> CodeGen Native (Operands i) -> CodeGen Native (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IROpenFun1 Native () aenv (Int -> i)
-> Operands Int -> CodeGen Native (Operands i)
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed Native aenv (Segments i)
-> IROpenFun1 Native () aenv (Int -> i)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed Native aenv (Segments i)
arrSeg) Operands Int
j

      Operands e
r <- case MIRExp Native aenv e
mseed of
             Just IRExp Native aenv e
seed -> do Operands e
z <- IRExp Native aenv e
seed
                             TypeR e
-> Operands Int
-> Operands Int
-> (Operands e -> Operands e -> CodeGen Native (Operands e))
-> Operands e
-> (Operands Int -> CodeGen Native (Operands e))
-> CodeGen Native (Operands e)
forall e.
TypeR e
-> Operands Int
-> Operands Int
-> (Operands e -> Operands e -> CodeGen Native (Operands e))
-> Operands e
-> (Operands Int -> CodeGen Native (Operands e))
-> CodeGen Native (Operands e)
reduceFromTo  TypeR e
eR Operands Int
u Operands Int
v (IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp Native aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine) Operands e
Operands e
z (IROpenFun1 Native () aenv ((sh, Int) -> e)
-> Operands (sh, Int) -> IRExp Native aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed Native aenv (Array (sh, Int) e)
-> IROpenFun1 Native () aenv ((sh, Int) -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (sh -> e)
delayedIndex IRDelayed Native aenv (Array (sh, Int) e)
arrIn) (Operands (sh, Int) -> IRExp Native aenv e)
-> (Operands Int -> Operands (sh, Int))
-> Operands Int
-> IRExp Native aenv e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Operands sh -> Operands Int -> Operands (sh, Int)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
indexCons Operands sh
iz)
             MIRExp Native aenv e
Nothing   ->    TypeR e
-> Operands Int
-> Operands Int
-> (Operands e -> Operands e -> CodeGen Native (Operands e))
-> (Operands Int -> CodeGen Native (Operands e))
-> CodeGen Native (Operands e)
forall e.
TypeR e
-> Operands Int
-> Operands Int
-> (Operands e -> Operands e -> CodeGen Native (Operands e))
-> (Operands Int -> CodeGen Native (Operands e))
-> CodeGen Native (Operands e)
reduce1FromTo TypeR e
eR Operands Int
u Operands Int
v (IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp Native aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine)   (IROpenFun1 Native () aenv ((sh, Int) -> e)
-> Operands (sh, Int) -> IRExp Native aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed Native aenv (Array (sh, Int) e)
-> IROpenFun1 Native () aenv ((sh, Int) -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (sh -> e)
delayedIndex IRDelayed Native aenv (Array (sh, Int) e)
arrIn) (Operands (sh, Int) -> IRExp Native aenv e)
-> (Operands Int -> Operands (sh, Int))
-> Operands Int
-> IRExp Native aenv e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Operands sh -> Operands Int -> Operands (sh, Int)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
indexCons Operands sh
iz)

      IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array (sh, Int) e)
arrOut Operands Int
ii Operands e
Operands e
r

    CodeGen Native ()
forall arch. HasCallStack => CodeGen arch ()
return_