{-# LANGUAGE BangPatterns, PackageImports #-}
{-# OPTIONS -Wall -fno-warn-missing-signatures -fno-warn-incomplete-patterns #-}

-- | Generic stencil based convolutions. 
-- 
--   If your stencil fits within a 7x7 tile and is known at compile-time then using
--   the built-in stencil support provided by the main Repa package will be
--   5-10x faster. 
-- 
--   If you have a larger stencil, the coefficients are not statically known, 
--   or need more complex boundary handling than provided by the built-in functions,
--   then use this version instead.
--
module Data.Array.Repa.Algorithms.Convolve
        ( -- * Arbitrary boundary handling
          convolveP

          -- * Specialised boundary handling
        , GetOut
        , outAs
        , outClamp
        , convolveOutP )
where
import Data.Array.Repa                                  as R
import Data.Array.Repa.Unsafe                           as R
import Data.Array.Repa.Repr.Unboxed                     as R
import qualified Data.Vector.Unboxed                    as V
import qualified Data.Array.Repa.Shape                  as S
import Prelude                                          as P


-- Plain Convolve -------------------------------------------------------------
-- | Image-kernel convolution,
--   which takes a function specifying what value to return when the
--   kernel doesn't apply.
convolveP
        :: (Num a, Unbox a, Monad m)
        => (DIM2 -> a)          -- ^ Function to get border elements when 
                                --   the stencil does not apply.
        -> Array U DIM2 a       -- ^ Stencil to use in the convolution.
        -> Array U DIM2 a       -- ^ Input image.
        -> m (Array U DIM2 a)

convolveP :: (DIM2 -> a)
-> Array U DIM2 a -> Array U DIM2 a -> m (Array U DIM2 a)
convolveP DIM2 -> a
makeOut Array U DIM2 a
kernel Array U DIM2 a
image
 = Array U DIM2 a
kernel Array U DIM2 a -> Array U DIM2 a -> Array U DIM2 a
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray` Array U DIM2 a
image Array U DIM2 a
-> (Array D DIM2 a -> m (Array U DIM2 a))
-> Array D DIM2 a
-> m (Array U DIM2 a)
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray` 
   Array D DIM2 a -> m (Array U DIM2 a)
forall r1 sh e r2 (m :: * -> *).
(Load r1 sh e, Target r2 e, Source r2 e, Monad m) =>
Array r1 sh e -> m (Array r2 sh e)
computeP (Array D DIM2 a -> m (Array U DIM2 a))
-> Array D DIM2 a -> m (Array U DIM2 a)
forall a b. (a -> b) -> a -> b
$ Array U DIM2 a
-> (DIM2 -> DIM2) -> ((DIM2 -> a) -> DIM2 -> a) -> Array D DIM2 a
forall r sh sh' a b.
(Source r a, Shape sh) =>
Array r sh a
-> (sh -> sh') -> ((sh -> a) -> sh' -> b) -> Array D sh' b
unsafeTraverse Array U DIM2 a
image DIM2 -> DIM2
forall a. a -> a
id (DIM2 -> a) -> DIM2 -> a
update
 where  
        (Z
Z :. Int
krnHeight :. Int
krnWidth)        = Array U DIM2 a -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 a
kernel
        krnVec :: Vector a
krnVec          = Array U DIM2 a -> Vector a
forall sh e. Array U sh e -> Vector e
toUnboxed Array U DIM2 a
kernel
        
        imgSh :: DIM2
imgSh@(Z
Z :. Int
imgHeight :. Int
imgWidth)  = Array U DIM2 a -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 a
image
        imgVec :: Vector a
imgVec          = Array U DIM2 a -> Vector a
forall sh e. Array U sh e -> Vector e
toUnboxed Array U DIM2 a
image

        !krnHeight2 :: Int
krnHeight2     = Int
krnHeight Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2
        !krnWidth2 :: Int
krnWidth2      = Int
krnWidth  Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2

        -- If we're too close to the edge of the input image then
        -- we can't apply the stencil because we don't have enough data.
        !borderLeft :: Int
borderLeft     = Int
krnWidth2
        !borderRight :: Int
borderRight    = Int
imgWidth   Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
krnWidth2  Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
        !borderUp :: Int
borderUp       = Int
krnHeight2
        !borderDown :: Int
borderDown     = Int
imgHeight  Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
krnHeight2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1

        {-# INLINE update #-}
        update :: (DIM2 -> a) -> DIM2 -> a
update DIM2 -> a
_ ix :: DIM2
ix@(Z
_ :. Int
j :. Int
i)
         | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
borderLeft       = DIM2 -> a
makeOut DIM2
ix
         | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
borderRight      = DIM2 -> a
makeOut DIM2
ix
         | Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
borderUp         = DIM2 -> a
makeOut DIM2
ix
         | Int
j Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
borderDown       = DIM2 -> a
makeOut DIM2
ix
         | Bool
otherwise            = Int -> Int -> a
stencil Int
j Int
i

        -- The actual stencil function.
        {-# INLINE stencil #-}
        stencil :: Int -> Int -> a
stencil Int
j Int
i
         = let  imgStart :: Int
imgStart = DIM2 -> DIM2 -> Int
forall sh. Shape sh => sh -> sh -> Int
S.toIndex DIM2
imgSh (Z
Z Z -> Int -> Z :. Int
forall tail head. tail -> head -> tail :. head
:. Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
krnHeight2 (Z :. Int) -> Int -> DIM2
forall tail head. tail -> head -> tail :. head
:. Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
krnWidth2)
           in   a -> Int -> Int -> Int -> Int -> a
integrate a
0 Int
0 Int
0 Int
imgStart Int
0

        {-# INLINE integrate #-}
        integrate :: a -> Int -> Int -> Int -> Int -> a
integrate !a
acc !Int
x !Int
y !Int
imgCur !Int
krnCur  
         | Int
y Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
krnHeight
         = a
acc

         | Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
krnWidth
         = a -> Int -> Int -> Int -> Int -> a
integrate a
acc Int
0 (Int
y Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
imgCur Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
imgWidth Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
krnWidth) Int
krnCur 
        
         | Bool
otherwise
         = let  imgZ :: a
imgZ    = Vector a
imgVec Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
`V.unsafeIndex` Int
imgCur 
                krnZ :: a
krnZ    = Vector a
krnVec Vector a -> Int -> a
forall a. Unbox a => Vector a -> Int -> a
`V.unsafeIndex` Int
krnCur 
                here :: a
here    = a
imgZ a -> a -> a
forall a. Num a => a -> a -> a
* a
krnZ 
           in   a -> Int -> Int -> Int -> Int -> a
integrate (a
acc a -> a -> a
forall a. Num a => a -> a -> a
+ a
here) (Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Int
y (Int
imgCur Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
krnCur Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
{-# INLINE convolveP #-}


-- Convolve Out -----------------------------------------------------------------------------------
-- | A function that gets out of range elements from an image.
type GetOut a
        = (DIM2 -> a)   -- ^ The original get function.
        -> DIM2         -- ^ The shape of the image.
        -> DIM2         -- ^ Index of element we were trying to get.
        -> a


-- | Use the provided value for every out-of-range element.
outAs :: a -> GetOut a
{-# INLINE outAs #-}
outAs :: a -> GetOut a
outAs a
x DIM2 -> a
_ DIM2
_ DIM2
_ = a
x


-- | If the requested element is out of range use
--   the closest one from the real image.
outClamp :: GetOut a
{-# INLINE outClamp #-}
outClamp :: GetOut a
outClamp DIM2 -> a
get (Z
_ :. Int
yLen :. Int
xLen) (Z
sh :. Int
j :. Int
i)
 = Int -> Int -> a
clampX Int
j Int
i
 where  {-# INLINE clampX #-}
        clampX :: Int -> Int -> a
clampX !Int
y !Int
x
          | Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0       = Int -> Int -> a
clampY Int
y Int
0
          | Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
xLen   = Int -> Int -> a
clampY Int
y (Int
xLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
          | Bool
otherwise   = Int -> Int -> a
clampY Int
y Int
x
                
        {-# INLINE clampY #-}
        clampY :: Int -> Int -> a
clampY !Int
y !Int
x
          | Int
y Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0       = DIM2 -> a
get (Z
sh Z -> Int -> Z :. Int
forall tail head. tail -> head -> tail :. head
:. Int
0          (Z :. Int) -> Int -> DIM2
forall tail head. tail -> head -> tail :. head
:. Int
x)
          | Int
y Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
yLen   = DIM2 -> a
get (Z
sh Z -> Int -> Z :. Int
forall tail head. tail -> head -> tail :. head
:. (Int
yLen Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Z :. Int) -> Int -> DIM2
forall tail head. tail -> head -> tail :. head
:. Int
x)
          | Bool
otherwise   = DIM2 -> a
get (Z
sh Z -> Int -> Z :. Int
forall tail head. tail -> head -> tail :. head
:. Int
y          (Z :. Int) -> Int -> DIM2
forall tail head. tail -> head -> tail :. head
:. Int
x)


-- | Image-kernel convolution, 
--   which takes a function specifying what value to use for out-of-range elements.
convolveOutP
        :: (Num a, Unbox a, Monad m)
        => GetOut a             -- ^ How to handle out-of-range elements.
        -> Array U DIM2 a       -- ^ Stencil to use in the convolution.
        -> Array U DIM2 a       -- ^ Input image.
        -> m (Array U DIM2 a)

convolveOutP :: GetOut a -> Array U DIM2 a -> Array U DIM2 a -> m (Array U DIM2 a)
convolveOutP GetOut a
getOut Array U DIM2 a
kernel Array U DIM2 a
image
 = Array U DIM2 a
kernel Array U DIM2 a -> Array U DIM2 a -> Array U DIM2 a
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray` Array U DIM2 a
image Array U DIM2 a
-> (Array D DIM2 a -> m (Array U DIM2 a))
-> Array D DIM2 a
-> m (Array U DIM2 a)
forall r e sh b. (Source r e, Shape sh) => Array r sh e -> b -> b
`deepSeqArray` 
   Array D DIM2 a -> m (Array U DIM2 a)
forall r1 sh e r2 (m :: * -> *).
(Load r1 sh e, Target r2 e, Source r2 e, Monad m) =>
Array r1 sh e -> m (Array r2 sh e)
computeP (Array D DIM2 a -> m (Array U DIM2 a))
-> Array D DIM2 a -> m (Array U DIM2 a)
forall a b. (a -> b) -> a -> b
$ Array U DIM2 a
-> (DIM2 -> DIM2) -> ((DIM2 -> a) -> DIM2 -> a) -> Array D DIM2 a
forall r sh sh' a b.
(Source r a, Shape sh) =>
Array r sh a
-> (sh -> sh') -> ((sh -> a) -> sh' -> b) -> Array D sh' b
unsafeTraverse Array U DIM2 a
image DIM2 -> DIM2
forall a. a -> a
id (DIM2 -> a) -> DIM2 -> a
stencil
 where  
        krnSh :: DIM2
krnSh@(Z
Z :. Int
krnHeight :. Int
krnWidth)  = Array U DIM2 a -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 a
kernel        
        imgSh :: DIM2
imgSh@(Z
Z :. Int
imgHeight :. Int
imgWidth)  = Array U DIM2 a -> DIM2
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh
extent Array U DIM2 a
image

        !krnHeight2 :: Int
krnHeight2     = Int
krnHeight Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2
        !krnWidth2 :: Int
krnWidth2      = Int
krnWidth  Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2
        !krnSize :: Int
krnSize        = DIM2 -> Int
forall sh. Shape sh => sh -> Int
S.size DIM2
krnSh

        -- If we're too close to the edge of the input image then
        -- we can't apply the stencil because we don't have enough data.
        !borderLeft :: Int
borderLeft     = Int
krnWidth2
        !borderRight :: Int
borderRight    = Int
imgWidth   Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
krnWidth2  Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1
        !borderUp :: Int
borderUp       = Int
krnHeight2
        !borderDown :: Int
borderDown     = Int
imgHeight  Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
krnHeight2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1

        -- The actual stencil function.
        {-# INLINE stencil #-}
        stencil :: (DIM2 -> a) -> DIM2 -> a
stencil DIM2 -> a
get (Z
_ :. Int
j :. Int
i)
         = let
                {-# INLINE get' #-}
                get' :: DIM2 -> a
get' ix :: DIM2
ix@(Z
_ :. Int
y :. Int
x)
                 | Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
borderLeft       = GetOut a
getOut DIM2 -> a
get DIM2
imgSh DIM2
ix
                 | Int
x Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
borderRight      = GetOut a
getOut DIM2 -> a
get DIM2
imgSh DIM2
ix
                 | Int
y Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
borderUp         = GetOut a
getOut DIM2 -> a
get DIM2
imgSh DIM2
ix
                 | Int
y Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
borderDown       = GetOut a
getOut DIM2 -> a
get DIM2
imgSh DIM2
ix
                 | Bool
otherwise            = DIM2 -> a
get DIM2
ix

                !ikrnWidth' :: Int
ikrnWidth'     = Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
krnWidth2
                !jkrnHeight' :: Int
jkrnHeight'    = Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
krnHeight2

                {-# INLINE integrate #-}
                integrate :: Int -> a -> a
integrate !Int
count !a
acc
                 | Int
count Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
krnSize             = a
acc
                 | Bool
otherwise
                 = let  !ix :: DIM2
ix@(Z
sh :. Int
y :. Int
x)      = DIM2 -> Int -> DIM2
forall sh. Shape sh => sh -> Int -> sh
S.fromIndex DIM2
krnSh Int
count
                        !ix' :: DIM2
ix'                    = Z
sh Z -> Int -> Z :. Int
forall tail head. tail -> head -> tail :. head
:. Int
y Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
jkrnHeight' (Z :. Int) -> Int -> DIM2
forall tail head. tail -> head -> tail :. head
:. Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
ikrnWidth'
                        !here :: a
here                   = Array U DIM2 a
kernel Array U DIM2 a -> DIM2 -> a
forall r e sh. (Source r e, Shape sh) => Array r sh e -> sh -> e
`unsafeIndex` DIM2
ix a -> a -> a
forall a. Num a => a -> a -> a
* (DIM2 -> a
get' DIM2
ix')
                   in   Int -> a -> a
integrate (Int
count Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (a
acc a -> a -> a
forall a. Num a => a -> a -> a
+ a
here)

           in   Int -> a -> a
integrate Int
0 a
0
{-# INLINE convolveOutP #-}