module TypedFlow.Layers.RNN (
RnnCell, RnnLayer,
stackRnnCells, (.-.),
stackRnnLayers, (.--.),
bothRnnLayers,(.++.),
withBypass,
onStates,
timeDistribute, timeDistribute',
cellInitializerBit,
LSTMP(..),
lstm,
GRUP(..),
gru,
rnn,
rnnBackward,
rnnBackwardsWithCull,
rnnWithCull,
AttentionScoring,
multiplicativeScoring,
AdditiveScoringP(..), additiveScoring,
AttentionFunction,
uniformAttn,
luongAttention,
attentiveWithFeedback
)
where
import Prelude hiding (tanh,Num(..),Floating(..),floor)
import GHC.TypeLits
import TypedFlow.TF
import TypedFlow.Types
import TypedFlow.Layers.Core (DenseP(..),(#))
import Data.Monoid ((<>))
type RnnCell t states input output = (HTV (Flt t) states , input) -> Gen (HTV (Flt t) states , output)
type RnnLayer b n state input t output u = HTV (Flt b) state -> Tensor (n ': input) t -> Gen (HTV (Flt b) state , Tensor (n ': output) u)
(.--.),stackRnnLayers :: forall s1 s2 a t b u c v n bits. KnownLen s1 =>
RnnLayer bits n s1 a t b u -> RnnLayer bits n s2 b u c v -> RnnLayer bits n (s1 ++ s2) a t c v
stackRnnLayers f g (hsplit @s1 -> (s0,s1)) x = do
(s0',y) <- f s0 x
(s1',z) <- g s1 y
return (happ s0' s1',z)
infixr .--.
(.--.) = stackRnnLayers
bothRnnLayers,(.++.) :: forall s1 s2 a t b u c n bs bits. KnownLen s1 =>
RnnLayer bits n s1 a t '[b,bs] u -> RnnLayer bits n s2 a t '[c,bs] u -> RnnLayer bits n (s1 ++ s2) a t '[b+c,bs] u
bothRnnLayers f g (hsplit @s1 -> (s0,s1)) x = do
(s0',y) <- f s0 x
(s1',z) <- g s1 x
return (happ s0' s1',concat1 y z)
infixr .++.
(.++.) = bothRnnLayers
onStates :: (HTV (Flt t) xs -> HTV (Flt t) xs) -> RnnCell t xs a b -> RnnCell t xs a b
onStates f cell (h,x) = do
cell (f h, x)
stackRnnCells, (.-.) :: forall s0 s1 a b c t. KnownLen s0 => RnnCell t s0 a b -> RnnCell t s1 b c -> RnnCell t (s0 ++ s1) a c
stackRnnCells l1 l2 (hsplit @s0 -> (s0,s1),x) = do
(s0',y) <- l1 (s0,x)
(s1',z) <- l2 (s1,y)
return ((happ s0' s1'),z)
(.-.) = stackRnnCells
withBypass :: RnnCell b s0 (T '[x,bs] t) (T '[y,bs] t) -> RnnCell b s0 (T '[x,bs] t) (T '[x+y,bs] t)
withBypass cell (s,x) = do
(s',y) <- cell (s,x)
return (s',concat0 x y)
timeDistribute :: (a -> b) -> RnnCell t '[] a b
timeDistribute pureLayer = timeDistribute' (return . pureLayer)
timeDistribute' :: (a -> Gen b) -> RnnCell t '[] a b
timeDistribute' stateLess (Unit,a) = do
b <- stateLess a
return (Unit,b)
cellInitializerBit :: ∀ n x t. (KnownNat n, KnownNat x, KnownBits t) => DenseP t (n + x) n
cellInitializerBit = DenseP (concat0 recurrentInitializer kernelInitializer) biasInitializer
where
recurrentInitializer :: Tensor '[n, n] ('Typ 'Float t)
recurrentInitializer = randomOrthogonal
kernelInitializer :: Tensor '[x, n] ('Typ 'Float t)
kernelInitializer = glorotUniform
biasInitializer = zeros
data LSTMP t n x = LSTMP (DenseP t (n+x) n) (DenseP t (n+x) n) (DenseP t (n+x) n) (DenseP t (n+x) n)
instance (KnownNat n, KnownNat x, KnownBits t) => KnownTensors (LSTMP t n x) where
travTensor f s (LSTMP x y z w) = LSTMP <$> travTensor f (s<>"_f") x <*> travTensor f (s<>"_i") y <*> travTensor f (s<>"_c") z <*> travTensor f (s<>"_o") w
instance (KnownNat n, KnownNat x, KnownBits t) => ParamWithDefault (LSTMP t n x) where
defaultInitializer = LSTMP forgetInit cellInitializerBit cellInitializerBit cellInitializerBit
where forgetInit = DenseP (denseWeights cellInitializerBit) ones
lstm :: ∀ n x bs t. LSTMP t n x ->
RnnCell t '[ '[n,bs], '[n,bs]] (Tensor '[x,bs] (Flt t)) (Tensor '[n,bs] (Flt t))
lstm (LSTMP wf wi wc wo) (VecPair ht1 ct1, input) = do
hx <- assign (concat0 ht1 input)
let f = sigmoid (wf # hx)
i = sigmoid (wi # hx)
cTilda = tanh (wc # hx)
o = sigmoid (wo # hx)
c <- assign ((f ⊙ ct1) + (i ⊙ cTilda))
h <- assign (o ⊙ tanh c)
return (VecPair h c, h)
data GRUP t n x = GRUP (T [n+x,n] ('Typ 'Float t)) (T [n+x,n] ('Typ 'Float t)) (T [n+x,n] ('Typ 'Float t))
instance (KnownNat n, KnownNat x, KnownBits t) => KnownTensors (GRUP t n x) where
travTensor f s (GRUP x y z) = GRUP <$> travTensor f (s<>"_z") x <*> travTensor f (s<>"_r") y <*> travTensor f (s<>"_w") z
instance (KnownNat n, KnownNat x, KnownBits t) => ParamWithDefault (GRUP t n x) where
defaultInitializer = GRUP (denseWeights cellInitializerBit) (denseWeights cellInitializerBit) (denseWeights cellInitializerBit)
gru :: ∀ n x bs t. (KnownNat bs, KnownNat n, KnownBits t) => GRUP t n x ->
RnnCell t '[ '[n,bs] ] (Tensor '[x,bs] (Flt t)) (Tensor '[n,bs] (Flt t))
gru (GRUP wz wr w) (VecSing ht1, xt) = do
hx <- assign (concat0 ht1 xt)
let zt = sigmoid (wz ∙ hx)
rt = sigmoid (wr ∙ hx)
hTilda = tanh (w ∙ (concat0 (rt ⊙ ht1) xt))
ht <- assign ((ones ⊝ zt) ⊙ ht1 + zt ⊙ hTilda)
return (VecSing ht, ht)
type AttentionScoring t batchSize keySize valueSize nValues =
Tensor '[keySize,batchSize] ('Typ 'Float t) -> Tensor '[nValues,valueSize,batchSize] ('Typ 'Float t) -> Tensor '[nValues,batchSize] ('Typ 'Float t)
type AttentionFunction t batchSize keySize valueSize =
T '[keySize,batchSize] (Flt t) -> Gen (T '[valueSize,batchSize] (Flt t))
uniformAttn :: ∀ valueSize m keySize batchSize t. KnownNat m => KnownBits t
=> AttentionScoring t batchSize keySize valueSize m
-> T '[batchSize] Int32
-> T '[m,valueSize,batchSize] (Flt t)
-> AttentionFunction t batchSize keySize valueSize
uniformAttn score lengths hs_ ht = do
let αt :: T '[m,batchSize] (Flt t)
xx = score ht hs_
αt = softmax0 (mask ⊙ xx)
ct :: T '[valueSize,batchSize] (Flt t)
ct = squeeze0 (matmul hs_ (expandDim0 αt))
mask = cast (sequenceMask @m lengths)
return ct
attentiveWithFeedback ::forall attSize cellSize inputSize bs w ss.
AttentionFunction w bs cellSize attSize ->
RnnCell w ss (T '[inputSize+attSize,bs] (Flt w)) (T '[cellSize,bs] (Flt w)) ->
RnnCell w ('[attSize,bs] ': ss) (T '[inputSize ,bs] (Flt w)) (T '[attSize,bs] (Flt w))
attentiveWithFeedback attn cell ((F prevAttnVector :* s),x) = do
(s',y) <- cell (s,concat0 x prevAttnVector)
focus <- attn y
return ((F focus :* s'),focus)
luongAttention :: ∀ attnSize d m e batchSize w. KnownNat m => KnownBits w
=> Tensor '[d+e,attnSize] (Flt w)
-> AttentionScoring w batchSize e d m
-> Tensor '[batchSize] Int32
-> T '[m,d,batchSize] (Flt w)
-> AttentionFunction w batchSize e attnSize
luongAttention w scoring lens hs_ ht = do
ct <- uniformAttn scoring lens hs_ ht
return (tanh (w ∙ (concat0 ct ht)))
multiplicativeScoring :: forall valueSize keySize batchSize nValues t.
KnownNat batchSize => T [keySize,valueSize] ('Typ 'Float t)
-> AttentionScoring t batchSize keySize valueSize nValues
multiplicativeScoring w dt hs = squeeze1 (matmul (expandDim1 ir) hs)
where ir :: T '[valueSize,batchSize] ('Typ 'Float t)
ir = w ∙ dt
data AdditiveScoringP sz keySize valueSize t = AdditiveScoringP
(Tensor '[sz, 1] ('Typ 'Float t))
(Tensor '[keySize, sz] ('Typ 'Float t))
(Tensor '[valueSize, sz] ('Typ 'Float t))
instance (KnownNat n, KnownNat k, KnownNat v, KnownBits t) => KnownTensors (AdditiveScoringP k v n t) where
travTensor f s (AdditiveScoringP x y z) = AdditiveScoringP <$> travTensor f (s<>"_v") x <*> travTensor f (s<>"_w1") y <*> travTensor f (s<>"_w2") z
instance (KnownNat n, KnownNat k, KnownNat v, KnownBits t) => ParamWithDefault (AdditiveScoringP k v n t) where
defaultInitializer = AdditiveScoringP glorotUniform glorotUniform glorotUniform
additiveScoring :: forall sz keySize valueSize t nValues batchSize. KnownNat sz => KnownNat keySize => (KnownNat nValues, KnownNat batchSize) =>
AdditiveScoringP sz keySize valueSize t -> AttentionScoring t batchSize valueSize keySize nValues
additiveScoring (AdditiveScoringP v w1 w2) dt h = transpose r''
where w1h :: Tensor '[sz,batchSize, nValues] ('Typ 'Float t)
w1h = transposeN01 @'[sz] (reshape @'[sz,nValues, batchSize] w1h')
w1h' = matmul (reshape @'[keySize, nValues*batchSize] (transpose01 h)) (transpose01 w1)
w2dt = w2 ∙ dt
z' = reshape @'[sz,batchSize*nValues] (tanh (w1h + w2dt))
r'' = reshape @[batchSize,nValues] (matmul z' (transpose v))
rnn :: ∀ n state input output t u b.
(KnownNat n, KnownShape input, KnownShape output) =>
RnnCell b state (T input t) (T output u) -> RnnLayer b n state input t output u
rnn cell s0 t = do
xs <- unstack0 t
(sFin,us) <- chainForward cell (s0,xs)
return (sFin,stack0 us)
rnnBackward :: ∀ n state input output t u b.
(KnownNat n, KnownShape input, KnownShape output) =>
RnnCell b state (T input t) (T output u) -> RnnLayer b n state input t output u
rnnBackward cell s0 t = do
xs <- unstack0 t
(sFin,us) <- chainBackward cell (s0,xs)
return (sFin,stack0 us)
chainForward :: ∀ state a b n. ((state , a) -> Gen (state , b)) → (state , V n a) -> Gen (state , V n b)
chainForward _ (s0 , V []) = return (s0 , V [])
chainForward f (s0 , V (x:xs)) = do
(s1,x') <- f (s0 , x)
(sFin,V xs') <- chainForward f (s1 , V xs)
return (sFin,V (x':xs'))
chainBackward :: ∀ state a b n. ((state , a) -> Gen (state , b)) → (state , V n a) -> Gen (state , V n b)
chainBackward _ (s0 , V []) = return (s0 , V [])
chainBackward f (s0 , V (x:xs)) = do
(s1,V xs') <- chainBackward f (s0,V xs)
(sFin, x') <- f (s1,x)
return (sFin,V (x':xs'))
chainForwardWithState :: ∀ state a b n. ((state , a) -> Gen (state , b)) → (state , V n a) -> Gen (V n b, V n state)
chainForwardWithState _ (_s0 , V []) = return (V [], V [])
chainForwardWithState f (s0 , V (x:xs)) = do
(s1,x') <- f (s0 , x)
(V xs',V ss) <- chainForwardWithState f (s1 , V xs)
return (V (x':xs'), V (s1:ss) )
transposeV :: forall n xs t. All KnownLen xs =>
SList xs -> V n (HTV (Flt t) xs) -> HTV (Flt t) (Ap (FMap (Cons n)) xs)
transposeV LZ _ = Unit
transposeV (LS _ n) xxs = F ys' :* yys'
where (ys,yys) = help @(Tail xs) xxs
ys' = stack0 ys
yys' = transposeV n yys
help :: forall ys x tt. V n (HTV tt (x ': ys)) -> (V n (T x tt) , V n (HTV tt ys))
help (V xs) = (V (map (fromF . hhead) xs),V (map htail xs))
gatherFinalStates :: KnownLen x => KnownNat n => LastEqual bs x => T '[bs] Int32 -> T (n ': x) t -> T x t
gatherFinalStates dynLen states = nth0 0 (reverseSequences dynLen states)
gathers :: forall n bs xs t. All (LastEqual bs) xs => All KnownLen xs => KnownNat n =>
SList xs -> T '[bs] Int32 -> HTV (Flt t) (Ap (FMap (Cons n)) xs) -> HTV (Flt t) xs
gathers LZ _ Unit = Unit
gathers (LS _ n) ixs (F x :* xs) = F (gatherFinalStates ixs x) :* gathers @n n ixs xs
rnnWithCull :: forall n bs x y t u ls b.
KnownLen ls => KnownNat n => KnownLen x => KnownLen y => All KnownLen ls =>
All (LastEqual bs) ls =>
T '[bs] Int32 -> RnnCell b ls (T x t) (T y u) -> RnnLayer b n ls x t y u
rnnWithCull dynLen cell s0 t = do
xs <- unstack0 t
(us,ss) <- chainForwardWithState cell (s0,xs)
let sss = transposeV @n (shapeSList @ls) ss
return (gathers @n (shapeSList @ls) dynLen sss,stack0 us)
rnnBackwardsWithCull :: forall n bs x y t u ls b.
KnownLen ls => KnownNat n => KnownLen x => KnownLen y => All KnownLen ls =>
All (LastEqual bs) ls => LastEqual bs x => LastEqual bs y =>
T '[bs] Int32 -> RnnCell b ls (T x t) (T y u) -> RnnLayer b n ls x t y u
rnnBackwardsWithCull dynLen cell s0 t = do
(sFin,hs) <- rnnWithCull dynLen cell s0 (reverseSequences dynLen t)
hs' <- assign (reverseSequences dynLen hs)
return (sFin, hs')