module Futhark.Representation.AST.Attributes.Reshape
(
newDim
, newDims
, newShape
, shapeCoerce
, repeatShapes
, reshapeOuter
, reshapeInner
, repeatDims
, shapeCoercion
, fuseReshape
, fuseReshapes
, informReshape
, reshapeIndex
, flattenIndex
, unflattenIndex
, sliceSizes
)
where
import Data.Foldable
import Prelude hiding (sum, product, quot)
import Futhark.Representation.AST.Attributes.Types
import Futhark.Representation.AST.Syntax
import Futhark.Util.IntegralExp
newDim :: DimChange d -> d
newDim (DimCoercion se) = se
newDim (DimNew se) = se
newDims :: ShapeChange d -> [d]
newDims = map newDim
newShape :: ShapeChange SubExp -> Shape
newShape = Shape . newDims
shapeCoerce :: [SubExp] -> VName -> Exp lore
shapeCoerce newdims arr =
BasicOp $ Reshape (map DimCoercion newdims) arr
repeatShapes :: [Shape] -> Type -> ([Shape], Shape)
repeatShapes shapes t =
case splitAt t_rank shapes of
(outer_shapes, [inner_shape]) ->
(outer_shapes, inner_shape)
_ ->
(shapes ++ replicate (length shapes - t_rank) (Shape []), Shape [])
where t_rank = arrayRank t
repeatDims :: [Shape] -> Shape -> Type -> Type
repeatDims shape innershape = modifyArrayShape repeatDims'
where repeatDims' (Shape ds) =
Shape $ concat (zipWith (++) (map shapeDims shape) (map pure ds)) ++
shapeDims innershape
reshapeOuter :: ShapeChange SubExp -> Int -> Shape -> ShapeChange SubExp
reshapeOuter newshape n oldshape =
newshape ++ map coercion_or_new (drop n (shapeDims oldshape))
where coercion_or_new
| length newshape == n = DimCoercion
| otherwise = DimNew
reshapeInner :: ShapeChange SubExp -> Int -> Shape -> ShapeChange SubExp
reshapeInner newshape n oldshape =
map coercion_or_new (take n (shapeDims oldshape)) ++ newshape
where coercion_or_new
| length newshape == m-n = DimCoercion
| otherwise = DimNew
m = shapeRank oldshape
shapeCoercion :: ShapeChange d -> Maybe [d]
shapeCoercion = mapM dimCoercion
where dimCoercion (DimCoercion d) = Just d
dimCoercion (DimNew _) = Nothing
fuseReshape :: Eq d => ShapeChange d -> ShapeChange d -> ShapeChange d
fuseReshape s1 s2
| length s1 == length s2 =
zipWith comb s1 s2
where comb (DimNew _) (DimCoercion d2) =
DimNew d2
comb (DimCoercion d1) (DimNew d2)
| d1 == d2 = DimCoercion d2
| otherwise = DimNew d2
comb _ d2 =
d2
fuseReshape _ s2 = s2
fuseReshapes :: (Eq d, Data.Foldable.Foldable t) =>
ShapeChange d -> t (ShapeChange d) -> ShapeChange d
fuseReshapes = Data.Foldable.foldl fuseReshape
informReshape :: Eq d => [d] -> ShapeChange d -> ShapeChange d
informReshape shape sc
| length shape == length sc =
zipWith inform shape sc
where inform d1 (DimNew d2)
| d1 == d2 = DimCoercion d2
inform _ dc =
dc
informReshape _ sc = sc
reshapeIndex :: IntegralExp num =>
[num] -> [num] -> [num] -> [num]
reshapeIndex to_dims from_dims is =
unflattenIndex to_dims $ flattenIndex from_dims is
unflattenIndex :: IntegralExp num =>
[num] -> num -> [num]
unflattenIndex = unflattenIndexFromSlices . drop 1 . sliceSizes
unflattenIndexFromSlices :: IntegralExp num =>
[num] -> num -> [num]
unflattenIndexFromSlices [] _ = []
unflattenIndexFromSlices (size : slices) i =
(i `quot` size) : unflattenIndexFromSlices slices (i - (i `quot` size) * size)
flattenIndex :: IntegralExp num =>
[num] -> [num] -> num
flattenIndex dims is =
sum $ zipWith (*) is slicesizes
where slicesizes = drop 1 $ sliceSizes dims
sliceSizes :: IntegralExp num =>
[num] -> [num]
sliceSizes [] = [1]
sliceSizes (n:ns) =
product (n : ns) : sliceSizes ns