{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# OPTIONS_HADDOCK hide #-}
module Internal.Convolution(
corr, conv, corrMin,
corr2, conv2, separable
) where
import qualified Data.Vector.Storable as SV
import Internal.Vector
import Internal.Matrix
import Internal.Numeric
import Internal.Element
import Internal.Conversion
import Internal.Container
#if MIN_VERSION_base(4,11,0)
import Prelude hiding ((<>))
#endif
vectSS :: Element t => Int -> Vector t -> Matrix t
vectSS n v = fromRows [ subVector k n v | k <- [0 .. dim v - n] ]
corr
:: (Container Vector t, Product t)
=> Vector t
-> Vector t
-> Vector t
corr ker v
| dim ker == 0 = konst 0 (dim v)
| dim ker <= dim v = vectSS (dim ker) v <> ker
| otherwise = error $ "corr: dim kernel ("++show (dim ker)++") > dim vector ("++show (dim v)++")"
conv :: (Container Vector t, Product t, Num t) => Vector t -> Vector t -> Vector t
conv ker v
| dim ker == 0 = konst 0 (dim v)
| otherwise = corr ker' v'
where
ker' = SV.reverse ker
v' = vjoin [z,v,z]
z = konst 0 (dim ker -1)
corrMin :: (Container Vector t, RealElement t, Product t)
=> Vector t
-> Vector t
-> Vector t
corrMin ker v
| dim ker == 0 = error "corrMin: empty kernel"
| otherwise = minEvery ss (asRow ker) <> ones
where
minEvery a b = cond a b a a b
ss = vectSS (dim ker) v
ones = konst 1 (dim ker)
matSS :: Element t => Int -> Matrix t -> [Matrix t]
matSS dr m = map (reshape c) [ subVector (k*c) n v | k <- [0 .. r - dr] ]
where
v = flatten m
c = cols m
r = rows m
n = dr*c
corr2 :: Product a => Matrix a -> Matrix a -> Matrix a
corr2 ker mat = dims
. concatMap (map (udot ker' . flatten) . matSS c . trans)
. matSS r $ mat
where
r = rows ker
c = cols ker
ker' = flatten (trans ker)
rr = rows mat - r + 1
rc = cols mat - c + 1
dims | rr > 0 && rc > 0 = (rr >< rc)
| otherwise = error $ "corr2: dim kernel ("++sz ker++") > dim matrix ("++sz mat++")"
sz m = show (rows m)++"x"++show (cols m)
conv2
:: (Num (Matrix a), Product a, Container Vector a)
=> Matrix a
-> Matrix a -> Matrix a
conv2 k m
| empty = konst 0 (rows m + r -1, cols m + c -1)
| otherwise = corr2 (fliprl . flipud $ k) padded
where
padded = fromBlocks [[z,0,0]
,[0,m,0]
,[0,0,z]]
r = rows k
c = cols k
z = konst 0 (r-1,c-1)
empty = r == 0 || c == 0
separable :: Element t => (Vector t -> Vector t) -> Matrix t -> Matrix t
separable f = fromColumns . map f . toColumns . fromRows . map f . toRows