{-# LANGUAGE DataKinds #-} {-# LANGUAGE ScopedTypeVariables #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE MultiParamTypeClasses #-} {-| Module : Grenade.Layers.Crop Description : Cropping layer Copyright : (c) Huw Campbell, 2016-2017 License : BSD2 Stability : experimental -} module Grenade.Layers.Crop ( Crop (..) ) where import Data.Maybe import Data.Proxy import Data.Singletons.TypeLits import GHC.TypeLits import Grenade.Core import Grenade.Layers.Internal.Pad import Numeric.LinearAlgebra (konst, subMatrix, diagBlock) import Numeric.LinearAlgebra.Static (extract, create) -- | A cropping layer for a neural network. data Crop :: Nat -> Nat -> Nat -> Nat -> * where Crop :: Crop cropLeft cropTop cropRight cropBottom instance Show (Crop cropLeft cropTop cropRight cropBottom) where show Crop = "Crop" instance UpdateLayer (Crop l t r b) where type Gradient (Crop l t r b) = () runUpdate _ x _ = x createRandom = return Crop -- | A two dimentional image can be cropped. instance ( KnownNat cropLeft , KnownNat cropTop , KnownNat cropRight , KnownNat cropBottom , KnownNat inputRows , KnownNat inputColumns , KnownNat outputRows , KnownNat outputColumns , (inputRows - cropTop - cropBottom) ~ outputRows , (inputColumns - cropLeft - cropRight) ~ outputColumns ) => Layer (Crop cropLeft cropTop cropRight cropBottom) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where type Tape (Crop cropLeft cropTop cropRight cropBottom) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) = () runForwards Crop (S2D input) = let cropl = fromIntegral $ natVal (Proxy :: Proxy cropLeft) cropt = fromIntegral $ natVal (Proxy :: Proxy cropTop) nrows = fromIntegral $ natVal (Proxy :: Proxy outputRows) ncols = fromIntegral $ natVal (Proxy :: Proxy outputColumns) m = extract input r = subMatrix (cropt, cropl) (nrows, ncols) m in ((), S2D . fromJust . create $ r) runBackwards _ _ (S2D dEdy) = let cropl = fromIntegral $ natVal (Proxy :: Proxy cropLeft) cropt = fromIntegral $ natVal (Proxy :: Proxy cropTop) cropr = fromIntegral $ natVal (Proxy :: Proxy cropRight) cropb = fromIntegral $ natVal (Proxy :: Proxy cropBottom) eo = extract dEdy vs = diagBlock [konst 0 (cropt,cropl), eo, konst 0 (cropb,cropr)] in ((), S2D . fromJust . create $ vs) -- | A two dimentional image can be cropped. instance ( KnownNat cropLeft , KnownNat cropTop , KnownNat cropRight , KnownNat cropBottom , KnownNat inputRows , KnownNat inputColumns , KnownNat outputRows , KnownNat outputColumns , KnownNat channels , KnownNat (inputRows * channels) , KnownNat (outputRows * channels) , (outputRows + cropTop + cropBottom) ~ inputRows , (outputColumns + cropLeft + cropRight) ~ inputColumns ) => Layer (Crop cropLeft cropTop cropRight cropBottom) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) where type Tape (Crop cropLeft cropTop cropRight cropBottom) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) = () runForwards Crop (S3D input) = let padl = fromIntegral $ natVal (Proxy :: Proxy cropLeft) padt = fromIntegral $ natVal (Proxy :: Proxy cropTop) padr = fromIntegral $ natVal (Proxy :: Proxy cropRight) padb = fromIntegral $ natVal (Proxy :: Proxy cropBottom) inr = fromIntegral $ natVal (Proxy :: Proxy inputRows) inc = fromIntegral $ natVal (Proxy :: Proxy inputColumns) outr = fromIntegral $ natVal (Proxy :: Proxy outputRows) outc = fromIntegral $ natVal (Proxy :: Proxy outputColumns) ch = fromIntegral $ natVal (Proxy :: Proxy channels) m = extract input cropped = crop ch padl padt padr padb outr outc inr inc m in ((), S3D . fromJust . create $ cropped) runBackwards Crop () (S3D gradient) = let padl = fromIntegral $ natVal (Proxy :: Proxy cropLeft) padt = fromIntegral $ natVal (Proxy :: Proxy cropTop) padr = fromIntegral $ natVal (Proxy :: Proxy cropRight) padb = fromIntegral $ natVal (Proxy :: Proxy cropBottom) inr = fromIntegral $ natVal (Proxy :: Proxy inputRows) inc = fromIntegral $ natVal (Proxy :: Proxy inputColumns) outr = fromIntegral $ natVal (Proxy :: Proxy outputRows) outc = fromIntegral $ natVal (Proxy :: Proxy outputColumns) ch = fromIntegral $ natVal (Proxy :: Proxy channels) m = extract gradient padded = pad ch padl padt padr padb outr outc inr inc m in ((), S3D . fromJust . create $ padded)