{-# LANGUAGE BangPatterns #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE TypeFamilies #-} -- | -- Module : Data.Massiv.Array.Stencil -- Copyright : (c) Alexey Kuleshevich 2018-2019 -- License : BSD3 -- Maintainer : Alexey Kuleshevich <lehins@yandex.ru> -- Stability : experimental -- Portability : non-portable -- module Data.Massiv.Array.Stencil ( -- * Stencil Stencil , Value , makeStencil , makeStencilDef , getStencilSize , getStencilCenter -- ** Padding , Padding(..) , noPadding , samePadding -- ** Application , mapStencil , applyStencil -- ** Common stencils , idStencil , sumStencil , productStencil , avgStencil , maxStencil , minStencil , foldlStencil , foldrStencil , foldStencil -- ** Profunctor , dimapStencil , lmapStencil , rmapStencil -- * Convolution , module Data.Massiv.Array.Stencil.Convolution -- * Re-export , Default(def) ) where import Data.Coerce import Data.Default.Class (Default(def)) import Data.Massiv.Array.Delayed.Windowed import Data.Massiv.Array.Manifest import Data.Massiv.Array.Stencil.Convolution import Data.Massiv.Array.Stencil.Internal import Data.Massiv.Array.Stencil.Unsafe import Data.Massiv.Core.Common import Data.Semigroup import GHC.Exts (inline) -- | Get the size of the stencil -- -- @since 0.4.3 getStencilSize :: Stencil ix e a -> Sz ix getStencilSize = stencilSize -- | Get the index of the stencil's center -- -- @since 0.4.3 getStencilCenter :: Stencil ix e a -> ix getStencilCenter = stencilCenter -- | Map a constructed stencil over an array. Resulting array must be -- `Data.Massiv.Array.compute`d in order to be useful. -- -- @since 0.1.0 mapStencil :: (Source r ix e, Manifest r ix e) => Border e -- ^ Border resolution technique -> Stencil ix e a -- ^ Stencil to map over the array -> Array r ix e -- ^ Source array -> Array DW ix a mapStencil b stencil = applyStencil (samePadding stencil b) stencil {-# INLINE mapStencil #-} -- | Padding of the source array before stencil application. -- -- ==== __Examples__ -- -- In order to see the affect of padding we can simply apply an identity stencil to an -- array: -- -- >>> import Data.Massiv.Array as A -- >>> a = computeAs P $ resize' (Sz2 2 3) (Ix1 1 ... 6) -- >>> applyStencil noPadding idStencil a -- Array DW Seq (Sz (2 :. 3)) -- [ [ 1, 2, 3 ] -- , [ 4, 5, 6 ] -- ] -- >>> applyStencil (Padding (Sz2 1 2) (Sz2 3 4) (Fill 0)) idStencil a -- Array DW Seq (Sz (6 :. 9)) -- [ [ 0, 0, 0, 0, 0, 0, 0, 0, 0 ] -- , [ 0, 0, 1, 2, 3, 0, 0, 0, 0 ] -- , [ 0, 0, 4, 5, 6, 0, 0, 0, 0 ] -- , [ 0, 0, 0, 0, 0, 0, 0, 0, 0 ] -- , [ 0, 0, 0, 0, 0, 0, 0, 0, 0 ] -- , [ 0, 0, 0, 0, 0, 0, 0, 0, 0 ] -- ] -- -- It is also a nice technique to see the border resolution strategy in action: -- -- >>> applyStencil (Padding (Sz2 2 3) (Sz2 2 3) Wrap) idStencil a -- Array DW Seq (Sz (6 :. 9)) -- [ [ 1, 2, 3, 1, 2, 3, 1, 2, 3 ] -- , [ 4, 5, 6, 4, 5, 6, 4, 5, 6 ] -- , [ 1, 2, 3, 1, 2, 3, 1, 2, 3 ] -- , [ 4, 5, 6, 4, 5, 6, 4, 5, 6 ] -- , [ 1, 2, 3, 1, 2, 3, 1, 2, 3 ] -- , [ 4, 5, 6, 4, 5, 6, 4, 5, 6 ] -- ] -- >>> applyStencil (Padding (Sz2 2 3) (Sz2 2 3) Edge) idStencil a -- Array DW Seq (Sz (6 :. 9)) -- [ [ 1, 1, 1, 1, 2, 3, 3, 3, 3 ] -- , [ 1, 1, 1, 1, 2, 3, 3, 3, 3 ] -- , [ 1, 1, 1, 1, 2, 3, 3, 3, 3 ] -- , [ 4, 4, 4, 4, 5, 6, 6, 6, 6 ] -- , [ 4, 4, 4, 4, 5, 6, 6, 6, 6 ] -- , [ 4, 4, 4, 4, 5, 6, 6, 6, 6 ] -- ] -- >>> applyStencil (Padding (Sz2 2 3) (Sz2 2 3) Reflect) idStencil a -- Array DW Seq (Sz (6 :. 9)) -- [ [ 6, 5, 4, 4, 5, 6, 6, 5, 4 ] -- , [ 3, 2, 1, 1, 2, 3, 3, 2, 1 ] -- , [ 3, 2, 1, 1, 2, 3, 3, 2, 1 ] -- , [ 6, 5, 4, 4, 5, 6, 6, 5, 4 ] -- , [ 6, 5, 4, 4, 5, 6, 6, 5, 4 ] -- , [ 3, 2, 1, 1, 2, 3, 3, 2, 1 ] -- ] -- >>> applyStencil (Padding (Sz2 2 3) (Sz2 2 3) Continue) idStencil a -- Array DW Seq (Sz (6 :. 9)) -- [ [ 1, 3, 2, 1, 2, 3, 2, 1, 3 ] -- , [ 4, 6, 5, 4, 5, 6, 5, 4, 6 ] -- , [ 1, 3, 2, 1, 2, 3, 2, 1, 3 ] -- , [ 4, 6, 5, 4, 5, 6, 5, 4, 6 ] -- , [ 1, 3, 2, 1, 2, 3, 2, 1, 3 ] -- , [ 4, 6, 5, 4, 5, 6, 5, 4, 6 ] -- ] -- -- @since 0.4.3 data Padding ix e = Padding { paddingFromOrigin :: !(Sz ix) , paddingFromBottom :: !(Sz ix) , paddingWithElement :: !(Border e) -- ^ Element to do padding with } deriving (Eq, Show) -- | Also known as "valid" padding. When stencil is applied to an array, that array will -- shrink, unless the stencil is of size 1. -- -- @since 0.4.3 noPadding :: Index ix => Padding ix e noPadding = Padding zeroSz zeroSz Edge -- | Padding that matches the size of the stencil, which is known as "same" padding, -- because when a stencil is applied to an array with such matching padding, the resulting -- array will be of the same size as the source array. This is exactly the behavior of -- `mapStencil` -- -- @since 0.4.3 samePadding :: Index ix => Stencil ix e a -> Border e -> Padding ix e samePadding (Stencil (Sz sSz) sCenter _) border = Padding { paddingFromOrigin = Sz sCenter , paddingFromBottom = Sz (liftIndex2 (-) sSz (liftIndex (+1) sCenter)) , paddingWithElement = border } -- | Apply a constructed stencil over an array. Resulting array must be -- `Data.Massiv.Array.compute`d in order to be useful. Unlike `mapStencil`, the size of -- the resulting array will not necesserally be the same as the source array, which will -- depend on the padding. -- -- @since 0.4.3 applyStencil :: (Source r ix e, Manifest r ix e) => Padding ix e -- ^ Padding to be applied to the source array. This will dictate the resulting size of -- the array. No padding will cause it to shrink by the size of the stencil -> Stencil ix e a -- ^ Stencil to apply to the array -> Array r ix e -- ^ Source array -> Array DW ix a applyStencil (Padding (Sz po) (Sz pb) border) (Stencil sSz sCenter stencilF) !arr = insertWindow warr window where !offset = liftIndex2 (-) sCenter po !warr = DArray (getComp arr) sz (unValue . stencilF (Value . borderIndex border arr) . liftIndex2 (+) offset) -- Size by which the resulting array will shrink (not accounting for padding) !shrinkSz = Sz (liftIndex (subtract 1) (unSz sSz)) !sz = liftSz2 (-) (SafeSz (liftIndex2 (+) po (liftIndex2 (+) pb (unSz (size arr))))) shrinkSz !wsz = liftSz2 (-) (size arr) shrinkSz !window = Window { windowStart = po , windowSize = wsz , windowIndex = unValue . stencilF (Value . unsafeIndex arr) . liftIndex2 (+) offset , windowUnrollIx2 = unSz . fst <$> pullOutSzM sSz 2 } {-# INLINE applyStencil #-} -- | Construct a stencil from a function, which describes how to calculate the -- value at a point while having access to neighboring elements with a function -- that accepts idices relative to the center of stencil. Trying to index -- outside the stencil box will result in a runtime error upon stencil -- creation. -- -- ==== __Example__ -- -- Below is an example of creating a `Stencil`, which, when mapped over a -- 2-dimensional array, will compute an average of all elements in a 3x3 square -- for each element in that array. -- -- /Note/ - Make sure to add an @INLINE@ pragma, otherwise performance will be terrible. -- -- > average3x3Stencil :: (Default a, Fractional a) => Stencil Ix2 a a -- > average3x3Stencil = makeStencil (Sz (3 :. 3)) (1 :. 1) $ \ get -> -- > ( get (-1 :. -1) + get (-1 :. 0) + get (-1 :. 1) + -- > get ( 0 :. -1) + get ( 0 :. 0) + get ( 0 :. 1) + -- > get ( 1 :. -1) + get ( 1 :. 0) + get ( 1 :. 1) ) / 9 -- > {-# INLINE average3x3Stencil #-} -- -- @since 0.1.0 makeStencil :: (Index ix, Default e) => Sz ix -- ^ Size of the stencil -> ix -- ^ Center of the stencil -> ((ix -> Value e) -> Value a) -- ^ Stencil function that receives a "get" function as it's argument that can -- retrieve values of cells in the source array with respect to the center of -- the stencil. Stencil function must return a value that will be assigned to -- the cell in the result array. Offset supplied to the "get" function -- cannot go outside the boundaries of the stencil, otherwise an error will be -- raised during stencil creation. -> Stencil ix e a makeStencil = makeStencilDef def {-# INLINE makeStencil #-} -- | Same as `makeStencil`, but with ability to specify default value for stencil validation. -- -- @since 0.2.3 makeStencilDef :: Index ix => e -- ^ Default element that will be used for stencil validation only. -> Sz ix -- ^ Size of the stencil -> ix -- ^ Center of the stencil -> ((ix -> Value e) -> Value a) -- ^ Stencil function. -> Stencil ix e a makeStencilDef defVal !sSz !sCenter relStencil = validateStencil defVal $ Stencil sSz sCenter stencil where stencil getVal !ix = inline relStencil $ \ !ixD -> getVal (liftIndex2 (+) ix ixD) {-# INLINE stencil #-} {-# INLINE makeStencilDef #-} -- | Identity stencil that does not change the elements of the source array. -- -- @since 0.4.3 idStencil :: Index ix => Stencil ix e e idStencil = makeUnsafeStencil oneSz zeroIndex $ \ _ get -> get zeroIndex {-# INLINE idStencil #-} -- | Stencil that does a left fold in a row-major order. Regardless of the supplied size -- resulting stencil will be centered at zero, although by using `Padding` it is possible -- to overcome this limitation. -- -- ==== __Examples__ -- -- >>> import Data.Massiv.Array as A -- >>> a = computeAs P $ iterateN (Sz2 3 4) (+1) (10 :: Int) -- >>> a -- Array P Seq (Sz (3 :. 4)) -- [ [ 11, 12, 13, 14 ] -- , [ 15, 16, 17, 18 ] -- , [ 19, 20, 21, 22 ] -- ] -- >>> applyStencil noPadding (foldlStencil (flip (:)) [] (Sz2 3 2)) a -- Array DW Seq (Sz (1 :. 3)) -- [ [ [20,19,16,15,12,11], [21,20,17,16,13,12], [22,21,18,17,14,13] ] -- ] -- >>> applyStencil (Padding (Sz2 1 0) 0 (Fill 10)) (foldlStencil (flip (:)) [] (Sz2 3 2)) a -- Array DW Seq (Sz (2 :. 3)) -- [ [ [16,15,12,11,10,10], [17,16,13,12,10,10], [18,17,14,13,10,10] ] -- , [ [20,19,16,15,12,11], [21,20,17,16,13,12], [22,21,18,17,14,13] ] -- ] -- -- @since 0.4.3 foldlStencil :: Index ix => (a -> e -> a) -> a -> Sz ix -> Stencil ix e a foldlStencil f acc0 sz = makeUnsafeStencil sz zeroIndex $ \_ get -> iter zeroIndex (unSz sz) oneIndex (<) acc0 $ \ix -> (`f` get ix) {-# INLINE foldlStencil #-} -- | Stencil that does a right fold in a row-major order. Regardless of the supplied size -- resulting stencil will be centered at zero, although by using `Padding` it is possible -- to overcome this limitation. -- -- ==== __Examples__ -- -- >>> import Data.Massiv.Array as A -- >>> a = computeAs P $ iterateN (Sz2 3 4) (+1) (10 :: Int) -- >>> a -- Array P Seq (Sz (3 :. 4)) -- [ [ 11, 12, 13, 14 ] -- , [ 15, 16, 17, 18 ] -- , [ 19, 20, 21, 22 ] -- ] -- >>> applyStencil noPadding (foldrStencil (:) [] (Sz2 2 3)) a -- Array DW Seq (Sz (2 :. 2)) -- [ [ [11,12,13,15,16,17], [12,13,14,16,17,18] ] -- , [ [15,16,17,19,20,21], [16,17,18,20,21,22] ] -- ] -- -- @since 0.4.3 foldrStencil :: Index ix => (e -> a -> a) -> a -> Sz ix -> Stencil ix e a foldrStencil f acc0 sz = let ixStart = liftIndex2 (-) (unSz sz) oneIndex in makeUnsafeStencil sz zeroIndex $ \_ get -> iter ixStart zeroIndex (pureIndex (-1)) (>=) acc0 $ \ix -> f (get ix) {-# INLINE foldrStencil #-} foldStencil :: (Monoid e, Index ix) => Sz ix -> Stencil ix e e foldStencil = foldlStencil mappend mempty {-# INLINE foldStencil #-} -- | Create a stencil centered at 0 that will extract the maximum value in the region of -- supplied size. -- -- ==== __Example__ -- -- Here is a sample implementation of max pooling. -- -- >>> import Data.Massiv.Array as A -- >>> a <- computeAs P <$> resizeM (Sz2 9 9) (Ix1 10 ..: 91) -- >>> a -- Array P Seq (Sz (9 :. 9)) -- [ [ 10, 11, 12, 13, 14, 15, 16, 17, 18 ] -- , [ 19, 20, 21, 22, 23, 24, 25, 26, 27 ] -- , [ 28, 29, 30, 31, 32, 33, 34, 35, 36 ] -- , [ 37, 38, 39, 40, 41, 42, 43, 44, 45 ] -- , [ 46, 47, 48, 49, 50, 51, 52, 53, 54 ] -- , [ 55, 56, 57, 58, 59, 60, 61, 62, 63 ] -- , [ 64, 65, 66, 67, 68, 69, 70, 71, 72 ] -- , [ 73, 74, 75, 76, 77, 78, 79, 80, 81 ] -- , [ 82, 83, 84, 85, 86, 87, 88, 89, 90 ] -- ] -- >>> computeWithStrideAs P (Stride 3) $ mapStencil Edge (maxStencil (Sz 3)) a -- Array P Seq (Sz (3 :. 3)) -- [ [ 30, 33, 36 ] -- , [ 57, 60, 63 ] -- , [ 84, 87, 90 ] -- ] -- -- @since 0.4.3 maxStencil :: (Bounded e, Ord e, Index ix) => Sz ix -> Stencil ix e e maxStencil = dimapStencil coerce getMax . foldStencil {-# INLINE maxStencil #-} -- | Create a stencil centered at 0 that will extract the maximum value in the region of -- supplied size. -- -- ==== __Example__ -- -- Here is a sample implementation of min pooling. -- -- >>> import Data.Massiv.Array as A -- >>> a <- computeAs P <$> resizeM (Sz2 9 9) (Ix1 10 ..: 91) -- >>> a -- Array P Seq (Sz (9 :. 9)) -- [ [ 10, 11, 12, 13, 14, 15, 16, 17, 18 ] -- , [ 19, 20, 21, 22, 23, 24, 25, 26, 27 ] -- , [ 28, 29, 30, 31, 32, 33, 34, 35, 36 ] -- , [ 37, 38, 39, 40, 41, 42, 43, 44, 45 ] -- , [ 46, 47, 48, 49, 50, 51, 52, 53, 54 ] -- , [ 55, 56, 57, 58, 59, 60, 61, 62, 63 ] -- , [ 64, 65, 66, 67, 68, 69, 70, 71, 72 ] -- , [ 73, 74, 75, 76, 77, 78, 79, 80, 81 ] -- , [ 82, 83, 84, 85, 86, 87, 88, 89, 90 ] -- ] -- >>> computeWithStrideAs P (Stride 3) $ mapStencil Edge (minStencil (Sz 3)) a -- Array P Seq (Sz (3 :. 3)) -- [ [ 10, 13, 16 ] -- , [ 37, 40, 43 ] -- , [ 64, 67, 70 ] -- ] -- -- @since 0.4.3 minStencil :: (Bounded e, Ord e, Index ix) => Sz ix -> Stencil ix e e minStencil = dimapStencil coerce getMin . foldStencil {-# INLINE minStencil #-} -- | Sum all elements in the stencil region -- -- ==== __Examples__ -- -- >>> import Data.Massiv.Array as A -- >>> a = computeAs P $ iterateN (Sz2 2 5) (* 2) (1 :: Int) -- >>> a -- Array P Seq (Sz (2 :. 5)) -- [ [ 2, 4, 8, 16, 32 ] -- , [ 64, 128, 256, 512, 1024 ] -- ] -- >>> applyStencil noPadding (sumStencil (Sz2 1 2)) a -- Array DW Seq (Sz (2 :. 4)) -- [ [ 6, 12, 24, 48 ] -- , [ 192, 384, 768, 1536 ] -- ] -- >>> [2 + 4, 4 + 8, 8 + 16, 16 + 32] :: [Int] -- [6,12,24,48] -- -- @since 0.4.3 sumStencil :: (Num e, Index ix) => Sz ix -> Stencil ix e e sumStencil = dimapStencil coerce getSum . foldStencil {-# INLINE sumStencil #-} -- | Multiply all elements in the stencil region -- -- ==== __Examples__ -- -- >>> import Data.Massiv.Array as A -- >>> a = computeAs P $ iterateN (Sz2 2 2) (+1) (0 :: Int) -- >>> a -- Array P Seq (Sz (2 :. 2)) -- [ [ 1, 2 ] -- , [ 3, 4 ] -- ] -- >>> applyStencil (Padding 0 2 (Fill 0)) (productStencil 2) a -- Array DW Seq (Sz (3 :. 3)) -- [ [ 24, 0, 0 ] -- , [ 0, 0, 0 ] -- , [ 0, 0, 0 ] -- ] -- >>> applyStencil (Padding 0 2 Reflect) (productStencil 2) a -- Array DW Seq (Sz (3 :. 3)) -- [ [ 24, 64, 24 ] -- , [ 144, 256, 144 ] -- , [ 24, 64, 24 ] -- ] -- -- @since 0.4.3 productStencil :: (Num e, Index ix) => Sz ix -> Stencil ix e e productStencil = dimapStencil coerce getProduct . foldStencil {-# INLINE productStencil #-} -- | Find the average value of all elements in the stencil region -- -- ==== __Example__ -- -- >>> import Data.Massiv.Array as A -- >>> a = computeAs P $ iterateN (Sz2 3 4) (+1) (10 :: Double) -- >>> a -- Array P Seq (Sz (3 :. 4)) -- [ [ 11.0, 12.0, 13.0, 14.0 ] -- , [ 15.0, 16.0, 17.0, 18.0 ] -- , [ 19.0, 20.0, 21.0, 22.0 ] -- ] -- >>> applyStencil noPadding (avgStencil (Sz2 2 3)) a -- Array DW Seq (Sz (2 :. 2)) -- [ [ 14.0, 15.0 ] -- , [ 18.0, 19.0 ] -- ] -- >>> Prelude.sum [11.0, 12.0, 13.0, 15.0, 16.0, 17.0] / 6 :: Double -- 14.0 -- -- @since 0.4.3 avgStencil :: (Fractional e, Index ix) => Sz ix -> Stencil ix e e avgStencil sz = sumStencil sz / fromIntegral (totalElem sz) {-# INLINE avgStencil #-}