{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE TypeFamilies #-}
module Futhark.IR.RetType
( IsBodyType (..),
IsRetType (..),
expectedTypes,
)
where
import Control.Monad.Identity
import qualified Data.Map.Strict as M
import Futhark.IR.Prop.Types
import Futhark.IR.Syntax.Core
class (Show rt, Eq rt, Ord rt, ExtTyped rt) => IsBodyType rt where
primBodyType :: PrimType -> rt
instance IsBodyType ExtType where
primBodyType :: PrimType -> ExtType
primBodyType = PrimType -> ExtType
forall shape u. PrimType -> TypeBase shape u
Prim
class (Show rt, Eq rt, Ord rt, DeclExtTyped rt) => IsRetType rt where
primRetType :: PrimType -> rt
applyRetType ::
Typed dec =>
[rt] ->
[Param dec] ->
[(SubExp, Type)] ->
Maybe [rt]
expectedTypes :: Typed t => [VName] -> [t] -> [SubExp] -> [Type]
expectedTypes :: forall t. Typed t => [VName] -> [t] -> [SubExp] -> [Type]
expectedTypes [VName]
shapes [t]
value_ts [SubExp]
args = (t -> Type) -> [t] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (Type -> Type
forall {u}. TypeBase Shape u -> TypeBase Shape u
correctDims (Type -> Type) -> (t -> Type) -> t -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. t -> Type
forall t. Typed t => t -> Type
typeOf) [t]
value_ts
where
parammap :: M.Map VName SubExp
parammap :: Map VName SubExp
parammap = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip [VName]
shapes [SubExp]
args
correctDims :: TypeBase Shape u -> TypeBase Shape u
correctDims = Identity (TypeBase Shape u) -> TypeBase Shape u
forall a. Identity a -> a
runIdentity (Identity (TypeBase Shape u) -> TypeBase Shape u)
-> (TypeBase Shape u -> Identity (TypeBase Shape u))
-> TypeBase Shape u
-> TypeBase Shape u
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> Identity SubExp)
-> TypeBase Shape u -> Identity (TypeBase Shape u)
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp) -> TypeBase Shape u -> m (TypeBase Shape u)
mapOnType (SubExp -> Identity SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> Identity SubExp)
-> (SubExp -> SubExp) -> SubExp -> Identity SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> SubExp
f)
where
f :: SubExp -> SubExp
f (Var VName
v)
| Just SubExp
se <- VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
parammap = SubExp
se
f SubExp
se = SubExp
se
instance IsRetType DeclExtType where
primRetType :: PrimType -> DeclExtType
primRetType = PrimType -> DeclExtType
forall shape u. PrimType -> TypeBase shape u
Prim
applyRetType :: forall dec.
Typed dec =>
[DeclExtType]
-> [Param dec] -> [(SubExp, Type)] -> Maybe [DeclExtType]
applyRetType [DeclExtType]
extret [Param dec]
params [(SubExp, Type)]
args =
if [(SubExp, Type)] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(SubExp, Type)]
args Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Param dec] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Param dec]
params
Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and
( (Type -> Type -> Bool) -> [Type] -> [Type] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith Type -> Type -> Bool
forall u shape.
(Ord u, ArrayShape shape) =>
TypeBase shape u -> TypeBase shape u -> Bool
subtypeOf [Type]
argtypes ([Type] -> [Bool]) -> [Type] -> [Bool]
forall a b. (a -> b) -> a -> b
$
[VName] -> [Param dec] -> [SubExp] -> [Type]
forall t. Typed t => [VName] -> [t] -> [SubExp] -> [Type]
expectedTypes ((Param dec -> VName) -> [Param dec] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param dec -> VName
forall dec. Param dec -> VName
paramName [Param dec]
params) [Param dec]
params ([SubExp] -> [Type]) -> [SubExp] -> [Type]
forall a b. (a -> b) -> a -> b
$ ((SubExp, Type) -> SubExp) -> [(SubExp, Type)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, Type) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, Type)]
args
)
then [DeclExtType] -> Maybe [DeclExtType]
forall a. a -> Maybe a
Just ([DeclExtType] -> Maybe [DeclExtType])
-> [DeclExtType] -> Maybe [DeclExtType]
forall a b. (a -> b) -> a -> b
$ (DeclExtType -> DeclExtType) -> [DeclExtType] -> [DeclExtType]
forall a b. (a -> b) -> [a] -> [b]
map DeclExtType -> DeclExtType
forall {u}.
TypeBase (ShapeBase ExtSize) u -> TypeBase (ShapeBase ExtSize) u
correctExtDims [DeclExtType]
extret
else Maybe [DeclExtType]
forall a. Maybe a
Nothing
where
argtypes :: [Type]
argtypes = ((SubExp, Type) -> Type) -> [(SubExp, Type)] -> [Type]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, Type) -> Type
forall a b. (a, b) -> b
snd [(SubExp, Type)]
args
parammap :: M.Map VName SubExp
parammap :: Map VName SubExp
parammap = [(VName, SubExp)] -> Map VName SubExp
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([(VName, SubExp)] -> Map VName SubExp)
-> [(VName, SubExp)] -> Map VName SubExp
forall a b. (a -> b) -> a -> b
$ [VName] -> [SubExp] -> [(VName, SubExp)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Param dec -> VName) -> [Param dec] -> [VName]
forall a b. (a -> b) -> [a] -> [b]
map Param dec -> VName
forall dec. Param dec -> VName
paramName [Param dec]
params) (((SubExp, Type) -> SubExp) -> [(SubExp, Type)] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (SubExp, Type) -> SubExp
forall a b. (a, b) -> a
fst [(SubExp, Type)]
args)
correctExtDims :: TypeBase (ShapeBase ExtSize) u -> TypeBase (ShapeBase ExtSize) u
correctExtDims = Identity (TypeBase (ShapeBase ExtSize) u)
-> TypeBase (ShapeBase ExtSize) u
forall a. Identity a -> a
runIdentity (Identity (TypeBase (ShapeBase ExtSize) u)
-> TypeBase (ShapeBase ExtSize) u)
-> (TypeBase (ShapeBase ExtSize) u
-> Identity (TypeBase (ShapeBase ExtSize) u))
-> TypeBase (ShapeBase ExtSize) u
-> TypeBase (ShapeBase ExtSize) u
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (SubExp -> Identity SubExp)
-> TypeBase (ShapeBase ExtSize) u
-> Identity (TypeBase (ShapeBase ExtSize) u)
forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase ExtSize) u
-> m (TypeBase (ShapeBase ExtSize) u)
mapOnExtType (SubExp -> Identity SubExp
forall (f :: * -> *) a. Applicative f => a -> f a
pure (SubExp -> Identity SubExp)
-> (SubExp -> SubExp) -> SubExp -> Identity SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SubExp -> SubExp
f)
where
f :: SubExp -> SubExp
f (Var VName
v)
| Just SubExp
se <- VName -> Map VName SubExp -> Maybe SubExp
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup VName
v Map VName SubExp
parammap = SubExp
se
f SubExp
se = SubExp
se