计算 LSTM 的参数量
目录
理论上的参数量
之前翻译了 Christopher Olah 的那篇著名的 Understanding LSTM Networks,这篇文章对于整体理解 LSTM 很有帮助,但是在理解 LSTM 的参数数量这种细节方面,略有不足。本文就来补充一下,讲讲如何计算 LSTM 的参数数量。
首先来回顾下 LSTM。一层 LSTM 如下:
这里的 $x_t$ 实际上是一个句子的 embedding(不考虑 batch 维度),shape 一般为 [seq_length, embedding_size]
。图中的 $A$ 就是 cell,$x_t$ 中的词依次进入这个 cell 中进行处理。可以看到其实只有这么一个 cell,所以每次词进去处理的时候,权重是共享的,将这个过程平铺展开,就是下面这张图了:
实际上我觉得这里 $x_t$ 并不准确,第一个 $x_t$ 应该指的是整句话,而第二个 $x_t$ 应该指的是这句话中最后一个词,所以为了避免歧义,我认为可以将第一个 $x_t$ 重命名为 $x$,第二个仍然保留,即现在 $x$ 表示一句话,该句话有 $t+1$ 个词,$x_t$ 表示该句话的第 $t+1$ 个词,$t \in [0, t]$。
始终要记住这么多 $A$ 都是一样的,权重是一样的,$x_0$ 到 $x_t$ 是一个个词,每一次的处理都依赖前一个词的处理结果,这也是 RNN 系的网络难以像 CNN 一样并行加速的原因。同时, 这就像一个递归过程,如果把求 $h_t$ 的公式展开写,$A$ 里的权重记为 $W$,那么就会发现需要 $t$ 个 $W$ 相乘,即 $W^t$,这是非常恐怖的:
一个不那么小的数被多次相乘之后会变得很小,一个不那么大的数被多次相乘之后会变得很大。所以,这也是普通 RNN 容易出现梯度消失/爆炸的问题的原因。
扯远了点。
那么 LSTM 的参数很明显了,就是这个 $A$ 中的参数。这个 $A$ 内部具体是这样的:
从这张图来理解参数的数量你可能有点懵逼,一步一步来看,实际上这里面有 4 个非线性变换(3 个 门 + 1 个 tanh),每一个非线性变换说白了就是一个两层的全连接网络。重点来了,第一层是 $x_i$ 和 $h_i$ 的结合,维度就是 embedding_size + hidden_size
,第二层就是输出层,维度为 hidden_size
,所以该网络的参数量就是:
1 | (embedding_size + hidden_size) * hidden_size + hidden_size |
一个 cell 有 4 个这样结构相同的网络,那么一个 cell 的总参数量就是直接 × 4:
1 | ((embedding_size + hidden_size) * hidden_size + hidden_size) * 4 |
注意这 4 个权重可不是共享的,都是独立的网络。
所以,一般来说,一层 LSTM 的参数量计算公式是:
其中 4 表示有 4 个非线性映射层,$d_h + d_x$ 即 Understanding LSTM Networks 中的 $[h_{t-1}, x_t]$ 的维度,后面的 $d_h$ 表示 bias 的数量。所以,LSTM 层的参数数量只与输入维度 $d_x$ 和输出维度 $d_h$ 相关,和普通全连接层相同。
那么显而易见,一层双向 LSTM 的参数量就是上述公式 × 2。
TensorFlow 中的实现
在 TensorFlow 中,这些 $d_x$、$d_h$ 如何与代码对应上呢?
我们可以如下实现一个简单的以 LSTM 为核心的网络:
1 | import tensorflow as tf |
输入如下:
1 | Model: "sequential" |
可以看到 TF 给出的 LSTM 层参数量是 49408。我们来根据上面的公式验证下。
- $d_x$:输入维度,在这里就对应于 128,就是词向量维度。
- $d_h$:输出维度,在这里就是 LSTM 的参数 64,在 TF 这里叫
units
。
所以,参数量就是 $4 \times \left[64 \times \left(64+128\right) + 64 \right] = 49408$,和 TF 给出的一样。
另外,tf.keras.layers.LSTM()
的默认输出大小为 [batch_size, units]
,就是只使用最后一个 time step 的输出。假如我们想要得到每个 time step 的输出($h_0,\cdots,h_t$)和最终的 cell state($C_t$),那么我们可以指定另外两个参数 return_sequences=True
和 return_state=True
:
1 | inputs = tf.random.normal([64, 100, 128]) # [batch_size, seq_length, embedding_size] |
输出:
1 | whole_seq_output.shape=TensorShape([32, 100, 64]) # 100 表示有 100 个词,即 100 个 time step |
OK,LSTM 的参数量应该挺清晰了,欢迎在评论区留下你的想法。😋
Reference
- Counting No. of Parameters in Deep Learning Models by Hand
- deep learning - Number of parameters in an LSTM model - Data Science Stack Exchange
- machine learning - How to calculate the number of parameters of an LSTM network? - Stack Overflow
- tensorflow - In Keras, what exactly am I configuring when I create a stateful
LSTM
layer with Nunits
? - Stack Overflow - 理解 LSTM 网络 · Alan Lee
- Recurrent Neural Networks (RNN) with Keras | TensorFlow Core
- LSTM is dead. Long Live Transformers! - YouTube