跳转至

RNN进阶知识总结

Truncated BPTT

RNN基本公式为:

\[ \begin{split}\begin{aligned} \mathbf{h}_t &= \mathbf{W}_{hx} \mathbf{x}_t + \mathbf{W}_{hh} \mathbf{h}_{t-1},\\ \mathbf{o}_t &= \mathbf{W}_{qh} \mathbf{h}_{t},\\ L &= \frac{1}{T} \sum_{t=1}^T l(\mathbf{o}_t, y_t) \end{aligned}\end{split} \]

其中:

\[ \mathbf{W}_{hx} \in \mathbb{R}^{h \times d} , \mathbf{W}_{hh} \in \mathbb{R}^{h \times h} , \mathbf{W}_{qh} \in \mathbb{R}^{q \times h} \]

接下来求偏导,先求关于 \(\mathbf{W}_{qh}\) 的,比较简单:

首先计算中间变量:

\[ \frac{\partial L}{\partial \mathbf{o}_t} = \frac{\partial l (\mathbf{o}_t, y_t)}{T \cdot \partial \mathbf{o}_t} \in \mathbb{R}^q \]

所以:

\[ \frac{\partial L}{\partial \mathbf{W}_{qh}} = \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_t}, \frac{\partial \mathbf{o}_t}{\partial \mathbf{W}_{qh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{o}_t} \mathbf{h}_t^\top = \sum_{t=1}^T \frac{\partial l (\mathbf{o}_t, y_t)}{T \cdot \partial \mathbf{o}_t} \mathbf{h}_t^\top \]

由公式可以全部直接求出。

下面求中间变量 \(\partial L/\partial \mathbf{h}_t \in \mathbb{R}^h\) 。为了求这个,我们先从最后时间T开始求起:

\[ \frac{\partial L}{\partial \mathbf{h}_T} = \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_T}, \frac{\partial \mathbf{o}_T}{\partial \mathbf{h}_T} \right) = \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_T} \]

由于目标函数 \(L\) 依赖于 \(\mathbf{h}_t\),而 \(\mathbf{h}_t\) 又依赖于 \(\mathbf{h}_{t+1}\)\(\mathbf{o}_t\),所以我们由链式法则:

\[ \frac{\partial L}{\partial \mathbf{h}_t} = \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}_{t+1}}, \frac{\partial \mathbf{h}_{t+1}}{\partial \mathbf{h}_t} \right) + \text{prod}\left(\frac{\partial L}{\partial \mathbf{o}_t}, \frac{\partial \mathbf{o}_t}{\partial \mathbf{h}_t} \right) = \mathbf{W}_{hh}^\top \frac{\partial L}{\partial \mathbf{h}_{t+1}} + \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_t} \]

进一步递归分析,得:

\[ \frac{\partial L}{\partial \mathbf{h}_t}= \sum_{i=t}^T {\left(\mathbf{W}_{hh}^\top\right)}^{T-i} \mathbf{W}_{qh}^\top \frac{\partial L}{\partial \mathbf{o}_{T+t-i}} \]

接下来就可以计算关于 \(\mathbf{W}_{hx}\)\(\mathbf{W}_{hh}\) 的偏导了,得:

\[ \begin{split}\begin{aligned} \frac{\partial L}{\partial \mathbf{W}_{hx}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}_t}, \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_{hx}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{h}_t} \mathbf{x}_t^\top,\\ \frac{\partial L}{\partial \mathbf{W}_{hh}} &= \sum_{t=1}^T \text{prod}\left(\frac{\partial L}{\partial \mathbf{h}_t}, \frac{\partial \mathbf{h}_t}{\partial \mathbf{W}_{hh}}\right) = \sum_{t=1}^T \frac{\partial L}{\partial \mathbf{h}_t} \mathbf{h}_{t-1}^\top, \end{aligned}\end{split} \]

其中 \(\frac{\partial L}{\partial \mathbf{h}_t}\) 需要带入上面

我们发现,\(\frac{\partial L}{\partial \mathbf{h}_t}\)\(\mathbf{W}_{hh}^\top\)幂非常大。在这个幂中,小于1的特征值将会消失,大于1的特征值将会发散。这在数值上是不稳定的,表现形式为梯度消失或梯度爆炸。

首先我们还可以通过“梯度裁剪”的方法来限制梯度 \(\mathbf{g}\) 不超过 \(\theta\) : $$ \mathbf{g} \leftarrow \min\left(1, \frac{\theta}{|\mathbf{g}|}\right) \mathbf{g} $$

此外,我们可以采用Truncated BPTT

在实际应用中,采用常规截断或者随机截断。

下面是三种方法的示意图:

truncated-bptt

常规截断

每次处理一个时间步,每前向传播 \(k_1\) 步,后向传播 \(k_2\) 步。

for t from 1 to T do 
    Run the RNN for one step, computing h_t and z_t
    if t divides k_1 then  
        Run BPTT, from t down to t−k_2
    end if
end for

参数确定: - TBPTT(n, n): 传统的BPTT

  • TBPTT(1, n): 每向前处理一个时间步,便后向传播所有已看到的时间步。(Williams and Peng提出的经典的TBPTT)

  • TBPTT( \(k_1\) ,1): 网络并没有足够的时序上下文来学习,严重的依赖内部状态和输入。

  • TBPTT( \(k_1\) , \(k_2\)), where \(k_1 < k_2 < n\) :对于每个序列,都进行了多次更新,可以加速训练。

  • TBPTT( \(k_1\) , \(k_2\)), where \(k_1 = k_2\) : 类似于预处理的时候分batch,TensorFlow默认是用这种

在以下博文中,有上面五种方式的对比: https://r2rt.com/styles-of-truncated-backpropagation.html

随机截断

用一个随机变量替换 \(\partial h_t/\partial w_h\), 使用序列 \(\xi_t\) 来实现。序列预定义了 \(0 \leq \pi_t \leq 1\) ,其中 \(P(\xi_t = 0) = 1-\pi_t\)\(P(\xi_t = \pi_t^{-1}) = \pi_t\) , 因此 \(E[\xi_t] = 1\)

效果并不好。