【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2

目录

名称:Pyramid Vision Transformer: A Versatile Backbone for Dense Prediction without Convolutions
论文:原文
代码:官方代码

笔记参考:
1. 语义分割中的Transformer(第三篇):PVT — 用于密集预测任务的金字塔 Vision Transformer
2 作者本人的解释!!-大白话Pyramid Vision Transformer
3 金字塔Transformer,更适合稠密预测任务的Transformer骨干架构
4. 对pvt的思考!
5 简洁版-PVT(Pyramid Vision Transformer)算法整理
6. 翻译版

  1. 简述

  2. 之前的所总结的ViT backbone,本身并没有针对视觉中诸如分割、检测等密集预测型的任务的特定,设计合适结构。
    后续 SERT等论文也只是简单的 将VIT作为Encoder,将其提取到的单尺度特征通过一些简单的Decoder的处理, 验证了transformer在语义分割任务上的效果。
    但是,我们知道,在语义分割任务上,多尺度的特征是非常重要的,因此在PVT中提出了一种能够 提取多尺度特征的vision transformer backbone

    【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2

; 2.主要工作

2.1 ViT遗留的问题

ViT的结构是下面这样的,它和原版Transformer一样是柱状结构的 。这就意味着,
1) 缺乏多尺度特征
ViT 输出的特征图和输入大小基本保持一致。这导致ViT作为Encoder时,只能输出单尺度的特征。
它全程只能输出16-stride或者32-stride的feature map;
2) 计算开销剧增
一旦输入图像的分辨率稍微大点,占用显存就会很高甚至显存溢出。
分割和检测相对于分类任务而言,往往需要较大的分辨率图片输入。
因此,一方面,我们需要相对于分类任务而言 划分更多个patch才能得到相同粒度的特征。如果仍然保持同样的patch数量,那么特征的粒度将会变粗,从而导致性能下降
另一方面,我们知道, Transformer的计算开销与token化后的patch数量正相关, patch数量越大,计算开销越大。所以,如果我们增大patch数量,可能就会让我们本就不富裕的计算资源雪上加霜。
以上是ViT应用于密集预测任务上的第一个缺陷。

解决
针对方案简单粗暴:

  • 输出分辨率不够,那么就加大;
  • patch token序列太长,导致attention矩阵的计算量太大,那么就针对性的缩减总体序列长度,或者是 仅仅缩减k和v的长度(如下图)。

2.2 引入金字塔结构

计算机视觉中CNN backbone经过多年的发展,沉淀了一些通用的设计模式。
最为典型的就是金字塔结构。
简单的概括就是:

1)feature map的 分辨率随着网络加深,逐渐减小

2)feature map的 channel数随着网络加深,逐渐增大
几乎所有的密集预测(dense prediction)算法都是围绕着特征金字塔设计的

这个结构怎么才能引入到Transformer里面呢?
最终还是发现: 简单地堆叠多个独立的Transformer encoder效果是最好的。
然后我们就得到了PVT,如下图所示。在每个Stage中 通过Patch Embedding来逐渐降低输入的分辨率

其中,除了金字塔结构以外。为了 可以以更小的代价处理高分辨率(4-stride或8-stride)的feature map,我们对Multi-Head Attention也做了一些调整。

【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2
为了在 保证feature map分辨率和全局感受野的同时降低计算量,我们把key(K)和value(V)的长和宽分别缩小到以前的1/R_i。通过这种方法,我们就可以以一个较小的代价处理4-stride,和8-stride的feature map了。

; 3.PVT的设计方案

模型总体上由4个stage组成用于生成不同尺度的特征,每个stage包含Patch Embedding和若干个Transformer模块(相对于原本的transformer有所改动)组成。

  • Patch Embedding:目的在于将信息分块,降低单张图的图片大小,但会增加数据的深度
  • Transformer Encoder:目的在于计算图片的attention value,由于深度变大了,计算复杂度会变大,所以在这里作者使用了Special Reduction来减小计算复杂度

在第一个阶段,给定尺寸为 H X WX3 的输入图像,我们按照如下流程进行处理:

首先,将其划分为HW/4的平方的块(这里是为了与ResNet对标,最大输出特征的尺寸为原始分辨率的1/4),每个块的大小为;
然后,将展开后的块送入到线性投影曾得到尺寸为HW/4的平方 xC1 的嵌入块;
其次,将前述嵌入块与位置嵌入信息送入到Transformer的Encoder,其输出将为reshap为.H/4 X W/4 X C1

采用类似的方式,我们以前一阶段的输出作为输入即可得到特征F2 F3 F4。

H * W * 3 -> stage1 block -> H/4 * W/4 * C1 -> stage2 block -> H/8 * W/8 * C2 -> stage3 block -> H/16 * W/16 * C3 -> stage3 block -> H/32 * W/32 * C4

【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2

3.1 Patch embedding

Patch Embedding部分与ViT中对与图片的分块操作是一样的,即:

【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2
通过这种方式我们就可以灵活的调整每个阶段的特征尺寸,使其可以针对Transformer构建特征金字塔。

在每个stage开始,首先像ViT一样对输入图像进行token化,即进行patch embedding,patch大小除第1个stage的是4×4外,其余均采用2×2大小。这个思想有些类似于池化或者带步长的卷积操作,减小图像的分辨率,使得模型能够提取到更为抽象的信息。这意味着 每个stage(第一个stage除外)最终得到的特征图维度是减半的,tokens数量对应减少4倍

每个patch随后会送入一层Linear中,调整通道数量,然后再reshape以将patch token化。

这使得PVT总体上与resnet看起来类似,4个stage得到的特征图相比原图大小分别是1/4,1/8,1/16 和 1/32。这也意味着PVT可以产生不同尺度的特征。

Note:由于不同的stage的tokens数量不一样,所以每个stage采用不同的position embeddings,在patch embed之后加上各自的position embedding,当输入图像大小变化时,position embeddings也可以通过插值来自适应。

; 代码

1、首先输入的data的shape是(bs,channal,H,W),为了方便直接用batchsize是1的图片做例子,因此输入是(1,3,224,224)
code对应:

model = pvt_small(**cfg)
data = torch.randn((1, 3, 224, 224))
output = model(data)

2、输入数据首先经过stage 1 block的Patch emb操作,这个操作首先把224 _224的图像分成4_4的一个个小patch,这个实现是用卷积实现的,用4 _4的卷积和对224_224的图像进行卷积,步长为4即可。
code对应:

self.proj = nn.Conv2d(in_chans=3, embed_dim=64, kernel_size=4, stride=4)

print(x.shape)
x = self.proj(x)
print(x.shape)

这样就可以用56 _56的矩阵每一个点表示原来4_4的patch

3、对1 _64_56*56的矩阵在进行第二个维度展平
code对应:

print(x.shape)
x = x.flatten(2)
print(x.shape)

这时候就可以用3136这个一维的向量来表示224*224的图像了

4、为了方便计算调换下第二第三两个维度,然后对数据进行layer norm。
code对应:

print(x.shape)
x = x.transpose(1, 2)
print(x.shape)
x = self.norm(x)
print(x.shape)

以上就完成了Patch emb的操作,完整代码对应:

def forward(self, x):
    B, C, H, W = x.shape
    x = self.proj(x)
    x = x.flatten(2)
    x = x.transpose(1, 2)
    x = self.norm(x)
    H, W = H // 4, W // 4
    return x, (H, W)

【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2

3.2position embedding

1、这部分和Vit的位置编码基本是一样的,创建一个可学习的参数,大小和patch emb出来的tensor的大小一致就是(1,3136,64),这是个可学习的参数。
code对应:

pos_embed = nn.Parameter(torch.zeros(1, 3136, 64))

2、位置编码的使用也是和Vit一样,直接和输出的x进行矩阵加,因此shape不变化。
code对应:

print(x.shape)
x = x + pos_embed
print(x.shape)

3、相加完后,作者加了个dropout进行正则化。
code对应:

pos_drop = nn.Dropout(p=drop_rate)
x = pos_drop(x)

以上就完成了position embedding的操作,完整代码对应:

x = x + pos_embed
x = pos_drop(x)

【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2

3.3Encoder

第i个stage的encoder部分由depth[i]个block构成,对于pvt_tiny到pvt_large来说主要就是depth的参数的不同:

【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2
例如对于pvt_tiny来说,每个encoder都是由两个block构成,每个block的结构如下图所示:
【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2
对于第一个encoder的第一个block的输入就是我们前面分析的经过position embedding后拿到的tensor,因此他的输入的大小是(1,3136,64),与此同时图像经过Patch emb后变成了56*56的大小。

1、首先从上图可以看出先对输入拷贝一份,给残差结构用。然后输入的x先经过一层layer norm层,此时维度不变,然后经过作者修改的Multi head attention层(SRA,后面再讲)与之前拷贝的输入叠加。
code对应:

print(x.shape)
x = x + self.drop_path(self.attn(self.norm1(x), H, W))
print(x.shape)

2、经过SRA层的特征拷贝一份留给残差结构,然后将输入经过layer norm层,维度不变,再送入feed forward层(后面再讲),之后与之前拷贝的输入叠加。
code对应:

print(x.shape)
x = x + self.drop_path(self.mlp(self.norm2(x)))
print(x.shape)

因此可以发现经过一个block,tensor的shape是不发生变化的。完整的代码对应:

def forward(self, x, H, W):
      x = x + self.drop_path(self.attn(self.norm1(x), H, W))
      x = x + self.drop_path(self.mlp(self.norm2(x)))
      return x

3、这样 经过depth[i]个block之后拿到的tensor的大小仍然是(1,3136,64),只需要 将它的shape还原成图像的形状就可以输入给下一个stage了。而还原shape,直接调用reshape函数即可,这时候的特征就还原成(bs,channal,H,W)了,数值为(1,64,56,56)
code对应:

print(x.shape)
x = x.reshape(B, H, W, -1)
print(x.shape)
x = x.permute(0, 3, 1, 2).contiguous()
print(x.shape)

这时候stage2输入的tensor就是(1,64,56,56),就完成了数据输出第一个stage的完整分析。
最后只要在不同的encoder中堆叠不同个数的block就可以构建出pvt_tiny、pvt_small、pvt_medium、pvt_large了。
完整图示如下:

【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2
所以!经过stage1,输入为(1,3,224,224)的tensor变成了(1,64,56,56)的tensor,这个tensor可以再次输入给下一个stage重复上述的计算就完成了PVT的设计。

3.2 Spatial-reduction attention(SRA)

在Patch embedding之后,需要将token化后的patch输入到若干个transformer 模块中进行处理。

不同的stage的tokens数量不同,越靠前的stage的patchs数量越多,self-attention的计算量与sequence的长度N的平方成正比,如果PVT和ViT一样,所有的transformer encoders均采用相同的参数,那么计算量肯定是无法承受的。

1.PVT 为了减少计算量,不同的stages采用的网络参数是不同的
PVT不同系列的网络参数设置如下所示,这里P为patch的size,C为特征维度大小,N为MHA(multi-head attention)的heads数量,E为FFN的扩展系数,transformer中默认为4。

【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2
可以见到随着stage,特征的维度是逐渐增加的,比如stage1的特征维度只有64,而stage4的特征维度为512,这种设置和常规的CNN网络设置是类似的,所以前面stage的patchs数量虽然大,但是特征维度小,所以计算量也不是太大。不同体量的PVT其差异主要体现在各个stage的transformer encoder的数量差异。

2.为了进一步减少计算量,作者将multi-head attention (MHA)用所提出的spatial-reduction attention (SRA)来替换。

SRA的核心是减少attention层的key和value对的数量,常规的MHA在attention层计算时key和value对的数量为sequence的长度,但是SRA将其降低为原来的1/Rd 平方。

SRA的处理过程可以描述如下:

【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2
也就是说每个head的维度等于Ci/Ni, SR( )为空间尺度下采样操作,定义如下:
【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2
【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2
(
总而言之,所提方案涉及的超参数包含如下:

P:阶段i的块尺寸;
C:阶段i的通道数;
L:阶段i的encoder层数;
R:阶段i中SRA的下采样比例;
N:阶段i的head数量;
E:阶段i中MLP的扩展比例。
)
具体来说:
在实现上,

  • 首先将维度为(HW,C)的K,V通过 reshape变换到维度为(H,W, C)的3-D特征图,
  • 然后均分大小为R * R的patchs, 每个patchs通过线性变换将得到维度为(H W / R R,C)的patch embeddings(这里实现上其实和patch emb操作类似,等价于一个卷积操作),
  • 最后应用一个layer norm层,这样就可以大大降低K和V的数量。其核心代码也是这么实现的:
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
self.norm = nn.LayerNorm(dim)`

每个stage,经过若干个SRA模块的处理后,将得到的特征,再次reshape成3D特征图的形式输入到下一个Stage中。

1、首先如果参数self.sr_ratio为1的话,那么pvt的attetion就和vit的attetion一模一样了:

【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2

2、因此分析不一样的地方

self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
x_ = self.norm(x_)
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

2.1:首先输入进来的x的shape是(1,3136,64),
2.2:先permute置换维度得到(1,64,3136),
2.3:reshape得到(1,64,56,56)
2.4:self.sr(x_)是一个卷积操作,卷积的步长和大小都是sr_ratio,这个数值这里是8因此相当于将56*56的大小长宽缩小到8分之一,也就是面积缩小到64分之一,因此输出的shape是(1,64,7,7)
2.5:reshape(B, C, -1)得到(1,64,49)
2.6:permute(0, 2, 1)得到(1,49,64)
2.7:经过layer norm后shape不变
2.8:kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)这一行就是和vit一样用x生成k和v,不同的是这里的x通过卷积的方式降低了x的大小。
这一行shape的变化是这样的:(1,49,64)->(1,49,128)->(1,49,2,1,64)->(2,1,1,49,64)
2.9:拿到kv:(2,1,1,49,64)分别取index为0和1就可以得到k和v对应k, v = kv[0], kv[1]因此k和v的shape为(1,1,49,64)

3、之后的代码与vit相同,主要就是拿到了x生成的q,k,v之后,q和所有的k矩阵乘之后算softmax,然后加权到v上。
code对应:

attn = (q @ k.transpose(-2, -1)) * self.scale
attn = attn.softmax(dim=-1)
attn = self.attn_drop(attn)

x = (attn @ v).transpose(1, 2).reshape(B, N, C)
x = self.proj(x)
x = self.proj_drop(x)

4、所以对于attention模块输入进来的x的大小是(1,3136,64)输出的shape也是(1,3136,64)

【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2

完整代码

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio

        if sr_ratio > 1:
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        if self.sr_ratio > 1:
            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
            x_ = self.norm(x_)
            kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        else:
            kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

feed forward

这部分比较简单了,其实就是一个mlp构成的模块。

1、完整代码:
class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x
首先forward函数的输入是attention的输出和原始输入残差相加的结果,
输入大小是(1,3136,64)
fc1输出(1,3136,512)
act是GELU激活函数,输出(1,3136,512)
drop输出(1,3136,512)
fc2输出(1,3136,64)
drop输出(1,3136,64)

实验结果

语义分割

我们选择语义FPN[21]作为基线,这是一种简单的分割方法,无需特殊操作(例如,扩展卷积)。因此,用它作为基线可以很好地检验主干的原始有效性。与目标检测中的实现类似,我们将特征金字塔直接输入到语义FPN中,并使用双线性插值来调整预训练位置嵌入的大小。

实验设置

我们选择了ADE20K[63],
这是一个具有挑战性的场景分析基准,用于语义分割。ADE20K包含150个细粒度语义类别,其中分别有20210、2000和3352个图像用于训练、验证和测试。我们将PVT主干应用于Semantic FPN21,这是一种没有扩展卷积的简单分割方法[57],以此来评估我们的PVT主干。在训练阶段,主干用ImageNet上预先训练的权重初始化[9],其他新添加的层用Xavier初始化[13]。我们使用AdamW[33]优化模型,初始学习率为1e-4。按照常见设置[21,6],我们在4个V100 gpu上训练80k迭代的模型,批量大小为16。学习率按polynomial衰减规律衰减,幂为0.9。我们随机调整训练图像的大小并将其裁剪为512× 512并在测试期间将图像缩放到短边长为512。

结果。 如表5所示,在不同的参数尺度下,我们的PVT始终优于ResNet[15]和ResNeXt[56],使用语义FPN进行语义分割。例如,在参数数目几乎相同的情况下,我们的PVT Tiny/Small/Medium至少比ResNet-18/50/101高出280万。此外,尽管我们的Semantic FPN+PVT Large的参数数目比Semantic FPN+ResNeXt101-64x4d的参数数目低了20%,但mIoU仍然高出1.9(42.1比40.2),这表明对于语义分段来说,我们的PVT可以提取比CNN更好的特征,这得益于全局注意机制。

最终,论文采用PVT为Encoder,FPN为Decoder的形式,测试了模型在ADE20K数据集上的性能。具体如下:

【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2

; 思考

反映出来了一个重要的现象, 逐步缩小Token序列的结构其实也是可以实现与Deit以及ViT那种定长Token序列模型相同的效果,甚至更好。

针对图像构造的Token实际上仍是有待进一步优化以更加高效的提炼图像的局部上下文信息,这一点在 Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet 中的T2T Module,以及 Visual Transformers: Token-based Image Representation and Processing for Computer Vision 中的Tokenizer等结构的设计中也是有所体现的。

抛开其他的不说,这里的 多分辨率特征的提取实际上既有传统又结合新潮:

  • 既有 CNN多层级特征提取以为后续结构提供丰富的多尺度信息的常见密集预测任务的增强策略
  • 亦有最近的SETR: Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers 这种 利用多层Transformer Layer中间的输出特征来送入CNN解码器服务于分割预测的恢复的有效尝试

二者拍拍手,诶,似乎思路顺其自然。但是想尝试和能做出也并不等价,具体实操细节过程中,遇到的问题,针对性的解决手段,策略的尝试与调整,都不足为外人道也。不过从结果而言,倒也确实可以说是很好的 结合了现有的CNN和Transformer的工作的探索。

作者想法

1. 为什么PVT在同样参数量下比CNN效果好?

我认为有两点
1)全局感受野和
2)动态权重。

其实本质上,Multi-Head Attention(MHA)和Conv有一些相通的地方。 MHA可以大致看作是一个具备全局感受野的,且结果是按照attention weight加权平均的卷积。因此Transformer的特征表达能力会更强。

2. 后续可扩展的思路

1)效率更高的Attention:随着输入图片的增大,PVT消耗资源的增长率要比ResNet要高,所以PVT更适合处理中等输入分辨率的图片(具体见PVT的Ablation Study)。所以找到一种效率更高的Attention方案是很重要的。

2)Position Embedding:PVT的position embedding是和ViT一样,都是随机的参数,然后硬学的。而且在改变输入图像的分辨率的时候,position embedding还需要通过插值来调整大小。所以我觉得这也是可以改进的地方,找到一种更适合2D图像的方法。

3)金字塔结构:PVT只是一种较简单的金字塔式Tranformer。 中间是通过Patch Embedding连接的,或许有更优美的方案。

【PVT v2】PVTv2: Improved Baselines with Pyramid Vision Transformer

论文
代码

笔记参考:
PVT,PVTv2
简洁版
代码解释版

动机

出发点:对PVT1进行优化

  • vit和pvt_v1对图像用44大小的patch进行编码,这样 忽略了一定的图像局部连续性*。
  • vit和pvt_v1都是用 固定大小的位置编码,这样对处理任意大小的图像不友好。
  • 计算量还是大

改进

1.Overlapping Patch Embedding

【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2
之前 ViT 以及 PVT 采用的都是正好切分的方式,有时边界部分的信息无法得到完整解读。同时,也丧失了这些分块的局部连续性。
改进:
对patch窗口进行放大,使相邻窗口的面积重叠一半,并在feature map上填充0以保持分辨率。在这项工作中,作者使用0 padding卷积来实现重叠的patch嵌入。

具体来说,给定一个大小为h×w×c的输入,将其输入到一个步长为S、卷积核大小为2S−1、padding大小为S−1、卷积核数目为的卷积c中。输出大小为H/S x W/S x C 。

具体可见代码:

1、针对PVT_V1的patch emb模块,作者将原来的卷积操作进行了修改:

self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)

self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride, padding=(patch_size[0] // 2, patch_size[1] // 2))

这样修改之后的patch emb输入是(1,3,224,224)输出还是(bs,channal,56,56)和原来用步长为4大小为4的卷积核卷积的结果一致。 不一致在于编码图像结合了每个patch和上下左右相邻的patch信息。从上图下部分可看出。感觉和Swin有相似的思想,但是实现更简单。

2.卷积的前向传播

【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2
删除固定大小位置编码,并引入0 padding位置编码进pvt如图1所示(b),
在前馈网络中的FC层与GELU之间添加一个3×3的depth-wise卷积。
depth-wise卷积讲解很好!
depth-wise卷积
depthwise separable convolution:
先用M个3 _3卷积核一对一卷积输入的M个feature map,不求和,生成M个结果;
然后用N个1_1的卷积核正常卷积前面生成的M个结果,求和,
最后生成N个结果。
因此文章中将depthwise separable convolution分成两步,
一步叫depthwise convolution,
另一步是pointwise convolution。
class DWConv(nn.Module):
    def __init__(self, dim=768):
        self.dwconv = nn.Conv2d(dim, dim, 3, 1, 1, bias=True, groups=dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        x = x.transpose(1, 2).view(B, C, H, W)
        x = self.dwconv(x)
        x = x.flatten(2).transpose(1, 2)

        return x

class Mlp(nn.Module):
    def forward(self, x, H, W):
        x = self.fc1(x)
        x = self.dwconv(x, H, W)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

3.Linear Spatial Reduction Attention

进一步降低PVT的计算成本
把PVT的 SRA结构中的卷积降低分辨率的操作换成了池化加卷积来进行,节省计算量。

线性SRA在进行注意操作之前使用平均池化 将空间维度(即h×w)减小到固定大小(即P ×P) ,P是线性SRA的池大小(pooling size)。所以线性SRA像卷积层一样需要线性的计算和内存开销。

【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2
print(x.shape)
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
print(x_.shape)
x_ = self.pool(x_)
print(x_.shape)
x_ = self.sr(x_)
print(x_.shape)
x_ = x_.reshape(B, C, -1)
print(x_.shape)
x_ = x_.permute(0, 2, 1)
print(x_.shape)

第一步把输入的x从tokens还原成二维,且完整shape为(bs,channal,H,W)
第二步经过尺寸为7的池化层
第三步经过卷积层
第四步至最后:还原成(1,H*W,dim)

完整代码

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False):
        if not linear:
            if sr_ratio > 1:
                self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
                self.norm = nn.LayerNorm(dim)
        else:
            self.pool = nn.AdaptiveAvgPool2d(7)
            self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
            self.norm = nn.LayerNorm(dim)
            self.act = nn.GELU()
        self.apply(self._init_weights)

    def forward(self, x, H, W):
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        if not self.linear:
            if self.sr_ratio > 1:
                x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
                x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
                x_ = self.norm(x_)
                kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
            else:
                kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)

        else:
            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
            x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1)
            x_ = self.norm(x_)
            x_ = self.act(x_)
            kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

注意

需要注意的是,在PVTv2中只有pvt_v2_b2_li作者用了这个linear SRA,而作者发布的最好的模型pvt_v2_b5是没用这个linear SRA的,所以主要涨点的贡献来自于动机一用更大尺寸的卷积核加强patch之间的联系。由此可见对于图像的任务,patch间的关系还是很重要的,针对patch emb应该还有更好的方法!

pvt_v1完整测试pvt.py代码


python3 -m pip install timm

python3 pvt.py

import torch
import torch.nn as nn
import torch.nn.functional as F
from functools import partial

from timm.models.layers import DropPath, to_2tuple, trunc_normal_
from timm.models.registry import register_model
from timm.models.vision_transformer import _cfg

__all__ = [
    'pvt_tiny', 'pvt_small', 'pvt_medium', 'pvt_large'
]

class Mlp(nn.Module):
    def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=nn.GELU, drop=0.):
        super().__init__()
        out_features = out_features or in_features
        hidden_features = hidden_features or in_features
        self.fc1 = nn.Linear(in_features, hidden_features)
        self.act = act_layer()
        self.fc2 = nn.Linear(hidden_features, out_features)
        self.drop = nn.Dropout(drop)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act(x)
        x = self.drop(x)
        x = self.fc2(x)
        x = self.drop(x)
        return x

class Attention(nn.Module):
    def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1):
        super().__init__()
        assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."

        self.dim = dim
        self.num_heads = num_heads
        head_dim = dim // num_heads
        self.scale = qk_scale or head_dim ** -0.5

        self.q = nn.Linear(dim, dim, bias=qkv_bias)
        self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.proj = nn.Linear(dim, dim)
        self.proj_drop = nn.Dropout(proj_drop)

        self.sr_ratio = sr_ratio
        if sr_ratio > 1:
            self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
            self.norm = nn.LayerNorm(dim)

    def forward(self, x, H, W):
        B, N, C = x.shape
        q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)

        if self.sr_ratio > 1:
            x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
            x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
            x_ = self.norm(x_)

            kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        else:
            kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
        k, v = kv[0], kv[1]

        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.proj_drop(x)

        return x

class Block(nn.Module):

    def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, qk_scale=None, drop=0., attn_drop=0.,
                 drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm, sr_ratio=1):
        super().__init__()
        self.norm1 = norm_layer(dim)
        self.attn = Attention(
            dim,
            num_heads=num_heads, qkv_bias=qkv_bias, qk_scale=qk_scale,
            attn_drop=attn_drop, proj_drop=drop, sr_ratio=sr_ratio)

        self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
        self.norm2 = norm_layer(dim)
        mlp_hidden_dim = int(dim * mlp_ratio)
        self.mlp = Mlp(in_features=dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)

    def forward(self, x, H, W):
        x = x + self.drop_path(self.attn(self.norm1(x), H, W))
        x = x + self.drop_path(self.mlp(self.norm2(x)))

        return x

class PatchEmbed(nn.Module):
    """ Image to Patch Embedding
"""

    def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768):
        super().__init__()
        img_size = to_2tuple(img_size)
        patch_size = to_2tuple(patch_size)

        self.img_size = img_size
        self.patch_size = patch_size

        self.H, self.W = img_size[0] // patch_size[0], img_size[1] // patch_size[1]
        self.num_patches = self.H * self.W
        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size)
        self.norm = nn.LayerNorm(embed_dim)

    def forward(self, x):
        B, C, H, W = x.shape

        x = self.proj(x)
        x = x.flatten(2)
        x = x.transpose(1, 2)
        x = self.norm(x)

        H, W = H // self.patch_size[0], W // self.patch_size[1]

        return x, (H, W)

class PyramidVisionTransformer(nn.Module):
    def __init__(self, img_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
                 num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, qk_scale=None, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
                 depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4):
        super().__init__()
        self.num_classes = num_classes
        self.depths = depths
        self.num_stages = num_stages

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
        cur = 0

        for i in range(num_stages):
            patch_embed = PatchEmbed(img_size=img_size if i == 0 else img_size // (2 ** (i + 1)),
                                     patch_size=patch_size if i == 0 else 2,
                                     in_chans=in_chans if i == 0 else embed_dims[i - 1],
                                     embed_dim=embed_dims[i])
            num_patches = patch_embed.num_patches if i != num_stages - 1 else patch_embed.num_patches + 1
            pos_embed = nn.Parameter(torch.zeros(1, num_patches, embed_dims[i]))
            pos_drop = nn.Dropout(p=drop_rate)

            block = nn.ModuleList([Block(
                dim=embed_dims[i], num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias,
                qk_scale=qk_scale, drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j],
                norm_layer=norm_layer, sr_ratio=sr_ratios[i])
                for j in range(depths[i])])
            cur += depths[i]

            setattr(self, f"patch_embed{i + 1}", patch_embed)
            setattr(self, f"pos_embed{i + 1}", pos_embed)
            setattr(self, f"pos_drop{i + 1}", pos_drop)
            setattr(self, f"block{i + 1}", block)

        self.norm = norm_layer(embed_dims[3])

        self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dims[3]))

        self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()

        for i in range(num_stages):
            pos_embed = getattr(self, f"pos_embed{i + 1}")
            trunc_normal_(pos_embed, std=.02)
        trunc_normal_(self.cls_token, std=.02)
        self.apply(self._init_weights)

    def _init_weights(self, m):
        if isinstance(m, nn.Linear):
            trunc_normal_(m.weight, std=.02)
            if isinstance(m, nn.Linear) and m.bias is not None:
                nn.init.constant_(m.bias, 0)
        elif isinstance(m, nn.LayerNorm):
            nn.init.constant_(m.bias, 0)
            nn.init.constant_(m.weight, 1.0)

    @torch.jit.ignore
    def no_weight_decay(self):

        return {'cls_token'}

    def get_classifier(self):
        return self.head

    def reset_classifier(self, num_classes, global_pool=''):
        self.num_classes = num_classes
        self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()

    def _get_pos_embed(self, pos_embed, patch_embed, H, W):
        if H * W == self.patch_embed1.num_patches:
            return pos_embed
        else:
            return F.interpolate(
                pos_embed.reshape(1, patch_embed.H, patch_embed.W, -1).permute(0, 3, 1, 2),
                size=(H, W), mode="bilinear").reshape(1, -1, H * W).permute(0, 2, 1)

    def forward_features(self, x):
        B = x.shape[0]

        for i in range(self.num_stages):
            patch_embed = getattr(self, f"patch_embed{i + 1}")
            pos_embed = getattr(self, f"pos_embed{i + 1}")
            pos_drop = getattr(self, f"pos_drop{i + 1}")
            block = getattr(self, f"block{i + 1}")
            x, (H, W) = patch_embed(x)
"""
            stage0:
"""

            if i == self.num_stages - 1:
                cls_tokens = self.cls_token.expand(B, -1, -1)
                x = torch.cat((cls_tokens, x), dim=1)
                pos_embed_ = self._get_pos_embed(pos_embed[:, 1:], patch_embed, H, W)
                pos_embed = torch.cat((pos_embed[:, 0:1], pos_embed_), dim=1)
            else:
                pos_embed = self._get_pos_embed(pos_embed, patch_embed, H, W)

            x = pos_drop(x + pos_embed)

            for blk in block:
                x = blk(x, H, W)
            if i != self.num_stages - 1:
                x = x.reshape(B, H, W, -1).permute(0, 3, 1, 2).contiguous()

        x = self.norm(x)

        return x[:, 0]

    def forward(self, x):
        x = self.forward_features(x)
        x = self.head(x)

        return x

def _conv_filter(state_dict, patch_size=16):
    """ convert patch embedding weight from manual patchify + linear proj to conv"""
    out_dict = {}
    for k, v in state_dict.items():
        if 'patch_embed.proj.weight' in k:
            v = v.reshape((v.shape[0], 3, patch_size, patch_size))
        out_dict[k] = v

    return out_dict

@register_model
def pvt_tiny(pretrained=False, **kwargs):
    model = PyramidVisionTransformer(
        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[2, 2, 2, 2], sr_ratios=[8, 4, 2, 1],
        **kwargs)
    model.default_cfg = _cfg()

    return model

@register_model
def pvt_small(pretrained=False, **kwargs):
    model = PyramidVisionTransformer(
        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], **kwargs)
    model.default_cfg = _cfg()

    return model

@register_model
def pvt_medium(pretrained=False, **kwargs):
    model = PyramidVisionTransformer(
        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 4, 18, 3], sr_ratios=[8, 4, 2, 1],
        **kwargs)
    model.default_cfg = _cfg()

    return model

@register_model
def pvt_large(pretrained=False, **kwargs):
    model = PyramidVisionTransformer(
        patch_size=4, embed_dims=[64, 128, 320, 512], num_heads=[1, 2, 5, 8], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 8, 27, 3], sr_ratios=[8, 4, 2, 1],
        **kwargs)
    model.default_cfg = _cfg()

    return model

@register_model
def pvt_huge_v2(pretrained=False, **kwargs):
    model = PyramidVisionTransformer(
        patch_size=4, embed_dims=[128, 256, 512, 768], num_heads=[2, 4, 8, 12], mlp_ratios=[8, 8, 4, 4], qkv_bias=True,
        norm_layer=partial(nn.LayerNorm, eps=1e-6), depths=[3, 10, 60, 3], sr_ratios=[8, 4, 2, 1],

        **kwargs)
    model.default_cfg = _cfg()

    return model

if __name__ == '__main__':
    cfg = dict(
        num_classes = 2,
        pretrained=False
    )
    model = pvt_small(**cfg)
    data = torch.randn((1, 3, 224, 224))
    output = model(data)
    print(output.shape)

Original: https://blog.csdn.net/zhe470719/article/details/124807854
Author: sky_柘
Title: 【深度学习】(ICCV-2021)PVT-金字塔 Vision Transformer及PVT_V2

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/649055/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球