-- | A generic Python code generator which is polymorphic in the type
-- of the operations.  Concretely, we use this to handle both
-- sequential and PyOpenCL Python code.
module Futhark.CodeGen.Backends.GenericPython
  ( compileProg,
    CompilerMode,
    Constructor (..),
    emptyConstructor,
    compileName,
    compileVar,
    compileDim,
    compileExp,
    compilePrimExp,
    compileCode,
    compilePrimValue,
    compilePrimType,
    compilePrimToNp,
    compilePrimToExtNp,
    fromStorage,
    toStorage,
    Operations (..),
    DoCopy,
    defaultOperations,
    unpackDim,
    CompilerM (..),
    OpCompiler,
    WriteScalar,
    ReadScalar,
    Allocate,
    Copy,
    EntryOutput,
    EntryInput,
    CompilerEnv (..),
    CompilerState (..),
    stm,
    atInit,
    collect',
    collect,
    simpleCall,
  )
where

import Control.Monad
import Control.Monad.RWS hiding (reader, writer)
import Data.Char (isAlpha, isAlphaNum)
import Data.Map qualified as M
import Data.Maybe
import Data.Text qualified as T
import Futhark.CodeGen.Backends.GenericPython.AST
import Futhark.CodeGen.Backends.GenericPython.Options
import Futhark.CodeGen.ImpCode (Count (..), Elements, TExp, elements, le64, untyped)
import Futhark.CodeGen.ImpCode qualified as Imp
import Futhark.CodeGen.RTS.Python
import Futhark.Compiler.Config (CompilerMode (..))
import Futhark.IR.Prop (isBuiltInFunction, subExpVars)
import Futhark.IR.Syntax.Core (Space (..))
import Futhark.MonadFreshNames
import Futhark.Util (zEncodeText)
import Futhark.Util.Pretty (prettyString, prettyText)
import Language.Futhark.Primitive hiding (Bool)

-- | A substitute expression compiler, tried before the main
-- compilation function.
type OpCompiler op s = op -> CompilerM op s ()

-- | Write a scalar to the given memory block with the given index and
-- in the given memory space.
type WriteScalar op s =
  PyExp ->
  PyExp ->
  PrimType ->
  Imp.SpaceId ->
  PyExp ->
  CompilerM op s ()

-- | Read a scalar from the given memory block with the given index and
-- in the given memory space.
type ReadScalar op s =
  PyExp ->
  PyExp ->
  PrimType ->
  Imp.SpaceId ->
  CompilerM op s PyExp

-- | Allocate a memory block of the given size in the given memory
-- space, saving a reference in the given variable name.
type Allocate op s =
  PyExp ->
  PyExp ->
  Imp.SpaceId ->
  CompilerM op s ()

-- | Copy from one memory block to another.
type Copy op s =
  PyExp ->
  PyExp ->
  Imp.Space ->
  PyExp ->
  PyExp ->
  Imp.Space ->
  PyExp ->
  PrimType ->
  CompilerM op s ()

-- | Perform an 'Imp.Copy'.  It is expected that these functions
-- are each specialised on which spaces they operate on, so that is
-- not part of their arguments.
type DoCopy op s =
  PrimType ->
  [Count Elements PyExp] ->
  PyExp ->
  ( Count Elements PyExp,
    [Count Elements PyExp]
  ) ->
  PyExp ->
  ( Count Elements PyExp,
    [Count Elements PyExp]
  ) ->
  CompilerM op s ()

-- | Construct the Python array being returned from an entry point.
type EntryOutput op s =
  VName ->
  Imp.SpaceId ->
  PrimType ->
  Imp.Signedness ->
  [Imp.DimSize] ->
  CompilerM op s PyExp

-- | Unpack the array being passed to an entry point.
type EntryInput op s =
  PyExp ->
  Imp.SpaceId ->
  PrimType ->
  Imp.Signedness ->
  [Imp.DimSize] ->
  PyExp ->
  CompilerM op s ()

data Operations op s = Operations
  { forall op s. Operations op s -> WriteScalar op s
opsWriteScalar :: WriteScalar op s,
    forall op s. Operations op s -> ReadScalar op s
opsReadScalar :: ReadScalar op s,
    forall op s. Operations op s -> Allocate op s
opsAllocate :: Allocate op s,
    -- | @(dst,src)@-space mapping to copy functions.
    forall op s. Operations op s -> Map (Space, Space) (DoCopy op s)
opsCopies :: M.Map (Space, Space) (DoCopy op s),
    forall op s. Operations op s -> OpCompiler op s
opsCompiler :: OpCompiler op s,
    forall op s. Operations op s -> EntryOutput op s
opsEntryOutput :: EntryOutput op s,
    forall op s. Operations op s -> EntryInput op s
opsEntryInput :: EntryInput op s
  }

-- | A set of operations that fail for every operation involving
-- non-default memory spaces.  Uses plain pointers and @malloc@ for
-- memory management.
defaultOperations :: Operations op s
defaultOperations :: forall op s. Operations op s
defaultOperations =
  Operations
    { opsWriteScalar :: WriteScalar op s
opsWriteScalar = WriteScalar op s
forall {p} {p} {p} {p} {p} {a}. p -> p -> p -> p -> p -> a
defWriteScalar,
      opsReadScalar :: ReadScalar op s
opsReadScalar = ReadScalar op s
forall {p} {p} {p} {p} {a}. p -> p -> p -> p -> a
defReadScalar,
      opsAllocate :: Allocate op s
opsAllocate = Allocate op s
forall {p} {p} {p} {a}. p -> p -> p -> a
defAllocate,
      opsCopies :: Map (Space, Space) (DoCopy op s)
opsCopies = (Space, Space) -> DoCopy op s -> Map (Space, Space) (DoCopy op s)
forall k a. k -> a -> Map k a
M.singleton (Space
DefaultSpace, Space
DefaultSpace) DoCopy op s
forall op s. DoCopy op s
lmadcopyCPU,
      opsCompiler :: OpCompiler op s
opsCompiler = OpCompiler op s
forall {p} {a}. p -> a
defCompiler,
      opsEntryOutput :: EntryOutput op s
opsEntryOutput = EntryOutput op s
forall {p} {p} {p} {p} {a}. p -> p -> p -> p -> a
defEntryOutput,
      opsEntryInput :: EntryInput op s
opsEntryInput = EntryInput op s
forall {p} {p} {p} {p} {a}. p -> p -> p -> p -> a
defEntryInput
    }
  where
    defWriteScalar :: p -> p -> p -> p -> p -> a
defWriteScalar p
_ p
_ p
_ p
_ p
_ =
      [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot write to non-default memory space because I am dumb"
    defReadScalar :: p -> p -> p -> p -> a
defReadScalar p
_ p
_ p
_ p
_ =
      [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot read from non-default memory space"
    defAllocate :: p -> p -> p -> a
defAllocate p
_ p
_ p
_ =
      [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot allocate in non-default memory space"
    defCompiler :: p -> a
defCompiler p
_ =
      [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"The default compiler cannot compile extended operations"
    defEntryOutput :: p -> p -> p -> p -> a
defEntryOutput p
_ p
_ p
_ p
_ =
      [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot return array not in default memory space"
    defEntryInput :: p -> p -> p -> p -> a
defEntryInput p
_ p
_ p
_ p
_ =
      [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"Cannot accept array not in default memory space"

data CompilerEnv op s = CompilerEnv
  { forall op s. CompilerEnv op s -> Operations op s
envOperations :: Operations op s,
    forall op s. CompilerEnv op s -> Map [Char] PyExp
envVarExp :: M.Map String PyExp
  }

envOpCompiler :: CompilerEnv op s -> OpCompiler op s
envOpCompiler :: forall op s. CompilerEnv op s -> OpCompiler op s
envOpCompiler = Operations op s -> OpCompiler op s
forall op s. Operations op s -> OpCompiler op s
opsCompiler (Operations op s -> OpCompiler op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> OpCompiler op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envReadScalar :: CompilerEnv op s -> ReadScalar op s
envReadScalar :: forall op s. CompilerEnv op s -> ReadScalar op s
envReadScalar = Operations op s -> ReadScalar op s
forall op s. Operations op s -> ReadScalar op s
opsReadScalar (Operations op s -> ReadScalar op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> ReadScalar op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envWriteScalar :: CompilerEnv op s -> WriteScalar op s
envWriteScalar :: forall op s. CompilerEnv op s -> WriteScalar op s
envWriteScalar = Operations op s -> WriteScalar op s
forall op s. Operations op s -> WriteScalar op s
opsWriteScalar (Operations op s -> WriteScalar op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> WriteScalar op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envAllocate :: CompilerEnv op s -> Allocate op s
envAllocate :: forall op s. CompilerEnv op s -> Allocate op s
envAllocate = Operations op s -> Allocate op s
forall op s. Operations op s -> Allocate op s
opsAllocate (Operations op s -> Allocate op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> Allocate op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envEntryOutput :: CompilerEnv op s -> EntryOutput op s
envEntryOutput :: forall op s. CompilerEnv op s -> EntryOutput op s
envEntryOutput = Operations op s -> EntryOutput op s
forall op s. Operations op s -> EntryOutput op s
opsEntryOutput (Operations op s -> EntryOutput op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> EntryOutput op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

envEntryInput :: CompilerEnv op s -> EntryInput op s
envEntryInput :: forall op s. CompilerEnv op s -> EntryInput op s
envEntryInput = Operations op s -> EntryInput op s
forall op s. Operations op s -> EntryInput op s
opsEntryInput (Operations op s -> EntryInput op s)
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> EntryInput op s
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations

newCompilerEnv :: Operations op s -> CompilerEnv op s
newCompilerEnv :: forall op s. Operations op s -> CompilerEnv op s
newCompilerEnv Operations op s
ops =
  CompilerEnv
    { envOperations :: Operations op s
envOperations = Operations op s
ops,
      envVarExp :: Map [Char] PyExp
envVarExp = Map [Char] PyExp
forall a. Monoid a => a
mempty
    }

data CompilerState s = CompilerState
  { forall s. CompilerState s -> VNameSource
compNameSrc :: VNameSource,
    forall s. CompilerState s -> [PyStmt]
compInit :: [PyStmt],
    forall s. CompilerState s -> s
compUserState :: s
  }

newCompilerState :: VNameSource -> s -> CompilerState s
newCompilerState :: forall s. VNameSource -> s -> CompilerState s
newCompilerState VNameSource
src s
s =
  CompilerState
    { compNameSrc :: VNameSource
compNameSrc = VNameSource
src,
      compInit :: [PyStmt]
compInit = [],
      compUserState :: s
compUserState = s
s
    }

newtype CompilerM op s a = CompilerM (RWS (CompilerEnv op s) [PyStmt] (CompilerState s) a)
  deriving
    ( (forall a b. (a -> b) -> CompilerM op s a -> CompilerM op s b)
-> (forall a b. a -> CompilerM op s b -> CompilerM op s a)
-> Functor (CompilerM op s)
forall a b. a -> CompilerM op s b -> CompilerM op s a
forall a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
forall op s a b. a -> CompilerM op s b -> CompilerM op s a
forall op s a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall op s a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
fmap :: forall a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
$c<$ :: forall op s a b. a -> CompilerM op s b -> CompilerM op s a
<$ :: forall a b. a -> CompilerM op s b -> CompilerM op s a
Functor,
      Functor (CompilerM op s)
Functor (CompilerM op s)
-> (forall a. a -> CompilerM op s a)
-> (forall a b.
    CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b)
-> (forall a b c.
    (a -> b -> c)
    -> CompilerM op s a -> CompilerM op s b -> CompilerM op s c)
-> (forall a b.
    CompilerM op s a -> CompilerM op s b -> CompilerM op s b)
-> (forall a b.
    CompilerM op s a -> CompilerM op s b -> CompilerM op s a)
-> Applicative (CompilerM op s)
forall a. a -> CompilerM op s a
forall op s. Functor (CompilerM op s)
forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall op s a. a -> CompilerM op s a
forall a b c.
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall op s a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall op s a b c.
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall op s a. a -> CompilerM op s a
pure :: forall a. a -> CompilerM op s a
$c<*> :: forall op s a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
<*> :: forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
$cliftA2 :: forall op s a b c.
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
liftA2 :: forall a b c.
(a -> b -> c)
-> CompilerM op s a -> CompilerM op s b -> CompilerM op s c
$c*> :: forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
*> :: forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
$c<* :: forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
<* :: forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s a
Applicative,
      Applicative (CompilerM op s)
Applicative (CompilerM op s)
-> (forall a b.
    CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b)
-> (forall a b.
    CompilerM op s a -> CompilerM op s b -> CompilerM op s b)
-> (forall a. a -> CompilerM op s a)
-> Monad (CompilerM op s)
forall a. a -> CompilerM op s a
forall op s. Applicative (CompilerM op s)
forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall a b.
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
forall op s a. a -> CompilerM op s a
forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
forall op s a b.
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall op s a b.
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
>>= :: forall a b.
CompilerM op s a -> (a -> CompilerM op s b) -> CompilerM op s b
$c>> :: forall op s a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
>> :: forall a b.
CompilerM op s a -> CompilerM op s b -> CompilerM op s b
$creturn :: forall op s a. a -> CompilerM op s a
return :: forall a. a -> CompilerM op s a
Monad,
      MonadState (CompilerState s),
      MonadReader (CompilerEnv op s),
      MonadWriter [PyStmt]
    )

instance MonadFreshNames (CompilerM op s) where
  getNameSource :: CompilerM op s VNameSource
getNameSource = (CompilerState s -> VNameSource) -> CompilerM op s VNameSource
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> VNameSource
forall s. CompilerState s -> VNameSource
compNameSrc
  putNameSource :: VNameSource -> CompilerM op s ()
putNameSource VNameSource
src = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s -> CompilerState s
s {compNameSrc :: VNameSource
compNameSrc = VNameSource
src}

collect :: CompilerM op s () -> CompilerM op s [PyStmt]
collect :: forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
collect CompilerM op s ()
m = CompilerM op s ([PyStmt], [PyStmt] -> [PyStmt])
-> CompilerM op s [PyStmt]
forall a.
CompilerM op s (a, [PyStmt] -> [PyStmt]) -> CompilerM op s a
forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass (CompilerM op s ([PyStmt], [PyStmt] -> [PyStmt])
 -> CompilerM op s [PyStmt])
-> CompilerM op s ([PyStmt], [PyStmt] -> [PyStmt])
-> CompilerM op s [PyStmt]
forall a b. (a -> b) -> a -> b
$ do
  ((), [PyStmt]
w) <- CompilerM op s () -> CompilerM op s ((), [PyStmt])
forall a. CompilerM op s a -> CompilerM op s (a, [PyStmt])
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen CompilerM op s ()
m
  ([PyStmt], [PyStmt] -> [PyStmt])
-> CompilerM op s ([PyStmt], [PyStmt] -> [PyStmt])
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([PyStmt]
w, [PyStmt] -> [PyStmt] -> [PyStmt]
forall a b. a -> b -> a
const [PyStmt]
forall a. Monoid a => a
mempty)

collect' :: CompilerM op s a -> CompilerM op s (a, [PyStmt])
collect' :: forall op s a. CompilerM op s a -> CompilerM op s (a, [PyStmt])
collect' CompilerM op s a
m = CompilerM op s ((a, [PyStmt]), [PyStmt] -> [PyStmt])
-> CompilerM op s (a, [PyStmt])
forall a.
CompilerM op s (a, [PyStmt] -> [PyStmt]) -> CompilerM op s a
forall w (m :: * -> *) a. MonadWriter w m => m (a, w -> w) -> m a
pass (CompilerM op s ((a, [PyStmt]), [PyStmt] -> [PyStmt])
 -> CompilerM op s (a, [PyStmt]))
-> CompilerM op s ((a, [PyStmt]), [PyStmt] -> [PyStmt])
-> CompilerM op s (a, [PyStmt])
forall a b. (a -> b) -> a -> b
$ do
  (a
x, [PyStmt]
w) <- CompilerM op s a -> CompilerM op s (a, [PyStmt])
forall a. CompilerM op s a -> CompilerM op s (a, [PyStmt])
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen CompilerM op s a
m
  ((a, [PyStmt]), [PyStmt] -> [PyStmt])
-> CompilerM op s ((a, [PyStmt]), [PyStmt] -> [PyStmt])
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((a
x, [PyStmt]
w), [PyStmt] -> [PyStmt] -> [PyStmt]
forall a b. a -> b -> a
const [PyStmt]
forall a. Monoid a => a
mempty)

atInit :: PyStmt -> CompilerM op s ()
atInit :: forall op s. PyStmt -> CompilerM op s ()
atInit PyStmt
x = (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((CompilerState s -> CompilerState s) -> CompilerM op s ())
-> (CompilerState s -> CompilerState s) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ \CompilerState s
s ->
  CompilerState s
s {compInit :: [PyStmt]
compInit = CompilerState s -> [PyStmt]
forall s. CompilerState s -> [PyStmt]
compInit CompilerState s
s [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [PyStmt
x]}

stm :: PyStmt -> CompilerM op s ()
stm :: forall op s. PyStmt -> CompilerM op s ()
stm PyStmt
x = [PyStmt] -> CompilerM op s ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell [PyStmt
x]

futharkFun :: T.Text -> T.Text
futharkFun :: Text -> Text
futharkFun Text
s = Text
"futhark_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> Text
zEncodeText Text
s

compileOutput :: [Imp.Param] -> [PyExp]
compileOutput :: [Param] -> [PyExp]
compileOutput = (Param -> PyExp) -> [Param] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map ([Char] -> PyExp
Var ([Char] -> PyExp) -> (Param -> [Char]) -> Param -> PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> [Char]
compileName (VName -> [Char]) -> (Param -> VName) -> Param -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param -> VName
Imp.paramName)

runCompilerM ::
  Operations op s ->
  VNameSource ->
  s ->
  CompilerM op s a ->
  a
runCompilerM :: forall op s a.
Operations op s -> VNameSource -> s -> CompilerM op s a -> a
runCompilerM Operations op s
ops VNameSource
src s
userstate (CompilerM RWS (CompilerEnv op s) [PyStmt] (CompilerState s) a
m) =
  (a, [PyStmt]) -> a
forall a b. (a, b) -> a
fst ((a, [PyStmt]) -> a) -> (a, [PyStmt]) -> a
forall a b. (a -> b) -> a -> b
$ RWS (CompilerEnv op s) [PyStmt] (CompilerState s) a
-> CompilerEnv op s -> CompilerState s -> (a, [PyStmt])
forall r w s a. RWS r w s a -> r -> s -> (a, w)
evalRWS RWS (CompilerEnv op s) [PyStmt] (CompilerState s) a
m (Operations op s -> CompilerEnv op s
forall op s. Operations op s -> CompilerEnv op s
newCompilerEnv Operations op s
ops) (VNameSource -> s -> CompilerState s
forall s. VNameSource -> s -> CompilerState s
newCompilerState VNameSource
src s
userstate)

standardOptions :: [Option]
standardOptions :: [Option]
standardOptions =
  [ Option
      { optionLongName :: Text
optionLongName = Text
"tuning",
        optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
        optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"open",
        optionAction :: [PyStmt]
optionAction = [PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"read_tuning_file" [[Char] -> PyExp
Var [Char]
"sizes", [Char] -> PyExp
Var [Char]
"optarg"]]
      },
    -- Does not actually do anything for Python backends.
    Option
      { optionLongName :: Text
optionLongName = Text
"cache-file",
        optionShortName :: Maybe Char
optionShortName = Maybe Char
forall a. Maybe a
Nothing,
        optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"str",
        optionAction :: [PyStmt]
optionAction = [PyStmt
Pass]
      },
    Option
      { optionLongName :: Text
optionLongName = Text
"log",
        optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'L',
        optionArgument :: OptionArgument
optionArgument = OptionArgument
NoArgument,
        optionAction :: [PyStmt]
optionAction = [PyStmt
Pass]
      }
  ]

executableOptions :: [Option]
executableOptions :: [Option]
executableOptions =
  [Option]
standardOptions
    [Option] -> [Option] -> [Option]
forall a. [a] -> [a] -> [a]
++ [ Option
           { optionLongName :: Text
optionLongName = Text
"write-runtime-to",
             optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
't',
             optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"str",
             optionAction :: [PyStmt]
optionAction =
               [ PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If
                   ([Char] -> PyExp
Var [Char]
"runtime_file")
                   [PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"runtime_file.close" []]
                   [],
                 PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"runtime_file") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
                   [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"open" [[Char] -> PyExp
Var [Char]
"optarg", Text -> PyExp
String Text
"w"]
               ]
           },
         Option
           { optionLongName :: Text
optionLongName = Text
"runs",
             optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'r',
             optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"str",
             optionAction :: [PyStmt]
optionAction =
               [ PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"num_runs") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> PyExp
Var [Char]
"optarg",
                 PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"do_warmup_run") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ Bool -> PyExp
Bool Bool
True
               ]
           },
         Option
           { optionLongName :: Text
optionLongName = Text
"entry-point",
             optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'e',
             optionArgument :: OptionArgument
optionArgument = [Char] -> OptionArgument
RequiredArgument [Char]
"str",
             optionAction :: [PyStmt]
optionAction =
               [PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"entry_point") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> PyExp
Var [Char]
"optarg"]
           },
         Option
           { optionLongName :: Text
optionLongName = Text
"binary-output",
             optionShortName :: Maybe Char
optionShortName = Char -> Maybe Char
forall a. a -> Maybe a
Just Char
'b',
             optionArgument :: OptionArgument
optionArgument = OptionArgument
NoArgument,
             optionAction :: [PyStmt]
optionAction = [PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"binary_output") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ Bool -> PyExp
Bool Bool
True]
           }
       ]

functionExternalValues :: Imp.EntryPoint -> [Imp.ExternalValue]
functionExternalValues :: EntryPoint -> [ExternalValue]
functionExternalValues EntryPoint
entry =
  ((Uniqueness, ExternalValue) -> ExternalValue)
-> [(Uniqueness, ExternalValue)] -> [ExternalValue]
forall a b. (a -> b) -> [a] -> [b]
map (Uniqueness, ExternalValue) -> ExternalValue
forall a b. (a, b) -> b
snd (EntryPoint -> [(Uniqueness, ExternalValue)]
Imp.entryPointResults EntryPoint
entry) [ExternalValue] -> [ExternalValue] -> [ExternalValue]
forall a. [a] -> [a] -> [a]
++ (((Name, Uniqueness), ExternalValue) -> ExternalValue)
-> [((Name, Uniqueness), ExternalValue)] -> [ExternalValue]
forall a b. (a -> b) -> [a] -> [b]
map ((Name, Uniqueness), ExternalValue) -> ExternalValue
forall a b. (a, b) -> b
snd (EntryPoint -> [((Name, Uniqueness), ExternalValue)]
Imp.entryPointArgs EntryPoint
entry)

-- | Is this name a valid Python identifier?  If not, it should be escaped
-- before being emitted.
isValidPyName :: T.Text -> Bool
isValidPyName :: Text -> Bool
isValidPyName = Bool -> ((Char, Text) -> Bool) -> Maybe (Char, Text) -> Bool
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Bool
True (Char, Text) -> Bool
check (Maybe (Char, Text) -> Bool)
-> (Text -> Maybe (Char, Text)) -> Text -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Maybe (Char, Text)
T.uncons
  where
    check :: (Char, Text) -> Bool
check (Char
c, Text
cs) = Char -> Bool
isAlpha Char
c Bool -> Bool -> Bool
&& (Char -> Bool) -> Text -> Bool
T.all Char -> Bool
constituent Text
cs
    constituent :: Char -> Bool
constituent Char
c = Char -> Bool
isAlphaNum Char
c Bool -> Bool -> Bool
|| Char
c Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
== Char
'_'

-- | If the provided text is a valid identifier, then return it
-- verbatim.  Otherwise, escape it such that it becomes valid.
escapeName :: Name -> T.Text
escapeName :: Name -> Text
escapeName Name
v
  | Text -> Bool
isValidPyName Text
v' = Text
v'
  | Bool
otherwise = Text -> Text
zEncodeText Text
v'
  where
    v' :: Text
v' = Name -> Text
nameToText Name
v

opaqueDefs :: Imp.Functions a -> M.Map T.Text [PyExp]
opaqueDefs :: forall a. Functions a -> Map Text [PyExp]
opaqueDefs (Imp.Functions [(Name, Function a)]
funs) =
  [Map Text [PyExp]] -> Map Text [PyExp]
forall a. Monoid a => [a] -> a
mconcat
    ([Map Text [PyExp]] -> Map Text [PyExp])
-> ([(Name, Function a)] -> [Map Text [PyExp]])
-> [(Name, Function a)]
-> Map Text [PyExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ExternalValue -> Map Text [PyExp])
-> [ExternalValue] -> [Map Text [PyExp]]
forall a b. (a -> b) -> [a] -> [b]
map ExternalValue -> Map Text [PyExp]
evd
    ([ExternalValue] -> [Map Text [PyExp]])
-> ([(Name, Function a)] -> [ExternalValue])
-> [(Name, Function a)]
-> [Map Text [PyExp]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (EntryPoint -> [ExternalValue]) -> [EntryPoint] -> [ExternalValue]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap EntryPoint -> [ExternalValue]
functionExternalValues
    ([EntryPoint] -> [ExternalValue])
-> ([(Name, Function a)] -> [EntryPoint])
-> [(Name, Function a)]
-> [ExternalValue]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Name, Function a) -> Maybe EntryPoint)
-> [(Name, Function a)] -> [EntryPoint]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Function a -> Maybe EntryPoint
forall a. FunctionT a -> Maybe EntryPoint
Imp.functionEntry (Function a -> Maybe EntryPoint)
-> ((Name, Function a) -> Function a)
-> (Name, Function a)
-> Maybe EntryPoint
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, Function a) -> Function a
forall a b. (a, b) -> b
snd)
    ([(Name, Function a)] -> Map Text [PyExp])
-> [(Name, Function a)] -> Map Text [PyExp]
forall a b. (a -> b) -> a -> b
$ [(Name, Function a)]
funs
  where
    evd :: ExternalValue -> Map Text [PyExp]
evd Imp.TransparentValue {} = Map Text [PyExp]
forall a. Monoid a => a
mempty
    evd (Imp.OpaqueValue Name
name [ValueDesc]
vds) = Text -> [PyExp] -> Map Text [PyExp]
forall k a. k -> a -> Map k a
M.singleton (Name -> Text
nameToText Name
name) ([PyExp] -> Map Text [PyExp]) -> [PyExp] -> Map Text [PyExp]
forall a b. (a -> b) -> a -> b
$ (ValueDesc -> PyExp) -> [ValueDesc] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map (Text -> PyExp
String (Text -> PyExp) -> (ValueDesc -> Text) -> ValueDesc -> PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ValueDesc -> Text
vd) [ValueDesc]
vds
    vd :: ValueDesc -> Text
vd (Imp.ScalarValue PrimType
pt Signedness
s VName
_) =
      PrimType -> Signedness -> Text
readTypeEnum PrimType
pt Signedness
s
    vd (Imp.ArrayValue VName
_ Space
_ PrimType
pt Signedness
s [DimSize]
dims) =
      [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat (Int -> Text -> [Text]
forall a. Int -> a -> [a]
replicate ([DimSize] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
dims) Text
"[]") Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> PrimType -> Signedness -> Text
readTypeEnum PrimType
pt Signedness
s

-- | The class generated by the code generator must have a
-- constructor, although it can be vacuous.
data Constructor = Constructor [String] [PyStmt]

-- | A constructor that takes no arguments and does nothing.
emptyConstructor :: Constructor
emptyConstructor :: Constructor
emptyConstructor = [[Char]] -> [PyStmt] -> Constructor
Constructor [[Char]
"self"] [PyStmt
Pass]

constructorToFunDef :: Constructor -> [PyStmt] -> PyFunDef
constructorToFunDef :: Constructor -> [PyStmt] -> PyFunDef
constructorToFunDef (Constructor [[Char]]
params [PyStmt]
body) [PyStmt]
at_init =
  [Char] -> [[Char]] -> [PyStmt] -> PyFunDef
Def [Char]
"__init__" [[Char]]
params ([PyStmt] -> PyFunDef) -> [PyStmt] -> PyFunDef
forall a b. (a -> b) -> a -> b
$ [PyStmt]
body [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. Semigroup a => a -> a -> a
<> [PyStmt]
at_init

compileProg ::
  (MonadFreshNames m) =>
  CompilerMode ->
  String ->
  Constructor ->
  [PyStmt] ->
  [PyStmt] ->
  Operations op s ->
  s ->
  [PyStmt] ->
  [Option] ->
  Imp.Definitions op ->
  m T.Text
compileProg :: forall (m :: * -> *) op s.
MonadFreshNames m =>
CompilerMode
-> [Char]
-> Constructor
-> [PyStmt]
-> [PyStmt]
-> Operations op s
-> s
-> [PyStmt]
-> [Option]
-> Definitions op
-> m Text
compileProg CompilerMode
mode [Char]
class_name Constructor
constructor [PyStmt]
imports [PyStmt]
defines Operations op s
ops s
userstate [PyStmt]
sync [Option]
options Definitions op
prog = do
  VNameSource
src <- m VNameSource
forall (m :: * -> *). MonadFreshNames m => m VNameSource
getNameSource
  let prog' :: [PyStmt]
prog' = Operations op s
-> VNameSource -> s -> CompilerM op s [PyStmt] -> [PyStmt]
forall op s a.
Operations op s -> VNameSource -> s -> CompilerM op s a -> a
runCompilerM Operations op s
ops VNameSource
src s
userstate CompilerM op s [PyStmt]
forall {s}. CompilerM op s [PyStmt]
compileProg'
  Text -> m Text
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Text -> m Text) -> ([PyStmt] -> Text) -> [PyStmt] -> m Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PyProg -> Text
forall a. Pretty a => a -> Text
prettyText (PyProg -> Text) -> ([PyStmt] -> PyProg) -> [PyStmt] -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [PyStmt] -> PyProg
PyProg ([PyStmt] -> m Text) -> [PyStmt] -> m Text
forall a b. (a -> b) -> a -> b
$
    [PyStmt]
imports
      [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [ [Char] -> Maybe [Char] -> PyStmt
Import [Char]
"argparse" Maybe [Char]
forall a. Maybe a
Nothing,
           PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"sizes") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [(PyExp, PyExp)] -> PyExp
Dict []
         ]
      [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [PyStmt]
defines
      [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [ Text -> PyStmt
Escape Text
valuesPy,
           Text -> PyStmt
Escape Text
memoryPy,
           Text -> PyStmt
Escape Text
panicPy,
           Text -> PyStmt
Escape Text
tuningPy,
           Text -> PyStmt
Escape Text
scalarPy,
           Text -> PyStmt
Escape Text
serverPy
         ]
      [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [PyStmt]
prog'
  where
    Imp.Definitions OpaqueTypes
_types Constants op
consts (Imp.Functions [(Name, Function op)]
funs) = Definitions op
prog
    compileProg' :: CompilerM op s [PyStmt]
compileProg' = Constants op -> CompilerM op s [PyStmt] -> CompilerM op s [PyStmt]
forall op s a. Constants op -> CompilerM op s a -> CompilerM op s a
withConstantSubsts Constants op
consts (CompilerM op s [PyStmt] -> CompilerM op s [PyStmt])
-> CompilerM op s [PyStmt] -> CompilerM op s [PyStmt]
forall a b. (a -> b) -> a -> b
$ do
      Constants op -> CompilerM op s ()
forall op s. Constants op -> CompilerM op s ()
compileConstants Constants op
consts

      [PyFunDef]
definitions <- ((Name, Function op) -> CompilerM op s PyFunDef)
-> [(Name, Function op)] -> CompilerM op s [PyFunDef]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Name, Function op) -> CompilerM op s PyFunDef
forall op s. (Name, Function op) -> CompilerM op s PyFunDef
compileFunc [(Name, Function op)]
funs
      [PyStmt]
at_inits <- (CompilerState s -> [PyStmt]) -> CompilerM op s [PyStmt]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets CompilerState s -> [PyStmt]
forall s. CompilerState s -> [PyStmt]
compInit

      let constructor' :: PyFunDef
constructor' = Constructor -> [PyStmt] -> PyFunDef
constructorToFunDef Constructor
constructor [PyStmt]
at_inits

      case CompilerMode
mode of
        CompilerMode
ToLibrary -> do
          ([PyFunDef]
entry_points, [(PyExp, PyExp)]
entry_point_types) <-
            [(PyFunDef, (PyExp, PyExp))] -> ([PyFunDef], [(PyExp, PyExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PyFunDef, (PyExp, PyExp))] -> ([PyFunDef], [(PyExp, PyExp)]))
-> ([Maybe (PyFunDef, (PyExp, PyExp))]
    -> [(PyFunDef, (PyExp, PyExp))])
-> [Maybe (PyFunDef, (PyExp, PyExp))]
-> ([PyFunDef], [(PyExp, PyExp)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (PyFunDef, (PyExp, PyExp))] -> [(PyFunDef, (PyExp, PyExp))]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (PyFunDef, (PyExp, PyExp))]
 -> ([PyFunDef], [(PyExp, PyExp)]))
-> CompilerM op s [Maybe (PyFunDef, (PyExp, PyExp))]
-> CompilerM op s ([PyFunDef], [(PyExp, PyExp)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Name, Function op)
 -> CompilerM op s (Maybe (PyFunDef, (PyExp, PyExp))))
-> [(Name, Function op)]
-> CompilerM op s [Maybe (PyFunDef, (PyExp, PyExp))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([PyStmt]
-> ReturnTiming
-> (Name, Function op)
-> CompilerM op s (Maybe (PyFunDef, (PyExp, PyExp)))
forall op s.
[PyStmt]
-> ReturnTiming
-> (Name, Function op)
-> CompilerM op s (Maybe (PyFunDef, (PyExp, PyExp)))
compileEntryFun [PyStmt]
sync ReturnTiming
DoNotReturnTiming) [(Name, Function op)]
funs
          [PyStmt] -> CompilerM op s [PyStmt]
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
            [ PyClassDef -> PyStmt
ClassDef (PyClassDef -> PyStmt) -> PyClassDef -> PyStmt
forall a b. (a -> b) -> a -> b
$
                [Char] -> [PyStmt] -> PyClassDef
Class [Char]
class_name ([PyStmt] -> PyClassDef) -> [PyStmt] -> PyClassDef
forall a b. (a -> b) -> a -> b
$
                  PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"entry_points") ([(PyExp, PyExp)] -> PyExp
Dict [(PyExp, PyExp)]
entry_point_types)
                    PyStmt -> [PyStmt] -> [PyStmt]
forall a. a -> [a] -> [a]
: PyExp -> PyExp -> PyStmt
Assign
                      ([Char] -> PyExp
Var [Char]
"opaques")
                      ([(PyExp, PyExp)] -> PyExp
Dict ([(PyExp, PyExp)] -> PyExp) -> [(PyExp, PyExp)] -> PyExp
forall a b. (a -> b) -> a -> b
$ [PyExp] -> [PyExp] -> [(PyExp, PyExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Text -> PyExp) -> [Text] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map Text -> PyExp
String [Text]
opaque_names) (([PyExp] -> PyExp) -> [[PyExp]] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map [PyExp] -> PyExp
Tuple [[PyExp]]
opaque_payloads))
                    PyStmt -> [PyStmt] -> [PyStmt]
forall a. a -> [a] -> [a]
: (PyFunDef -> PyStmt) -> [PyFunDef] -> [PyStmt]
forall a b. (a -> b) -> [a] -> [b]
map PyFunDef -> PyStmt
FunDef (PyFunDef
constructor' PyFunDef -> [PyFunDef] -> [PyFunDef]
forall a. a -> [a] -> [a]
: [PyFunDef]
definitions [PyFunDef] -> [PyFunDef] -> [PyFunDef]
forall a. [a] -> [a] -> [a]
++ [PyFunDef]
entry_points)
            ]
        CompilerMode
ToServer -> do
          ([PyFunDef]
entry_points, [(PyExp, PyExp)]
entry_point_types) <-
            [(PyFunDef, (PyExp, PyExp))] -> ([PyFunDef], [(PyExp, PyExp)])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(PyFunDef, (PyExp, PyExp))] -> ([PyFunDef], [(PyExp, PyExp)]))
-> ([Maybe (PyFunDef, (PyExp, PyExp))]
    -> [(PyFunDef, (PyExp, PyExp))])
-> [Maybe (PyFunDef, (PyExp, PyExp))]
-> ([PyFunDef], [(PyExp, PyExp)])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (PyFunDef, (PyExp, PyExp))] -> [(PyFunDef, (PyExp, PyExp))]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (PyFunDef, (PyExp, PyExp))]
 -> ([PyFunDef], [(PyExp, PyExp)]))
-> CompilerM op s [Maybe (PyFunDef, (PyExp, PyExp))]
-> CompilerM op s ([PyFunDef], [(PyExp, PyExp)])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Name, Function op)
 -> CompilerM op s (Maybe (PyFunDef, (PyExp, PyExp))))
-> [(Name, Function op)]
-> CompilerM op s [Maybe (PyFunDef, (PyExp, PyExp))]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([PyStmt]
-> ReturnTiming
-> (Name, Function op)
-> CompilerM op s (Maybe (PyFunDef, (PyExp, PyExp)))
forall op s.
[PyStmt]
-> ReturnTiming
-> (Name, Function op)
-> CompilerM op s (Maybe (PyFunDef, (PyExp, PyExp)))
compileEntryFun [PyStmt]
sync ReturnTiming
ReturnTiming) [(Name, Function op)]
funs
          [PyStmt] -> CompilerM op s [PyStmt]
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([PyStmt] -> CompilerM op s [PyStmt])
-> [PyStmt] -> CompilerM op s [PyStmt]
forall a b. (a -> b) -> a -> b
$
            [PyStmt]
parse_options_server
              [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [ PyClassDef -> PyStmt
ClassDef
                     ( [Char] -> [PyStmt] -> PyClassDef
Class [Char]
class_name ([PyStmt] -> PyClassDef) -> [PyStmt] -> PyClassDef
forall a b. (a -> b) -> a -> b
$
                         PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"entry_points") ([(PyExp, PyExp)] -> PyExp
Dict [(PyExp, PyExp)]
entry_point_types)
                           PyStmt -> [PyStmt] -> [PyStmt]
forall a. a -> [a] -> [a]
: PyExp -> PyExp -> PyStmt
Assign
                             ([Char] -> PyExp
Var [Char]
"opaques")
                             ([(PyExp, PyExp)] -> PyExp
Dict ([(PyExp, PyExp)] -> PyExp) -> [(PyExp, PyExp)] -> PyExp
forall a b. (a -> b) -> a -> b
$ [PyExp] -> [PyExp] -> [(PyExp, PyExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Text -> PyExp) -> [Text] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map Text -> PyExp
String [Text]
opaque_names) (([PyExp] -> PyExp) -> [[PyExp]] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map [PyExp] -> PyExp
Tuple [[PyExp]]
opaque_payloads))
                           PyStmt -> [PyStmt] -> [PyStmt]
forall a. a -> [a] -> [a]
: (PyFunDef -> PyStmt) -> [PyFunDef] -> [PyStmt]
forall a b. (a -> b) -> [a] -> [b]
map PyFunDef -> PyStmt
FunDef (PyFunDef
constructor' PyFunDef -> [PyFunDef] -> [PyFunDef]
forall a. a -> [a] -> [a]
: [PyFunDef]
definitions [PyFunDef] -> [PyFunDef] -> [PyFunDef]
forall a. [a] -> [a] -> [a]
++ [PyFunDef]
entry_points)
                     ),
                   PyExp -> PyExp -> PyStmt
Assign
                     ([Char] -> PyExp
Var [Char]
"server")
                     ([Char] -> [PyExp] -> PyExp
simpleCall [Char]
"Server" [[Char] -> [PyExp] -> PyExp
simpleCall [Char]
class_name []]),
                   PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"server.run" []
                 ]
        CompilerMode
ToExecutable -> do
          let classinst :: PyStmt
classinst = PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"self") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
class_name []
          ([PyFunDef]
entry_point_defs, [Text]
entry_point_names, [PyExp]
entry_points) <-
            [(PyFunDef, Text, PyExp)] -> ([PyFunDef], [Text], [PyExp])
forall a b c. [(a, b, c)] -> ([a], [b], [c])
unzip3 ([(PyFunDef, Text, PyExp)] -> ([PyFunDef], [Text], [PyExp]))
-> ([Maybe (PyFunDef, Text, PyExp)] -> [(PyFunDef, Text, PyExp)])
-> [Maybe (PyFunDef, Text, PyExp)]
-> ([PyFunDef], [Text], [PyExp])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (PyFunDef, Text, PyExp)] -> [(PyFunDef, Text, PyExp)]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe (PyFunDef, Text, PyExp)] -> ([PyFunDef], [Text], [PyExp]))
-> CompilerM op s [Maybe (PyFunDef, Text, PyExp)]
-> CompilerM op s ([PyFunDef], [Text], [PyExp])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Name, Function op)
 -> CompilerM op s (Maybe (PyFunDef, Text, PyExp)))
-> [(Name, Function op)]
-> CompilerM op s [Maybe (PyFunDef, Text, PyExp)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ([PyStmt]
-> (Name, Function op)
-> CompilerM op s (Maybe (PyFunDef, Text, PyExp))
forall op s.
[PyStmt]
-> (Name, Function op)
-> CompilerM op s (Maybe (PyFunDef, Text, PyExp))
callEntryFun [PyStmt]
sync) [(Name, Function op)]
funs
          [PyStmt] -> CompilerM op s [PyStmt]
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([PyStmt] -> CompilerM op s [PyStmt])
-> [PyStmt] -> CompilerM op s [PyStmt]
forall a b. (a -> b) -> a -> b
$
            [PyStmt]
parse_options_executable
              [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ PyClassDef -> PyStmt
ClassDef
                ( [Char] -> [PyStmt] -> PyClassDef
Class [Char]
class_name ([PyStmt] -> PyClassDef) -> [PyStmt] -> PyClassDef
forall a b. (a -> b) -> a -> b
$
                    (PyFunDef -> PyStmt) -> [PyFunDef] -> [PyStmt]
forall a b. (a -> b) -> [a] -> [b]
map PyFunDef -> PyStmt
FunDef ([PyFunDef] -> [PyStmt]) -> [PyFunDef] -> [PyStmt]
forall a b. (a -> b) -> a -> b
$
                      PyFunDef
constructor' PyFunDef -> [PyFunDef] -> [PyFunDef]
forall a. a -> [a] -> [a]
: [PyFunDef]
definitions
                )
              PyStmt -> [PyStmt] -> [PyStmt]
forall a. a -> [a] -> [a]
: PyStmt
classinst
              PyStmt -> [PyStmt] -> [PyStmt]
forall a. a -> [a] -> [a]
: (PyFunDef -> PyStmt) -> [PyFunDef] -> [PyStmt]
forall a b. (a -> b) -> [a] -> [b]
map PyFunDef -> PyStmt
FunDef [PyFunDef]
entry_point_defs
              [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [Text] -> [PyExp] -> [PyStmt]
selectEntryPoint [Text]
entry_point_names [PyExp]
entry_points

    parse_options_executable :: [PyStmt]
parse_options_executable =
      PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"runtime_file") PyExp
None
        PyStmt -> [PyStmt] -> [PyStmt]
forall a. a -> [a] -> [a]
: PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"do_warmup_run") (Bool -> PyExp
Bool Bool
False)
        PyStmt -> [PyStmt] -> [PyStmt]
forall a. a -> [a] -> [a]
: PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"num_runs") (Integer -> PyExp
Integer Integer
1)
        PyStmt -> [PyStmt] -> [PyStmt]
forall a. a -> [a] -> [a]
: PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"entry_point") (Text -> PyExp
String Text
"main")
        PyStmt -> [PyStmt] -> [PyStmt]
forall a. a -> [a] -> [a]
: PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"binary_output") (Bool -> PyExp
Bool Bool
False)
        PyStmt -> [PyStmt] -> [PyStmt]
forall a. a -> [a] -> [a]
: [Option] -> [PyStmt]
generateOptionParser ([Option]
executableOptions [Option] -> [Option] -> [Option]
forall a. [a] -> [a] -> [a]
++ [Option]
options)

    parse_options_server :: [PyStmt]
parse_options_server =
      [Option] -> [PyStmt]
generateOptionParser ([Option]
standardOptions [Option] -> [Option] -> [Option]
forall a. [a] -> [a] -> [a]
++ [Option]
options)

    ([Text]
opaque_names, [[PyExp]]
opaque_payloads) =
      [(Text, [PyExp])] -> ([Text], [[PyExp]])
forall a b. [(a, b)] -> ([a], [b])
unzip ([(Text, [PyExp])] -> ([Text], [[PyExp]]))
-> [(Text, [PyExp])] -> ([Text], [[PyExp]])
forall a b. (a -> b) -> a -> b
$ Map Text [PyExp] -> [(Text, [PyExp])]
forall k a. Map k a -> [(k, a)]
M.toList (Map Text [PyExp] -> [(Text, [PyExp])])
-> Map Text [PyExp] -> [(Text, [PyExp])]
forall a b. (a -> b) -> a -> b
$ Functions op -> Map Text [PyExp]
forall a. Functions a -> Map Text [PyExp]
opaqueDefs (Functions op -> Map Text [PyExp])
-> Functions op -> Map Text [PyExp]
forall a b. (a -> b) -> a -> b
$ Definitions op -> Functions op
forall a. Definitions a -> Functions a
Imp.defFuns Definitions op
prog

    selectEntryPoint :: [Text] -> [PyExp] -> [PyStmt]
selectEntryPoint [Text]
entry_point_names [PyExp]
entry_points =
      [ PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"entry_points") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
          [(PyExp, PyExp)] -> PyExp
Dict ([(PyExp, PyExp)] -> PyExp) -> [(PyExp, PyExp)] -> PyExp
forall a b. (a -> b) -> a -> b
$
            [PyExp] -> [PyExp] -> [(PyExp, PyExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Text -> PyExp) -> [Text] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map Text -> PyExp
String [Text]
entry_point_names) [PyExp]
entry_points,
        PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"entry_point_fun") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
          [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"entry_points.get" [[Char] -> PyExp
Var [Char]
"entry_point"],
        PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If
          ([Char] -> PyExp -> PyExp -> PyExp
BinOp [Char]
"==" ([Char] -> PyExp
Var [Char]
"entry_point_fun") PyExp
None)
          [ PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
              [Char] -> [PyExp] -> PyExp
simpleCall
                [Char]
"sys.exit"
                [ PyExp -> [PyArg] -> PyExp
Call
                    ( PyExp -> [Char] -> PyExp
Field
                        (Text -> PyExp
String Text
"No entry point '{}'.  Select another with --entry point.  Options are:\n{}")
                        [Char]
"format"
                    )
                    [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ [Char] -> PyExp
Var [Char]
"entry_point",
                      PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$
                        PyExp -> [PyArg] -> PyExp
Call
                          (PyExp -> [Char] -> PyExp
Field (Text -> PyExp
String Text
"\n") [Char]
"join")
                          [PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"entry_points.keys" []]
                    ]
                ]
          ]
          [PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"entry_point_fun" []]
      ]

withConstantSubsts :: Imp.Constants op -> CompilerM op s a -> CompilerM op s a
withConstantSubsts :: forall op s a. Constants op -> CompilerM op s a -> CompilerM op s a
withConstantSubsts (Imp.Constants [Param]
ps Code op
_) =
  (CompilerEnv op s -> CompilerEnv op s)
-> CompilerM op s a -> CompilerM op s a
forall a.
(CompilerEnv op s -> CompilerEnv op s)
-> CompilerM op s a -> CompilerM op s a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((CompilerEnv op s -> CompilerEnv op s)
 -> CompilerM op s a -> CompilerM op s a)
-> (CompilerEnv op s -> CompilerEnv op s)
-> CompilerM op s a
-> CompilerM op s a
forall a b. (a -> b) -> a -> b
$ \CompilerEnv op s
env -> CompilerEnv op s
env {envVarExp :: Map [Char] PyExp
envVarExp = (Param -> Map [Char] PyExp) -> [Param] -> Map [Char] PyExp
forall m a. Monoid m => (a -> m) -> [a] -> m
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Param -> Map [Char] PyExp
constExp [Param]
ps}
  where
    constExp :: Param -> Map [Char] PyExp
constExp Param
p =
      [Char] -> PyExp -> Map [Char] PyExp
forall k a. k -> a -> Map k a
M.singleton
        (VName -> [Char]
compileName (VName -> [Char]) -> VName -> [Char]
forall a b. (a -> b) -> a -> b
$ Param -> VName
Imp.paramName Param
p)
        (PyExp -> PyIdx -> PyExp
Index ([Char] -> PyExp
Var [Char]
"self.constants") (PyIdx -> PyExp) -> PyIdx -> PyExp
forall a b. (a -> b) -> a -> b
$ PyExp -> PyIdx
IdxExp (PyExp -> PyIdx) -> PyExp -> PyIdx
forall a b. (a -> b) -> a -> b
$ Text -> PyExp
String (Text -> PyExp) -> Text -> PyExp
forall a b. (a -> b) -> a -> b
$ VName -> Text
forall a. Pretty a => a -> Text
prettyText (VName -> Text) -> VName -> Text
forall a b. (a -> b) -> a -> b
$ Param -> VName
Imp.paramName Param
p)

compileConstants :: Imp.Constants op -> CompilerM op s ()
compileConstants :: forall op s. Constants op -> CompilerM op s ()
compileConstants (Imp.Constants [Param]
_ Code op
init_consts) = do
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
atInit (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"self.constants") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [(PyExp, PyExp)] -> PyExp
Dict []
  (PyStmt -> CompilerM op s ()) -> [PyStmt] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
atInit ([PyStmt] -> CompilerM op s ())
-> CompilerM op s [PyStmt] -> CompilerM op s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< CompilerM op s () -> CompilerM op s [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
collect (Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
init_consts)

compileFunc :: (Name, Imp.Function op) -> CompilerM op s PyFunDef
compileFunc :: forall op s. (Name, Function op) -> CompilerM op s PyFunDef
compileFunc (Name
fname, Imp.Function Maybe EntryPoint
_ [Param]
outputs [Param]
inputs Code op
body) = do
  [PyStmt]
body' <- CompilerM op s () -> CompilerM op s [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
collect (CompilerM op s () -> CompilerM op s [PyStmt])
-> CompilerM op s () -> CompilerM op s [PyStmt]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
body
  let inputs' :: [[Char]]
inputs' = (Param -> [Char]) -> [Param] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> [Char]
compileName (VName -> [Char]) -> (Param -> VName) -> Param -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param -> VName
Imp.paramName) [Param]
inputs
  let ret :: PyStmt
ret = PyExp -> PyStmt
Return (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [PyExp] -> PyExp
tupleOrSingle ([PyExp] -> PyExp) -> [PyExp] -> PyExp
forall a b. (a -> b) -> a -> b
$ [Param] -> [PyExp]
compileOutput [Param]
outputs
  PyFunDef -> CompilerM op s PyFunDef
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PyFunDef -> CompilerM op s PyFunDef)
-> PyFunDef -> CompilerM op s PyFunDef
forall a b. (a -> b) -> a -> b
$
    [Char] -> [[Char]] -> [PyStmt] -> PyFunDef
Def (Text -> [Char]
T.unpack (Text -> [Char]) -> Text -> [Char]
forall a b. (a -> b) -> a -> b
$ Text -> Text
futharkFun (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Text
nameToText Name
fname) ([Char]
"self" [Char] -> [[Char]] -> [[Char]]
forall a. a -> [a] -> [a]
: [[Char]]
inputs') ([PyStmt] -> PyFunDef) -> [PyStmt] -> PyFunDef
forall a b. (a -> b) -> a -> b
$
      [PyStmt]
body' [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [PyStmt
ret]

tupleOrSingle :: [PyExp] -> PyExp
tupleOrSingle :: [PyExp] -> PyExp
tupleOrSingle [PyExp
e] = PyExp
e
tupleOrSingle [PyExp]
es = [PyExp] -> PyExp
Tuple [PyExp]
es

-- | A 'Call' where the function is a variable and every argument is a
-- simple 'Arg'.
simpleCall :: String -> [PyExp] -> PyExp
simpleCall :: [Char] -> [PyExp] -> PyExp
simpleCall [Char]
fname = PyExp -> [PyArg] -> PyExp
Call ([Char] -> PyExp
Var [Char]
fname) ([PyArg] -> PyExp) -> ([PyExp] -> [PyArg]) -> [PyExp] -> PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (PyExp -> PyArg) -> [PyExp] -> [PyArg]
forall a b. (a -> b) -> [a] -> [b]
map PyExp -> PyArg
Arg

compileName :: VName -> String
compileName :: VName -> [Char]
compileName = Text -> [Char]
T.unpack (Text -> [Char]) -> (VName -> Text) -> VName -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> Text
zEncodeText (Text -> Text) -> (VName -> Text) -> VName -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> Text
forall a. Pretty a => a -> Text
prettyText

compileDim :: Imp.DimSize -> CompilerM op s PyExp
compileDim :: forall op s. DimSize -> CompilerM op s PyExp
compileDim (Imp.Constant PrimValue
v) = PyExp -> CompilerM op s PyExp
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PyExp -> CompilerM op s PyExp) -> PyExp -> CompilerM op s PyExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> PyExp
compilePrimValue PrimValue
v
compileDim (Imp.Var VName
v) = VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
v

unpackDim :: PyExp -> Imp.DimSize -> Int32 -> CompilerM op s ()
unpackDim :: forall op s. PyExp -> DimSize -> Int32 -> CompilerM op s ()
unpackDim PyExp
arr_name (Imp.Constant PrimValue
c) Int32
i = do
  let shape_name :: PyExp
shape_name = PyExp -> [Char] -> PyExp
Field PyExp
arr_name [Char]
"shape"
  let constant_c :: PyExp
constant_c = PrimValue -> PyExp
compilePrimValue PrimValue
c
  let constant_i :: PyExp
constant_i = Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ Int32 -> Integer
forall a. Integral a => a -> Integer
toInteger Int32
i
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyExp -> PyStmt
Assert ([Char] -> PyExp -> PyExp -> PyExp
BinOp [Char]
"==" PyExp
constant_c (PyExp -> PyIdx -> PyExp
Index PyExp
shape_name (PyIdx -> PyExp) -> PyIdx -> PyExp
forall a b. (a -> b) -> a -> b
$ PyExp -> PyIdx
IdxExp PyExp
constant_i)) (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
      Text -> PyExp
String Text
"Entry point arguments have invalid sizes."
unpackDim PyExp
arr_name (Imp.Var VName
var) Int32
i = do
  let shape_name :: PyExp
shape_name = PyExp -> [Char] -> PyExp
Field PyExp
arr_name [Char]
"shape"
      src :: PyExp
src = PyExp -> PyIdx -> PyExp
Index PyExp
shape_name (PyIdx -> PyExp) -> PyIdx -> PyExp
forall a b. (a -> b) -> a -> b
$ PyExp -> PyIdx
IdxExp (PyExp -> PyIdx) -> PyExp -> PyIdx
forall a b. (a -> b) -> a -> b
$ Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ Int32 -> Integer
forall a. Integral a => a -> Integer
toInteger Int32
i
  PyExp
var' <- VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
var
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If
      ([Char] -> PyExp -> PyExp -> PyExp
BinOp [Char]
"==" PyExp
var' PyExp
None)
      [PyExp -> PyExp -> PyStmt
Assign PyExp
var' (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"np.int64" [PyExp
src]]
      [ PyExp -> PyExp -> PyStmt
Assert ([Char] -> PyExp -> PyExp -> PyExp
BinOp [Char]
"==" PyExp
var' PyExp
src) (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
          Text -> PyExp
String Text
"Error: entry point arguments have invalid sizes."
      ]

entryPointOutput :: Imp.ExternalValue -> CompilerM op s PyExp
entryPointOutput :: forall op s. ExternalValue -> CompilerM op s PyExp
entryPointOutput (Imp.OpaqueValue Name
desc [ValueDesc]
vs) =
  [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"opaque" ([PyExp] -> PyExp) -> ([PyExp] -> [PyExp]) -> [PyExp] -> PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> PyExp
String (Name -> Text
forall a. Pretty a => a -> Text
prettyText Name
desc) :)
    ([PyExp] -> PyExp)
-> CompilerM op s [PyExp] -> CompilerM op s PyExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ValueDesc -> CompilerM op s PyExp)
-> [ValueDesc] -> CompilerM op s [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (ExternalValue -> CompilerM op s PyExp
forall op s. ExternalValue -> CompilerM op s PyExp
entryPointOutput (ExternalValue -> CompilerM op s PyExp)
-> (ValueDesc -> ExternalValue)
-> ValueDesc
-> CompilerM op s PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ValueDesc -> ExternalValue
Imp.TransparentValue) [ValueDesc]
vs
entryPointOutput (Imp.TransparentValue (Imp.ScalarValue PrimType
bt Signedness
ept VName
name)) = do
  PyExp
name' <- VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
name
  PyExp -> CompilerM op s PyExp
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PyExp -> CompilerM op s PyExp) -> PyExp -> CompilerM op s PyExp
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
tf [PyExp
name']
  where
    tf :: [Char]
tf = PrimType -> Signedness -> [Char]
compilePrimToExtNp PrimType
bt Signedness
ept
entryPointOutput (Imp.TransparentValue (Imp.ArrayValue VName
mem (Imp.Space [Char]
sid) PrimType
bt Signedness
ept [DimSize]
dims)) = do
  EntryOutput op s
pack_output <- (CompilerEnv op s -> EntryOutput op s)
-> CompilerM op s (EntryOutput op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> EntryOutput op s
forall op s. CompilerEnv op s -> EntryOutput op s
envEntryOutput
  EntryOutput op s
pack_output VName
mem [Char]
sid PrimType
bt Signedness
ept [DimSize]
dims
entryPointOutput (Imp.TransparentValue (Imp.ArrayValue VName
mem Space
_ PrimType
bt Signedness
ept [DimSize]
dims)) = do
  PyExp
mem' <- VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
mem
  [PyExp]
dims' <- (DimSize -> CompilerM op s PyExp)
-> [DimSize] -> CompilerM op s [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM DimSize -> CompilerM op s PyExp
forall op s. DimSize -> CompilerM op s PyExp
compileDim [DimSize]
dims
  PyExp -> CompilerM op s PyExp
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PyExp -> CompilerM op s PyExp) -> PyExp -> CompilerM op s PyExp
forall a b. (a -> b) -> a -> b
$
    [Char] -> [PyExp] -> PyExp
simpleCall
      [Char]
"np.reshape"
      [ PyExp -> PyIdx -> PyExp
Index
          (PyExp -> [PyArg] -> PyExp
Call (PyExp -> [Char] -> PyExp
Field PyExp
mem' [Char]
"view") [PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ [Char] -> PyExp
Var ([Char] -> PyExp) -> [Char] -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> [Char]
compilePrimToExtNp PrimType
bt Signedness
ept])
          (PyExp -> PyExp -> PyIdx
IdxRange (Integer -> PyExp
Integer Integer
0) ((PyExp -> PyExp -> PyExp) -> [PyExp] -> PyExp
forall a. (a -> a -> a) -> [a] -> a
forall (t :: * -> *) a. Foldable t => (a -> a -> a) -> t a -> a
foldl1 ([Char] -> PyExp -> PyExp -> PyExp
BinOp [Char]
"*") [PyExp]
dims')),
        [PyExp] -> PyExp
Tuple [PyExp]
dims'
      ]

badInput :: Int -> PyExp -> T.Text -> PyStmt
badInput :: Int -> PyExp -> Text -> PyStmt
badInput Int
i PyExp
e Text
t =
  PyExp -> PyStmt
Raise (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
    [Char] -> [PyExp] -> PyExp
simpleCall
      [Char]
"TypeError"
      [ PyExp -> [PyArg] -> PyExp
Call
          (PyExp -> [Char] -> PyExp
Field (Text -> PyExp
String Text
err_msg) [Char]
"format")
          [PyExp -> PyArg
Arg (Text -> PyExp
String Text
t), PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"type" [PyExp
e], PyExp -> PyArg
Arg PyExp
e]
      ]
  where
    err_msg :: Text
err_msg =
      [Text] -> Text
T.unlines
        [ Text
"Argument #" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Pretty a => a -> Text
prettyText Int
i Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" has invalid value",
          Text
"Futhark type: {}",
          Text
"Argument has Python type {} and value: {}"
        ]

badInputType :: Int -> PyExp -> T.Text -> PyExp -> PyExp -> PyStmt
badInputType :: Int -> PyExp -> Text -> PyExp -> PyExp -> PyStmt
badInputType Int
i PyExp
e Text
t PyExp
de PyExp
dg =
  PyExp -> PyStmt
Raise (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
    [Char] -> [PyExp] -> PyExp
simpleCall
      [Char]
"TypeError"
      [ PyExp -> [PyArg] -> PyExp
Call
          (PyExp -> [Char] -> PyExp
Field (Text -> PyExp
String Text
err_msg) [Char]
"format")
          [PyExp -> PyArg
Arg (Text -> PyExp
String Text
t), PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"type" [PyExp
e], PyExp -> PyArg
Arg PyExp
e, PyExp -> PyArg
Arg PyExp
de, PyExp -> PyArg
Arg PyExp
dg]
      ]
  where
    err_msg :: Text
err_msg =
      [Text] -> Text
T.unlines
        [ Text
"Argument #" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Pretty a => a -> Text
prettyText Int
i Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" has invalid value",
          Text
"Futhark type: {}",
          Text
"Argument has Python type {} and value: {}",
          Text
"Expected array with elements of dtype: {}",
          Text
"The array given has elements of dtype: {}"
        ]

badInputDim :: Int -> PyExp -> T.Text -> Int -> PyStmt
badInputDim :: Int -> PyExp -> Text -> Int -> PyStmt
badInputDim Int
i PyExp
e Text
typ Int
dimf =
  PyExp -> PyStmt
Raise (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
    [Char] -> [PyExp] -> PyExp
simpleCall
      [Char]
"TypeError"
      [ PyExp -> [PyArg] -> PyExp
Call
          (PyExp -> [Char] -> PyExp
Field (Text -> PyExp
String Text
err_msg) [Char]
"format")
          [PyExp -> PyArg
Arg PyExp
eft, PyExp -> PyArg
Arg PyExp
aft]
      ]
  where
    eft :: PyExp
eft = Text -> PyExp
String ([Text] -> Text
forall a. Monoid a => [a] -> a
mconcat (Int -> Text -> [Text]
forall a. Int -> a -> [a]
replicate Int
dimf Text
"[]") Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
typ)
    aft :: PyExp
aft = [Char] -> PyExp -> PyExp -> PyExp
BinOp [Char]
"+" ([Char] -> PyExp -> PyExp -> PyExp
BinOp [Char]
"*" (Text -> PyExp
String Text
"[]") (PyExp -> [Char] -> PyExp
Field PyExp
e [Char]
"ndim")) (Text -> PyExp
String Text
typ)
    err_msg :: Text
err_msg =
      [Text] -> Text
T.unlines
        [ Text
"Argument #" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Pretty a => a -> Text
prettyText Int
i Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
" has invalid value",
          Text
"Dimensionality mismatch",
          Text
"Expected Futhark type: {}",
          Text
"Bad Python value passed",
          Text
"Actual Futhark type: {}"
        ]

declEntryPointInputSizes :: [Imp.ExternalValue] -> CompilerM op s ()
declEntryPointInputSizes :: forall op s. [ExternalValue] -> CompilerM op s ()
declEntryPointInputSizes = (VName -> CompilerM op s ()) -> [VName] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ VName -> CompilerM op s ()
forall {op} {s}. VName -> CompilerM op s ()
onSize ([VName] -> CompilerM op s ())
-> ([ExternalValue] -> [VName])
-> [ExternalValue]
-> CompilerM op s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ExternalValue -> [VName]) -> [ExternalValue] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ExternalValue -> [VName]
sizes
  where
    sizes :: ExternalValue -> [VName]
sizes (Imp.TransparentValue ValueDesc
v) = ValueDesc -> [VName]
valueSizes ValueDesc
v
    sizes (Imp.OpaqueValue Name
_ [ValueDesc]
vs) = (ValueDesc -> [VName]) -> [ValueDesc] -> [VName]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ValueDesc -> [VName]
valueSizes [ValueDesc]
vs
    valueSizes :: ValueDesc -> [VName]
valueSizes (Imp.ArrayValue VName
_ Space
_ PrimType
_ Signedness
_ [DimSize]
dims) = [DimSize] -> [VName]
subExpVars [DimSize]
dims
    valueSizes Imp.ScalarValue {} = []
    onSize :: VName -> CompilerM op s ()
onSize VName
v = PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var (VName -> [Char]
compileName VName
v)) PyExp
None

entryPointInput :: (Int, Imp.ExternalValue, PyExp) -> CompilerM op s ()
entryPointInput :: forall op s. (Int, ExternalValue, PyExp) -> CompilerM op s ()
entryPointInput (Int
i, Imp.OpaqueValue Name
desc [ValueDesc]
vs, PyExp
e) = do
  let type_is_ok :: PyExp
type_is_ok =
        [Char] -> PyExp -> PyExp -> PyExp
BinOp
          [Char]
"and"
          ([Char] -> [PyExp] -> PyExp
simpleCall [Char]
"isinstance" [PyExp
e, [Char] -> PyExp
Var [Char]
"opaque"])
          ([Char] -> PyExp -> PyExp -> PyExp
BinOp [Char]
"==" (PyExp -> [Char] -> PyExp
Field PyExp
e [Char]
"desc") (Text -> PyExp
String (Name -> Text
nameToText Name
desc)))
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If ([Char] -> PyExp -> PyExp
UnOp [Char]
"not" PyExp
type_is_ok) [Int -> PyExp -> Text -> PyStmt
badInput Int
i PyExp
e (Name -> Text
nameToText Name
desc)] []
  ((Int, ExternalValue, PyExp) -> CompilerM op s ())
-> [(Int, ExternalValue, PyExp)] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Int, ExternalValue, PyExp) -> CompilerM op s ()
forall op s. (Int, ExternalValue, PyExp) -> CompilerM op s ()
entryPointInput ([(Int, ExternalValue, PyExp)] -> CompilerM op s ())
-> [(Int, ExternalValue, PyExp)] -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    [Int]
-> [ExternalValue] -> [PyExp] -> [(Int, ExternalValue, PyExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 (Int -> [Int]
forall a. a -> [a]
repeat Int
i) ((ValueDesc -> ExternalValue) -> [ValueDesc] -> [ExternalValue]
forall a b. (a -> b) -> [a] -> [b]
map ValueDesc -> ExternalValue
Imp.TransparentValue [ValueDesc]
vs) ([PyExp] -> [(Int, ExternalValue, PyExp)])
-> [PyExp] -> [(Int, ExternalValue, PyExp)]
forall a b. (a -> b) -> a -> b
$
      (Integer -> PyExp) -> [Integer] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map (PyExp -> PyIdx -> PyExp
Index (PyExp -> [Char] -> PyExp
Field PyExp
e [Char]
"data") (PyIdx -> PyExp) -> (Integer -> PyIdx) -> Integer -> PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PyExp -> PyIdx
IdxExp (PyExp -> PyIdx) -> (Integer -> PyExp) -> Integer -> PyIdx
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> PyExp
Integer) [Integer
0 ..]
entryPointInput (Int
i, Imp.TransparentValue (Imp.ScalarValue PrimType
bt Signedness
s VName
name), PyExp
e) = do
  PyExp
vname' <- VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
name
  let -- HACK: A Numpy int64 will signal an OverflowError if we pass
      -- it a number bigger than 2**63.  This does not happen if we
      -- pass e.g. int8 a number bigger than 2**7.  As a workaround,
      -- we first go through the corresponding ctypes type, which does
      -- not have this problem.
      ctobject :: [Char]
ctobject = PrimType -> [Char]
compilePrimType PrimType
bt
      npobject :: [Char]
npobject = PrimType -> [Char]
compilePrimToNp PrimType
bt
      npcall :: PyExp
npcall =
        [Char] -> [PyExp] -> PyExp
simpleCall
          [Char]
npobject
          [ case PrimType
bt of
              IntType IntType
Int64 -> [Char] -> [PyExp] -> PyExp
simpleCall [Char]
ctobject [PyExp
e]
              PrimType
_ -> PyExp
e
          ]
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    [PyStmt] -> [PyExcept] -> PyStmt
Try
      [PyExp -> PyExp -> PyStmt
Assign PyExp
vname' PyExp
npcall]
      [ PyExp -> [PyStmt] -> PyExcept
Catch
          ([PyExp] -> PyExp
Tuple [[Char] -> PyExp
Var [Char]
"TypeError", [Char] -> PyExp
Var [Char]
"AssertionError"])
          [Int -> PyExp -> Text -> PyStmt
badInput Int
i PyExp
e (Text -> PyStmt) -> Text -> PyStmt
forall a b. (a -> b) -> a -> b
$ Bool -> PrimType -> Text
prettySigned (Signedness
s Signedness -> Signedness -> Bool
forall a. Eq a => a -> a -> Bool
== Signedness
Imp.Unsigned) PrimType
bt]
      ]
entryPointInput (Int
i, Imp.TransparentValue (Imp.ArrayValue VName
mem (Imp.Space [Char]
sid) PrimType
bt Signedness
ept [DimSize]
dims), PyExp
e) = do
  EntryInput op s
unpack_input <- (CompilerEnv op s -> EntryInput op s)
-> CompilerM op s (EntryInput op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> EntryInput op s
forall op s. CompilerEnv op s -> EntryInput op s
envEntryInput
  PyExp
mem' <- VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
mem
  [PyStmt]
unpack <- CompilerM op s () -> CompilerM op s [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
collect (CompilerM op s () -> CompilerM op s [PyStmt])
-> CompilerM op s () -> CompilerM op s [PyStmt]
forall a b. (a -> b) -> a -> b
$ EntryInput op s
unpack_input PyExp
mem' [Char]
sid PrimType
bt Signedness
ept [DimSize]
dims PyExp
e
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    [PyStmt] -> [PyExcept] -> PyStmt
Try
      [PyStmt]
unpack
      [ PyExp -> [PyStmt] -> PyExcept
Catch
          ([PyExp] -> PyExp
Tuple [[Char] -> PyExp
Var [Char]
"TypeError", [Char] -> PyExp
Var [Char]
"AssertionError"])
          [ Int -> PyExp -> Text -> PyStmt
badInput Int
i PyExp
e (Text -> PyStmt) -> Text -> PyStmt
forall a b. (a -> b) -> a -> b
$
              [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat (Int -> Text -> [Text]
forall a. Int -> a -> [a]
replicate ([DimSize] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
dims) Text
"[]")
                Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Bool -> PrimType -> Text
prettySigned (Signedness
ept Signedness -> Signedness -> Bool
forall a. Eq a => a -> a -> Bool
== Signedness
Imp.Unsigned) PrimType
bt
          ]
      ]
entryPointInput (Int
i, Imp.TransparentValue (Imp.ArrayValue VName
mem Space
_ PrimType
t Signedness
s [DimSize]
dims), PyExp
e) = do
  let type_is_wrong :: PyExp
type_is_wrong = [Char] -> PyExp -> PyExp
UnOp [Char]
"not" (PyExp -> PyExp) -> PyExp -> PyExp
forall a b. (a -> b) -> a -> b
$ [Char] -> PyExp -> PyExp -> PyExp
BinOp [Char]
"in" ([Char] -> [PyExp] -> PyExp
simpleCall [Char]
"type" [PyExp
e]) (PyExp -> PyExp) -> PyExp -> PyExp
forall a b. (a -> b) -> a -> b
$ [PyExp] -> PyExp
List [[Char] -> PyExp
Var [Char]
"np.ndarray"]
  let dtype_is_wrong :: PyExp
dtype_is_wrong = [Char] -> PyExp -> PyExp
UnOp [Char]
"not" (PyExp -> PyExp) -> PyExp -> PyExp
forall a b. (a -> b) -> a -> b
$ [Char] -> PyExp -> PyExp -> PyExp
BinOp [Char]
"==" (PyExp -> [Char] -> PyExp
Field PyExp
e [Char]
"dtype") (PyExp -> PyExp) -> PyExp -> PyExp
forall a b. (a -> b) -> a -> b
$ [Char] -> PyExp
Var ([Char] -> PyExp) -> [Char] -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> Signedness -> [Char]
compilePrimToExtNp PrimType
t Signedness
s
  let dim_is_wrong :: PyExp
dim_is_wrong = [Char] -> PyExp -> PyExp
UnOp [Char]
"not" (PyExp -> PyExp) -> PyExp -> PyExp
forall a b. (a -> b) -> a -> b
$ [Char] -> PyExp -> PyExp -> PyExp
BinOp [Char]
"==" (PyExp -> [Char] -> PyExp
Field PyExp
e [Char]
"ndim") (PyExp -> PyExp) -> PyExp -> PyExp
forall a b. (a -> b) -> a -> b
$ Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a. Integral a => a -> Integer
toInteger (Int -> Integer) -> Int -> Integer
forall a b. (a -> b) -> a -> b
$ [DimSize] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
dims
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If
      PyExp
type_is_wrong
      [ Int -> PyExp -> Text -> PyStmt
badInput Int
i PyExp
e (Text -> PyStmt) -> Text -> PyStmt
forall a b. (a -> b) -> a -> b
$
          [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat (Int -> Text -> [Text]
forall a. Int -> a -> [a]
replicate ([DimSize] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
dims) Text
"[]")
            Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Bool -> PrimType -> Text
prettySigned (Signedness
s Signedness -> Signedness -> Bool
forall a. Eq a => a -> a -> Bool
== Signedness
Imp.Unsigned) PrimType
t
      ]
      []
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If
      PyExp
dtype_is_wrong
      [ Int -> PyExp -> Text -> PyExp -> PyExp -> PyStmt
badInputType
          Int
i
          PyExp
e
          ([Text] -> Text
forall a. Monoid a => [a] -> a
mconcat (Int -> Text -> [Text]
forall a. Int -> a -> [a]
replicate ([DimSize] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
dims) Text
"[]") Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Bool -> PrimType -> Text
prettySigned (Signedness
s Signedness -> Signedness -> Bool
forall a. Eq a => a -> a -> Bool
== Signedness
Imp.Unsigned) PrimType
t)
          ([Char] -> [PyExp] -> PyExp
simpleCall [Char]
"np.dtype" [[Char] -> PyExp
Var (PrimType -> Signedness -> [Char]
compilePrimToExtNp PrimType
t Signedness
s)])
          (PyExp -> [Char] -> PyExp
Field PyExp
e [Char]
"dtype")
      ]
      []
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If
      PyExp
dim_is_wrong
      [Int -> PyExp -> Text -> Int -> PyStmt
badInputDim Int
i PyExp
e (Bool -> PrimType -> Text
prettySigned (Signedness
s Signedness -> Signedness -> Bool
forall a. Eq a => a -> a -> Bool
== Signedness
Imp.Unsigned) PrimType
t) ([DimSize] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
dims)]
      []

  (DimSize -> Int32 -> CompilerM op s ())
-> [DimSize] -> [Int32] -> CompilerM op s ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ (PyExp -> DimSize -> Int32 -> CompilerM op s ()
forall op s. PyExp -> DimSize -> Int32 -> CompilerM op s ()
unpackDim PyExp
e) [DimSize]
dims [Int32
0 ..]
  PyExp
dest <- VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
mem
  let unwrap_call :: PyExp
unwrap_call = [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"unwrapArray" [PyExp
e]

  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign PyExp
dest PyExp
unwrap_call

extValueDescName :: Imp.ExternalValue -> T.Text
extValueDescName :: ExternalValue -> Text
extValueDescName (Imp.TransparentValue ValueDesc
v) = Text -> Text
extName (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ [Char] -> Text
T.pack ([Char] -> Text) -> [Char] -> Text
forall a b. (a -> b) -> a -> b
$ VName -> [Char]
compileName (VName -> [Char]) -> VName -> [Char]
forall a b. (a -> b) -> a -> b
$ ValueDesc -> VName
valueDescVName ValueDesc
v
extValueDescName (Imp.OpaqueValue Name
desc []) = Text -> Text
extName (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ Text -> Text
zEncodeText (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ Name -> Text
nameToText Name
desc
extValueDescName (Imp.OpaqueValue Name
desc (ValueDesc
v : [ValueDesc]
_)) =
  Text -> Text
extName (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ Text -> Text
zEncodeText (Name -> Text
nameToText Name
desc) Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Int -> Text
forall a. Pretty a => a -> Text
prettyText (VName -> Int
baseTag (ValueDesc -> VName
valueDescVName ValueDesc
v))

extName :: T.Text -> T.Text
extName :: Text -> Text
extName = (Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"_ext")

valueDescVName :: Imp.ValueDesc -> VName
valueDescVName :: ValueDesc -> VName
valueDescVName (Imp.ScalarValue PrimType
_ Signedness
_ VName
vname) = VName
vname
valueDescVName (Imp.ArrayValue VName
vname Space
_ PrimType
_ Signedness
_ [DimSize]
_) = VName
vname

-- Key into the FUTHARK_PRIMTYPES dict.
readTypeEnum :: PrimType -> Imp.Signedness -> T.Text
readTypeEnum :: PrimType -> Signedness -> Text
readTypeEnum (IntType IntType
Int8) Signedness
Imp.Unsigned = Text
"u8"
readTypeEnum (IntType IntType
Int16) Signedness
Imp.Unsigned = Text
"u16"
readTypeEnum (IntType IntType
Int32) Signedness
Imp.Unsigned = Text
"u32"
readTypeEnum (IntType IntType
Int64) Signedness
Imp.Unsigned = Text
"u64"
readTypeEnum (IntType IntType
Int8) Signedness
Imp.Signed = Text
"i8"
readTypeEnum (IntType IntType
Int16) Signedness
Imp.Signed = Text
"i16"
readTypeEnum (IntType IntType
Int32) Signedness
Imp.Signed = Text
"i32"
readTypeEnum (IntType IntType
Int64) Signedness
Imp.Signed = Text
"i64"
readTypeEnum (FloatType FloatType
Float16) Signedness
_ = Text
"f16"
readTypeEnum (FloatType FloatType
Float32) Signedness
_ = Text
"f32"
readTypeEnum (FloatType FloatType
Float64) Signedness
_ = Text
"f64"
readTypeEnum PrimType
Imp.Bool Signedness
_ = Text
"bool"
readTypeEnum PrimType
Unit Signedness
_ = Text
"bool"

readInput :: Imp.ExternalValue -> PyStmt
readInput :: ExternalValue -> PyStmt
readInput (Imp.OpaqueValue Name
desc [ValueDesc]
_) =
  PyExp -> PyStmt
Raise (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
    [Char] -> [PyExp] -> PyExp
simpleCall
      [Char]
"Exception"
      [Text -> PyExp
String (Text -> PyExp) -> Text -> PyExp
forall a b. (a -> b) -> a -> b
$ Text
"Cannot read argument of type " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Name -> Text
nameToText Name
desc Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"."]
readInput decl :: ExternalValue
decl@(Imp.TransparentValue (Imp.ScalarValue PrimType
bt Signedness
ept VName
_)) =
  let type_name :: Text
type_name = PrimType -> Signedness -> Text
readTypeEnum PrimType
bt Signedness
ept
   in PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var ([Char] -> PyExp) -> [Char] -> PyExp
forall a b. (a -> b) -> a -> b
$ Text -> [Char]
T.unpack (Text -> [Char]) -> Text -> [Char]
forall a b. (a -> b) -> a -> b
$ ExternalValue -> Text
extValueDescName ExternalValue
decl) (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"read_value" [Text -> PyExp
String Text
type_name]
readInput decl :: ExternalValue
decl@(Imp.TransparentValue (Imp.ArrayValue VName
_ Space
_ PrimType
bt Signedness
ept [DimSize]
dims)) =
  let type_name :: Text
type_name = PrimType -> Signedness -> Text
readTypeEnum PrimType
bt Signedness
ept
   in PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var ([Char] -> PyExp) -> [Char] -> PyExp
forall a b. (a -> b) -> a -> b
$ Text -> [Char]
T.unpack (Text -> [Char]) -> Text -> [Char]
forall a b. (a -> b) -> a -> b
$ ExternalValue -> Text
extValueDescName ExternalValue
decl) (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
        [Char] -> [PyExp] -> PyExp
simpleCall
          [Char]
"read_value"
          [Text -> PyExp
String (Text -> PyExp) -> Text -> PyExp
forall a b. (a -> b) -> a -> b
$ [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat (Int -> Text -> [Text]
forall a. Int -> a -> [a]
replicate ([DimSize] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
dims) Text
"[]") Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
type_name]

printValue :: [(Imp.ExternalValue, PyExp)] -> CompilerM op s [PyStmt]
printValue :: forall op s. [(ExternalValue, PyExp)] -> CompilerM op s [PyStmt]
printValue = ([[PyStmt]] -> [PyStmt])
-> CompilerM op s [[PyStmt]] -> CompilerM op s [PyStmt]
forall a b. (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [[PyStmt]] -> [PyStmt]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat (CompilerM op s [[PyStmt]] -> CompilerM op s [PyStmt])
-> ([(ExternalValue, PyExp)] -> CompilerM op s [[PyStmt]])
-> [(ExternalValue, PyExp)]
-> CompilerM op s [PyStmt]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((ExternalValue, PyExp) -> CompilerM op s [PyStmt])
-> [(ExternalValue, PyExp)] -> CompilerM op s [[PyStmt]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((ExternalValue -> PyExp -> CompilerM op s [PyStmt])
-> (ExternalValue, PyExp) -> CompilerM op s [PyStmt]
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry ExternalValue -> PyExp -> CompilerM op s [PyStmt]
forall {f :: * -> *}.
Applicative f =>
ExternalValue -> PyExp -> f [PyStmt]
printValue')
  where
    -- We copy non-host arrays to the host before printing.  This is
    -- done in a hacky way - we assume the value has a .get()-method
    -- that returns an equivalent Numpy array.  This works for PyOpenCL,
    -- but we will probably need yet another plugin mechanism here in
    -- the future.
    printValue' :: ExternalValue -> PyExp -> f [PyStmt]
printValue' (Imp.OpaqueValue Name
desc [ValueDesc]
_) PyExp
_ =
      [PyStmt] -> f [PyStmt]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
        [ PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
            [Char] -> [PyExp] -> PyExp
simpleCall
              [Char]
"sys.stdout.write"
              [Text -> PyExp
String (Text -> PyExp) -> Text -> PyExp
forall a b. (a -> b) -> a -> b
$ Text
"#<opaque " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Name -> Text
nameToText Name
desc Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
">"]
        ]
    printValue' (Imp.TransparentValue (Imp.ArrayValue VName
mem (Space [Char]
_) PrimType
bt Signedness
ept [DimSize]
shape)) PyExp
e =
      ExternalValue -> PyExp -> f [PyStmt]
printValue' (ValueDesc -> ExternalValue
Imp.TransparentValue (VName -> Space -> PrimType -> Signedness -> [DimSize] -> ValueDesc
Imp.ArrayValue VName
mem Space
DefaultSpace PrimType
bt Signedness
ept [DimSize]
shape)) (PyExp -> f [PyStmt]) -> PyExp -> f [PyStmt]
forall a b. (a -> b) -> a -> b
$
        [Char] -> [PyExp] -> PyExp
simpleCall (PyExp -> [Char]
forall a. Pretty a => a -> [Char]
prettyString PyExp
e [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
".get") []
    printValue' (Imp.TransparentValue ValueDesc
_) PyExp
e =
      [PyStmt] -> f [PyStmt]
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
        [ PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
            PyExp -> [PyArg] -> PyExp
Call
              ([Char] -> PyExp
Var [Char]
"write_value")
              [ PyExp -> PyArg
Arg PyExp
e,
                [Char] -> PyExp -> PyArg
ArgKeyword [Char]
"binary" ([Char] -> PyExp
Var [Char]
"binary_output")
              ],
          PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"sys.stdout.write" [Text -> PyExp
String Text
"\n"]
        ]

prepareEntry ::
  Imp.EntryPoint ->
  (Name, Imp.Function op) ->
  CompilerM
    op
    s
    ( [String],
      [PyStmt],
      [PyStmt],
      [PyStmt],
      [(Imp.ExternalValue, PyExp)]
    )
prepareEntry :: forall op s.
EntryPoint
-> (Name, Function op)
-> CompilerM
     op
     s
     ([[Char]], [PyStmt], [PyStmt], [PyStmt], [(ExternalValue, PyExp)])
prepareEntry (Imp.EntryPoint Name
_ [(Uniqueness, ExternalValue)]
results [((Name, Uniqueness), ExternalValue)]
args) (Name
fname, Imp.Function Maybe EntryPoint
_ [Param]
outputs [Param]
inputs Code op
_) = do
  let output_paramNames :: [[Char]]
output_paramNames = (Param -> [Char]) -> [Param] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> [Char]
compileName (VName -> [Char]) -> (Param -> VName) -> Param -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param -> VName
Imp.paramName) [Param]
outputs
      funTuple :: PyExp
funTuple = [PyExp] -> PyExp
tupleOrSingle ([PyExp] -> PyExp) -> [PyExp] -> PyExp
forall a b. (a -> b) -> a -> b
$ ([Char] -> PyExp) -> [[Char]] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Char] -> PyExp
Var [[Char]]
output_paramNames

  [PyStmt]
prepareIn <- CompilerM op s () -> CompilerM op s [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
collect (CompilerM op s () -> CompilerM op s [PyStmt])
-> CompilerM op s () -> CompilerM op s [PyStmt]
forall a b. (a -> b) -> a -> b
$ do
    [ExternalValue] -> CompilerM op s ()
forall op s. [ExternalValue] -> CompilerM op s ()
declEntryPointInputSizes ([ExternalValue] -> CompilerM op s ())
-> [ExternalValue] -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ (((Name, Uniqueness), ExternalValue) -> ExternalValue)
-> [((Name, Uniqueness), ExternalValue)] -> [ExternalValue]
forall a b. (a -> b) -> [a] -> [b]
map ((Name, Uniqueness), ExternalValue) -> ExternalValue
forall a b. (a, b) -> b
snd [((Name, Uniqueness), ExternalValue)]
args
    ((Int, ExternalValue, PyExp) -> CompilerM op s ())
-> [(Int, ExternalValue, PyExp)] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ (Int, ExternalValue, PyExp) -> CompilerM op s ()
forall op s. (Int, ExternalValue, PyExp) -> CompilerM op s ()
entryPointInput ([(Int, ExternalValue, PyExp)] -> CompilerM op s ())
-> ([PyExp] -> [(Int, ExternalValue, PyExp)])
-> [PyExp]
-> CompilerM op s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int]
-> [ExternalValue] -> [PyExp] -> [(Int, ExternalValue, PyExp)]
forall a b c. [a] -> [b] -> [c] -> [(a, b, c)]
zip3 [Int
0 ..] ((((Name, Uniqueness), ExternalValue) -> ExternalValue)
-> [((Name, Uniqueness), ExternalValue)] -> [ExternalValue]
forall a b. (a -> b) -> [a] -> [b]
map ((Name, Uniqueness), ExternalValue) -> ExternalValue
forall a b. (a, b) -> b
snd [((Name, Uniqueness), ExternalValue)]
args) ([PyExp] -> CompilerM op s ()) -> [PyExp] -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
      (((Name, Uniqueness), ExternalValue) -> PyExp)
-> [((Name, Uniqueness), ExternalValue)] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map ([Char] -> PyExp
Var ([Char] -> PyExp)
-> (((Name, Uniqueness), ExternalValue) -> [Char])
-> ((Name, Uniqueness), ExternalValue)
-> PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Text -> [Char]
T.unpack (Text -> [Char])
-> (((Name, Uniqueness), ExternalValue) -> Text)
-> ((Name, Uniqueness), ExternalValue)
-> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExternalValue -> Text
extValueDescName (ExternalValue -> Text)
-> (((Name, Uniqueness), ExternalValue) -> ExternalValue)
-> ((Name, Uniqueness), ExternalValue)
-> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Name, Uniqueness), ExternalValue) -> ExternalValue
forall a b. (a, b) -> b
snd) [((Name, Uniqueness), ExternalValue)]
args
  ([PyExp]
res, [PyStmt]
prepareOut) <- CompilerM op s [PyExp] -> CompilerM op s ([PyExp], [PyStmt])
forall op s a. CompilerM op s a -> CompilerM op s (a, [PyStmt])
collect' (CompilerM op s [PyExp] -> CompilerM op s ([PyExp], [PyStmt]))
-> CompilerM op s [PyExp] -> CompilerM op s ([PyExp], [PyStmt])
forall a b. (a -> b) -> a -> b
$ ((Uniqueness, ExternalValue) -> CompilerM op s PyExp)
-> [(Uniqueness, ExternalValue)] -> CompilerM op s [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (ExternalValue -> CompilerM op s PyExp
forall op s. ExternalValue -> CompilerM op s PyExp
entryPointOutput (ExternalValue -> CompilerM op s PyExp)
-> ((Uniqueness, ExternalValue) -> ExternalValue)
-> (Uniqueness, ExternalValue)
-> CompilerM op s PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Uniqueness, ExternalValue) -> ExternalValue
forall a b. (a, b) -> b
snd) [(Uniqueness, ExternalValue)]
results

  let argexps_lib :: [[Char]]
argexps_lib = (Param -> [Char]) -> [Param] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map (VName -> [Char]
compileName (VName -> [Char]) -> (Param -> VName) -> Param -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param -> VName
Imp.paramName) [Param]
inputs
      fname' :: Text
fname' = Text
"self." Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> Text
futharkFun (Name -> Text
nameToText Name
fname)

      -- We ignore overflow errors and the like for executable entry
      -- points.  These are (somewhat) well-defined in Futhark.
      ignore :: [Char] -> PyArg
ignore [Char]
s = [Char] -> PyExp -> PyArg
ArgKeyword [Char]
s (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ Text -> PyExp
String Text
"ignore"
      errstate :: PyExp
errstate = PyExp -> [PyArg] -> PyExp
Call ([Char] -> PyExp
Var [Char]
"np.errstate") ([PyArg] -> PyExp) -> [PyArg] -> PyExp
forall a b. (a -> b) -> a -> b
$ ([Char] -> PyArg) -> [[Char]] -> [PyArg]
forall a b. (a -> b) -> [a] -> [b]
map [Char] -> PyArg
ignore [[Char]
"divide", [Char]
"over", [Char]
"under", [Char]
"invalid"]

      call :: [[Char]] -> [PyStmt]
call [[Char]]
argexps =
        [ PyExp -> [PyStmt] -> PyStmt
With
            PyExp
errstate
            [PyExp -> PyExp -> PyStmt
Assign PyExp
funTuple (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall (Text -> [Char]
T.unpack Text
fname') (([Char] -> PyExp) -> [[Char]] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Char] -> PyExp
Var [[Char]]
argexps)]
        ]

  ([[Char]], [PyStmt], [PyStmt], [PyStmt], [(ExternalValue, PyExp)])
-> CompilerM
     op
     s
     ([[Char]], [PyStmt], [PyStmt], [PyStmt], [(ExternalValue, PyExp)])
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    ( (((Name, Uniqueness), ExternalValue) -> [Char])
-> [((Name, Uniqueness), ExternalValue)] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map (Text -> [Char]
T.unpack (Text -> [Char])
-> (((Name, Uniqueness), ExternalValue) -> Text)
-> ((Name, Uniqueness), ExternalValue)
-> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExternalValue -> Text
extValueDescName (ExternalValue -> Text)
-> (((Name, Uniqueness), ExternalValue) -> ExternalValue)
-> ((Name, Uniqueness), ExternalValue)
-> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Name, Uniqueness), ExternalValue) -> ExternalValue
forall a b. (a, b) -> b
snd) [((Name, Uniqueness), ExternalValue)]
args,
      [PyStmt]
prepareIn,
      [[Char]] -> [PyStmt]
call [[Char]]
argexps_lib,
      [PyStmt]
prepareOut,
      [ExternalValue] -> [PyExp] -> [(ExternalValue, PyExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip (((Uniqueness, ExternalValue) -> ExternalValue)
-> [(Uniqueness, ExternalValue)] -> [ExternalValue]
forall a b. (a -> b) -> [a] -> [b]
map (Uniqueness, ExternalValue) -> ExternalValue
forall a b. (a, b) -> b
snd [(Uniqueness, ExternalValue)]
results) [PyExp]
res
    )

data ReturnTiming = ReturnTiming | DoNotReturnTiming

compileEntryFun ::
  [PyStmt] ->
  ReturnTiming ->
  (Name, Imp.Function op) ->
  CompilerM op s (Maybe (PyFunDef, (PyExp, PyExp)))
compileEntryFun :: forall op s.
[PyStmt]
-> ReturnTiming
-> (Name, Function op)
-> CompilerM op s (Maybe (PyFunDef, (PyExp, PyExp)))
compileEntryFun [PyStmt]
sync ReturnTiming
timing (Name, Function op)
fun
  | Just EntryPoint
entry <- Function op -> Maybe EntryPoint
forall a. FunctionT a -> Maybe EntryPoint
Imp.functionEntry (Function op -> Maybe EntryPoint)
-> Function op -> Maybe EntryPoint
forall a b. (a -> b) -> a -> b
$ (Name, Function op) -> Function op
forall a b. (a, b) -> b
snd (Name, Function op)
fun = do
      let ename :: Name
ename = EntryPoint -> Name
Imp.entryPointName EntryPoint
entry
      ([[Char]]
params, [PyStmt]
prepareIn, [PyStmt]
body_lib, [PyStmt]
prepareOut, [(ExternalValue, PyExp)]
res) <- EntryPoint
-> (Name, Function op)
-> CompilerM
     op
     s
     ([[Char]], [PyStmt], [PyStmt], [PyStmt], [(ExternalValue, PyExp)])
forall op s.
EntryPoint
-> (Name, Function op)
-> CompilerM
     op
     s
     ([[Char]], [PyStmt], [PyStmt], [PyStmt], [(ExternalValue, PyExp)])
prepareEntry EntryPoint
entry (Name, Function op)
fun
      let ([PyStmt]
maybe_sync, PyStmt
ret) =
            case ReturnTiming
timing of
              ReturnTiming
DoNotReturnTiming ->
                ( [],
                  PyExp -> PyStmt
Return (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [PyExp] -> PyExp
tupleOrSingle ([PyExp] -> PyExp) -> [PyExp] -> PyExp
forall a b. (a -> b) -> a -> b
$ ((ExternalValue, PyExp) -> PyExp)
-> [(ExternalValue, PyExp)] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map (ExternalValue, PyExp) -> PyExp
forall a b. (a, b) -> b
snd [(ExternalValue, PyExp)]
res
                )
              ReturnTiming
ReturnTiming ->
                ( [PyStmt]
sync,
                  PyExp -> PyStmt
Return (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
                    [PyExp] -> PyExp
Tuple
                      [ [Char] -> PyExp
Var [Char]
"runtime",
                        [PyExp] -> PyExp
tupleOrSingle ([PyExp] -> PyExp) -> [PyExp] -> PyExp
forall a b. (a -> b) -> a -> b
$ ((ExternalValue, PyExp) -> PyExp)
-> [(ExternalValue, PyExp)] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map (ExternalValue, PyExp) -> PyExp
forall a b. (a, b) -> b
snd [(ExternalValue, PyExp)]
res
                      ]
                )
          ([Text]
pts, [Text]
rts) = EntryPoint -> ([Text], [Text])
entryTypes EntryPoint
entry

          do_run :: [PyStmt]
do_run =
            PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"time_start") ([Char] -> [PyExp] -> PyExp
simpleCall [Char]
"time.time" [])
              PyStmt -> [PyStmt] -> [PyStmt]
forall a. a -> [a] -> [a]
: [PyStmt]
body_lib
              [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [PyStmt]
maybe_sync
              [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [ PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"runtime") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
                     [Char] -> PyExp -> PyExp -> PyExp
BinOp
                       [Char]
"-"
                       (PyExp -> PyExp
toMicroseconds ([Char] -> [PyExp] -> PyExp
simpleCall [Char]
"time.time" []))
                       (PyExp -> PyExp
toMicroseconds ([Char] -> PyExp
Var [Char]
"time_start"))
                 ]

      Maybe (PyFunDef, (PyExp, PyExp))
-> CompilerM op s (Maybe (PyFunDef, (PyExp, PyExp)))
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (PyFunDef, (PyExp, PyExp))
 -> CompilerM op s (Maybe (PyFunDef, (PyExp, PyExp))))
-> Maybe (PyFunDef, (PyExp, PyExp))
-> CompilerM op s (Maybe (PyFunDef, (PyExp, PyExp)))
forall a b. (a -> b) -> a -> b
$
        (PyFunDef, (PyExp, PyExp)) -> Maybe (PyFunDef, (PyExp, PyExp))
forall a. a -> Maybe a
Just
          ( [Char] -> [[Char]] -> [PyStmt] -> PyFunDef
Def (Text -> [Char]
T.unpack (Name -> Text
escapeName Name
ename)) ([Char]
"self" [Char] -> [[Char]] -> [[Char]]
forall a. a -> [a] -> [a]
: [[Char]]
params) ([PyStmt] -> PyFunDef) -> [PyStmt] -> PyFunDef
forall a b. (a -> b) -> a -> b
$
              [PyStmt]
prepareIn [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [PyStmt]
do_run [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [PyStmt]
prepareOut [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [PyStmt]
sync [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [PyStmt
ret],
            ( Text -> PyExp
String (Name -> Text
nameToText Name
ename),
              [PyExp] -> PyExp
Tuple
                [ Text -> PyExp
String (Name -> Text
escapeName Name
ename),
                  [PyExp] -> PyExp
List ((Text -> PyExp) -> [Text] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map Text -> PyExp
String [Text]
pts),
                  [PyExp] -> PyExp
List ((Text -> PyExp) -> [Text] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map Text -> PyExp
String [Text]
rts)
                ]
            )
          )
  | Bool
otherwise = Maybe (PyFunDef, (PyExp, PyExp))
-> CompilerM op s (Maybe (PyFunDef, (PyExp, PyExp)))
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (PyFunDef, (PyExp, PyExp))
forall a. Maybe a
Nothing

entryTypes :: Imp.EntryPoint -> ([T.Text], [T.Text])
entryTypes :: EntryPoint -> ([Text], [Text])
entryTypes (Imp.EntryPoint Name
_ [(Uniqueness, ExternalValue)]
res [((Name, Uniqueness), ExternalValue)]
args) =
  ((((Name, Uniqueness), ExternalValue) -> Text)
-> [((Name, Uniqueness), ExternalValue)] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map ((Name, Uniqueness), ExternalValue) -> Text
forall {a} {a}. Pretty a => ((a, a), ExternalValue) -> Text
descArg [((Name, Uniqueness), ExternalValue)]
args, ((Uniqueness, ExternalValue) -> Text)
-> [(Uniqueness, ExternalValue)] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map (Uniqueness, ExternalValue) -> Text
forall {a}. Pretty a => (a, ExternalValue) -> Text
desc [(Uniqueness, ExternalValue)]
res)
  where
    descArg :: ((a, a), ExternalValue) -> Text
descArg ((a
_, a
u), ExternalValue
d) = (a, ExternalValue) -> Text
forall {a}. Pretty a => (a, ExternalValue) -> Text
desc (a
u, ExternalValue
d)
    desc :: (a, ExternalValue) -> Text
desc (a
u, Imp.OpaqueValue Name
d [ValueDesc]
_) = a -> Text
forall a. Pretty a => a -> Text
prettyText a
u Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Name -> Text
nameToText Name
d
    desc (a
u, Imp.TransparentValue (Imp.ScalarValue PrimType
pt Signedness
s VName
_)) = a -> Text
forall a. Pretty a => a -> Text
prettyText a
u Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> PrimType -> Signedness -> Text
readTypeEnum PrimType
pt Signedness
s
    desc (a
u, Imp.TransparentValue (Imp.ArrayValue VName
_ Space
_ PrimType
pt Signedness
s [DimSize]
dims)) =
      a -> Text
forall a. Pretty a => a -> Text
prettyText a
u Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> [Text] -> Text
forall a. Monoid a => [a] -> a
mconcat (Int -> Text -> [Text]
forall a. Int -> a -> [a]
replicate ([DimSize] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimSize]
dims) Text
"[]") Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> PrimType -> Signedness -> Text
readTypeEnum PrimType
pt Signedness
s

callEntryFun ::
  [PyStmt] ->
  (Name, Imp.Function op) ->
  CompilerM op s (Maybe (PyFunDef, T.Text, PyExp))
callEntryFun :: forall op s.
[PyStmt]
-> (Name, Function op)
-> CompilerM op s (Maybe (PyFunDef, Text, PyExp))
callEntryFun [PyStmt]
_ (Name
_, Imp.Function Maybe EntryPoint
Nothing [Param]
_ [Param]
_ Code op
_) = Maybe (PyFunDef, Text, PyExp)
-> CompilerM op s (Maybe (PyFunDef, Text, PyExp))
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (PyFunDef, Text, PyExp)
forall a. Maybe a
Nothing
callEntryFun [PyStmt]
pre_timing fun :: (Name, Function op)
fun@(Name
fname, Imp.Function (Just EntryPoint
entry) [Param]
_ [Param]
_ Code op
_) = do
  let Imp.EntryPoint Name
ename [(Uniqueness, ExternalValue)]
_ [((Name, Uniqueness), ExternalValue)]
decl_args = EntryPoint
entry
  ([[Char]]
_, [PyStmt]
prepare_in, [PyStmt]
body_bin, [PyStmt]
_, [(ExternalValue, PyExp)]
res) <- EntryPoint
-> (Name, Function op)
-> CompilerM
     op
     s
     ([[Char]], [PyStmt], [PyStmt], [PyStmt], [(ExternalValue, PyExp)])
forall op s.
EntryPoint
-> (Name, Function op)
-> CompilerM
     op
     s
     ([[Char]], [PyStmt], [PyStmt], [PyStmt], [(ExternalValue, PyExp)])
prepareEntry EntryPoint
entry (Name, Function op)
fun

  let str_input :: [PyStmt]
str_input = (((Name, Uniqueness), ExternalValue) -> PyStmt)
-> [((Name, Uniqueness), ExternalValue)] -> [PyStmt]
forall a b. (a -> b) -> [a] -> [b]
map (ExternalValue -> PyStmt
readInput (ExternalValue -> PyStmt)
-> (((Name, Uniqueness), ExternalValue) -> ExternalValue)
-> ((Name, Uniqueness), ExternalValue)
-> PyStmt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Name, Uniqueness), ExternalValue) -> ExternalValue
forall a b. (a, b) -> b
snd) [((Name, Uniqueness), ExternalValue)]
decl_args
      end_of_input :: [PyStmt]
end_of_input = [PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"end_of_input" [Text -> PyExp
String (Text -> PyExp) -> Text -> PyExp
forall a b. (a -> b) -> a -> b
$ Name -> Text
forall a. Pretty a => a -> Text
prettyText Name
fname]]

      exitcall :: [PyStmt]
exitcall = [PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"sys.exit" [PyExp -> [Char] -> PyExp
Field (Text -> PyExp
String Text
"Assertion.{} failed") [Char]
"format(e)"]]
      except' :: PyExcept
except' = PyExp -> [PyStmt] -> PyExcept
Catch ([Char] -> PyExp
Var [Char]
"AssertionError") [PyStmt]
exitcall
      do_run :: [PyStmt]
do_run = [PyStmt]
body_bin [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [PyStmt]
pre_timing
      ([PyStmt]
do_run_with_timing, PyStmt
close_runtime_file) = [PyStmt] -> ([PyStmt], PyStmt)
addTiming [PyStmt]
do_run

      do_warmup_run :: PyStmt
do_warmup_run =
        PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If ([Char] -> PyExp
Var [Char]
"do_warmup_run") [PyStmt]
do_run []

      do_num_runs :: PyStmt
do_num_runs =
        [Char] -> PyExp -> [PyStmt] -> PyStmt
For
          [Char]
"i"
          ([Char] -> [PyExp] -> PyExp
simpleCall [Char]
"range" [[Char] -> [PyExp] -> PyExp
simpleCall [Char]
"int" [[Char] -> PyExp
Var [Char]
"num_runs"]])
          [PyStmt]
do_run_with_timing

  [PyStmt]
str_output <- [(ExternalValue, PyExp)] -> CompilerM op s [PyStmt]
forall op s. [(ExternalValue, PyExp)] -> CompilerM op s [PyStmt]
printValue [(ExternalValue, PyExp)]
res

  let fname' :: [Char]
fname' = [Char]
"entry_" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Text -> [Char]
T.unpack (Name -> Text
escapeName Name
fname)

  Maybe (PyFunDef, Text, PyExp)
-> CompilerM op s (Maybe (PyFunDef, Text, PyExp))
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (PyFunDef, Text, PyExp)
 -> CompilerM op s (Maybe (PyFunDef, Text, PyExp)))
-> Maybe (PyFunDef, Text, PyExp)
-> CompilerM op s (Maybe (PyFunDef, Text, PyExp))
forall a b. (a -> b) -> a -> b
$
    (PyFunDef, Text, PyExp) -> Maybe (PyFunDef, Text, PyExp)
forall a. a -> Maybe a
Just
      ( [Char] -> [[Char]] -> [PyStmt] -> PyFunDef
Def [Char]
fname' [] ([PyStmt] -> PyFunDef) -> [PyStmt] -> PyFunDef
forall a b. (a -> b) -> a -> b
$
          [PyStmt]
str_input
            [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [PyStmt]
end_of_input
            [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [PyStmt]
prepare_in
            [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [[PyStmt] -> [PyExcept] -> PyStmt
Try [PyStmt
do_warmup_run, PyStmt
do_num_runs] [PyExcept
except']]
            [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [PyStmt
close_runtime_file]
            [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [PyStmt]
str_output,
        Name -> Text
nameToText Name
ename,
        [Char] -> PyExp
Var [Char]
fname'
      )

addTiming :: [PyStmt] -> ([PyStmt], PyStmt)
addTiming :: [PyStmt] -> ([PyStmt], PyStmt)
addTiming [PyStmt]
statements =
  ( [PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"time_start") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"time.time" []]
      [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [PyStmt]
statements
      [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [ PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
"time_end") (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"time.time" [],
           PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If ([Char] -> PyExp
Var [Char]
"runtime_file") [PyStmt]
print_runtime []
         ],
    PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If ([Char] -> PyExp
Var [Char]
"runtime_file") [PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"runtime_file.close" []] []
  )
  where
    print_runtime :: [PyStmt]
print_runtime =
      [ PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$
          [Char] -> [PyExp] -> PyExp
simpleCall
            [Char]
"runtime_file.write"
            [ [Char] -> [PyExp] -> PyExp
simpleCall
                [Char]
"str"
                [ [Char] -> PyExp -> PyExp -> PyExp
BinOp
                    [Char]
"-"
                    (PyExp -> PyExp
toMicroseconds ([Char] -> PyExp
Var [Char]
"time_end"))
                    (PyExp -> PyExp
toMicroseconds ([Char] -> PyExp
Var [Char]
"time_start"))
                ]
            ],
        PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"runtime_file.write" [Text -> PyExp
String Text
"\n"],
        PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"runtime_file.flush" []
      ]

toMicroseconds :: PyExp -> PyExp
toMicroseconds :: PyExp -> PyExp
toMicroseconds PyExp
x =
  [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"int" [[Char] -> PyExp -> PyExp -> PyExp
BinOp [Char]
"*" PyExp
x (PyExp -> PyExp) -> PyExp -> PyExp
forall a b. (a -> b) -> a -> b
$ Integer -> PyExp
Integer Integer
1000000]

compileUnOp :: Imp.UnOp -> String
compileUnOp :: UnOp -> [Char]
compileUnOp UnOp
op =
  case UnOp
op of
    UnOp
Not -> [Char]
"not"
    Complement {} -> [Char]
"~"
    Abs {} -> [Char]
"abs"
    FAbs {} -> [Char]
"abs"
    SSignum {} -> [Char]
"ssignum"
    USignum {} -> [Char]
"usignum"
    FSignum {} -> [Char]
"np.sign"

compileBinOpLike ::
  (Monad m) =>
  (v -> m PyExp) ->
  Imp.PrimExp v ->
  Imp.PrimExp v ->
  m (PyExp, PyExp, String -> m PyExp)
compileBinOpLike :: forall (m :: * -> *) v.
Monad m =>
(v -> m PyExp)
-> PrimExp v -> PrimExp v -> m (PyExp, PyExp, [Char] -> m PyExp)
compileBinOpLike v -> m PyExp
f PrimExp v
x PrimExp v
y = do
  PyExp
x' <- (v -> m PyExp) -> PrimExp v -> m PyExp
forall (m :: * -> *) v.
Monad m =>
(v -> m PyExp) -> PrimExp v -> m PyExp
compilePrimExp v -> m PyExp
f PrimExp v
x
  PyExp
y' <- (v -> m PyExp) -> PrimExp v -> m PyExp
forall (m :: * -> *) v.
Monad m =>
(v -> m PyExp) -> PrimExp v -> m PyExp
compilePrimExp v -> m PyExp
f PrimExp v
y
  let simple :: [Char] -> f PyExp
simple [Char]
s = PyExp -> f PyExp
forall a. a -> f a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PyExp -> f PyExp) -> PyExp -> f PyExp
forall a b. (a -> b) -> a -> b
$ [Char] -> PyExp -> PyExp -> PyExp
BinOp [Char]
s PyExp
x' PyExp
y'
  (PyExp, PyExp, [Char] -> m PyExp)
-> m (PyExp, PyExp, [Char] -> m PyExp)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PyExp
x', PyExp
y', [Char] -> m PyExp
forall {f :: * -> *}. Applicative f => [Char] -> f PyExp
simple)

-- | The ctypes type corresponding to a 'PrimType'.
compilePrimType :: PrimType -> String
compilePrimType :: PrimType -> [Char]
compilePrimType PrimType
t =
  case PrimType
t of
    IntType IntType
Int8 -> [Char]
"ct.c_int8"
    IntType IntType
Int16 -> [Char]
"ct.c_int16"
    IntType IntType
Int32 -> [Char]
"ct.c_int32"
    IntType IntType
Int64 -> [Char]
"ct.c_int64"
    FloatType FloatType
Float16 -> [Char]
"ct.c_uint16"
    FloatType FloatType
Float32 -> [Char]
"ct.c_float"
    FloatType FloatType
Float64 -> [Char]
"ct.c_double"
    PrimType
Imp.Bool -> [Char]
"ct.c_bool"
    PrimType
Unit -> [Char]
"ct.c_bool"

-- | The Numpy type corresponding to a 'PrimType'.
compilePrimToNp :: Imp.PrimType -> String
compilePrimToNp :: PrimType -> [Char]
compilePrimToNp PrimType
bt =
  case PrimType
bt of
    IntType IntType
Int8 -> [Char]
"np.int8"
    IntType IntType
Int16 -> [Char]
"np.int16"
    IntType IntType
Int32 -> [Char]
"np.int32"
    IntType IntType
Int64 -> [Char]
"np.int64"
    FloatType FloatType
Float16 -> [Char]
"np.float16"
    FloatType FloatType
Float32 -> [Char]
"np.float32"
    FloatType FloatType
Float64 -> [Char]
"np.float64"
    PrimType
Imp.Bool -> [Char]
"np.byte"
    PrimType
Unit -> [Char]
"np.byte"

-- | The Numpy type corresponding to a 'PrimType', taking sign into account.
compilePrimToExtNp :: Imp.PrimType -> Imp.Signedness -> String
compilePrimToExtNp :: PrimType -> Signedness -> [Char]
compilePrimToExtNp PrimType
bt Signedness
ept =
  case (PrimType
bt, Signedness
ept) of
    (IntType IntType
Int8, Signedness
Imp.Unsigned) -> [Char]
"np.uint8"
    (IntType IntType
Int16, Signedness
Imp.Unsigned) -> [Char]
"np.uint16"
    (IntType IntType
Int32, Signedness
Imp.Unsigned) -> [Char]
"np.uint32"
    (IntType IntType
Int64, Signedness
Imp.Unsigned) -> [Char]
"np.uint64"
    (IntType IntType
Int8, Signedness
_) -> [Char]
"np.int8"
    (IntType IntType
Int16, Signedness
_) -> [Char]
"np.int16"
    (IntType IntType
Int32, Signedness
_) -> [Char]
"np.int32"
    (IntType IntType
Int64, Signedness
_) -> [Char]
"np.int64"
    (FloatType FloatType
Float16, Signedness
_) -> [Char]
"np.float16"
    (FloatType FloatType
Float32, Signedness
_) -> [Char]
"np.float32"
    (FloatType FloatType
Float64, Signedness
_) -> [Char]
"np.float64"
    (PrimType
Imp.Bool, Signedness
_) -> [Char]
"np.bool_"
    (PrimType
Unit, Signedness
_) -> [Char]
"np.byte"

-- | Convert from scalar to storage representation for the given type.
toStorage :: PrimType -> PyExp -> PyExp
toStorage :: PrimType -> PyExp -> PyExp
toStorage (FloatType FloatType
Float16) PyExp
e =
  [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"ct.c_int16" [[Char] -> [PyExp] -> PyExp
simpleCall [Char]
"futhark_to_bits16" [PyExp
e]]
toStorage PrimType
t PyExp
e = [Char] -> [PyExp] -> PyExp
simpleCall (PrimType -> [Char]
compilePrimType PrimType
t) [PyExp
e]

-- | Convert from storage to scalar representation for the given type.
fromStorage :: PrimType -> PyExp -> PyExp
fromStorage :: PrimType -> PyExp -> PyExp
fromStorage (FloatType FloatType
Float16) PyExp
e =
  [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"futhark_from_bits16" [[Char] -> [PyExp] -> PyExp
simpleCall [Char]
"np.int16" [PyExp
e]]
fromStorage PrimType
t PyExp
e = [Char] -> [PyExp] -> PyExp
simpleCall (PrimType -> [Char]
compilePrimToNp PrimType
t) [PyExp
e]

compilePrimValue :: Imp.PrimValue -> PyExp
compilePrimValue :: PrimValue -> PyExp
compilePrimValue (IntValue (Int8Value Int8
v)) =
  [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"np.int8" [Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ Int8 -> Integer
forall a. Integral a => a -> Integer
toInteger Int8
v]
compilePrimValue (IntValue (Int16Value Int16
v)) =
  [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"np.int16" [Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ Int16 -> Integer
forall a. Integral a => a -> Integer
toInteger Int16
v]
compilePrimValue (IntValue (Int32Value Int32
v)) =
  [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"np.int32" [Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ Int32 -> Integer
forall a. Integral a => a -> Integer
toInteger Int32
v]
compilePrimValue (IntValue (Int64Value Int64
v)) =
  [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"np.int64" [Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ Int64 -> Integer
forall a. Integral a => a -> Integer
toInteger Int64
v]
compilePrimValue (FloatValue (Float16Value Half
v))
  | Half -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Half
v =
      if Half
v Half -> Half -> Bool
forall a. Ord a => a -> a -> Bool
> Half
0 then [Char] -> PyExp
Var [Char]
"np.float16(np.inf)" else [Char] -> PyExp
Var [Char]
"np.float16(-np.inf)"
  | Half -> Bool
forall a. RealFloat a => a -> Bool
isNaN Half
v =
      [Char] -> PyExp
Var [Char]
"np.float16(np.nan)"
  | Bool
otherwise = [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"np.float16" [Double -> PyExp
Float (Double -> PyExp) -> Double -> PyExp
forall a b. (a -> b) -> a -> b
$ Rational -> Double
forall a. Fractional a => Rational -> a
fromRational (Rational -> Double) -> Rational -> Double
forall a b. (a -> b) -> a -> b
$ Half -> Rational
forall a. Real a => a -> Rational
toRational Half
v]
compilePrimValue (FloatValue (Float32Value Float
v))
  | Float -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Float
v =
      if Float
v Float -> Float -> Bool
forall a. Ord a => a -> a -> Bool
> Float
0 then [Char] -> PyExp
Var [Char]
"np.float32(np.inf)" else [Char] -> PyExp
Var [Char]
"np.float32(-np.inf)"
  | Float -> Bool
forall a. RealFloat a => a -> Bool
isNaN Float
v =
      [Char] -> PyExp
Var [Char]
"np.float32(np.nan)"
  | Bool
otherwise = [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"np.float32" [Double -> PyExp
Float (Double -> PyExp) -> Double -> PyExp
forall a b. (a -> b) -> a -> b
$ Rational -> Double
forall a. Fractional a => Rational -> a
fromRational (Rational -> Double) -> Rational -> Double
forall a b. (a -> b) -> a -> b
$ Float -> Rational
forall a. Real a => a -> Rational
toRational Float
v]
compilePrimValue (FloatValue (Float64Value Double
v))
  | Double -> Bool
forall a. RealFloat a => a -> Bool
isInfinite Double
v =
      if Double
v Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
> Double
0 then [Char] -> PyExp
Var [Char]
"np.inf" else [Char] -> PyExp
Var [Char]
"-np.inf"
  | Double -> Bool
forall a. RealFloat a => a -> Bool
isNaN Double
v =
      [Char] -> PyExp
Var [Char]
"np.float64(np.nan)"
  | Bool
otherwise = [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"np.float64" [Double -> PyExp
Float (Double -> PyExp) -> Double -> PyExp
forall a b. (a -> b) -> a -> b
$ Rational -> Double
forall a. Fractional a => Rational -> a
fromRational (Rational -> Double) -> Rational -> Double
forall a b. (a -> b) -> a -> b
$ Double -> Rational
forall a. Real a => a -> Rational
toRational Double
v]
compilePrimValue (BoolValue Bool
v) = Bool -> PyExp
Bool Bool
v
compilePrimValue PrimValue
UnitValue = [Char] -> PyExp
Var [Char]
"np.byte(0)"

compileVar :: VName -> CompilerM op s PyExp
compileVar :: forall op s. VName -> CompilerM op s PyExp
compileVar VName
v = (CompilerEnv op s -> PyExp) -> CompilerM op s PyExp
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((CompilerEnv op s -> PyExp) -> CompilerM op s PyExp)
-> (CompilerEnv op s -> PyExp) -> CompilerM op s PyExp
forall a b. (a -> b) -> a -> b
$ PyExp -> Maybe PyExp -> PyExp
forall a. a -> Maybe a -> a
fromMaybe ([Char] -> PyExp
Var [Char]
v') (Maybe PyExp -> PyExp)
-> (CompilerEnv op s -> Maybe PyExp) -> CompilerEnv op s -> PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> Map [Char] PyExp -> Maybe PyExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup [Char]
v' (Map [Char] PyExp -> Maybe PyExp)
-> (CompilerEnv op s -> Map [Char] PyExp)
-> CompilerEnv op s
-> Maybe PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Map [Char] PyExp
forall op s. CompilerEnv op s -> Map [Char] PyExp
envVarExp
  where
    v' :: [Char]
v' = VName -> [Char]
compileName VName
v

-- | Tell me how to compile a @v@, and I'll Compile any @PrimExp v@ for you.
compilePrimExp :: (Monad m) => (v -> m PyExp) -> Imp.PrimExp v -> m PyExp
compilePrimExp :: forall (m :: * -> *) v.
Monad m =>
(v -> m PyExp) -> PrimExp v -> m PyExp
compilePrimExp v -> m PyExp
_ (Imp.ValueExp PrimValue
v) = PyExp -> m PyExp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PyExp -> m PyExp) -> PyExp -> m PyExp
forall a b. (a -> b) -> a -> b
$ PrimValue -> PyExp
compilePrimValue PrimValue
v
compilePrimExp v -> m PyExp
f (Imp.LeafExp v
v PrimType
_) = v -> m PyExp
f v
v
compilePrimExp v -> m PyExp
f (Imp.BinOpExp BinOp
op PrimExp v
x PrimExp v
y) = do
  (PyExp
x', PyExp
y', [Char] -> m PyExp
simple) <- (v -> m PyExp)
-> PrimExp v -> PrimExp v -> m (PyExp, PyExp, [Char] -> m PyExp)
forall (m :: * -> *) v.
Monad m =>
(v -> m PyExp)
-> PrimExp v -> PrimExp v -> m (PyExp, PyExp, [Char] -> m PyExp)
compileBinOpLike v -> m PyExp
f PrimExp v
x PrimExp v
y
  case BinOp
op of
    Add {} -> [Char] -> m PyExp
simple [Char]
"+"
    Sub {} -> [Char] -> m PyExp
simple [Char]
"-"
    Mul {} -> [Char] -> m PyExp
simple [Char]
"*"
    FAdd {} -> [Char] -> m PyExp
simple [Char]
"+"
    FSub {} -> [Char] -> m PyExp
simple [Char]
"-"
    FMul {} -> [Char] -> m PyExp
simple [Char]
"*"
    FDiv {} -> [Char] -> m PyExp
simple [Char]
"/"
    FMod {} -> [Char] -> m PyExp
simple [Char]
"%"
    Xor {} -> [Char] -> m PyExp
simple [Char]
"^"
    And {} -> [Char] -> m PyExp
simple [Char]
"&"
    Or {} -> [Char] -> m PyExp
simple [Char]
"|"
    Shl {} -> [Char] -> m PyExp
simple [Char]
"<<"
    LogAnd {} -> [Char] -> m PyExp
simple [Char]
"and"
    LogOr {} -> [Char] -> m PyExp
simple [Char]
"or"
    BinOp
_ -> PyExp -> m PyExp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PyExp -> m PyExp) -> PyExp -> m PyExp
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall (BinOp -> [Char]
forall a. Pretty a => a -> [Char]
prettyString BinOp
op) [PyExp
x', PyExp
y']
compilePrimExp v -> m PyExp
f (Imp.ConvOpExp ConvOp
conv PrimExp v
x) = do
  PyExp
x' <- (v -> m PyExp) -> PrimExp v -> m PyExp
forall (m :: * -> *) v.
Monad m =>
(v -> m PyExp) -> PrimExp v -> m PyExp
compilePrimExp v -> m PyExp
f PrimExp v
x
  PyExp -> m PyExp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PyExp -> m PyExp) -> PyExp -> m PyExp
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall (ConvOp -> [Char]
forall a. Pretty a => a -> [Char]
prettyString ConvOp
conv) [PyExp
x']
compilePrimExp v -> m PyExp
f (Imp.CmpOpExp CmpOp
cmp PrimExp v
x PrimExp v
y) = do
  (PyExp
x', PyExp
y', [Char] -> m PyExp
simple) <- (v -> m PyExp)
-> PrimExp v -> PrimExp v -> m (PyExp, PyExp, [Char] -> m PyExp)
forall (m :: * -> *) v.
Monad m =>
(v -> m PyExp)
-> PrimExp v -> PrimExp v -> m (PyExp, PyExp, [Char] -> m PyExp)
compileBinOpLike v -> m PyExp
f PrimExp v
x PrimExp v
y
  case CmpOp
cmp of
    CmpEq {} -> [Char] -> m PyExp
simple [Char]
"=="
    FCmpLt {} -> [Char] -> m PyExp
simple [Char]
"<"
    FCmpLe {} -> [Char] -> m PyExp
simple [Char]
"<="
    CmpOp
CmpLlt -> [Char] -> m PyExp
simple [Char]
"<"
    CmpOp
CmpLle -> [Char] -> m PyExp
simple [Char]
"<="
    CmpOp
_ -> PyExp -> m PyExp
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PyExp -> m PyExp) -> PyExp -> m PyExp
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall (CmpOp -> [Char]
forall a. Pretty a => a -> [Char]
prettyString CmpOp
cmp) [PyExp
x', PyExp
y']
compilePrimExp v -> m PyExp
f (Imp.UnOpExp UnOp
op PrimExp v
exp1) =
  [Char] -> PyExp -> PyExp
UnOp (UnOp -> [Char]
compileUnOp UnOp
op) (PyExp -> PyExp) -> m PyExp -> m PyExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (v -> m PyExp) -> PrimExp v -> m PyExp
forall (m :: * -> *) v.
Monad m =>
(v -> m PyExp) -> PrimExp v -> m PyExp
compilePrimExp v -> m PyExp
f PrimExp v
exp1
compilePrimExp v -> m PyExp
f (Imp.FunExp [Char]
h [PrimExp v]
args PrimType
_) =
  [Char] -> [PyExp] -> PyExp
simpleCall (Text -> [Char]
T.unpack (Text -> Text
futharkFun ([Char] -> Text
forall a. Pretty a => a -> Text
prettyText [Char]
h))) ([PyExp] -> PyExp) -> m [PyExp] -> m PyExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (PrimExp v -> m PyExp) -> [PrimExp v] -> m [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM ((v -> m PyExp) -> PrimExp v -> m PyExp
forall (m :: * -> *) v.
Monad m =>
(v -> m PyExp) -> PrimExp v -> m PyExp
compilePrimExp v -> m PyExp
f) [PrimExp v]
args

compileExp :: Imp.Exp -> CompilerM op s PyExp
compileExp :: forall op s. Exp -> CompilerM op s PyExp
compileExp = (VName -> CompilerM op s PyExp) -> Exp -> CompilerM op s PyExp
forall (m :: * -> *) v.
Monad m =>
(v -> m PyExp) -> PrimExp v -> m PyExp
compilePrimExp VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar

errorMsgString :: Imp.ErrorMsg Imp.Exp -> CompilerM op s (T.Text, [PyExp])
errorMsgString :: forall op s. ErrorMsg Exp -> CompilerM op s (Text, [PyExp])
errorMsgString (Imp.ErrorMsg [ErrorMsgPart Exp]
parts) = do
  let onPart :: ErrorMsgPart Exp -> CompilerM op s (a, PyExp)
onPart (Imp.ErrorString Text
s) = (a, PyExp) -> CompilerM op s (a, PyExp)
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a
"%s", Text -> PyExp
String Text
s)
      onPart (Imp.ErrorVal IntType {} Exp
x) = (a
"%d",) (PyExp -> (a, PyExp))
-> CompilerM op s PyExp -> CompilerM op s (a, PyExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp Exp
x
      onPart (Imp.ErrorVal FloatType {} Exp
x) = (a
"%f",) (PyExp -> (a, PyExp))
-> CompilerM op s PyExp -> CompilerM op s (a, PyExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp Exp
x
      onPart (Imp.ErrorVal PrimType
Imp.Bool Exp
x) = (a
"%r",) (PyExp -> (a, PyExp))
-> CompilerM op s PyExp -> CompilerM op s (a, PyExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp Exp
x
      onPart (Imp.ErrorVal Unit {} Exp
x) = (a
"%r",) (PyExp -> (a, PyExp))
-> CompilerM op s PyExp -> CompilerM op s (a, PyExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp Exp
x
  ([Text]
formatstrs, [PyExp]
formatargs) <- (ErrorMsgPart Exp -> CompilerM op s (Text, PyExp))
-> [ErrorMsgPart Exp] -> CompilerM op s ([Text], [PyExp])
forall (m :: * -> *) a b c.
Applicative m =>
(a -> m (b, c)) -> [a] -> m ([b], [c])
mapAndUnzipM ErrorMsgPart Exp -> CompilerM op s (Text, PyExp)
forall {a} {op} {s}.
IsString a =>
ErrorMsgPart Exp -> CompilerM op s (a, PyExp)
onPart [ErrorMsgPart Exp]
parts
  (Text, [PyExp]) -> CompilerM op s (Text, [PyExp])
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Text] -> Text
forall a. Monoid a => [a] -> a
mconcat [Text]
formatstrs, [PyExp]
formatargs)

generateRead ::
  PyExp ->
  PyExp ->
  PrimType ->
  Space ->
  CompilerM op s PyExp
generateRead :: forall op s.
PyExp -> PyExp -> PrimType -> Space -> CompilerM op s PyExp
generateRead PyExp
_ PyExp
_ PrimType
Unit Space
_ =
  PyExp -> CompilerM op s PyExp
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PrimValue -> PyExp
compilePrimValue PrimValue
UnitValue)
generateRead PyExp
_ PyExp
_ PrimType
_ ScalarSpace {} =
  [Char] -> CompilerM op s PyExp
forall a. HasCallStack => [Char] -> a
error [Char]
"GenericPython.generateRead: ScalarSpace"
generateRead PyExp
src PyExp
iexp PrimType
pt Space
DefaultSpace = do
  let pt' :: [Char]
pt' = PrimType -> [Char]
compilePrimType PrimType
pt
  PyExp -> CompilerM op s PyExp
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (PyExp -> CompilerM op s PyExp) -> PyExp -> CompilerM op s PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PyExp -> PyExp
fromStorage PrimType
pt (PyExp -> PyExp) -> PyExp -> PyExp
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"indexArray" [PyExp
src, PyExp
iexp, [Char] -> PyExp
Var [Char]
pt']
generateRead PyExp
src PyExp
iexp PrimType
pt (Space [Char]
space) = do
  ReadScalar op s
reader <- (CompilerEnv op s -> ReadScalar op s)
-> CompilerM op s (ReadScalar op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> ReadScalar op s
forall op s. CompilerEnv op s -> ReadScalar op s
envReadScalar
  ReadScalar op s
reader PyExp
src PyExp
iexp PrimType
pt [Char]
space

generateWrite ::
  PyExp ->
  PyExp ->
  PrimType ->
  Space ->
  PyExp ->
  CompilerM op s ()
generateWrite :: forall op s.
PyExp -> PyExp -> PrimType -> Space -> PyExp -> CompilerM op s ()
generateWrite PyExp
_ PyExp
_ PrimType
Unit Space
_ PyExp
_ = () -> CompilerM op s ()
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
generateWrite PyExp
_ PyExp
_ PrimType
_ ScalarSpace {} PyExp
_ = do
  [Char] -> CompilerM op s ()
forall a. HasCallStack => [Char] -> a
error [Char]
"GenericPython.generateWrite: ScalarSpace"
generateWrite PyExp
dst PyExp
iexp PrimType
pt (Imp.Space [Char]
space) PyExp
elemexp = do
  WriteScalar op s
writer <- (CompilerEnv op s -> WriteScalar op s)
-> CompilerM op s (WriteScalar op s)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> WriteScalar op s
forall op s. CompilerEnv op s -> WriteScalar op s
envWriteScalar
  WriteScalar op s
writer PyExp
dst PyExp
iexp PrimType
pt [Char]
space PyExp
elemexp
generateWrite PyExp
dst PyExp
iexp PrimType
_ Space
DefaultSpace PyExp
elemexp =
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"writeScalarArray" [PyExp
dst, PyExp
iexp, PyExp
elemexp]

-- | Compile an 'Copy' using sequential nested loops, but
-- parameterised over how to do the reads and writes.
compileCopyWith ::
  [Count Elements (TExp Int64)] ->
  (PyExp -> PyExp -> CompilerM op s ()) ->
  ( Count Elements (TExp Int64),
    [Count Elements (TExp Int64)]
  ) ->
  (PyExp -> CompilerM op s PyExp) ->
  ( Count Elements (TExp Int64),
    [Count Elements (TExp Int64)]
  ) ->
  CompilerM op s ()
compileCopyWith :: forall op s.
[Count Elements (TExp Int64)]
-> (PyExp -> PyExp -> CompilerM op s ())
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> (PyExp -> CompilerM op s PyExp)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> CompilerM op s ()
compileCopyWith [Count Elements (TExp Int64)]
shape PyExp -> PyExp -> CompilerM op s ()
doWrite (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
dst_lmad PyExp -> CompilerM op s PyExp
doRead (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
src_lmad = do
  let (Count Elements (TExp Int64)
dstoffset, [Count Elements (TExp Int64)]
dststrides) = (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
dst_lmad
      (Count Elements (TExp Int64)
srcoffset, [Count Elements (TExp Int64)]
srcstrides) = (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
src_lmad
  [PyExp]
shape' <- (Count Elements (TExp Int64) -> CompilerM op s PyExp)
-> [Count Elements (TExp Int64)] -> CompilerM op s [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp (Exp -> CompilerM op s PyExp)
-> (Count Elements (TExp Int64) -> Exp)
-> Count Elements (TExp Int64)
-> CompilerM op s PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp)
-> (Count Elements (TExp Int64) -> TExp Int64)
-> Count Elements (TExp Int64)
-> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Count Elements (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount) [Count Elements (TExp Int64)]
shape
  [PyStmt]
body <- CompilerM op s () -> CompilerM op s [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
collect (CompilerM op s () -> CompilerM op s [PyStmt])
-> CompilerM op s () -> CompilerM op s [PyStmt]
forall a b. (a -> b) -> a -> b
$ do
    PyExp
dst_i <-
      Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp (Exp -> CompilerM op s PyExp)
-> (Count Elements (TExp Int64) -> Exp)
-> Count Elements (TExp Int64)
-> CompilerM op s PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp)
-> (Count Elements (TExp Int64) -> TExp Int64)
-> Count Elements (TExp Int64)
-> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Count Elements (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount (Count Elements (TExp Int64) -> CompilerM op s PyExp)
-> Count Elements (TExp Int64) -> CompilerM op s PyExp
forall a b. (a -> b) -> a -> b
$
        Count Elements (TExp Int64)
dstoffset Count Elements (TExp Int64)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int64)
forall a. Num a => a -> a -> a
+ [Count Elements (TExp Int64)] -> Count Elements (TExp Int64)
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Count Elements (TExp Int64)
 -> Count Elements (TExp Int64) -> Count Elements (TExp Int64))
-> [Count Elements (TExp Int64)]
-> [Count Elements (TExp Int64)]
-> [Count Elements (TExp Int64)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Count Elements (TExp Int64)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int64)
forall a. Num a => a -> a -> a
(*) [Count Elements (TExp Int64)]
is' [Count Elements (TExp Int64)]
dststrides)
    PyExp
src_i <-
      Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp (Exp -> CompilerM op s PyExp)
-> (Count Elements (TExp Int64) -> Exp)
-> Count Elements (TExp Int64)
-> CompilerM op s PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TExp Int64 -> Exp)
-> (Count Elements (TExp Int64) -> TExp Int64)
-> Count Elements (TExp Int64)
-> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Count Elements (TExp Int64) -> TExp Int64
forall {k} (u :: k) e. Count u e -> e
unCount (Count Elements (TExp Int64) -> CompilerM op s PyExp)
-> Count Elements (TExp Int64) -> CompilerM op s PyExp
forall a b. (a -> b) -> a -> b
$
        Count Elements (TExp Int64)
srcoffset Count Elements (TExp Int64)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int64)
forall a. Num a => a -> a -> a
+ [Count Elements (TExp Int64)] -> Count Elements (TExp Int64)
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ((Count Elements (TExp Int64)
 -> Count Elements (TExp Int64) -> Count Elements (TExp Int64))
-> [Count Elements (TExp Int64)]
-> [Count Elements (TExp Int64)]
-> [Count Elements (TExp Int64)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Count Elements (TExp Int64)
-> Count Elements (TExp Int64) -> Count Elements (TExp Int64)
forall a. Num a => a -> a -> a
(*) [Count Elements (TExp Int64)]
is' [Count Elements (TExp Int64)]
srcstrides)
    PyExp -> PyExp -> CompilerM op s ()
doWrite PyExp
dst_i (PyExp -> CompilerM op s ())
-> CompilerM op s PyExp -> CompilerM op s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PyExp -> CompilerM op s PyExp
doRead PyExp
src_i
  (PyStmt -> CompilerM op s ()) -> [PyStmt] -> CompilerM op s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> m b) -> t a -> m ()
mapM_ PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm ([PyStmt] -> CompilerM op s ()) -> [PyStmt] -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ [(VName, PyExp)] -> [PyStmt] -> [PyStmt]
loops ([VName] -> [PyExp] -> [(VName, PyExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
is [PyExp]
shape') [PyStmt]
body
  where
    r :: Int
r = [Count Elements (TExp Int64)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Count Elements (TExp Int64)]
shape
    is :: [VName]
is = (Int -> VName) -> [Int] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (Name -> Int -> VName
VName Name
"i") [Int
0 .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
    is' :: [Count Elements (TExp Int64)]
    is' :: [Count Elements (TExp Int64)]
is' = (VName -> Count Elements (TExp Int64))
-> [VName] -> [Count Elements (TExp Int64)]
forall a b. (a -> b) -> [a] -> [b]
map (TExp Int64 -> Count Elements (TExp Int64)
forall a. a -> Count Elements a
elements (TExp Int64 -> Count Elements (TExp Int64))
-> (VName -> TExp Int64) -> VName -> Count Elements (TExp Int64)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> TExp Int64
forall a. a -> TPrimExp Int64 a
le64) [VName]
is
    loops :: [(VName, PyExp)] -> [PyStmt] -> [PyStmt]
loops [] [PyStmt]
body = [PyStmt]
body
    loops ((VName
i, PyExp
n) : [(VName, PyExp)]
ins) [PyStmt]
body =
      [[Char] -> PyExp -> [PyStmt] -> PyStmt
For (VName -> [Char]
compileName VName
i) ([Char] -> [PyExp] -> PyExp
simpleCall [Char]
"range" [PyExp
n]) ([PyStmt] -> PyStmt) -> [PyStmt] -> PyStmt
forall a b. (a -> b) -> a -> b
$ [(VName, PyExp)] -> [PyStmt] -> [PyStmt]
loops [(VName, PyExp)]
ins [PyStmt]
body]

-- | Compile an 'Copy' using sequential nested loops and
-- 'Imp.Read'/'Imp.Write' of individual scalars.  This always works,
-- but can be pretty slow if those reads and writes are costly.
compileCopy ::
  PrimType ->
  [Count Elements (TExp Int64)] ->
  (VName, Space) ->
  ( Count Elements (TExp Int64),
    [Count Elements (TExp Int64)]
  ) ->
  (VName, Space) ->
  ( Count Elements (TExp Int64),
    [Count Elements (TExp Int64)]
  ) ->
  CompilerM op s ()
compileCopy :: forall op s.
PrimType
-> [Count Elements (TExp Int64)]
-> (VName, Space)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> (VName, Space)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> CompilerM op s ()
compileCopy PrimType
t [Count Elements (TExp Int64)]
shape (VName
dst, Space
dstspace) (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
dst_lmad (VName
src, Space
srcspace) (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
src_lmad = do
  PyExp
src' <- VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
src
  PyExp
dst' <- VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
dst
  let doWrite :: PyExp -> PyExp -> CompilerM op s ()
doWrite PyExp
dst_i = PyExp -> PyExp -> PrimType -> Space -> PyExp -> CompilerM op s ()
forall op s.
PyExp -> PyExp -> PrimType -> Space -> PyExp -> CompilerM op s ()
generateWrite PyExp
dst' PyExp
dst_i PrimType
t Space
dstspace
      doRead :: PyExp -> CompilerM op s PyExp
doRead PyExp
src_i = PyExp -> PyExp -> PrimType -> Space -> CompilerM op s PyExp
forall op s.
PyExp -> PyExp -> PrimType -> Space -> CompilerM op s PyExp
generateRead PyExp
src' PyExp
src_i PrimType
t Space
srcspace
  [Count Elements (TExp Int64)]
-> (PyExp -> PyExp -> CompilerM op s ())
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> (PyExp -> CompilerM op s PyExp)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> CompilerM op s ()
forall op s.
[Count Elements (TExp Int64)]
-> (PyExp -> PyExp -> CompilerM op s ())
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> (PyExp -> CompilerM op s PyExp)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> CompilerM op s ()
compileCopyWith [Count Elements (TExp Int64)]
shape PyExp -> PyExp -> CompilerM op s ()
forall {op} {s}. PyExp -> PyExp -> CompilerM op s ()
doWrite (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
dst_lmad PyExp -> CompilerM op s PyExp
forall {op} {s}. PyExp -> CompilerM op s PyExp
doRead (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
src_lmad

compileCode :: Imp.Code op -> CompilerM op s ()
compileCode :: forall op s. Code op -> CompilerM op s ()
compileCode Imp.DebugPrint {} =
  () -> CompilerM op s ()
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
compileCode Imp.TracePrint {} =
  () -> CompilerM op s ()
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
compileCode (Imp.Op op
op) =
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ (CompilerEnv op s -> op -> CompilerM op s ())
-> CompilerM op s (op -> CompilerM op s ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> op -> CompilerM op s ()
forall op s. CompilerEnv op s -> OpCompiler op s
envOpCompiler CompilerM op s (op -> CompilerM op s ())
-> CompilerM op s op -> CompilerM op s (CompilerM op s ())
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> op -> CompilerM op s op
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure op
op
compileCode (Imp.If TExp Bool
cond Code op
tb Code op
fb) = do
  PyExp
cond' <- Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp (Exp -> CompilerM op s PyExp) -> Exp -> CompilerM op s PyExp
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
Imp.untyped TExp Bool
cond
  [PyStmt]
tb' <- CompilerM op s () -> CompilerM op s [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
collect (CompilerM op s () -> CompilerM op s [PyStmt])
-> CompilerM op s () -> CompilerM op s [PyStmt]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
tb
  [PyStmt]
fb' <- CompilerM op s () -> CompilerM op s [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
collect (CompilerM op s () -> CompilerM op s [PyStmt])
-> CompilerM op s () -> CompilerM op s [PyStmt]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
fb
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyStmt] -> [PyStmt] -> PyStmt
If PyExp
cond' [PyStmt]
tb' [PyStmt]
fb'
compileCode (Code op
c1 Imp.:>>: Code op
c2) = do
  Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
c1
  Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
c2
compileCode (Imp.While TExp Bool
cond Code op
body) = do
  PyExp
cond' <- Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp (Exp -> CompilerM op s PyExp) -> Exp -> CompilerM op s PyExp
forall a b. (a -> b) -> a -> b
$ TExp Bool -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
Imp.untyped TExp Bool
cond
  [PyStmt]
body' <- CompilerM op s () -> CompilerM op s [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
collect (CompilerM op s () -> CompilerM op s [PyStmt])
-> CompilerM op s () -> CompilerM op s [PyStmt]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
body
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ PyExp -> [PyStmt] -> PyStmt
While PyExp
cond' [PyStmt]
body'
compileCode (Imp.For VName
i Exp
bound Code op
body) = do
  PyExp
bound' <- Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp Exp
bound
  let i' :: [Char]
i' = VName -> [Char]
compileName VName
i
  [PyStmt]
body' <- CompilerM op s () -> CompilerM op s [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
collect (CompilerM op s () -> CompilerM op s [PyStmt])
-> CompilerM op s () -> CompilerM op s [PyStmt]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
body
  [Char]
counter <- VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (VName -> [Char]) -> CompilerM op s VName -> CompilerM op s [Char]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"counter"
  [Char]
one <- VName -> [Char]
forall a. Pretty a => a -> [Char]
prettyString (VName -> [Char]) -> CompilerM op s VName -> CompilerM op s [Char]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> CompilerM op s VName
forall (m :: * -> *). MonadFreshNames m => [Char] -> m VName
newVName [Char]
"one"
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
i') (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall (PrimType -> [Char]
compilePrimToNp (Exp -> PrimType
forall v. PrimExp v -> PrimType
Imp.primExpType Exp
bound)) [Integer -> PyExp
Integer Integer
0]
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
one) (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall (PrimType -> [Char]
compilePrimToNp (Exp -> PrimType
forall v. PrimExp v -> PrimType
Imp.primExpType Exp
bound)) [Integer -> PyExp
Integer Integer
1]
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    [Char] -> PyExp -> [PyStmt] -> PyStmt
For [Char]
counter ([Char] -> [PyExp] -> PyExp
simpleCall [Char]
"range" [PyExp
bound']) ([PyStmt] -> PyStmt) -> [PyStmt] -> PyStmt
forall a b. (a -> b) -> a -> b
$
      [PyStmt]
body' [PyStmt] -> [PyStmt] -> [PyStmt]
forall a. [a] -> [a] -> [a]
++ [[Char] -> PyExp -> PyExp -> PyStmt
AssignOp [Char]
"+" ([Char] -> PyExp
Var [Char]
i') ([Char] -> PyExp
Var [Char]
one)]
compileCode (Imp.SetScalar VName
name Exp
exp1) =
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ())
-> CompilerM op s PyStmt -> CompilerM op s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PyExp -> PyExp -> PyStmt
Assign (PyExp -> PyExp -> PyStmt)
-> CompilerM op s PyExp -> CompilerM op s (PyExp -> PyStmt)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
name CompilerM op s (PyExp -> PyStmt)
-> CompilerM op s PyExp -> CompilerM op s PyStmt
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp Exp
exp1
compileCode Imp.DeclareMem {} = () -> CompilerM op s ()
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
compileCode (Imp.DeclareScalar VName
v Volatility
_ PrimType
Unit) = do
  PyExp
v' <- VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
v
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign PyExp
v' (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> PyExp
Var [Char]
"True"
compileCode Imp.DeclareScalar {} = () -> CompilerM op s ()
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
compileCode (Imp.DeclareArray VName
name PrimType
t ArrayContents
vs) = do
  let arr_name :: [Char]
arr_name = VName -> [Char]
compileName VName
name [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"_arr"
  -- It is important to store the Numpy array in a temporary variable
  -- to prevent it from going "out-of-scope" before calling
  -- unwrapArray (which internally uses the .ctype method); see
  -- https://docs.scipy.org/doc/numpy/reference/generated/numpy.ndarray.ctypes.html
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign ([Char] -> PyExp
Var [Char]
arr_name) (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ case ArrayContents
vs of
    Imp.ArrayValues [PrimValue]
vs' ->
      PyExp -> [PyArg] -> PyExp
Call
        ([Char] -> PyExp
Var [Char]
"np.array")
        [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ [PyExp] -> PyExp
List ([PyExp] -> PyExp) -> [PyExp] -> PyExp
forall a b. (a -> b) -> a -> b
$ (PrimValue -> PyExp) -> [PrimValue] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map PrimValue -> PyExp
compilePrimValue [PrimValue]
vs',
          [Char] -> PyExp -> PyArg
ArgKeyword [Char]
"dtype" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ [Char] -> PyExp
Var ([Char] -> PyExp) -> [Char] -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> [Char]
compilePrimToNp PrimType
t
        ]
    Imp.ArrayZeros Int
n ->
      PyExp -> [PyArg] -> PyExp
Call
        ([Char] -> PyExp
Var [Char]
"np.zeros")
        [ PyExp -> PyArg
Arg (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ Integer -> PyExp
Integer (Integer -> PyExp) -> Integer -> PyExp
forall a b. (a -> b) -> a -> b
$ Int -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
n,
          [Char] -> PyExp -> PyArg
ArgKeyword [Char]
"dtype" (PyExp -> PyArg) -> PyExp -> PyArg
forall a b. (a -> b) -> a -> b
$ [Char] -> PyExp
Var ([Char] -> PyExp) -> [Char] -> PyExp
forall a b. (a -> b) -> a -> b
$ PrimType -> [Char]
compilePrimToNp PrimType
t
        ]
  PyExp
name' <- VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
name
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ PyExp -> PyExp -> PyStmt
Assign PyExp
name' (PyExp -> PyStmt) -> PyExp -> PyStmt
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"unwrapArray" [[Char] -> PyExp
Var [Char]
arr_name]
compileCode (Imp.Comment Text
s Code op
code) = do
  [PyStmt]
code' <- CompilerM op s () -> CompilerM op s [PyStmt]
forall op s. CompilerM op s () -> CompilerM op s [PyStmt]
collect (CompilerM op s () -> CompilerM op s [PyStmt])
-> CompilerM op s () -> CompilerM op s [PyStmt]
forall a b. (a -> b) -> a -> b
$ Code op -> CompilerM op s ()
forall op s. Code op -> CompilerM op s ()
compileCode Code op
code
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$ [Char] -> [PyStmt] -> PyStmt
Comment (Text -> [Char]
T.unpack Text
s) [PyStmt]
code'
compileCode (Imp.Assert Exp
e ErrorMsg Exp
msg (SrcLoc
loc, [SrcLoc]
locs)) = do
  PyExp
e' <- Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp Exp
e
  (Text
formatstr, [PyExp]
formatargs) <- ErrorMsg Exp -> CompilerM op s (Text, [PyExp])
forall op s. ErrorMsg Exp -> CompilerM op s (Text, [PyExp])
errorMsgString ErrorMsg Exp
msg
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    PyExp -> PyExp -> PyStmt
Assert
      PyExp
e'
      ( [Char] -> PyExp -> PyExp -> PyExp
BinOp
          [Char]
"%"
          (Text -> PyExp
String (Text -> PyExp) -> Text -> PyExp
forall a b. (a -> b) -> a -> b
$ Text
"Error: " Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
formatstr Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
"\n\nBacktrace:\n" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
stacktrace)
          ([PyExp] -> PyExp
Tuple [PyExp]
formatargs)
      )
  where
    stacktrace :: Text
stacktrace = Int -> [Text] -> Text
prettyStacktrace Int
0 ([Text] -> Text) -> [Text] -> Text
forall a b. (a -> b) -> a -> b
$ (SrcLoc -> Text) -> [SrcLoc] -> [Text]
forall a b. (a -> b) -> [a] -> [b]
map SrcLoc -> Text
forall a. Located a => a -> Text
locText ([SrcLoc] -> [Text]) -> [SrcLoc] -> [Text]
forall a b. (a -> b) -> a -> b
$ SrcLoc
loc SrcLoc -> [SrcLoc] -> [SrcLoc]
forall a. a -> [a] -> [a]
: [SrcLoc]
locs
compileCode (Imp.Call [VName]
dests Name
fname [Arg]
args) = do
  [PyExp]
args' <- (Arg -> CompilerM op s PyExp) -> [Arg] -> CompilerM op s [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Arg -> CompilerM op s PyExp
forall {op} {s}. Arg -> CompilerM op s PyExp
compileArg [Arg]
args
  PyExp
dests' <- [PyExp] -> PyExp
tupleOrSingle ([PyExp] -> PyExp)
-> CompilerM op s [PyExp] -> CompilerM op s PyExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (VName -> CompilerM op s PyExp)
-> [VName] -> CompilerM op s [PyExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar [VName]
dests
  let fname' :: Text
fname'
        | Name -> Bool
isBuiltInFunction Name
fname = Text -> Text
futharkFun (Name -> Text
forall a. Pretty a => a -> Text
prettyText Name
fname)
        | Bool
otherwise = Text
"self." Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text -> Text
futharkFun (Name -> Text
forall a. Pretty a => a -> Text
prettyText Name
fname)
      call' :: PyExp
call' = [Char] -> [PyExp] -> PyExp
simpleCall (Text -> [Char]
T.unpack Text
fname') [PyExp]
args'
  -- If the function returns nothing (is called only for side
  -- effects), take care not to assign to an empty tuple.
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ()) -> PyStmt -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    if [VName] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [VName]
dests
      then PyExp -> PyStmt
Exp PyExp
call'
      else PyExp -> PyExp -> PyStmt
Assign PyExp
dests' PyExp
call'
  where
    compileArg :: Arg -> CompilerM op s PyExp
compileArg (Imp.MemArg VName
m) = VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
m
    compileArg (Imp.ExpArg Exp
e) = Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp Exp
e
compileCode (Imp.SetMem VName
dest VName
src Space
_) =
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ())
-> CompilerM op s PyStmt -> CompilerM op s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PyExp -> PyExp -> PyStmt
Assign (PyExp -> PyExp -> PyStmt)
-> CompilerM op s PyExp -> CompilerM op s (PyExp -> PyStmt)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
dest CompilerM op s (PyExp -> PyStmt)
-> CompilerM op s PyExp -> CompilerM op s PyStmt
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
src
compileCode (Imp.Allocate VName
name (Imp.Count (Imp.TPrimExp Exp
e)) (Imp.Space [Char]
space)) =
  CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (CompilerM op s (CompilerM op s ()) -> CompilerM op s ())
-> CompilerM op s (CompilerM op s ()) -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    (CompilerEnv op s -> PyExp -> PyExp -> [Char] -> CompilerM op s ())
-> CompilerM op s (PyExp -> PyExp -> [Char] -> CompilerM op s ())
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks CompilerEnv op s -> PyExp -> PyExp -> [Char] -> CompilerM op s ()
forall op s. CompilerEnv op s -> Allocate op s
envAllocate
      CompilerM op s (PyExp -> PyExp -> [Char] -> CompilerM op s ())
-> CompilerM op s PyExp
-> CompilerM op s (PyExp -> [Char] -> CompilerM op s ())
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
name
      CompilerM op s (PyExp -> [Char] -> CompilerM op s ())
-> CompilerM op s PyExp
-> CompilerM op s ([Char] -> CompilerM op s ())
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp Exp
e
      CompilerM op s ([Char] -> CompilerM op s ())
-> CompilerM op s [Char] -> CompilerM op s (CompilerM op s ())
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [Char] -> CompilerM op s [Char]
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Char]
space
compileCode (Imp.Allocate VName
name (Imp.Count (Imp.TPrimExp Exp
e)) Space
_) = do
  PyExp
e' <- Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp Exp
e
  let allocate' :: PyExp
allocate' = [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"allocateMem" [PyExp
e']
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ())
-> CompilerM op s PyStmt -> CompilerM op s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PyExp -> PyExp -> PyStmt
Assign (PyExp -> PyExp -> PyStmt)
-> CompilerM op s PyExp -> CompilerM op s (PyExp -> PyStmt)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
name CompilerM op s (PyExp -> PyStmt)
-> CompilerM op s PyExp -> CompilerM op s PyStmt
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PyExp -> CompilerM op s PyExp
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure PyExp
allocate'
compileCode (Imp.Free VName
name Space
_) =
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ())
-> CompilerM op s PyStmt -> CompilerM op s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PyExp -> PyExp -> PyStmt
Assign (PyExp -> PyExp -> PyStmt)
-> CompilerM op s PyExp -> CompilerM op s (PyExp -> PyStmt)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
name CompilerM op s (PyExp -> PyStmt)
-> CompilerM op s PyExp -> CompilerM op s PyStmt
forall a b.
CompilerM op s (a -> b) -> CompilerM op s a -> CompilerM op s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> PyExp -> CompilerM op s PyExp
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure PyExp
None
compileCode (Imp.Copy PrimType
t [Count Elements (TExp Int64)]
shape (VName
dst, Space
dstspace) (Count Elements (TExp Int64)
dstoffset, [Count Elements (TExp Int64)]
dststrides) (VName
src, Space
srcspace) (Count Elements (TExp Int64)
srcoffset, [Count Elements (TExp Int64)]
srcstrides)) = do
  Maybe (DoCopy op s)
cp <- (CompilerEnv op s -> Maybe (DoCopy op s))
-> CompilerM op s (Maybe (DoCopy op s))
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((CompilerEnv op s -> Maybe (DoCopy op s))
 -> CompilerM op s (Maybe (DoCopy op s)))
-> (CompilerEnv op s -> Maybe (DoCopy op s))
-> CompilerM op s (Maybe (DoCopy op s))
forall a b. (a -> b) -> a -> b
$ (Space, Space)
-> Map (Space, Space) (DoCopy op s) -> Maybe (DoCopy op s)
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Space
dstspace, Space
srcspace) (Map (Space, Space) (DoCopy op s) -> Maybe (DoCopy op s))
-> (CompilerEnv op s -> Map (Space, Space) (DoCopy op s))
-> CompilerEnv op s
-> Maybe (DoCopy op s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Operations op s -> Map (Space, Space) (DoCopy op s)
forall op s. Operations op s -> Map (Space, Space) (DoCopy op s)
opsCopies (Operations op s -> Map (Space, Space) (DoCopy op s))
-> (CompilerEnv op s -> Operations op s)
-> CompilerEnv op s
-> Map (Space, Space) (DoCopy op s)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompilerEnv op s -> Operations op s
forall op s. CompilerEnv op s -> Operations op s
envOperations
  case Maybe (DoCopy op s)
cp of
    Maybe (DoCopy op s)
Nothing ->
      PrimType
-> [Count Elements (TExp Int64)]
-> (VName, Space)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> (VName, Space)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> CompilerM op s ()
forall op s.
PrimType
-> [Count Elements (TExp Int64)]
-> (VName, Space)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> (VName, Space)
-> (Count Elements (TExp Int64), [Count Elements (TExp Int64)])
-> CompilerM op s ()
compileCopy PrimType
t [Count Elements (TExp Int64)]
shape (VName
dst, Space
dstspace) (Count Elements (TExp Int64)
dstoffset, [Count Elements (TExp Int64)]
dststrides) (VName
src, Space
srcspace) (Count Elements (TExp Int64)
srcoffset, [Count Elements (TExp Int64)]
srcstrides)
    Just DoCopy op s
cp' -> do
      [Count Elements PyExp]
shape' <- (Count Elements (TExp Int64)
 -> CompilerM op s (Count Elements PyExp))
-> [Count Elements (TExp Int64)]
-> CompilerM op s [Count Elements PyExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((TExp Int64 -> CompilerM op s PyExp)
-> Count Elements (TExp Int64)
-> CompilerM op s (Count Elements PyExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Count Elements a -> f (Count Elements b)
traverse (Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp (Exp -> CompilerM op s PyExp)
-> (TExp Int64 -> Exp) -> TExp Int64 -> CompilerM op s PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped)) [Count Elements (TExp Int64)]
shape
      PyExp
dst' <- VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
dst
      PyExp
src' <- VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
src
      Count Elements PyExp
dstoffset' <- (TExp Int64 -> CompilerM op s PyExp)
-> Count Elements (TExp Int64)
-> CompilerM op s (Count Elements PyExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Count Elements a -> f (Count Elements b)
traverse (Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp (Exp -> CompilerM op s PyExp)
-> (TExp Int64 -> Exp) -> TExp Int64 -> CompilerM op s PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) Count Elements (TExp Int64)
dstoffset
      [Count Elements PyExp]
dststrides' <- (Count Elements (TExp Int64)
 -> CompilerM op s (Count Elements PyExp))
-> [Count Elements (TExp Int64)]
-> CompilerM op s [Count Elements PyExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((TExp Int64 -> CompilerM op s PyExp)
-> Count Elements (TExp Int64)
-> CompilerM op s (Count Elements PyExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Count Elements a -> f (Count Elements b)
traverse (Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp (Exp -> CompilerM op s PyExp)
-> (TExp Int64 -> Exp) -> TExp Int64 -> CompilerM op s PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped)) [Count Elements (TExp Int64)]
dststrides
      Count Elements PyExp
srcoffset' <- (TExp Int64 -> CompilerM op s PyExp)
-> Count Elements (TExp Int64)
-> CompilerM op s (Count Elements PyExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Count Elements a -> f (Count Elements b)
traverse (Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp (Exp -> CompilerM op s PyExp)
-> (TExp Int64 -> Exp) -> TExp Int64 -> CompilerM op s PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped) Count Elements (TExp Int64)
srcoffset
      [Count Elements PyExp]
srcstrides' <- (Count Elements (TExp Int64)
 -> CompilerM op s (Count Elements PyExp))
-> [Count Elements (TExp Int64)]
-> CompilerM op s [Count Elements PyExp]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> [a] -> f [b]
traverse ((TExp Int64 -> CompilerM op s PyExp)
-> Count Elements (TExp Int64)
-> CompilerM op s (Count Elements PyExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> Count Elements a -> f (Count Elements b)
traverse (Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp (Exp -> CompilerM op s PyExp)
-> (TExp Int64 -> Exp) -> TExp Int64 -> CompilerM op s PyExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped)) [Count Elements (TExp Int64)]
srcstrides
      DoCopy op s
cp' PrimType
t [Count Elements PyExp]
shape' PyExp
dst' (Count Elements PyExp
dstoffset', [Count Elements PyExp]
dststrides') PyExp
src' (Count Elements PyExp
srcoffset', [Count Elements PyExp]
srcstrides')
compileCode (Imp.Write VName
dst (Imp.Count TExp Int64
idx) PrimType
pt Space
space Volatility
_ Exp
elemexp) = do
  PyExp
dst' <- VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
dst
  PyExp
idx' <- Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp (Exp -> CompilerM op s PyExp) -> Exp -> CompilerM op s PyExp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
Imp.untyped TExp Int64
idx
  PyExp
elemexp' <- Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp Exp
elemexp
  PyExp -> PyExp -> PrimType -> Space -> PyExp -> CompilerM op s ()
forall op s.
PyExp -> PyExp -> PrimType -> Space -> PyExp -> CompilerM op s ()
generateWrite PyExp
dst' PyExp
idx' PrimType
pt Space
space PyExp
elemexp'
compileCode (Imp.Read VName
x VName
src (Imp.Count TExp Int64
iexp) PrimType
pt Space
space Volatility
_) = do
  PyExp
x' <- VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
x
  PyExp
iexp' <- Exp -> CompilerM op s PyExp
forall op s. Exp -> CompilerM op s PyExp
compileExp (Exp -> CompilerM op s PyExp) -> Exp -> CompilerM op s PyExp
forall a b. (a -> b) -> a -> b
$ TExp Int64 -> Exp
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TExp Int64
iexp
  PyExp
src' <- VName -> CompilerM op s PyExp
forall op s. VName -> CompilerM op s PyExp
compileVar VName
src
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ())
-> (PyExp -> PyStmt) -> PyExp -> CompilerM op s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PyExp -> PyExp -> PyStmt
Assign PyExp
x' (PyExp -> CompilerM op s ())
-> CompilerM op s PyExp -> CompilerM op s ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< PyExp -> PyExp -> PrimType -> Space -> CompilerM op s PyExp
forall op s.
PyExp -> PyExp -> PrimType -> Space -> CompilerM op s PyExp
generateRead PyExp
src' PyExp
iexp' PrimType
pt Space
space
compileCode Code op
Imp.Skip = () -> CompilerM op s ()
forall a. a -> CompilerM op s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()

lmadcopyCPU :: DoCopy op s
lmadcopyCPU :: forall op s. DoCopy op s
lmadcopyCPU PrimType
t [Count Elements PyExp]
shape PyExp
dst (Count Elements PyExp
dstoffset, [Count Elements PyExp]
dststride) PyExp
src (Count Elements PyExp
srcoffset, [Count Elements PyExp]
srcstride) =
  PyStmt -> CompilerM op s ()
forall op s. PyStmt -> CompilerM op s ()
stm (PyStmt -> CompilerM op s ())
-> ([PyExp] -> PyStmt) -> [PyExp] -> CompilerM op s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PyExp -> PyStmt
Exp (PyExp -> PyStmt) -> ([PyExp] -> PyExp) -> [PyExp] -> PyStmt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> [PyExp] -> PyExp
simpleCall [Char]
"lmad_copy" ([PyExp] -> CompilerM op s ()) -> [PyExp] -> CompilerM op s ()
forall a b. (a -> b) -> a -> b
$
    [ [Char] -> PyExp
Var (PrimType -> [Char]
compilePrimType PrimType
t),
      PyExp
dst,
      Count Elements PyExp -> PyExp
forall {k} (u :: k) e. Count u e -> e
unCount Count Elements PyExp
dstoffset,
      [PyExp] -> PyExp
List ((Count Elements PyExp -> PyExp)
-> [Count Elements PyExp] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map Count Elements PyExp -> PyExp
forall {k} (u :: k) e. Count u e -> e
unCount [Count Elements PyExp]
dststride),
      PyExp
src,
      Count Elements PyExp -> PyExp
forall {k} (u :: k) e. Count u e -> e
unCount Count Elements PyExp
srcoffset,
      [PyExp] -> PyExp
List ((Count Elements PyExp -> PyExp)
-> [Count Elements PyExp] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map Count Elements PyExp -> PyExp
forall {k} (u :: k) e. Count u e -> e
unCount [Count Elements PyExp]
srcstride),
      [PyExp] -> PyExp
List ((Count Elements PyExp -> PyExp)
-> [Count Elements PyExp] -> [PyExp]
forall a b. (a -> b) -> [a] -> [b]
map Count Elements PyExp -> PyExp
forall {k} (u :: k) e. Count u e -> e
unCount [Count Elements PyExp]
shape)
    ]