循环神经网络RNN

2023-06-20 10:59:29 来源:个人图书馆-汉无为

我们目前所学会的神经网络(包括CNN)的结构大致如下图所示。


(资料图片)

当前的输出向量y只取决于当前的输入向量x。

然而许多机器学习应用,比如机器翻译、语音识别、手势识别等应用涉及时间依赖,也就是说模型当前的输出不仅取决于当前的输入,而且还依赖于过去的输入。

再说一个更具体的应用——ChatGPT。

RNN是循环神经网络(Recurrent Neural Network)的简称。RNN的结构如下图所示。

t+2时刻的输出y_t+2不仅取决于t+2时刻的输入x_t+2,而且还依赖t+1时刻的隐藏状态s_t+1和t时刻的隐藏状态s_t。而s_t+1又依赖于t+1时刻的输入x_t+1和t时刻的隐藏状态s_t,s_t依赖于前一时刻的隐藏状态s_t-1和t时刻的输入x_t。

不过上述结构图中涉及的模型参数W_y,W_s和W_x都只有一套,并不会因为时间的变化而变化。

所以可以将上述结构图折叠起来可以简化为下图。

我们可以堆叠出多层的RNN。

我们通过反向传播(Backpropagation,BP)算法来训练CNN,类似的,我们通过基于时间的反向传播(Backpropagation Through Time,BPTT)算法训练RNN。

假设现在我们以梯度下降(Gradient Descent)方法更新模型参数。

假设我们计算出了t=3时刻的损失函数E_3,现在要根据损失函数更新W_y,W_s和W_x。

E_3关于W_y的梯度:

E_3关于W_s的梯度计算有点复杂,要找到E_3和W_s的关系,我们首先找到了s_3:

S_3又依赖之前的S_2和S_1,所以我们将之前状态的贡献累加起来,先考虑s_2:

再考虑s_1:

同样的计算E_3关于W_x的梯度也是要累加之前的贡献:

这样我们就可以通过BPTT来更新RNN中的参数了。

标签:
推荐阅读>