1、简介

先回顾seq2seq模型, seq2seq模型有一个encoder和一个decoder。encoder的输入是英语, decoder把英语翻译成德语。encoder每次读入一个英语向量x, 在状态h中积累输入的信息。最后一个状态hm中积累了所有词向量x的信息。encoder输出最后ig状态hm, 把之前的状态向量全部扔掉。

image-20231205170348294

decoder RNN的初始状态s0等于encoder RNN的最后一个状态hm。hm包含了输入英语句子的信息, 通过hm, decoder就知道了这句英语。然后decoder就像文本生成器一样, 逐字生成一句德语。这句德语就是模型生成的翻译。

image-20231205170659672

但seq2seq模型有一个明显的缺陷, 就是如果输入的句子很长, 那么encoder会记不住完整的句子。encoder最后一个状态可能会漏掉一些信息。假如输入的英语里有个别词被忘记了, 那么decoder就无从得知完整的句子, 也就不可能产生正确的翻译。

如果你拿seq2seq模型做机器翻译, 你会得到这样的结果。 横轴是输入句子的长度。纵轴是BLEU, BLEU score是评价机器翻译好坏的标准。BLEU score越高, 说明机器翻译越准确。如果不用attention, 当输入的句子超过20个单词的时候, BLEU score就会往下掉。这是因为LSTM会遗忘, 造成翻译出错。

image-20231205171208705

用attention的话会得到这条红色的线。

image-20231205171251689

即便输入的句子很长, attention曲线中的BLEU score也不会往下掉。用attention改进seq2seq模型, 解决seq2seq遗忘问题。

attention是2015年提出的, 用attention, decoder每次更新状态时, 都会看一遍encoder所有状态,这样不会遗忘。

attention还会告诉decoder应该关注encoder哪个状态, 这就是attention名字的由来。

attention可以大幅度提高准确率, 但是attention的缺点是计算量非常大。

2、attention原理

在encoder已经结束工作后, attention与decoder同时开始工作, 回顾一下decoder的初始状态s0是encoder的最后一个状态hm。

encoder的所有状态h1、h2一直到hm都保留下来, 这里计算s0与每一个h的相关性。用align公式表示相关性, 计算encoder第i个状态hi与decoder当前状态s0的相关性。把结果记作αi, αi被称为权重weight。

image-20231205174131738

encoder一共有m个状态, 所以一共算出m个α。从α1一直到αm, 都是介于0到1之间的实数, 所以α加起来等于1。

下面看一下如何进行计算。有多种方法计算hi与s0的相关性。第一种方法是这样的, 把hi与so做concatination得到更高的向量,然后计算矩阵w与这个向量的乘积,再用tanh应用到向量的每个元素上。把每个元素都压缩到-1和1之间, tanh的输出还是一个向量。 最后计算向量v与刚计算出来的向量的内积, 两个向量内积是个实数, 记作αi。

image-20231205174849776

这里的向量v与矩阵w都是参数, 需要从训练数据中学习。

计算出α~i到α~m这m个实数之后, 对它们做softmax变换。把输出结果记为α1, α2, …, αm。由于α是softmax的输出, α1到αm都大于0, 而且相加等于1。这种计算权重α的方法是attention第一篇论文提出的。

image-20231205175259698

之后很多论文提出过其他计算权重的方法。

第二种计算权重的方法:

输入还是hi和s0, 第一步分别用两个参数矩阵wk和wq对两个输入向量做线性变换, 得到ki和q0这两个向量,这两个参数矩阵要从训练数据中学习。

第二步是计算ki与q0的内积, 把结果记作α~i, 由于有m个k向量, 所以得到m个α~i。

第三步对α’i到α‘m做softmax变换, 把输出记作α1到αm。 α1到αm都是正数, 而且相加等于1。这种计算权重的方法被transformer模型采用, transform模型是当前很多nlp问题的state of the art。

image-20231205175830385

刚才讲了两种方法计算hi与s0的相关性, 随便用哪种方法都会得到m个α。这些α被称为权重, 每个αi对应一个encoder状态hi。

利用这些权重α, 可以对这m个权重向量h做加权平均, 计算α1h1, 一直到αmhm。把加权平均的结果记作c0, c0称为context vector。每一个context vector都会对应一个decoder状态, context vector c0对应decoder状态s0。

image-20231206091948919

decoder读入向量x’1, 然后需要把状态更新为s1。

image-20231206092612386

具体怎么计算新的状态s1?回顾一下, 如果不用attention, 那么simple RNN是这样更新状态的。新的状态s1是输入x’1与旧的状态s0的函数。看一下simple RNN的计算公式, 把x‘1与s0做concatination, 然后乘到参数矩阵A’上, 加上intercept向量b, 得到新的状态s1。simple RNN在更新状态时, 只需知道新的输入x‘1与旧的状态s0, simple RNN并不会看encoder的状态。

image-20231206093039714

用attention的话, 更新decoder状态的时候, 需要用到context vector c0, 把x’t, s0, c0做concatination, 用它们来计算新的状态s1。

image-20231206093511051

已经更新了decoder的状态s1, 回忆一下c0是encoder所有状态h1到hm的加权平均。所以c0知道encoder的输入x1到xm的完整信息。decoder新的状态s1依赖于context vector c0, 这样一来, decoder也知道encoder完整的输入, 于是RNN遗忘问题就解决了。

image-20231206093854163

下一步计算context vector c1, 跟之前一样。

先计算权重α, αi是encoder第i个状态hi与decoder当前状态s1的相关性, 把decoder状态s1与encoder所有m个状态对比, 计算出m个权重记作α1到αm。

image-20231206094239683

注意, 虽然上一轮计算C0的时候算出了m个权重α, 但我们现在不能用那些α,必须要重新计算α。上一轮计算的是h与s0的相关性, 这一轮计算的是h与s1的相关性。 这里用了相同的符号α, 然而现在这m个权重α跟上一轮计算出来的α不一样, 现在不能重复使用上一轮算出来的α。

image-20231206094829890

有了权重α, 就可以计算新的context vector c1, c1是encoder的m个状态向量h1到hm的加权平均, c1等于α1h1一直加到αmhm。

image-20231206095359354

decoder接收新的输入x’2, 然后把状态s1更新到s2。

image-20231206095616450

s2是新输入x‘2、旧状态s1以及context vector c1的函数, 还是用这个公式来计算decoder新的状态。

image-20231206100033530

已经算出了新的状态s2, 下一步是计算新的context vector c2。

image-20231206100356914

把decoder当前状态s2与encoder所有状态h1到hm做对比, 计算出权重α1一直到αm。

image-20231206100526507

有了α1到αm这些权重, 把encoder的状态h1到hm做加权平均, 把加权平均的结果记作c2。

image-20231206101626812

然后再更新状态s3, 然后计算c3, 不断重复, 依次更新状态s, 计算context vector c, 再更新状态s再计算context vector c一直到结束。在计算context vector c的过程中, 一共计算了多少个α。

image-20231206101942514

每计算一个context vector c, 需要把decoder当前状态s与encoder所有m个状态h做对比, 计算出m个权重α1一直到αm,所以decoder每一轮更新都需要重新计算m个权重α。

假设decoder运行了t步, 那么一共计算了mt个权重α。所以attention的时间复杂度是mt。也就是encoder与decoder状态数量的乘积。这个时间复杂度很高的, attention避免了遗忘, 大幅度提高预测准确率, 但是代价也是巨大的。

image-20231206102325594

已经介绍完attention的原理, 现在用这个例子来说明权重α的实际意义。这张图下面encoder输入是英语, 上面是decoder。把英语翻译成德语, attention会把decoder每个状态与encoder每个状态做对比, 得到二者的相关性, 也就是权重α。

下图中用线连接每个decoder状态与encoder状态, 每条线对应一个权重α。粗的线表示α很大, 细的线表示α很小。线越粗, 说明相关性很大。

这条粗线可以这样解释, 法语里面的zone就是英语里面的area。所以这两个状态的相似度很高。每当decoder想要生成一个状态时, 都会看一遍encoder的所有状态。这些权重α告诉decoder应该关注什么地方, 这就是attention名字的由来。当decoder需要计算这个状态的时候, 权重α告诉decoder应该关注encoder的这个状态。 这帮助decoder产生正确的状态, 从而生成正确的法语单词。

image-20231206103127572

image-20231206105328780

3、总结

如果使用标准的seq2seq模型, decoder基于当前状态产生下一个状态, 这样产生的新状态可能已经遗忘了encoder的部分输入。如果使用attention, decoder在产生下一个状态前, 会先看一遍encoder的所有状态, 于是decoder就知道encoder的完整信息, 并不会遗忘。除了解决遗忘问题, attention还能告诉decoder应该关注encoder的哪一个状态, 这就是attention名字的由来。

attention可以大幅度提升seq2seq模型的表现。但attention也是有缺点的, 缺点就是计算量太大了。假设输入encoder序列的长度为m, decoder输出的序列长度为t。标准的seq2seq模型只需要让encoder读一遍输入序列, 之后就不会再看encoder的输入或者状态了, 然后让decoder依次生成输出的序列, 所以时间复杂度是m+t。

但attention中, 时间复杂度会高很多。decoder每更新一个状态都会把encoder的状态先看一遍, 所以时间复杂度是m。decoder自己有t个状态, 所以总的时间复杂度是mt。使用attention可以提升准确率, 但需要付出更多的计算量。