分享

pytorch循环神经网络RNN从结构原理到应用实例

 头号码甲 2022-10-01 发布于北京

一、 RNN概述

人工神经网络和卷积神经网络的假设前提都是:元素之间是相互独立的 ,但是在生活中很多情况下这种假设并不成立,比如你写一段有意义的话 “遇见一个人只需1秒,喜欢一个人只需3,秒,爱上一个人只需1分钟,而我却用我的[?]在爱你。” ,作为正常人我们知道这里应该填 “一生”,但之所以我们会这样填是因为我们读取了上下文,而普通的神经网络输入之间是相互独立的,网络没有记忆能力。扩展一下:训练样本是连续的序列且其长短不一,如一段连续的语音、一段连续的文本等,这些序列前面的输入与后面的输入有有一定的相关性,很难将其拆解为一个个单独的样本来进行DNN/CNN训练。

循环神经网络(Recurrent Neural Networks,简称RNN)广泛应用于:

  • 语义分析(Semantic Analysis):按照语法分析器识别语法范畴进行语义检查和处理,产生相应的中间代码或者目标代码
  • 情感分析(Sentiment Classification)
  • 图像标注(Image Captioning):对图片进行文本描述
  • 语言翻译(Language Translation)

二、RNN网络结构及原理

图中各个参数意义:

1)x(t)代表在序列索引号t时训练样本的输入。同样的,x(t−1)x(t+1)代表在序列索引号t−1t+1时训练样本的输入。

2)h(t)代表在序列索引号t时模型的隐藏状态。h(t)x(t)h(t−1)共同决定。

3)o(t)代表在序列索引号t时模型的输出。o(t)只由模型当前的隐藏状态h(t)决定。

4)L(t)代表在序列索引号t时模型的损失函数。

5)y(t)代表在序列索引号t时训练样本序列的真实输出。

6)U,W,V这三个矩阵是我们的模型的线性关系参数,它在整个RNN网络中是共享的,这点和DNN很不相同。 也正因为是共享了,它体现了RNN的模型的“循环反馈”的思想。 [1]

三、RNN前向传播原理

对于任何一个序列索引号t,隐藏状态\(h{(t)}\)\(h^{(t-1)}\)\(x^{(t)}\)得到:

\[h^{(t)} = \sigma(z^{(t)} = \sigma(Ux^{(t)}+Wh^{(t-1)}+b)) \]

其中σ为RNN的激活函数,b为偏置值(bias)

序列索引号为t的时候模型的输出\(o^{(t)}\)的表达式比较简单:

\[o^{(t)} = Vh^{(t)}+c \]

此时预测输出为:

\[\hat{y}^{(t)} = \sigma(o^{(t)}) \]

在上面这一过程中使用了两次激活函数(第一次获得隐藏状态\(h^{(t)}\),第二次获得预测输出\(\hat{y}^{(t)}\))通常在第一次使用tanh激活函数,第二次使用softmax激活函数

四、RNN反向传播推导

RNN的法向传播通过梯度下降一次次迭代得到合适的参数U、W、V、b、c。在RNN中U、W、V、b、c参数在序列的各个位置都是相同的,反向传播我们更新的是同样的参数。

对于RNN,我们在序列的每一个位置上都有损失,所以最终的损失L为:

\[L = \sum_{t=1}^{\tau}L^{(t)} \]

损失函数对更新的参数进行求偏导(注意我们这里使用的两个激活函数分别为softmaxtanh,使用的误差计算公式为交叉熵):

  • 首先考虑与损失函数直接相关的两个变量cV(即预测输出时的权值和偏置值),利用损失函数可以对这两个变量进行直接求偏导(即对softmax函数求导):

\[\frac{\partial{L}}{\partial{c}} = \sum_{t=1}^{\tau}\frac{\partial{L^{(t)}}}{\partial{c}} = \sum_{t =1}^{\tau}\hat{y}^{(t)}-y^{(t)} \]

\[\frac{\partial{L}}{\partial{V}} = \sum_{t=1}^{\tau}\frac{\partial{L^{(t)}}}{\partial{V}} = \sum_{t =1}^{\tau}(\hat{y}^{(t)}-y^{(t)})(h^{(t)})^T \]

  • 而损失函数对W、U、b的偏导数计算就比较复杂了:在反向传播时,某一序列位置t的梯度损失由当前位置的输出对应的梯度损失和序列索引位置t+1时的梯度损失两部分共同决定。
    从正向传播来看:

\[h^{(t+1)} = tanh(Ux^{(t+1)}+Wh^{(t)}+b)) \]

对于W、U、b在某一序列位置t的梯度损失需要反向传播一步步的计算。我们定义序列索引t位置的隐藏状态的梯度为:

\[\delta^{(t)} = \frac{\partial{L}}{\partial{(h^{(t)})}} \]

\(\delta^{(\tau+1)}\)递推\(\delta^{(t)}\)

\[\delta^{(t)} = (\frac{\partial{\delta^{(t)}}}{\partial{h^{(t)}}})^T \frac{\partial{L}}{\partial{o^{(t)}}} + (\frac{\partial{h^{(t+1)}}}{\partial{h^{(t)}}})^T \frac{\partial{L}}{\partial{h^{(t+1)}}} = V^T(\hat{y}^{(t)}-y^{(t)}) +W^Tdiag(1-(h^{(t+1)})^2)\delta^{(t+1)} \]

对于\(\delta{(\tau)}\),其后面没有其他的索引(最后一个输入),因此:

\[\delta^{(\tau)} = (\frac{\partial{\delta^{(\tau)}}}{\partial{h^{(\tau)}}})^T \frac{\partial{L}}{\partial{o^{(\tau)}}} = V^T(\hat{y}^{(\tau)}-y^{(t)}) \]

根据\(\delta{(t)}\),我们就可以计算W、U、b了:

\[\frac{\partial{L}}{\partial{W}} = \sum_{t=1}^{\tau}diag(1-(h^{(t)})^2)\delta^{(t)}(h^{(t-1)})^T \]

\[\frac{\partial{L}}{\partial{b}} = \sum_{t=1}^{\tau}diag(1-(h^{(t)})^2)\delta^{(t)} \]

\[\frac{\partial{L}}{\partial{V}} = \sum_{t=1}^{\tau}diag(1-(h^{(t)})^2)\delta^{(t)}(x^{(t)})^T \]

五、RNN梯度消失问题


假设时间序列只有三段,\(S_0\)为给定值,神经元没有激活函数,而RNN按照最简单的前向传播:

\[S_1 = W_xX_1 + W_sS_0+b_1 ; O_1 = W_0S_1 +b2 \]

\[S_2 = W_xX_2 + W_sS_1+b_1 ; O_2 = W_0S_2 +b2 \]

\[S_3 = W_xX_3 + W_sS_2+b_1 ; O_3 = W_0S_3 +b2 \]

假设在t=3时刻,损失函数为$$L_3 = \frac{1}{2}(Y_3-O_3)^2$$
对于一次训练,其损失函数值是累加的:$$L = \sum_{t = 0}{T}L_t$$
此处利用反向传播公式仅对Wx、Ws、W0求偏导数(Wx、Ws与输出Output相关,并非直接求损失函数Loss的偏导,在第四部分也已经说明了:

\[\frac{\partial{L}_3}{\partial{W}_0} = \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{W}_0} \]

\[\frac{\partial{L}_3}{\partial{W}_x} = \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{S}_3} \frac{\partial{S}_3}{\partial{W}_x} + \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{S}_3} \frac{\partial{S}_3}{\partial{S}_2}\frac{\partial{S}_2}{\partial{W}_x}+ \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{S}_3} \frac{\partial{S}_3}{\partial{S}_2}\frac{\partial{S}_2}{\partial{S}_1}\frac{\partial{S}_1}{\partial{w}_x} \]

\[\frac{\partial{L}_3}{\partial{W}_s} = \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{S}_3} \frac{\partial{S}_3}{\partial{W}_s} + \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{S}_3} \frac{\partial{S}_3}{\partial{S}_2}\frac{\partial{S}_2}{\partial{W}_s}+ \frac{\partial{L}_3}{\partial{O}_3} \frac{\partial{O}_3}{\partial{S}_3} \frac{\partial{S}_3}{\partial{S}_2}\frac{\partial{S}_2}{\partial{S}_1}\frac{\partial{S}_1}{\partial{w}_s} \]

从这冗长的公式中可以看见用梯度下降法对损失函数求W0的偏导数其没有很长的依赖(就是公式很短、求解简单)但是对于WxWs的公式就非常长了,上面仅仅推到了三层网络结构就已经如此繁杂了,推导任意时刻损失函数关于WxWs的偏导数公式:

\[\frac{\partial{L}_t}{\partial{W}_x} = \sum_{k=0}^{t}\frac{\partial{L}_t}{\partial{O}_t}\frac{\partial{O}_t}{\partial{S}_t}(\prod_{j=k+1}^{t}\frac{\partial{S}_j}{\partial{S}_{j-1}})\frac{\partial{S}_k}{\partial{W}_x} \]

\[\frac{\partial{L}_t}{\partial{W}_s} = \sum_{k=0}^{t}\frac{\partial{L}_t}{\partial{O}_t}\frac{\partial{O}_t}{\partial{S}_t}(\prod_{j=k+1}^{t}\frac{\partial{S}_j}{\partial{S}_{j-1}})\frac{\partial{S}_k}{\partial{W}_s} \]

如果再加上激活函数:$$S_j = tanh(W_xX_j + W_sS_{j-1}+b_1)$$

则$$\prod_{j=k+1}^{t}\frac{\partial{S}j}{\partial{S}{j-1}} = \prod_{j=k+1}^{t}W_s tanh^{'}$$

激活函数tanh[2]:

\[f(x) = tanh(x) = \frac{e^x - e^{-x}}{e^x + e^{-x}} \]

tanh函数导数:

\[f(x)^{'} = 1 - (tanh(x))^2 \]

tanh函数及其导数

根据激活函数及其导数的图像可见 [3]

  • \[tanh^{'}(x) ≤ 1 \]

  • 绝大部分情况下,tanh的导数都是小于1的。很少情况出现:

\[W_xX_j + W_sS_{j-1} + b_1 = 0 \]

  • 如果Ws是一个大于0小于1的值,当t很大的时候

\[\prod_{j=k+1}^{t}W_s tanh^{'} --> 0 \]

  • 如果Ws是一个很大的值,当t很大的时候

\[\prod_{j=k+1}^{t}W_s tanh^{'} --> ∞ \]

六、消除梯度爆炸和梯度消失

在公式:

\[\frac{\partial{L}_t}{\partial{W}_x} = \sum_{k=0}^{t}\frac{\partial{L}_t}{\partial{O}_t}\frac{\partial{O}_t}{\partial{S}_t}(\prod_{j=k+1}^{t}\frac{\partial{S}_j}{\partial{S}_{j-1}})\frac{\partial{S}_k}{\partial{W}_x} \]

\[\frac{\partial{L}_t}{\partial{W}_s} = \sum_{k=0}^{t}\frac{\partial{L}_t}{\partial{O}_t}\frac{\partial{O}_t}{\partial{S}_t}(\prod_{j=k+1}^{t}\frac{\partial{S}_j}{\partial{S}_{j-1}})\frac{\partial{S}_k}{\partial{W}_s} \]

导致梯度消失和梯度爆炸的原因在于:

\[\prod_{j=k+1}^{t}\frac{\partial{S}_j}{\partial{S}_{j-1}} \]

消除这个部分的影响一个考虑是使得

\[\frac{\partial{S}_j}{\partial{S}_{j-1}} ≈ 1 \]

另一种是使得:

\[\frac{\partial{S}_j}{\partial{S}_{j-1}} ≈ 0 \]


  1. 循环神经网络(RNN)模型与前向反向传播算法 ↩︎

  2. Tanh激活函数及求导过程 ↩︎

  3. RNN梯度消失和爆炸的原因 ↩︎

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多