Attention
1、简介
先回顾seq2seq模型, seq2seq模型有一个encoder和一个decoder。encoder的输入是英语, decoder把英语翻译成德语。encoder每次读入一个英语向量x, 在状态h中积累输入的信息。最后一个状态hm中积累了所有词向量x的信息。encoder输出最后ig状态hm, 把之前的状态向量全部扔掉。
decoder RNN的初始状态s0等于encoder RNN的最后一个状态hm。hm包含了输入英语句子的信息, 通过hm, decoder就知道了这句英语。然后decoder就像文本生成器一样, 逐字生成一句德语。这句德语就是模型生成的翻译。
但seq2seq模型有一个明显的缺陷, 就是如果输入的句子很长, 那么encoder会记不住完整的句子。encoder最后一个状态可能会漏掉一些信息。假如输入的英语里有个别词被忘记了, 那么decoder就无从得知完整的句子, 也就不可能产生正确的翻译。
如果你拿seq2seq模型做机器翻译, 你会得到这样的结果。 横轴是输入句子的长度。纵轴是BLEU, BLEU score是评价机器翻译好坏的标准。BLEU score越高, 说明机器翻译越准确。如果不用attention, 当输入的句子超过20个单词的时候, BLEU score就会往下掉。这是因为LSTM会遗忘, 造成翻译出错。
用attention的话会得到这条红色的线。
即便输入的句子很长, attention曲线中的BLEU score也不会往下掉。用attention改进seq2seq模型, 解决seq2seq遗忘问题。
attention是2015年提出的, 用attention, decoder每次更新状态时, 都会看一遍encoder所有状态,这样不会遗忘。
attention还会告诉decoder应该关注encoder哪个状态, 这就是attention名字的由来。
attention可以大幅度提高准确率, 但是attention的缺点是计算量非常大。
2、attention原理
在encoder已经结束工作后, attention与decoder同时开始工作, 回顾一下decoder的初始状态s0是encoder的最后一个状态hm。
encoder的所有状态h1、h2一直到hm都保留下来, 这里计算s0与每一个h的相关性。用align公式表示相关性, 计算encoder第i个状态hi与decoder当前状态s0的相关性。把结果记作αi, αi被称为权重weight。
encoder一共有m个状态, 所以一共算出m个α。从α1一直到αm, 都是介于0到1之间的实数, 所以α加起来等于1。
下面看一下如何进行计算。有多种方法计算hi与s0的相关性。第一种方法是这样的, 把hi与so做concatination得到更高的向量,然后计算矩阵w与这个向量的乘积,再用tanh应用到向量的每个元素上。把每个元素都压缩到-1和1之间, tanh的输出还是一个向量。 最后计算向量v与刚计算出来的向量的内积, 两个向量内积是个实数, 记作αi。
这里的向量v与矩阵w都是参数, 需要从训练数据中学习。
计算出α~i到α~m这m个实数之后, 对它们做softmax变换。把输出结果记为α1, α2, …, αm。由于α是softmax的输出, α1到αm都大于0, 而且相加等于1。这种计算权重α的方法是attention第一篇论文提出的。
之后很多论文提出过其他计算权重的方法。
第二种计算权重的方法:
输入还是hi和s0, 第一步分别用两个参数矩阵wk和wq对两个输入向量做线性变换, 得到ki和q0这两个向量,这两个参数矩阵要从训练数据中学习。
第二步是计算ki与q0的内积, 把结果记作α~i, 由于有m个k向量, 所以得到m个α~i。
第三步对α’i到α‘m做softmax变换, 把输出记作α1到αm。 α1到αm都是正数, 而且相加等于1。这种计算权重的方法被transformer模型采用, transform模型是当前很多nlp问题的state of the art。
刚才讲了两种方法计算hi与s0的相关性, 随便用哪种方法都会得到m个α。这些α被称为权重, 每个αi对应一个encoder状态hi。
利用这些权重α, 可以对这m个权重向量h做加权平均, 计算α1h1, 一直到αmhm。把加权平均的结果记作c0, c0称为context vector。每一个context vector都会对应一个decoder状态, context vector c0对应decoder状态s0。
decoder读入向量x’1, 然后需要把状态更新为s1。
具体怎么计算新的状态s1?回顾一下, 如果不用attention, 那么simple RNN是这样更新状态的。新的状态s1是输入x’1与旧的状态s0的函数。看一下simple RNN的计算公式, 把x‘1与s0做concatination, 然后乘到参数矩阵A’上, 加上intercept向量b, 得到新的状态s1。simple RNN在更新状态时, 只需知道新的输入x‘1与旧的状态s0, simple RNN并不会看encoder的状态。
用attention的话, 更新decoder状态的时候, 需要用到context vector c0, 把x’t, s0, c0做concatination, 用它们来计算新的状态s1。
已经更新了decoder的状态s1, 回忆一下c0是encoder所有状态h1到hm的加权平均。所以c0知道encoder的输入x1到xm的完整信息。decoder新的状态s1依赖于context vector c0, 这样一来, decoder也知道encoder完整的输入, 于是RNN遗忘问题就解决了。
下一步计算context vector c1, 跟之前一样。
先计算权重α, αi是encoder第i个状态hi与decoder当前状态s1的相关性, 把decoder状态s1与encoder所有m个状态对比, 计算出m个权重记作α1到αm。
注意, 虽然上一轮计算C0的时候算出了m个权重α, 但我们现在不能用那些α,必须要重新计算α。上一轮计算的是h与s0的相关性, 这一轮计算的是h与s1的相关性。 这里用了相同的符号α, 然而现在这m个权重α跟上一轮计算出来的α不一样, 现在不能重复使用上一轮算出来的α。
有了权重α, 就可以计算新的context vector c1, c1是encoder的m个状态向量h1到hm的加权平均, c1等于α1h1一直加到αmhm。
decoder接收新的输入x’2, 然后把状态s1更新到s2。
s2是新输入x‘2、旧状态s1以及context vector c1的函数, 还是用这个公式来计算decoder新的状态。
已经算出了新的状态s2, 下一步是计算新的context vector c2。
把decoder当前状态s2与encoder所有状态h1到hm做对比, 计算出权重α1一直到αm。
有了α1到αm这些权重, 把encoder的状态h1到hm做加权平均, 把加权平均的结果记作c2。
然后再更新状态s3, 然后计算c3, 不断重复, 依次更新状态s, 计算context vector c, 再更新状态s再计算context vector c一直到结束。在计算context vector c的过程中, 一共计算了多少个α。
每计算一个context vector c, 需要把decoder当前状态s与encoder所有m个状态h做对比, 计算出m个权重α1一直到αm,所以decoder每一轮更新都需要重新计算m个权重α。
假设decoder运行了t步, 那么一共计算了mt个权重α。所以attention的时间复杂度是mt。也就是encoder与decoder状态数量的乘积。这个时间复杂度很高的, attention避免了遗忘, 大幅度提高预测准确率, 但是代价也是巨大的。
已经介绍完attention的原理, 现在用这个例子来说明权重α的实际意义。这张图下面encoder输入是英语, 上面是decoder。把英语翻译成德语, attention会把decoder每个状态与encoder每个状态做对比, 得到二者的相关性, 也就是权重α。
下图中用线连接每个decoder状态与encoder状态, 每条线对应一个权重α。粗的线表示α很大, 细的线表示α很小。线越粗, 说明相关性很大。
这条粗线可以这样解释, 法语里面的zone就是英语里面的area。所以这两个状态的相似度很高。每当decoder想要生成一个状态时, 都会看一遍encoder的所有状态。这些权重α告诉decoder应该关注什么地方, 这就是attention名字的由来。当decoder需要计算这个状态的时候, 权重α告诉decoder应该关注encoder的这个状态。 这帮助decoder产生正确的状态, 从而生成正确的法语单词。
3、总结
如果使用标准的seq2seq模型, decoder基于当前状态产生下一个状态, 这样产生的新状态可能已经遗忘了encoder的部分输入。如果使用attention, decoder在产生下一个状态前, 会先看一遍encoder的所有状态, 于是decoder就知道encoder的完整信息, 并不会遗忘。除了解决遗忘问题, attention还能告诉decoder应该关注encoder的哪一个状态, 这就是attention名字的由来。
attention可以大幅度提升seq2seq模型的表现。但attention也是有缺点的, 缺点就是计算量太大了。假设输入encoder序列的长度为m, decoder输出的序列长度为t。标准的seq2seq模型只需要让encoder读一遍输入序列, 之后就不会再看encoder的输入或者状态了, 然后让decoder依次生成输出的序列, 所以时间复杂度是m+t。
但attention中, 时间复杂度会高很多。decoder每更新一个状态都会把encoder的状态先看一遍, 所以时间复杂度是m。decoder自己有t个状态, 所以总的时间复杂度是mt。使用attention可以提升准确率, 但需要付出更多的计算量。