module Grenade.Recurrent.Core.Network (
Recurrent
, FeedForward
, RecurrentNetwork (..)
, RecurrentInputs (..)
, RecurrentTapes (..)
, RecurrentGradients (..)
, randomRecurrent
, runRecurrentNetwork
, runRecurrentGradient
, applyRecurrentUpdate
) where
import Control.Monad.Random ( MonadRandom )
import Data.Singletons ( SingI )
import Data.Singletons.Prelude ( Head, Last )
import Data.Serialize
import qualified Data.Vector.Storable as V
import Grenade.Core
import Grenade.Recurrent.Core.Layer
import qualified Numeric.LinearAlgebra as LA
import qualified Numeric.LinearAlgebra.Static as LAS
data FeedForward :: * -> *
data Recurrent :: * -> *
data RecurrentNetwork :: [*] -> [Shape] -> * where
RNil :: SingI i
=> RecurrentNetwork '[] '[i]
(:~~>) :: (SingI i, Layer x i h)
=> !x
-> !(RecurrentNetwork xs (h ': hs))
-> RecurrentNetwork (FeedForward x ': xs) (i ': h ': hs)
(:~@>) :: (SingI i, RecurrentLayer x i h)
=> !x
-> !(RecurrentNetwork xs (h ': hs))
-> RecurrentNetwork (Recurrent x ': xs) (i ': h ': hs)
infixr 5 :~~>
infixr 5 :~@>
data RecurrentGradients :: [*] -> * where
RGNil :: RecurrentGradients '[]
(://>) :: UpdateLayer x
=> [Gradient x]
-> RecurrentGradients xs
-> RecurrentGradients (phantom x ': xs)
data RecurrentInputs :: [*] -> * where
RINil :: RecurrentInputs '[]
(:~~+>) :: UpdateLayer x
=> () -> !(RecurrentInputs xs) -> RecurrentInputs (FeedForward x ': xs)
(:~@+>) :: (SingI (RecurrentShape x), RecurrentUpdateLayer x)
=> !(S (RecurrentShape x)) -> !(RecurrentInputs xs) -> RecurrentInputs (Recurrent x ': xs)
data RecurrentTapes :: [*] -> [Shape] -> * where
TRNil :: SingI i
=> RecurrentTapes '[] '[i]
(:\~>) :: [Tape x i h]
-> !(RecurrentTapes xs (h ': hs))
-> RecurrentTapes (FeedForward x ': xs) (i ': h ': hs)
(:\@>) :: [RecTape x i h]
-> !(RecurrentTapes xs (h ': hs))
-> RecurrentTapes (Recurrent x ': xs) (i ': h ': hs)
runRecurrentNetwork :: forall shapes layers.
RecurrentNetwork layers shapes
-> RecurrentInputs layers
-> [S (Head shapes)]
-> (RecurrentTapes layers shapes, RecurrentInputs layers, [S (Last shapes)])
runRecurrentNetwork =
go
where
go :: forall js sublayers. (Last js ~ Last shapes)
=> RecurrentNetwork sublayers js
-> RecurrentInputs sublayers
-> [S (Head js)]
-> (RecurrentTapes sublayers js, RecurrentInputs sublayers, [S (Last js)])
go (layer :~~> n) (() :~~+> nIn) !xs
= let tys = runForwards layer <$> xs
feedForwardTapes = fst <$> tys
forwards = snd <$> tys
(newFN, ig, answer) = go n nIn forwards
in (feedForwardTapes :\~> newFN, () :~~+> ig, answer)
go (layer :~@> n) (recIn :~@+> nIn) !xs
= let (recOut, tys) = goR layer recIn xs
recurrentTapes = fst <$> tys
forwards = snd <$> tys
(newFN, ig, answer) = go n nIn forwards
in (recurrentTapes :\@> newFN, recOut :~@+> ig, answer)
go RNil RINil !x
= (TRNil, RINil, x)
goR !layer !recShape (x:xs) =
let (tape, lerec, lepush) = runRecurrentForwards layer recShape x
(rems, push) = goR layer lerec xs
in (rems, (tape, lepush) : push)
goR _ rin [] = (rin, [])
runRecurrentGradient :: forall layers shapes.
RecurrentNetwork layers shapes
-> RecurrentTapes layers shapes
-> RecurrentInputs layers
-> [S (Last shapes)]
-> (RecurrentGradients layers, RecurrentInputs layers, [S (Head shapes)])
runRecurrentGradient net tapes r o =
go net tapes r
where
go :: forall js ss. (Last js ~ Last shapes)
=> RecurrentNetwork ss js
-> RecurrentTapes ss js
-> RecurrentInputs ss
-> (RecurrentGradients ss, RecurrentInputs ss, [S (Head js)])
go (layer :~~> n) (feedForwardTapes :\~> nTapes) (() :~~+> nRecs) =
let (gradients, rins, feed) = go n nTapes nRecs
backs = uncurry (runBackwards layer) <$> zip (reverse feedForwardTapes) feed
in ((fst <$> backs) ://> gradients, () :~~+> rins, snd <$> backs)
go (layer :~@> n) (recurrentTapes :\@> nTapes) (recGrad :~@+> nRecs) =
let (gradients, rins, feed) = go n nTapes nRecs
backExamples = zip (reverse recurrentTapes) feed
(rg, backs) = goX layer recGrad backExamples
in ((fst <$> backs) ://> gradients, rg :~@+> rins, snd <$> backs)
go RNil TRNil RINil
= (RGNil, RINil, reverse o)
goX :: RecurrentLayer x i o => x -> S (RecurrentShape x) -> [(RecTape x i o, S o)] -> (S (RecurrentShape x), [(Gradient x, S i)])
goX layer !lastback ((recTape, backgrad):xs) =
let (layergrad, recgrad, ingrad) = runRecurrentBackwards layer recTape lastback backgrad
(pushedback, ll) = goX layer recgrad xs
in (pushedback, (layergrad, ingrad) : ll)
goX _ !lastback [] = (lastback, [])
applyRecurrentUpdate :: LearningParameters
-> RecurrentNetwork layers shapes
-> RecurrentGradients layers
-> RecurrentNetwork layers shapes
applyRecurrentUpdate rate (layer :~~> rest) (gradient ://> grest)
= runUpdates rate layer gradient :~~> applyRecurrentUpdate rate rest grest
applyRecurrentUpdate rate (layer :~@> rest) (gradient ://> grest)
= runUpdates rate layer gradient :~@> applyRecurrentUpdate rate rest grest
applyRecurrentUpdate _ RNil RGNil
= RNil
instance Show (RecurrentNetwork '[] '[i]) where
show RNil = "NNil"
instance (Show x, Show (RecurrentNetwork xs rs)) => Show (RecurrentNetwork (FeedForward x ': xs) (i ': rs)) where
show (x :~~> xs) = show x ++ "\n~~>\n" ++ show xs
instance (Show x, Show (RecurrentNetwork xs rs)) => Show (RecurrentNetwork (Recurrent x ': xs) (i ': rs)) where
show (x :~@> xs) = show x ++ "\n~~>\n" ++ show xs
class CreatableRecurrent (xs :: [*]) (ss :: [Shape]) where
randomRecurrent :: MonadRandom m => m (RecurrentNetwork xs ss, RecurrentInputs xs)
instance SingI i => CreatableRecurrent '[] '[i] where
randomRecurrent =
return (RNil, RINil)
instance (SingI i, Layer x i o, CreatableRecurrent xs (o ': rs)) => CreatableRecurrent (FeedForward x ': xs) (i ': o ': rs) where
randomRecurrent = do
thisLayer <- createRandom
(rest, resti) <- randomRecurrent
return (thisLayer :~~> rest, () :~~+> resti)
instance (SingI i, RecurrentLayer x i o, CreatableRecurrent xs (o ': rs)) => CreatableRecurrent (Recurrent x ': xs) (i ': o ': rs) where
randomRecurrent = do
thisLayer <- createRandom
thisShape <- randomOfShape
(rest, resti) <- randomRecurrent
return (thisLayer :~@> rest, thisShape :~@+> resti)
instance SingI i => Serialize (RecurrentNetwork '[] '[i]) where
put RNil = pure ()
get = pure RNil
instance (SingI i, Layer x i o, Serialize x, Serialize (RecurrentNetwork xs (o ': rs))) => Serialize (RecurrentNetwork (FeedForward x ': xs) (i ': o ': rs)) where
put (x :~~> r) = put x >> put r
get = (:~~>) <$> get <*> get
instance (SingI i, RecurrentLayer x i o, Serialize x, Serialize (RecurrentNetwork xs (o ': rs))) => Serialize (RecurrentNetwork (Recurrent x ': xs) (i ': o ': rs)) where
put (x :~@> r) = put x >> put r
get = (:~@>) <$> get <*> get
instance (Serialize (RecurrentInputs '[])) where
put _ = return ()
get = return RINil
instance (UpdateLayer x, Serialize (RecurrentInputs ys)) => (Serialize (RecurrentInputs (FeedForward x ': ys))) where
put ( () :~~+> rest) = put rest
get = ( () :~~+> ) <$> get
instance (SingI (RecurrentShape x), RecurrentUpdateLayer x, Serialize (RecurrentInputs ys)) => (Serialize (RecurrentInputs (Recurrent x ': ys))) where
put ( i :~@+> rest ) = do
_ <- (case i of
(S1D x) -> putListOf put . LA.toList . LAS.extract $ x
(S2D x) -> putListOf put . LA.toList . LA.flatten . LAS.extract $ x
(S3D x) -> putListOf put . LA.toList . LA.flatten . LAS.extract $ x
) :: PutM ()
put rest
get = do
Just i <- fromStorable . V.fromList <$> getListOf get
rest <- get
return ( i :~@+> rest)
instance (Num (RecurrentInputs '[])) where
(+) _ _ = RINil
() _ _ = RINil
(*) _ _ = RINil
abs _ = RINil
signum _ = RINil
fromInteger _ = RINil
instance (UpdateLayer x, Num (RecurrentInputs ys)) => (Num (RecurrentInputs (FeedForward x ': ys))) where
(+) (() :~~+> x) (() :~~+> y) = () :~~+> (x + y)
() (() :~~+> x) (() :~~+> y) = () :~~+> (x y)
(*) (() :~~+> x) (() :~~+> y) = () :~~+> (x * y)
abs (() :~~+> x) = () :~~+> abs x
signum (() :~~+> x) = () :~~+> signum x
fromInteger x = () :~~+> fromInteger x
instance (SingI (RecurrentShape x), RecurrentUpdateLayer x, Num (RecurrentInputs ys)) => (Num (RecurrentInputs (Recurrent x ': ys))) where
(+) (x :~@+> x') (y :~@+> y') = (x + y) :~@+> (x' + y')
() (x :~@+> x') (y :~@+> y') = (x y) :~@+> (x' y')
(*) (x :~@+> x') (y :~@+> y') = (x * y) :~@+> (x' * y')
abs (x :~@+> x') = abs x :~@+> abs x'
signum (x :~@+> x') = signum x :~@+> signum x'
fromInteger x = fromInteger x :~@+> fromInteger x