{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}

-- | Functions for inspecting and constructing various types.
module Futhark.IR.Prop.Types
  ( rankShaped,
    arrayRank,
    arrayShape,
    setArrayShape,
    existential,
    uniqueness,
    unique,
    staticShapes,
    staticShapes1,
    primType,
    arrayOf,
    arrayOfRow,
    arrayOfShape,
    setOuterSize,
    setDimSize,
    setOuterDim,
    setDim,
    setArrayDims,
    peelArray,
    stripArray,
    arrayDims,
    arrayExtDims,
    shapeSize,
    arraySize,
    arraysSize,
    rowType,
    elemType,
    transposeType,
    rearrangeType,
    mapOnExtType,
    mapOnType,
    diet,
    subtypeOf,
    subtypesOf,
    toDecl,
    fromDecl,
    isExt,
    isFree,
    extractShapeContext,
    shapeContext,
    hasStaticShape,
    generaliseExtTypes,
    existentialiseExtTypes,
    shapeExtMapping,

    -- * Abbreviations
    int8,
    int16,
    int32,
    int64,
    float32,
    float64,

    -- * The Typed typeclass
    Typed (..),
    DeclTyped (..),
    ExtTyped (..),
    DeclExtTyped (..),
    SetType (..),
    FixExt (..),
  )
where

import Control.Monad.State
import Data.List (elemIndex, foldl')
import qualified Data.Map.Strict as M
import Data.Maybe
import qualified Data.Set as S
import Futhark.IR.Prop.Constants
import Futhark.IR.Prop.Rearrange
import Futhark.IR.Syntax.Core

-- | Remove shape information from a type.
rankShaped :: ArrayShape shape => TypeBase shape u -> TypeBase Rank u
rankShaped :: forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase Rank u
rankShaped (Array PrimType
et shape
sz u
u) = PrimType -> Rank -> u -> TypeBase Rank u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et (Int -> Rank
Rank (Int -> Rank) -> Int -> Rank
forall a b. (a -> b) -> a -> b
$ shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank shape
sz) u
u
rankShaped (Prim PrimType
et) = PrimType -> TypeBase Rank u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
et
rankShaped (Mem Space
space) = Space -> TypeBase Rank u
forall shape u. Space -> TypeBase shape u
Mem Space
space

-- | Return the dimensionality of a type.  For non-arrays, this is
-- zero.  For a one-dimensional array it is one, for a two-dimensional
-- it is two, and so forth.
arrayRank :: ArrayShape shape => TypeBase shape u -> Int
arrayRank :: forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank = shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank (shape -> Int)
-> (TypeBase shape u -> shape) -> TypeBase shape u -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase shape u -> shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape

-- | Return the shape of a type - for non-arrays, this is the
-- 'mempty'.
arrayShape :: ArrayShape shape => TypeBase shape u -> shape
arrayShape :: forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape (Array PrimType
_ shape
ds u
_) = shape
ds
arrayShape TypeBase shape u
_ = shape
forall a. Monoid a => a
mempty

-- | Modify the shape of an array - for non-arrays, this does nothing.
modifyArrayShape ::
  ArrayShape newshape =>
  (oldshape -> newshape) ->
  TypeBase oldshape u ->
  TypeBase newshape u
modifyArrayShape :: forall newshape oldshape u.
ArrayShape newshape =>
(oldshape -> newshape)
-> TypeBase oldshape u -> TypeBase newshape u
modifyArrayShape oldshape -> newshape
f (Array PrimType
t oldshape
ds u
u)
  | newshape -> Int
forall a. ArrayShape a => a -> Int
shapeRank newshape
ds' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = PrimType -> TypeBase newshape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
  | Bool
otherwise = PrimType -> newshape -> u -> TypeBase newshape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
t (oldshape -> newshape
f oldshape
ds) u
u
  where
    ds' :: newshape
ds' = oldshape -> newshape
f oldshape
ds
modifyArrayShape oldshape -> newshape
_ (Prim PrimType
t) = PrimType -> TypeBase newshape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
t
modifyArrayShape oldshape -> newshape
_ (Mem Space
space) = Space -> TypeBase newshape u
forall shape u. Space -> TypeBase shape u
Mem Space
space

-- | Set the shape of an array.  If the given type is not an
-- array, return the type unchanged.
setArrayShape ::
  ArrayShape newshape =>
  TypeBase oldshape u ->
  newshape ->
  TypeBase newshape u
setArrayShape :: forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
setArrayShape TypeBase oldshape u
t newshape
ds = (oldshape -> newshape)
-> TypeBase oldshape u -> TypeBase newshape u
forall newshape oldshape u.
ArrayShape newshape =>
(oldshape -> newshape)
-> TypeBase oldshape u -> TypeBase newshape u
modifyArrayShape (newshape -> oldshape -> newshape
forall a b. a -> b -> a
const newshape
ds) TypeBase oldshape u
t

-- | True if the given type has a dimension that is existentially sized.
existential :: ExtType -> Bool
existential :: ExtType -> Bool
existential = (Ext SubExp -> Bool) -> [Ext SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
any Ext SubExp -> Bool
forall {a}. Ext a -> Bool
ext ([Ext SubExp] -> Bool)
-> (ExtType -> [Ext SubExp]) -> ExtType -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims (ShapeBase (Ext SubExp) -> [Ext SubExp])
-> (ExtType -> ShapeBase (Ext SubExp)) -> ExtType -> [Ext SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExtType -> ShapeBase (Ext SubExp)
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape
  where
    ext :: Ext a -> Bool
ext (Ext Int
_) = Bool
True
    ext (Free a
_) = Bool
False

-- | Return the uniqueness of a type.
uniqueness :: TypeBase shape Uniqueness -> Uniqueness
uniqueness :: forall shape. TypeBase shape Uniqueness -> Uniqueness
uniqueness (Array PrimType
_ shape
_ Uniqueness
u) = Uniqueness
u
uniqueness TypeBase shape Uniqueness
_ = Uniqueness
Nonunique

-- | @unique t@ is 'True' if the type of the argument is unique.
unique :: TypeBase shape Uniqueness -> Bool
unique :: forall shape. TypeBase shape Uniqueness -> Bool
unique = (Uniqueness -> Uniqueness -> Bool
forall a. Eq a => a -> a -> Bool
== Uniqueness
Unique) (Uniqueness -> Bool)
-> (TypeBase shape Uniqueness -> Uniqueness)
-> TypeBase shape Uniqueness
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase shape Uniqueness -> Uniqueness
forall shape. TypeBase shape Uniqueness -> Uniqueness
uniqueness

-- | Convert types with non-existential shapes to types with
-- non-existential shapes.  Only the representation is changed, so all
-- the shapes will be 'Free'.
staticShapes :: [TypeBase Shape u] -> [TypeBase ExtShape u]
staticShapes :: forall u.
[TypeBase Shape u] -> [TypeBase (ShapeBase (Ext SubExp)) u]
staticShapes = (TypeBase Shape u -> TypeBase (ShapeBase (Ext SubExp)) u)
-> [TypeBase Shape u] -> [TypeBase (ShapeBase (Ext SubExp)) u]
forall a b. (a -> b) -> [a] -> [b]
map TypeBase Shape u -> TypeBase (ShapeBase (Ext SubExp)) u
forall u. TypeBase Shape u -> TypeBase (ShapeBase (Ext SubExp)) u
staticShapes1

-- | As 'staticShapes', but on a single type.
staticShapes1 :: TypeBase Shape u -> TypeBase ExtShape u
staticShapes1 :: forall u. TypeBase Shape u -> TypeBase (ShapeBase (Ext SubExp)) u
staticShapes1 (Prim PrimType
bt) =
  PrimType -> TypeBase (ShapeBase (Ext SubExp)) u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
staticShapes1 (Array PrimType
bt (Shape [SubExp]
shape) u
u) =
  PrimType
-> ShapeBase (Ext SubExp)
-> u
-> TypeBase (ShapeBase (Ext SubExp)) u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
bt ([Ext SubExp] -> ShapeBase (Ext SubExp)
forall d. [d] -> ShapeBase d
Shape ([Ext SubExp] -> ShapeBase (Ext SubExp))
-> [Ext SubExp] -> ShapeBase (Ext SubExp)
forall a b. (a -> b) -> a -> b
$ (SubExp -> Ext SubExp) -> [SubExp] -> [Ext SubExp]
forall a b. (a -> b) -> [a] -> [b]
map SubExp -> Ext SubExp
forall a. a -> Ext a
Free [SubExp]
shape) u
u
staticShapes1 (Mem Space
space) =
  Space -> TypeBase (ShapeBase (Ext SubExp)) u
forall shape u. Space -> TypeBase shape u
Mem Space
space

-- | @arrayOf t s u@ constructs an array type.  The convenience
-- compared to using the 'Array' constructor directly is that @t@ can
-- itself be an array.  If @t@ is an @n@-dimensional array, and @s@ is
-- a list of length @n@, the resulting type is of an @n+m@ dimensions.
-- The uniqueness of the new array will be @u@, no matter the
-- uniqueness of @t@.  If the shape @s@ has rank 0, then the @t@ will
-- be returned, although if it is an array, with the uniqueness
-- changed to @u@.
arrayOf ::
  ArrayShape shape =>
  TypeBase shape u_unused ->
  shape ->
  u ->
  TypeBase shape u
arrayOf :: forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf (Array PrimType
et shape
size1 u_unused
_) shape
size2 u
u =
  PrimType -> shape -> u -> TypeBase shape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et (shape
size2 shape -> shape -> shape
forall a. Semigroup a => a -> a -> a
<> shape
size1) u
u
arrayOf (Prim PrimType
et) shape
s u
_
  | Int
0 <- shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank shape
s = PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
et
arrayOf (Prim PrimType
et) shape
size u
u =
  PrimType -> shape -> u -> TypeBase shape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et shape
size u
u
arrayOf Mem {} shape
_ u
_ =
  [Char] -> TypeBase shape u
forall a. HasCallStack => [Char] -> a
error [Char]
"arrayOf Mem"

-- | Construct an array whose rows are the given type, and the outer
-- size is the given dimension.  This is just a convenient wrapper
-- around 'arrayOf'.
arrayOfRow ::
  ArrayShape (ShapeBase d) =>
  TypeBase (ShapeBase d) NoUniqueness ->
  d ->
  TypeBase (ShapeBase d) NoUniqueness
arrayOfRow :: forall d.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) NoUniqueness
-> d -> TypeBase (ShapeBase d) NoUniqueness
arrayOfRow TypeBase (ShapeBase d) NoUniqueness
t d
size = TypeBase (ShapeBase d) NoUniqueness
-> ShapeBase d
-> NoUniqueness
-> TypeBase (ShapeBase d) NoUniqueness
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf TypeBase (ShapeBase d) NoUniqueness
t ([d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape [d
size]) NoUniqueness
NoUniqueness

-- | Construct an array whose rows are the given type, and the outer
-- size is the given t'Shape'.  This is just a convenient wrapper
-- around 'arrayOf'.
arrayOfShape :: Type -> Shape -> Type
arrayOfShape :: Type -> Shape -> Type
arrayOfShape Type
t Shape
shape = Type -> Shape -> NoUniqueness -> Type
forall shape u_unused u.
ArrayShape shape =>
TypeBase shape u_unused -> shape -> u -> TypeBase shape u
arrayOf Type
t Shape
shape NoUniqueness
NoUniqueness

-- | Set the dimensions of an array.  If the given type is not an
-- array, return the type unchanged.
setArrayDims :: TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
setArrayDims :: forall oldshape u.
TypeBase oldshape u -> [SubExp] -> TypeBase Shape u
setArrayDims TypeBase oldshape u
t [SubExp]
dims = TypeBase oldshape u
t TypeBase oldshape u -> Shape -> TypeBase Shape u
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape [SubExp]
dims

-- | Replace the size of the outermost dimension of an array.  If the
-- given type is not an array, it is returned unchanged.
setOuterSize ::
  ArrayShape (ShapeBase d) =>
  TypeBase (ShapeBase d) u ->
  d ->
  TypeBase (ShapeBase d) u
setOuterSize :: forall d u.
ArrayShape (ShapeBase d) =>
TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
setOuterSize = Int -> TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
forall d u.
ArrayShape (ShapeBase d) =>
Int -> TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
setDimSize Int
0

-- | Replace the size of the given dimension of an array.  If the
-- given type is not an array, it is returned unchanged.
setDimSize ::
  ArrayShape (ShapeBase d) =>
  Int ->
  TypeBase (ShapeBase d) u ->
  d ->
  TypeBase (ShapeBase d) u
setDimSize :: forall d u.
ArrayShape (ShapeBase d) =>
Int -> TypeBase (ShapeBase d) u -> d -> TypeBase (ShapeBase d) u
setDimSize Int
i TypeBase (ShapeBase d) u
t d
e = TypeBase (ShapeBase d) u
t TypeBase (ShapeBase d) u -> ShapeBase d -> TypeBase (ShapeBase d) u
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` Int -> ShapeBase d -> d -> ShapeBase d
forall d. Int -> ShapeBase d -> d -> ShapeBase d
setDim Int
i (TypeBase (ShapeBase d) u -> ShapeBase d
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase d) u
t) d
e

-- | Replace the outermost dimension of an array shape.
setOuterDim :: ShapeBase d -> d -> ShapeBase d
setOuterDim :: forall d. ShapeBase d -> d -> ShapeBase d
setOuterDim = Int -> ShapeBase d -> d -> ShapeBase d
forall d. Int -> ShapeBase d -> d -> ShapeBase d
setDim Int
0

-- | Replace the specified dimension of an array shape.
setDim :: Int -> ShapeBase d -> d -> ShapeBase d
setDim :: forall d. Int -> ShapeBase d -> d -> ShapeBase d
setDim Int
i (Shape [d]
ds) d
e = [d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape ([d] -> ShapeBase d) -> [d] -> ShapeBase d
forall a b. (a -> b) -> a -> b
$ Int -> [d] -> [d]
forall a. Int -> [a] -> [a]
take Int
i [d]
ds [d] -> [d] -> [d]
forall a. [a] -> [a] -> [a]
++ d
e d -> [d] -> [d]
forall a. a -> [a] -> [a]
: Int -> [d] -> [d]
forall a. Int -> [a] -> [a]
drop (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) [d]
ds

-- | @peelArray n t@ returns the type resulting from peeling the first
-- @n@ array dimensions from @t@.  Returns @Nothing@ if @t@ has less
-- than @n@ dimensions.
peelArray ::
  ArrayShape shape =>
  Int ->
  TypeBase shape u ->
  Maybe (TypeBase shape u)
peelArray :: forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> Maybe (TypeBase shape u)
peelArray Int
0 TypeBase shape u
t = TypeBase shape u -> Maybe (TypeBase shape u)
forall a. a -> Maybe a
Just TypeBase shape u
t
peelArray Int
n (Array PrimType
et shape
shape u
u)
  | shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank shape
shape Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n = TypeBase shape u -> Maybe (TypeBase shape u)
forall a. a -> Maybe a
Just (TypeBase shape u -> Maybe (TypeBase shape u))
-> TypeBase shape u -> Maybe (TypeBase shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
et
  | shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank shape
shape Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
n = TypeBase shape u -> Maybe (TypeBase shape u)
forall a. a -> Maybe a
Just (TypeBase shape u -> Maybe (TypeBase shape u))
-> TypeBase shape u -> Maybe (TypeBase shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> shape -> u -> TypeBase shape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et (Int -> shape -> shape
forall a. ArrayShape a => Int -> a -> a
stripDims Int
n shape
shape) u
u
peelArray Int
_ TypeBase shape u
_ = Maybe (TypeBase shape u)
forall a. Maybe a
Nothing

-- | @stripArray n t@ removes the @n@ outermost layers of the array.
-- Essentially, it is the type of indexing an array of type @t@ with
-- @n@ indexes.
stripArray :: ArrayShape shape => Int -> TypeBase shape u -> TypeBase shape u
stripArray :: forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray Int
n (Array PrimType
et shape
shape u
u)
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< shape -> Int
forall a. ArrayShape a => a -> Int
shapeRank shape
shape = PrimType -> shape -> u -> TypeBase shape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et (Int -> shape -> shape
forall a. ArrayShape a => Int -> a -> a
stripDims Int
n shape
shape) u
u
  | Bool
otherwise = PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
et
stripArray Int
_ TypeBase shape u
t = TypeBase shape u
t

-- | Return the size of the given dimension.  If the dimension does
-- not exist, the zero constant is returned.
shapeSize :: Int -> Shape -> SubExp
shapeSize :: Int -> Shape -> SubExp
shapeSize Int
i Shape
shape = case Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
i ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape of
  SubExp
e : [SubExp]
_ -> SubExp
e
  [] -> Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)

-- | Return the dimensions of a type - for non-arrays, this is the
-- empty list.
arrayDims :: TypeBase Shape u -> [SubExp]
arrayDims :: forall u. TypeBase Shape u -> [SubExp]
arrayDims = Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims (Shape -> [SubExp])
-> (TypeBase Shape u -> Shape) -> TypeBase Shape u -> [SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase Shape u -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape

-- | Return the existential dimensions of a type - for non-arrays,
-- this is the empty list.
arrayExtDims :: TypeBase ExtShape u -> [ExtSize]
arrayExtDims :: forall u. TypeBase (ShapeBase (Ext SubExp)) u -> [Ext SubExp]
arrayExtDims = ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims (ShapeBase (Ext SubExp) -> [Ext SubExp])
-> (TypeBase (ShapeBase (Ext SubExp)) u -> ShapeBase (Ext SubExp))
-> TypeBase (ShapeBase (Ext SubExp)) u
-> [Ext SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase (ShapeBase (Ext SubExp)) u -> ShapeBase (Ext SubExp)
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape

-- | Return the size of the given dimension.  If the dimension does
-- not exist, the zero constant is returned.
arraySize :: Int -> TypeBase Shape u -> SubExp
arraySize :: forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
i = Int -> Shape -> SubExp
shapeSize Int
i (Shape -> SubExp)
-> (TypeBase Shape u -> Shape) -> TypeBase Shape u -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase Shape u -> Shape
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape

-- | Return the size of the given dimension in the first element of
-- the given type list.  If the dimension does not exist, or no types
-- are given, the zero constant is returned.
arraysSize :: Int -> [TypeBase Shape u] -> SubExp
arraysSize :: forall u. Int -> [TypeBase Shape u] -> SubExp
arraysSize Int
_ [] = Int64 -> SubExp
forall v. IsValue v => v -> SubExp
constant (Int64
0 :: Int64)
arraysSize Int
i (TypeBase Shape u
t : [TypeBase Shape u]
_) = Int -> TypeBase Shape u -> SubExp
forall u. Int -> TypeBase Shape u -> SubExp
arraySize Int
i TypeBase Shape u
t

-- | Return the immediate row-type of an array.  For @[[int]]@, this
-- would be @[int]@.
rowType :: ArrayShape shape => TypeBase shape u -> TypeBase shape u
rowType :: forall shape u.
ArrayShape shape =>
TypeBase shape u -> TypeBase shape u
rowType = Int -> TypeBase shape u -> TypeBase shape u
forall shape u.
ArrayShape shape =>
Int -> TypeBase shape u -> TypeBase shape u
stripArray Int
1

-- | A type is a primitive type if it is not an array or memory block.
primType :: TypeBase shape u -> Bool
primType :: forall shape u. TypeBase shape u -> Bool
primType Array {} = Bool
False
primType Mem {} = Bool
False
primType TypeBase shape u
_ = Bool
True

-- | Returns the bottommost type of an array.  For @[[int]]@, this
-- would be @int@.  If the given type is not an array, it is returned.
elemType :: TypeBase shape u -> PrimType
elemType :: forall shape u. TypeBase shape u -> PrimType
elemType (Array PrimType
t shape
_ u
_) = PrimType
t
elemType (Prim PrimType
t) = PrimType
t
elemType Mem {} = [Char] -> PrimType
forall a. HasCallStack => [Char] -> a
error [Char]
"elemType Mem"

-- | Swap the two outer dimensions of the type.
transposeType :: Type -> Type
transposeType :: Type -> Type
transposeType = [Int] -> Type -> Type
rearrangeType [Int
1, Int
0]

-- | Rearrange the dimensions of the type.  If the length of the
-- permutation does not match the rank of the type, the permutation
-- will be extended with identity.
rearrangeType :: [Int] -> Type -> Type
rearrangeType :: [Int] -> Type -> Type
rearrangeType [Int]
perm Type
t =
  Type
t Type -> Shape -> Type
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
`setArrayShape` [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([Int] -> [SubExp] -> [SubExp]
forall a. [Int] -> [a] -> [a]
rearrangeShape [Int]
perm' ([SubExp] -> [SubExp]) -> [SubExp] -> [SubExp]
forall a b. (a -> b) -> a -> b
$ Type -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Type
t)
  where
    perm' :: [Int]
perm' = [Int]
perm [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [[Int] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Int]
perm .. Type -> Int
forall shape u. ArrayShape shape => TypeBase shape u -> Int
arrayRank Type
t Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]

-- | Transform any t'SubExp's in the type.
mapOnExtType ::
  Monad m =>
  (SubExp -> m SubExp) ->
  TypeBase ExtShape u ->
  m (TypeBase ExtShape u)
mapOnExtType :: forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp)
-> TypeBase (ShapeBase (Ext SubExp)) u
-> m (TypeBase (ShapeBase (Ext SubExp)) u)
mapOnExtType SubExp -> m SubExp
_ (Prim PrimType
bt) =
  TypeBase (ShapeBase (Ext SubExp)) u
-> m (TypeBase (ShapeBase (Ext SubExp)) u)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeBase (ShapeBase (Ext SubExp)) u
 -> m (TypeBase (ShapeBase (Ext SubExp)) u))
-> TypeBase (ShapeBase (Ext SubExp)) u
-> m (TypeBase (ShapeBase (Ext SubExp)) u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase (ShapeBase (Ext SubExp)) u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
mapOnExtType SubExp -> m SubExp
_ (Mem Space
space) =
  TypeBase (ShapeBase (Ext SubExp)) u
-> m (TypeBase (ShapeBase (Ext SubExp)) u)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TypeBase (ShapeBase (Ext SubExp)) u
 -> m (TypeBase (ShapeBase (Ext SubExp)) u))
-> TypeBase (ShapeBase (Ext SubExp)) u
-> m (TypeBase (ShapeBase (Ext SubExp)) u)
forall a b. (a -> b) -> a -> b
$ Space -> TypeBase (ShapeBase (Ext SubExp)) u
forall shape u. Space -> TypeBase shape u
Mem Space
space
mapOnExtType SubExp -> m SubExp
f (Array PrimType
t ShapeBase (Ext SubExp)
shape u
u) =
  PrimType
-> ShapeBase (Ext SubExp)
-> u
-> TypeBase (ShapeBase (Ext SubExp)) u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
t (ShapeBase (Ext SubExp)
 -> u -> TypeBase (ShapeBase (Ext SubExp)) u)
-> m (ShapeBase (Ext SubExp))
-> m (u -> TypeBase (ShapeBase (Ext SubExp)) u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([Ext SubExp] -> ShapeBase (Ext SubExp)
forall d. [d] -> ShapeBase d
Shape ([Ext SubExp] -> ShapeBase (Ext SubExp))
-> m [Ext SubExp] -> m (ShapeBase (Ext SubExp))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ext SubExp -> m (Ext SubExp)) -> [Ext SubExp] -> m [Ext SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM ((SubExp -> m SubExp) -> Ext SubExp -> m (Ext SubExp)
forall (t :: * -> *) (f :: * -> *) a b.
(Traversable t, Applicative f) =>
(a -> f b) -> t a -> f (t b)
traverse SubExp -> m SubExp
f) (ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase (Ext SubExp)
shape)) m (u -> TypeBase (ShapeBase (Ext SubExp)) u)
-> m u -> m (TypeBase (ShapeBase (Ext SubExp)) u)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u -> m u
forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u

-- | Transform any t'SubExp's in the type.
mapOnType ::
  Monad m =>
  (SubExp -> m SubExp) ->
  TypeBase Shape u ->
  m (TypeBase Shape u)
mapOnType :: forall (m :: * -> *) u.
Monad m =>
(SubExp -> m SubExp) -> TypeBase Shape u -> m (TypeBase Shape u)
mapOnType SubExp -> m SubExp
_ (Prim PrimType
bt) = TypeBase Shape u -> m (TypeBase Shape u)
forall (m :: * -> *) a. Monad m => a -> m a
return (TypeBase Shape u -> m (TypeBase Shape u))
-> TypeBase Shape u -> m (TypeBase Shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
mapOnType SubExp -> m SubExp
_ (Mem Space
space) = TypeBase Shape u -> m (TypeBase Shape u)
forall (f :: * -> *) a. Applicative f => a -> f a
pure (TypeBase Shape u -> m (TypeBase Shape u))
-> TypeBase Shape u -> m (TypeBase Shape u)
forall a b. (a -> b) -> a -> b
$ Space -> TypeBase Shape u
forall shape u. Space -> TypeBase shape u
Mem Space
space
mapOnType SubExp -> m SubExp
f (Array PrimType
t Shape
shape u
u) =
  PrimType -> Shape -> u -> TypeBase Shape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
t (Shape -> u -> TypeBase Shape u)
-> m Shape -> m (u -> TypeBase Shape u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> m [SubExp] -> m Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (SubExp -> m SubExp) -> [SubExp] -> m [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM SubExp -> m SubExp
f (Shape -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims Shape
shape)) m (u -> TypeBase Shape u) -> m u -> m (TypeBase Shape u)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u -> m u
forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u

-- | @diet t@ returns a description of how a function parameter of
-- type @t@ might consume its argument.
diet :: TypeBase shape Uniqueness -> Diet
diet :: forall shape. TypeBase shape Uniqueness -> Diet
diet (Prim PrimType
_) = Diet
ObservePrim
diet (Array PrimType
_ shape
_ Uniqueness
Unique) = Diet
Consume
diet (Array PrimType
_ shape
_ Uniqueness
Nonunique) = Diet
Observe
diet Mem {} = Diet
Observe

-- | @x \`subtypeOf\` y@ is true if @x@ is a subtype of @y@ (or equal to
-- @y@), meaning @x@ is valid whenever @y@ is.
subtypeOf ::
  (Ord u, ArrayShape shape) =>
  TypeBase shape u ->
  TypeBase shape u ->
  Bool
subtypeOf :: forall u shape.
(Ord u, ArrayShape shape) =>
TypeBase shape u -> TypeBase shape u -> Bool
subtypeOf (Array PrimType
t1 shape
shape1 u
u1) (Array PrimType
t2 shape
shape2 u
u2) =
  u
u2 u -> u -> Bool
forall a. Ord a => a -> a -> Bool
<= u
u1
    Bool -> Bool -> Bool
&& PrimType
t1 PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
t2
    Bool -> Bool -> Bool
&& shape
shape1 shape -> shape -> Bool
forall a. ArrayShape a => a -> a -> Bool
`subShapeOf` shape
shape2
subtypeOf (Prim PrimType
t1) (Prim PrimType
t2) = PrimType
t1 PrimType -> PrimType -> Bool
forall a. Eq a => a -> a -> Bool
== PrimType
t2
subtypeOf (Mem Space
space1) (Mem Space
space2) = Space
space1 Space -> Space -> Bool
forall a. Eq a => a -> a -> Bool
== Space
space2
subtypeOf TypeBase shape u
_ TypeBase shape u
_ = Bool
False

-- | @xs \`subtypesOf\` ys@ is true if @xs@ is the same size as @ys@,
-- and each element in @xs@ is a subtype of the corresponding element
-- in @ys@..
subtypesOf ::
  (Ord u, ArrayShape shape) =>
  [TypeBase shape u] ->
  [TypeBase shape u] ->
  Bool
subtypesOf :: forall u shape.
(Ord u, ArrayShape shape) =>
[TypeBase shape u] -> [TypeBase shape u] -> Bool
subtypesOf [TypeBase shape u]
xs [TypeBase shape u]
ys =
  [TypeBase shape u] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeBase shape u]
xs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [TypeBase shape u] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [TypeBase shape u]
ys
    Bool -> Bool -> Bool
&& [Bool] -> Bool
forall (t :: * -> *). Foldable t => t Bool -> Bool
and ((TypeBase shape u -> TypeBase shape u -> Bool)
-> [TypeBase shape u] -> [TypeBase shape u] -> [Bool]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith TypeBase shape u -> TypeBase shape u -> Bool
forall u shape.
(Ord u, ArrayShape shape) =>
TypeBase shape u -> TypeBase shape u -> Bool
subtypeOf [TypeBase shape u]
xs [TypeBase shape u]
ys)

-- | Add the given uniqueness information to the types.
toDecl ::
  TypeBase shape NoUniqueness ->
  Uniqueness ->
  TypeBase shape Uniqueness
toDecl :: forall shape.
TypeBase shape NoUniqueness
-> Uniqueness -> TypeBase shape Uniqueness
toDecl (Prim PrimType
bt) Uniqueness
_ = PrimType -> TypeBase shape Uniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
toDecl (Array PrimType
et shape
shape NoUniqueness
_) Uniqueness
u = PrimType -> shape -> Uniqueness -> TypeBase shape Uniqueness
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et shape
shape Uniqueness
u
toDecl (Mem Space
space) Uniqueness
_ = Space -> TypeBase shape Uniqueness
forall shape u. Space -> TypeBase shape u
Mem Space
space

-- | Remove uniqueness information from the type.
fromDecl ::
  TypeBase shape Uniqueness ->
  TypeBase shape NoUniqueness
fromDecl :: forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl (Prim PrimType
bt) = PrimType -> TypeBase shape NoUniqueness
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
fromDecl (Array PrimType
et shape
shape Uniqueness
_) = PrimType -> shape -> NoUniqueness -> TypeBase shape NoUniqueness
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
et shape
shape NoUniqueness
NoUniqueness
fromDecl (Mem Space
space) = Space -> TypeBase shape NoUniqueness
forall shape u. Space -> TypeBase shape u
Mem Space
space

-- | If an existential, then return its existential index.
isExt :: Ext a -> Maybe Int
isExt :: forall a. Ext a -> Maybe Int
isExt (Ext Int
i) = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
i
isExt Ext a
_ = Maybe Int
forall a. Maybe a
Nothing

-- | If a known size, then return that size.
isFree :: Ext a -> Maybe a
isFree :: forall a. Ext a -> Maybe a
isFree (Free a
d) = a -> Maybe a
forall a. a -> Maybe a
Just a
d
isFree Ext a
_ = Maybe a
forall a. Maybe a
Nothing

-- | Given the existential return type of a function, and the shapes
-- of the values returned by the function, return the existential
-- shape context.  That is, those sizes that are existential in the
-- return type.
extractShapeContext :: [TypeBase ExtShape u] -> [[a]] -> [a]
extractShapeContext :: forall u a. [TypeBase (ShapeBase (Ext SubExp)) u] -> [[a]] -> [a]
extractShapeContext [TypeBase (ShapeBase (Ext SubExp)) u]
ts [[a]]
shapes =
  State (Set Int) [a] -> Set Int -> [a]
forall s a. State s a -> s -> a
evalState ([[a]] -> [a]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[a]] -> [a])
-> StateT (Set Int) Identity [[a]] -> State (Set Int) [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (TypeBase (ShapeBase (Ext SubExp)) u -> [a] -> State (Set Int) [a])
-> [TypeBase (ShapeBase (Ext SubExp)) u]
-> [[a]]
-> StateT (Set Int) Identity [[a]]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM TypeBase (ShapeBase (Ext SubExp)) u -> [a] -> State (Set Int) [a]
forall {f :: * -> *} {a} {u} {a}.
(MonadState (Set Int) f, ArrayShape (ShapeBase (Ext a))) =>
TypeBase (ShapeBase (Ext a)) u -> [a] -> f [a]
extract [TypeBase (ShapeBase (Ext SubExp)) u]
ts [[a]]
shapes) Set Int
forall a. Set a
S.empty
  where
    extract :: TypeBase (ShapeBase (Ext a)) u -> [a] -> f [a]
extract TypeBase (ShapeBase (Ext a)) u
t [a]
shape =
      [Maybe a] -> [a]
forall a. [Maybe a] -> [a]
catMaybes ([Maybe a] -> [a]) -> f [Maybe a] -> f [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ext a -> a -> f (Maybe a)) -> [Ext a] -> [a] -> f [Maybe a]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM Ext a -> a -> f (Maybe a)
forall {m :: * -> *} {a} {a}.
MonadState (Set Int) m =>
Ext a -> a -> m (Maybe a)
extract' (ShapeBase (Ext a) -> [Ext a]
forall d. ShapeBase d -> [d]
shapeDims (ShapeBase (Ext a) -> [Ext a]) -> ShapeBase (Ext a) -> [Ext a]
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase (Ext a)) u -> ShapeBase (Ext a)
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase (Ext a)) u
t) [a]
shape
    extract' :: Ext a -> a -> m (Maybe a)
extract' (Ext Int
x) a
v = do
      Bool
seen <- (Set Int -> Bool) -> m Bool
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets ((Set Int -> Bool) -> m Bool) -> (Set Int -> Bool) -> m Bool
forall a b. (a -> b) -> a -> b
$ Int -> Set Int -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member Int
x
      if Bool
seen
        then Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing
        else do
          (Set Int -> Set Int) -> m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Set Int -> Set Int) -> m ()) -> (Set Int -> Set Int) -> m ()
forall a b. (a -> b) -> a -> b
$ Int -> Set Int -> Set Int
forall a. Ord a => a -> Set a -> Set a
S.insert Int
x
          Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe a -> m (Maybe a)) -> Maybe a -> m (Maybe a)
forall a b. (a -> b) -> a -> b
$ a -> Maybe a
forall a. a -> Maybe a
Just a
v
    extract' (Free a
_) a
_ = Maybe a -> m (Maybe a)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe a
forall a. Maybe a
Nothing

-- | The set of identifiers used for the shape context in the given
-- 'ExtType's.
shapeContext :: [TypeBase ExtShape u] -> S.Set Int
shapeContext :: forall u. [TypeBase (ShapeBase (Ext SubExp)) u] -> Set Int
shapeContext =
  [Int] -> Set Int
forall a. Ord a => [a] -> Set a
S.fromList
    ([Int] -> Set Int)
-> ([TypeBase (ShapeBase (Ext SubExp)) u] -> [Int])
-> [TypeBase (ShapeBase (Ext SubExp)) u]
-> Set Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (TypeBase (ShapeBase (Ext SubExp)) u -> [Int])
-> [TypeBase (ShapeBase (Ext SubExp)) u] -> [Int]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Ext SubExp -> Maybe Int) -> [Ext SubExp] -> [Int]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe Ext SubExp -> Maybe Int
forall a. Ext a -> Maybe Int
ext ([Ext SubExp] -> [Int])
-> (TypeBase (ShapeBase (Ext SubExp)) u -> [Ext SubExp])
-> TypeBase (ShapeBase (Ext SubExp)) u
-> [Int]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeBase (Ext SubExp) -> [Ext SubExp]
forall d. ShapeBase d -> [d]
shapeDims (ShapeBase (Ext SubExp) -> [Ext SubExp])
-> (TypeBase (ShapeBase (Ext SubExp)) u -> ShapeBase (Ext SubExp))
-> TypeBase (ShapeBase (Ext SubExp)) u
-> [Ext SubExp]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. TypeBase (ShapeBase (Ext SubExp)) u -> ShapeBase (Ext SubExp)
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape)
  where
    ext :: Ext a -> Maybe Int
ext (Ext Int
x) = Int -> Maybe Int
forall a. a -> Maybe a
Just Int
x
    ext (Free a
_) = Maybe Int
forall a. Maybe a
Nothing

-- | If all dimensions of the given 'ExtShape' are statically known,
-- change to the corresponding t'Shape'.
hasStaticShape :: TypeBase ExtShape u -> Maybe (TypeBase Shape u)
hasStaticShape :: forall u.
TypeBase (ShapeBase (Ext SubExp)) u -> Maybe (TypeBase Shape u)
hasStaticShape (Prim PrimType
bt) = TypeBase Shape u -> Maybe (TypeBase Shape u)
forall a. a -> Maybe a
Just (TypeBase Shape u -> Maybe (TypeBase Shape u))
-> TypeBase Shape u -> Maybe (TypeBase Shape u)
forall a b. (a -> b) -> a -> b
$ PrimType -> TypeBase Shape u
forall shape u. PrimType -> TypeBase shape u
Prim PrimType
bt
hasStaticShape (Mem Space
space) = TypeBase Shape u -> Maybe (TypeBase Shape u)
forall a. a -> Maybe a
Just (TypeBase Shape u -> Maybe (TypeBase Shape u))
-> TypeBase Shape u -> Maybe (TypeBase Shape u)
forall a b. (a -> b) -> a -> b
$ Space -> TypeBase Shape u
forall shape u. Space -> TypeBase shape u
Mem Space
space
hasStaticShape (Array PrimType
bt (Shape [Ext SubExp]
shape) u
u) =
  PrimType -> Shape -> u -> TypeBase Shape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array PrimType
bt (Shape -> u -> TypeBase Shape u)
-> Maybe Shape -> Maybe (u -> TypeBase Shape u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> Maybe [SubExp] -> Maybe Shape
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ext SubExp -> Maybe SubExp) -> [Ext SubExp] -> Maybe [SubExp]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
mapM Ext SubExp -> Maybe SubExp
forall a. Ext a -> Maybe a
isFree [Ext SubExp]
shape) Maybe (u -> TypeBase Shape u)
-> Maybe u -> Maybe (TypeBase Shape u)
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> u -> Maybe u
forall (f :: * -> *) a. Applicative f => a -> f a
pure u
u

-- | Given two lists of 'ExtType's of the same length, return a list
-- of 'ExtType's that is a subtype of the two operands.
generaliseExtTypes ::
  [TypeBase ExtShape u] ->
  [TypeBase ExtShape u] ->
  [TypeBase ExtShape u]
generaliseExtTypes :: forall u.
[TypeBase (ShapeBase (Ext SubExp)) u]
-> [TypeBase (ShapeBase (Ext SubExp)) u]
-> [TypeBase (ShapeBase (Ext SubExp)) u]
generaliseExtTypes [TypeBase (ShapeBase (Ext SubExp)) u]
rt1 [TypeBase (ShapeBase (Ext SubExp)) u]
rt2 =
  State (Int, Map Int Int) [TypeBase (ShapeBase (Ext SubExp)) u]
-> (Int, Map Int Int) -> [TypeBase (ShapeBase (Ext SubExp)) u]
forall s a. State s a -> s -> a
evalState ((TypeBase (ShapeBase (Ext SubExp)) u
 -> TypeBase (ShapeBase (Ext SubExp)) u
 -> StateT
      (Int, Map Int Int) Identity (TypeBase (ShapeBase (Ext SubExp)) u))
-> [TypeBase (ShapeBase (Ext SubExp)) u]
-> [TypeBase (ShapeBase (Ext SubExp)) u]
-> State (Int, Map Int Int) [TypeBase (ShapeBase (Ext SubExp)) u]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM TypeBase (ShapeBase (Ext SubExp)) u
-> TypeBase (ShapeBase (Ext SubExp)) u
-> StateT
     (Int, Map Int Int) Identity (TypeBase (ShapeBase (Ext SubExp)) u)
forall {f :: * -> *} {a} {u} {u}.
(ArrayShape (ShapeBase (Ext a)), MonadState (Int, Map Int Int) f,
 Eq a) =>
TypeBase (ShapeBase (Ext a)) u
-> TypeBase (ShapeBase (Ext a)) u
-> f (TypeBase (ShapeBase (Ext a)) u)
unifyExtShapes [TypeBase (ShapeBase (Ext SubExp)) u]
rt1 [TypeBase (ShapeBase (Ext SubExp)) u]
rt2) (Int
0, Map Int Int
forall k a. Map k a
M.empty)
  where
    unifyExtShapes :: TypeBase (ShapeBase (Ext a)) u
-> TypeBase (ShapeBase (Ext a)) u
-> f (TypeBase (ShapeBase (Ext a)) u)
unifyExtShapes TypeBase (ShapeBase (Ext a)) u
t1 TypeBase (ShapeBase (Ext a)) u
t2 =
      TypeBase (ShapeBase (Ext a)) u
-> ShapeBase (Ext a) -> TypeBase (ShapeBase (Ext a)) u
forall newshape oldshape u.
ArrayShape newshape =>
TypeBase oldshape u -> newshape -> TypeBase newshape u
setArrayShape TypeBase (ShapeBase (Ext a)) u
t1 (ShapeBase (Ext a) -> TypeBase (ShapeBase (Ext a)) u)
-> ([Ext a] -> ShapeBase (Ext a))
-> [Ext a]
-> TypeBase (ShapeBase (Ext a)) u
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Ext a] -> ShapeBase (Ext a)
forall d. [d] -> ShapeBase d
Shape
        ([Ext a] -> TypeBase (ShapeBase (Ext a)) u)
-> f [Ext a] -> f (TypeBase (ShapeBase (Ext a)) u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Ext a -> Ext a -> f (Ext a)) -> [Ext a] -> [Ext a] -> f [Ext a]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM
          Ext a -> Ext a -> f (Ext a)
forall {m :: * -> *} {a}.
(MonadState (Int, Map Int Int) m, Eq a) =>
Ext a -> Ext a -> m (Ext a)
unifyExtDims
          (ShapeBase (Ext a) -> [Ext a]
forall d. ShapeBase d -> [d]
shapeDims (ShapeBase (Ext a) -> [Ext a]) -> ShapeBase (Ext a) -> [Ext a]
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase (Ext a)) u -> ShapeBase (Ext a)
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase (Ext a)) u
t1)
          (ShapeBase (Ext a) -> [Ext a]
forall d. ShapeBase d -> [d]
shapeDims (ShapeBase (Ext a) -> [Ext a]) -> ShapeBase (Ext a) -> [Ext a]
forall a b. (a -> b) -> a -> b
$ TypeBase (ShapeBase (Ext a)) u -> ShapeBase (Ext a)
forall shape u. ArrayShape shape => TypeBase shape u -> shape
arrayShape TypeBase (ShapeBase (Ext a)) u
t2)
    unifyExtDims :: Ext a -> Ext a -> m (Ext a)
unifyExtDims (Free a
se1) (Free a
se2)
      | a
se1 a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
se2 = Ext a -> m (Ext a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Ext a -> m (Ext a)) -> Ext a -> m (Ext a)
forall a b. (a -> b) -> a -> b
$ a -> Ext a
forall a. a -> Ext a
Free a
se1 -- Arbitrary
      | Bool
otherwise = do
        (Int
n, Map Int Int
m) <- m (Int, Map Int Int)
forall s (m :: * -> *). MonadState s m => m s
get
        (Int, Map Int Int) -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1, Map Int Int
m)
        Ext a -> m (Ext a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Ext a -> m (Ext a)) -> Ext a -> m (Ext a)
forall a b. (a -> b) -> a -> b
$ Int -> Ext a
forall a. Int -> Ext a
Ext Int
n
    unifyExtDims (Ext Int
x) (Ext Int
y)
      | Int
x Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
y =
        Int -> Ext a
forall a. Int -> Ext a
Ext
          (Int -> Ext a) -> m Int -> m (Ext a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ( m Int -> (Int -> m Int) -> Maybe Int -> m Int
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Int -> m Int
forall {m :: * -> *} {b} {k}.
(MonadState (b, Map k b) m, Num b, Ord k) =>
k -> m b
new Int
x) Int -> m Int
forall (m :: * -> *) a. Monad m => a -> m a
return
                  (Maybe Int -> m Int) -> m (Maybe Int) -> m Int
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< ((Int, Map Int Int) -> Maybe Int) -> m (Maybe Int)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (Int -> Map Int Int -> Maybe Int
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Int
x (Map Int Int -> Maybe Int)
-> ((Int, Map Int Int) -> Map Int Int)
-> (Int, Map Int Int)
-> Maybe Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int, Map Int Int) -> Map Int Int
forall a b. (a, b) -> b
snd)
              )
    unifyExtDims (Ext Int
x) Ext a
_ = Int -> Ext a
forall a. Int -> Ext a
Ext (Int -> Ext a) -> m Int -> m (Ext a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m Int
forall {m :: * -> *} {b} {k}.
(MonadState (b, Map k b) m, Num b, Ord k) =>
k -> m b
new Int
x
    unifyExtDims Ext a
_ (Ext Int
x) = Int -> Ext a
forall a. Int -> Ext a
Ext (Int -> Ext a) -> m Int -> m (Ext a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> m Int
forall {m :: * -> *} {b} {k}.
(MonadState (b, Map k b) m, Num b, Ord k) =>
k -> m b
new Int
x
    new :: k -> m b
new k
x = do
      (b
n, Map k b
m) <- m (b, Map k b)
forall s (m :: * -> *). MonadState s m => m s
get
      (b, Map k b) -> m ()
forall s (m :: * -> *). MonadState s m => s -> m ()
put (b
n b -> b -> b
forall a. Num a => a -> a -> a
+ b
1, k -> b -> Map k b -> Map k b
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
x b
n Map k b
m)
      b -> m b
forall (m :: * -> *) a. Monad m => a -> m a
return b
n

-- | Given a list of 'ExtType's and a list of "forbidden" names,
-- modify the dimensions of the 'ExtType's such that they are 'Ext'
-- where they were previously 'Free' with a variable in the set of
-- forbidden names.
existentialiseExtTypes :: [VName] -> [ExtType] -> [ExtType]
existentialiseExtTypes :: [VName] -> [ExtType] -> [ExtType]
existentialiseExtTypes [VName]
inaccessible = (ExtType -> ExtType) -> [ExtType] -> [ExtType]
forall a b. (a -> b) -> [a] -> [b]
map ExtType -> ExtType
forall {u}.
TypeBase (ShapeBase (Ext SubExp)) u
-> TypeBase (ShapeBase (Ext SubExp)) u
makeBoundShapesFree
  where
    makeBoundShapesFree :: TypeBase (ShapeBase (Ext SubExp)) u
-> TypeBase (ShapeBase (Ext SubExp)) u
makeBoundShapesFree =
      (ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp))
-> TypeBase (ShapeBase (Ext SubExp)) u
-> TypeBase (ShapeBase (Ext SubExp)) u
forall newshape oldshape u.
ArrayShape newshape =>
(oldshape -> newshape)
-> TypeBase oldshape u -> TypeBase newshape u
modifyArrayShape ((ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp))
 -> TypeBase (ShapeBase (Ext SubExp)) u
 -> TypeBase (ShapeBase (Ext SubExp)) u)
-> (ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp))
-> TypeBase (ShapeBase (Ext SubExp)) u
-> TypeBase (ShapeBase (Ext SubExp)) u
forall a b. (a -> b) -> a -> b
$ (Ext SubExp -> Ext SubExp)
-> ShapeBase (Ext SubExp) -> ShapeBase (Ext SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Ext SubExp -> Ext SubExp
checkDim
    checkDim :: Ext SubExp -> Ext SubExp
checkDim (Free (Var VName
v))
      | Just Int
i <- VName
v VName -> [VName] -> Maybe Int
forall a. Eq a => a -> [a] -> Maybe Int
`elemIndex` [VName]
inaccessible =
        Int -> Ext SubExp
forall a. Int -> Ext a
Ext Int
i
    checkDim Ext SubExp
d = Ext SubExp
d

-- | Produce a mapping for the dimensions context.
shapeExtMapping :: [TypeBase ExtShape u] -> [TypeBase Shape u1] -> M.Map Int SubExp
shapeExtMapping :: forall u u1.
[TypeBase (ShapeBase (Ext SubExp)) u]
-> [TypeBase Shape u1] -> Map Int SubExp
shapeExtMapping = (TypeBase (ShapeBase (Ext SubExp)) u -> [Ext SubExp])
-> (TypeBase Shape u1 -> [SubExp])
-> (Ext SubExp -> SubExp -> Map Int SubExp)
-> (Map Int SubExp -> Map Int SubExp -> Map Int SubExp)
-> [TypeBase (ShapeBase (Ext SubExp)) u]
-> [TypeBase Shape u1]
-> Map Int SubExp
forall res t1 dim1 t2 dim2.
Monoid res =>
(t1 -> [dim1])
-> (t2 -> [dim2])
-> (dim1 -> dim2 -> res)
-> (res -> res -> res)
-> [t1]
-> [t2]
-> res
dimMapping TypeBase (ShapeBase (Ext SubExp)) u -> [Ext SubExp]
forall u. TypeBase (ShapeBase (Ext SubExp)) u -> [Ext SubExp]
arrayExtDims TypeBase Shape u1 -> [SubExp]
forall u. TypeBase Shape u -> [SubExp]
arrayDims Ext SubExp -> SubExp -> Map Int SubExp
forall {a} {a}. Ext a -> a -> Map Int a
match Map Int SubExp -> Map Int SubExp -> Map Int SubExp
forall a. Monoid a => a -> a -> a
mappend
  where
    match :: Ext a -> a -> Map Int a
match Free {} a
_ = Map Int a
forall a. Monoid a => a
mempty
    match (Ext Int
i) a
dim = Int -> a -> Map Int a
forall k a. k -> a -> Map k a
M.singleton Int
i a
dim

dimMapping ::
  Monoid res =>
  (t1 -> [dim1]) ->
  (t2 -> [dim2]) ->
  (dim1 -> dim2 -> res) ->
  (res -> res -> res) ->
  [t1] ->
  [t2] ->
  res
dimMapping :: forall res t1 dim1 t2 dim2.
Monoid res =>
(t1 -> [dim1])
-> (t2 -> [dim2])
-> (dim1 -> dim2 -> res)
-> (res -> res -> res)
-> [t1]
-> [t2]
-> res
dimMapping t1 -> [dim1]
getDims1 t2 -> [dim2]
getDims2 dim1 -> dim2 -> res
f res -> res -> res
comb [t1]
ts1 [t2]
ts2 =
  (res -> res -> res) -> res -> [res] -> res
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' res -> res -> res
comb res
forall a. Monoid a => a
mempty ([res] -> res) -> [res] -> res
forall a b. (a -> b) -> a -> b
$ [[res]] -> [res]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat ([[res]] -> [res]) -> [[res]] -> [res]
forall a b. (a -> b) -> a -> b
$ ([dim1] -> [dim2] -> [res]) -> [[dim1]] -> [[dim2]] -> [[res]]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith ((dim1 -> dim2 -> res) -> [dim1] -> [dim2] -> [res]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith dim1 -> dim2 -> res
f) ((t1 -> [dim1]) -> [t1] -> [[dim1]]
forall a b. (a -> b) -> [a] -> [b]
map t1 -> [dim1]
getDims1 [t1]
ts1) ((t2 -> [dim2]) -> [t2] -> [[dim2]]
forall a b. (a -> b) -> [a] -> [b]
map t2 -> [dim2]
getDims2 [t2]
ts2)

-- | @IntType Int8@
int8 :: PrimType
int8 :: PrimType
int8 = IntType -> PrimType
IntType IntType
Int8

-- | @IntType Int16@
int16 :: PrimType
int16 :: PrimType
int16 = IntType -> PrimType
IntType IntType
Int16

-- | @IntType Int32@
int32 :: PrimType
int32 :: PrimType
int32 = IntType -> PrimType
IntType IntType
Int32

-- | @IntType Int64@
int64 :: PrimType
int64 :: PrimType
int64 = IntType -> PrimType
IntType IntType
Int64

-- | @FloatType Float32@
float32 :: PrimType
float32 :: PrimType
float32 = FloatType -> PrimType
FloatType FloatType
Float32

-- | @FloatType Float64@
float64 :: PrimType
float64 :: PrimType
float64 = FloatType -> PrimType
FloatType FloatType
Float64

-- | Typeclass for things that contain 'Type's.
class Typed t where
  typeOf :: t -> Type

instance Typed Type where
  typeOf :: Type -> Type
typeOf = Type -> Type
forall a. a -> a
id

instance Typed DeclType where
  typeOf :: DeclType -> Type
typeOf = DeclType -> Type
forall shape.
TypeBase shape Uniqueness -> TypeBase shape NoUniqueness
fromDecl

instance Typed Ident where
  typeOf :: Ident -> Type
typeOf = Ident -> Type
identType

instance Typed dec => Typed (Param dec) where
  typeOf :: Param dec -> Type
typeOf = dec -> Type
forall t. Typed t => t -> Type
typeOf (dec -> Type) -> (Param dec -> dec) -> Param dec -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> dec
forall dec. Param dec -> dec
paramDec

instance Typed dec => Typed (PatElemT dec) where
  typeOf :: PatElemT dec -> Type
typeOf = dec -> Type
forall t. Typed t => t -> Type
typeOf (dec -> Type) -> (PatElemT dec -> dec) -> PatElemT dec -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PatElemT dec -> dec
forall dec. PatElemT dec -> dec
patElemDec

instance Typed b => Typed (a, b) where
  typeOf :: (a, b) -> Type
typeOf = b -> Type
forall t. Typed t => t -> Type
typeOf (b -> Type) -> ((a, b) -> b) -> (a, b) -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a, b) -> b
forall a b. (a, b) -> b
snd

-- | Typeclass for things that contain 'DeclType's.
class DeclTyped t where
  declTypeOf :: t -> DeclType

instance DeclTyped DeclType where
  declTypeOf :: DeclType -> DeclType
declTypeOf = DeclType -> DeclType
forall a. a -> a
id

instance DeclTyped dec => DeclTyped (Param dec) where
  declTypeOf :: Param dec -> DeclType
declTypeOf = dec -> DeclType
forall t. DeclTyped t => t -> DeclType
declTypeOf (dec -> DeclType) -> (Param dec -> dec) -> Param dec -> DeclType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Param dec -> dec
forall dec. Param dec -> dec
paramDec

-- | Typeclass for things that contain 'ExtType's.
class FixExt t => ExtTyped t where
  extTypeOf :: t -> ExtType

instance ExtTyped ExtType where
  extTypeOf :: ExtType -> ExtType
extTypeOf = ExtType -> ExtType
forall a. a -> a
id

-- | Typeclass for things that contain 'DeclExtType's.
class FixExt t => DeclExtTyped t where
  declExtTypeOf :: t -> DeclExtType

instance DeclExtTyped DeclExtType where
  declExtTypeOf :: DeclExtType -> DeclExtType
declExtTypeOf = DeclExtType -> DeclExtType
forall a. a -> a
id

-- | Typeclass for things whose type can be changed.
class Typed a => SetType a where
  setType :: a -> Type -> a

instance SetType Type where
  setType :: Type -> Type -> Type
setType Type
_ Type
t = Type
t

instance SetType b => SetType (a, b) where
  setType :: (a, b) -> Type -> (a, b)
setType (a
a, b
b) Type
t = (a
a, b -> Type -> b
forall a. SetType a => a -> Type -> a
setType b
b Type
t)

instance SetType dec => SetType (PatElemT dec) where
  setType :: PatElemT dec -> Type -> PatElemT dec
setType (PatElem VName
name dec
dec) Type
t =
    VName -> dec -> PatElemT dec
forall dec. VName -> dec -> PatElemT dec
PatElem VName
name (dec -> PatElemT dec) -> dec -> PatElemT dec
forall a b. (a -> b) -> a -> b
$ dec -> Type -> dec
forall a. SetType a => a -> Type -> a
setType dec
dec Type
t

-- | Something with an existential context that can be (partially)
-- fixed.
class FixExt t where
  -- | Fix the given existentional variable to the indicated free
  -- value.
  fixExt :: Int -> SubExp -> t -> t

instance (FixExt shape, ArrayShape shape) => FixExt (TypeBase shape u) where
  fixExt :: Int -> SubExp -> TypeBase shape u -> TypeBase shape u
fixExt Int
i SubExp
se = (shape -> shape) -> TypeBase shape u -> TypeBase shape u
forall newshape oldshape u.
ArrayShape newshape =>
(oldshape -> newshape)
-> TypeBase oldshape u -> TypeBase newshape u
modifyArrayShape ((shape -> shape) -> TypeBase shape u -> TypeBase shape u)
-> (shape -> shape) -> TypeBase shape u -> TypeBase shape u
forall a b. (a -> b) -> a -> b
$ Int -> SubExp -> shape -> shape
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se

instance FixExt d => FixExt (ShapeBase d) where
  fixExt :: Int -> SubExp -> ShapeBase d -> ShapeBase d
fixExt Int
i SubExp
se = (d -> d) -> ShapeBase d -> ShapeBase d
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((d -> d) -> ShapeBase d -> ShapeBase d)
-> (d -> d) -> ShapeBase d -> ShapeBase d
forall a b. (a -> b) -> a -> b
$ Int -> SubExp -> d -> d
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se

instance FixExt a => FixExt [a] where
  fixExt :: Int -> SubExp -> [a] -> [a]
fixExt Int
i SubExp
se = (a -> a) -> [a] -> [a]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((a -> a) -> [a] -> [a]) -> (a -> a) -> [a] -> [a]
forall a b. (a -> b) -> a -> b
$ Int -> SubExp -> a -> a
forall t. FixExt t => Int -> SubExp -> t -> t
fixExt Int
i SubExp
se

instance FixExt ExtSize where
  fixExt :: Int -> SubExp -> Ext SubExp -> Ext SubExp
fixExt Int
i SubExp
se (Ext Int
j)
    | Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
i = Int -> Ext SubExp
forall a. Int -> Ext a
Ext (Int -> Ext SubExp) -> Int -> Ext SubExp
forall a b. (a -> b) -> a -> b
$ Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
    | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i = SubExp -> Ext SubExp
forall a. a -> Ext a
Free SubExp
se
    | Bool
otherwise = Int -> Ext SubExp
forall a. Int -> Ext a
Ext Int
j
  fixExt Int
_ SubExp
_ (Free SubExp
x) = SubExp -> Ext SubExp
forall a. a -> Ext a
Free SubExp
x

instance FixExt () where
  fixExt :: Int -> SubExp -> () -> ()
fixExt Int
_ SubExp
_ () = ()