module Grenade.Layers.Elu (
Elu (..)
) where
import Data.Serialize
import GHC.TypeLits
import Grenade.Core
import qualified Numeric.LinearAlgebra.Static as LAS
data Elu = Elu
deriving Show
instance UpdateLayer Elu where
type Gradient Elu = ()
runUpdate _ _ _ = Elu
createRandom = return Elu
instance Serialize Elu where
put _ = return ()
get = return Elu
instance ( KnownNat i) => Layer Elu ('D1 i) ('D1 i) where
type Tape Elu ('D1 i) ('D1 i) = S ('D1 i)
runForwards _ (S1D y) = (S1D y, S1D (elu y))
where
elu = LAS.dvmap (\a -> if a <= 0 then exp a 1 else a)
runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (elu' y * dEdy))
where
elu' = LAS.dvmap (\a -> if a <= 0 then exp a else 1)
instance (KnownNat i, KnownNat j) => Layer Elu ('D2 i j) ('D2 i j) where
type Tape Elu ('D2 i j) ('D2 i j) = S ('D2 i j)
runForwards _ (S2D y) = (S2D y, S2D (elu y))
where
elu = LAS.dmmap (\a -> if a <= 0 then exp a 1 else a)
runBackwards _ (S2D y) (S2D dEdy) = ((), S2D (elu' y * dEdy))
where
elu' = LAS.dmmap (\a -> if a <= 0 then exp a else 1)
instance (KnownNat i, KnownNat j, KnownNat k) => Layer Elu ('D3 i j k) ('D3 i j k) where
type Tape Elu ('D3 i j k) ('D3 i j k) = S ('D3 i j k)
runForwards _ (S3D y) = (S3D y, S3D (elu y))
where
elu = LAS.dmmap (\a -> if a <= 0 then exp a 1 else a)
runBackwards _ (S3D y) (S3D dEdy) = ((), S3D (elu' y * dEdy))
where
elu' = LAS.dmmap (\a -> if a <= 0 then exp a else 1)