{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE FlexibleContexts      #-}
{-|
Module      : Grenade.Core.Runner
Description : Functions to perform training and backpropagation
Copyright   : (c) Huw Campbell, 2016-2017
License     : BSD2
Stability   : experimental
-}
module Grenade.Core.Runner (
    train
  , backPropagate
  , runNet
  ) where

import           Data.Singletons.Prelude

import           Grenade.Core.LearningParameters
import           Grenade.Core.Network
import           Grenade.Core.Shape

-- | Perform reverse automatic differentiation on the network
--   for the current input and expected output.
--
--   /Note:/ The loss function pushed backwards is appropriate
--   for both regression and classification as a squared loss
--   or log-loss respectively.
--
--   For other loss functions, use runNetwork and runGradient
--   with the back propagated gradient of your loss.
--
backPropagate :: SingI (Last shapes)
              => Network layers shapes
              -> S (Head shapes)
              -> S (Last shapes)
              -> Gradients layers
backPropagate network input target =
    let (tapes, output) = runNetwork network input
        (grads, _)      = runGradient network tapes (output - target)
    in  grads


-- | Update a network with new weights after training with an instance.
train :: SingI (Last shapes)
      => LearningParameters
      -> Network layers shapes
      -> S (Head shapes)
      -> S (Last shapes)
      -> Network layers shapes
train rate network input output =
    let grads = backPropagate network input output
    in  applyUpdate rate network grads


-- | Run the network with input and return the given output.
runNet :: Network layers shapes -> S (Head shapes) -> S (Last shapes)
runNet net = snd . runNetwork net