{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-|
Module      : Grenade.Layers.Logit
Description : Exponential linear unit layer
Copyright   : (c) Huw Campbell, 2016-2017
License     : BSD2
Stability   : experimental
-}
module Grenade.Layers.Elu (
    Elu (..)
  ) where

import           Data.Serialize

import           GHC.TypeLits
import           Grenade.Core

import qualified Numeric.LinearAlgebra.Static as LAS

-- | An exponential linear unit.
--   A layer which can act between any shape of the same dimension, acting as a
--   diode on every neuron individually.
data Elu = Elu
  deriving Show

instance UpdateLayer Elu where
  type Gradient Elu = ()
  runUpdate _ _ _ = Elu
  createRandom = return Elu

instance Serialize Elu where
  put _ = return ()
  get = return Elu

instance ( KnownNat i) => Layer Elu ('D1 i) ('D1 i) where
  type Tape Elu ('D1 i) ('D1 i) = S ('D1 i)

  runForwards _ (S1D y) = (S1D y, S1D (elu y))
    where
      elu = LAS.dvmap (\a -> if a <= 0 then exp a - 1 else a)
  runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (elu' y * dEdy))
    where
      elu' = LAS.dvmap (\a -> if a <= 0 then exp a else 1)

instance (KnownNat i, KnownNat j) => Layer Elu ('D2 i j) ('D2 i j) where
  type Tape Elu ('D2 i j) ('D2 i j) = S ('D2 i j)

  runForwards _ (S2D y) = (S2D y, S2D (elu y))
    where
      elu = LAS.dmmap (\a -> if a <= 0 then exp a - 1 else a)
  runBackwards _ (S2D y) (S2D dEdy) = ((), S2D (elu' y * dEdy))
    where
      elu' = LAS.dmmap (\a -> if a <= 0 then exp a else 1)

instance (KnownNat i, KnownNat j, KnownNat k) => Layer Elu ('D3 i j k) ('D3 i j k) where

  type Tape Elu ('D3 i j k) ('D3 i j k) = S ('D3 i j k)

  runForwards _ (S3D y) = (S3D y, S3D (elu y))
    where
      elu = LAS.dmmap (\a -> if a <= 0 then exp a - 1 else a)
  runBackwards _ (S3D y) (S3D dEdy) = ((), S3D (elu' y * dEdy))
    where
      elu' = LAS.dmmap (\a -> if a <= 0 then exp a else 1)