{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
-----------------------------------------------------------------------------
{- |
Module      :  Internal.Convolution
Copyright   :  (c) Alberto Ruiz 2012
License     :  BSD3
Maintainer  :  Alberto Ruiz
Stability   :  provisional

-}
-----------------------------------------------------------------------------
{-# OPTIONS_HADDOCK hide #-}

module Internal.Convolution(
   corr, conv, corrMin,
   corr2, conv2, separable
) where

import qualified Data.Vector.Storable as SV
import Internal.Vector
import Internal.Matrix
import Internal.Numeric
import Internal.Element
import Internal.Conversion
import Internal.Container
#if MIN_VERSION_base(4,11,0)
import Prelude hiding ((<>))
#endif


vectSS :: Element t => Int -> Vector t -> Matrix t
vectSS :: Int -> Vector t -> Matrix t
vectSS Int
n Vector t
v = [Vector t] -> Matrix t
forall t. Element t => [Vector t] -> Matrix t
fromRows [ Int -> Int -> Vector t -> Vector t
forall t. Storable t => Int -> Int -> Vector t -> Vector t
subVector Int
k Int
n Vector t
v | Int
k <- [Int
0 .. Vector t -> Int
forall t. Storable t => Vector t -> Int
dim Vector t
v Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n] ]


corr
  :: (Container Vector t, Product t)
    => Vector t -- ^ kernel
    -> Vector t -- ^ source
    -> Vector t
{- ^ correlation

>>> corr (fromList[1,2,3]) (fromList [1..10])
[14.0,20.0,26.0,32.0,38.0,44.0,50.0,56.0]
it :: (Enum t, Product t, Container Vector t) => Vector t

-}
corr :: Vector t -> Vector t -> Vector t
corr Vector t
ker Vector t
v
    | Vector t -> Int
forall t. Storable t => Vector t -> Int
dim Vector t
ker Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = t -> Int -> Vector t
forall e d (c :: * -> *). Konst e d c => e -> d -> c e
konst t
0 (Vector t -> Int
forall t. Storable t => Vector t -> Int
dim Vector t
v)
    | Vector t -> Int
forall t. Storable t => Vector t -> Int
dim Vector t
ker Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Vector t -> Int
forall t. Storable t => Vector t -> Int
dim Vector t
v = Int -> Vector t -> Matrix t
forall t. Element t => Int -> Vector t -> Matrix t
vectSS (Vector t -> Int
forall t. Storable t => Vector t -> Int
dim Vector t
ker) Vector t
v Matrix t -> Vector t -> Vector t
forall (a :: * -> *) (b :: * -> *) (c :: * -> *) t.
(Mul a b c, Product t) =>
a t -> b t -> c t
<> Vector t
ker
    | Bool
otherwise = [Char] -> Vector t
forall a. HasCallStack => [Char] -> a
error ([Char] -> Vector t) -> [Char] -> Vector t
forall a b. (a -> b) -> a -> b
$ [Char]
"corr: dim kernel ("[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++Int -> [Char]
forall a. Show a => a -> [Char]
show (Vector t -> Int
forall t. Storable t => Vector t -> Int
dim Vector t
ker)[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
") > dim vector ("[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++Int -> [Char]
forall a. Show a => a -> [Char]
show (Vector t -> Int
forall t. Storable t => Vector t -> Int
dim Vector t
v)[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
")"


conv :: (Container Vector t, Product t, Num t) => Vector t -> Vector t -> Vector t
{- ^ convolution ('corr' with reversed kernel and padded input, equivalent to polynomial product)

>>> conv (fromList[1,1]) (fromList [-1,1])
[-1.0,0.0,1.0]
it :: (Product t, Container Vector t) => Vector t

-}
conv :: Vector t -> Vector t -> Vector t
conv Vector t
ker Vector t
v
    | Vector t -> Int
forall t. Storable t => Vector t -> Int
dim Vector t
ker Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = t -> Int -> Vector t
forall e d (c :: * -> *). Konst e d c => e -> d -> c e
konst t
0 (Vector t -> Int
forall t. Storable t => Vector t -> Int
dim Vector t
v)
    | Bool
otherwise = Vector t -> Vector t -> Vector t
forall t.
(Container Vector t, Product t) =>
Vector t -> Vector t -> Vector t
corr Vector t
ker' Vector t
v'
  where
    ker' :: Vector t
ker' = Vector t -> Vector t
forall a. Storable a => Vector a -> Vector a
SV.reverse Vector t
ker
    v' :: Vector t
v' = [Vector t] -> Vector t
forall t. Storable t => [Vector t] -> Vector t
vjoin [Vector t
z,Vector t
v,Vector t
z]
    z :: Vector t
z = t -> Int -> Vector t
forall e d (c :: * -> *). Konst e d c => e -> d -> c e
konst t
0 (Vector t -> Int
forall t. Storable t => Vector t -> Int
dim Vector t
ker Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)

corrMin :: (Container Vector t, RealElement t, Product t)
        => Vector t
        -> Vector t
        -> Vector t
-- ^ similar to 'corr', using 'min' instead of (*)
corrMin :: Vector t -> Vector t -> Vector t
corrMin Vector t
ker Vector t
v
    | Vector t -> Int
forall t. Storable t => Vector t -> Int
dim Vector t
ker Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = [Char] -> Vector t
forall a. HasCallStack => [Char] -> a
error [Char]
"corrMin: empty kernel"
    | Bool
otherwise    = Matrix t -> Matrix t -> Matrix t
forall x (c :: * -> *). (Ord x, Container c x) => c x -> c x -> c x
minEvery Matrix t
ss (Vector t -> Matrix t
forall a. Storable a => Vector a -> Matrix a
asRow Vector t
ker) Matrix t -> Vector t -> Vector t
forall (a :: * -> *) (b :: * -> *) (c :: * -> *) t.
(Mul a b c, Product t) =>
a t -> b t -> c t
<> Vector t
ones
  where
    minEvery :: c x -> c x -> c x
minEvery c x
a c x
b = c x -> c x -> c x -> c x -> c x -> c x
forall e (c :: * -> *) x.
(Ord e, Container c e, Container c x) =>
c e -> c e -> c x -> c x -> c x -> c x
cond c x
a c x
b c x
a c x
a c x
b
    ss :: Matrix t
ss = Int -> Vector t -> Matrix t
forall t. Element t => Int -> Vector t -> Matrix t
vectSS (Vector t -> Int
forall t. Storable t => Vector t -> Int
dim Vector t
ker) Vector t
v
    ones :: Vector t
ones = t -> Int -> Vector t
forall e d (c :: * -> *). Konst e d c => e -> d -> c e
konst t
1 (Vector t -> Int
forall t. Storable t => Vector t -> Int
dim Vector t
ker)



matSS :: Element t => Int -> Matrix t -> [Matrix t]
matSS :: Int -> Matrix t -> [Matrix t]
matSS Int
dr Matrix t
m = (Vector t -> Matrix t) -> [Vector t] -> [Matrix t]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Vector t -> Matrix t
forall t. Storable t => Int -> Vector t -> Matrix t
reshape Int
c) [ Int -> Int -> Vector t -> Vector t
forall t. Storable t => Int -> Int -> Vector t -> Vector t
subVector (Int
kInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
c) Int
n Vector t
v | Int
k <- [Int
0 .. Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
dr] ]
  where
    v :: Vector t
v = Matrix t -> Vector t
forall t. Element t => Matrix t -> Vector t
flatten Matrix t
m
    c :: Int
c = Matrix t -> Int
forall t. Matrix t -> Int
cols Matrix t
m
    r :: Int
r = Matrix t -> Int
forall t. Matrix t -> Int
rows Matrix t
m
    n :: Int
n = Int
drInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
c


{- | 2D correlation (without padding)

>>> disp 5 $ corr2 (konst 1 (3,3)) (ident 10 :: Matrix Double)
8x8
3  2  1  0  0  0  0  0
2  3  2  1  0  0  0  0
1  2  3  2  1  0  0  0
0  1  2  3  2  1  0  0
0  0  1  2  3  2  1  0
0  0  0  1  2  3  2  1
0  0  0  0  1  2  3  2
0  0  0  0  0  1  2  3

-}
corr2 :: Product a => Matrix a -> Matrix a -> Matrix a
corr2 :: Matrix a -> Matrix a -> Matrix a
corr2 Matrix a
ker Matrix a
mat = [a] -> Matrix a
dims
              ([a] -> Matrix a) -> (Matrix a -> [a]) -> Matrix a -> Matrix a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Matrix a -> [a]) -> [Matrix a] -> [a]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap ((Matrix a -> a) -> [Matrix a] -> [a]
forall a b. (a -> b) -> [a] -> [b]
map (Vector a -> Vector a -> a
forall e. Product e => Vector e -> Vector e -> e
udot Vector a
ker' (Vector a -> a) -> (Matrix a -> Vector a) -> Matrix a -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix a -> Vector a
forall t. Element t => Matrix t -> Vector t
flatten) ([Matrix a] -> [a]) -> (Matrix a -> [Matrix a]) -> Matrix a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Matrix a -> [Matrix a]
forall t. Element t => Int -> Matrix t -> [Matrix t]
matSS Int
c (Matrix a -> [Matrix a])
-> (Matrix a -> Matrix a) -> Matrix a -> [Matrix a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix a -> Matrix a
forall t. Matrix t -> Matrix t
trans)
              ([Matrix a] -> [a]) -> (Matrix a -> [Matrix a]) -> Matrix a -> [a]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Matrix a -> [Matrix a]
forall t. Element t => Int -> Matrix t -> [Matrix t]
matSS Int
r (Matrix a -> Matrix a) -> Matrix a -> Matrix a
forall a b. (a -> b) -> a -> b
$ Matrix a
mat
  where
    r :: Int
r = Matrix a -> Int
forall t. Matrix t -> Int
rows Matrix a
ker
    c :: Int
c = Matrix a -> Int
forall t. Matrix t -> Int
cols Matrix a
ker
    ker' :: Vector a
ker' = Matrix a -> Vector a
forall t. Element t => Matrix t -> Vector t
flatten (Matrix a -> Matrix a
forall t. Matrix t -> Matrix t
trans Matrix a
ker)
    rr :: Int
rr = Matrix a -> Int
forall t. Matrix t -> Int
rows Matrix a
mat Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
    rc :: Int
rc = Matrix a -> Int
forall t. Matrix t -> Int
cols Matrix a
mat Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
    dims :: [a] -> Matrix a
dims | Int
rr Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 Bool -> Bool -> Bool
&& Int
rc Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0 = (Int
rr Int -> Int -> [a] -> Matrix a
forall a. Storable a => Int -> Int -> [a] -> Matrix a
>< Int
rc)
         | Bool
otherwise = [Char] -> [a] -> Matrix a
forall a. HasCallStack => [Char] -> a
error ([Char] -> [a] -> Matrix a) -> [Char] -> [a] -> Matrix a
forall a b. (a -> b) -> a -> b
$ [Char]
"corr2: dim kernel ("[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++Matrix a -> [Char]
forall t. Matrix t -> [Char]
sz Matrix a
ker[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
") > dim matrix ("[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++Matrix a -> [Char]
forall t. Matrix t -> [Char]
sz Matrix a
mat[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
")"
    sz :: Matrix t -> [Char]
sz Matrix t
m = Int -> [Char]
forall a. Show a => a -> [Char]
show (Matrix t -> Int
forall t. Matrix t -> Int
rows Matrix t
m)[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
"x"[Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++Int -> [Char]
forall a. Show a => a -> [Char]
show (Matrix t -> Int
forall t. Matrix t -> Int
cols Matrix t
m)
-- TODO check empty kernel

{- | 2D convolution

>>> disp 5 $ conv2 (konst 1 (3,3)) (ident 10 :: Matrix Double)
12x12
1  1  1  0  0  0  0  0  0  0  0  0
1  2  2  1  0  0  0  0  0  0  0  0
1  2  3  2  1  0  0  0  0  0  0  0
0  1  2  3  2  1  0  0  0  0  0  0
0  0  1  2  3  2  1  0  0  0  0  0
0  0  0  1  2  3  2  1  0  0  0  0
0  0  0  0  1  2  3  2  1  0  0  0
0  0  0  0  0  1  2  3  2  1  0  0
0  0  0  0  0  0  1  2  3  2  1  0
0  0  0  0  0  0  0  1  2  3  2  1
0  0  0  0  0  0  0  0  1  2  2  1
0  0  0  0  0  0  0  0  0  1  1  1

-}
conv2
    :: (Num (Matrix a), Product a, Container Vector a)
    => Matrix a -- ^ kernel
    -> Matrix a -> Matrix a
conv2 :: Matrix a -> Matrix a -> Matrix a
conv2 Matrix a
k Matrix a
m
    | Bool
empty     = a -> (Int, Int) -> Matrix a
forall e d (c :: * -> *). Konst e d c => e -> d -> c e
konst a
0 (Matrix a -> Int
forall t. Matrix t -> Int
rows Matrix a
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1, Matrix a -> Int
forall t. Matrix t -> Int
cols Matrix a
m Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
c Int -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
    | Bool
otherwise = Matrix a -> Matrix a -> Matrix a
forall a. Product a => Matrix a -> Matrix a -> Matrix a
corr2 (Matrix a -> Matrix a
forall t. Element t => Matrix t -> Matrix t
fliprl (Matrix a -> Matrix a)
-> (Matrix a -> Matrix a) -> Matrix a -> Matrix a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix a -> Matrix a
forall t. Element t => Matrix t -> Matrix t
flipud (Matrix a -> Matrix a) -> Matrix a -> Matrix a
forall a b. (a -> b) -> a -> b
$ Matrix a
k) Matrix a
padded
  where
    padded :: Matrix a
padded = [[Matrix a]] -> Matrix a
forall t. Element t => [[Matrix t]] -> Matrix t
fromBlocks [[Matrix a
z,Matrix a
0,Matrix a
0]
                        ,[Matrix a
0,Matrix a
m,Matrix a
0]
                        ,[Matrix a
0,Matrix a
0,Matrix a
z]]
    r :: Int
r = Matrix a -> Int
forall t. Matrix t -> Int
rows Matrix a
k
    c :: Int
c = Matrix a -> Int
forall t. Matrix t -> Int
cols Matrix a
k
    z :: Matrix a
z = a -> (Int, Int) -> Matrix a
forall e d (c :: * -> *). Konst e d c => e -> d -> c e
konst a
0 (Int
rInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1,Int
cInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1)
    empty :: Bool
empty = Int
r Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 Bool -> Bool -> Bool
|| Int
c Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0


separable :: Element t => (Vector t -> Vector t) -> Matrix t -> Matrix t
-- ^ matrix computation implemented as separated vector operations by rows and columns.
separable :: (Vector t -> Vector t) -> Matrix t -> Matrix t
separable Vector t -> Vector t
f = [Vector t] -> Matrix t
forall t. Element t => [Vector t] -> Matrix t
fromColumns ([Vector t] -> Matrix t)
-> (Matrix t -> [Vector t]) -> Matrix t -> Matrix t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vector t -> Vector t) -> [Vector t] -> [Vector t]
forall a b. (a -> b) -> [a] -> [b]
map Vector t -> Vector t
f ([Vector t] -> [Vector t])
-> (Matrix t -> [Vector t]) -> Matrix t -> [Vector t]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix t -> [Vector t]
forall t. Element t => Matrix t -> [Vector t]
toColumns (Matrix t -> [Vector t])
-> (Matrix t -> Matrix t) -> Matrix t -> [Vector t]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Vector t] -> Matrix t
forall t. Element t => [Vector t] -> Matrix t
fromRows ([Vector t] -> Matrix t)
-> (Matrix t -> [Vector t]) -> Matrix t -> Matrix t
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Vector t -> Vector t) -> [Vector t] -> [Vector t]
forall a b. (a -> b) -> [a] -> [b]
map Vector t -> Vector t
f ([Vector t] -> [Vector t])
-> (Matrix t -> [Vector t]) -> Matrix t -> [Vector t]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Matrix t -> [Vector t]
forall t. Element t => Matrix t -> [Vector t]
toRows