{-# LANGUAGE GADTs               #-}
{-# LANGUAGE OverloadedStrings   #-}
{-# LANGUAGE RecordWildCards     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}
-- |
-- Module      : Data.Array.Accelerate.LLVM.Native.CodeGen.Map
-- Copyright   : [2014..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.LLVM.Native.CodeGen.Map
  where

import Data.Array.Accelerate.Representation.Array
import Data.Array.Accelerate.Representation.Shape
import Data.Array.Accelerate.Representation.Type
import Data.Array.Accelerate.Type

import Data.Array.Accelerate.LLVM.CodeGen.Array
import Data.Array.Accelerate.LLVM.CodeGen.Base
import Data.Array.Accelerate.LLVM.CodeGen.Environment
import Data.Array.Accelerate.LLVM.CodeGen.Exp
import Data.Array.Accelerate.LLVM.CodeGen.Monad
import Data.Array.Accelerate.LLVM.CodeGen.Sugar
import Data.Array.Accelerate.LLVM.Compile.Cache

import Data.Array.Accelerate.LLVM.Native.Target                 ( Native )
import Data.Array.Accelerate.LLVM.Native.CodeGen.Base
import Data.Array.Accelerate.LLVM.Native.CodeGen.Loop


-- C Code
-- ======
--
-- float f(float);
--
-- void map(float* __restrict__ out, const float* __restrict__ in, const int n)
-- {
--     for (int i = 0; i < n; ++i)
--         out[i] = f(in[i]);
--
--     return;
-- }

-- Corresponding LLVM
-- ==================
--
-- define void @map(float* noalias nocapture %out, float* noalias nocapture %in, i32 %n) nounwind uwtable ssp {
--   %1 = icmp sgt i32 %n, 0
--   br i1 %1, label %.lr.ph, label %._crit_edge
--
-- .lr.ph:                                           ; preds = %0, %.lr.ph
--   %indvars.iv = phi i64 [ %indvars.iv.next, %.lr.ph ], [ 0, %0 ]
--   %2 = getelementptr inbounds float* %in, i64 %indvars.iv
--   %3 = load float* %2, align 4
--   %4 = tail call float @apply(float %3) nounwind
--   %5 = getelementptr inbounds float* %out, i64 %indvars.iv
--   store float %4, float* %5, align 4
--   %indvars.iv.next = add i64 %indvars.iv, 1
--   %lftr.wideiv = trunc i64 %indvars.iv.next to i32
--   %exitcond = icmp eq i32 %lftr.wideiv, %n
--   br i1 %exitcond, label %._crit_edge, label %.lr.ph
--
-- ._crit_edge:                                      ; preds = %.lr.ph, %0
--   ret void
-- }
--
-- declare float @apply(float)
--

-- Apply the given unary function to each element of an array.
--
-- The map operation can always treat an array of any dimension in its flat
-- underlying representation, which simplifies code generation.
--
mkMap :: UID
      -> Gamma aenv
      -> ArrayR (Array sh a)
      -> TypeR b
      -> IRFun1  Native aenv (a -> b)
      -> CodeGen Native      (IROpenAcc Native aenv (Array sh b))
mkMap :: UID
-> Gamma aenv
-> ArrayR (Array sh a)
-> TypeR b
-> IRFun1 Native aenv (a -> b)
-> CodeGen Native (IROpenAcc Native aenv (Array sh b))
mkMap UID
uid Gamma aenv
aenv (ArrayR ShapeR sh
shR TypeR e
aR) TypeR b
bR IRFun1 Native aenv (a -> b)
apply =
  let
      (Operands DIM1
start, Operands DIM1
end, [Parameter]
paramGang)   = ShapeR DIM1 -> (Operands DIM1, Operands DIM1, [Parameter])
forall sh. ShapeR sh -> (Operands sh, Operands sh, [Parameter])
gangParam ShapeR DIM1
dim1
      (IRArray (Array sh e)
arrIn,  [Parameter]
paramIn)         = ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR sh -> TypeR e -> ArrayR (Array sh e)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR sh
shR TypeR e
aR) Name (Array sh e)
"in"
      (IRArray (Array sh b)
arrOut, [Parameter]
paramOut)        = ArrayR (Array sh b)
-> Name (Array sh b) -> (IRArray (Array sh b), [Parameter])
forall sh e.
ArrayR (Array sh e)
-> Name (Array sh e) -> (IRArray (Array sh e), [Parameter])
mutableArray (ShapeR sh -> TypeR b -> ArrayR (Array sh b)
forall sh e. ShapeR sh -> TypeR e -> ArrayR (Array sh e)
ArrayR ShapeR sh
shR TypeR b
bR) Name (Array sh b)
"out"
      paramEnv :: [Parameter]
paramEnv                  = Gamma aenv -> [Parameter]
forall aenv. Gamma aenv -> [Parameter]
envParam Gamma aenv
aenv
  in
  UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Array sh b))
forall aenv a.
UID
-> Label
-> [Parameter]
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv a)
makeOpenAcc UID
uid Label
"map" ([Parameter]
paramGang [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramOut [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramIn [Parameter] -> [Parameter] -> [Parameter]
forall a. [a] -> [a] -> [a]
++ [Parameter]
paramEnv) (CodeGen Native ()
 -> CodeGen Native (IROpenAcc Native aenv (Array sh b)))
-> CodeGen Native ()
-> CodeGen Native (IROpenAcc Native aenv (Array sh b))
forall a b. (a -> b) -> a -> b
$ do

    Operands Int
-> Operands Int
-> (Operands Int -> CodeGen Native ())
-> CodeGen Native ()
imapFromTo (Operands DIM1 -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands DIM1
start) (Operands DIM1 -> Operands Int
forall sh sz. Operands (sh, sz) -> Operands sz
indexHead Operands DIM1
end) ((Operands Int -> CodeGen Native ()) -> CodeGen Native ())
-> (Operands Int -> CodeGen Native ()) -> CodeGen Native ()
forall a b. (a -> b) -> a -> b
$ \Operands Int
i -> do
      Operands e
xs <- IntegralType Int
-> IRArray (Array sh e)
-> Operands Int
-> CodeGen Native (Operands e)
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> CodeGen arch (Operands e)
readArray IntegralType Int
TypeInt IRArray (Array sh e)
arrIn Operands Int
i
      Operands b
ys <- IRFun1 Native aenv (a -> b)
-> Operands a -> IROpenExp Native ((), a) aenv b
forall arch env aenv a b.
IROpenFun1 arch env aenv (a -> b)
-> Operands a -> IROpenExp arch (env, a) aenv b
app1 IRFun1 Native aenv (a -> b)
apply Operands a
Operands e
xs
      IntegralType Int
-> IRArray (Array sh b)
-> Operands Int
-> Operands b
-> CodeGen Native ()
forall int sh e arch.
IntegralType int
-> IRArray (Array sh e)
-> Operands int
-> Operands e
-> CodeGen arch ()
writeArray IntegralType Int
TypeInt IRArray (Array sh b)
arrOut Operands Int
i Operands b
ys

    CodeGen Native ()
forall arch. HasCallStack => CodeGen arch ()
return_