{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TypeApplications    #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.Native.CodeGen.Stencil
-- Copyright   : [2018..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.Native.CodeGen.Stencil (

  mkStencil1,
  mkStencil2,

) where

import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Stencil
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Type

import Data.Array.Accelerate.LLVM.CodeGen.Arithmetic
import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.IR
import Data.Array.Accelerate.LLVM.CodeGen.Loop                      hiding ( imapFromStepTo )
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Stencil
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.Compile.Cache

import Data.Array.Accelerate.LLVM.Native.CodeGen.Base
import Data.Array.Accelerate.LLVM.Native.CodeGen.Loop
import Data.Array.Accelerate.LLVM.Native.Target                     ( Native )

import qualified LLVM.AST.Global                                    as LLVM

import Control.Monad


-- The stencil function is similar to a map, but has access to surrounding
-- elements as specified by the stencil pattern.
--
-- This generates two functions:
--
--  * stencil_inside: does not apply boundary conditions, assumes all element
--                    accesses are valid
--
--  * stencil_border: applies boundary condition check to each array access
--
mkStencil1
    :: UID
    -> Gamma              aenv
    -> StencilR sh a stencil
    -> TypeR b
    -> IRFun1      Native aenv (stencil -> b)
    -> IRBoundary  Native aenv (Array sh a)
    -> MIRDelayed  Native aenv (Array sh a)
    -> CodeGen     Native      (IROpenAcc Native aenv (Array sh b))
mkStencil1 :: UID
-> Gamma aenv
-> StencilR sh a stencil
-> TypeR b
-> IRFun1 Native aenv (stencil -> b)
-> IRBoundary Native aenv (Array sh a)
-> MIRDelayed Native aenv (Array sh a)
-> CodeGen Native (IROpenAcc Native aenv (Array sh b))
mkStencil1 UID
uid Gamma aenv
aenv StencilR sh a stencil
sr TypeR b
tp IRFun1 Native aenv (stencil -> b)
f IRBoundary Native aenv (Array sh a)
bnd MIRDelayed Native aenv (Array sh a)
marr =
  let (IRDelayed Native aenv (Array sh a)
arrIn, [Parameter]
paramIn) = Name (Array sh a)
-> MIRDelayed Native aenv (Array sh a)
-> (IRDelayed Native aenv (Array sh a), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Array sh a)
"in" MIRDelayed Native aenv (Array sh a)
marr
      repr :: ArrayR (Array sh b)
repr = ShapeR sh -> TypeR b -> ArrayR (Array sh b)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR (StencilR sh a stencil -> ShapeR sh
forall sh e pat. StencilR sh e pat -> ShapeR sh
stencilShapeR StencilR sh a stencil
sr) TypeR b
tp
   in IROpenAcc Native aenv (Array sh b)
-> IROpenAcc Native aenv (Array sh b)
-> IROpenAcc Native aenv (Array sh b)
forall aenv a.
IROpenAcc Native aenv a
-> IROpenAcc Native aenv a -> IROpenAcc Native aenv a
(+++) (IROpenAcc Native aenv (Array sh b)
 -> IROpenAcc Native aenv (Array sh b)
 -> IROpenAcc Native aenv (Array sh b))
-> CodeGen Native (IROpenAcc Native aenv (Array sh b))
-> CodeGen
     Native
     (IROpenAcc Native aenv (Array sh b)
      -> IROpenAcc Native aenv (Array sh b))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UID
-> Gamma aenv
-> ArrayR (Array sh b)
-> IRFun1 Native aenv (sh -> b)
-> [Parameter]
-> CodeGen Native (IROpenAcc Native aenv (Array sh b))
forall aenv sh e.
UID
-> Gamma aenv
-> ArrayR (Array sh e)
-> IRFun1 Native aenv (sh -> e)
-> [Parameter]
-> CodeGen Native (IROpenAcc Native aenv (Array sh e))
mkInside UID
uid Gamma aenv
aenv ArrayR (Array sh b)
repr ((Operands sh -> IROpenExp Native ((), sh) aenv b)
-> IRFun1 Native aenv (sh -> b)
forall a arch env aenv b.
(Operands a -> IROpenExp arch (env, a) aenv b)
-> IROpenFun1 arch env aenv (a -> b)
IRFun1 ((Operands sh -> IROpenExp Native ((), sh) aenv b)
 -> IRFun1 Native aenv (sh -> b))
-> (Operands sh -> IROpenExp Native ((), sh) aenv b)
-> IRFun1 Native aenv (sh -> b)
forall a b. (a -> b) -> a -> b
$ IRFun1 Native aenv (stencil -> b)
-> Operands stencil -> IROpenExp Native ((), sh) aenv b
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 IRFun1 Native aenv (stencil -> b)
f (Operands stencil -> IROpenExp Native ((), sh) aenv b)
-> (Operands sh -> CodeGen Native (Operands stencil))
-> Operands sh
-> IROpenExp Native ((), sh) aenv b
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< StencilR sh a stencil
-> Maybe (IRBoundary Native aenv (Array sh a))
-> IRDelayed Native aenv (Array sh a)
-> Operands sh
-> CodeGen Native (Operands stencil)
forall sh e stencil arch aenv.
HasCallStack =>
StencilR sh e stencil
-> Maybe (IRBoundary arch aenv (Array sh e))
-> IRDelayed arch aenv (Array sh e)
-> Operands sh
-> IRExp arch aenv stencil
stencilAccess StencilR sh a stencil
sr Maybe (IRBoundary Native aenv (Array sh a))
forall a. Maybe a
Nothing    IRDelayed Native aenv (Array sh a)
arrIn) [Parameter]
paramIn
            CodeGen
  Native
  (IROpenAcc Native aenv (Array sh b)
   -> IROpenAcc Native aenv (Array sh b))
-> CodeGen Native (IROpenAcc Native aenv (Array sh b))
-> CodeGen Native (IROpenAcc Native aenv (Array sh b))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> UID
-> Gamma aenv
-> ArrayR (Array sh b)
-> IRFun1 Native aenv (sh -> b)
-> [Parameter]
-> CodeGen Native (IROpenAcc Native aenv (Array sh b))
forall aenv sh e.
UID
-> Gamma aenv
-> ArrayR (Array sh e)
-> IRFun1 Native aenv (sh -> e)
-> [Parameter]
-> CodeGen Native (IROpenAcc Native aenv (Array sh e))
mkBorder UID
uid Gamma aenv
aenv ArrayR (Array sh b)
repr ((Operands sh -> IROpenExp Native ((), sh) aenv b)
-> IRFun1 Native aenv (sh -> b)
forall a arch env aenv b.
(Operands a -> IROpenExp arch (env, a) aenv b)
-> IROpenFun1 arch env aenv (a -> b)
IRFun1 ((Operands sh -> IROpenExp Native ((), sh) aenv b)
 -> IRFun1 Native aenv (sh -> b))
-> (Operands sh -> IROpenExp Native ((), sh) aenv b)
-> IRFun1 Native aenv (sh -> b)
forall a b. (a -> b) -> a -> b
$ IRFun1 Native aenv (stencil -> b)
-> Operands stencil -> IROpenExp Native ((), sh) aenv b
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 IRFun1 Native aenv (stencil -> b)
f (Operands stencil -> IROpenExp Native ((), sh) aenv b)
-> (Operands sh -> CodeGen Native (Operands stencil))
-> Operands sh
-> IROpenExp Native ((), sh) aenv b
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< StencilR sh a stencil
-> Maybe (IRBoundary Native aenv (Array sh a))
-> IRDelayed Native aenv (Array sh a)
-> Operands sh
-> CodeGen Native (Operands stencil)
forall sh e stencil arch aenv.
HasCallStack =>
StencilR sh e stencil
-> Maybe (IRBoundary arch aenv (Array sh e))
-> IRDelayed arch aenv (Array sh e)
-> Operands sh
-> IRExp arch aenv stencil
stencilAccess StencilR sh a stencil
sr (IRBoundary Native aenv (Array sh a)
-> Maybe (IRBoundary Native aenv (Array sh a))
forall a. a -> Maybe a
Just IRBoundary Native aenv (Array sh a)
bnd) IRDelayed Native aenv (Array sh a)
arrIn) [Parameter]
paramIn

mkStencil2
    :: UID
    -> Gamma              aenv
    -> StencilR sh a stencil1
    -> StencilR sh b stencil2
    -> TypeR c
    -> IRFun2      Native aenv (stencil1 -> stencil2 -> c)
    -> IRBoundary  Native aenv (Array sh a)
    -> MIRDelayed  Native aenv (Array sh a)
    -> IRBoundary  Native aenv (Array sh b)
    -> MIRDelayed  Native aenv (Array sh b)
    -> CodeGen     Native      (IROpenAcc Native aenv (Array sh c))
mkStencil2 :: UID
-> Gamma aenv
-> StencilR sh a stencil1
-> StencilR sh b stencil2
-> TypeR c
-> IRFun2 Native aenv (stencil1 -> stencil2 -> c)
-> IRBoundary Native aenv (Array sh a)
-> MIRDelayed Native aenv (Array sh a)
-> IRBoundary Native aenv (Array sh b)
-> MIRDelayed Native aenv (Array sh b)
-> CodeGen Native (IROpenAcc Native aenv (Array sh c))
mkStencil2 UID
uid Gamma aenv
aenv StencilR sh a stencil1
sr1 StencilR sh b stencil2
sr2 TypeR c
tp IRFun2 Native aenv (stencil1 -> stencil2 -> c)
f IRBoundary Native aenv (Array sh a)
bnd1 MIRDelayed Native aenv (Array sh a)
marr1 IRBoundary Native aenv (Array sh b)
bnd2 MIRDelayed Native aenv (Array sh b)
marr2 =
  let
      (IRDelayed Native aenv (Array sh a)
arrIn1, [Parameter]
paramIn1)  = Name (Array sh a)
-> MIRDelayed Native aenv (Array sh a)
-> (IRDelayed Native aenv (Array sh a), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Array sh a)
"in1" MIRDelayed Native aenv (Array sh a)
marr1
      (IRDelayed Native aenv (Array sh b)
arrIn2, [Parameter]
paramIn2)  = Name (Array sh b)
-> MIRDelayed Native aenv (Array sh b)
-> (IRDelayed Native aenv (Array sh b), [Parameter])
forall sh e arch aenv.
Name (Array sh e)
-> MIRDelayed arch aenv (Array sh e)
-> (IRDelayed arch aenv (Array sh e), [Parameter])
delayedArray Name (Array sh b)
"in2" MIRDelayed Native aenv (Array sh b)
marr2

      repr :: ArrayR (Array sh c)
repr = ShapeR sh -> TypeR c -> ArrayR (Array sh c)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR (StencilR sh a stencil1 -> ShapeR sh
forall sh e pat. StencilR sh e pat -> ShapeR sh
stencilShapeR StencilR sh a stencil1
sr1) TypeR c
tp

      inside :: IROpenFun1 Native () aenv (sh -> c)
inside  = (Operands sh -> IROpenExp Native ((), sh) aenv c)
-> IROpenFun1 Native () aenv (sh -> c)
forall a arch env aenv b.
(Operands a -> IROpenExp arch (env, a) aenv b)
-> IROpenFun1 arch env aenv (a -> b)
IRFun1 ((Operands sh -> IROpenExp Native ((), sh) aenv c)
 -> IROpenFun1 Native () aenv (sh -> c))
-> (Operands sh -> IROpenExp Native ((), sh) aenv c)
-> IROpenFun1 Native () aenv (sh -> c)
forall a b. (a -> b) -> a -> b
$ \Operands sh
ix -> do
        Operands stencil1
stencil1 <- StencilR sh a stencil1
-> Maybe (IRBoundary Native aenv (Array sh a))
-> IRDelayed Native aenv (Array sh a)
-> Operands sh
-> IRExp Native aenv stencil1
forall sh e stencil arch aenv.
HasCallStack =>
StencilR sh e stencil
-> Maybe (IRBoundary arch aenv (Array sh e))
-> IRDelayed arch aenv (Array sh e)
-> Operands sh
-> IRExp arch aenv stencil
stencilAccess StencilR sh a stencil1
sr1 Maybe (IRBoundary Native aenv (Array sh a))
forall a. Maybe a
Nothing IRDelayed Native aenv (Array sh a)
arrIn1 Operands sh
ix
        Operands stencil2
stencil2 <- StencilR sh b stencil2
-> Maybe (IRBoundary Native aenv (Array sh b))
-> IRDelayed Native aenv (Array sh b)
-> Operands sh
-> IRExp Native aenv stencil2
forall sh e stencil arch aenv.
HasCallStack =>
StencilR sh e stencil
-> Maybe (IRBoundary arch aenv (Array sh e))
-> IRDelayed arch aenv (Array sh e)
-> Operands sh
-> IRExp arch aenv stencil
stencilAccess StencilR sh b stencil2
sr2 Maybe (IRBoundary Native aenv (Array sh b))
forall a. Maybe a
Nothing IRDelayed Native aenv (Array sh b)
arrIn2 Operands sh
ix
        IRFun2 Native aenv (stencil1 -> stencil2 -> c)
-> Operands stencil1
-> Operands stencil2
-> IROpenExp Native ((), sh) aenv c
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (stencil1 -> stencil2 -> c)
f Operands stencil1
stencil1 Operands stencil2
stencil2
      --
      border :: IROpenFun1 Native () aenv (sh -> c)
border  = (Operands sh -> IROpenExp Native ((), sh) aenv c)
-> IROpenFun1 Native () aenv (sh -> c)
forall a arch env aenv b.
(Operands a -> IROpenExp arch (env, a) aenv b)
-> IROpenFun1 arch env aenv (a -> b)
IRFun1 ((Operands sh -> IROpenExp Native ((), sh) aenv c)
 -> IROpenFun1 Native () aenv (sh -> c))
-> (Operands sh -> IROpenExp Native ((), sh) aenv c)
-> IROpenFun1 Native () aenv (sh -> c)
forall a b. (a -> b) -> a -> b
$ \Operands sh
ix -> do
        Operands stencil1
stencil1 <- StencilR sh a stencil1
-> Maybe (IRBoundary Native aenv (Array sh a))
-> IRDelayed Native aenv (Array sh a)
-> Operands sh
-> IRExp Native aenv stencil1
forall sh e stencil arch aenv.
HasCallStack =>
StencilR sh e stencil
-> Maybe (IRBoundary arch aenv (Array sh e))
-> IRDelayed arch aenv (Array sh e)
-> Operands sh
-> IRExp arch aenv stencil
stencilAccess StencilR sh a stencil1
sr1 (IRBoundary Native aenv (Array sh a)
-> Maybe (IRBoundary Native aenv (Array sh a))
forall a. a -> Maybe a
Just IRBoundary Native aenv (Array sh a)
bnd1) IRDelayed Native aenv (Array sh a)
arrIn1 Operands sh
ix
        Operands stencil2
stencil2 <- StencilR sh b stencil2
-> Maybe (IRBoundary Native aenv (Array sh b))
-> IRDelayed Native aenv (Array sh b)
-> Operands sh
-> IRExp Native aenv stencil2
forall sh e stencil arch aenv.
HasCallStack =>
StencilR sh e stencil
-> Maybe (IRBoundary arch aenv (Array sh e))
-> IRDelayed arch aenv (Array sh e)
-> Operands sh
-> IRExp arch aenv stencil
stencilAccess StencilR sh b stencil2
sr2 (IRBoundary Native aenv (Array sh b)
-> Maybe (IRBoundary Native aenv (Array sh b))
forall a. a -> Maybe a
Just IRBoundary Native aenv (Array sh b)
bnd2) IRDelayed Native aenv (Array sh b)
arrIn2 Operands sh
ix
        IRFun2 Native aenv (stencil1 -> stencil2 -> c)
-> Operands stencil1
-> Operands stencil2
-> IROpenExp Native ((), sh) aenv c
forall arch env aenv a b c.
IROpenFun2 arch env aenv (a -> b -> c)
-> Operands a -> Operands b -> IROpenExp arch ((env, a), b) aenv c
app2 IRFun2 Native aenv (stencil1 -> stencil2 -> c)
f Operands stencil1
stencil1 Operands stencil2
stencil2
  in
  IROpenAcc Native aenv (Array sh c)
-> IROpenAcc Native aenv (Array sh c)
-> IROpenAcc Native aenv (Array sh c)
forall aenv a.
IROpenAcc Native aenv a
-> IROpenAcc Native aenv a -> IROpenAcc Native aenv a
(+++) (IROpenAcc Native aenv (Array sh c)
 -> IROpenAcc Native aenv (Array sh c)
 -> IROpenAcc Native aenv (Array sh c))
-> CodeGen Native (IROpenAcc Native aenv (Array sh c))
-> CodeGen
     Native
     (IROpenAcc Native aenv (Array sh c)
      -> IROpenAcc Native aenv (Array sh c))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UID
-> Gamma aenv
-> ArrayR (Array sh c)
-> IROpenFun1 Native () aenv (sh -> c)
-> [Parameter]
-> CodeGen Native (IROpenAcc Native aenv (Array sh c))
forall aenv sh e.
UID
-> Gamma aenv
-> ArrayR (Array sh e)
-> IRFun1 Native aenv (sh -> e)
-> [Parameter]
-> CodeGen Native (IROpenAcc Native aenv (Array sh e))
mkInside UID
uid Gamma aenv
aenv ArrayR (Array sh c)
repr IROpenFun1 Native () aenv (sh -> c)
inside ([Parameter]
paramIn1 [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn2)
        CodeGen
  Native
  (IROpenAcc Native aenv (Array sh c)
   -> IROpenAcc Native aenv (Array sh c))
-> CodeGen Native (IROpenAcc Native aenv (Array sh c))
-> CodeGen Native (IROpenAcc Native aenv (Array sh c))
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> UID
-> Gamma aenv
-> ArrayR (Array sh c)
-> IROpenFun1 Native () aenv (sh -> c)
-> [Parameter]
-> CodeGen Native (IROpenAcc Native aenv (Array sh c))
forall aenv sh e.
UID
-> Gamma aenv
-> ArrayR (Array sh e)
-> IRFun1 Native aenv (sh -> e)
-> [Parameter]
-> CodeGen Native (IROpenAcc Native aenv (Array sh e))
mkBorder UID
uid Gamma aenv
aenv ArrayR (Array sh c)
repr IROpenFun1 Native () aenv (sh -> c)
border ([Parameter]
paramIn1 [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn2)


mkInside
    :: UID
    -> Gamma aenv
    -> ArrayR (Array sh e)
    -> IRFun1  Native aenv (sh -> e)
    -> [LLVM.Parameter]
    -> CodeGen Native      (IROpenAcc Native aenv (Array sh e))
mkInside :: UID
-> Gamma aenv
-> ArrayR (Array sh e)
-> IRFun1 Native aenv (sh -> e)
-> [Parameter]
-> CodeGen Native (IROpenAcc Native aenv (Array sh e))
mkInside UID
uid Gamma aenv
aenv ArrayR (Array sh e)
repr IRFun1 Native aenv (sh -> e)
apply [Parameter]
paramIn =
  let
      (Operands sh
start, Operands sh
end, [Parameter]
paramGang)   = ShapeR sh -> (Operands sh, Operands sh, [Parameter])
forall sh. ShapeR sh -> (Operands sh, Operands sh, [Parameter])
gangParam    (ArrayR (Array sh e) -> ShapeR sh
forall sh e. ArrayR (Array sh e) -> ShapeR sh
arrayRshape ArrayR (Array sh e)
repr)
      (IRArray (Array sh e)
arrOut, [Parameter]
paramOut)        = ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray ArrayR (Array sh e)
repr Name (Array sh e)
"out"
      paramEnv :: [Parameter]
paramEnv                  = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      shOut :: Operands sh
shOut                     = IRArray (Array sh e) -> Operands sh
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array sh e)
arrOut
  in
  UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Array sh e))
forall aenv a.
UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv a)
makeOpenAcc UID
uid Label
"stencil_inside" ([Parameter]
paramGang [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen Native ()
 -> CodeGen Native (IROpenAcc Native aenv (Array sh e)))
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Array sh e))
forall a b. (a -> b) -> a -> b
$ do

    ShapeR sh
-> Int
-> Operands sh
-> Operands sh
-> Operands sh
-> (Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ()
forall sh.
ShapeR sh
-> Int
-> Operands sh
-> Operands sh
-> Operands sh
-> (Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ()
imapNestFromToTile (ArrayR (Array sh e) -> ShapeR sh
forall sh e. ArrayR (Array sh e) -> ShapeR sh
arrayRshape ArrayR (Array sh e)
repr) Int
4 Operands sh
start Operands sh
end Operands sh
shOut ((Operands sh -> Operands Int -> CodeGen Native ())
 -> CodeGen Native ())
-> (Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ()
forall a b. (a -> b) -> a -> b
$ \Operands sh
ix Operands Int
i -> do
      Operands e
r <- IRFun1 Native aenv (sh -> e)
-> Operands sh -> IROpenExp Native ((), sh) aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 IRFun1 Native aenv (sh -> e)
apply Operands sh
ix                        -- apply generator function
      IntegralType Int
-> IRArray (Array sh e)
-> Operands Int
-> Operands e
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array sh e)
arrOut Operands Int
i Operands e
r                     -- store result

    CodeGen Native ()
forall arch. HasCallStack => CodeGen arch ()
return_

mkBorder
    :: UID
    -> Gamma aenv
    -> ArrayR (Array sh e)
    -> IRFun1  Native aenv (sh -> e)
    -> [LLVM.Parameter]
    -> CodeGen Native      (IROpenAcc Native aenv (Array sh e))
mkBorder :: UID
-> Gamma aenv
-> ArrayR (Array sh e)
-> IRFun1 Native aenv (sh -> e)
-> [Parameter]
-> CodeGen Native (IROpenAcc Native aenv (Array sh e))
mkBorder UID
uid Gamma aenv
aenv ArrayR (Array sh e)
repr IRFun1 Native aenv (sh -> e)
apply [Parameter]
paramIn =
  let
      (Operands sh
start, Operands sh
end, [Parameter]
paramGang)   = ShapeR sh -> (Operands sh, Operands sh, [Parameter])
forall sh. ShapeR sh -> (Operands sh, Operands sh, [Parameter])
gangParam    (ArrayR (Array sh e) -> ShapeR sh
forall sh e. ArrayR (Array sh e) -> ShapeR sh
arrayRshape ArrayR (Array sh e)
repr)
      (IRArray (Array sh e)
arrOut, [Parameter]
paramOut)        = ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray ArrayR (Array sh e)
repr Name (Array sh e)
"out"
      paramEnv :: [Parameter]
paramEnv                  = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
      shOut :: Operands sh
shOut                     = IRArray (Array sh e) -> Operands sh
forall sh e. IRArray (Array sh e) -> Operands sh
irArrayShape IRArray (Array sh e)
arrOut
  in
  UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Array sh e))
forall aenv a.
UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv a)
makeOpenAcc UID
uid Label
"stencil_border" ([Parameter]
paramGang [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen Native ()
 -> CodeGen Native (IROpenAcc Native aenv (Array sh e)))
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Array sh e))
forall a b. (a -> b) -> a -> b
$ do

    ShapeR sh
-> Operands sh
-> Operands sh
-> Operands sh
-> (Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ()
forall sh.
ShapeR sh
-> Operands sh
-> Operands sh
-> Operands sh
-> (Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ()
imapNestFromTo (ArrayR (Array sh e) -> ShapeR sh
forall sh e. ArrayR (Array sh e) -> ShapeR sh
arrayRshape ArrayR (Array sh e)
repr) Operands sh
start Operands sh
end Operands sh
shOut ((Operands sh -> Operands Int -> CodeGen Native ())
 -> CodeGen Native ())
-> (Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ()
forall a b. (a -> b) -> a -> b
$ \Operands sh
ix Operands Int
i -> do
      Operands e
r <- IRFun1 Native aenv (sh -> e)
-> Operands sh -> IROpenExp Native ((), sh) aenv e
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 IRFun1 Native aenv (sh -> e)
apply Operands sh
ix                        -- apply generator function
      IntegralType Int
-> IRArray (Array sh e)
-> Operands Int
-> Operands e
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array sh e)
arrOut Operands Int
i Operands e
r             -- store result

    CodeGen Native ()
forall arch. HasCallStack => CodeGen arch ()
return_


imapNestFromToTile
    :: ShapeR sh
    -> Int                                                  -- ^ unroll amount (tile height)
    -> Operands sh                                          -- ^ initial index (inclusive)
    -> Operands sh                                          -- ^ final index (exclusive)
    -> Operands sh                                          -- ^ total array extent
    -> (Operands sh -> Operands Int -> CodeGen Native ())   -- ^ apply at each index
    -> CodeGen Native ()
imapNestFromToTile :: ShapeR sh
-> Int
-> Operands sh
-> Operands sh
-> Operands sh
-> (Operands sh -> Operands Int -> CodeGen Native ())
-> CodeGen Native ()
imapNestFromToTile ShapeR sh
shr Int
unroll Operands sh
start Operands sh
end Operands sh
extent Operands sh -> Operands Int -> CodeGen Native ()
body =
  ShapeR sh
-> Operands sh
-> Operands sh
-> (Operands sh -> CodeGen Native ())
-> CodeGen Native ()
forall t.
ShapeR t
-> Operands t
-> Operands t
-> (Operands t -> CodeGen Native ())
-> CodeGen Native ()
go ShapeR sh
shr Operands sh
start Operands sh
end Operands sh -> CodeGen Native ()
body'
  where
    body' :: Operands sh -> CodeGen Native ()
body' Operands sh
ix = Operands sh -> Operands Int -> CodeGen Native ()
body Operands sh
ix (Operands Int -> CodeGen Native ())
-> CodeGen Native (Operands Int) -> CodeGen Native ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ShapeR sh
-> Operands sh -> Operands sh -> CodeGen Native (Operands Int)
forall sh arch.
ShapeR sh
-> Operands sh -> Operands sh -> CodeGen arch (Operands Int)
intOfIndex ShapeR sh
shr Operands sh
extent Operands sh
ix

    go :: ShapeR t
       -> Operands t
       -> Operands t
       -> (Operands t -> CodeGen Native ())
       -> CodeGen Native ()
    go :: ShapeR t
-> Operands t
-> Operands t
-> (Operands t -> CodeGen Native ())
-> CodeGen Native ()
go ShapeR t
ShapeRz Operands t
OP_Unit Operands t
OP_Unit Operands t -> CodeGen Native ()
k
      = Operands t -> CodeGen Native ()
k Operands t
Operands ()
OP_Unit

    -- To correctly generate the unrolled loop nest we need to explicitly match
    -- on the last two dimensions.
    --
    go (ShapeRsnoc (ShapeRsnoc ShapeR sh1
ShapeRz)) (OP_Pair (OP_Pair OP_Unit sy) sx) (OP_Pair (OP_Pair OP_Unit ey) ex) Operands t -> CodeGen Native ()
k
      = do
          -- Tile the stencil operator in the xy-plane by unrolling in the
          -- y-dimension and vectorising in the x-dimension.
          --
          Operands Int
sy' <- Operands Int
-> Operands Int
-> Operands Int
-> (Operands Int -> CodeGen Native ())
-> CodeGen Native (Operands Int)
imapFromStepTo Operands Int
sy (Int -> Operands Int
liftInt Int
unroll) Operands Int
ey ((Operands Int -> CodeGen Native ())
 -> CodeGen Native (Operands Int))
-> (Operands Int -> CodeGen Native ())
-> CodeGen Native (Operands Int)
forall a b. (a -> b) -> a -> b
$ \Operands Int
iy ->
                  Operands Int
-> Operands Int
-> (Operands Int -> CodeGen Native ())
-> CodeGen Native ()
imapFromTo    Operands Int
sx                  Operands Int
ex ((Operands Int -> CodeGen Native ()) -> CodeGen Native ())
-> (Operands Int -> CodeGen Native ()) -> CodeGen Native ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
ix ->
                    [Int] -> (Int -> CodeGen Native ()) -> CodeGen Native ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
unrollInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1] ((Int -> CodeGen Native ()) -> CodeGen Native ())
-> (Int -> CodeGen Native ()) -> CodeGen Native ()
forall a b. (a -> b) -> a -> b
$ \Int
n -> do
                    Operands Int
iy' <- NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
iy (Int -> Operands Int
liftInt Int
n)
                    Operands t -> CodeGen Native ()
k (Operands ((), Int) -> Operands Int -> Operands (((), Int), Int)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair (Operands () -> Operands Int -> Operands ((), Int)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair Operands ()
OP_Unit Operands Int
iy') Operands Int
ix)

          -- Take care of any remaining loop iterations in the y-dimension
          --
          ()
_       <- Operands Int
-> Operands Int
-> (Operands Int -> CodeGen Native ())
-> CodeGen Native ()
imapFromTo  Operands Int
sy' Operands Int
ey ((Operands Int -> CodeGen Native ()) -> CodeGen Native ())
-> (Operands Int -> CodeGen Native ()) -> CodeGen Native ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
iy ->
                      Operands Int
-> Operands Int
-> (Operands Int -> CodeGen Native ())
-> CodeGen Native ()
imapFromTo Operands Int
sx  Operands Int
ex ((Operands Int -> CodeGen Native ()) -> CodeGen Native ())
-> (Operands Int -> CodeGen Native ()) -> CodeGen Native ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
ix ->
                        Operands t -> CodeGen Native ()
k (Operands ((), Int) -> Operands Int -> Operands (((), Int), Int)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair (Operands () -> Operands Int -> Operands ((), Int)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair Operands ()
OP_Unit Operands Int
iy) Operands Int
ix)
          () -> CodeGen Native ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    -- The 1- and 3+-dimensional cases can recurse normally
    --
    go (ShapeRsnoc ShapeR sh1
shr') (OP_Pair ssh ssz) (OP_Pair esh esz) Operands t -> CodeGen Native ()
k
      = ShapeR sh1
-> Operands sh1
-> Operands sh1
-> (Operands sh1 -> CodeGen Native ())
-> CodeGen Native ()
forall t.
ShapeR t
-> Operands t
-> Operands t
-> (Operands t -> CodeGen Native ())
-> CodeGen Native ()
go ShapeR sh1
shr' Operands sh1
ssh Operands sh1
esh
      ((Operands sh1 -> CodeGen Native ()) -> CodeGen Native ())
-> (Operands sh1 -> CodeGen Native ()) -> CodeGen Native ()
forall a b. (a -> b) -> a -> b
$ \Operands sh1
sz      -> Operands Int
-> Operands Int
-> (Operands Int -> CodeGen Native ())
-> CodeGen Native ()
imapFromTo Operands Int
ssz Operands Int
esz
      ((Operands Int -> CodeGen Native ()) -> CodeGen Native ())
-> (Operands Int -> CodeGen Native ()) -> CodeGen Native ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
i       -> Operands t -> CodeGen Native ()
k (Operands sh1 -> Operands Int -> Operands (sh1, Int)
forall a b. Operands a -> Operands b -> Operands (a, b)
OP_Pair Operands sh1
sz Operands Int
i)

imapFromStepTo
    :: Operands Int
    -> Operands Int
    -> Operands Int
    -> (Operands Int -> CodeGen Native ())
    -> CodeGen Native (Operands Int)
imapFromStepTo :: Operands Int
-> Operands Int
-> Operands Int
-> (Operands Int -> CodeGen Native ())
-> CodeGen Native (Operands Int)
imapFromStepTo Operands Int
start Operands Int
step Operands Int
end Operands Int -> CodeGen Native ()
body =
  let
      incr :: Operands Int -> CodeGen Native (Operands Int)
incr Operands Int
i = NumType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Int)
forall a arch.
NumType a -> Operands a -> Operands a -> CodeGen arch (Operands a)
add NumType Int
forall a. IsNum a => NumType a
numType Operands Int
i Operands Int
step
      test :: Operands Int -> CodeGen Native (Operands Bool)
test Operands Int
i = do Operands Int
i' <- Operands Int -> CodeGen Native (Operands Int)
incr Operands Int
i
                  SingleType Int
-> Operands Int -> Operands Int -> CodeGen Native (Operands Bool)
forall a arch.
SingleType a
-> Operands a -> Operands a -> CodeGen arch (Operands Bool)
lt SingleType Int
forall a. IsSingle a => SingleType a
singleType Operands Int
i' Operands Int
end
  in
  TypeR Int
-> (Operands Int -> CodeGen Native (Operands Bool))
-> (Operands Int -> CodeGen Native (Operands Int))
-> Operands Int
-> CodeGen Native (Operands Int)
forall a arch.
TypeR a
-> (Operands a -> CodeGen arch (Operands Bool))
-> (Operands a -> CodeGen arch (Operands a))
-> Operands a
-> CodeGen arch (Operands a)
while (ScalarType Int -> TypeR Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle ScalarType Int
scalarTypeInt) Operands Int -> CodeGen Native (Operands Bool)
test
        (\Operands Int
i -> Operands Int -> CodeGen Native ()
body Operands Int
i CodeGen Native ()
-> CodeGen Native (Operands Int) -> CodeGen Native (Operands Int)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Operands Int -> CodeGen Native (Operands Int)
incr Operands Int
i)
        Operands Int
start