{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TupleSections #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.Construct
( letSubExp,
letSubExps,
letExp,
letTupExp,
letTupExp',
letInPlace,
eSubExp,
eParam,
eIf,
eIf',
eBinOp,
eCmpOp,
eConvOp,
eSignum,
eCopy,
eBody,
eLambda,
eRoundToMultipleOf,
eSliceArray,
eBlank,
eAll,
eOutOfBounds,
asIntZ,
asIntS,
resultBody,
resultBodyM,
insertStmsM,
buildBody,
buildBody_,
mapResult,
foldBinOp,
binOpLambda,
cmpOpLambda,
mkLambda,
sliceDim,
fullSlice,
fullSliceNum,
isFullSlice,
sliceAt,
ifCommon,
module Futhark.Builder,
instantiateShapes,
instantiateShapes',
removeExistentials,
simpleMkLetNames,
ToExp (..),
toSubExp,
)
where
import Control.Monad.Identity
import Control.Monad.State
import Data.List (sortOn)
import qualified Data.Map.Strict as M
import qualified Data.Set as S
import Futhark.Builder
import Futhark.IR
import Futhark.Util (maybeNth)
letSubExp ::
MonadBuilder m =>
String ->
Exp (Rep m) ->
m SubExp
letSubExp :: String -> Exp (Rep m) -> m SubExp
letSubExp String
_ (BasicOp (SubExp SubExp
se)) = SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se
letSubExp String
desc Exp (Rep m)
e = VName -> SubExp
Var (VName -> SubExp) -> m VName -> m SubExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
desc Exp (Rep m)
e
letExp ::
MonadBuilder m =>
String ->
Exp (Rep m) ->
m VName
letExp :: String -> Exp (Rep m) -> m VName
letExp String
_ (BasicOp (SubExp (Var VName
v))) =
VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v
letExp String
desc Exp (Rep m)
e = do
Int
n <- [ExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([ExtType] -> Int) -> m [ExtType] -> m Int
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Exp (Rep m) -> m [ExtType]
forall rep (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp (Rep m)
e
[VName]
vs <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
desc
[VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
vs Exp (Rep m)
e
case [VName]
vs of
[VName
v] -> VName -> m VName
forall (m :: * -> *) a. Monad m => a -> m a
return VName
v
[VName]
_ -> String -> m VName
forall a. HasCallStack => String -> a
error (String -> m VName) -> String -> m VName
forall a b. (a -> b) -> a -> b
$ String
"letExp: tuple-typed expression given:\n" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Exp (Rep m) -> String
forall a. Pretty a => a -> String
pretty Exp (Rep m)
e
letInPlace ::
MonadBuilder m =>
String ->
VName ->
Slice SubExp ->
Exp (Rep m) ->
m VName
letInPlace :: String -> VName -> Slice SubExp -> Exp (Rep m) -> m VName
letInPlace String
desc VName
src Slice SubExp
slice Exp (Rep m)
e = do
SubExp
tmp <- String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp (String
desc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"_tmp") Exp (Rep m)
e
String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
desc (Exp (Rep m) -> m VName) -> Exp (Rep m) -> m VName
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ Safety -> VName -> Slice SubExp -> SubExp -> BasicOp
Update Safety
Unsafe VName
src Slice SubExp
slice SubExp
tmp
letSubExps ::
MonadBuilder m =>
String ->
[Exp (Rep m)] ->
m [SubExp]
letSubExps :: String -> [Exp (Rep m)] -> m [SubExp]
letSubExps String
desc = (Exp (Rep m) -> m SubExp) -> [Exp (Rep m)] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((Exp (Rep m) -> m SubExp) -> [Exp (Rep m)] -> m [SubExp])
-> (Exp (Rep m) -> m SubExp) -> [Exp (Rep m)] -> m [SubExp]
forall a b. (a -> b) -> a -> b
$ String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
desc
letTupExp ::
(MonadBuilder m) =>
String ->
Exp (Rep m) ->
m [VName]
letTupExp :: String -> Exp (Rep m) -> m [VName]
letTupExp String
_ (BasicOp (SubExp (Var VName
v))) =
[VName] -> m [VName]
forall (m :: * -> *) a. Monad m => a -> m a
return [VName
v]
letTupExp String
name Exp (Rep m)
e = do
[ExtType]
e_t <- Exp (Rep m) -> m [ExtType]
forall rep (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp (Rep m)
e
[VName]
names <- Int -> m VName -> m [VName]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM ([ExtType] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [ExtType]
e_t) (m VName -> m [VName]) -> m VName -> m [VName]
forall a b. (a -> b) -> a -> b
$ String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
name
[VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [VName]
names Exp (Rep m)
e
let ctx :: Set Int
ctx = [ExtType] -> Set Int
forall u. [TypeBase ExtShape u] -> Set Int
shapeContext [ExtType]
e_t
[VName] -> m [VName]
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([VName] -> m [VName]) -> [VName] -> m [VName]
forall a b. (a -> b) -> a -> b
$ ((VName, Int) -> VName) -> [(VName, Int)] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map (VName, Int) -> VName
forall a b. (a, b) -> a
fst ([(VName, Int)] -> [VName]) -> [(VName, Int)] -> [VName]
forall a b. (a -> b) -> a -> b
$ ((VName, Int) -> Bool) -> [(VName, Int)] -> [(VName, Int)]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Int -> Set Int -> Bool
forall a. Ord a => a -> Set a -> Bool
`S.notMember` Set Int
ctx) (Int -> Bool) -> ((VName, Int) -> Int) -> (VName, Int) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VName, Int) -> Int
forall a b. (a, b) -> b
snd) ([(VName, Int)] -> [(VName, Int)])
-> [(VName, Int)] -> [(VName, Int)]
forall a b. (a -> b) -> a -> b
$ [VName] -> [Int] -> [(VName, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
names [Int
0 ..]
letTupExp' ::
(MonadBuilder m) =>
String ->
Exp (Rep m) ->
m [SubExp]
letTupExp' :: String -> Exp (Rep m) -> m [SubExp]
letTupExp' String
_ (BasicOp (SubExp SubExp
se)) = [SubExp] -> m [SubExp]
forall (m :: * -> *) a. Monad m => a -> m a
return [SubExp
se]
letTupExp' String
name Exp (Rep m)
ses = (VName -> SubExp) -> [VName] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
Var ([VName] -> [SubExp]) -> m [VName] -> m [SubExp]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> String -> Exp (Rep m) -> m [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
name Exp (Rep m)
ses
eSubExp ::
MonadBuilder m =>
SubExp ->
m (Exp (Rep m))
eSubExp :: SubExp -> m (Exp (Rep m))
eSubExp = Exp (Rep m) -> m (Exp (Rep m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep m) -> m (Exp (Rep m)))
-> (SubExp -> Exp (Rep m)) -> SubExp -> m (Exp (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m))
-> (SubExp -> BasicOp) -> SubExp -> Exp (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp
eParam ::
MonadBuilder m =>
Param t ->
m (Exp (Rep m))
eParam :: Param t -> m (Exp (Rep m))
eParam = SubExp -> m (Exp (Rep m))
forall (m :: * -> *). MonadBuilder m => SubExp -> m (Exp (Rep m))
eSubExp (SubExp -> m (Exp (Rep m)))
-> (Param t -> SubExp) -> Param t -> m (Exp (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var (VName -> SubExp) -> (Param t -> VName) -> Param t -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param t -> VName
forall dec. Param dec -> VName
paramName
eIf ::
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m)) ->
m (Body (Rep m)) ->
m (Body (Rep m)) ->
m (Exp (Rep m))
eIf :: m (Exp (Rep m))
-> m (Body (Rep m)) -> m (Body (Rep m)) -> m (Exp (Rep m))
eIf m (Exp (Rep m))
ce m (Body (Rep m))
te m (Body (Rep m))
fe = m (Exp (Rep m))
-> m (Body (Rep m))
-> m (Body (Rep m))
-> IfSort
-> m (Exp (Rep m))
forall (m :: * -> *).
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m))
-> m (Body (Rep m))
-> m (Body (Rep m))
-> IfSort
-> m (Exp (Rep m))
eIf' m (Exp (Rep m))
ce m (Body (Rep m))
te m (Body (Rep m))
fe IfSort
IfNormal
eIf' ::
(MonadBuilder m, BranchType (Rep m) ~ ExtType) =>
m (Exp (Rep m)) ->
m (Body (Rep m)) ->
m (Body (Rep m)) ->
IfSort ->
m (Exp (Rep m))
eIf' :: m (Exp (Rep m))
-> m (Body (Rep m))
-> m (Body (Rep m))
-> IfSort
-> m (Exp (Rep m))
eIf' m (Exp (Rep m))
ce m (Body (Rep m))
te m (Body (Rep m))
fe IfSort
if_sort = do
SubExp
ce' <- String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"cond" (Exp (Rep m) -> m SubExp) -> m (Exp (Rep m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
ce
Body (Rep m)
te' <- m (Body (Rep m)) -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM m (Body (Rep m))
te
Body (Rep m)
fe' <- m (Body (Rep m)) -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM m (Body (Rep m))
fe
[ExtType]
ts <- [ExtType] -> [ExtType] -> [ExtType]
forall u.
[TypeBase ExtShape u]
-> [TypeBase ExtShape u] -> [TypeBase ExtShape u]
generaliseExtTypes ([ExtType] -> [ExtType] -> [ExtType])
-> m [ExtType] -> m ([ExtType] -> [ExtType])
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Body (Rep m) -> m [ExtType]
forall rep (m :: * -> *).
(HasScope rep m, Monad m) =>
Body rep -> m [ExtType]
bodyExtType Body (Rep m)
te' m ([ExtType] -> [ExtType]) -> m [ExtType] -> m [ExtType]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Body (Rep m) -> m [ExtType]
forall rep (m :: * -> *).
(HasScope rep m, Monad m) =>
Body rep -> m [ExtType]
bodyExtType Body (Rep m)
fe'
Body (Rep m)
te'' <- [ExtType] -> Body (Rep m) -> m (Body (Rep m))
forall (m :: * -> *) u.
MonadBuilder m =>
[TypeBase ExtShape u] -> BodyT (Rep m) -> m (BodyT (Rep m))
addContextForBranch [ExtType]
ts Body (Rep m)
te'
Body (Rep m)
fe'' <- [ExtType] -> Body (Rep m) -> m (Body (Rep m))
forall (m :: * -> *) u.
MonadBuilder m =>
[TypeBase ExtShape u] -> BodyT (Rep m) -> m (BodyT (Rep m))
addContextForBranch [ExtType]
ts Body (Rep m)
fe'
let ts' :: [ExtType]
ts' = Int -> ExtType -> [ExtType]
forall a. Int -> a -> [a]
replicate (Set Int -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ([ExtType] -> Set Int
forall u. [TypeBase ExtShape u] -> Set Int
shapeContext [ExtType]
ts)) (PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
int64) [ExtType] -> [ExtType] -> [ExtType]
forall a. [a] -> [a] -> [a]
++ [ExtType]
ts
Exp (Rep m) -> m (Exp (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Rep m) -> m (Exp (Rep m))) -> Exp (Rep m) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ SubExp
-> Body (Rep m)
-> Body (Rep m)
-> IfDec (BranchType (Rep m))
-> Exp (Rep m)
forall rep.
SubExp
-> BodyT rep -> BodyT rep -> IfDec (BranchType rep) -> ExpT rep
If SubExp
ce' Body (Rep m)
te'' Body (Rep m)
fe'' (IfDec (BranchType (Rep m)) -> Exp (Rep m))
-> IfDec (BranchType (Rep m)) -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ [ExtType] -> IfSort -> IfDec ExtType
forall rt. [rt] -> IfSort -> IfDec rt
IfDec [ExtType]
ts' IfSort
if_sort
where
addContextForBranch :: [TypeBase ExtShape u] -> BodyT (Rep m) -> m (BodyT (Rep m))
addContextForBranch [TypeBase ExtShape u]
ts (Body BodyDec (Rep m)
_ Stms (Rep m)
stms Result
val_res) = do
[Type]
body_ts <- ExtendedScope (Rep m) m [Type] -> Scope (Rep m) -> m [Type]
forall rep (m :: * -> *) a.
ExtendedScope rep m a -> Scope rep -> m a
extendedScope ((SubExpRes -> ExtendedScope (Rep m) m Type)
-> Result -> ExtendedScope (Rep m) m [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExpRes -> ExtendedScope (Rep m) m Type
forall t (m :: * -> *). HasScope t m => SubExpRes -> m Type
subExpResType Result
val_res) Scope (Rep m)
stmsscope
let ctx_res :: [SubExp]
ctx_res =
((Int, SubExp) -> SubExp) -> [(Int, SubExp)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (Int, SubExp) -> SubExp
forall a b. (a, b) -> b
snd ([(Int, SubExp)] -> [SubExp]) -> [(Int, SubExp)] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ ((Int, SubExp) -> Int) -> [(Int, SubExp)] -> [(Int, SubExp)]
forall b a. Ord b => (a -> b) -> [a] -> [a]
sortOn (Int, SubExp) -> Int
forall a b. (a, b) -> a
fst ([(Int, SubExp)] -> [(Int, SubExp)])
-> [(Int, SubExp)] -> [(Int, SubExp)]
forall a b. (a -> b) -> a -> b
$ Map Int SubExp -> [(Int, SubExp)]
forall k a. Map k a -> [(k, a)]
M.toList (Map Int SubExp -> [(Int, SubExp)])
-> Map Int SubExp -> [(Int, SubExp)]
forall a b. (a -> b) -> a -> b
$ [TypeBase ExtShape u] -> [Type] -> Map Int SubExp
forall u u1.
[TypeBase ExtShape u] -> [TypeBase Shape u1] -> Map Int SubExp
shapeExtMapping [TypeBase ExtShape u]
ts [Type]
body_ts
Stms (Rep m) -> Result -> m (BodyT (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
stms (Result -> m (BodyT (Rep m))) -> Result -> m (BodyT (Rep m))
forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
ctx_res Result -> Result -> Result
forall a. [a] -> [a] -> [a]
++ Result
val_res
where
stmsscope :: Scope (Rep m)
stmsscope = Stms (Rep m) -> Scope (Rep m)
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms (Rep m)
stms
bodyExtType :: (HasScope rep m, Monad m) => Body rep -> m [ExtType]
bodyExtType :: Body rep -> m [ExtType]
bodyExtType (Body BodyDec rep
_ Stms rep
stms Result
res) =
[VName] -> [ExtType] -> [ExtType]
existentialiseExtTypes (Map VName (NameInfo rep) -> [VName]
forall k a. Map k a -> [k]
M.keys Map VName (NameInfo rep)
stmsscope) ([ExtType] -> [ExtType])
-> ([Type] -> [ExtType]) -> [Type] -> [ExtType]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes
([Type] -> [ExtType]) -> m [Type] -> m [ExtType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExtendedScope rep m [Type] -> Map VName (NameInfo rep) -> m [Type]
forall rep (m :: * -> *) a.
ExtendedScope rep m a -> Scope rep -> m a
extendedScope ((SubExpRes -> ExtendedScope rep m Type)
-> Result -> ExtendedScope rep m [Type]
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExpRes -> ExtendedScope rep m Type
forall t (m :: * -> *). HasScope t m => SubExpRes -> m Type
subExpResType Result
res) Map VName (NameInfo rep)
stmsscope
where
stmsscope :: Map VName (NameInfo rep)
stmsscope = Stms rep -> Map VName (NameInfo rep)
forall rep a. Scoped rep a => a -> Scope rep
scopeOf Stms rep
stms
eBinOp ::
MonadBuilder m =>
BinOp ->
m (Exp (Rep m)) ->
m (Exp (Rep m)) ->
m (Exp (Rep m))
eBinOp :: BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
op m (Exp (Rep m))
x m (Exp (Rep m))
y = do
SubExp
x' <- String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"x" (Exp (Rep m) -> m SubExp) -> m (Exp (Rep m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
x
SubExp
y' <- String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"y" (Exp (Rep m) -> m SubExp) -> m (Exp (Rep m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
y
Exp (Rep m) -> m (Exp (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Rep m) -> m (Exp (Rep m))) -> Exp (Rep m) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
op SubExp
x' SubExp
y'
eCmpOp ::
MonadBuilder m =>
CmpOp ->
m (Exp (Rep m)) ->
m (Exp (Rep m)) ->
m (Exp (Rep m))
eCmpOp :: CmpOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eCmpOp CmpOp
op m (Exp (Rep m))
x m (Exp (Rep m))
y = do
SubExp
x' <- String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"x" (Exp (Rep m) -> m SubExp) -> m (Exp (Rep m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
x
SubExp
y' <- String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"y" (Exp (Rep m) -> m SubExp) -> m (Exp (Rep m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
y
Exp (Rep m) -> m (Exp (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Rep m) -> m (Exp (Rep m))) -> Exp (Rep m) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp CmpOp
op SubExp
x' SubExp
y'
eConvOp ::
MonadBuilder m =>
ConvOp ->
m (Exp (Rep m)) ->
m (Exp (Rep m))
eConvOp :: ConvOp -> m (Exp (Rep m)) -> m (Exp (Rep m))
eConvOp ConvOp
op m (Exp (Rep m))
x = do
SubExp
x' <- String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"x" (Exp (Rep m) -> m SubExp) -> m (Exp (Rep m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
x
Exp (Rep m) -> m (Exp (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Rep m) -> m (Exp (Rep m))) -> Exp (Rep m) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp ConvOp
op SubExp
x'
eSignum ::
MonadBuilder m =>
m (Exp (Rep m)) ->
m (Exp (Rep m))
eSignum :: m (Exp (Rep m)) -> m (Exp (Rep m))
eSignum m (Exp (Rep m))
em = do
Exp (Rep m)
e <- m (Exp (Rep m))
em
SubExp
e' <- String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"signum_arg" Exp (Rep m)
e
Type
t <- SubExp -> m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
e'
case Type
t of
Prim (IntType IntType
int_t) ->
Exp (Rep m) -> m (Exp (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Rep m) -> m (Exp (Rep m))) -> Exp (Rep m) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ UnOp -> SubExp -> BasicOp
UnOp (IntType -> UnOp
SSignum IntType
int_t) SubExp
e'
Type
_ ->
String -> m (Exp (Rep m))
forall a. HasCallStack => String -> a
error (String -> m (Exp (Rep m))) -> String -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ String
"eSignum: operand " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Exp (Rep m) -> String
forall a. Pretty a => a -> String
pretty Exp (Rep m)
e String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" has invalid type."
eCopy ::
MonadBuilder m =>
m (Exp (Rep m)) ->
m (Exp (Rep m))
eCopy :: m (Exp (Rep m)) -> m (Exp (Rep m))
eCopy m (Exp (Rep m))
e = BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m))
-> (VName -> BasicOp) -> VName -> Exp (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> BasicOp
Copy (VName -> Exp (Rep m)) -> m VName -> m (Exp (Rep m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (String -> Exp (Rep m) -> m VName
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m VName
letExp String
"copy_arg" (Exp (Rep m) -> m VName) -> m (Exp (Rep m)) -> m VName
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
e)
eBody ::
(MonadBuilder m) =>
[m (Exp (Rep m))] ->
m (Body (Rep m))
eBody :: [m (Exp (Rep m))] -> m (Body (Rep m))
eBody [m (Exp (Rep m))]
es = m Result -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (m Result -> m (Body (Rep m))) -> m Result -> m (Body (Rep m))
forall a b. (a -> b) -> a -> b
$ do
[Exp (Rep m)]
es' <- [m (Exp (Rep m))] -> m [Exp (Rep m)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [m (Exp (Rep m))]
es
[[VName]]
xs <- (Exp (Rep m) -> m [VName]) -> [Exp (Rep m)] -> m [[VName]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Exp (Rep m) -> m [VName]
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m [VName]
letTupExp String
"x") [Exp (Rep m)]
es'
Result -> m Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result -> m Result) -> Result -> m Result
forall a b. (a -> b) -> a -> b
$ [VName] -> Result
varsRes ([VName] -> Result) -> [VName] -> Result
forall a b. (a -> b) -> a -> b
$ [[VName]] -> [VName]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[VName]]
xs
eLambda ::
MonadBuilder m =>
Lambda (Rep m) ->
[m (Exp (Rep m))] ->
m [SubExpRes]
eLambda :: Lambda (Rep m) -> [m (Exp (Rep m))] -> m Result
eLambda Lambda (Rep m)
lam [m (Exp (Rep m))]
args = do
(Param (LParamInfo (Rep m)) -> m (Exp (Rep m)) -> m ())
-> [Param (LParamInfo (Rep m))] -> [m (Exp (Rep m))] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Param (LParamInfo (Rep m)) -> m (Exp (Rep m)) -> m ()
forall (m :: * -> *) dec.
MonadBuilder m =>
Param dec -> m (Exp (Rep m)) -> m ()
bindParam (Lambda (Rep m) -> [Param (LParamInfo (Rep m))]
forall rep. LambdaT rep -> [LParam rep]
lambdaParams Lambda (Rep m)
lam) [m (Exp (Rep m))]
args
Body (Rep m) -> m Result
forall (m :: * -> *). MonadBuilder m => Body (Rep m) -> m Result
bodyBind (Body (Rep m) -> m Result) -> Body (Rep m) -> m Result
forall a b. (a -> b) -> a -> b
$ Lambda (Rep m) -> Body (Rep m)
forall rep. LambdaT rep -> BodyT rep
lambdaBody Lambda (Rep m)
lam
where
bindParam :: Param dec -> m (Exp (Rep m)) -> m ()
bindParam Param dec
param m (Exp (Rep m))
arg = [VName] -> Exp (Rep m) -> m ()
forall (m :: * -> *).
MonadBuilder m =>
[VName] -> Exp (Rep m) -> m ()
letBindNames [Param dec -> VName
forall dec. Param dec -> VName
paramName Param dec
param] (Exp (Rep m) -> m ()) -> m (Exp (Rep m)) -> m ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
arg
eRoundToMultipleOf ::
MonadBuilder m =>
IntType ->
m (Exp (Rep m)) ->
m (Exp (Rep m)) ->
m (Exp (Rep m))
eRoundToMultipleOf :: IntType -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eRoundToMultipleOf IntType
t m (Exp (Rep m))
x m (Exp (Rep m))
d =
m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
ePlus m (Exp (Rep m))
x (m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eMod (m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eMinus m (Exp (Rep m))
d (m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eMod m (Exp (Rep m))
x m (Exp (Rep m))
d)) m (Exp (Rep m))
d)
where
eMod :: m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eMod = BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Safety -> BinOp
SMod IntType
t Safety
Unsafe)
eMinus :: m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eMinus = BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Overflow -> BinOp
Sub IntType
t Overflow
OverflowWrap)
ePlus :: m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
ePlus = BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp (IntType -> Overflow -> BinOp
Add IntType
t Overflow
OverflowWrap)
eSliceArray ::
MonadBuilder m =>
Int ->
VName ->
m (Exp (Rep m)) ->
m (Exp (Rep m)) ->
m (Exp (Rep m))
eSliceArray :: Int
-> VName -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eSliceArray Int
d VName
arr m (Exp (Rep m))
i m (Exp (Rep m))
n = do
Type
arr_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
let skips :: [DimIndex SubExp]
skips = (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp -> SubExp -> DimIndex SubExp
slice (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64))) ([SubExp] -> [DimIndex SubExp]) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> a -> b
$ Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
d ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
arr_t
SubExp
i' <- String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"slice_i" (Exp (Rep m) -> m SubExp) -> m (Exp (Rep m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
i
SubExp
n' <- String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"slice_n" (Exp (Rep m) -> m SubExp) -> m (Exp (Rep m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< m (Exp (Rep m))
n
Exp (Rep m) -> m (Exp (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Rep m) -> m (Exp (Rep m))) -> Exp (Rep m) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ VName -> Slice SubExp -> BasicOp
Index VName
arr (Slice SubExp -> BasicOp) -> Slice SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
arr_t ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
skips [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ [SubExp -> SubExp -> DimIndex SubExp
slice SubExp
i' SubExp
n']
where
slice :: SubExp -> SubExp -> DimIndex SubExp
slice SubExp
j SubExp
m = SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice SubExp
j SubExp
m (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))
eOutOfBounds ::
MonadBuilder m =>
VName ->
[m (Exp (Rep m))] ->
m (Exp (Rep m))
eOutOfBounds :: VName -> [m (Exp (Rep m))] -> m (Exp (Rep m))
eOutOfBounds VName
arr [m (Exp (Rep m))]
is = do
Type
arr_t <- VName -> m Type
forall rep (m :: * -> *). HasScope rep m => VName -> m Type
lookupType VName
arr
let ws :: [SubExp]
ws = Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
arr_t
[SubExp]
is' <- (Exp (Rep m) -> m SubExp) -> [Exp (Rep m)] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM (String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"write_i") ([Exp (Rep m)] -> m [SubExp]) -> m [Exp (Rep m)] -> m [SubExp]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [m (Exp (Rep m))] -> m [Exp (Rep m)]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
sequence [m (Exp (Rep m))]
is
let checkDim :: SubExp -> SubExp -> m SubExp
checkDim SubExp
w SubExp
i = do
SubExp
less_than_zero <-
String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"less_than_zero" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSlt IntType
Int64) SubExp
i (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64))
SubExp
greater_than_size <-
String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"greater_than_size" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (IntType -> CmpOp
CmpSle IntType
Int64) SubExp
w SubExp
i
String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"outside_bounds_dim" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$
BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
LogOr SubExp
less_than_zero SubExp
greater_than_size
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
LogOr (Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
False) ([SubExp] -> m (Exp (Rep m))) -> m [SubExp] -> m (Exp (Rep m))
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< (SubExp -> SubExp -> m SubExp)
-> [SubExp] -> [SubExp] -> m [SubExp]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM SubExp -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
SubExp -> SubExp -> m SubExp
checkDim [SubExp]
ws [SubExp]
is'
eBlank :: MonadBuilder m => Type -> m (Exp (Rep m))
eBlank :: Type -> m (Exp (Rep m))
eBlank (Prim PrimType
t) = Exp (Rep m) -> m (Exp (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Rep m) -> m (Exp (Rep m))) -> Exp (Rep m) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> PrimValue -> SubExp
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimValue
blankPrimValue PrimType
t
eBlank (Array PrimType
t Shape
shape NoUniqueness
_) = Exp (Rep m) -> m (Exp (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Rep m) -> m (Exp (Rep m))) -> Exp (Rep m) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ PrimType -> [SubExp] -> BasicOp
Scratch PrimType
t ([SubExp] -> BasicOp) -> [SubExp] -> BasicOp
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape
eBlank Acc {} = String -> m (Exp (Rep m))
forall a. HasCallStack => String -> a
error String
"eBlank: cannot create blank accumulator"
eBlank Mem {} = String -> m (Exp (Rep m))
forall a. HasCallStack => String -> a
error String
"eBlank: cannot create blank memory"
asIntS :: MonadBuilder m => IntType -> SubExp -> m SubExp
asIntS :: IntType -> SubExp -> m SubExp
asIntS = (IntType -> IntType -> ConvOp) -> IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
(IntType -> IntType -> ConvOp) -> IntType -> SubExp -> m SubExp
asInt IntType -> IntType -> ConvOp
SExt
asIntZ :: MonadBuilder m => IntType -> SubExp -> m SubExp
asIntZ :: IntType -> SubExp -> m SubExp
asIntZ = (IntType -> IntType -> ConvOp) -> IntType -> SubExp -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
(IntType -> IntType -> ConvOp) -> IntType -> SubExp -> m SubExp
asInt IntType -> IntType -> ConvOp
ZExt
asInt ::
MonadBuilder m =>
(IntType -> IntType -> ConvOp) ->
IntType ->
SubExp ->
m SubExp
asInt :: (IntType -> IntType -> ConvOp) -> IntType -> SubExp -> m SubExp
asInt IntType -> IntType -> ConvOp
ext IntType
to_it SubExp
e = do
Type
e_t <- SubExp -> m Type
forall t (m :: * -> *). HasScope t m => SubExp -> m Type
subExpType SubExp
e
case Type
e_t of
Prim (IntType IntType
from_it)
| IntType
to_it IntType -> IntType -> Bool
forall a. Eq a => a -> a -> Bool
== IntType
from_it -> SubExp -> m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
e
| Bool
otherwise -> String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
s (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ ConvOp -> SubExp -> BasicOp
ConvOp (IntType -> IntType -> ConvOp
ext IntType
from_it IntType
to_it) SubExp
e
Type
_ -> String -> m SubExp
forall a. HasCallStack => String -> a
error String
"asInt: wrong type"
where
s :: String
s = case SubExp
e of
Var VName
v -> VName -> String
baseString VName
v
SubExp
_ -> String
"to_" String -> String -> String
forall a. [a] -> [a] -> [a]
++ IntType -> String
forall a. Pretty a => a -> String
pretty IntType
to_it
foldBinOp ::
MonadBuilder m =>
BinOp ->
SubExp ->
[SubExp] ->
m (Exp (Rep m))
foldBinOp :: BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
_ SubExp
ne [] =
Exp (Rep m) -> m (Exp (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Rep m) -> m (Exp (Rep m))) -> Exp (Rep m) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
ne
foldBinOp BinOp
bop SubExp
ne (SubExp
e : [SubExp]
es) =
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> m (Exp (Rep m)) -> m (Exp (Rep m)) -> m (Exp (Rep m))
eBinOp BinOp
bop (Exp (Rep m) -> m (Exp (Rep m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep m) -> m (Exp (Rep m))) -> Exp (Rep m) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp SubExp
e) (BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
bop SubExp
ne [SubExp]
es)
eAll :: MonadBuilder m => [SubExp] -> m (Exp (Rep m))
eAll :: [SubExp] -> m (Exp (Rep m))
eAll [] = Exp (Rep m) -> m (Exp (Rep m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Exp (Rep m) -> m (Exp (Rep m))) -> Exp (Rep m) -> m (Exp (Rep m))
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> SubExp -> BasicOp
forall a b. (a -> b) -> a -> b
$ Bool -> SubExp
forall v. IsValue v => v -> SubExp
constant Bool
True
eAll (SubExp
x : [SubExp]
xs) = BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
BinOp -> SubExp -> [SubExp] -> m (Exp (Rep m))
foldBinOp BinOp
LogAnd SubExp
x [SubExp]
xs
binOpLambda ::
(MonadBuilder m, Buildable (Rep m)) =>
BinOp ->
PrimType ->
m (Lambda (Rep m))
binOpLambda :: BinOp -> PrimType -> m (Lambda (Rep m))
binOpLambda BinOp
bop PrimType
t = (SubExp -> SubExp -> BasicOp)
-> PrimType -> PrimType -> m (Lambda (Rep m))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
(SubExp -> SubExp -> BasicOp)
-> PrimType -> PrimType -> m (Lambda (Rep m))
binLambda (BinOp -> SubExp -> SubExp -> BasicOp
BinOp BinOp
bop) PrimType
t PrimType
t
cmpOpLambda ::
(MonadBuilder m, Buildable (Rep m)) =>
CmpOp ->
m (Lambda (Rep m))
cmpOpLambda :: CmpOp -> m (Lambda (Rep m))
cmpOpLambda CmpOp
cop = (SubExp -> SubExp -> BasicOp)
-> PrimType -> PrimType -> m (Lambda (Rep m))
forall (m :: * -> *).
(MonadBuilder m, Buildable (Rep m)) =>
(SubExp -> SubExp -> BasicOp)
-> PrimType -> PrimType -> m (Lambda (Rep m))
binLambda (CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp CmpOp
cop) (CmpOp -> PrimType
cmpOpType CmpOp
cop) PrimType
Bool
binLambda ::
(MonadBuilder m, Buildable (Rep m)) =>
(SubExp -> SubExp -> BasicOp) ->
PrimType ->
PrimType ->
m (Lambda (Rep m))
binLambda :: (SubExp -> SubExp -> BasicOp)
-> PrimType -> PrimType -> m (Lambda (Rep m))
binLambda SubExp -> SubExp -> BasicOp
bop PrimType
arg_t PrimType
ret_t = do
VName
x <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"x"
VName
y <- String -> m VName
forall (m :: * -> *). MonadFreshNames m => String -> m VName
newVName String
"y"
BodyT (Rep m)
body <-
m Result -> m (BodyT (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
m Result -> m (Body (Rep m))
buildBody_ (m Result -> m (BodyT (Rep m)))
-> (m SubExp -> m Result) -> m SubExp -> m (BodyT (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> Result) -> m SubExp -> m Result
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (SubExpRes -> Result
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExpRes -> Result) -> (SubExp -> SubExpRes) -> SubExp -> Result
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> SubExpRes
subExpRes) (m SubExp -> m (BodyT (Rep m))) -> m SubExp -> m (BodyT (Rep m))
forall a b. (a -> b) -> a -> b
$
String -> Exp (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
"binlam_res" (Exp (Rep m) -> m SubExp) -> Exp (Rep m) -> m SubExp
forall a b. (a -> b) -> a -> b
$ BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m)) -> BasicOp -> Exp (Rep m)
forall a b. (a -> b) -> a -> b
$ SubExp -> SubExp -> BasicOp
bop (VName -> SubExp
Var VName
x) (VName -> SubExp
Var VName
y)
Lambda (Rep m) -> m (Lambda (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return
Lambda :: forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda
{ lambdaParams :: [LParam (Rep m)]
lambdaParams =
[ VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
x (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
arg_t),
VName -> Type -> Param Type
forall dec. VName -> dec -> Param dec
Param VName
y (PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
arg_t)
],
lambdaReturnType :: [Type]
lambdaReturnType = [PrimType -> Type
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
ret_t],
lambdaBody :: BodyT (Rep m)
lambdaBody = BodyT (Rep m)
body
}
mkLambda ::
MonadBuilder m =>
[LParam (Rep m)] ->
m Result ->
m (Lambda (Rep m))
mkLambda :: [LParam (Rep m)] -> m Result -> m (Lambda (Rep m))
mkLambda [LParam (Rep m)]
params m Result
m = do
(BodyT (Rep m)
body, [Type]
ret) <- m (Result, [Type]) -> m (BodyT (Rep m), [Type])
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody (m (Result, [Type]) -> m (BodyT (Rep m), [Type]))
-> (m (Result, [Type]) -> m (Result, [Type]))
-> m (Result, [Type])
-> m (BodyT (Rep m), [Type])
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Scope (Rep m) -> m (Result, [Type]) -> m (Result, [Type])
forall rep (m :: * -> *) a.
LocalScope rep m =>
Scope rep -> m a -> m a
localScope ([LParam (Rep m)] -> Scope (Rep m)
forall rep dec. (LParamInfo rep ~ dec) => [Param dec] -> Scope rep
scopeOfLParams [LParam (Rep m)]
params) (m (Result, [Type]) -> m (BodyT (Rep m), [Type]))
-> m (Result, [Type]) -> m (BodyT (Rep m), [Type])
forall a b. (a -> b) -> a -> b
$ do
Result
res <- m Result
m
[Type]
ret <- (SubExpRes -> m Type) -> Result -> m [Type]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExpRes -> m Type
forall t (m :: * -> *). HasScope t m => SubExpRes -> m Type
subExpResType Result
res
(Result, [Type]) -> m (Result, [Type])
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Result
res, [Type]
ret)
Lambda (Rep m) -> m (Lambda (Rep m))
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Lambda (Rep m) -> m (Lambda (Rep m)))
-> Lambda (Rep m) -> m (Lambda (Rep m))
forall a b. (a -> b) -> a -> b
$ [LParam (Rep m)] -> BodyT (Rep m) -> [Type] -> Lambda (Rep m)
forall rep. [LParam rep] -> BodyT rep -> [Type] -> LambdaT rep
Lambda [LParam (Rep m)]
params BodyT (Rep m)
body [Type]
ret
sliceDim :: SubExp -> DimIndex SubExp
sliceDim :: SubExp -> DimIndex SubExp
sliceDim SubExp
d = SubExp -> SubExp -> SubExp -> DimIndex SubExp
forall d. d -> d -> d -> DimIndex d
DimSlice (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)) SubExp
d (Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
1 :: Int64))
fullSlice :: Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice :: Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t [DimIndex SubExp]
slice =
[DimIndex SubExp] -> Slice SubExp
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ [DimIndex SubExp]
slice [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop ([DimIndex SubExp] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex SubExp]
slice) ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
sliceAt :: Type -> Int -> [DimIndex SubExp] -> Slice SubExp
sliceAt :: Type -> Int -> [DimIndex SubExp] -> Slice SubExp
sliceAt Type
t Int
n [DimIndex SubExp]
slice =
Type -> [DimIndex SubExp] -> Slice SubExp
fullSlice Type
t ([DimIndex SubExp] -> Slice SubExp)
-> [DimIndex SubExp] -> Slice SubExp
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp) -> [SubExp] -> [DimIndex SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> DimIndex SubExp
sliceDim (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
n ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t) [DimIndex SubExp] -> [DimIndex SubExp] -> [DimIndex SubExp]
forall a. [a] -> [a] -> [a]
++ [DimIndex SubExp]
slice
fullSliceNum :: Num d => [d] -> [DimIndex d] -> Slice d
fullSliceNum :: [d] -> [DimIndex d] -> Slice d
fullSliceNum [d]
dims [DimIndex d]
slice =
[DimIndex d] -> Slice d
forall d. [DimIndex d] -> Slice d
Slice ([DimIndex d] -> Slice d) -> [DimIndex d] -> Slice d
forall a b. (a -> b) -> a -> b
$ [DimIndex d]
slice [DimIndex d] -> [DimIndex d] -> [DimIndex d]
forall a. [a] -> [a] -> [a]
++ (d -> DimIndex d) -> [d] -> [DimIndex d]
forall a b. (a -> b) -> [a] -> [b]
map (\d
d -> d -> d -> d -> DimIndex d
forall d. d -> d -> d -> DimIndex d
DimSlice d
0 d
d d
1) (Int -> [d] -> [d]
forall a. Int -> [a] -> [a]
drop ([DimIndex d] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [DimIndex d]
slice) [d]
dims)
isFullSlice :: Shape -> Slice SubExp -> Bool
isFullSlice :: Shape -> Slice SubExp -> Bool
isFullSlice Shape
shape Slice SubExp
slice = [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ([Bool] -> Bool) -> [Bool] -> Bool
forall a b. (a -> b) -> a -> b
$ (SubExp -> DimIndex SubExp -> Bool)
-> [SubExp] -> [DimIndex SubExp] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith SubExp -> DimIndex SubExp -> Bool
allOfIt (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape) (Slice SubExp -> [DimIndex SubExp]
forall d. Slice d -> [DimIndex d]
unSlice Slice SubExp
slice)
where
allOfIt :: SubExp -> DimIndex SubExp -> Bool
allOfIt (Constant PrimValue
v) DimFix {} = PrimValue -> Bool
oneIsh PrimValue
v
allOfIt SubExp
d (DimSlice SubExp
_ SubExp
n SubExp
_) = SubExp
d SubExp -> SubExp -> Bool
forall a. Eq a => a -> a -> Bool
== SubExp
n
allOfIt SubExp
_ DimIndex SubExp
_ = Bool
False
ifCommon :: [Type] -> IfDec ExtType
ifCommon :: [Type] -> IfDec ExtType
ifCommon [Type]
ts = [ExtType] -> IfSort -> IfDec ExtType
forall rt. [rt] -> IfSort -> IfDec rt
IfDec ([Type] -> [ExtType]
forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes [Type]
ts) IfSort
IfNormal
resultBody :: Buildable rep => [SubExp] -> Body rep
resultBody :: [SubExp] -> Body rep
resultBody = Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody Stms rep
forall a. Monoid a => a
mempty (Result -> Body rep)
-> ([SubExp] -> Result) -> [SubExp] -> Body rep
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Result
subExpsRes
resultBodyM :: MonadBuilder m => [SubExp] -> m (Body (Rep m))
resultBodyM :: [SubExp] -> m (Body (Rep m))
resultBodyM = Stms (Rep m) -> Result -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
forall a. Monoid a => a
mempty (Result -> m (Body (Rep m)))
-> ([SubExp] -> Result) -> [SubExp] -> m (Body (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [SubExp] -> Result
subExpsRes
insertStmsM ::
(MonadBuilder m) =>
m (Body (Rep m)) ->
m (Body (Rep m))
insertStmsM :: m (Body (Rep m)) -> m (Body (Rep m))
insertStmsM m (Body (Rep m))
m = do
(Body BodyDec (Rep m)
_ Stms (Rep m)
bnds Result
res, Stms (Rep m)
otherbnds) <- m (Body (Rep m)) -> m (Body (Rep m), Stms (Rep m))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms m (Body (Rep m))
m
Stms (Rep m) -> Result -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM (Stms (Rep m)
otherbnds Stms (Rep m) -> Stms (Rep m) -> Stms (Rep m)
forall a. Semigroup a => a -> a -> a
<> Stms (Rep m)
bnds) Result
res
buildBody ::
MonadBuilder m =>
m (Result, a) ->
m (Body (Rep m), a)
buildBody :: m (Result, a) -> m (Body (Rep m), a)
buildBody m (Result, a)
m = do
((Result
res, a
v), Stms (Rep m)
stms) <- m (Result, a) -> m ((Result, a), Stms (Rep m))
forall (m :: * -> *) a.
MonadBuilder m =>
m a -> m (a, Stms (Rep m))
collectStms m (Result, a)
m
Body (Rep m)
body <- Stms (Rep m) -> Result -> m (Body (Rep m))
forall (m :: * -> *).
MonadBuilder m =>
Stms (Rep m) -> Result -> m (Body (Rep m))
mkBodyM Stms (Rep m)
stms Result
res
(Body (Rep m), a) -> m (Body (Rep m), a)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Body (Rep m)
body, a
v)
buildBody_ ::
MonadBuilder m =>
m Result ->
m (Body (Rep m))
buildBody_ :: m Result -> m (Body (Rep m))
buildBody_ m Result
m = (Body (Rep m), ()) -> Body (Rep m)
forall a b. (a, b) -> a
fst ((Body (Rep m), ()) -> Body (Rep m))
-> m (Body (Rep m), ()) -> m (Body (Rep m))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m (Result, ()) -> m (Body (Rep m), ())
forall (m :: * -> *) a.
MonadBuilder m =>
m (Result, a) -> m (Body (Rep m), a)
buildBody ((,()) (Result -> (Result, ())) -> m Result -> m (Result, ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> m Result
m)
mapResult ::
Buildable rep =>
(Result -> Body rep) ->
Body rep ->
Body rep
mapResult :: (Result -> Body rep) -> Body rep -> Body rep
mapResult Result -> Body rep
f (Body BodyDec rep
_ Stms rep
bnds Result
res) =
let Body BodyDec rep
_ Stms rep
bnds2 Result
newres = Result -> Body rep
f Result
res
in Stms rep -> Result -> Body rep
forall rep. Buildable rep => Stms rep -> Result -> Body rep
mkBody (Stms rep
bnds Stms rep -> Stms rep -> Stms rep
forall a. Semigroup a => a -> a -> a
<> Stms rep
bnds2) Result
newres
instantiateShapes ::
Monad m =>
(Int -> m SubExp) ->
[TypeBase ExtShape u] ->
m [TypeBase Shape u]
instantiateShapes :: (Int -> m SubExp) -> [TypeBase ExtShape u] -> m [TypeBase Shape u]
instantiateShapes Int -> m SubExp
f [TypeBase ExtShape u]
ts = StateT (Map Int SubExp) m [TypeBase Shape u]
-> Map Int SubExp -> m [TypeBase Shape u]
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((TypeBase ExtShape u
-> StateT (Map Int SubExp) m (TypeBase Shape u))
-> [TypeBase ExtShape u]
-> StateT (Map Int SubExp) m [TypeBase Shape u]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM TypeBase ExtShape u -> StateT (Map Int SubExp) m (TypeBase Shape u)
instantiate [TypeBase ExtShape u]
ts) Map Int SubExp
forall k a. Map k a
M.empty
where
instantiate :: TypeBase ExtShape u -> StateT (Map Int SubExp) m (TypeBase Shape u)
instantiate TypeBase ExtShape u
t = do
[SubExp]
shape <- (Ext SubExp -> StateT (Map Int SubExp) m SubExp)
-> [Ext SubExp] -> StateT (Map Int SubExp) m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ext SubExp -> StateT (Map Int SubExp) m SubExp
instantiate' ([Ext SubExp] -> StateT (Map Int SubExp) m [SubExp])
-> [Ext SubExp] -> StateT (Map Int SubExp) m [SubExp]
forall a b. (a -> b) -> a -> b
$ ExtShape -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims (ExtShape -> [Ext SubExp]) -> ExtShape -> [Ext SubExp]
forall a b. (a -> b) -> a -> b
$ TypeBase ExtShape u -> ExtShape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase ExtShape u
t
TypeBase Shape u -> StateT (Map Int SubExp) m (TypeBase Shape u)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeBase Shape u -> StateT (Map Int SubExp) m (TypeBase Shape u))
-> TypeBase Shape u -> StateT (Map Int SubExp) m (TypeBase Shape u)
forall a b. (a -> b) -> a -> b
$ TypeBase ExtShape u
t TypeBase ExtShape u -> Shape -> TypeBase Shape u
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
shape
instantiate' :: Ext SubExp -> StateT (Map Int SubExp) m SubExp
instantiate' (Ext Int
x) = do
Map Int SubExp
m <- StateT (Map Int SubExp) m (Map Int SubExp)
forall s (m :: * -> *). MonadState s m => m s
get
case Int -> Map Int SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Int
x Map Int SubExp
m of
Just SubExp
se -> SubExp -> StateT (Map Int SubExp) m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se
Maybe SubExp
Nothing -> do
SubExp
se <- m SubExp -> StateT (Map Int SubExp) m SubExp
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (m SubExp -> StateT (Map Int SubExp) m SubExp)
-> m SubExp -> StateT (Map Int SubExp) m SubExp
forall a b. (a -> b) -> a -> b
$ Int -> m SubExp
f Int
x
Map Int SubExp -> StateT (Map Int SubExp) m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Map Int SubExp -> StateT (Map Int SubExp) m ())
-> Map Int SubExp -> StateT (Map Int SubExp) m ()
forall a b. (a -> b) -> a -> b
$ Int -> SubExp -> Map Int SubExp -> Map Int SubExp
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Int
x SubExp
se Map Int SubExp
m
SubExp -> StateT (Map Int SubExp) m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se
instantiate' (Free SubExp
se) = SubExp -> StateT (Map Int SubExp) m SubExp
forall (m :: * -> *) a. Monad m => a -> m a
return SubExp
se
instantiateShapes' :: [VName] -> [TypeBase ExtShape u] -> [TypeBase Shape u]
instantiateShapes' :: [VName] -> [TypeBase ExtShape u] -> [TypeBase Shape u]
instantiateShapes' [VName]
names [TypeBase ExtShape u]
ts =
Identity [TypeBase Shape u] -> [TypeBase Shape u]
forall a. Identity a -> a
runIdentity (Identity [TypeBase Shape u] -> [TypeBase Shape u])
-> Identity [TypeBase Shape u] -> [TypeBase Shape u]
forall a b. (a -> b) -> a -> b
$ (Int -> Identity SubExp)
-> [TypeBase ExtShape u] -> Identity [TypeBase Shape u]
forall (m :: * -> *) u.
Monad m =>
(Int -> m SubExp) -> [TypeBase ExtShape u] -> m [TypeBase Shape u]
instantiateShapes Int -> Identity SubExp
instantiate [TypeBase ExtShape u]
ts
where
instantiate :: Int -> Identity SubExp
instantiate Int
x =
case Int -> [VName] -> Maybe VName
forall int a. Integral int => int -> [a] -> Maybe a
maybeNth Int
x [VName]
names of
Maybe VName
Nothing -> String -> Identity SubExp
forall a. HasCallStack => String -> a
error (String -> Identity SubExp) -> String -> Identity SubExp
forall a b. (a -> b) -> a -> b
$ String
"instantiateShapes': " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [VName] -> String
forall a. Pretty a => a -> String
pretty [VName]
names String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
", " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
x
Just VName
name -> SubExp -> Identity SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> Identity SubExp) -> SubExp -> Identity SubExp
forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
name
removeExistentials :: ExtType -> Type -> Type
removeExistentials :: ExtType -> Type -> Type
removeExistentials ExtType
t1 Type
t2 =
ExtType
t1
ExtType -> [SubExp] -> Type
forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
`setArrayDims` (Ext SubExp -> SubExp -> SubExp)
-> [Ext SubExp] -> [SubExp] -> [SubExp]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
Ext SubExp -> SubExp -> SubExp
forall p. Ext p -> p -> p
nonExistential
(ExtShape -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims (ExtShape -> [Ext SubExp]) -> ExtShape -> [Ext SubExp]
forall a b. (a -> b) -> a -> b
$ ExtType -> ExtShape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape ExtType
t1)
(Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t2)
where
nonExistential :: Ext p -> p -> p
nonExistential (Ext Int
_) p
dim = p
dim
nonExistential (Free p
dim) p
_ = p
dim
simpleMkLetNames ::
( ExpDec rep ~ (),
LetDec rep ~ Type,
MonadFreshNames m,
TypedOp (Op rep),
HasScope rep m
) =>
[VName] ->
Exp rep ->
m (Stm rep)
simpleMkLetNames :: [VName] -> Exp rep -> m (Stm rep)
simpleMkLetNames [VName]
names Exp rep
e = do
[ExtType]
et <- Exp rep -> m [ExtType]
forall rep (m :: * -> *).
(HasScope rep m, TypedOp (Op rep)) =>
Exp rep -> m [ExtType]
expExtType Exp rep
e
let ts :: [Type]
ts = [VName] -> [ExtType] -> [Type]
forall u. [VName] -> [TypeBase ExtShape u] -> [TypeBase Shape u]
instantiateShapes' [VName]
names [ExtType]
et
Stm rep -> m (Stm rep)
forall (m :: * -> *) a. Monad m => a -> m a
return (Stm rep -> m (Stm rep)) -> Stm rep -> m (Stm rep)
forall a b. (a -> b) -> a -> b
$ Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
forall rep. Pat rep -> StmAux (ExpDec rep) -> Exp rep -> Stm rep
Let ([PatElemT Type] -> PatT Type
forall dec. [PatElemT dec] -> PatT dec
Pat ([PatElemT Type] -> PatT Type) -> [PatElemT Type] -> PatT Type
forall a b. (a -> b) -> a -> b
$ (VName -> Type -> PatElemT Type)
-> [VName] -> [Type] -> [PatElemT Type]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith VName -> Type -> PatElemT Type
forall dec. VName -> dec -> PatElemT dec
PatElem [VName]
names [Type]
ts) (() -> StmAux ()
forall dec. dec -> StmAux dec
defAux ()) Exp rep
e
class ToExp a where
toExp :: MonadBuilder m => a -> m (Exp (Rep m))
instance ToExp SubExp where
toExp :: SubExp -> m (Exp (Rep m))
toExp = Exp (Rep m) -> m (Exp (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Rep m) -> m (Exp (Rep m)))
-> (SubExp -> Exp (Rep m)) -> SubExp -> m (Exp (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m))
-> (SubExp -> BasicOp) -> SubExp -> Exp (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp
instance ToExp VName where
toExp :: VName -> m (Exp (Rep m))
toExp = Exp (Rep m) -> m (Exp (Rep m))
forall (m :: * -> *) a. Monad m => a -> m a
return (Exp (Rep m) -> m (Exp (Rep m)))
-> (VName -> Exp (Rep m)) -> VName -> m (Exp (Rep m))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BasicOp -> Exp (Rep m)
forall rep. BasicOp -> ExpT rep
BasicOp (BasicOp -> Exp (Rep m))
-> (VName -> BasicOp) -> VName -> Exp (Rep m)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> BasicOp
SubExp (SubExp -> BasicOp) -> (VName -> SubExp) -> VName -> BasicOp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SubExp
Var
toSubExp :: (MonadBuilder m, ToExp a) => String -> a -> m SubExp
toSubExp :: String -> a -> m SubExp
toSubExp String
s a
e = String -> ExpT (Rep m) -> m SubExp
forall (m :: * -> *).
MonadBuilder m =>
String -> Exp (Rep m) -> m SubExp
letSubExp String
s (ExpT (Rep m) -> m SubExp) -> m (ExpT (Rep m)) -> m SubExp
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< a -> m (ExpT (Rep m))
forall a (m :: * -> *).
(ToExp a, MonadBuilder m) =>
a -> m (Exp (Rep m))
toExp a
e