module Futhark.Internalise.AccurateSizes
( argShapes,
ensureResultShape,
ensureResultExtShape,
ensureExtShape,
ensureShape,
ensureArgShapes,
)
where
import Control.Monad
import Data.Map.Strict qualified as M
import Data.Maybe
import Futhark.Construct
import Futhark.IR.SOACS
import Futhark.Internalise.Monad
import Futhark.Util (takeLast)
shapeMapping ::
(HasScope SOACS m, Monad m) =>
[FParam SOACS] ->
[Type] ->
m (M.Map VName SubExp)
shapeMapping :: forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
[FParam SOACS] -> [Type] -> m (Map VName SubExp)
shapeMapping [FParam SOACS]
all_params [Type]
value_arg_types =
forall a. Monoid a => [a] -> a
mconcat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {f :: * -> *}.
Monad f =>
Type -> Type -> f (Map VName SubExp)
f (forall a b. (a -> b) -> [a] -> [b]
map forall dec. Typed dec => Param dec -> Type
paramType [Param DeclType]
value_params) [Type]
value_arg_types
where
value_params :: [Param DeclType]
value_params = forall a. Int -> [a] -> [a]
takeLast (forall (t :: * -> *) a. Foldable t => t a -> Int
length [Type]
value_arg_types) [FParam SOACS]
all_params
f :: Type -> Type -> f (Map VName SubExp)
f t1 :: Type
t1@Array {} t2 :: Type
t2@Array {} =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall a b. (a -> b) -> a -> b
$ forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {b}. (SubExp, b) -> Maybe (VName, b)
match forall a b. (a -> b) -> a -> b
$ forall a b. [a] -> [b] -> [(a, b)]
zip (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t1) (forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t2)
f (Acc VName
acc1 Shape
ispace1 [Type]
ts1 NoUniqueness
_) (Acc VName
acc2 Shape
ispace2 [Type]
ts2 NoUniqueness
_) = do
let ispace_m :: Map VName SubExp
ispace_m =
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe forall {b}. (SubExp, b) -> Maybe (VName, b)
match forall a b. (a -> b) -> a -> b
$
forall a b. [a] -> [b] -> [(a, b)]
zip (forall d. ShapeBase d -> [d]
shapeDims Shape
ispace1) (forall d. ShapeBase d -> [d]
shapeDims Shape
ispace2)
Map VName SubExp
arr_sizes_m <- forall a. Monoid a => [a] -> a
mconcat forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Type -> Type -> f (Map VName SubExp)
f [Type]
ts1 [Type]
ts2
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall k a. k -> a -> Map k a
M.singleton VName
acc1 (VName -> SubExp
Var VName
acc2) forall a. Semigroup a => a -> a -> a
<> Map VName SubExp
ispace_m forall a. Semigroup a => a -> a -> a
<> Map VName SubExp
arr_sizes_m
f Type
_ Type
_ =
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a. Monoid a => a
mempty
match :: (SubExp, b) -> Maybe (VName, b)
match (Var VName
v, b
se) = forall a. a -> Maybe a
Just (VName
v, b
se)
match (SubExp, b)
_ = forall a. Maybe a
Nothing
argShapes :: [VName] -> [FParam SOACS] -> [Type] -> InternaliseM [SubExp]
argShapes :: [VName] -> [FParam SOACS] -> [Type] -> InternaliseM [SubExp]
argShapes [VName]
shapes [FParam SOACS]
all_params [Type]
valargts = do
Map VName SubExp
mapping <- forall (m :: * -> *).
(HasScope SOACS m, Monad m) =>
[FParam SOACS] -> [Type] -> m (Map VName SubExp)
shapeMapping [FParam SOACS]
all_params [Type]
valargts
let addShape :: VName -> SubExp
addShape VName
name =
case forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
name Map VName SubExp
mapping of
Just SubExp
se -> SubExp
se
Maybe SubExp
_ -> forall a. HasCallStack => [Char] -> a
error forall a b. (a -> b) -> a -> b
$ [Char]
"argShapes: " forall a. [a] -> [a] -> [a]
++ forall a. Pretty a => a -> [Char]
prettyString VName
name
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map VName -> SubExp
addShape [VName]
shapes
ensureResultShape ::
ErrorMsg SubExp ->
SrcLoc ->
[Type] ->
Result ->
InternaliseM Result
ensureResultShape :: ErrorMsg SubExp
-> SrcLoc -> [Type] -> Result -> InternaliseM Result
ensureResultShape ErrorMsg SubExp
msg SrcLoc
loc =
ErrorMsg SubExp
-> SrcLoc -> [ExtType] -> Result -> InternaliseM Result
ensureResultExtShape ErrorMsg SubExp
msg SrcLoc
loc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes
ensureResultExtShape ::
ErrorMsg SubExp ->
SrcLoc ->
[ExtType] ->
Result ->
InternaliseM Result
ensureResultExtShape :: ErrorMsg SubExp
-> SrcLoc -> [ExtType] -> Result -> InternaliseM Result
ensureResultExtShape ErrorMsg SubExp
msg SrcLoc
loc [ExtType]
rettype Result
res = do
Result
res' <- ErrorMsg SubExp
-> SrcLoc -> [ExtType] -> Result -> InternaliseM Result
ensureResultExtShapeNoCtx ErrorMsg SubExp
msg SrcLoc
loc [ExtType]
rettype Result
res
[Type]
ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
SubExpRes -> m Type
subExpResType Result
res'
let ctx :: [SubExp]
ctx = forall u a. [TypeBase ExtShape u] -> [[a]] -> [a]
extractShapeContext [ExtType]
rettype forall a b. (a -> b) -> a -> b
$ forall a b. (a -> b) -> [a] -> [b]
map forall u. TypeBase Shape u -> [SubExp]
arrayDims [Type]
ts
forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ [SubExp] -> Result
subExpsRes [SubExp]
ctx forall a. [a] -> [a] -> [a]
++ Result
res'
ensureResultExtShapeNoCtx ::
ErrorMsg SubExp ->
SrcLoc ->
[ExtType] ->
Result ->
InternaliseM Result
ensureResultExtShapeNoCtx :: ErrorMsg SubExp
-> SrcLoc -> [ExtType] -> Result -> InternaliseM Result
ensureResultExtShapeNoCtx ErrorMsg SubExp
msg SrcLoc
loc [ExtType]
rettype Result
es = do
[Type]
es_ts <- forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM forall {k} (t :: k) (m :: * -> *).
HasScope t m =>
SubExpRes -> m Type
subExpResType Result
es
let ext_mapping :: Map Int SubExp
ext_mapping = forall u u1.
[TypeBase ExtShape u] -> [TypeBase Shape u1] -> Map Int SubExp
shapeExtMapping [ExtType]
rettype [Type]
es_ts
rettype' :: [ExtType]
rettype' = forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry forall t. FixExt t => Int -> SubExp -> t -> t
fixExt) [ExtType]
rettype forall a b. (a -> b) -> a -> b
$ forall k a. Map k a -> [(k, a)]
M.toList Map Int SubExp
ext_mapping
assertProperShape :: ExtType -> SubExpRes -> InternaliseM SubExpRes
assertProperShape ExtType
t (SubExpRes Certs
cs SubExp
se) =
let name :: [Char]
name = [Char]
"result_proper_shape"
in Certs -> SubExp -> SubExpRes
SubExpRes Certs
cs forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ErrorMsg SubExp
-> SrcLoc -> ExtType -> [Char] -> SubExp -> InternaliseM SubExp
ensureExtShape ErrorMsg SubExp
msg SrcLoc
loc ExtType
t [Char]
name SubExp
se
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM ExtType -> SubExpRes -> InternaliseM SubExpRes
assertProperShape [ExtType]
rettype' Result
es
ensureExtShape ::
ErrorMsg SubExp ->
SrcLoc ->
ExtType ->
String ->
SubExp ->
InternaliseM SubExp
ensureExtShape :: ErrorMsg SubExp
-> SrcLoc -> ExtType -> [Char] -> SubExp -> InternaliseM SubExp
ensureExtShape ErrorMsg SubExp
msg SrcLoc
loc ExtType
t [Char]
name SubExp
orig
| Array {} <- ExtType
t,
Var VName
v <- SubExp
orig =
VName -> SubExp
Var forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ErrorMsg SubExp
-> SrcLoc -> ExtType -> [Char] -> VName -> InternaliseM VName
ensureShapeVar ErrorMsg SubExp
msg SrcLoc
loc ExtType
t [Char]
name VName
v
| Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure SubExp
orig
ensureShape ::
ErrorMsg SubExp ->
SrcLoc ->
Type ->
String ->
SubExp ->
InternaliseM SubExp
ensureShape :: ErrorMsg SubExp
-> SrcLoc -> Type -> [Char] -> SubExp -> InternaliseM SubExp
ensureShape ErrorMsg SubExp
msg SrcLoc
loc = ErrorMsg SubExp
-> SrcLoc -> ExtType -> [Char] -> SubExp -> InternaliseM SubExp
ensureExtShape ErrorMsg SubExp
msg SrcLoc
loc forall b c a. (b -> c) -> (a -> b) -> a -> c
. forall u. TypeBase Shape u -> TypeBase ExtShape u
staticShapes1
ensureArgShapes ::
(Typed (TypeBase Shape u)) =>
ErrorMsg SubExp ->
SrcLoc ->
[VName] ->
[TypeBase Shape u] ->
[SubExp] ->
InternaliseM [SubExp]
ensureArgShapes :: forall u.
Typed (TypeBase Shape u) =>
ErrorMsg SubExp
-> SrcLoc
-> [VName]
-> [TypeBase Shape u]
-> [SubExp]
-> InternaliseM [SubExp]
ensureArgShapes ErrorMsg SubExp
msg SrcLoc
loc [VName]
shapes [TypeBase Shape u]
paramts [SubExp]
args =
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Type -> SubExp -> InternaliseM SubExp
ensureArgShape (forall t. Typed t => [VName] -> [t] -> [SubExp] -> [Type]
expectedTypes [VName]
shapes [TypeBase Shape u]
paramts [SubExp]
args) [SubExp]
args
where
ensureArgShape :: Type -> SubExp -> InternaliseM SubExp
ensureArgShape Type
_ (Constant PrimValue
v) = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ PrimValue -> SubExp
Constant PrimValue
v
ensureArgShape Type
t (Var VName
v)
| forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t forall a. Ord a => a -> a -> Bool
< Int
1 = forall (f :: * -> *) a. Applicative f => a -> f a
pure forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
| Bool
otherwise =
ErrorMsg SubExp
-> SrcLoc -> Type -> [Char] -> SubExp -> InternaliseM SubExp
ensureShape ErrorMsg SubExp
msg SrcLoc
loc Type
t (VName -> [Char]
baseString VName
v) forall a b. (a -> b) -> a -> b
$ VName -> SubExp
Var VName
v
ensureShapeVar ::
ErrorMsg SubExp ->
SrcLoc ->
ExtType ->
String ->
VName ->
InternaliseM VName
ensureShapeVar :: ErrorMsg SubExp
-> SrcLoc -> ExtType -> [Char] -> VName -> InternaliseM VName
ensureShapeVar ErrorMsg SubExp
msg SrcLoc
loc ExtType
t [Char]
name VName
v
| Array {} <- ExtType
t = do
[SubExp]
newdims <- forall u. TypeBase Shape u -> [SubExp]
arrayDims forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExtType -> Type -> Type
removeExistentials ExtType
t forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
v
[SubExp]
olddims <- forall u. TypeBase Shape u -> [SubExp]
arrayDims forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> forall {k} (rep :: k) (m :: * -> *).
HasScope rep m =>
VName -> m Type
lookupType VName
v
if [SubExp]
newdims forall a. Eq a => a -> a -> Bool
== [SubExp]
olddims
then forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
else do
[SubExp]
matches <- forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM forall {m :: * -> *}.
MonadBuilder m =>
SubExp -> SubExp -> m SubExp
checkDim [SubExp]
newdims [SubExp]
olddims
SubExp
all_match <- forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"match" forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< forall (m :: * -> *). MonadBuilder m => [SubExp] -> m (Exp (Rep m))
eAll [SubExp]
matches
Certs
cs <- [Char] -> SubExp -> ErrorMsg SubExp -> SrcLoc -> InternaliseM Certs
assert [Char]
"empty_or_match_cert" SubExp
all_match ErrorMsg SubExp
msg SrcLoc
loc
forall (m :: * -> *) a. MonadBuilder m => Certs -> m a -> m a
certifying Certs
cs forall a b. (a -> b) -> a -> b
$ forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m VName
letExp [Char]
name forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). [SubExp] -> VName -> Exp rep
shapeCoerce [SubExp]
newdims VName
v
| Bool
otherwise = forall (f :: * -> *) a. Applicative f => a -> f a
pure VName
v
where
checkDim :: SubExp -> SubExp -> m SubExp
checkDim SubExp
desired SubExp
has =
forall (m :: * -> *).
MonadBuilder m =>
[Char] -> Exp (Rep m) -> m SubExp
letSubExp [Char]
"dim_match" forall a b. (a -> b) -> a -> b
$ forall {k} (rep :: k). BasicOp -> Exp rep
BasicOp forall a b. (a -> b) -> a -> b
$ CmpOp -> SubExp -> SubExp -> BasicOp
CmpOp (PrimType -> CmpOp
CmpEq PrimType
int64) SubExp
desired SubExp
has