module Grenade.Layers.Dropout (
Dropout (..)
, randomDropout
) where
import Control.Monad.Random hiding (fromList)
import GHC.TypeLits
import Grenade.Core
data Dropout = Dropout {
dropoutRate :: Double
, dropoutSeed :: Int
} deriving Show
instance UpdateLayer Dropout where
type Gradient Dropout = ()
runUpdate _ x _ = x
createRandom = randomDropout 0.95
randomDropout :: MonadRandom m
=> Double -> m Dropout
randomDropout rate = Dropout rate <$> getRandom
instance (KnownNat i) => Layer Dropout ('D1 i) ('D1 i) where
type Tape Dropout ('D1 i) ('D1 i) = ()
runForwards (Dropout _ _) (S1D x) = ((), S1D x)
runBackwards (Dropout _ _) _ (S1D x) = ((), S1D x)