分享

LSTM前向传播与反向传播算法推导

 印度阿三17 2019-09-21

1.长短期记忆网络LSTM

LSTM(Long short-term memory)通过刻意的设计来避免长期依赖问题,是一种特殊的RNN。长时间记住信息实际上是 LSTM 的默认行为,而不是需要努力学习的东西!

所有递归神经网络都具有神经网络的链式重复模块。在标准的RNN中,这个重复模块具有非常简单的结构,例如只有单个tanh层,如下图所示。
[外链图片转存失败(img-EwKxtSFp-1569051242265)(./images/lstm-rnn.jpg)]
LSTM具有同样的结构,但是重复的模块拥有不同的结构,如下图所示。与RNN的单一神经网络层不同,这里有四个网络层,并且以一种非常特殊的方式进行交互。

1.1 LSTM–遗忘门在这里插入图片描述

LSTM 的第一步要决定从细胞状态中舍弃哪些信息。这一决定由所谓“遗忘门层”的 S 形网络层做出。它接收ht−1h_{t-1}ht−1 和xtx_txt,并且对细胞状态Ct−1C_{t−1}Ct−1 中的每一个数来说输出值都介于 0 和 1 之间。1 表示“完全接受这个”,0 表示“完全忽略这个”。

1.2 LSTM–输入门在这里插入图片描述

下一步就是要确定需要在细胞状态中保存哪些新信息。这里分成两部分。第一部分,一个所谓“输入门层”的 S 形网络层确定哪些信息需要更新。第二部分,一个 tanh 形网络层创建一个新的备选值向量——C~t\tilde{C}_tC~t,可以用来添加到细胞状态。在下一步中我们将上面的两部分结合起来,产生对状态的更新。

1.3 LSTM–细胞状态更新在这里插入图片描述

现在更新旧的细胞状态Ct−1C_{t−1}Ct−1 更新到CtC_tCt。先前的步骤已经决定要做什么,我们只需要照做就好。
我们对旧的状态乘以ftf_tft,用来忘记我们决定忘记的事。然后我们加上it⊙C~ti_t\odot\tilde{C}_tit⊙C~t,这是新的候选值,根据我们对每个状态决定的更新值按比例进行缩放。

1.4 LSTM–输出门在这里插入图片描述

最后,我们需要确定输出值。输出依赖于我们的细胞状态,但会是一个“过滤的”版本。首先我们运行 S 形网络层,用来确定细胞状态中的哪些部分可以输出。然后,我们把细胞状态输入 tanh(把数值调整到 −1 和 1 之间)再和 S 形网络层的输出值相乘,部这样我们就可以输出想要输出的分。

1.5 LSTM的变种

目前我所描述的还只是一个相当一般化的 LSTM 网络。但并非所有 LSTM 网络都和之前描述的一样。事实上,几乎所有文章都会改进 LSTM 网络得到一个特定版本。差别是次要的,但有必要认识一下这些变种。

(1) 一个流行的 LSTM 变种由 Gers 和 Schmidhuber 提出,在 LSTM 的基础上添加了一个“窥视孔连接”,这意味着我们可以让门网络层输入细胞状态。
[外链图片转存失败(img-MPofz7mK-1569051242268)(./images/lstm-5.jpg)]
上图中我们为所有门添加窥视孔,但许多论文只为部分门添加.

(2)另一个变种把遗忘和输入门结合起来。同时确定要遗忘的信息和要添加的新信息,而不再是分开确定。当输入的时候才会遗忘,当遗忘旧信息的时候才会输入新数据。
[外链图片转存失败(img-1wlmhlj8-1569051242268)(./images/lstm-6.jpg)]
(3)一个更有意思的 LSTM 变种称为 Gated Recurrent Unit(GRU),由 Cho 等人提出。GRU 把遗忘门和输入门合并成为一个“更新门”,把细胞状态和隐含状态合并,还有其他变化。这样做使得 GRU 比标准的 LSTM 模型更简单,因此正在变得流行起来。
[外链图片转存失败(img-KGPO73y8-1569051242268)(./images/lstm-7.jpg)]

2.LSTM前向传播与反向传播

本小节只推导添加“窥视孔连接”的变种LSTM,如下图所示,其它LSTM变种的推导方法与该方法类似,这里不做过多介绍。对反向传播算法了解不够透彻的,请参考https://zhuanlan.zhihu.com/p/79657669 ,这里有详细的推导过程,本文将直接使用https://zhuanlan.zhihu.com/p/79657669的结论。

为了更直观的推导反向传播算法,将其转化为右图所示形式。

2.1 LSTM前向传播

LSTM在t时刻的前向传播公式为:
{it=σ(i~t)=σ(WxixtWhiht−1Wcict−1bi)ft=σ(f~t)=σ(WxfxtWhfht−1Wcfct−1bf)gt=tanh⁡(g~t)=tanh⁡(WxgxtWhght−1bg)ot=σ(o~t)=σ(WxoxtWhoht−1Wcoctbo)ct=ct−1⊙ftgt⊙itmt=tanh⁡(ct)ht=ot⊙mtyt=Wyhhtby\left\{ \begin{array}{l} {i_t=\sigma(\tilde{i}_t)=\sigma(W_{xi}x_t W_{hi}h_{t-1} W_{ci}c_{t-1} b_i)} \ {f_t=\sigma(\tilde{f}_t)=\sigma(W_{xf}x_t W_{hf}h_{t-1} W_{cf}c_{t-1} b_f) }\ {g_t=\tanh(\tilde{g}_t)=\tanh(W_{xg}x_t W_{hg}h_{t-1} b_g)} \ {o_t=\sigma(\tilde{o}_t)=\sigma(W_{xo}x_t W_{ho}h_{t-1} W_{co}c_{t} b_o) }\ {c_t=c_{t-1}\odot f_t g_t\odot i_t}\ {m_t=\tanh(c_t)}\ {h_t=o_t\odot m_t}\ {y_t=W_{yh}h_t b_y} \end{array}\right.⎩⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎧it=σ(i~t)=σ(Wxixt Whiht−1 Wcict−1 bi)ft=σ(f~t)=σ(Wxfxt Whfht−1 Wcfct−1 bf)gt=tanh(g~t)=tanh(Wxgxt Whght−1 bg)ot=σ(o~t)=σ(Wxoxt Whoht−1 Wcoct bo)ct=ct−1⊙ft gt⊙itmt=tanh(ct)ht=ot⊙mtyt=Wyhht by

2.2 LSTM反向传播

已知:∂J∂yt,∂J∂ct1,∂J∂o~t1,,∂J∂f~t1,∂J∂i~t1,∂J∂g~t1\frac{\partial J}{\partial y_t},\frac{\partial J}{\partial c_{t 1}},\frac{\partial J}{\partial \tilde{o}_{t 1}},,\frac{\partial J}{\partial \tilde{f}_{t 1}},\frac{\partial J}{\partial \tilde{i}_{t 1}},\frac{\partial J}{\partial \tilde{g}_{t 1}}∂yt∂J,∂ct 1∂J,∂o~t 1∂J,,∂f~t 1∂J,∂i~t 1∂J,∂g~t 1∂J,求某个节点梯度时,首先应该找到该节点的输出节点,然后分别计算所有输出节点的梯度乘以输出节点对该节点的梯度,最后相加即可得到该节点的梯度。如计算∂J∂ht\frac{\partial J}{\partial h_t}∂ht∂J时,找到hth_tht节点的所有输出节点yt、o~t1、f~t1、i~t1、g~t1y_t、 \tilde{o}_{t 1}、\tilde{f}_{t 1}、\tilde{i}_{t 1}、\tilde{g}_{t 1}yt、o~t 1、f~t 1、i~t 1、g~t 1,然后分别计算输出节点的梯度(如∂J∂yt\frac{\partial J}{\partial y_t}∂yt∂J)与输出节点对hth_tht的梯度的乘积(如∂J∂ytWyhT\frac{\partial J}{\partial y_t}W_{yh}^T∂yt∂JWyhT),最后相加即可得到节点hth_tht的梯度:
∂J∂ht=∂J∂ytWyhT∂J∂o~t1WhoT∂J∂f~t1WhfT∂J∂i~t1WhiT∂J∂g~t1WhgT\frac{\partial J}{\partial h_t}=\frac{\partial J}{\partial y_t}W_{yh}^T \frac{\partial J}{\partial \tilde{o}_{t 1}}W_{ho}^T \frac{\partial J}{\partial \tilde{f}_{t 1}}W_{hf}^T \frac{\partial J}{\partial \tilde{i}_{t 1}}W_{hi}^T \frac{\partial J}{\partial \tilde{g}_{t 1}}W_{hg}^T∂ht∂J=∂yt∂JWyhT ∂o~t 1∂JWhoT ∂f~t 1∂JWhfT ∂i~t 1∂JWhiT ∂g~t 1∂JWhgT
同理可得t时刻其它节点的梯度:
{∂J∂ht=∂J∂ytWyhT∂J∂o~t1WhoT∂J∂f~t1WhfT∂J∂i~t1WhiT∂J∂g~t1WhgT∂J∂mt=∂J∂ht⊙ot∂J∂ct=∂J∂mtdmtdct∂J∂ct1⊙ft1∂J∂f~t1WcfT∂J∂i~t1WciT∂J∂gt=∂J∂ct⊙it∂J∂it=∂J∂ct⊙gt∂J∂ft=∂J∂ct⊙ct−1∂J∂ot=∂J∂ht⊙mt}⇒{∂J∂g~t=∂J∂gt(1−gt2)∂J∂i~t=∂J∂itit(1−it)∂J∂f~t=∂J∂ftft(1−ft)∂J∂o~t=∂J∂otit(1−ot)∂J∂xt=∂J∂o~tWxoT∂J∂f~tWxfT∂J∂i~tWxiT∂J∂g~tWxgT\left \{\begin{array}{l} \frac{\partial J}{\partial h_t}=\frac{\partial J}{\partial y_t}W_{yh}^T \frac{\partial J}{\partial \tilde{o}_{t 1}}W_{ho}^T \frac{\partial J}{\partial \tilde{f}_{t 1}}W_{hf}^T \frac{\partial J}{\partial \tilde{i}_{t 1}}W_{hi}^T \frac{\partial J}{\partial \tilde{g}_{t 1}}W_{hg}^T \\ \ \frac{\partial J}{\partial m_t} = \frac{\partial J}{\partial h_t} \odot o_t \\ \ \frac{\partial J}{\partial c_t} = \frac{\partial J}{\partial m_t}\frac{dm_t}{dc_t} \frac{\partial J}{\partial c_{t 1}}\odot f_{t 1} \frac{\partial J}{\partial \tilde{f}_{t 1}}W_{cf}^T \frac{\partial J}{\partial \tilde{i}_{t 1}}W_{ci}^T \\ \ \left. \begin{array}{l} \frac{\partial J}{\partial g_t} = \frac{\partial J}{\partial c_t}\odot i_t \ \frac{\partial J}{\partial i_t} = \frac{\partial J}{\partial c_t} \odot g_t \ \frac{\partial J}{\partial f_t} = \frac{\partial J}{\partial c_t} \odot c_{t-1} \ \frac{\partial J}{\partial o_t} = \frac{\partial J}{\partial h_t} \odot m_t \end{array} \right \} \Rightarrow \left\{ \begin{array}{l} \frac{\partial J}{\partial \tilde{g}_t} = \frac{\partial J}{\partial g_t}(1-g_t^2) \ \frac{\partial J}{\partial \tilde{i}_t} = \frac{\partial J}{\partial i_t}i_t(1-i_t) \ \frac{\partial J}{\partial \tilde{f}_t} = \frac{\partial J}{\partial f_t}f_t(1-f_t) \ \frac{\partial J}{\partial \tilde{o}_t} = \frac{\partial J}{\partial o_t}i_t(1-o_t) \ \end{array}\right. \\ \ \frac{\partial J}{\partial x_t} = \frac{\partial J}{\partial \tilde{o}_t}W_{xo}^T \frac{\partial J}{\partial \tilde{f}_t}W_{xf}^T \frac{\partial J}{\partial \tilde{i}_t}W_{xi}^T \frac{\partial J}{\partial \tilde{g}_t}W_{xg}^T\\end{array}\right.⎩⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎨⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎪⎧∂ht∂J=∂yt∂JWyhT ∂o~t 1∂JWhoT ∂f~t 1∂JWhfT ∂i~t 1∂JWhiT ∂g~t 1∂JWhgT∂mt∂J=∂ht∂J⊙ot∂ct∂J=∂mt∂Jdctdmt ∂ct 1∂J⊙ft 1 ∂f~t 1∂JWcfT ∂i~t 1∂JWciT∂gt∂J=∂ct∂J⊙it∂it∂J=∂ct∂J⊙gt∂ft∂J=∂ct∂J⊙ct−1∂ot∂J=∂ht∂J⊙mt⎭⎪⎪⎪⎬⎪⎪⎪⎫⇒⎩⎪⎪⎪⎨⎪⎪⎪⎧∂g~t∂J=∂gt∂J(1−gt2)∂i~t∂J=∂it∂Jit(1−it)∂f~t∂J=∂ft∂Jft(1−ft)∂o~t∂J=∂ot∂Jit(1−ot)∂xt∂J=∂o~t∂JWxoT ∂f~t∂JWxfT ∂i~t∂JWxiT ∂g~t∂JWxgT

对参数的梯度:
{∂J∂Who=htT∂J∂o~t1∂J∂Whf=htT∂J∂f~t1∂J∂Whi=htT∂J∂i~t1∂J∂Whg=htT∂J∂g~t1{∂J∂Wyh=htT∂J∂yt∂J∂Wcf=ctT∂J∂f~t1∂J∂Wci=ctT∂J∂i~t1∂J∂Wco=ctT∂J∂o~t{∂J∂Wxo=xtT∂J∂o~t∂J∂Wxf=xtT∂J∂f~t∂J∂Wxi=xtT∂J∂i~t∂J∂Wxg=xtT∂J∂g~t\left \{\begin{array}{l} \frac{\partial J}{\partial W_{ho}} = h_t^T\frac{\partial J}{\partial \tilde{o}_{t 1}} \ \frac{\partial J}{\partial W_{hf}} = h_t^T\frac{\partial J}{\partial \tilde{f}_{t 1}} \ \frac{\partial J}{\partial W_{hi}} = h_t^T\frac{\partial J}{\partial \tilde{i}_{t 1}} \ \frac{\partial J}{\partial W_{hg}} = h_t^T\frac{\partial J}{\partial \tilde{g}_{t 1}} \end{array} \right. \left \{\begin{array}{l} \frac{\partial J}{\partial W_{yh}} = h_t^T\frac{\partial J}{\partial y_t} \ \frac{\partial J}{\partial W_{cf}} = c_t^T\frac{\partial J}{\partial \tilde{f}_{t 1}} \ \frac{\partial J}{\partial W_{ci}} = c_t^T\frac{\partial J}{\partial \tilde{i}_{t 1}} \ \frac{\partial J}{\partial W_{co}} = c_t^T\frac{\partial J}{\partial \tilde{o}_{t}} \end{array} \right. \left \{\begin{array}{l} \frac{\partial J}{\partial W_{xo}} = x_t^T\frac{\partial J}{\partial \tilde{o}_{t}} \ \frac{\partial J}{\partial W_{xf}} = x_t^T\frac{\partial J}{\partial \tilde{f}_{t}} \ \frac{\partial J}{\partial W_{xi}} = x_t^T\frac{\partial J}{\partial \tilde{i}_{t}} \ \frac{\partial J}{\partial W_{xg}} = x_t^T\frac{\partial J}{\partial \tilde{g}_{t}} \\end{array} \right.⎩⎪⎪⎪⎨⎪⎪⎪⎧∂Who∂J=htT∂o~t 1∂J∂Whf∂J=htT∂f~t 1∂J∂Whi∂J=htT∂i~t 1∂J∂Whg∂J=htT∂g~t 1∂J⎩⎪⎪⎪⎨⎪⎪⎪⎧∂Wyh∂J=htT∂yt∂J∂Wcf∂J=ctT∂f~t 1∂J∂Wci∂J=ctT∂i~t 1∂J∂Wco∂J=ctT∂o~t∂J⎩⎪⎪⎪⎨⎪⎪⎪⎧∂Wxo∂J=xtT∂o~t∂J∂Wxf∂J=xtT∂f~t∂J∂Wxi∂J=xtT∂i~t∂J∂Wxg∂J=xtT∂g~t∂J

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约