{-|
Module      : TypedFlow.Layers.RNN
Description : RNN cells, layers and combinators.
Copyright   : (c) Jean-Philippe Bernardy, 2017
License     : LGPL-3
Maintainer  : jean-philippe.bernardy@gu.se
Stability   : experimental
-}

{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE TypeInType #-}
{-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-}
{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE UnicodeSyntax #-}
{-# LANGUAGE PatternSynonyms #-}

module TypedFlow.Layers.RNN (
  -- * Types
  RnnCell, RnnLayer,
  -- * Combinators
  stackRnnCells, (.-.),
  stackRnnLayers, (.--.),
  bothRnnLayers,(.++.),
  withBypass,
  onStates,
  timeDistribute, timeDistribute',
  -- * RNN Cells
  cellInitializerBit,
  LSTMP(..),
  lstm,
  GRUP(..),
  gru,
  -- * RNN unfolding functions
  rnn,
  rnnBackward,
  rnnBackwardsWithCull,
  rnnWithCull,
  -- * Attention mechanisms
  -- ** Scoring functions
  AttentionScoring,
  multiplicativeScoring,
  AdditiveScoringP(..), additiveScoring,
  -- ** Attention functions
  AttentionFunction,
  uniformAttn,
  luongAttention,
  -- ** Attention combinators
  attentiveWithFeedback
  )

where

import Prelude hiding (tanh,Num(..),Floating(..),floor)
import GHC.TypeLits
-- import Text.PrettyPrint.Compact (float)
import TypedFlow.TF
import TypedFlow.Types
import TypedFlow.Layers.Core (DenseP(..),(#))
-- import Data.Type.Equality
-- import Data.Kind (Type,Constraint)
import Data.Monoid ((<>))


-- | A cell in an rnn. @state@ is the state propagated through time.
type RnnCell t states input output = (HTV (Flt t) states , input) -> Gen (HTV (Flt t) states , output)

-- | A layer in an rnn. @n@ is the length of the time sequence. @state@ is the state propagated through time.
type RnnLayer b n state input t output u = HTV (Flt b) state -> Tensor (n ': input) t -> Gen (HTV (Flt b) state , Tensor (n ': output) u)

--------------------------------------
-- Combinators


-- | Compose two rnn layers. This is useful for example to combine
-- forward and backward layers.
(.--.),stackRnnLayers :: forall s1 s2 a t b u c v n bits. KnownLen s1 =>
                  RnnLayer bits n s1 a t b u -> RnnLayer bits n s2 b u c v -> RnnLayer bits n (s1 ++ s2) a t c v
stackRnnLayers f g (hsplit @s1 -> (s0,s1)) x = do
  (s0',y) <- f s0 x
  (s1',z) <- g s1 y
  return (happ s0' s1',z)

infixr .--.
(.--.) = stackRnnLayers


-- | Compose two rnn layers in parallel.
bothRnnLayers,(.++.)  :: forall s1 s2 a t b u c n bs bits. KnownLen s1 =>
                  RnnLayer bits n s1 a t '[b,bs] u -> RnnLayer bits n s2 a t '[c,bs] u -> RnnLayer bits n (s1 ++ s2) a t '[b+c,bs] u
bothRnnLayers f g (hsplit @s1 -> (s0,s1)) x = do
  (s0',y) <- f s0 x
  (s1',z) <- g s1 x
  return (happ s0' s1',concat1 y z)


infixr .++.
(.++.) = bothRnnLayers

-- | Apply a function on the cell state(s) before running the cell itself.
onStates ::  (HTV (Flt t) xs -> HTV (Flt t) xs) -> RnnCell t xs a b -> RnnCell t xs a b
onStates f cell (h,x) = do
  cell (f h, x)

-- | Stack two RNN cells (LHS is run first)
stackRnnCells, (.-.) :: forall s0 s1 a b c t. KnownLen s0 => RnnCell t s0 a b -> RnnCell t s1 b c -> RnnCell t (s0 ++ s1) a c
stackRnnCells l1 l2 (hsplit @s0 -> (s0,s1),x) = do
  (s0',y) <- l1 (s0,x)
  (s1',z) <- l2 (s1,y)
  return ((happ s0' s1'),z)

(.-.) = stackRnnCells

-- | Run the cell, and forward the input to the output, by concatenation with the output of the cell.
withBypass :: RnnCell b s0 (T '[x,bs] t) (T '[y,bs] t) -> RnnCell b s0 (T '[x,bs] t) (T '[x+y,bs] t)
withBypass cell (s,x) = do
  (s',y) <- cell (s,x)
  return (s',concat0 x y)

--------------------------------------
-- Cells

-- | Convert a pure function (feed-forward layer) to an RNN cell by
-- ignoring the RNN state.
timeDistribute :: (a -> b) -> RnnCell t '[] a b
timeDistribute pureLayer = timeDistribute' (return . pureLayer)

-- | Convert a stateless generator into an RNN cell by ignoring the
-- RNN state.
timeDistribute' :: (a -> Gen b) -> RnnCell t '[] a b
timeDistribute' stateLess (Unit,a) = do
  b <- stateLess a
  return (Unit,b)

-- | Standard RNN gate initializer. (The recurrent kernel is
-- orthogonal to avoid divergence; the input kernel is glorot)
cellInitializerBit ::  n x t. (KnownNat n, KnownNat x, KnownBits t) => DenseP t (n + x) n
cellInitializerBit = DenseP (concat0 recurrentInitializer kernelInitializer) biasInitializer
  where
        recurrentInitializer :: Tensor '[n, n] ('Typ 'Float t)
        recurrentInitializer = randomOrthogonal
        kernelInitializer :: Tensor '[x, n] ('Typ 'Float t)
        kernelInitializer = glorotUniform
        biasInitializer = zeros

-- | Parameter for an LSTM
data LSTMP t n x = LSTMP (DenseP t (n+x) n) (DenseP t (n+x) n) (DenseP t (n+x) n) (DenseP t (n+x) n)

instance (KnownNat n, KnownNat x, KnownBits t) => KnownTensors (LSTMP t n x) where
  travTensor f s (LSTMP x y z w) = LSTMP <$> travTensor f (s<>"_f") x <*> travTensor f (s<>"_i") y <*> travTensor f (s<>"_c") z <*> travTensor f (s<>"_o") w
instance (KnownNat n, KnownNat x, KnownBits t) => ParamWithDefault (LSTMP t n x) where
  defaultInitializer = LSTMP forgetInit cellInitializerBit cellInitializerBit cellInitializerBit
    where forgetInit = DenseP (denseWeights cellInitializerBit) ones

-- | Standard LSTM
lstm ::  n x bs t. LSTMP t n x ->
        RnnCell t '[ '[n,bs], '[n,bs]] (Tensor '[x,bs] (Flt t)) (Tensor '[n,bs] (Flt t))
lstm (LSTMP wf wi wc wo) (VecPair ht1 ct1, input) = do
  hx <- assign (concat0 ht1 input)
  let f = sigmoid (wf # hx)
      i = sigmoid (wi # hx)
      cTilda = tanh (wc # hx)
      o = sigmoid (wo # hx)
  c <- assign ((f  ct1) + (i  cTilda))
  h <- assign (o  tanh c)
  return (VecPair h c, h)

-- -- | LSTM for an attention model. The result of attention is combined using + to generate output (bad!)
-- attentiveLstmPlus :: forall x n bs t. KnownNat bs =>
--   AttentionFunction t bs n n ->
--   LSTMP t n x ->
--   RnnCell t '[ '[n,bs], '[n,bs]] (Tensor '[x,bs] (Flt t)) (Tensor '[n,bs] (Flt t))
-- attentiveLstmPlus att w x = do
--   (VecPair ht ct, _ht) <- lstm w x
--   a <- att ht
--   let ht' = ht ⊕ a -- alternatively add a dense layer to combine
--   return (VecPair ht' ct, a)

-- | Parameter for a GRU
data GRUP t n x = GRUP (T [n+x,n] ('Typ 'Float t)) (T [n+x,n] ('Typ 'Float t)) (T [n+x,n] ('Typ 'Float t))

instance (KnownNat n, KnownNat x, KnownBits t) => KnownTensors (GRUP t n x) where
  travTensor f s (GRUP x y z) = GRUP <$> travTensor f (s<>"_z") x <*> travTensor f (s<>"_r") y <*> travTensor f (s<>"_w") z
instance (KnownNat n, KnownNat x, KnownBits t) => ParamWithDefault (GRUP t n x) where
  defaultInitializer = GRUP (denseWeights cellInitializerBit) (denseWeights cellInitializerBit) (denseWeights cellInitializerBit)


-- | Standard GRU cell
gru ::  n x bs t. (KnownNat bs, KnownNat n, KnownBits t) => GRUP t n x ->
        RnnCell t '[ '[n,bs] ] (Tensor '[x,bs] (Flt t)) (Tensor '[n,bs] (Flt t))
gru (GRUP wz wr w) (VecSing ht1, xt) = do
  hx <- assign (concat0 ht1 xt)
  let zt = sigmoid (wz  hx)
      rt = sigmoid (wr  hx)
      hTilda = tanh (w  (concat0 (rt  ht1) xt))
  ht <- assign ((ones  zt)  ht1 + zt  hTilda)
  return (VecSing ht, ht)

----------------------------------------------
-- "Attention" layers


-- | An attention scoring function. This function should produce a
-- score (between 0 and 1) for each of the @nValues@ entries of size
-- @valueSize@.
type AttentionScoring t batchSize keySize valueSize nValues = 
  Tensor '[keySize,batchSize] ('Typ 'Float t) -> Tensor '[nValues,valueSize,batchSize] ('Typ 'Float t) -> Tensor '[nValues,batchSize] ('Typ 'Float t)

-- | A function which attends to an external input. Typically a
-- function of this type is a closure which has the attended input in
-- its environment.
type AttentionFunction t batchSize keySize valueSize =
  T '[keySize,batchSize] (Flt t) -> Gen (T '[valueSize,batchSize] (Flt t))

{- NICER, SLOW

type AttentionScoring t batchSize keySize valueSize =
  Tensor '[keySize,batchSize] ('Typ 'Float t) -> Tensor '[valueSize,batchSize] ('Typ 'Float t) -> Tensor '[batchSize] ('Typ 'Float t)


-- | @attnExample1 θ h st@ combines each element of the vector h with
-- s, and applies a dense layer with parameters θ. The "winning"
-- element of h (using softmax) is returned.
uniformAttn :: ∀ valueSize m keySize batchSize t. KnownNat m => KnownBits t =>
               T '[batchSize] Int32 ->
               AttentionScoring t batchSize keySize valueSize ->
               T '[m,valueSize,batchSize] (Flt t) -> AttentionFunction t batchSize keySize valueSize
uniformAttn lengths score hs_ ht = do
  xx <- mapT (score ht) hs_
  let   αt :: T '[m,batchSize] (Flt t)
        αt = softmax0 (mask ⊙ xx)
        ct :: T '[valueSize,batchSize] (Flt t)
        ct = squeeze0 (matmul hs_ (expandDim0 αt))
        mask = cast (sequenceMask @m lengths) -- mask according to length
  return ct



-- | A multiplicative scoring function. See 
-- https://github.com/tensorflow/nmt#background-on-the-attention-mechanism
-- commit 75aa22dfb159f10a1a5b4557777d9ff547c1975a
multiplicativeScoring :: forall valueSize keySize batchSize t.
  T [keySize,valueSize] ('Typ 'Float t) ->  AttentionScoring t batchSize keySize valueSize
multiplicativeScoring w dt h = h · ir
  where ir :: T '[valueSize,batchSize] ('Typ 'Float t)
        ir = w ∙ dt


additiveScoring :: AdditiveScoringP sz keySize valueSize t -> AttentionScoring t batchSize valueSize keySize
additiveScoring (AdditiveScoringP v w1 w2) dt h = squeeze0 (v ∙ tanh ((w1 ∙ h) ⊕ (w2 ∙ dt)))

-}

-- | @attnExample1 θ h st@ combines each element of the vector h with
-- s, and applies a dense layer with parameters θ. The "winning"
-- element of h (using softmax) is returned.
uniformAttn ::  valueSize m keySize batchSize t. KnownNat m => KnownBits t
            => AttentionScoring t batchSize keySize valueSize m -- ^ scoring function
            -> T '[batchSize] Int32 -- ^ lengths of the inputs
            -> T '[m,valueSize,batchSize] (Flt t) -- ^ inputs
            -> AttentionFunction t batchSize keySize valueSize
uniformAttn score lengths hs_ ht = do
  let   αt :: T '[m,batchSize] (Flt t)
        xx = score ht hs_
        αt = softmax0 (mask  xx)
        ct :: T '[valueSize,batchSize] (Flt t)
        ct = squeeze0 (matmul hs_ (expandDim0 αt))
        mask = cast (sequenceMask @m lengths) -- mask according to length
  return ct

-- | Add some attention to an RnnCell, and feed the attention vector to
-- the next iteration in the rnn. (This follows the diagram at
-- https://github.com/tensorflow/nmt#background-on-the-attention-mechanism
-- commit 75aa22dfb159f10a1a5b4557777d9ff547c1975a).
attentiveWithFeedback ::forall attSize cellSize inputSize bs w ss.
  AttentionFunction w bs cellSize attSize ->
  RnnCell w ss                      (T '[inputSize+attSize,bs] (Flt w)) (T '[cellSize,bs] (Flt w)) ->
  RnnCell w ('[attSize,bs] ': ss)   (T '[inputSize        ,bs] (Flt w)) (T '[attSize,bs] (Flt w))
attentiveWithFeedback attn cell ((F prevAttnVector :* s),x) = do
  (s',y) <- cell (s,concat0 x prevAttnVector)
  focus <- attn y
  return ((F focus :* s'),focus)

-- -- | LSTM for an attention model. The result of attention is fed to the next step.
-- attentiveLstm :: forall attSize n x bs t. KnownNat bs =>
--   AttentionFunction t bs n attSize ->
--   LSTMP t n (x+attSize) ->
--   RnnCell t '[ '[attSize,bs], '[n,bs], '[n,bs] ] (Tensor '[x,bs] (Flt t)) (Tensor '[attSize,bs] (Flt t))
-- attentiveLstm att w = attentiveWithFeedback att (lstm w)


-- | Luong attention function (following
-- https://github.com/tensorflow/nmt#background-on-the-attention-mechanism
-- commit 75aa22dfb159f10a1a5b4557777d9ff547c1975a).
-- Essentially a dense layer with tanh activation, on top of uniform attention.
luongAttention ::  attnSize d m e batchSize w. KnownNat m => KnownBits w
  => Tensor '[d+e,attnSize] (Flt w)     -- ^ weights for the dense layer
  -> AttentionScoring w batchSize e d m -- ^ scoring function
  -> Tensor '[batchSize] Int32          -- ^ length of the input
  -> T '[m,d,batchSize] (Flt w)         -- ^ inputs
  -> AttentionFunction w batchSize e attnSize
luongAttention w scoring lens hs_ ht = do
  ct <- uniformAttn scoring lens hs_ ht
  return (tanh (w  (concat0 ct ht)))

-- | Multiplicative scoring function
multiplicativeScoring :: forall valueSize keySize batchSize nValues t.
  KnownNat batchSize => T [keySize,valueSize] ('Typ 'Float t) -- ^ weights
  ->  AttentionScoring t batchSize keySize valueSize nValues
multiplicativeScoring w dt hs = squeeze1 (matmul (expandDim1 ir) hs)
  where ir :: T '[valueSize,batchSize] ('Typ 'Float t)
        ir = w  dt


data AdditiveScoringP sz keySize valueSize t = AdditiveScoringP
  (Tensor '[sz, 1]         ('Typ 'Float t))
  (Tensor '[keySize, sz]   ('Typ 'Float t))
  (Tensor '[valueSize, sz] ('Typ 'Float t))

instance (KnownNat n, KnownNat k, KnownNat v, KnownBits t) => KnownTensors (AdditiveScoringP k v n t) where
  travTensor f s (AdditiveScoringP x y z) = AdditiveScoringP <$> travTensor f (s<>"_v") x <*> travTensor f (s<>"_w1") y <*> travTensor f (s<>"_w2") z
instance (KnownNat n, KnownNat k, KnownNat v, KnownBits t) => ParamWithDefault (AdditiveScoringP k v n t) where
  defaultInitializer = AdditiveScoringP glorotUniform glorotUniform glorotUniform

-- | An additive scoring function. See https://arxiv.org/pdf/1412.7449.pdf
additiveScoring :: forall sz keySize valueSize t nValues batchSize. KnownNat sz => KnownNat keySize => (KnownNat nValues, KnownNat batchSize) =>
  AdditiveScoringP sz keySize valueSize t -> AttentionScoring t batchSize valueSize keySize nValues
additiveScoring (AdditiveScoringP v w1 w2) dt h = transpose r''
  where w1h :: Tensor '[sz,batchSize, nValues] ('Typ 'Float t)
        w1h = transposeN01 @'[sz] (reshape @'[sz,nValues, batchSize] w1h')
        w1h' = matmul (reshape @'[keySize, nValues*batchSize] (transpose01 h)) (transpose01 w1)
        w2dt = w2  dt
        z' = reshape @'[sz,batchSize*nValues] (tanh (w1h + w2dt))
        r'' = reshape @[batchSize,nValues] (matmul z' (transpose v))

---------------------------------------------------------
-- RNN unfolding


-- | Build a RNN by repeating a cell @n@ times.
rnn ::  n state input output t u b.
       (KnownNat n, KnownShape input, KnownShape output) =>
       RnnCell b state (T input t) (T output u) -> RnnLayer b n state input t output u
rnn cell s0 t = do
  xs <- unstack0 t
  (sFin,us) <- chainForward cell (s0,xs)
  return (sFin,stack0 us)
-- There will be lots of stacking and unstacking at each layer for no
-- reason; we should change the in/out from tensors to vectors of
-- tensors.

-- | Build a RNN by repeating a cell @n@ times. However the state is
-- propagated in the right-to-left direction (decreasing indices in
-- the time dimension of the input and output tensors)
rnnBackward ::  n state input output t u b.
       (KnownNat n, KnownShape input, KnownShape output) =>
       RnnCell b state (T input t) (T output u) -> RnnLayer b n state input t output u

rnnBackward cell s0 t = do
  xs <- unstack0 t
  (sFin,us) <- chainBackward cell (s0,xs)
  return (sFin,stack0 us)



-- | RNN helper
chainForward ::  state a b n. ((state , a) -> Gen (state , b))  (state , V n a) -> Gen (state , V n b)
chainForward _ (s0 , V []) = return (s0 , V [])
chainForward f (s0 , V (x:xs)) = do
  (s1,x') <- f (s0 , x)
  (sFin,V xs') <- chainForward f (s1 , V xs)
  return (sFin,V (x':xs'))

-- | RNN helper
chainBackward ::  state a b n. ((state , a) -> Gen (state , b))  (state , V n a) -> Gen (state , V n b)
chainBackward _ (s0 , V []) = return (s0 , V [])
chainBackward f (s0 , V (x:xs)) = do
  (s1,V xs') <- chainBackward f (s0,V xs)
  (sFin, x') <- f (s1,x)
  return (sFin,V (x':xs'))

-- | RNN helper
chainForwardWithState ::  state a b n. ((state , a) -> Gen (state , b))  (state , V n a) -> Gen (V n b, V n state)
chainForwardWithState _ (_s0 , V []) = return (V [], V [])
chainForwardWithState f (s0 , V (x:xs)) = do
  (s1,x') <- f (s0 , x)
  (V xs',V ss) <- chainForwardWithState f (s1 , V xs)
  return (V (x':xs'), V (s1:ss) )

-- -- | RNN helper
-- chainBackwardWithState ::
--   ∀ state a b n. ((state , a) -> Gen (state , b)) → (state , V n a) -> Gen (state , V n b, V n state)
-- chainBackwardWithState _ (s0 , V []) = return (s0 , V [], V [])
-- chainBackwardWithState f (s0 , V (x:xs)) = do
--   (s1,V xs',V ss') <- chainBackwardWithState f (s0,V xs)
--   (sFin, x') <- f (s1,x)
--   return (sFin,V (x':xs'),V (sFin:ss'))

-- | RNN helper
transposeV :: forall n xs t. All KnownLen xs =>
               SList xs -> V n (HTV (Flt t) xs) -> HTV (Flt t) (Ap (FMap (Cons n)) xs)
transposeV LZ _ = Unit
transposeV (LS _ n) xxs  = F ys' :* yys'
  where (ys,yys) = help @(Tail xs) xxs
        ys' = stack0 ys
        yys' = transposeV n yys

        help :: forall ys x tt. V n (HTV tt (x ': ys)) -> (V n (T x tt) , V n (HTV tt ys))
        help (V xs) = (V (map (fromF . hhead) xs),V (map htail xs))

-- | @(gatherFinalStates dynLen states)[i] = states[dynLen[i]]@
gatherFinalStates :: KnownLen x => KnownNat n => LastEqual bs x => T '[bs] Int32 -> T (n ': x) t -> T x t
gatherFinalStates dynLen states = nth0 0 (reverseSequences dynLen states)

-- a more efficient algorithm (perhaps:)
-- gatherFinalStates' :: forall x n bs t. KnownLen x => KnownNat n => LastEqual bs x => T '[bs] Int32 -> T (x ++ '[n,bs]) t -> T x (x ++ '[bs])
-- gatherFinalStates' (T dynLen)t = gather (flattenN2 @x @n @bs t) indexInFlat
--  where indexInFlat = (dynLen - 1) + tf.range(0, bs) * n

gathers :: forall n bs xs t. All (LastEqual bs) xs => All KnownLen xs => KnownNat n =>
            SList xs -> T '[bs] Int32 -> HTV (Flt t) (Ap (FMap (Cons n)) xs) -> HTV (Flt t) xs
gathers LZ _ Unit = Unit
gathers (LS _ n) ixs (F x :* xs) = F (gatherFinalStates ixs x) :* gathers @n n ixs xs

-- | @rnnWithCull dynLen@ constructs an RNN as normal, but returns the
-- state after step @dynLen@ only.
rnnWithCull :: forall n bs x y t u ls b.
  KnownLen ls => KnownNat n => KnownLen x  => KnownLen y => All KnownLen ls =>
  All (LastEqual bs) ls =>
  T '[bs] Int32 -> RnnCell b ls (T x t) (T y u) -> RnnLayer b n ls x t y u
rnnWithCull dynLen cell s0 t = do
  xs <- unstack0 t
  (us,ss) <- chainForwardWithState cell (s0,xs)
  let sss = transposeV @n (shapeSList @ls) ss
  return (gathers @n (shapeSList @ls) dynLen sss,stack0 us)

-- | Like @rnnWithCull@, but states are threaded backwards.
rnnBackwardsWithCull :: forall n bs x y t u ls b.
  KnownLen ls => KnownNat n => KnownLen x  => KnownLen y => All KnownLen ls =>
  All (LastEqual bs) ls => LastEqual bs x => LastEqual bs y =>
  T '[bs] Int32 -> RnnCell b ls (T x t) (T y u) -> RnnLayer b n ls x t y u
rnnBackwardsWithCull dynLen cell s0 t = do
  (sFin,hs) <- rnnWithCull dynLen cell s0 (reverseSequences dynLen t)
  hs' <- assign (reverseSequences dynLen hs)
  return (sFin, hs')