module Grenade.Layers.Pooling (
Pooling (..)
) where
import Data.Maybe
import Data.Proxy
import Data.Serialize
import Data.Singletons.TypeLits
import GHC.TypeLits
import Grenade.Core
import Grenade.Layers.Internal.Pooling
import Numeric.LinearAlgebra.Static as LAS hiding ((|||), build, toRows)
data Pooling :: Nat -> Nat -> Nat -> Nat -> * where
Pooling :: Pooling kernelRows kernelColumns strideRows strideColumns
instance Show (Pooling k k' s s') where
show Pooling = "Pooling"
instance UpdateLayer (Pooling kernelRows kernelColumns strideRows strideColumns) where
type Gradient (Pooling kr kc sr sc) = ()
runUpdate _ Pooling _ = Pooling
createRandom = return Pooling
instance Serialize (Pooling kernelRows kernelColumns strideRows strideColumns) where
put _ = return ()
get = return Pooling
instance ( KnownNat kernelRows
, KnownNat kernelColumns
, KnownNat strideRows
, KnownNat strideColumns
, KnownNat inputRows
, KnownNat inputColumns
, KnownNat outputRows
, KnownNat outputColumns
, ((outputRows 1) * strideRows) ~ (inputRows kernelRows)
, ((outputColumns 1) * strideColumns) ~ (inputColumns kernelColumns)
) => Layer (Pooling kernelRows kernelColumns strideRows strideColumns) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where
type Tape (Pooling kernelRows kernelColumns strideRows strideColumns) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) = S ('D2 inputRows inputColumns)
runForwards Pooling (S2D input) =
let height = fromIntegral $ natVal (Proxy :: Proxy inputRows)
width = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns)
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns)
ex = extract input
r = poolForward 1 height width kx ky sx sy ex
rs = fromJust . create $ r
in (S2D input, S2D rs)
runBackwards Pooling (S2D input) (S2D dEdy) =
let height = fromIntegral $ natVal (Proxy :: Proxy inputRows)
width = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns)
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns)
ex = extract input
eo = extract dEdy
vs = poolBackward 1 height width kx ky sx sy ex eo
in ((), S2D . fromJust . create $ vs)
instance ( KnownNat kernelRows
, KnownNat kernelColumns
, KnownNat strideRows
, KnownNat strideColumns
, KnownNat inputRows
, KnownNat inputColumns
, KnownNat outputRows
, KnownNat outputColumns
, KnownNat channels
, KnownNat (outputRows * channels)
, ((outputRows 1) * strideRows) ~ (inputRows kernelRows)
, ((outputColumns 1) * strideColumns) ~ (inputColumns kernelColumns)
) => Layer (Pooling kernelRows kernelColumns strideRows strideColumns) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) where
type Tape (Pooling kernelRows kernelColumns strideRows strideColumns) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) = S ('D3 inputRows inputColumns channels)
runForwards Pooling (S3D input) =
let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns)
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns)
ch = fromIntegral $ natVal (Proxy :: Proxy channels)
ex = extract input
r = poolForward ch ix iy kx ky sx sy ex
rs = fromJust . create $ r
in (S3D input, S3D rs)
runBackwards Pooling (S3D input) (S3D dEdy) =
let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
ky = fromIntegral $ natVal (Proxy :: Proxy kernelColumns)
sx = fromIntegral $ natVal (Proxy :: Proxy strideRows)
sy = fromIntegral $ natVal (Proxy :: Proxy strideColumns)
ch = fromIntegral $ natVal (Proxy :: Proxy channels)
ex = extract input
eo = extract dEdy
vs = poolBackward ch ix iy kx ky sx sy ex eo
in ((), S3D . fromJust . create $ vs)