module Grenade.Core.Runner (
train
, backPropagate
, runNet
) where
import Data.Singletons.Prelude
import Grenade.Core.LearningParameters
import Grenade.Core.Network
import Grenade.Core.Shape
backPropagate :: SingI (Last shapes)
=> Network layers shapes
-> S (Head shapes)
-> S (Last shapes)
-> Gradients layers
backPropagate network input target =
let (tapes, output) = runNetwork network input
(grads, _) = runGradient network tapes (output target)
in grads
train :: SingI (Last shapes)
=> LearningParameters
-> Network layers shapes
-> S (Head shapes)
-> S (Last shapes)
-> Network layers shapes
train rate network input output =
let grads = backPropagate network input output
in applyUpdate rate network grads
runNet :: Network layers shapes -> S (Head shapes) -> S (Last shapes)
runNet net = snd . runNetwork net