{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE FlexibleContexts #-}
module Data.Massiv.Array.Stencil.Convolution
( makeConvolutionStencil
, makeConvolutionStencilFromKernel
, makeCorrelationStencil
, makeCorrelationStencilFromKernel
) where
import Data.Massiv.Array.Ops.Fold (ifoldlS)
import Data.Massiv.Array.Stencil.Internal
import Data.Massiv.Core.Common
import GHC.Exts (inline)
makeConvolutionStencil
:: (Index ix, Num e)
=> Sz ix
-> ix
-> ((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
-> Stencil ix e e
makeConvolutionStencil !sz !sCenter relStencil =
validateStencil 0 $ Stencil sz sInvertCenter stencil
where
!sInvertCenter = liftIndex2 (-) (liftIndex (subtract 1) (unSz sz)) sCenter
stencil getVal !ix =
(inline relStencil $ \ !ixD !kVal !acc -> getVal (liftIndex2 (-) ix ixD) * kVal + acc) 0
{-# INLINE stencil #-}
{-# INLINE makeConvolutionStencil #-}
makeConvolutionStencilFromKernel
:: (Manifest r ix e, Num e)
=> Array r ix e
-> Stencil ix e e
makeConvolutionStencilFromKernel kArr = Stencil sz sInvertCenter stencil
where
!sz@(Sz szi) = size kArr
!szi1 = liftIndex (subtract 1) szi
!sInvertCenter = liftIndex2 (-) szi1 sCenter
!sCenter = liftIndex (`quot` 2) szi
stencil getVal !ix = Value (ifoldlS accum 0 kArr) where
!ixOff = liftIndex2 (+) ix sCenter
accum !acc !kIx !kVal =
unValue (getVal (liftIndex2 (-) ixOff kIx)) * kVal + acc
{-# INLINE accum #-}
{-# INLINE stencil #-}
{-# INLINE makeConvolutionStencilFromKernel #-}
makeCorrelationStencil
:: (Index ix, Num e)
=> Sz ix
-> ix
-> ((ix -> Value e -> Value e -> Value e) -> Value e -> Value e)
-> Stencil ix e e
makeCorrelationStencil !sSz !sCenter relStencil = validateStencil 0 $ Stencil sSz sCenter stencil
where
stencil getVal !ix =
(inline relStencil $ \ !ixD !kVal !acc -> getVal (liftIndex2 (+) ix ixD) * kVal + acc) 0
{-# INLINE stencil #-}
{-# INLINE makeCorrelationStencil #-}
makeCorrelationStencilFromKernel
:: (Manifest r ix e, Num e)
=> Array r ix e
-> Stencil ix e e
makeCorrelationStencilFromKernel kArr = Stencil sz sCenter stencil
where
!sz = size kArr
!sCenter = liftIndex (`div` 2) $ unSz sz
stencil getVal !ix = Value (ifoldlS accum 0 kArr) where
!ixOff = liftIndex2 (-) ix sCenter
accum !acc !kIx !kVal =
unValue (getVal (liftIndex2 (+) ixOff kIx)) * kVal + acc
{-# INLINE accum #-}
{-# INLINE stencil #-}
{-# INLINE makeCorrelationStencilFromKernel #-}