Copyright | (c) Jean-Philippe Bernardy 2017 |
---|---|
License | LGPL-3 |
Maintainer | jean-philippe.bernardy@gu.se |
Stability | experimental |
Safe Haskell | None |
Language | Haskell2010 |
- type RnnCell t states input output = (HTV (Flt t) states, input) -> Gen (HTV (Flt t) states, output)
- type RnnLayer b n state input t output u = HTV (Flt b) state -> Tensor (n ': input) t -> Gen (HTV (Flt b) state, Tensor (n ': output) u)
- stackRnnCells :: forall s0 s1 a b c t. KnownLen s0 => RnnCell t s0 a b -> RnnCell t s1 b c -> RnnCell t (s0 ++ s1) a c
- (.-.) :: forall s0 s1 a b c t. KnownLen s0 => RnnCell t s0 a b -> RnnCell t s1 b c -> RnnCell t (s0 ++ s1) a c
- stackRnnLayers :: forall s1 s2 a t b u c v n bits. KnownLen s1 => RnnLayer bits n s1 a t b u -> RnnLayer bits n s2 b u c v -> RnnLayer bits n (s1 ++ s2) a t c v
- (.--.) :: forall s1 s2 a t b u c v n bits. KnownLen s1 => RnnLayer bits n s1 a t b u -> RnnLayer bits n s2 b u c v -> RnnLayer bits n (s1 ++ s2) a t c v
- bothRnnLayers :: forall s1 s2 a t b u c n bs bits. KnownLen s1 => RnnLayer bits n s1 a t '[b, bs] u -> RnnLayer bits n s2 a t '[c, bs] u -> RnnLayer bits n (s1 ++ s2) a t '[b + c, bs] u
- (.++.) :: forall s1 s2 a t b u c n bs bits. KnownLen s1 => RnnLayer bits n s1 a t '[b, bs] u -> RnnLayer bits n s2 a t '[c, bs] u -> RnnLayer bits n (s1 ++ s2) a t '[b + c, bs] u
- withBypass :: RnnCell b s0 (T '[x, bs] t) (T '[y, bs] t) -> RnnCell b s0 (T '[x, bs] t) (T '[x + y, bs] t)
- onStates :: (HTV (Flt t) xs -> HTV (Flt t) xs) -> RnnCell t xs a b -> RnnCell t xs a b
- timeDistribute :: (a -> b) -> RnnCell t '[] a b
- timeDistribute' :: (a -> Gen b) -> RnnCell t '[] a b
- cellInitializerBit :: forall n x t. (KnownNat n, KnownNat x, KnownBits t) => DenseP t (n + x) n
- data LSTMP t n x = LSTMP (DenseP t (n + x) n) (DenseP t (n + x) n) (DenseP t (n + x) n) (DenseP t (n + x) n)
- lstm :: forall n x bs t. LSTMP t n x -> RnnCell t '['[n, bs], '[n, bs]] (Tensor '[x, bs] (Flt t)) (Tensor '[n, bs] (Flt t))
- data GRUP t n x = GRUP (T '[n + x, n] (Typ Float t)) (T '[n + x, n] (Typ Float t)) (T '[n + x, n] (Typ Float t))
- gru :: forall n x bs t. (KnownNat bs, KnownNat n, KnownBits t) => GRUP t n x -> RnnCell t '['[n, bs]] (Tensor '[x, bs] (Flt t)) (Tensor '[n, bs] (Flt t))
- rnn :: forall n state input output t u b. (KnownNat n, KnownShape input, KnownShape output) => RnnCell b state (T input t) (T output u) -> RnnLayer b n state input t output u
- rnnBackward :: forall n state input output t u b. (KnownNat n, KnownShape input, KnownShape output) => RnnCell b state (T input t) (T output u) -> RnnLayer b n state input t output u
- rnnBackwardsWithCull :: forall n bs x y t u ls b. KnownLen ls => KnownNat n => KnownLen x => KnownLen y => All KnownLen ls => All (LastEqual bs) ls => LastEqual bs x => LastEqual bs y => T '[bs] Int32 -> RnnCell b ls (T x t) (T y u) -> RnnLayer b n ls x t y u
- rnnWithCull :: forall n bs x y t u ls b. KnownLen ls => KnownNat n => KnownLen x => KnownLen y => All KnownLen ls => All (LastEqual bs) ls => T '[bs] Int32 -> RnnCell b ls (T x t) (T y u) -> RnnLayer b n ls x t y u
- type AttentionScoring t batchSize keySize valueSize nValues = Tensor '[keySize, batchSize] (Typ Float t) -> Tensor '[nValues, valueSize, batchSize] (Typ Float t) -> Tensor '[nValues, batchSize] (Typ Float t)
- multiplicativeScoring :: forall valueSize keySize batchSize nValues t. KnownNat batchSize => T '[keySize, valueSize] (Typ Float t) -> AttentionScoring t batchSize keySize valueSize nValues
- data AdditiveScoringP sz keySize valueSize t = AdditiveScoringP (Tensor '[sz, 1] (Typ Float t)) (Tensor '[keySize, sz] (Typ Float t)) (Tensor '[valueSize, sz] (Typ Float t))
- additiveScoring :: forall sz keySize valueSize t nValues batchSize. KnownNat sz => KnownNat keySize => (KnownNat nValues, KnownNat batchSize) => AdditiveScoringP sz keySize valueSize t -> AttentionScoring t batchSize valueSize keySize nValues
- type AttentionFunction t batchSize keySize valueSize = T '[keySize, batchSize] (Flt t) -> Gen (T '[valueSize, batchSize] (Flt t))
- uniformAttn :: forall valueSize m keySize batchSize t. KnownNat m => KnownBits t => AttentionScoring t batchSize keySize valueSize m -> T '[batchSize] Int32 -> T '[m, valueSize, batchSize] (Flt t) -> AttentionFunction t batchSize keySize valueSize
- luongAttention :: forall attnSize d m e batchSize w. KnownNat m => KnownBits w => Tensor '[d + e, attnSize] (Flt w) -> AttentionScoring w batchSize e d m -> Tensor '[batchSize] Int32 -> T '[m, d, batchSize] (Flt w) -> AttentionFunction w batchSize e attnSize
- attentiveWithFeedback :: forall attSize cellSize inputSize bs w ss. AttentionFunction w bs cellSize attSize -> RnnCell w ss (T '[inputSize + attSize, bs] (Flt w)) (T '[cellSize, bs] (Flt w)) -> RnnCell w ('[attSize, bs] ': ss) (T '[inputSize, bs] (Flt w)) (T '[attSize, bs] (Flt w))
Types
type RnnCell t states input output = (HTV (Flt t) states, input) -> Gen (HTV (Flt t) states, output) Source #
A cell in an rnn. state
is the state propagated through time.
type RnnLayer b n state input t output u = HTV (Flt b) state -> Tensor (n ': input) t -> Gen (HTV (Flt b) state, Tensor (n ': output) u) Source #
A layer in an rnn. n
is the length of the time sequence. state
is the state propagated through time.
Combinators
stackRnnCells :: forall s0 s1 a b c t. KnownLen s0 => RnnCell t s0 a b -> RnnCell t s1 b c -> RnnCell t (s0 ++ s1) a c Source #
Stack two RNN cells (LHS is run first)
(.-.) :: forall s0 s1 a b c t. KnownLen s0 => RnnCell t s0 a b -> RnnCell t s1 b c -> RnnCell t (s0 ++ s1) a c Source #
Stack two RNN cells (LHS is run first)
stackRnnLayers :: forall s1 s2 a t b u c v n bits. KnownLen s1 => RnnLayer bits n s1 a t b u -> RnnLayer bits n s2 b u c v -> RnnLayer bits n (s1 ++ s2) a t c v Source #
Compose two rnn layers. This is useful for example to combine forward and backward layers.
(.--.) :: forall s1 s2 a t b u c v n bits. KnownLen s1 => RnnLayer bits n s1 a t b u -> RnnLayer bits n s2 b u c v -> RnnLayer bits n (s1 ++ s2) a t c v infixr 9 Source #
Compose two rnn layers. This is useful for example to combine forward and backward layers.
bothRnnLayers :: forall s1 s2 a t b u c n bs bits. KnownLen s1 => RnnLayer bits n s1 a t '[b, bs] u -> RnnLayer bits n s2 a t '[c, bs] u -> RnnLayer bits n (s1 ++ s2) a t '[b + c, bs] u Source #
Compose two rnn layers in parallel.
(.++.) :: forall s1 s2 a t b u c n bs bits. KnownLen s1 => RnnLayer bits n s1 a t '[b, bs] u -> RnnLayer bits n s2 a t '[c, bs] u -> RnnLayer bits n (s1 ++ s2) a t '[b + c, bs] u infixr 9 Source #
Compose two rnn layers in parallel.
withBypass :: RnnCell b s0 (T '[x, bs] t) (T '[y, bs] t) -> RnnCell b s0 (T '[x, bs] t) (T '[x + y, bs] t) Source #
Run the cell, and forward the input to the output, by concatenation with the output of the cell.
onStates :: (HTV (Flt t) xs -> HTV (Flt t) xs) -> RnnCell t xs a b -> RnnCell t xs a b Source #
Apply a function on the cell state(s) before running the cell itself.
timeDistribute :: (a -> b) -> RnnCell t '[] a b Source #
Convert a pure function (feed-forward layer) to an RNN cell by ignoring the RNN state.
timeDistribute' :: (a -> Gen b) -> RnnCell t '[] a b Source #
Convert a stateless generator into an RNN cell by ignoring the RNN state.
RNN Cells
cellInitializerBit :: forall n x t. (KnownNat n, KnownNat x, KnownBits t) => DenseP t (n + x) n Source #
Standard RNN gate initializer. (The recurrent kernel is orthogonal to avoid divergence; the input kernel is glorot)
Parameter for an LSTM
lstm :: forall n x bs t. LSTMP t n x -> RnnCell t '['[n, bs], '[n, bs]] (Tensor '[x, bs] (Flt t)) (Tensor '[n, bs] (Flt t)) Source #
Standard LSTM
Parameter for a GRU
gru :: forall n x bs t. (KnownNat bs, KnownNat n, KnownBits t) => GRUP t n x -> RnnCell t '['[n, bs]] (Tensor '[x, bs] (Flt t)) (Tensor '[n, bs] (Flt t)) Source #
Standard GRU cell
RNN unfolding functions
rnn :: forall n state input output t u b. (KnownNat n, KnownShape input, KnownShape output) => RnnCell b state (T input t) (T output u) -> RnnLayer b n state input t output u Source #
Build a RNN by repeating a cell n
times.
rnnBackward :: forall n state input output t u b. (KnownNat n, KnownShape input, KnownShape output) => RnnCell b state (T input t) (T output u) -> RnnLayer b n state input t output u Source #
Build a RNN by repeating a cell n
times. However the state is
propagated in the right-to-left direction (decreasing indices in
the time dimension of the input and output tensors)
rnnBackwardsWithCull :: forall n bs x y t u ls b. KnownLen ls => KnownNat n => KnownLen x => KnownLen y => All KnownLen ls => All (LastEqual bs) ls => LastEqual bs x => LastEqual bs y => T '[bs] Int32 -> RnnCell b ls (T x t) (T y u) -> RnnLayer b n ls x t y u Source #
Like rnnWithCull
, but states are threaded backwards.
rnnWithCull :: forall n bs x y t u ls b. KnownLen ls => KnownNat n => KnownLen x => KnownLen y => All KnownLen ls => All (LastEqual bs) ls => T '[bs] Int32 -> RnnCell b ls (T x t) (T y u) -> RnnLayer b n ls x t y u Source #
rnnWithCull dynLen
constructs an RNN as normal, but returns the
state after step dynLen
only.
Attention mechanisms
Scoring functions
type AttentionScoring t batchSize keySize valueSize nValues = Tensor '[keySize, batchSize] (Typ Float t) -> Tensor '[nValues, valueSize, batchSize] (Typ Float t) -> Tensor '[nValues, batchSize] (Typ Float t) Source #
An attention scoring function. This function should produce a
score (between 0 and 1) for each of the nValues
entries of size
valueSize
.
multiplicativeScoring Source #
:: KnownNat batchSize | |
=> T '[keySize, valueSize] (Typ Float t) | weights |
-> AttentionScoring t batchSize keySize valueSize nValues |
Multiplicative scoring function
data AdditiveScoringP sz keySize valueSize t Source #
AdditiveScoringP (Tensor '[sz, 1] (Typ Float t)) (Tensor '[keySize, sz] (Typ Float t)) (Tensor '[valueSize, sz] (Typ Float t)) |
(KnownNat n, KnownNat k, KnownNat v, KnownBits NBits t) => ParamWithDefault (AdditiveScoringP k v n t) Source # | |
(KnownNat n, KnownNat k, KnownNat v, KnownBits NBits t) => KnownTensors (AdditiveScoringP k v n t) Source # | |
additiveScoring :: forall sz keySize valueSize t nValues batchSize. KnownNat sz => KnownNat keySize => (KnownNat nValues, KnownNat batchSize) => AdditiveScoringP sz keySize valueSize t -> AttentionScoring t batchSize valueSize keySize nValues Source #
An additive scoring function. See https://arxiv.org/pdf/1412.7449.pdf
Attention functions
type AttentionFunction t batchSize keySize valueSize = T '[keySize, batchSize] (Flt t) -> Gen (T '[valueSize, batchSize] (Flt t)) Source #
A function which attends to an external input. Typically a function of this type is a closure which has the attended input in its environment.
:: KnownNat m | |
=> KnownBits t | |
=> AttentionScoring t batchSize keySize valueSize m | scoring function |
-> T '[batchSize] Int32 | lengths of the inputs |
-> T '[m, valueSize, batchSize] (Flt t) | inputs |
-> AttentionFunction t batchSize keySize valueSize |
attnExample1 θ h st
combines each element of the vector h with
s, and applies a dense layer with parameters θ. The "winning"
element of h (using softmax) is returned.
:: KnownNat m | |
=> KnownBits w | |
=> Tensor '[d + e, attnSize] (Flt w) | weights for the dense layer |
-> AttentionScoring w batchSize e d m | scoring function |
-> Tensor '[batchSize] Int32 | length of the input |
-> T '[m, d, batchSize] (Flt w) | inputs |
-> AttentionFunction w batchSize e attnSize |
Luong attention function (following https://github.com/tensorflow/nmt#background-on-the-attention-mechanism commit 75aa22dfb159f10a1a5b4557777d9ff547c1975a). Essentially a dense layer with tanh activation, on top of uniform attention.
Attention combinators
attentiveWithFeedback :: forall attSize cellSize inputSize bs w ss. AttentionFunction w bs cellSize attSize -> RnnCell w ss (T '[inputSize + attSize, bs] (Flt w)) (T '[cellSize, bs] (Flt w)) -> RnnCell w ('[attSize, bs] ': ss) (T '[inputSize, bs] (Flt w)) (T '[attSize, bs] (Flt w)) Source #
Add some attention to an RnnCell, and feed the attention vector to the next iteration in the rnn. (This follows the diagram at https://github.com/tensorflow/nmt#background-on-the-attention-mechanism commit 75aa22dfb159f10a1a5b4557777d9ff547c1975a).