SwinTransformer
1、简介
Swin Transformer是2012年微软研究院在ICCV上发表的一篇文章, 并荣获2021 ICCV最佳论文称号。 Swin Transformer网络是Transformer模型在视觉领域的又一次碰撞。
论文名称:Swin Transformer: Hierarchical Vision Transformer using Shifted Windows
论文地址: Swin Transformer
2、论文整体框架
在正文开始之前, 先对比一下Swin transformer与Vision Transformers, 下图是swin Transformers文章中给出的图, 左边是Swin Transformers,右边是之前的Vision Transformers,通过对比可以发现有两点不同。
- Swin Transformers使用了之前类似卷积神经网络的层次化构建方法(Hierachical feature maps), 比如特征图尺寸有对图像下采样4倍的, 8倍的以及16倍的, 这样的backbone有助于在此基础上构建目标检测, 实例分割等任务。 而之前的Vision Transformer中一开始就是直接下采样16倍, 后面的特征图也是维持这样的下采样率不变。
- 在Swin Transformer中使用了Window Multi-Head Self-Attention(W-MSA)的概念, 比如在4倍下采样和8倍下采样中, 将特征图划分成多个不相交的区域(windows), 并且Multi-Head Self-Attention只在每个窗口内进行, 相对于Vision Transformer中直接对整个特征图进行Multi-Head Self-Attention, 这样做的目的是能够减少计算量, 尤其在浅层特征图很大的时候。这样做虽然减少了计算量, 但是也会隔绝不同窗口之间的信息传递, 所以在论文中作者又提出Shifted Windows Multi-Head Self-Attention(SW-MSA)的概念, 通过该方法能够让信息在相邻的窗口中进行传递。
接下来,简单看下原论文中给出的关于Swin Transformer网络的架构图, 通过图(a)可以看到整个架构的基本历程如下:
首先将图片输入到Patch Partition模块中进行分块, 即每4×4相邻的像素为一个Patch, 然后再channel方向展平(flatten)。假设输入的是RGB三通道图片, 那么每个patch就有4×4=16个像素, 然后每个像素有R/G/B三个值, 所以展平后是16×3=48, 所以通过patch partition后图像shape由[H, W, 3]变成了[H/4, W/4, 48]。 然后通过Linear Embedding层对每个像素的channel数据做线性变换, 由48变成C, 即图像shape再由[H/4, W/4, 48]变成了[H/4 ,W/4, C]。 其实在源码中Patch Patittion和Linear Embedding就是直接通过一个卷积实现的, 和之前VisionTransformer中讲的Embedding层结构一模一样。
然后通过四个Stage构建不同大小的特征图, 除了Stage1中先通过Linear Embedding层外, 其他三个都是先通过一个Patch Merging层进行下采样。 然后都是重复堆叠Swin Transformer Block。注意这里的Block有两种结构, 如图(b)所示, 这两种结构的不同之处在于一个使用了W-MSA结构, 一个使用了SW-MSA结构。 而且这两种结构都是成对使用的,先使用一个W-MSA结构再使用一个SW-MSA结构。 所以会发现堆叠SwinTransformer Block的次数都是偶数(成对使用)
最后对于分类网络, 还会接上一个Layer Norm层, 全局池化层以及全连接层得到最终输出。
下面分别对Patch Merging , W-MSA, SW-MAS以及使用到的相对位置偏置(relative position bias)进行详解。 关于Swin Transformer Block中的MLP结构和Vision Transformer中的结构一样。
3、Patch Merging详解
前面提到, 在每个Stage中首先经过一个Patch Mering层进行下采样(stage1除外). 如下图所示, 假设输入Patch Merging的是一个4×4大小的单通道特征图(feature map), Patch Merging会将每个2×2的相邻像素划分为一个patch, 然后将每个patch中相同位置(同一颜色)像素拼接到一起就得到了4个feature map。接着将这四个feature map在深度方向进行concat拼接, 然后通过一个LayerNorm层, 最后通过一个全连接层在feature map的深度方向做线性变换, 将feature map的深度有C变成C/2, 通过这个简单的例子可以看出, 通过Patch Merging层后, feature map的高和宽会减半, 深度会翻倍。
4、W-MSA详解
引入Windows Multi-Head Self-Attention(W-MSA)模块是为了减少计算量, 如下如所示, 左侧使用的是普通的Multi-Head Self-Attention(MSA)模块, 对于feature map中的每个像素(或称作token, patch), 在Self-Attention计算过程中需要和所有的像素去计算, 但在图右侧, 在使用Windows Multi-Head Self-Attention (W-MSA)模块时, 首先将feature map按照MxM(例子中M=2)大小划分成一个个Windowns, 然后单独对每个windows内部进行Self-Attention。
两者的计算量具体差多少呢?原论文中给出了下面两个公式, 这里忽略了Softmax的计算复杂度:
h, w代表feature map的高度和宽度
C代表feature map的深度
M代表每个窗口(Windows)的大小
回顾一下Self-Attention的公式
MSA模块计算量
对于feature map中的每个像素(或者称为token, patch), 都要通过生成对应的query(q), key(q)以及value(v), 这里假设q, k, v的向量长度与feature map的深度C保持一致, 那么对应所有像素生成Q的过程如下式:
Q^{hw \times C}.K^{T(c\times hw)}=X^{hw \ time hw}
那么对应单头的Self-Attention模块, 总共需要。而实际使用过程中, 使用的是多头的Multi-Head Self-Attention模块, 在之前的文章中尽显过实验对比, 多头注意力模块相比单头注意力模块的计算量仅多了最后一个融合矩阵的计算量
所以总共加起来是:
5、W-MSA模块计算量
对于W-MSA模块首先要将feature map划分到一个窗口(windows)中, 假设每个窗口的高宽都是M, 那么总共会得到, 对于每个窗口内使用多头注意力模块, 刚刚计算高为h, 宽为w, 深度为C的feature map的计算量为, 这里每个窗口的高为M, 宽为M, 带人公式得:
又因为有个窗口, 则:
故使用W-MAS模块的计算量为:
假设feature map的h, w都是112, M=7, C=112, 采用W-MSA模块相比MSA模块能够减少40124743680FLOPs
6、SW-MAS详解
采用W-MSA模块时, 只会在每个窗口内进行注意力计算, 所以窗口与窗口之间是无法进行消息传递的。 为了解决这问题, 作者引入了shifted windows Multi-Head Self-Attention(SW-MSA)模块, 即进行偏移的W-MSA。如下图所示, 左侧使用的是刚刚讲的W-MSA(假设是第L层), 那么根据之前介绍的W-SMA和SW-MAS是称帝使用的, 那么第L+1层使用的就是SW-MSA(右侧图)。 根据左右两幅图对比能够发信啊窗口发生了偏移(可以理解为创库从左上角分别向右侧和下方各偏移了个像素)。 看下偏移后的窗口, 比如对于第一行第2列的2x4的窗口, 它能够使用第L层的第一排的两个窗口信息进行交流。 再比如, 第二行第二列的4x4的窗口, 它能够使第L层的四个窗口信息进行交流, 其他的同理, 那么这就解决了不同窗口之间无法进行信息交流的问题。
根据上图, 可以发现通过将窗口进行偏移后, 由原来的4个窗口变成9个窗口了, 后面又要对每个窗口内部进行MSA, 这样做感觉又变麻烦了, 为解决这个麻烦, 作者又提出了Efficient batch computation for shifted configuration, 一种更加高效的计算方法, 下面是原论文给出的示意图。
为更好的解释, 作如下示意图, 下图左侧是刚刚通过偏移窗口后得到的新窗口, 右侧是为了方便大家理解, 在每个窗口上加 另一个标识, 然后0对应的窗口表示区域为A, 3和6对应的窗口标记区域为B, 1和2对应的窗口标记为区域C。
将A和C移到最下方
然后再将A和B移到最右侧
移动完成后,4是一个单独的窗口, 将5和3合并到一起, 7和1合并成一个窗口, 8, 6, 2, 0合并成一个窗口。 这样又和原来一样是4个4×4的窗口了。 所以保证计算量是一样的。 这里肯定会有人疑惑, 把不同的区域合并在一起(比如5和3)进行Multi-Head Self-Attention,这信息不就乱窜了吗?是的, 为了放置这个问题, 实际计算中使用的是masked MSA即带蒙板mask的Multi-Head Self-Attention, 这样就能够设置蒙板来隔绝不同区域的信息了。 关于mask如何使用, 可以看下图, 下图以上图的区域5和区域3为例。
对于该窗口内的每一个像素(或称token, patch)在进行Multi-Head Self-Attention计算时, 都要先生成对应的query(q), key(k), value(v)。 假设对于上图的像素0而言, 得到后要与每一个像素的k进行匹配(match), 假设代表与像素0对应的进行匹配的结果, 那么同理可以得到至。按照普通的MSA计算, 接下来就是Softmax操作了, 但对于这里的masked Multi-Head Self-Attention, 像素0属于区域5的, 我们只想让它和区域5的像素进行匹配。那么将像素0与区域3中的所有像素匹配结果都减去100(例如等等)。 由于的值很小, 一般都是零点几的数, 将其中一些数减去100后再通过softmax得到对应的权重都等于0了。 所以对于像素0而言实际上还是只和区域5内的像素进行了Multi-Head Self-Attention, 那么对于其他像素也是同理。 注意,在计算完后还要把数据给挪回到原来的位置上(例如上述的A, B, C区域
7、Relative Position Bias详解
关于相对位置偏置, 论文中没有细讲, 只是说使用了相对位置偏置后能够带来明显的提升。
相对位置偏置是如何使用的,论文中提供了如下的公式:
由于论文中并没有详细讲解这个相对位置偏置, 所以根据阅读源码做了简单的总结。 如下图, 假设输入的feature map高和宽都是2, 那么首先可以通过构建每个像素的相对位置(左下方的矩阵), 对于每个像素的绝对位置是使用行号和列号表示。 比如蓝色的像素对应的是第0行和第0列, 所以绝对位置索引是(0, 0)。
下面来看相对位置索引, 首先看蓝色的像素, 在蓝色像素使用q与所有像素k进行匹配过程中, 是以蓝色像素为参考点。然后使用蓝色像素的绝对位置索引与其他位置索引进行相减, 就得到其他位置相对于蓝色像素的相对位置索引。例如黄色像素的绝对位置索引是(0, 1), 则它相对于蓝色像素的相对位置索引为(0, 0) - (0, 1)= (0, -1)。同理可以得到其他位置相对于蓝色像素的相对位置索引矩阵。 同样, 也能得到相对黄色, 红色以及绿色像素的相对位置索引矩阵。 接下来将每个相对位置索引矩阵按行展开, 并拼接在一起可以得到下面的4×4矩阵。
注意, 这里描述的都是相对位置所以, 并不是相对位置偏置参数。因为后面后面会根据相对位置索引去取对应的参数。比如黄色像素在蓝色像素的右边, 所以相对蓝色像素相对位置索引为(0, -1)。绿色像素在红色像素的右边, 所以相对红色像素的相对位置索引为(0,-1)。可以发现二者的相对位置索引都是(0, -1), 所以它们使用相同的相对位置偏置参数。
在源码中作者为了方便吧二维索引转成一维索引, 具体怎么转的呢?, 简单的想法是直接把行、列索引相加不就可以变成一维了吗?比如上面的相对位置索引中有(0, -1)和(-1, 0), 在二维的相对位置索引中明显代表不同的位置, 但如果简单相加都等于-1那不就出问题了吗, 下面看看源码中是怎么做到的。 首先在原始的相对位置索引上加上M-1(M为窗口的大小, 在本例中M=2), 加上之后索引中就不会有负数了。
然后将所有的行标都乘以2M-1
最后将行标和列标都相加, 这样即保证了相对位置关系, 而且不会出现上述0 + (-1) = (-1) + 0的问题了。
上面提到, 之前计算的是相对位置索引, 并不是相对位置偏置参数, 真正使用到的可训练参数B保存在relative position bias table表中的额, 这个表的长度等于(2M-1)×(2M-1)的。 那么上述公式中的相对位置偏置参数B是根据上面的相对位置索引表查relative position bias table表得到的。 如下图所示。
8、模型配置参数
回顾一下swin-transformer的网络架构
下表是原论文给出的关于不同swin Transformers的配置, T(tiny), S(small), B(Base), L(Large), 其中:
win. sz. 7×7表示使用的窗口(Windows)的大小
dim表示feature map的channel深度(或者说token的向量长度)
head表示多头注意力模块中head的个数