结合代码看Vision Transformer【ViT】

参考仓库:

论文:An Image is Worth 16×16 Words: Transformers for Image Recognition at Scale

有相关问题搜索知识星球号:1453755 【CV老司机】加入星球提问。扫码也可加入:

结合代码看Vision Transformer【ViT】

也可以搜索关注微信公众号: CV老司机

结合代码看Vision Transformer【ViT】

相关代码和详细资源或者相关问题,可联系牛先生小猪wx号: jishudashou

结构介绍:

结合代码看Vision Transformer【ViT】

ViT: Transformer + Head

Transformer: Embeddings [1x197x768] + Encoder

Encoder: N x { Block_Sequence + layerNorm [非全局均值方差,有的实现没做】}

Block: LayerNorm + MultiHeadAttension + LayerNorm + Mlp [中间有两次残差累加]

以输入224x224x3为例,embedding :196+1 个patch , 768 通道【embedding dimension】

Embedding说明

patch embedding: 将图像分为16×16的小块,然后把16x16x3的小块拉平并跟一个FC做特征映射【有的实现是直接使用了kernel为16×16,stride为16的卷积实现,当然这里和我们说的Vision Transformer就有一定出入】

clstokens: 如果上面的patch embedding 为 1x196x768 , 这里的class token就是 1x1x768, 然后cat在一块儿。作为分类特征得汇总。【作用可以视为在CNN里面最后一层的FC logits】【注:维度也可以是 1x196x1024,:超参调节】

position embedding: NLP的实现中,position embedding 用来标记这个patch 在全局中的哪个位置,用于学习一定的结构信息。这里Vision Transformer 遵从原本设计,加入了这个可学习的position embedding.

参考的path embedding 代码:

undefined

        self.to_patch_embedding = nn.Sequential(
            Rearrange('b c (h p1) (w p2) -> b (h w) (p1 p2 c)', p1=patch_height, p2=patch_width),
            nn.Linear(patch_dim, dim),
        )

维度变化: 1x3x224x224  --->   1x196x768

LayerNorm说明

结合代码看Vision Transformer【ViT】

涉及到的算子如下:也就是上面的公式:减均值,除方差,乘以scale,加bias

结合代码看Vision Transformer【ViT】

multi-Head Attention 实现

取得全图注意力。

结合代码看Vision Transformer【ViT】

在这里实现没有mask,NLP中用于句子不一样时,还有填空时,做掩码。

参考代码1:【和上图除了mask部分,流程部分基本一致】

undefined

 def forward(self, hidden_states):
        # 1 x 197 x 768   197 个patch
        mixed_query_layer = self.query(hidden_states)  # fc implement query
        mixed_key_layer = self.key(hidden_states)
        mixed_value_layer = self.value(hidden_states)

        # 1 x 12 x 197 x 64  197pathch  12组特征,每组64维
        query_layer = self.transpose_for_scores(mixed_query_layer)
        key_layer = self.transpose_for_scores(mixed_key_layer)
        value_layer = self.transpose_for_scores(mixed_value_layer)

        # 1 x 12 x 197 x 197 ,感受野添加至 每个patch两两之间通道特征相关性
        attention_scores = torch.matmul(query_layer, key_layer.transpose(-1, -2))

        # 类似于使用温度系数添加注意力多样性
        attention_scores = attention_scores / math.sqrt(self.attention_head_size)

        # patch 内部通道【特征】重要程度使用softmax添加注意力
        attention_probs = self.softmax(attention_scores)
        weights = attention_probs if self.vis else None
        attention_probs = self.attn_dropout(attention_probs)

        # 感受野范围添加至全图所有patch
        context_layer = torch.matmul(attention_probs, value_layer)
        context_layer = context_layer.permute(0, 2, 1, 3).contiguous()
        new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,)
        context_layer = context_layer.view(*new_context_layer_shape)
        attention_output = self.out(context_layer)
        attention_output = self.proj_dropout(attention_output)
        return attention_output, weights

参考代码2:

class Attention(nn.Module):
    def __init__(self, dim, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        project_out = not (heads == 1 and dim_head == dim)

        self.heads = heads
        self.scale = dim_head ** -0.5

        self.attend = nn.Softmax(dim=-1)
        self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, dim),
            nn.Dropout(dropout)
        ) if project_out else nn.Identity()

    def forward(self, x):
        b, n, _, h = *x.shape, self.heads
        qkv = self.to_qkv(x).chunk(3, dim=-1)
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b h n d', h=h), qkv)

        dots = einsum('b h i d, b h j d -> b h i j', q, k) * self.scale

        attn = self.attend(dots)

        out = einsum('b h i j, b h j d -> b h i d', attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')
        return self.to_out(out)

undefined

整体流程描述:

结合代码看Vision Transformer【ViT】

从整理流程聊聊为什么多头注意力可以获取全图的感受野。首先第一个矩阵乘,以上图为例,每一个输出点,获取到了输入图与输入图两两对应的乘累加结果【我们这里叫相关性,也可以叫感受野就是和另外一个patch通道维度】,第二次矩阵乘,输入就是两个patch相关性,单个点具备两个patch所有通道信息,乘累加过程就是,乘以对应通道权重,然后与其他所有patch乘以对应通道权重结果的累加,这样之后就影响这个点结果的因子覆盖了全图。也就是我们常说的感受野是全图。

attention 每一行,patchN 与所有patch相关性,乘以对应通道权重的累加和。

MLP实现

两个FC+一个激活【gelu激活】

class Mlp(nn.Module):
    def __init__(self, config):
        super(Mlp, self).__init__()
        self.fc1 = Linear(config.hidden_size, config.transformer["mlp_dim"])
        self.fc2 = Linear(config.transformer["mlp_dim"], config.hidden_size)
        self.act_fn = ACT2FN["gelu"]
        self.dropout = Dropout(config.transformer["dropout_rate"])

        self._init_weights()

    def _init_weights(self):
        nn.init.xavier_uniform_(self.fc1.weight)
        nn.init.xavier_uniform_(self.fc2.weight)
        nn.init.normal_(self.fc1.bias, std=1e-6)
        nn.init.normal_(self.fc2.bias, std=1e-6)

    def forward(self, x):
        x = self.fc1(x)
        x = self.act_fn(x)
        x = self.dropout(x)
        x = self.fc2(x)
        x = self.dropout(x)
        return x

undefined

gelu实现:

x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))

结合代码看Vision Transformer【ViT】

全文差不多就这些了。Vit用尽量接近Transformer的方式来做了视觉任务最基本的分类任务,并且也取得了十分SOTA的效果。十分新颖!

Original: https://blog.csdn.net/m0_62789066/article/details/120773011
Author: 牛先生
Title: 结合代码看Vision Transformer【ViT】

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/532183/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球