{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Array.Accelerate.LLVM.Native.CodeGen.FoldSeg
where
import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic as A
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.Compile.Cache
import Data.Array.Accelerate.LLVM.Native.CodeGen.Base
import Data.Array.Accelerate.LLVM.Native.CodeGen.Fold
import Data.Array.Accelerate.LLVM.Native.CodeGen.Loop
import Data.Array.Accelerate.LLVM.Native.Target ( Native )
import Control.Monad
import Prelude as P
mkFoldSeg
:: UID
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 Native aenv (e -> e -> e)
-> MIRExp Native aenv e
-> MIRDelayed Native aenv (Array (sh, Int) e)
-> MIRDelayed Native aenv (Segments i)
-> CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))
mkFoldSeg :: UID
-> Gamma aenv
-> ArrayR (Array (sh, Int) e)
-> IntegralType i
-> IRFun2 Native aenv (e -> e -> e)
-> MIRExp Native aenv e
-> MIRDelayed Native aenv (Array (sh, Int) e)
-> MIRDelayed Native aenv (Segments i)
-> CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))
mkFoldSeg UID
uid Gamma aenv
aenv aR :: ArrayR (Array (sh, Int) e)
aR@(ArrayR ShapeR sh
shR TypeR e
eR) IntegralType i
int IRFun2 Native aenv (e -> e -> e)
combine MIRExp Native aenv e
mseed MIRDelayed Native aenv (Array (sh, Int) e)
marr MIRDelayed Native aenv (Segments i)
mseg =
let
(Operands sh
start, Operands sh
end, [Parameter]
paramGang) = ShapeR sh -> (Operands sh, Operands sh, [Parameter])
forall sh. ShapeR sh -> (Operands sh, Operands sh, [Parameter])
gangParam ShapeR sh
shR
(IRArray (Array (sh, Int) e)
arrOut, [Parameter]
paramOut) = ArrayR (Array (sh, Int) e)
-> Name (Array (sh, Int) e)
-> (IRArray (Array (sh, Int) e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray ArrayR (Array (sh, Int) e)
aR Name (Array (sh, Int) e)
"out"
(IRDelayed Native aenv (Array (sh, Int) e)
arrIn, [Parameter]
paramIn) = Name (Array (sh, Int) e)
-> MIRDelayed Native aenv (Array (sh, Int) e)
-> (IRDelayed Native aenv (Array (sh, Int) e), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Array (sh, Int) e)
"in" MIRDelayed Native aenv (Array (sh, Int) e)
marr
(IRDelayed Native aenv (Segments i)
arrSeg, [Parameter]
paramSeg) = Name (Segments i)
-> MIRDelayed Native aenv (Segments i)
-> (IRDelayed Native aenv (Segments i), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Segments i)
"seg" MIRDelayed Native aenv (Segments i)
mseg
paramEnv :: [Parameter]
paramEnv = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
in
UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))
forall aenv a.
UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv a)
makeOpenAcc UID
uid Label
"foldSegP" ([Parameter]
paramGang [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramSeg [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e)))
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Array (sh, Int) e))
forall a b. (a -> b) -> a -> b
$ do
ShapeR sh
-> Operands sh
-> Operands sh
-> Operands sh
-> (Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ()
forall sh.
ShapeR sh
-> Operands sh
-> Operands sh
-> Operands sh
-> (Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ()
imapNestFromTo ShapeR sh
shR Operands sh
start Operands sh
end (IRArray (Array (sh, Int) e) -> Operands (sh, Int)
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array (sh, Int) e)
arrOut) ((Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ())
-> (Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ()
forall a b. (a -> b) -> a -> b
$ \Operands sh
ix Operands Int
ii -> do
let iz :: Operands sh
iz = Operands (sh, Int) -> Operands sh
forall sh sz. Operands (sh, sz) -> Operands sh
indexTail Operands sh
Operands (sh, Int)
ix
i :: Operands Int
i = Operands (sh, Int) -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands sh
Operands (sh, Int)
ix
Operands Int
j <- NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
A.add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i (Int -> Operands Int
liftInt Int
1)
Operands Int
u <- IntegralType i
-> NumType Int -> Operands i -> CodeGen Native (Operands Int)
forall arch a b.
IntegralType a
-> NumType b -> Operands a -> CodeGen arch (Operands b)
A.fromIntegral IntegralType i
int NumType Int
forall a. IsNum a => NumType a
numType (Operands i -> CodeGen Native (Operands Int))
-> CodeGen Native (Operands i) -> CodeGen Native (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IROpenFun1 Native () aenv (Int -> i)
-> Operands Int -> CodeGen Native (Operands i)
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed Native aenv (Segments i)
-> IROpenFun1 Native () aenv (Int -> i)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed Native aenv (Segments i)
arrSeg) Operands Int
i
Operands Int
v <- IntegralType i
-> NumType Int -> Operands i -> CodeGen Native (Operands Int)
forall arch a b.
IntegralType a
-> NumType b -> Operands a -> CodeGen arch (Operands b)
A.fromIntegral IntegralType i
int NumType Int
forall a. IsNum a => NumType a
numType (Operands i -> CodeGen Native (Operands Int))
-> CodeGen Native (Operands i) -> CodeGen Native (Operands Int)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< IROpenFun1 Native () aenv (Int -> i)
-> Operands Int -> CodeGen Native (Operands i)
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed Native aenv (Segments i)
-> IROpenFun1 Native () aenv (Int -> i)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (Int -> e)
delayedLinearIndex IRDelayed Native aenv (Segments i)
arrSeg) Operands Int
j
Operands e
r <- case MIRExp Native aenv e
mseed of
Just IRExp Native aenv e
seed -> do Operands e
z <- IRExp Native aenv e
seed
TypeR e
-> Operands Int
-> Operands Int
-> (Operands e -> Operands e -> CodeGen Native (Operands e))
-> Operands e
-> (Operands Int -> CodeGen Native (Operands e))
-> CodeGen Native (Operands e)
forall e.
TypeR e
-> Operands Int
-> Operands Int
-> (Operands e -> Operands e -> CodeGen Native (Operands e))
-> Operands e
-> (Operands Int -> CodeGen Native (Operands e))
-> CodeGen Native (Operands e)
reduceFromTo TypeR e
eR Operands Int
u Operands Int
v (IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp Native aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine) Operands e
Operands e
z (IROpenFun1 Native () aenv ((sh, Int) -> e)
-> Operands (sh, Int) -> IRExp Native aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed Native aenv (Array (sh, Int) e)
-> IROpenFun1 Native () aenv ((sh, Int) -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (sh -> e)
delayedIndex IRDelayed Native aenv (Array (sh, Int) e)
arrIn) (Operands (sh, Int) -> IRExp Native aenv e)
-> (Operands Int -> Operands (sh, Int))
-> Operands Int
-> IRExp Native aenv e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Operands sh -> Operands Int -> Operands (sh, Int)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
indexCons Operands sh
iz)
MIRExp Native aenv e
Nothing -> TypeR e
-> Operands Int
-> Operands Int
-> (Operands e -> Operands e -> CodeGen Native (Operands e))
-> (Operands Int -> CodeGen Native (Operands e))
-> CodeGen Native (Operands e)
forall e.
TypeR e
-> Operands Int
-> Operands Int
-> (Operands e -> Operands e -> CodeGen Native (Operands e))
-> (Operands Int -> CodeGen Native (Operands e))
-> CodeGen Native (Operands e)
reduce1FromTo TypeR e
eR Operands Int
u Operands Int
v (IRFun2 Native aenv (e -> e -> e)
-> Operands e -> Operands e -> IRExp Native aenv e
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (e -> e -> e)
combine) (IROpenFun1 Native () aenv ((sh, Int) -> e)
-> Operands (sh, Int) -> IRExp Native aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 (IRDelayed Native aenv (Array (sh, Int) e)
-> IROpenFun1 Native () aenv ((sh, Int) -> e)
forall arch aenv sh e.
IRDelayed arch aenv (Array sh e) -> IRFun1 arch aenv (sh -> e)
delayedIndex IRDelayed Native aenv (Array (sh, Int) e)
arrIn) (Operands (sh, Int) -> IRExp Native aenv e)
-> (Operands Int -> Operands (sh, Int))
-> Operands Int
-> IRExp Native aenv e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Operands sh -> Operands Int -> Operands (sh, Int)
forall sh sz. Operands sh -> Operands sz -> Operands (sh, sz)
indexCons Operands sh
iz)
IntegralType Int
-> IRArray (Array (sh, Int) e)
-> Operands Int
-> Operands e
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array (sh, Int) e)
arrOut Operands Int
ii Operands e
Operands e
r
CodeGen Native ()
forall arch. HasCallStack => CodeGen arch ()
return_