DABDetr论文解读+核心源码解读

文章目录

前言

本文主要介绍下发表在ICLR2022的DAB-Detr论文的基本思想以及代码的实现。
1、代码地址
2、论文地址
另外,感兴趣可以看下本人写的关于detr其他文章:
1、nn.Transformer使用
2、mmdet解读Detr
3、DeformableDetr
4、ConditionalDetr

1、论文解读

整体模型结构图和Detr很相似:

DABDetr论文解读+核心源码解读

; 1.1.空间注意力热图可视化

DABDetr论文解读+核心源码解读
本文认为原始的Detr系列论文中:可学习的object queries仅仅是给model预测bbox提供了参考点(中心点)信息,却没有提供box的宽和高信息。于是,本文考虑引入可学习的锚框来使model能够适配不同尺寸的物体。上图是可视化的三个模型的空间注意力热图(pk*pq),若读者对热图如何产生的,可参考Detr热图可视化。从图中可以看出,引入可学习锚框后,DAB-Detr能够很好覆盖不同尺寸的物体。本文所得出的一个结论:query中content query和key计算相似度完成特征提取,而pos query则用于限制提取区域的范围及大小。

1.2.模型草稿

DABDetr论文解读+核心源码解读
图中紫色是改动的区域,大体流程是:DAB-Detr直接预设了N个可学习的anchor,这点类似于SparseRCNN。然后经过宽高调制交叉注意力模块,预测出每个锚框四个元素偏移量来更新anchor。

; 1.3.详细模型

DABDetr论文解读+核心源码解读
上图是我做的一张PPT,展示的是一层DecoderLayer。简单说下流程:首先设定了N个可学习的4维的anchors,然后经过PE和MLP将其映射成Pq。
1) 在self-attn部分:常规的自注意力,使用的是Cq和Pq做加法;
2) 在cross-attn部分:参考点(x,y)部分完全和ConditionalDetr一样,Cq和Pq使用拼接来生成Qq;唯一区别是”宽和高调制交叉注意力模块”:在计算Pk和Pq的权重相似度时引入了一个(1/w,1/h)的一个尺度变换操作。

1.4.设置温度系数

Detr中给特征图每个位置生成位置Pk完全使用的是Transformer中温度系数,而Transformer针对的是单词的嵌入向量设计的,而特征图中像素值大多分布在[0,1]之间,因此,贸然采用10000不合适,所以,本文采用了20。算是个trick吧,能涨一个点左右。

DABDetr论文解读+核心源码解读

; 1.5.实验

在四个backbone比较了性能,总体来看,达到最优。

DABDetr论文解读+核心源码解读

2、代码讲解

感觉这套代码质量非常高,因为作者基本上开源了每个实验的代码,值得反复看(包括deformable attn的算子、分布式训练等等)。

2.1.Decoder

首先看下整体Decoder的forward函数部分:

def forward(self, tgt, memory,
            tgt_mask: Optional[Tensor] = None,
            memory_mask: Optional[Tensor] = None,
            tgt_key_padding_mask: Optional[Tensor] = None,
            memory_key_padding_mask: Optional[Tensor] = None,
            pos: Optional[Tensor] = None,
            refpoints_unsigmoid: Optional[Tensor] = None,
            ):

    output = tgt

    intermediate = []
    reference_points = refpoints_unsigmoid.sigmoid()
    ref_points = [reference_points]

    for layer_id, layer in enumerate(self.layers):

        obj_center = reference_points[..., :self.query_dim]

        query_sine_embed = gen_sineembed_for_position(obj_center)
        query_pos = self.ref_point_head(query_sine_embed)

        if self.query_scale_type != 'fix_elewise':
            if layer_id == 0:
                pos_transformation = 1

            else:
                pos_transformation = self.query_scale(output)
        else:
            pos_transformation = self.query_scale.weight[layer_id]

        query_sine_embed = query_sine_embed[...,:self.d_model] * pos_transformation

        if self.modulate_hw_attn:

            refHW_cond = self.ref_anchor_head(output).sigmoid()

            query_sine_embed[..., self.d_model // 2:] *= (refHW_cond[..., 0] / obj_center[..., 2]).unsqueeze(-1)
            query_sine_embed[..., :self.d_model // 2] *= (refHW_cond[..., 1] / obj_center[..., 3]).unsqueeze(-1)

        output = layer(output, memory, tgt_mask=tgt_mask,
                       memory_mask=memory_mask,
                       tgt_key_padding_mask=tgt_key_padding_mask,
                       memory_key_padding_mask=memory_key_padding_mask,
                       pos=pos, query_pos=query_pos, query_sine_embed=query_sine_embed,
                       is_first=(layer_id == 0))

        if self.bbox_embed is not None:
            if self.bbox_embed_diff_each_layer:

                tmp = self.bbox_embed[layer_id](output)
            else:
                tmp = self.bbox_embed(output)

            tmp[..., :self.query_dim] += inverse_sigmoid(reference_points)

            new_reference_points = tmp[..., :self.query_dim].sigmoid()
            if layer_id != self.num_layers - 1:

                ref_points.append(new_reference_points)

            reference_points = new_reference_points.detach()

        if self.return_intermediate:
            intermediate.append(self.norm(output))

    if self.norm is not None:
        output = self.norm(output)
        if self.return_intermediate:
            intermediate.pop()
            intermediate.append(output)

    if self.return_intermediate:
        if self.bbox_embed is not None:
            return [
                torch.stack(intermediate).transpose(1, 2),
                torch.stack(ref_points).transpose(1, 2),
            ]
        else:
            return [
                torch.stack(intermediate).transpose(1, 2),
                reference_points.unsqueeze(0).transpose(1, 2)
            ]

    return output.unsqueeze(0)

2.2.DecoderLayer

内部就是调用了self-attn和cross-attn,pq,pk,cq,ck按照论文中相加或者拼接即可。

def forward(self, tgt, memory,
                 tgt_mask: Optional[Tensor] = None,
                 memory_mask: Optional[Tensor] = None,
                 tgt_key_padding_mask: Optional[Tensor] = None,
                 memory_key_padding_mask: Optional[Tensor] = None,
                 pos: Optional[Tensor] = None,
                 query_pos: Optional[Tensor] = None,
                 query_sine_embed = None,
                 is_first = False):

    if not self.rm_self_attn_decoder:

        q_content = self.sa_qcontent_proj(tgt)
        q_pos = self.sa_qpos_proj(query_pos)
        k_content = self.sa_kcontent_proj(tgt)
        k_pos = self.sa_kpos_proj(query_pos)
        v = self.sa_v_proj(tgt)

        num_queries, bs, n_model = q_content.shape
        hw, _, _ = k_content.shape

        q = q_content + q_pos
        k = k_content + k_pos

        tgt2 = self.self_attn(q, k, value=v, attn_mask=tgt_mask,
                            key_padding_mask=tgt_key_padding_mask)[0]

        tgt = tgt + self.dropout1(tgt2)
        tgt = self.norm1(tgt)

    q_content = self.ca_qcontent_proj(tgt)
    k_content = self.ca_kcontent_proj(memory)
    v = self.ca_v_proj(memory)

    num_queries, bs, n_model = q_content.shape
    hw, _, _ = k_content.shape

    k_pos = self.ca_kpos_proj(pos)

    if is_first or self.keep_query_pos:
        q_pos = self.ca_qpos_proj(query_pos)
        q = q_content + q_pos
        k = k_content + k_pos
    else:
        q = q_content
        k = k_content

    q = q.view(num_queries, bs, self.nhead, n_model//self.nhead)
    query_sine_embed = self.ca_qpos_sine_proj(query_sine_embed)
    query_sine_embed = query_sine_embed.view(num_queries, bs, self.nhead, n_model//self.nhead)

    q = torch.cat([q, query_sine_embed], dim=3).view(num_queries, bs, n_model * 2)
    k = k.view(hw, bs, self.nhead, n_model//self.nhead)
    k_pos = k_pos.view(hw, bs, self.nhead, n_model//self.nhead)
    k = torch.cat([k, k_pos], dim=3).view(hw, bs, n_model * 2)

    tgt2 = self.cross_attn(query=q,
                               key=k,
                               value=v, attn_mask=memory_mask,
                               key_padding_mask=memory_key_padding_mask)[0]

    tgt = tgt + self.dropout2(tgt2)
    tgt = self.norm2(tgt)
    tgt2 = self.linear2(self.dropout(self.activation(self.linear1(tgt))))
    tgt = tgt + self.dropout3(tgt2)
    tgt = self.norm3(tgt)
    return tgt

总结

后面会介绍DN-DETR,敬请期待。若有问题欢迎+vx:wulele2541612007,拉你进群探讨交流。

Original: https://blog.csdn.net/wulele2/article/details/124251533
Author: 武乐乐~
Title: DABDetr论文解读+核心源码解读

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/706748/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球