{-# LANGUAGE CPP          #-}
{-# LANGUAGE GADTs        #-}
{-# LANGUAGE LambdaCase   #-}
{-# LANGUAGE ViewPatterns #-}
-- |
-- Module      : Data.Array.Accelerate.Numeric.Sum.LLVM.Prim
-- Copyright   : [2017..2020] Trevor L. McDonell
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Numeric.Sum.LLVM.Prim (

  fadd, fsub, fmul,

) where

import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Sugar.Elt

import Data.Array.Accelerate.LLVM.CodeGen.IR                        ( Operands(..), IROP(..) )
import Data.Array.Accelerate.LLVM.CodeGen.Monad                     ( CodeGen, freshName, instr_ )
import Data.Array.Accelerate.LLVM.CodeGen.Sugar                     ( IROpenFun1(..) )

import LLVM.AST.Type.Downcast                                       ( downcast )
import qualified LLVM.AST.Type.Name                                 as A
import qualified LLVM.AST.Type.Operand                              as A
import qualified LLVM.AST.Type.Representation                       as A

import LLVM.AST.Instruction
import LLVM.AST.Name                                                ( Name(..) )
import LLVM.AST.Operand                                             ( Operand(..) )
import LLVM.AST.Type                                                ( Type(..), FloatingPointType(..) )

import Prelude                                                      hiding (uncurry)


uncurry :: (Operands a -> Operands b -> c) -> Operands (((), a), b) -> c
uncurry :: (Operands a -> Operands b -> c) -> Operands (((), a), b) -> c
uncurry Operands a -> Operands b -> c
f (OP_Unit `OP_Pair` x `OP_Pair` y) = Operands a -> Operands b -> c
f Operands a
x Operands b
y

-- | As (+), but don't allow potentially unsafe floating-point optimisations.
--
fadd :: FloatingType a -> IROpenFun1 arch env aenv ((((), EltR a), EltR a) -> EltR a)
fadd :: FloatingType a
-> IROpenFun1 arch env aenv ((((), EltR a), EltR a) -> EltR a)
fadd = \case
  FloatingType a
TypeHalf   -> (Operands (((), Half), Half)
 -> IROpenExp arch (env, (((), Half), Half)) aenv Half)
-> IROpenFun1 arch env aenv ((((), Half), Half) -> Half)
forall a arch env aenv b.
(Operands a -> IROpenExp arch (env, a) aenv b)
-> IROpenFun1 arch env aenv (a -> b)
IRFun1 ((Operands (((), Half), Half)
  -> IROpenExp arch (env, (((), Half), Half)) aenv Half)
 -> IROpenFun1 arch env aenv ((((), Half), Half) -> Half))
-> (Operands (((), Half), Half)
    -> IROpenExp arch (env, (((), Half), Half)) aenv Half)
-> IROpenFun1 arch env aenv ((((), Half), Half) -> Half)
forall a b. (a -> b) -> a -> b
$ (Operands Half
 -> Operands Half
 -> IROpenExp arch (env, (((), Half), Half)) aenv Half)
-> Operands (((), Half), Half)
-> IROpenExp arch (env, (((), Half), Half)) aenv Half
forall a b c.
(Operands a -> Operands b -> c) -> Operands (((), a), b) -> c
uncurry ((FastMathFlags
 -> Operand -> Operand -> InstructionMetadata -> Instruction)
-> FloatingType Half
-> Operands Half
-> Operands Half
-> IROpenExp arch (env, (((), Half), Half)) aenv Half
forall a arch.
(FastMathFlags
 -> Operand -> Operand -> InstructionMetadata -> Instruction)
-> FloatingType a
-> Operands a
-> Operands a
-> CodeGen arch (Operands a)
binop FastMathFlags
-> Operand -> Operand -> InstructionMetadata -> Instruction
FAdd FloatingType Half
TypeHalf)    -- the pattern match yields a ~ EltR a
  FloatingType a
TypeFloat  -> (Operands (((), Float), Float)
 -> IROpenExp arch (env, (((), Float), Float)) aenv Float)
-> IROpenFun1 arch env aenv ((((), Float), Float) -> Float)
forall a arch env aenv b.
(Operands a -> IROpenExp arch (env, a) aenv b)
-> IROpenFun1 arch env aenv (a -> b)
IRFun1 ((Operands (((), Float), Float)
  -> IROpenExp arch (env, (((), Float), Float)) aenv Float)
 -> IROpenFun1 arch env aenv ((((), Float), Float) -> Float))
-> (Operands (((), Float), Float)
    -> IROpenExp arch (env, (((), Float), Float)) aenv Float)
-> IROpenFun1 arch env aenv ((((), Float), Float) -> Float)
forall a b. (a -> b) -> a -> b
$ (Operands Float
 -> Operands Float
 -> IROpenExp arch (env, (((), Float), Float)) aenv Float)
-> Operands (((), Float), Float)
-> IROpenExp arch (env, (((), Float), Float)) aenv Float
forall a b c.
(Operands a -> Operands b -> c) -> Operands (((), a), b) -> c
uncurry ((FastMathFlags
 -> Operand -> Operand -> InstructionMetadata -> Instruction)
-> FloatingType Float
-> Operands Float
-> Operands Float
-> IROpenExp arch (env, (((), Float), Float)) aenv Float
forall a arch.
(FastMathFlags
 -> Operand -> Operand -> InstructionMetadata -> Instruction)
-> FloatingType a
-> Operands a
-> Operands a
-> CodeGen arch (Operands a)
binop FastMathFlags
-> Operand -> Operand -> InstructionMetadata -> Instruction
FAdd FloatingType Float
TypeFloat)
  FloatingType a
TypeDouble -> (Operands (((), Double), Double)
 -> IROpenExp arch (env, (((), Double), Double)) aenv Double)
-> IROpenFun1 arch env aenv ((((), Double), Double) -> Double)
forall a arch env aenv b.
(Operands a -> IROpenExp arch (env, a) aenv b)
-> IROpenFun1 arch env aenv (a -> b)
IRFun1 ((Operands (((), Double), Double)
  -> IROpenExp arch (env, (((), Double), Double)) aenv Double)
 -> IROpenFun1 arch env aenv ((((), Double), Double) -> Double))
-> (Operands (((), Double), Double)
    -> IROpenExp arch (env, (((), Double), Double)) aenv Double)
-> IROpenFun1 arch env aenv ((((), Double), Double) -> Double)
forall a b. (a -> b) -> a -> b
$ (Operands Double
 -> Operands Double
 -> IROpenExp arch (env, (((), Double), Double)) aenv Double)
-> Operands (((), Double), Double)
-> IROpenExp arch (env, (((), Double), Double)) aenv Double
forall a b c.
(Operands a -> Operands b -> c) -> Operands (((), a), b) -> c
uncurry ((FastMathFlags
 -> Operand -> Operand -> InstructionMetadata -> Instruction)
-> FloatingType Double
-> Operands Double
-> Operands Double
-> IROpenExp arch (env, (((), Double), Double)) aenv Double
forall a arch.
(FastMathFlags
 -> Operand -> Operand -> InstructionMetadata -> Instruction)
-> FloatingType a
-> Operands a
-> Operands a
-> CodeGen arch (Operands a)
binop FastMathFlags
-> Operand -> Operand -> InstructionMetadata -> Instruction
FAdd FloatingType Double
TypeDouble)

-- | As (-), but don't allow potentially unsafe floating-point optimisations.
--
fsub :: FloatingType a -> IROpenFun1 arch env aenv ((((), EltR a), EltR a) -> EltR a)
fsub :: FloatingType a
-> IROpenFun1 arch env aenv ((((), EltR a), EltR a) -> EltR a)
fsub = \case
  FloatingType a
TypeHalf   -> (Operands (((), Half), Half)
 -> IROpenExp arch (env, (((), Half), Half)) aenv Half)
-> IROpenFun1 arch env aenv ((((), Half), Half) -> Half)
forall a arch env aenv b.
(Operands a -> IROpenExp arch (env, a) aenv b)
-> IROpenFun1 arch env aenv (a -> b)
IRFun1 ((Operands (((), Half), Half)
  -> IROpenExp arch (env, (((), Half), Half)) aenv Half)
 -> IROpenFun1 arch env aenv ((((), Half), Half) -> Half))
-> (Operands (((), Half), Half)
    -> IROpenExp arch (env, (((), Half), Half)) aenv Half)
-> IROpenFun1 arch env aenv ((((), Half), Half) -> Half)
forall a b. (a -> b) -> a -> b
$ (Operands Half
 -> Operands Half
 -> IROpenExp arch (env, (((), Half), Half)) aenv Half)
-> Operands (((), Half), Half)
-> IROpenExp arch (env, (((), Half), Half)) aenv Half
forall a b c.
(Operands a -> Operands b -> c) -> Operands (((), a), b) -> c
uncurry ((FastMathFlags
 -> Operand -> Operand -> InstructionMetadata -> Instruction)
-> FloatingType Half
-> Operands Half
-> Operands Half
-> IROpenExp arch (env, (((), Half), Half)) aenv Half
forall a arch.
(FastMathFlags
 -> Operand -> Operand -> InstructionMetadata -> Instruction)
-> FloatingType a
-> Operands a
-> Operands a
-> CodeGen arch (Operands a)
binop FastMathFlags
-> Operand -> Operand -> InstructionMetadata -> Instruction
FSub FloatingType Half
TypeHalf)
  FloatingType a
TypeFloat  -> (Operands (((), Float), Float)
 -> IROpenExp arch (env, (((), Float), Float)) aenv Float)
-> IROpenFun1 arch env aenv ((((), Float), Float) -> Float)
forall a arch env aenv b.
(Operands a -> IROpenExp arch (env, a) aenv b)
-> IROpenFun1 arch env aenv (a -> b)
IRFun1 ((Operands (((), Float), Float)
  -> IROpenExp arch (env, (((), Float), Float)) aenv Float)
 -> IROpenFun1 arch env aenv ((((), Float), Float) -> Float))
-> (Operands (((), Float), Float)
    -> IROpenExp arch (env, (((), Float), Float)) aenv Float)
-> IROpenFun1 arch env aenv ((((), Float), Float) -> Float)
forall a b. (a -> b) -> a -> b
$ (Operands Float
 -> Operands Float
 -> IROpenExp arch (env, (((), Float), Float)) aenv Float)
-> Operands (((), Float), Float)
-> IROpenExp arch (env, (((), Float), Float)) aenv Float
forall a b c.
(Operands a -> Operands b -> c) -> Operands (((), a), b) -> c
uncurry ((FastMathFlags
 -> Operand -> Operand -> InstructionMetadata -> Instruction)
-> FloatingType Float
-> Operands Float
-> Operands Float
-> IROpenExp arch (env, (((), Float), Float)) aenv Float
forall a arch.
(FastMathFlags
 -> Operand -> Operand -> InstructionMetadata -> Instruction)
-> FloatingType a
-> Operands a
-> Operands a
-> CodeGen arch (Operands a)
binop FastMathFlags
-> Operand -> Operand -> InstructionMetadata -> Instruction
FSub FloatingType Float
TypeFloat)
  FloatingType a
TypeDouble -> (Operands (((), Double), Double)
 -> IROpenExp arch (env, (((), Double), Double)) aenv Double)
-> IROpenFun1 arch env aenv ((((), Double), Double) -> Double)
forall a arch env aenv b.
(Operands a -> IROpenExp arch (env, a) aenv b)
-> IROpenFun1 arch env aenv (a -> b)
IRFun1 ((Operands (((), Double), Double)
  -> IROpenExp arch (env, (((), Double), Double)) aenv Double)
 -> IROpenFun1 arch env aenv ((((), Double), Double) -> Double))
-> (Operands (((), Double), Double)
    -> IROpenExp arch (env, (((), Double), Double)) aenv Double)
-> IROpenFun1 arch env aenv ((((), Double), Double) -> Double)
forall a b. (a -> b) -> a -> b
$ (Operands Double
 -> Operands Double
 -> IROpenExp arch (env, (((), Double), Double)) aenv Double)
-> Operands (((), Double), Double)
-> IROpenExp arch (env, (((), Double), Double)) aenv Double
forall a b c.
(Operands a -> Operands b -> c) -> Operands (((), a), b) -> c
uncurry ((FastMathFlags
 -> Operand -> Operand -> InstructionMetadata -> Instruction)
-> FloatingType Double
-> Operands Double
-> Operands Double
-> IROpenExp arch (env, (((), Double), Double)) aenv Double
forall a arch.
(FastMathFlags
 -> Operand -> Operand -> InstructionMetadata -> Instruction)
-> FloatingType a
-> Operands a
-> Operands a
-> CodeGen arch (Operands a)
binop FastMathFlags
-> Operand -> Operand -> InstructionMetadata -> Instruction
FSub FloatingType Double
TypeDouble)

-- | As (*), but don't allow potentially unsafe floating-point optimisations.
--
fmul :: FloatingType a -> IROpenFun1 arch env aenv ((((), EltR a), EltR a) -> EltR a)
fmul :: FloatingType a
-> IROpenFun1 arch env aenv ((((), EltR a), EltR a) -> EltR a)
fmul = \case
  FloatingType a
TypeHalf   -> (Operands (((), Half), Half)
 -> IROpenExp arch (env, (((), Half), Half)) aenv Half)
-> IROpenFun1 arch env aenv ((((), Half), Half) -> Half)
forall a arch env aenv b.
(Operands a -> IROpenExp arch (env, a) aenv b)
-> IROpenFun1 arch env aenv (a -> b)
IRFun1 ((Operands (((), Half), Half)
  -> IROpenExp arch (env, (((), Half), Half)) aenv Half)
 -> IROpenFun1 arch env aenv ((((), Half), Half) -> Half))
-> (Operands (((), Half), Half)
    -> IROpenExp arch (env, (((), Half), Half)) aenv Half)
-> IROpenFun1 arch env aenv ((((), Half), Half) -> Half)
forall a b. (a -> b) -> a -> b
$ (Operands Half
 -> Operands Half
 -> IROpenExp arch (env, (((), Half), Half)) aenv Half)
-> Operands (((), Half), Half)
-> IROpenExp arch (env, (((), Half), Half)) aenv Half
forall a b c.
(Operands a -> Operands b -> c) -> Operands (((), a), b) -> c
uncurry ((FastMathFlags
 -> Operand -> Operand -> InstructionMetadata -> Instruction)
-> FloatingType Half
-> Operands Half
-> Operands Half
-> IROpenExp arch (env, (((), Half), Half)) aenv Half
forall a arch.
(FastMathFlags
 -> Operand -> Operand -> InstructionMetadata -> Instruction)
-> FloatingType a
-> Operands a
-> Operands a
-> CodeGen arch (Operands a)
binop FastMathFlags
-> Operand -> Operand -> InstructionMetadata -> Instruction
FMul FloatingType Half
TypeHalf)
  FloatingType a
TypeFloat  -> (Operands (((), Float), Float)
 -> IROpenExp arch (env, (((), Float), Float)) aenv Float)
-> IROpenFun1 arch env aenv ((((), Float), Float) -> Float)
forall a arch env aenv b.
(Operands a -> IROpenExp arch (env, a) aenv b)
-> IROpenFun1 arch env aenv (a -> b)
IRFun1 ((Operands (((), Float), Float)
  -> IROpenExp arch (env, (((), Float), Float)) aenv Float)
 -> IROpenFun1 arch env aenv ((((), Float), Float) -> Float))
-> (Operands (((), Float), Float)
    -> IROpenExp arch (env, (((), Float), Float)) aenv Float)
-> IROpenFun1 arch env aenv ((((), Float), Float) -> Float)
forall a b. (a -> b) -> a -> b
$ (Operands Float
 -> Operands Float
 -> IROpenExp arch (env, (((), Float), Float)) aenv Float)
-> Operands (((), Float), Float)
-> IROpenExp arch (env, (((), Float), Float)) aenv Float
forall a b c.
(Operands a -> Operands b -> c) -> Operands (((), a), b) -> c
uncurry ((FastMathFlags
 -> Operand -> Operand -> InstructionMetadata -> Instruction)
-> FloatingType Float
-> Operands Float
-> Operands Float
-> IROpenExp arch (env, (((), Float), Float)) aenv Float
forall a arch.
(FastMathFlags
 -> Operand -> Operand -> InstructionMetadata -> Instruction)
-> FloatingType a
-> Operands a
-> Operands a
-> CodeGen arch (Operands a)
binop FastMathFlags
-> Operand -> Operand -> InstructionMetadata -> Instruction
FMul FloatingType Float
TypeFloat)
  FloatingType a
TypeDouble -> (Operands (((), Double), Double)
 -> IROpenExp arch (env, (((), Double), Double)) aenv Double)
-> IROpenFun1 arch env aenv ((((), Double), Double) -> Double)
forall a arch env aenv b.
(Operands a -> IROpenExp arch (env, a) aenv b)
-> IROpenFun1 arch env aenv (a -> b)
IRFun1 ((Operands (((), Double), Double)
  -> IROpenExp arch (env, (((), Double), Double)) aenv Double)
 -> IROpenFun1 arch env aenv ((((), Double), Double) -> Double))
-> (Operands (((), Double), Double)
    -> IROpenExp arch (env, (((), Double), Double)) aenv Double)
-> IROpenFun1 arch env aenv ((((), Double), Double) -> Double)
forall a b. (a -> b) -> a -> b
$ (Operands Double
 -> Operands Double
 -> IROpenExp arch (env, (((), Double), Double)) aenv Double)
-> Operands (((), Double), Double)
-> IROpenExp arch (env, (((), Double), Double)) aenv Double
forall a b c.
(Operands a -> Operands b -> c) -> Operands (((), a), b) -> c
uncurry ((FastMathFlags
 -> Operand -> Operand -> InstructionMetadata -> Instruction)
-> FloatingType Double
-> Operands Double
-> Operands Double
-> IROpenExp arch (env, (((), Double), Double)) aenv Double
forall a arch.
(FastMathFlags
 -> Operand -> Operand -> InstructionMetadata -> Instruction)
-> FloatingType a
-> Operands a
-> Operands a
-> CodeGen arch (Operands a)
binop FastMathFlags
-> Operand -> Operand -> InstructionMetadata -> Instruction
FMul FloatingType Double
TypeDouble)

binop :: (FastMathFlags -> Operand -> Operand -> InstructionMetadata -> Instruction)
      -> FloatingType a
      -> Operands a
      -> Operands a
      -> CodeGen arch (Operands a)
binop :: (FastMathFlags
 -> Operand -> Operand -> InstructionMetadata -> Instruction)
-> FloatingType a
-> Operands a
-> Operands a
-> CodeGen arch (Operands a)
binop FastMathFlags
-> Operand -> Operand -> InstructionMetadata -> Instruction
f FloatingType a
t (FloatingType a -> Operands a -> Operand a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op FloatingType a
t -> Operand a
x) (FloatingType a -> Operands a -> Operand a
forall (dict :: * -> *) a.
(IROP dict, HasCallStack) =>
dict a -> Operands a -> Operand a
op FloatingType a
t -> Operand a
y) = do
  Operand
r <- Type -> Instruction -> CodeGen arch Operand
forall arch. Type -> Instruction -> CodeGen arch Operand
instr (FloatingType a -> Type
forall typed untyped.
(Downcast typed untyped, HasCallStack) =>
typed -> untyped
downcast FloatingType a
t) (FastMathFlags
-> Operand -> Operand -> InstructionMetadata -> Instruction
f FastMathFlags
fmf (Operand a -> Operand
forall typed untyped.
(Downcast typed untyped, HasCallStack) =>
typed -> untyped
downcast Operand a
x) (Operand a -> Operand
forall typed untyped.
(Downcast typed untyped, HasCallStack) =>
typed -> untyped
downcast Operand a
y) InstructionMetadata
md)
  Operands a -> CodeGen arch (Operands a)
forall (m :: * -> *) a. Monad m => a -> m a
return (FloatingType a -> Operand -> Operands a
forall t. FloatingType t -> Operand -> Operands t
upcast FloatingType a
t Operand
r)


-- Prim
-- ----

md :: InstructionMetadata
md :: InstructionMetadata
md = []

fmf :: FastMathFlags
fmf :: FastMathFlags
fmf = FastMathFlags
noFastMathFlags

fresh :: CodeGen arch Name
fresh :: CodeGen arch Name
fresh = Name Any -> Name
forall typed untyped.
(Downcast typed untyped, HasCallStack) =>
typed -> untyped
downcast (Name Any -> Name) -> CodeGen arch (Name Any) -> CodeGen arch Name
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> CodeGen arch (Name Any)
forall arch a. CodeGen arch (Name a)
freshName

instr :: Type -> Instruction -> CodeGen arch Operand
instr :: Type -> Instruction -> CodeGen arch Operand
instr Type
ty Instruction
ins = do
  Name
name <- CodeGen arch Name
forall arch. CodeGen arch Name
fresh
  Named Instruction -> CodeGen arch ()
forall arch. HasCallStack => Named Instruction -> CodeGen arch ()
instr_ (Name
name Name -> Instruction -> Named Instruction
forall a. Name -> a -> Named a
:= Instruction
ins)
  Operand -> CodeGen arch Operand
forall (m :: * -> *) a. Monad m => a -> m a
return (Type -> Name -> Operand
LocalReference Type
ty Name
name)

upcast :: FloatingType t -> Operand -> Operands t
upcast :: FloatingType t -> Operand -> Operands t
upcast TypeHalf{}    (LocalReference (FloatingPointType FloatingPointType
HalfFP)   (UnName Word
x)) = Operand Half -> Operands Half
OP_Half    (Type Half -> Name Half -> Operand Half
forall a. Type a -> Name a -> Operand a
A.LocalReference Type Half
forall a. IsType a => Type a
A.type' (Word -> Name Half
forall a. Word -> Name a
A.UnName Word
x))
upcast TypeFloat{}   (LocalReference (FloatingPointType FloatingPointType
FloatFP)  (UnName Word
x)) = Operand Float -> Operands Float
OP_Float   (Type Float -> Name Float -> Operand Float
forall a. Type a -> Name a -> Operand a
A.LocalReference Type Float
forall a. IsType a => Type a
A.type' (Word -> Name Float
forall a. Word -> Name a
A.UnName Word
x))
upcast TypeDouble{}  (LocalReference (FloatingPointType FloatingPointType
DoubleFP) (UnName Word
x)) = Operand Double -> Operands Double
OP_Double  (Type Double -> Name Double -> Operand Double
forall a. Type a -> Name a -> Operand a
A.LocalReference Type Double
forall a. IsType a => Type a
A.type' (Word -> Name Double
forall a. Word -> Name a
A.UnName Word
x))
upcast FloatingType t
_ Operand
_ = String -> String -> Operands t
forall a. HasCallStack => String -> a
internalError String
"upcast" String
"expected local reference"