{-# LANGUAGE QuasiQuotes #-}

-- | Translation of ImpCode Exp and Code to C.
module Futhark.CodeGen.Backends.GenericC.Code
  ( compilePrimExp,
    compileExp,
    compileExpToName,
    compileCode,
    compileDest,
    compileArg,
    compileLMADCopy,
    compileLMADCopyWith,
    errorMsgString,
    linearCode,
  )
where

import Control.Monad
import Control.Monad.Reader (asks)
import Data.Map qualified as M
import Data.Maybe
import Data.Text qualified as T
import Futhark.CodeGen.Backends.GenericC.Monad
import Futhark.CodeGen.ImpCode
import Futhark.IR.Prop (isBuiltInFunction)
import Futhark.MonadFreshNames
import Language.C.Quote.OpenCL qualified as C
import Language.C.Syntax qualified as C

errorMsgString :: ErrorMsg Exp -> CompilerM op s (String, [C.Exp])
errorMsgString :: forall op s. ErrorMsg Exp -> CompilerM op s (String, [Exp])
errorMsgString (ErrorMsg [ErrorMsgPart Exp]
parts) = do
  let boolStr :: a -> Exp
boolStr a
e = [C.cexp|($exp:e) ? "true" : "false"|]
      asLongLong :: a -> Exp
asLongLong a
e = [C.cexp|(long long int)$exp:e|]
      asDouble :: a -> Exp
asDouble a
e = [C.cexp|(double)$exp:e|]
      onPart :: ErrorMsgPart Exp -> CompilerM op s (a, Exp)
onPart (ErrorString Text
s) = (a, Exp) -> CompilerM op s (a, Exp)
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
"%s", [C.cexp|$string:(T.unpack s)|])
      onPart (ErrorVal PrimType
Bool Exp
x) = (a
"%s",) (Exp -> (a, Exp)) -> (Exp -> Exp) -> Exp -> (a, Exp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> Exp
forall {a}. ToExp a => a -> Exp
boolStr (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
      onPart (ErrorVal PrimType
Unit Exp
_) = (a, Exp) -> CompilerM op s (a, Exp)
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
"%s", [C.cexp|"()"|])
      onPart (ErrorVal (IntType IntType
Int8) Exp
x) = (a
"%hhd",) (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
      onPart (ErrorVal (IntType IntType
Int16) Exp
x) = (a
"%hd",) (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
      onPart (ErrorVal (IntType IntType
Int32) Exp
x) = (a
"%d",) (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
      onPart (ErrorVal (IntType IntType
Int64) Exp
x) = (a
"%lld",) (Exp -> (a, Exp)) -> (Exp -> Exp) -> Exp -> (a, Exp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> Exp
forall {a}. ToExp a => a -> Exp
asLongLong (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
      onPart (ErrorVal (FloatType FloatType
Float16) Exp
x) = (a
"%f",) (Exp -> (a, Exp)) -> (Exp -> Exp) -> Exp -> (a, Exp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> Exp
forall {a}. ToExp a => a -> Exp
asDouble (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
      onPart (ErrorVal (FloatType FloatType
Float32) Exp
x) = (a
"%f",) (Exp -> (a, Exp)) -> (Exp -> Exp) -> Exp -> (a, Exp)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp -> Exp
forall {a}. ToExp a => a -> Exp
asDouble (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
      onPart (ErrorVal (FloatType FloatType
Float64) Exp
x) = (a
"%f",) (Exp -> (a, Exp)) -> CompilerM op s Exp -> CompilerM op s (a, Exp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
x
  ([String]
formatstrs, [Exp]
formatargs) <- (ErrorMsgPart Exp -> CompilerM op s (String, Exp))
-> [ErrorMsgPart Exp] -> CompilerM op s ([String], [Exp])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM ErrorMsgPart Exp -> CompilerM op s (String, Exp)
forall {a} {op} {s}.
IsString a =>
ErrorMsgPart Exp -> CompilerM op s (a, Exp)
onPart [ErrorMsgPart Exp]
parts
  (String, [Exp]) -> CompilerM op s (String, [Exp])
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([String] -> String
forall a. Monoid a => [a] -> a
mconcat [String]
formatstrs, [Exp]
formatargs)

compileExpToName :: String -> PrimType -> Exp -> CompilerM op s VName
compileExpToName :: forall op s. String -> PrimType -> Exp -> CompilerM op s VName
compileExpToName String
_ PrimType
_ (LeafExp VName
v PrimType
_) =
  VName -> CompilerM op s VName
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
compileExpToName String
desc PrimType
t Exp
e = do
  VName
desc' <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
  Exp
e' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
  InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$ty:(primTypeToCType t) $id:desc' = $e';|]
  VName -> CompilerM op s VName
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
desc'

compileExp :: Exp -> CompilerM op s C.Exp
compileExp :: forall op s. Exp -> CompilerM op s Exp
compileExp = (VName -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp ((VName -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp)
-> (VName -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ \VName
v -> Exp -> CompilerM op s Exp
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$id:v|]

-- | Tell me how to compile a @v@, and I'll Compile any @PrimExp v@ for you.
compilePrimExp :: (Monad m) => (v -> m C.Exp) -> PrimExp v -> m C.Exp
compilePrimExp :: forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
_ (ValueExp PrimValue
val) =
  Exp -> m Exp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SrcLoc -> Exp
forall a. ToExp a => a -> SrcLoc -> Exp
C.toExp PrimValue
val SrcLoc
forall a. Monoid a => a
mempty
compilePrimExp v -> m Exp
f (LeafExp v
v PrimType
_) =
  v -> m Exp
f v
v
compilePrimExp v -> m Exp
f (UnOpExp Complement {} PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|~$exp:x'|]
compilePrimExp v -> m Exp
f (UnOpExp Not {} PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|!$exp:x'|]
compilePrimExp v -> m Exp
f (UnOpExp (FAbs FloatType
Float32) PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|(float)fabs($exp:x')|]
compilePrimExp v -> m Exp
f (UnOpExp (FAbs FloatType
Float64) PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|fabs($exp:x')|]
compilePrimExp v -> m Exp
f (UnOpExp SSignum {} PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|($exp:x' > 0 ? 1 : 0) - ($exp:x' < 0 ? 1 : 0)|]
compilePrimExp v -> m Exp
f (UnOpExp USignum {} PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|($exp:x' > 0 ? 1 : 0) - ($exp:x' < 0 ? 1 : 0) != 0|]
compilePrimExp v -> m Exp
f (UnOpExp UnOp
op PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$id:(prettyString op)($exp:x')|]
compilePrimExp v -> m Exp
f (CmpOpExp CmpOp
cmp PrimExp v
x PrimExp v
y) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp
y' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
y
  Exp -> m Exp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ case CmpOp
cmp of
    CmpEq {} -> [C.cexp|$exp:x' == $exp:y'|]
    FCmpLt {} -> [C.cexp|$exp:x' < $exp:y'|]
    FCmpLe {} -> [C.cexp|$exp:x' <= $exp:y'|]
    CmpLlt {} -> [C.cexp|$exp:x' < $exp:y'|]
    CmpLle {} -> [C.cexp|$exp:x' <= $exp:y'|]
    CmpOp
_ -> [C.cexp|$id:(prettyString cmp)($exp:x', $exp:y')|]
compilePrimExp v -> m Exp
f (ConvOpExp ConvOp
conv PrimExp v
x) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp -> m Exp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$id:(prettyString conv)($exp:x')|]
compilePrimExp v -> m Exp
f (BinOpExp BinOp
bop PrimExp v
x PrimExp v
y) = do
  Exp
x' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
x
  Exp
y' <- (v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f PrimExp v
y
  -- Note that integer addition, subtraction, and multiplication with
  -- OverflowWrap are not handled by explicit operators, but rather by
  -- functions.  This is because we want to implicitly convert them to
  -- unsigned numbers, so we can do overflow without invoking
  -- undefined behaviour.
  Exp -> m Exp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> m Exp) -> Exp -> m Exp
forall a b. (a -> b) -> a -> b
$ case BinOp
bop of
    Add IntType
_ Overflow
OverflowUndef -> [C.cexp|$exp:x' + $exp:y'|]
    Sub IntType
_ Overflow
OverflowUndef -> [C.cexp|$exp:x' - $exp:y'|]
    Mul IntType
_ Overflow
OverflowUndef -> [C.cexp|$exp:x' * $exp:y'|]
    FAdd {} -> [C.cexp|$exp:x' + $exp:y'|]
    FSub {} -> [C.cexp|$exp:x' - $exp:y'|]
    FMul {} -> [C.cexp|$exp:x' * $exp:y'|]
    FDiv {} -> [C.cexp|$exp:x' / $exp:y'|]
    Xor {} -> [C.cexp|$exp:x' ^ $exp:y'|]
    And {} -> [C.cexp|$exp:x' & $exp:y'|]
    Or {} -> [C.cexp|$exp:x' | $exp:y'|]
    LogAnd {} -> [C.cexp|$exp:x' && $exp:y'|]
    LogOr {} -> [C.cexp|$exp:x' || $exp:y'|]
    BinOp
_ -> [C.cexp|$id:(prettyString bop)($exp:x', $exp:y')|]
compilePrimExp v -> m Exp
f (FunExp String
h [PrimExp v]
args PrimType
_) = do
  [Exp]
args' <- (PrimExp v -> m Exp) -> [PrimExp v] -> m [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((v -> m Exp) -> PrimExp v -> m Exp
forall (m :: * -> *) v.
Monad m =>
(v -> m Exp) -> PrimExp v -> m Exp
compilePrimExp v -> m Exp
f) [PrimExp v]
args
  Exp -> m Exp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$id:(funName (nameFromString h))($args:args')|]

linearCode :: Code op -> [Code op]
linearCode :: forall op. Code op -> [Code op]
linearCode = [Code op] -> [Code op]
forall a. [a] -> [a]
reverse ([Code op] -> [Code op])
-> (Code op -> [Code op]) -> Code op -> [Code op]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Code op] -> Code op -> [Code op]
forall {a}. [Code a] -> Code a -> [Code a]
go []
  where
    go :: [Code a] -> Code a -> [Code a]
go [Code a]
acc (Code a
x :>>: Code a
y) =
      [Code a] -> Code a -> [Code a]
go ([Code a] -> Code a -> [Code a]
go [Code a]
acc Code a
x) Code a
y
    go [Code a]
acc Code a
x = Code a
x Code a -> [Code a] -> [Code a]
forall a. a -> [a] -> [a]
: [Code a]
acc

assignmentOperator :: BinOp -> Maybe (VName -> C.Exp -> C.Exp)
assignmentOperator :: BinOp -> Maybe (VName -> Exp -> Exp)
assignmentOperator Add {} = (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a. a -> Maybe a
Just ((VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp))
-> (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a b. (a -> b) -> a -> b
$ \VName
d Exp
e -> [C.cexp|$id:d += $exp:e|]
assignmentOperator Sub {} = (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a. a -> Maybe a
Just ((VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp))
-> (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a b. (a -> b) -> a -> b
$ \VName
d Exp
e -> [C.cexp|$id:d -= $exp:e|]
assignmentOperator Mul {} = (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a. a -> Maybe a
Just ((VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp))
-> (VName -> Exp -> Exp) -> Maybe (VName -> Exp -> Exp)
forall a b. (a -> b) -> a -> b
$ \VName
d Exp
e -> [C.cexp|$id:d *= $exp:e|]
assignmentOperator BinOp
_ = Maybe (VName -> Exp -> Exp)
forall a. Maybe a
Nothing

generateRead ::
  C.Exp ->
  C.Exp ->
  PrimType ->
  Space ->
  Volatility ->
  CompilerM op s C.Exp
generateRead :: forall op s.
Exp -> Exp -> PrimType -> Space -> Volatility -> CompilerM op s Exp
generateRead Exp
_ Exp
_ PrimType
Unit Space
_ Volatility
_ =
  Exp -> CompilerM op s Exp
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$exp:(UnitValue)|]
generateRead Exp
src Exp
iexp PrimType
_ ScalarSpace {} Volatility
_ =
  Exp -> CompilerM op s Exp
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$exp:src[$exp:iexp]|]
generateRead Exp
src Exp
iexp PrimType
restype Space
DefaultSpace Volatility
vol =
  Exp -> CompilerM op s Exp
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp -> CompilerM op s Exp)
-> (Exp -> Exp) -> Exp -> CompilerM op s Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> Exp -> Exp
fromStorage PrimType
restype (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$
    Exp -> Exp -> Type -> Exp
derefPointer
      Exp
src
      Exp
iexp
      [C.cty|$tyquals:(volQuals vol) $ty:(primStorageType restype)*|]
generateRead Exp
src Exp
iexp PrimType
restype (Space String
space) Volatility
vol = do
  ReadScalar op s
reader <- (CompilerEnv op s -> ReadScalar op s)
-> CompilerM op s (ReadScalar op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Operations op s -> ReadScalar op s
forall op s. Operations op s -> ReadScalar op s
opsReadScalar (Operations op s -> ReadScalar op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> ReadScalar op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations)
  PrimType -> Exp -> Exp
fromStorage PrimType
restype (Exp -> Exp) -> CompilerM op s Exp -> CompilerM op s Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ReadScalar op s
reader Exp
src Exp
iexp (PrimType -> Type
primStorageType PrimType
restype) String
space Volatility
vol

generateWrite ::
  C.Exp ->
  C.Exp ->
  PrimType ->
  Space ->
  Volatility ->
  C.Exp ->
  CompilerM op s ()
generateWrite :: forall op s.
Exp
-> Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> CompilerM op s ()
generateWrite Exp
_ Exp
_ PrimType
Unit Space
_ Volatility
_ Exp
_ = () -> CompilerM op s ()
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
generateWrite Exp
dest Exp
idx PrimType
_ ScalarSpace {} Volatility
_ Exp
elemexp = do
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:dest[$exp:idx] = $exp:elemexp;|]
generateWrite Exp
dest Exp
idx PrimType
elemtype Space
DefaultSpace Volatility
vol Exp
elemexp = do
  let deref :: Exp
deref =
        Exp -> Exp -> Type -> Exp
derefPointer
          Exp
dest
          Exp
idx
          [C.cty|$tyquals:(volQuals vol) $ty:(primStorageType elemtype)*|]
      elemexp' :: Exp
elemexp' = PrimType -> Exp -> Exp
toStorage PrimType
elemtype Exp
elemexp
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:deref = $exp:elemexp';|]
generateWrite Exp
dest Exp
idx PrimType
elemtype (Space String
space) Volatility
vol Exp
elemexp = do
  WriteScalar op s
writer <- (CompilerEnv op s -> WriteScalar op s)
-> CompilerM op s (WriteScalar op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Operations op s -> WriteScalar op s
forall op s. Operations op s -> WriteScalar op s
opsWriteScalar (Operations op s -> WriteScalar op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> WriteScalar op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations)
  WriteScalar op s
writer Exp
dest Exp
idx (PrimType -> Type
primStorageType PrimType
elemtype) String
space Volatility
vol (PrimType -> Exp -> Exp
toStorage PrimType
elemtype Exp
elemexp)

compileRead ::
  VName ->
  Count u (TPrimExp t VName) ->
  PrimType ->
  Space ->
  Volatility ->
  CompilerM op s C.Exp
compileRead :: forall {k} {k} (u :: k) (t :: k) op s.
VName
-> Count u (TPrimExp t VName)
-> PrimType
-> Space
-> Volatility
-> CompilerM op s Exp
compileRead VName
src (Count TPrimExp t VName
iexp) PrimType
restype Space
space Volatility
vol = do
  Exp
src' <- VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
src
  Exp
iexp' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TPrimExp t VName -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp t VName
iexp)
  Exp -> Exp -> PrimType -> Space -> Volatility -> CompilerM op s Exp
forall op s.
Exp -> Exp -> PrimType -> Space -> Volatility -> CompilerM op s Exp
generateRead Exp
src' Exp
iexp' PrimType
restype Space
space Volatility
vol

memNeedsWrapping :: VName -> CompilerM op s Bool
memNeedsWrapping :: forall op s. VName -> CompilerM op s Bool
memNeedsWrapping VName
v = do
  Bool
refcount <- Space -> CompilerM op s Bool
forall op s. Space -> CompilerM op s Bool
fatMemory Space
DefaultSpace
  Bool
cached <- Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
v
  Bool -> CompilerM op s Bool
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bool -> CompilerM op s Bool) -> Bool -> CompilerM op s Bool
forall a b. (a -> b) -> a -> b
$ Bool
refcount Bool -> Bool -> Bool
&& Bool
cached

-- | Compile an argument to a function applicaiton.
compileArg :: Arg -> CompilerM op s C.Exp
compileArg :: forall op s. Arg -> CompilerM op s Exp
compileArg (MemArg VName
m) = do
  -- Function might expect fat memory, so if this is a lexical/cached
  -- raw pointer, we have to wrap it in a struct.
  Bool
wrap <- VName -> CompilerM op s Bool
forall op s. VName -> CompilerM op s Bool
memNeedsWrapping VName
m
  if Bool
wrap
    then Exp -> CompilerM op s Exp
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|($ty:(fatMemType DefaultSpace)) {.references = NULL, .mem = $exp:m}|]
    else Exp -> CompilerM op s Exp
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [C.cexp|$exp:m|]
compileArg (ExpArg Exp
e) = Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e

-- | Prepare a destination for function application.
compileDest :: VName -> CompilerM op s (VName, [C.Stm])
compileDest :: forall op s. VName -> CompilerM op s (VName, [Stm])
compileDest VName
v = do
  -- Function result be fat memory, so if target is a raw pointer, we
  -- have to wrap it in a struct and unwrap it afterwards.
  Bool
wrap <- VName -> CompilerM op s Bool
forall op s. VName -> CompilerM op s Bool
memNeedsWrapping VName
v
  if Bool
wrap
    then do
      VName
v' <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> CompilerM op s VName) -> String -> CompilerM op s VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
v String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
"_struct"
      BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$ty:(fatMemType DefaultSpace) $id:v' = {.references = NULL, .mem = $exp:v};|]
      (VName, [Stm]) -> CompilerM op s (VName, [Stm])
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
v', [C.cstms|$id:v = $id:v'.mem;|])
    else (VName, [Stm]) -> CompilerM op s (VName, [Stm])
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (VName
v, [Stm]
forall a. Monoid a => a
mempty)

compileCode :: Code op -> CompilerM op s ()
compileCode :: forall op s. Code op -> CompilerM op s ()
compileCode (Op op
op) =
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ (CompilerEnv op s -> op -> CompilerM op s ())
-> CompilerM op s (op -> CompilerM op s ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Operations op s -> op -> CompilerM op s ()
forall op s. Operations op s -> OpCompiler op s
opsCompiler (Operations op s -> op -> CompilerM op s ())
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> op
-> CompilerM op s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations) CompilerM op s (op -> CompilerM op s ())
-> CompilerM op s op -> CompilerM op s (CompilerM op s ())
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> op -> CompilerM op s op
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure op
op
compileCode Code op
Skip = () -> CompilerM op s ()
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
compileCode (Comment Text
s Code op
code) = do
  [BlockItem]
xs <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
code
  let comment :: String
comment = String
"// " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Text -> String
T.unpack Text
s
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|$comment:comment
              { $items:xs }
             |]
compileCode (TracePrint ErrorMsg Exp
msg) = do
  (String
formatstr, [Exp]
formatargs) <- ErrorMsg Exp -> CompilerM op s (String, [Exp])
forall op s. ErrorMsg Exp -> CompilerM op s (String, [Exp])
errorMsgString ErrorMsg Exp
msg
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|fprintf(ctx->log, $string:formatstr, $args:formatargs);|]
compileCode (DebugPrint String
s (Just Exp
e)) = do
  Exp
e' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|if (ctx->debugging) {
          fprintf(ctx->log, $string:fmtstr, $exp:s, ($ty:ety)$exp:e', '\n');
       }|]
  where
    (String
fmt, Type
ety) = case Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
e of
      IntType IntType
_ -> (String
"llu", [C.cty|long long int|])
      FloatType FloatType
_ -> (String
"f", [C.cty|double|])
      PrimType
_ -> (String
"d", [C.cty|int|])
    fmtstr :: String
fmtstr = String
"%s: %" String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
fmt String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"%c"
compileCode (DebugPrint String
s Maybe Exp
Nothing) =
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|if (ctx->debugging) {
          fprintf(ctx->log, "%s\n", $exp:s);
       }|]
-- :>>: is treated in a special way to detect declare-set pairs in
-- order to generate prettier code.
compileCode (Code op
c1 :>>: Code op
c2) = [Code op] -> CompilerM op s ()
forall {op} {s}. [Code op] -> CompilerM op s ()
go (Code op -> [Code op]
forall op. Code op -> [Code op]
linearCode (Code op
c1 Code op -> Code op -> Code op
forall a. Code a -> Code a -> Code a
:>>: Code op
c2))
  where
    go :: [Code op] -> CompilerM op s ()
go (DeclareScalar VName
name Volatility
vol PrimType
t : SetScalar VName
dest Exp
e : [Code op]
code)
      | VName
name VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest = do
          let ct :: Type
ct = PrimType -> Type
primTypeToCType PrimType
t
          Exp
e' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
          BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$tyquals:(volQuals vol) $ty:ct $id:name = $exp:e';|]
          [Code op] -> CompilerM op s ()
go [Code op]
code
    go (DeclareScalar VName
name Volatility
vol PrimType
t : Read VName
dest VName
src Count Elements (TExp Int64)
i PrimType
restype Space
space Volatility
read_vol : [Code op]
code)
      | VName
name VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest = do
          let ct :: Type
ct = PrimType -> Type
primTypeToCType PrimType
t
          Exp
e <- VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> CompilerM op s Exp
forall {k} {k} (u :: k) (t :: k) op s.
VName
-> Count u (TPrimExp t VName)
-> PrimType
-> Space
-> Volatility
-> CompilerM op s Exp
compileRead VName
src Count Elements (TExp Int64)
i PrimType
restype Space
space Volatility
read_vol
          BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$tyquals:(volQuals vol) $ty:ct $id:name = $exp:e;|]
          [Code op] -> CompilerM op s ()
go [Code op]
code
    go (DeclareScalar VName
name Volatility
vol PrimType
t : Call [VName
dest] Name
fname [Arg]
args : [Code op]
code)
      | VName
name VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
dest,
        Name -> Bool
isBuiltInFunction Name
fname = do
          let ct :: Type
ct = PrimType -> Type
primTypeToCType PrimType
t
          [Exp]
args' <- (Arg -> CompilerM op s Exp) -> [Arg] -> CompilerM op s [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Arg -> CompilerM op s Exp
forall op s. Arg -> CompilerM op s Exp
compileArg [Arg]
args
          BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item [C.citem|$tyquals:(volQuals vol) $ty:ct $id:name = $id:(funName fname)($args:args');|]
          [Code op] -> CompilerM op s ()
go [Code op]
code
    go (Code op
x : [Code op]
xs) = Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
x CompilerM op s () -> CompilerM op s () -> CompilerM op s ()
forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> [Code op] -> CompilerM op s ()
go [Code op]
xs
    go [] = () -> CompilerM op s ()
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
compileCode (Assert Exp
e ErrorMsg Exp
msg (SrcLoc
loc, [SrcLoc]
locs)) = do
  Exp
e' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
  [BlockItem]
err <-
    CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ())
-> CompilerM op s [BlockItem]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s [BlockItem])
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$
      (CompilerEnv op s -> ErrorMsg Exp -> String -> CompilerM op s ())
-> CompilerM op s (ErrorMsg Exp -> String -> CompilerM op s ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Operations op s -> ErrorMsg Exp -> String -> CompilerM op s ()
forall op s. Operations op s -> ErrorCompiler op s
opsError (Operations op s -> ErrorMsg Exp -> String -> CompilerM op s ())
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> ErrorMsg Exp
-> String
-> CompilerM op s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations) CompilerM op s (ErrorMsg Exp -> String -> CompilerM op s ())
-> CompilerM op s (ErrorMsg Exp)
-> CompilerM op s (String -> CompilerM op s ())
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ErrorMsg Exp -> CompilerM op s (ErrorMsg Exp)
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ErrorMsg Exp
msg CompilerM op s (String -> CompilerM op s ())
-> CompilerM op s String -> CompilerM op s (CompilerM op s ())
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> String -> CompilerM op s String
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure String
stacktrace
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|if (!$exp:e') { $items:err }|]
  where
    stacktrace :: String
stacktrace = Text -> String
T.unpack (Text -> String) -> Text -> String
forall a b. (a -> b) -> a -> b
$ Int -> [Text] -> Text
prettyStacktrace Int
0 ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ (SrcLoc -> Text) -> [SrcLoc] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map SrcLoc -> Text
forall a. Located a => a -> Text
locText ([SrcLoc] -> [Text]) -> [SrcLoc] -> [Text]
forall a b. (a -> b) -> a -> b
$ SrcLoc
loc SrcLoc -> [SrcLoc] -> [SrcLoc]
forall a. a -> [a] -> [a]
: [SrcLoc]
locs
compileCode (Allocate VName
_ Count Bytes (TExp Int64)
_ ScalarSpace {}) =
  -- Handled by the declaration of the memory block, which is
  -- translated to an actual array.
  () -> CompilerM op s ()
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
compileCode (Allocate VName
name (Count (TPrimExp Exp
e)) Space
space) = do
  Exp
size <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
e
  Maybe VName
cached <- VName -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
name
  case Maybe VName
cached of
    Just VName
cur_size ->
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
        [C.cstm|if ($exp:cur_size < $exp:size) {
                 err = lexical_realloc(ctx, &$exp:name, &$exp:cur_size, $exp:size);
                 if (err != FUTHARK_SUCCESS) {
                   goto cleanup;
                 }
                }|]
    Maybe VName
_ ->
      VName -> Exp -> Space -> Stm -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> Stm -> CompilerM op s ()
allocMem VName
name Exp
size Space
space [C.cstm|{err = 1; goto cleanup;}|]
compileCode (Free VName
name Space
space) = do
  Bool
cached <- Maybe VName -> Bool
forall a. Maybe a -> Bool
isJust (Maybe VName -> Bool)
-> CompilerM op s (Maybe VName) -> CompilerM op s Bool
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s (Maybe VName)
forall a op s. ToExp a => a -> CompilerM op s (Maybe VName)
cacheMem VName
name
  Bool -> CompilerM op s () -> CompilerM op s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
unless Bool
cached (CompilerM op s () -> CompilerM op s ())
-> CompilerM op s () -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ VName -> Space -> CompilerM op s ()
forall a op s. ToExp a => a -> Space -> CompilerM op s ()
unRefMem VName
name Space
space
compileCode (For VName
i Exp
bound Code op
body) = do
  let i' :: SrcLoc -> Id
i' = VName -> SrcLoc -> Id
forall a. ToIdent a => a -> SrcLoc -> Id
C.toIdent VName
i
      t :: Type
t = PrimType -> Type
primTypeToCType (PrimType -> Type) -> PrimType -> Type
forall a b. (a -> b) -> a -> b
$ Exp -> PrimType
forall v. PrimExp v -> PrimType
primExpType Exp
bound
  Exp
bound' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
bound
  [BlockItem]
body' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
body
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|for ($ty:t $id:i' = 0; $id:i' < $exp:bound'; $id:i'++) {
            $items:body'
          }|]
compileCode (While TExp Bool
cond Code op
body) = do
  Exp
cond' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Bool
cond
  [BlockItem]
body' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
body
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm
    [C.cstm|while ($exp:cond') {
            $items:body'
          }|]
compileCode (If TExp Bool
cond Code op
tbranch Code op
fbranch) = do
  Exp
cond' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (Exp -> CompilerM op s Exp) -> Exp -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Bool
cond
  [BlockItem]
tbranch' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
tbranch
  [BlockItem]
fbranch' <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
fbranch
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm (Stm -> CompilerM op s ()) -> Stm -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ case ([BlockItem]
tbranch', [BlockItem]
fbranch') of
    ([BlockItem]
_, []) ->
      [C.cstm|if ($exp:cond') { $items:tbranch' }|]
    ([], [BlockItem]
_) ->
      [C.cstm|if (!($exp:cond')) { $items:fbranch' }|]
    ([BlockItem]
_, [C.BlockStm x :: Stm
x@C.If {}]) ->
      [C.cstm|if ($exp:cond') { $items:tbranch' } else $stm:x|]
    ([BlockItem], [BlockItem])
_ ->
      [C.cstm|if ($exp:cond') { $items:tbranch' } else { $items:fbranch' }|]
compileCode (LMADCopy PrimType
t [Count Elements (TExp Int64)]
shape (VName
dst, Space
dstspace) (Count Elements (TExp Int64)
dstoffset, [Count Elements (TExp Int64)]
dststrides) (VName
src, Space
srcspace) (Count Elements (TExp Int64)
srcoffset, [Count Elements (TExp Int64)]
srcstrides)) = do
  Maybe (DoLMADCopy op s)
cp <- (CompilerEnv op s -> Maybe (DoLMADCopy op s))
-> CompilerM op s (Maybe (DoLMADCopy op s))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((CompilerEnv op s -> Maybe (DoLMADCopy op s))
 -> CompilerM op s (Maybe (DoLMADCopy op s)))
-> (CompilerEnv op s -> Maybe (DoLMADCopy op s))
-> CompilerM op s (Maybe (DoLMADCopy op s))
forall a b. (a -> b) -> a -> b
$ (Space, Space)
-> Map (Space, Space) (DoLMADCopy op s) -> Maybe (DoLMADCopy op s)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Space
dstspace, Space
srcspace) (Map (Space, Space) (DoLMADCopy op s) -> Maybe (DoLMADCopy op s))
-> (CompilerEnv op s -> Map (Space, Space) (DoLMADCopy op s))
-> CompilerEnv op s
-> Maybe (DoLMADCopy op s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Operations op s -> Map (Space, Space) (DoLMADCopy op s)
forall op s.
Operations op s -> Map (Space, Space) (DoLMADCopy op s)
opsCopies (Operations op s -> Map (Space, Space) (DoLMADCopy op s))
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> Map (Space, Space) (DoLMADCopy op s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations
  case Maybe (DoLMADCopy op s)
cp of
    Maybe (DoLMADCopy op s)
Nothing ->
      PrimType
-> [Count Elements (TExp Int64)]
-> (VName, Space)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> (VName, Space)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> CompilerM op s ()
forall op s.
PrimType
-> [Count Elements (TExp Int64)]
-> (VName, Space)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> (VName, Space)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> CompilerM op s ()
compileLMADCopy PrimType
t [Count Elements (TExp Int64)]
shape (VName
dst, Space
dstspace) (Count Elements (TExp Int64)
dstoffset, [Count Elements (TExp Int64)]
dststrides) (VName
src, Space
srcspace) (Count Elements (TExp Int64)
srcoffset, [Count Elements (TExp Int64)]
srcstrides)
    Just DoLMADCopy op s
cp' -> do
      [Count Elements Exp]
shape' <- (Count Elements (TExp Int64)
 -> CompilerM op s (Count Elements Exp))
-> [Count Elements (TExp Int64)]
-> CompilerM op s [Count Elements Exp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((TExp Int64 -> CompilerM op s Exp)
-> Count Elements (TExp Int64)
-> CompilerM op s (Count Elements Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Count Elements a -> f (Count Elements b)
traverse (Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (Exp -> CompilerM op s Exp)
-> (TExp Int64 -> Exp) -> TExp Int64 -> CompilerM op s Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped)) [Count Elements (TExp Int64)]
shape
      Exp
dst' <- VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
dst
      Exp
src' <- VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
src
      Count Elements Exp
dstoffset' <- (TExp Int64 -> CompilerM op s Exp)
-> Count Elements (TExp Int64)
-> CompilerM op s (Count Elements Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Count Elements a -> f (Count Elements b)
traverse (Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (Exp -> CompilerM op s Exp)
-> (TExp Int64 -> Exp) -> TExp Int64 -> CompilerM op s Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) Count Elements (TExp Int64)
dstoffset
      [Count Elements Exp]
dststrides' <- (Count Elements (TExp Int64)
 -> CompilerM op s (Count Elements Exp))
-> [Count Elements (TExp Int64)]
-> CompilerM op s [Count Elements Exp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((TExp Int64 -> CompilerM op s Exp)
-> Count Elements (TExp Int64)
-> CompilerM op s (Count Elements Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Count Elements a -> f (Count Elements b)
traverse (Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (Exp -> CompilerM op s Exp)
-> (TExp Int64 -> Exp) -> TExp Int64 -> CompilerM op s Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped)) [Count Elements (TExp Int64)]
dststrides
      Count Elements Exp
srcoffset' <- (TExp Int64 -> CompilerM op s Exp)
-> Count Elements (TExp Int64)
-> CompilerM op s (Count Elements Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Count Elements a -> f (Count Elements b)
traverse (Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (Exp -> CompilerM op s Exp)
-> (TExp Int64 -> Exp) -> TExp Int64 -> CompilerM op s Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) Count Elements (TExp Int64)
srcoffset
      [Count Elements Exp]
srcstrides' <- (Count Elements (TExp Int64)
 -> CompilerM op s (Count Elements Exp))
-> [Count Elements (TExp Int64)]
-> CompilerM op s [Count Elements Exp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((TExp Int64 -> CompilerM op s Exp)
-> Count Elements (TExp Int64)
-> CompilerM op s (Count Elements Exp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Count Elements a -> f (Count Elements b)
traverse (Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (Exp -> CompilerM op s Exp)
-> (TExp Int64 -> Exp) -> TExp Int64 -> CompilerM op s Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped)) [Count Elements (TExp Int64)]
srcstrides
      DoLMADCopy op s
cp' CopyBarrier
CopyBarrier PrimType
t [Count Elements Exp]
shape' Exp
dst' (Count Elements Exp
dstoffset', [Count Elements Exp]
dststrides') Exp
src' (Count Elements Exp
srcoffset', [Count Elements Exp]
srcstrides')
compileCode (Write VName
_ Count Elements (TExp Int64)
_ PrimType
Unit Space
_ Volatility
_ Exp
_) = () -> CompilerM op s ()
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
compileCode (Write VName
dst (Count TExp Int64
idx) PrimType
elemtype Space
space Volatility
vol Exp
elemexp) = do
  Exp
dst' <- VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
dst
  Exp
idx' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
idx)
  Exp
elemexp' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
elemexp
  Exp
-> Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> CompilerM op s ()
forall op s.
Exp
-> Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> CompilerM op s ()
generateWrite Exp
dst' Exp
idx' PrimType
elemtype Space
space Volatility
vol Exp
elemexp'
compileCode (Read VName
x VName
src Count Elements (TExp Int64)
i PrimType
restype Space
space Volatility
vol) = do
  Exp
e <- VName
-> Count Elements (TExp Int64)
-> PrimType
-> Space
-> Volatility
-> CompilerM op s Exp
forall {k} {k} (u :: k) (t :: k) op s.
VName
-> Count u (TPrimExp t VName)
-> PrimType
-> Space
-> Volatility
-> CompilerM op s Exp
compileRead VName
src Count Elements (TExp Int64)
i PrimType
restype Space
space Volatility
vol
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$id:x = $exp:e;|]
compileCode (DeclareMem VName
name Space
space) =
  VName -> Space -> CompilerM op s ()
forall op s. VName -> Space -> CompilerM op s ()
declMem VName
name Space
space
compileCode (DeclareScalar VName
name Volatility
vol PrimType
t) = do
  let ct :: Type
ct = PrimType -> Type
primTypeToCType PrimType
t
  InitGroup -> CompilerM op s ()
forall op s. InitGroup -> CompilerM op s ()
decl [C.cdecl|$tyquals:(volQuals vol) $ty:ct $id:name;|]
compileCode (DeclareArray VName
name PrimType
t ArrayContents
vs) = do
  VName
name_realtype <- String -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName (String -> CompilerM op s VName) -> String -> CompilerM op s VName
forall a b. (a -> b) -> a -> b
$ VName -> String
baseString VName
name String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_realtype"
  let ct :: Type
ct = PrimType -> Type
primTypeToCType PrimType
t
  case ArrayContents
vs of
    ArrayValues [PrimValue]
vs' -> do
      let vs'' :: [Initializer]
vs'' = [[C.cinit|$exp:v|] | PrimValue
v <- [PrimValue]
vs']
      Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:(length vs')] = {$inits:vs''};|]
    ArrayZeros Int
n ->
      Definition -> CompilerM op s ()
forall op s. Definition -> CompilerM op s ()
earlyDecl [C.cedecl|static $ty:ct $id:name_realtype[$int:n];|]
  -- Fake a memory block.
  BlockItem -> CompilerM op s ()
forall op s. BlockItem -> CompilerM op s ()
item
    [C.citem|struct memblock $id:name =
               (struct memblock){NULL,
                                 (unsigned char*)$id:name_realtype,
                                 0,
                                 $string:(prettyString name)};|]
-- For assignments of the form 'x = x OP e', we generate C assignment
-- operators to make the resulting code slightly nicer.  This has no
-- effect on performance.
compileCode (SetScalar VName
dest (BinOpExp BinOp
op (LeafExp VName
x PrimType
_) Exp
y))
  | VName
dest VName -> VName -> Bool
forall a. Eq a => a -> a -> Bool
== VName
x,
    Just VName -> Exp -> Exp
f <- BinOp -> Maybe (VName -> Exp -> Exp)
assignmentOperator BinOp
op = do
      Exp
y' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
y
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$exp:(f dest y');|]
compileCode (SetScalar VName
dest Exp
src) = do
  Exp
src' <- Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp Exp
src
  Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$id:dest = $exp:src';|]
compileCode (SetMem VName
dest VName
src Space
space) =
  VName -> VName -> Space -> CompilerM op s ()
forall a b op s.
(ToExp a, ToExp b) =>
a -> b -> Space -> CompilerM op s ()
setMem VName
dest VName
src Space
space
compileCode (Call [VName
dest] Name
fname [Arg]
args)
  | Name -> Bool
isBuiltInFunction Name
fname = do
      [Exp]
args' <- (Arg -> CompilerM op s Exp) -> [Arg] -> CompilerM op s [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Arg -> CompilerM op s Exp
forall op s. Arg -> CompilerM op s Exp
compileArg [Arg]
args
      Stm -> CompilerM op s ()
forall op s. Stm -> CompilerM op s ()
stm [C.cstm|$id:dest = $id:(funName fname)($args:args');|]
compileCode (Call [VName]
dests Name
fname [Arg]
args) = do
  ([VName]
dests', [[Stm]]
unpack_dest) <- (VName -> CompilerM op s (VName, [Stm]))
-> [VName] -> CompilerM op s ([VName], [[Stm]])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM VName -> CompilerM op s (VName, [Stm])
forall op s. VName -> CompilerM op s (VName, [Stm])
compileDest [VName]
dests
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    (CompilerEnv op s -> [VName] -> Name -> [Exp] -> CompilerM op s ())
-> CompilerM op s ([VName] -> Name -> [Exp] -> CompilerM op s ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks (Operations op s -> [VName] -> Name -> [Exp] -> CompilerM op s ()
forall op s. Operations op s -> CallCompiler op s
opsCall (Operations op s -> [VName] -> Name -> [Exp] -> CompilerM op s ())
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> [VName]
-> Name
-> [Exp]
-> CompilerM op s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations)
      CompilerM op s ([VName] -> Name -> [Exp] -> CompilerM op s ())
-> CompilerM op s [VName]
-> CompilerM op s (Name -> [Exp] -> CompilerM op s ())
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName] -> CompilerM op s [VName]
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [VName]
dests'
      CompilerM op s (Name -> [Exp] -> CompilerM op s ())
-> CompilerM op s Name
-> CompilerM op s ([Exp] -> CompilerM op s ())
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Name -> CompilerM op s Name
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Name
fname
      CompilerM op s ([Exp] -> CompilerM op s ())
-> CompilerM op s [Exp] -> CompilerM op s (CompilerM op s ())
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> (Arg -> CompilerM op s Exp) -> [Arg] -> CompilerM op s [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Arg -> CompilerM op s Exp
forall op s. Arg -> CompilerM op s Exp
compileArg [Arg]
args
  [Stm] -> CompilerM op s ()
forall op s. [Stm] -> CompilerM op s ()
stms ([Stm] -> CompilerM op s ()) -> [Stm] -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ [[Stm]] -> [Stm]
forall a. Monoid a => [a] -> a
mconcat [[Stm]]
unpack_dest

-- | Compile an 'LMADCopy' using sequential nested loops, but
-- parameterised over how to do the reads and writes.
compileLMADCopyWith ::
  [Count Elements (TExp Int64)] ->
  (C.Exp -> C.Exp -> CompilerM op s ()) ->
  ( Count Elements (TExp Int64),
    [Count Elements (TExp Int64)]
  ) ->
  (C.Exp -> CompilerM op s C.Exp) ->
  ( Count Elements (TExp Int64),
    [Count Elements (TExp Int64)]
  ) ->
  CompilerM op s ()
compileLMADCopyWith :: forall op s.
[Count Elements (TExp Int64)]
-> (Exp -> Exp -> CompilerM op s ())
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> (Exp -> CompilerM op s Exp)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> CompilerM op s ()
compileLMADCopyWith [Count Elements (TExp Int64)]
shape Exp -> Exp -> CompilerM op s ()
doWrite (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
dst_lmad Exp -> CompilerM op s Exp
doRead (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
src_lmad = do
  let (Count Elements (TExp Int64)
dstoffset, [Count Elements (TExp Int64)]
dststrides) = (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
dst_lmad
      (Count Elements (TExp Int64)
srcoffset, [Count Elements (TExp Int64)]
srcstrides) = (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
src_lmad
  [Exp]
shape' <- (Count Elements (TExp Int64) -> CompilerM op s Exp)
-> [Count Elements (TExp Int64)] -> CompilerM op s [Exp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (Exp -> CompilerM op s Exp)
-> (Count Elements (TExp Int64) -> Exp)
-> Count Elements (TExp Int64)
-> CompilerM op s Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp)
-> (Count Elements (TExp Int64) -> TExp Int64)
-> Count Elements (TExp Int64)
-> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Count Elements (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount) [Count Elements (TExp Int64)]
shape
  [BlockItem]
body <- CompilerM op s () -> CompilerM op s [BlockItem]
forall op s. CompilerM op s () -> CompilerM op s [BlockItem]
collect (CompilerM op s () -> CompilerM op s [BlockItem])
-> CompilerM op s () -> CompilerM op s [BlockItem]
forall a b. (a -> b) -> a -> b
$ do
    Exp
dst_i <-
      Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (Exp -> CompilerM op s Exp)
-> (Count Elements (TExp Int64) -> Exp)
-> Count Elements (TExp Int64)
-> CompilerM op s Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp)
-> (Count Elements (TExp Int64) -> TExp Int64)
-> Count Elements (TExp Int64)
-> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Count Elements (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount (Count Elements (TExp Int64) -> CompilerM op s Exp)
-> Count Elements (TExp Int64) -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$
        Count Elements (TExp Int64)
dstoffset Count Elements (TExp Int64)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int64)
forall a. Num a => a -> a -> a
+ [Count Elements (TExp Int64)] -> Count Elements (TExp Int64)
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Count Elements (TExp Int64)
 -> Count Elements (TExp Int64) -> Count Elements (TExp Int64))
-> [Count Elements (TExp Int64)]
-> [Count Elements (TExp Int64)]
-> [Count Elements (TExp Int64)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Count Elements (TExp Int64)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int64)
forall a. Num a => a -> a -> a
(*) [Count Elements (TExp Int64)]
is' [Count Elements (TExp Int64)]
dststrides)
    Exp
src_i <-
      Exp -> CompilerM op s Exp
forall op s. Exp -> CompilerM op s Exp
compileExp (Exp -> CompilerM op s Exp)
-> (Count Elements (TExp Int64) -> Exp)
-> Count Elements (TExp Int64)
-> CompilerM op s Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp)
-> (Count Elements (TExp Int64) -> TExp Int64)
-> Count Elements (TExp Int64)
-> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Count Elements (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount (Count Elements (TExp Int64) -> CompilerM op s Exp)
-> Count Elements (TExp Int64) -> CompilerM op s Exp
forall a b. (a -> b) -> a -> b
$
        Count Elements (TExp Int64)
srcoffset Count Elements (TExp Int64)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int64)
forall a. Num a => a -> a -> a
+ [Count Elements (TExp Int64)] -> Count Elements (TExp Int64)
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Count Elements (TExp Int64)
 -> Count Elements (TExp Int64) -> Count Elements (TExp Int64))
-> [Count Elements (TExp Int64)]
-> [Count Elements (TExp Int64)]
-> [Count Elements (TExp Int64)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Count Elements (TExp Int64)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int64)
forall a. Num a => a -> a -> a
(*) [Count Elements (TExp Int64)]
is' [Count Elements (TExp Int64)]
srcstrides)
    Exp -> Exp -> CompilerM op s ()
doWrite Exp
dst_i (Exp -> CompilerM op s ())
-> CompilerM op s Exp -> CompilerM op s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Exp -> CompilerM op s Exp
doRead Exp
src_i
  [BlockItem] -> CompilerM op s ()
forall op s. [BlockItem] -> CompilerM op s ()
items ([BlockItem] -> CompilerM op s ())
-> [BlockItem] -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ [(VName, Exp)] -> [BlockItem] -> [BlockItem]
forall {a} {a}.
(ToExp a, ToIdent a) =>
[(a, a)] -> [BlockItem] -> [BlockItem]
loops ([VName] -> [Exp] -> [(VName, Exp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is [Exp]
shape') [BlockItem]
body
  where
    r :: Int
r = [Count Elements (TExp Int64)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Count Elements (TExp Int64)]
shape
    is :: [VName]
is = (Int -> VName) -> [Int] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Name -> Int -> VName
VName Name
"i") [Int
0 .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
    is' :: [Count Elements (TExp Int64)]
    is' :: [Count Elements (TExp Int64)]
is' = (VName -> Count Elements (TExp Int64))
-> [VName] -> [Count Elements (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> (VName -> TExp Int64) -> VName -> Count Elements (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
le64) [VName]
is
    loops :: [(a, a)] -> [BlockItem] -> [BlockItem]
loops [] [BlockItem]
body = [BlockItem]
body
    loops ((a
i, a
n) : [(a, a)]
ins) [BlockItem]
body =
      [C.citems|for (typename int64_t $id:i = 0; $id:i < $exp:n; $id:i++)
                  { $items:(loops ins body) }|]

-- | Compile an 'LMADCopy' using sequential nested loops and
-- 'Read'/'Write' of individual scalars.  This always works, but can
-- be pretty slow if those reads and writes are costly.
compileLMADCopy ::
  PrimType ->
  [Count Elements (TExp Int64)] ->
  (VName, Space) ->
  ( Count Elements (TExp Int64),
    [Count Elements (TExp Int64)]
  ) ->
  (VName, Space) ->
  ( Count Elements (TExp Int64),
    [Count Elements (TExp Int64)]
  ) ->
  CompilerM op s ()
compileLMADCopy :: forall op s.
PrimType
-> [Count Elements (TExp Int64)]
-> (VName, Space)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> (VName, Space)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> CompilerM op s ()
compileLMADCopy PrimType
t [Count Elements (TExp Int64)]
shape (VName
dst, Space
dstspace) (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
dst_lmad (VName
src, Space
srcspace) (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
src_lmad = do
  Exp
src' <- VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
src
  Exp
dst' <- VName -> CompilerM op s Exp
forall op s. VName -> CompilerM op s Exp
rawMem VName
dst
  let doWrite :: Exp -> Exp -> CompilerM op s ()
doWrite Exp
dst_i = Exp
-> Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> CompilerM op s ()
forall op s.
Exp
-> Exp
-> PrimType
-> Space
-> Volatility
-> Exp
-> CompilerM op s ()
generateWrite Exp
dst' Exp
dst_i PrimType
t Space
dstspace Volatility
Nonvolatile
      doRead :: Exp -> CompilerM op s Exp
doRead Exp
src_i = Exp -> Exp -> PrimType -> Space -> Volatility -> CompilerM op s Exp
forall op s.
Exp -> Exp -> PrimType -> Space -> Volatility -> CompilerM op s Exp
generateRead Exp
src' Exp
src_i PrimType
t Space
srcspace Volatility
Nonvolatile
  [Count Elements (TExp Int64)]
-> (Exp -> Exp -> CompilerM op s ())
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> (Exp -> CompilerM op s Exp)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> CompilerM op s ()
forall op s.
[Count Elements (TExp Int64)]
-> (Exp -> Exp -> CompilerM op s ())
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> (Exp -> CompilerM op s Exp)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> CompilerM op s ()
compileLMADCopyWith [Count Elements (TExp Int64)]
shape Exp -> Exp -> CompilerM op s ()
forall {op} {s}. Exp -> Exp -> CompilerM op s ()
doWrite (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
dst_lmad Exp -> CompilerM op s Exp
forall {op} {s}. Exp -> CompilerM op s Exp
doRead (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
src_lmad