{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE TypeOperators         #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE EmptyDataDecls        #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE BangPatterns          #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE UndecidableInstances  #-}

module Grenade.Recurrent.Core.Network (
    Recurrent
  , FeedForward

  , RecurrentNetwork (..)
  , RecurrentInputs (..)
  , RecurrentTapes (..)
  , RecurrentGradients (..)

  , randomRecurrent
  , runRecurrentNetwork
  , runRecurrentGradient
  , applyRecurrentUpdate
  ) where


import           Control.Monad.Random ( MonadRandom )
import           Data.Singletons ( SingI )
import           Data.Singletons.Prelude ( Head, Last )
import           Data.Serialize
import qualified Data.Vector.Storable as V

import           Grenade.Core
import           Grenade.Recurrent.Core.Layer

import qualified Numeric.LinearAlgebra as LA
import qualified Numeric.LinearAlgebra.Static as LAS

-- | Witness type to say indicate we're building up with a normal feed
--   forward layer.
data FeedForward :: * -> *
-- | Witness type to say indicate we're building up with a recurrent layer.
data Recurrent :: * -> *

-- | Type of a recurrent neural network.
--
--   The [*] type specifies the types of the layers.
--
--   The [Shape] type specifies the shapes of data passed between the layers.
--
--   The definition is similar to a Network, but every layer in the
--   type is tagged by whether it's a FeedForward Layer of a Recurrent layer.
--
--   Often, to make the definitions more concise, one will use a type alias
--   for these empty data types.
data RecurrentNetwork :: [*] -> [Shape] -> * where
  RNil   :: SingI i
         => RecurrentNetwork '[] '[i]

  (:~~>) :: (SingI i, Layer x i h)
         => !x
         -> !(RecurrentNetwork xs (h ': hs))
         -> RecurrentNetwork (FeedForward x ': xs) (i ': h ': hs)

  (:~@>) :: (SingI i, RecurrentLayer x i h)
         => !x
         -> !(RecurrentNetwork xs (h ': hs))
         -> RecurrentNetwork (Recurrent x ': xs) (i ': h ': hs)
infixr 5 :~~>
infixr 5 :~@>

-- | Gradient of a network.
--
--   Parameterised on the layers of the network.
data RecurrentGradients :: [*] -> * where
   RGNil  :: RecurrentGradients '[]

   (://>) :: UpdateLayer x
          => [Gradient x]
          -> RecurrentGradients xs
          -> RecurrentGradients (phantom x ': xs)

-- | Recurrent inputs (sideways shapes on an imaginary unrolled graph)
--   Parameterised on the layers of a Network.
data RecurrentInputs :: [*] -> * where
   RINil   :: RecurrentInputs '[]

   (:~~+>) :: UpdateLayer x
           => ()                      -> !(RecurrentInputs xs) -> RecurrentInputs (FeedForward x ': xs)

   (:~@+>) :: (SingI (RecurrentShape x), RecurrentUpdateLayer x)
           => !(S (RecurrentShape x)) -> !(RecurrentInputs xs) -> RecurrentInputs (Recurrent x ': xs)

-- | All the information required to backpropogate
--   through time safely.
--
--   We index on the time step length as well, to ensure
--   that that all Tape lengths are the same.
data RecurrentTapes :: [*] -> [Shape] -> * where
   TRNil  :: SingI i
          => RecurrentTapes '[] '[i]

   (:\~>) :: [Tape x i h]
          -> !(RecurrentTapes xs (h ': hs))
          -> RecurrentTapes (FeedForward x ': xs) (i ': h ': hs)


   (:\@>) :: [RecTape x i h]
          -> !(RecurrentTapes xs (h ': hs))
          -> RecurrentTapes (Recurrent x ': xs) (i ': h ': hs)


runRecurrentNetwork  :: forall shapes layers.
                        RecurrentNetwork layers shapes
                     -> RecurrentInputs layers
                     -> [S (Head shapes)]
                     -> (RecurrentTapes layers shapes, RecurrentInputs layers, [S (Last shapes)])
runRecurrentNetwork =
  go
    where
  go  :: forall js sublayers. (Last js ~ Last shapes)
      => RecurrentNetwork sublayers js
      -> RecurrentInputs sublayers
      -> [S (Head js)]
      -> (RecurrentTapes sublayers js, RecurrentInputs sublayers, [S (Last js)])
  -- This is a simple non-recurrent layer, just map it forwards
  go (layer :~~> n) (() :~~+> nIn) !xs
      = let tys                 = runForwards layer <$> xs
            feedForwardTapes    = fst <$> tys
            forwards            = snd <$> tys
            -- recursively run the rest of the network, and get the gradients from above.
            (newFN, ig, answer) = go n nIn forwards
        in (feedForwardTapes :\~> newFN, () :~~+> ig, answer)

  -- This is a recurrent layer, so we need to do a scan, first input to last, providing
  -- the recurrent shape output to the next layer.
  go (layer :~@> n) (recIn :~@+> nIn) !xs
      = let (recOut, tys)       = goR layer recIn xs
            recurrentTapes      = fst <$> tys
            forwards            = snd <$> tys

            (newFN, ig, answer) = go n nIn forwards
        in (recurrentTapes :\@> newFN, recOut :~@+> ig, answer)

  -- Handle the output layer, bouncing the derivatives back down.
  -- We may not have a target for each example, so when we don't use 0 gradient.
  go RNil RINil !x
    = (TRNil, RINil, x)

  -- Helper function for recurrent layers
  -- Scans over the recurrent direction of the graph.
  goR !layer !recShape (x:xs) =
    let (tape, lerec, lepush) = runRecurrentForwards layer recShape x
        (rems, push)          = goR layer lerec xs
    in  (rems, (tape, lepush) : push)
  goR _ rin []      = (rin, [])

runRecurrentGradient :: forall layers shapes.
                        RecurrentNetwork layers shapes
                     -> RecurrentTapes layers shapes
                     -> RecurrentInputs layers
                     -> [S (Last shapes)]
                     -> (RecurrentGradients layers, RecurrentInputs layers, [S (Head shapes)])
runRecurrentGradient net tapes r o =
  go net tapes r
    where
  -- We have to be careful regarding the direction of the lists
  -- Inputs come in forwards, but our return value is backwards
  -- through time.
  go  :: forall js ss. (Last js ~ Last shapes)
      => RecurrentNetwork ss js
      -> RecurrentTapes ss js
      -> RecurrentInputs ss
      -> (RecurrentGradients ss, RecurrentInputs ss, [S (Head js)])
  -- This is a simple non-recurrent layer
  -- Run the rest of the network, then fmap the tapes and gradients
  go (layer :~~> n) (feedForwardTapes :\~> nTapes) (() :~~+> nRecs) =
    let (gradients, rins, feed)  = go n nTapes nRecs
        backs                    = uncurry (runBackwards layer) <$> zip (reverse feedForwardTapes) feed
    in  ((fst <$> backs) ://> gradients, () :~~+> rins, snd <$> backs)

  -- This is a recurrent layer
  -- Run the rest of the network, scan over the tapes in reverse
  go (layer :~@> n) (recurrentTapes :\@> nTapes) (recGrad :~@+> nRecs) =
    let (gradients, rins, feed)  = go n nTapes nRecs
        backExamples             = zip (reverse recurrentTapes) feed
        (rg, backs)              = goX layer recGrad backExamples
    in  ((fst <$> backs) ://> gradients, rg :~@+> rins, snd <$> backs)

  -- End of the road, so we reflect the given gradients backwards.
  -- Crucially, we reverse the list, so it's backwards in time as
  -- well.
  go RNil TRNil RINil
    = (RGNil, RINil, reverse o)

  -- Helper function for recurrent layers
  -- Scans over the recurrent direction of the graph.
  goX :: RecurrentLayer x i o => x -> S (RecurrentShape x) -> [(RecTape x i o, S o)] -> (S (RecurrentShape x), [(Gradient x, S i)])
  goX layer !lastback ((recTape, backgrad):xs) =
    let (layergrad, recgrad, ingrad) = runRecurrentBackwards layer recTape lastback backgrad
        (pushedback, ll)             = goX layer recgrad xs
    in  (pushedback, (layergrad, ingrad) : ll)
  goX _ !lastback []      = (lastback, [])

-- | Apply a batch of gradients to the network
--   Uses runUpdates which can be specialised for
--   a layer.
applyRecurrentUpdate :: LearningParameters
                     -> RecurrentNetwork layers shapes
                     -> RecurrentGradients layers
                     -> RecurrentNetwork layers shapes
applyRecurrentUpdate rate (layer :~~> rest) (gradient ://> grest)
  = runUpdates rate layer gradient :~~> applyRecurrentUpdate rate rest grest

applyRecurrentUpdate rate (layer :~@> rest) (gradient ://> grest)
  = runUpdates rate layer gradient :~@> applyRecurrentUpdate rate rest grest

applyRecurrentUpdate _ RNil RGNil
  = RNil


instance Show (RecurrentNetwork '[] '[i]) where
  show RNil = "NNil"
instance (Show x, Show (RecurrentNetwork xs rs)) => Show (RecurrentNetwork (FeedForward x ': xs) (i ': rs)) where
  show (x :~~> xs) = show x ++ "\n~~>\n" ++ show xs
instance (Show x, Show (RecurrentNetwork xs rs)) => Show (RecurrentNetwork (Recurrent x ': xs) (i ': rs)) where
  show (x :~@> xs) = show x ++ "\n~~>\n" ++ show xs


-- | A network can easily be created by hand with (:~~>) and (:~@>), but an easy way to initialise a random
--   recurrent network and a set of random inputs for it is with the randomRecurrent.
class CreatableRecurrent (xs :: [*]) (ss :: [Shape]) where
  -- | Create a network of the types requested
  randomRecurrent :: MonadRandom m => m (RecurrentNetwork xs ss, RecurrentInputs xs)

instance SingI i => CreatableRecurrent '[] '[i] where
  randomRecurrent =
    return (RNil, RINil)

instance (SingI i, Layer x i o, CreatableRecurrent xs (o ': rs)) => CreatableRecurrent (FeedForward x ': xs) (i ': o ': rs) where
  randomRecurrent = do
    thisLayer     <- createRandom
    (rest, resti) <- randomRecurrent
    return (thisLayer :~~> rest, () :~~+> resti)

instance (SingI i, RecurrentLayer x i o, CreatableRecurrent xs (o ':  rs)) => CreatableRecurrent (Recurrent x ': xs) (i ': o ': rs) where
  randomRecurrent = do
    thisLayer     <- createRandom
    thisShape     <- randomOfShape
    (rest, resti) <- randomRecurrent
    return (thisLayer :~@> rest, thisShape :~@+> resti)

-- | Add very simple serialisation to the recurrent network
instance SingI i => Serialize (RecurrentNetwork '[] '[i]) where
  put RNil = pure ()
  get = pure RNil

instance (SingI i, Layer x i o, Serialize x, Serialize (RecurrentNetwork xs (o ': rs))) => Serialize (RecurrentNetwork (FeedForward x ': xs) (i ': o ': rs)) where
  put (x :~~> r) = put x >> put r
  get = (:~~>) <$> get <*> get

instance (SingI i, RecurrentLayer x i o, Serialize x, Serialize (RecurrentNetwork xs (o ': rs))) => Serialize (RecurrentNetwork (Recurrent x ': xs) (i ': o ': rs)) where
  put (x :~@> r) = put x >> put r
  get = (:~@>) <$> get <*> get

instance (Serialize (RecurrentInputs '[])) where
  put _ = return ()
  get = return RINil

instance (UpdateLayer x, Serialize (RecurrentInputs ys)) => (Serialize (RecurrentInputs (FeedForward x ': ys))) where
  put ( () :~~+> rest) = put rest
  get = ( () :~~+> ) <$> get

instance (SingI (RecurrentShape x), RecurrentUpdateLayer x, Serialize (RecurrentInputs ys)) => (Serialize (RecurrentInputs (Recurrent x ': ys))) where
  put ( i :~@+> rest ) = do
    _ <- (case i of
           (S1D x) -> putListOf put . LA.toList . LAS.extract $ x
           (S2D x) -> putListOf put . LA.toList . LA.flatten . LAS.extract $ x
           (S3D x) -> putListOf put . LA.toList . LA.flatten . LAS.extract $ x
         ) :: PutM ()
    put rest

  get = do
    Just i <- fromStorable . V.fromList <$> getListOf get
    rest   <- get
    return ( i :~@+> rest)


-- Num instance for `RecurrentInputs layers`
-- Not sure if this is really needed, as I only need a `fromInteger 0` at
-- the moment for training, to create a null gradient on the recurrent
-- edge.
--
-- It does raise an interesting question though? Is a 0 gradient actually
-- the best?
--
-- I could imaging that weakly push back towards the optimum input could
-- help make a more stable generator.
instance (Num (RecurrentInputs '[])) where
  (+) _ _  = RINil
  (-) _ _  = RINil
  (*) _ _  = RINil
  abs _    = RINil
  signum _ = RINil
  fromInteger _ = RINil

instance (UpdateLayer x, Num (RecurrentInputs ys)) => (Num (RecurrentInputs (FeedForward x ': ys))) where
  (+) (() :~~+> x) (() :~~+> y)  = () :~~+> (x + y)
  (-) (() :~~+> x) (() :~~+> y)  = () :~~+> (x - y)
  (*) (() :~~+> x) (() :~~+> y)  = () :~~+> (x * y)
  abs (() :~~+> x)      = () :~~+> abs x
  signum (() :~~+> x)   = () :~~+> signum x
  fromInteger x         = () :~~+> fromInteger x

instance (SingI (RecurrentShape x), RecurrentUpdateLayer x, Num (RecurrentInputs ys)) => (Num (RecurrentInputs (Recurrent x ': ys))) where
  (+) (x :~@+> x') (y :~@+> y')  = (x + y) :~@+> (x' + y')
  (-) (x :~@+> x') (y :~@+> y')  = (x - y) :~@+> (x' - y')
  (*) (x :~@+> x') (y :~@+> y')  = (x * y) :~@+> (x' * y')
  abs (x :~@+> x')      = abs x :~@+> abs x'
  signum (x :~@+> x')   = signum x :~@+> signum x'
  fromInteger x         = fromInteger x :~@+> fromInteger x