{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE NoStarIsType #-}
module Data.Array.Shaped.Convolve(convolve) where
import Data.Array.Shaped
import Data.Array.Shaped.MatMul
import Data.Array.Internal.Shape
import GHC.TypeLits
import qualified Numeric.LinearAlgebra as N

-- | Convolve the /n/ outer dimensions with the given kernel.
-- There is no padding nor striding.
-- The input has shape /spatialSh/ ++ /channelSh/,
-- the kernel has shape /spatialKernelSh/ ++ /channelSh/ ++ /featureSh/,
-- and the result has shape /spatialOutSh/ ++ /featureSh/.
-- The /n/ gives the rank of the /spatialSh/.
--
-- Example:
-- @
--  i :: Array [20,30,3] T  -- 20x30 image with 3 channels
--  k :: Array [5,5,3,8] T  -- 5x5 kernel with 8 output features
--  convolve @2 i k :: Array [16,26,8] T
-- @
convolve :: forall (n :: Nat) ish ksh osh wsh a ksc ksf i ws isp iwc .
            ( i ~ Rank ish                       -- input rank
            , ws ~ Take n ksh                    -- window size
            , Window ws ish wsh, KnownNat (Rank ws)
            , ksc ~ Size (Take i ksh)            -- spatial + channels
            , ksf ~ Size (Drop i ksh)            -- features
            , isp ~ Size (Take n wsh)            -- spatial
            , iwc ~ Size (Drop n wsh)            -- kernel + channels
            , iwc ~ ksc
            , osh ~ (Take n wsh ++ Drop i ksh)
            , Size wsh ~ (isp * iwc)
            , Size ksh ~ (ksc * ksf)
            , Size osh ~ (isp * ksf)
            , Shape wsh, Shape ksh, Shape osh
            , KnownNat ksc, KnownNat isp, KnownNat ksf
            , N.Numeric a
            ) =>
            Array ish a -> Array ksh a -> Array osh a
convolve :: Array ish a -> Array ksh a -> Array osh a
convolve Array ish a
i Array ksh a
k =
  let iw :: Array wsh a
      iw :: Array wsh a
iw = Array ish a -> Array wsh a
forall (ws :: [Nat]) (sh' :: [Nat]) (sh :: [Nat]) a.
(Window ws sh sh', KnownNat (Rank ws)) =>
Array sh a -> Array sh' a
window @ws Array ish a
i
      ir :: Array [isp, iwc] a
      ir :: Array '[isp, iwc] a
ir = Array wsh a -> Array '[isp, iwc] a
forall (sh' :: [Nat]) (sh :: [Nat]) a.
(Shape sh, Shape sh', Size sh ~ Size sh') =>
Array sh a -> Array sh' a
reshape Array wsh a
iw
      kr :: Array [ksc, ksf] a
      kr :: Array '[ksc, ksf] a
kr = Array ksh a -> Array '[ksc, ksf] a
forall (sh' :: [Nat]) (sh :: [Nat]) a.
(Shape sh, Shape sh', Size sh ~ Size sh') =>
Array sh a -> Array sh' a
reshape Array ksh a
k
      m  :: Array [isp, ksf] a
      m :: Array '[isp, ksf] a
m  = Array '[isp, iwc] a -> Array '[iwc, ksf] a -> Array '[isp, ksf] a
forall (m :: Nat) (n :: Nat) (o :: Nat) a.
(Numeric a, KnownNat m, KnownNat n, KnownNat o) =>
Array '[m, n] a -> Array '[n, o] a -> Array '[m, o] a
matMul Array '[isp, iwc] a
ir Array '[ksc, ksf] a
Array '[iwc, ksf] a
kr
      r  :: Array osh a
      r :: Array osh a
r  = Array '[isp, ksf] a -> Array osh a
forall (sh' :: [Nat]) (sh :: [Nat]) a.
(Shape sh, Shape sh', Size sh ~ Size sh') =>
Array sh a -> Array sh' a
reshape Array '[isp, ksf] a
m
  in  Array osh a
r

_example :: Array [20,30,3] Float -> Array [5,5,3,8] Float -> Array [16,26,8] Float
_example :: Array '[20, 30, 3] Float
-> Array '[5, 5, 3, 8] Float -> Array '[16, 26, 8] Float
_example = forall (ish :: [Nat]) (ksh :: [Nat]) (osh :: [Nat]) (wsh :: [Nat])
       a (ksc :: Nat) (ksf :: Nat) (i :: Nat) (ws :: [Nat]) (isp :: Nat)
       (iwc :: Nat).
(i ~ Rank ish, ws ~ Take 2 ksh, Window ws ish wsh,
 KnownNat (Rank ws), ksc ~ Size (Take i ksh),
 ksf ~ Size (Drop i ksh), isp ~ Size (Take 2 wsh),
 iwc ~ Size (Drop 2 wsh), iwc ~ ksc,
 osh ~ (Take 2 wsh ++ Drop i ksh), Size wsh ~ (isp * iwc),
 Size ksh ~ (ksc * ksf), Size osh ~ (isp * ksf), Shape wsh,
 Shape ksh, Shape osh, KnownNat ksc, KnownNat isp, KnownNat ksf,
 Numeric a) =>
Array ish a -> Array ksh a -> Array osh a
forall (n :: Nat) (ish :: [Nat]) (ksh :: [Nat]) (osh :: [Nat])
       (wsh :: [Nat]) a (ksc :: Nat) (ksf :: Nat) (i :: Nat) (ws :: [Nat])
       (isp :: Nat) (iwc :: Nat).
(i ~ Rank ish, ws ~ Take n ksh, Window ws ish wsh,
 KnownNat (Rank ws), ksc ~ Size (Take i ksh),
 ksf ~ Size (Drop i ksh), isp ~ Size (Take n wsh),
 iwc ~ Size (Drop n wsh), iwc ~ ksc,
 osh ~ (Take n wsh ++ Drop i ksh), Size wsh ~ (isp * iwc),
 Size ksh ~ (ksc * ksf), Size osh ~ (isp * ksf), Shape wsh,
 Shape ksh, Shape osh, KnownNat ksc, KnownNat isp, KnownNat ksf,
 Numeric a) =>
Array ish a -> Array ksh a -> Array osh a
convolve @2