{-# LANGUAGE DataKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleContexts #-} {-| Module : Grenade.Core.Pooling Description : Max Pooling layer for 2D and 3D images Copyright : (c) Huw Campbell, 2016-2017 License : BSD2 Stability : experimental -} module Grenade.Layers.Pooling ( Pooling (..) ) where import Data.Maybe import Data.Proxy import Data.Serialize import Data.Singletons.TypeLits import GHC.TypeLits import Grenade.Core import Grenade.Layers.Internal.Pooling import Numeric.LinearAlgebra.Static as LAS hiding ((|||), build, toRows) -- | A pooling layer for a neural network. -- -- Does a max pooling, looking over a kernel similarly to the convolution network, but returning -- maxarg only. This layer is often used to provide minor amounts of translational invariance. -- -- The kernel size dictates which input and output sizes will "fit". Fitting the equation: -- `out = (in - kernel) / stride + 1` for both dimensions. -- data Pooling :: Nat -> Nat -> Nat -> Nat -> * where Pooling :: Pooling kernelRows kernelColumns strideRows strideColumns instance Show (Pooling k k' s s') where show Pooling = "Pooling" instance UpdateLayer (Pooling kernelRows kernelColumns strideRows strideColumns) where type Gradient (Pooling kr kc sr sc) = () runUpdate _ Pooling _ = Pooling createRandom = return Pooling instance Serialize (Pooling kernelRows kernelColumns strideRows strideColumns) where put _ = return () get = return Pooling -- | A two dimentional image can be pooled. instance ( KnownNat kernelRows , KnownNat kernelColumns , KnownNat strideRows , KnownNat strideColumns , KnownNat inputRows , KnownNat inputColumns , KnownNat outputRows , KnownNat outputColumns , ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows) , ((outputColumns - 1) * strideColumns) ~ (inputColumns - kernelColumns) ) => Layer (Pooling kernelRows kernelColumns strideRows strideColumns) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where type Tape (Pooling kernelRows kernelColumns strideRows strideColumns) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) = S ('D2 inputRows inputColumns) runForwards Pooling (S2D input) = let height = fromIntegral $ natVal (Proxy :: Proxy inputRows) width = fromIntegral $ natVal (Proxy :: Proxy inputColumns) kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows) ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns) sx = fromIntegral $ natVal (Proxy :: Proxy strideRows) sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns) ex = extract input r = poolForward 1 height width kx ky sx sy ex rs = fromJust . create $ r in (S2D input, S2D rs) runBackwards Pooling (S2D input) (S2D dEdy) = let height = fromIntegral $ natVal (Proxy :: Proxy inputRows) width = fromIntegral $ natVal (Proxy :: Proxy inputColumns) kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows) ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns) sx = fromIntegral $ natVal (Proxy :: Proxy strideRows) sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns) ex = extract input eo = extract dEdy vs = poolBackward 1 height width kx ky sx sy ex eo in ((), S2D . fromJust . create $ vs) -- | A three dimensional image can be pooled on each layer. instance ( KnownNat kernelRows , KnownNat kernelColumns , KnownNat strideRows , KnownNat strideColumns , KnownNat inputRows , KnownNat inputColumns , KnownNat outputRows , KnownNat outputColumns , KnownNat channels , KnownNat (outputRows * channels) , ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows) , ((outputColumns - 1) * strideColumns) ~ (inputColumns - kernelColumns) ) => Layer (Pooling kernelRows kernelColumns strideRows strideColumns) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) where type Tape (Pooling kernelRows kernelColumns strideRows strideColumns) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) = S ('D3 inputRows inputColumns channels) runForwards Pooling (S3D input) = let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows) iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns) kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows) ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns) sx = fromIntegral $ natVal (Proxy :: Proxy strideRows) sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns) ch = fromIntegral $ natVal (Proxy :: Proxy channels) ex = extract input r = poolForward ch ix iy kx ky sx sy ex rs = fromJust . create $ r in (S3D input, S3D rs) runBackwards Pooling (S3D input) (S3D dEdy) = let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows) iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns) kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows) ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns) sx = fromIntegral $ natVal (Proxy :: Proxy strideRows) sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns) ch = fromIntegral $ natVal (Proxy :: Proxy channels) ex = extract input eo = extract dEdy vs = poolBackward ch ix iy kx ky sx sy ex eo in ((), S3D . fromJust . create $ vs)