{-# LANGUAGE DataKinds #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE TypeOperators #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE FlexibleContexts #-} {-# LANGUAGE UndecidableInstances #-} module Grenade.Recurrent.Layers.BasicRecurrent ( BasicRecurrent (..) , randomBasicRecurrent ) where import Control.Monad.Random ( MonadRandom, getRandom ) import Data.Singletons.TypeLits import Numeric.LinearAlgebra.Static import GHC.TypeLits import Grenade.Core import Grenade.Recurrent.Core data BasicRecurrent :: Nat -- Input layer size -> Nat -- Output layer size -> * where BasicRecurrent :: ( KnownNat input , KnownNat output , KnownNat matrixCols , matrixCols ~ (input + output)) => !(R output) -- Bias neuron weights -> !(R output) -- Bias neuron momentum -> !(L output matrixCols) -- Activation -> !(L output matrixCols) -- Momentum -> BasicRecurrent input output data BasicRecurrent' :: Nat -- Input layer size -> Nat -- Output layer size -> * where BasicRecurrent' :: ( KnownNat input , KnownNat output , KnownNat matrixCols , matrixCols ~ (input + output)) => !(R output) -- Bias neuron gradients -> !(L output matrixCols) -> BasicRecurrent' input output instance Show (BasicRecurrent i o) where show BasicRecurrent {} = "BasicRecurrent" instance (KnownNat i, KnownNat o, KnownNat (i + o)) => UpdateLayer (BasicRecurrent i o) where type Gradient (BasicRecurrent i o) = (BasicRecurrent' i o) runUpdate LearningParameters {..} (BasicRecurrent oldBias oldBiasMomentum oldActivations oldMomentum) (BasicRecurrent' biasGradient activationGradient) = let newBiasMomentum = konst learningMomentum * oldBiasMomentum - konst learningRate * biasGradient newBias = oldBias + newBiasMomentum newMomentum = konst learningMomentum * oldMomentum - konst learningRate * activationGradient regulariser = konst (learningRegulariser * learningRate) * oldActivations newActivations = oldActivations + newMomentum - regulariser in BasicRecurrent newBias newBiasMomentum newActivations newMomentum createRandom = randomBasicRecurrent instance (KnownNat i, KnownNat o, KnownNat (i + o), i <= (i + o), o ~ ((i + o) - i)) => RecurrentUpdateLayer (BasicRecurrent i o) where type RecurrentShape (BasicRecurrent i o) = 'D1 o instance (KnownNat i, KnownNat o, KnownNat (i + o), i <= (i + o), o ~ ((i + o) - i)) => RecurrentLayer (BasicRecurrent i o) ('D1 i) ('D1 o) where type RecTape (BasicRecurrent i o) ('D1 i) ('D1 o) = (S ('D1 o), S ('D1 i)) -- Do a matrix vector multiplication and return the result. runRecurrentForwards (BasicRecurrent wB _ wN _) (S1D lastOutput) (S1D thisInput) = let thisOutput = S1D $ wB + wN #> (thisInput # lastOutput) in ((S1D lastOutput, S1D thisInput), thisOutput, thisOutput) -- Run a backpropogation step for a full connected layer. runRecurrentBackwards (BasicRecurrent _ _ wN _) (S1D lastOutput, S1D thisInput) (S1D dRec) (S1D dEdy) = let biasGradient = (dRec + dEdy) layerGrad = (dRec + dEdy) `outer` (thisInput # lastOutput) -- calcluate derivatives for next step (backGrad, recGrad) = split $ tr wN #> (dRec + dEdy) in (BasicRecurrent' biasGradient layerGrad, S1D recGrad, S1D backGrad) randomBasicRecurrent :: (MonadRandom m, KnownNat i, KnownNat o, KnownNat x, x ~ (i + o)) => m (BasicRecurrent i o) randomBasicRecurrent = do seed1 <- getRandom seed2 <- getRandom let wB = randomVector seed1 Uniform * 2 - 1 wN = uniformSample seed2 (-1) 1 bm = konst 0 mm = konst 0 return $ BasicRecurrent wB bm wN mm