module Grenade.Layers.Logit (
Logit (..)
) where
import Data.Serialize
import Data.Singletons
import Grenade.Core
data Logit = Logit
deriving Show
instance UpdateLayer Logit where
type Gradient Logit = ()
runUpdate _ _ _ = Logit
createRandom = return Logit
instance (a ~ b, SingI a) => Layer Logit a b where
type Tape Logit a b = S a
runForwards _ a = (a, logistic a)
runBackwards _ a g = ((), logistic' a * g)
instance Serialize Logit where
put _ = return ()
get = return Logit
logistic :: Floating a => a -> a
logistic x = 1 / (1 + exp (x))
logistic' :: Floating a => a -> a
logistic' x = logix * (1 logix)
where
logix = logistic x