Deformable DETR 实战(训练及预测)

开源地址:
https://github.com/fundamentalvision/deformable-detr

超级小白,摸索了几天,感谢批评指正!!!

一、数据集准备

1.下载数据集:

train_2017:

http://images.cocodataset.org/zips/train2017.zip

val_2017:

http://images.cocodataset.org/zips/val2017.zip

2.下载标注文件(instances_train2017.json instances_val2017.json)

http://images.cocodataset.org/annotations/annotations_trainval2017.zip

3.数据集文件夹

Deformable DETR 实战(训练及预测)

二、环境配置(命令)

  1. 创建python环境:

conda create -n deformable_detr python=3.7 pip

  1. 激活环境:

conda activate deformable_detr

PyTorch>=1.5.1, torchvision>=0.6.1,自行配置,不赘述

Deformable DETR 实战(训练及预测)
  1. 安装必要的包:

pip install -r requirements.txt

  1. 编译cuda操作:

cd ./models/ops

sh ./make.sh

编译成功后可 pip list结果如下:

Deformable DETR 实战(训练及预测)

有 MultiScaleDeformableAttention 包

  1. 测试 pyt hon test.py (可省略):

运行test.py的时间太长,我直接Kill了

  1. 运行 python main.py

也可以使用官方给的命令:

GPUS_PER_NODE=8 ./tools/run_dist_launch.sh 8 ./configs/r50_deformable_detr.sh

进行修改,如两张卡进行训练:

GPUS_PER_NODE=2 ./tools/run_dist_launch.sh 2 ./configs/r50_deformable_detr.sh

(我的环境是Linux,此处会出现chmod文件权限问题,百度即可自行解决,用到了chmod 777)

  1. 训练过程:训练Epoch:[0] 结束后会进行Test,然后接着Epoch:[1]训练

Deformable DETR 实战(训练及预测)

三、预测

由于训练时间太长,我直接Kill了,使用官方给的权重进行预测

  1. 下载权重文件:r50_deformable_detr-checkpoint.pth

如图点击model下载(需要梯子)

https://drive.google.com/file/d/1nDWZWHuRwtwGden77NLM9JoWe-YisJnA/view

Deformable DETR 实战(训练及预测)
  1. 待预测图片及其位置:

Deformable DETR 实战(训练及预测)

(我自己从COCO数据集随机复制的几张图片)

  1. 运行如下代码 predict.py(代码非原创,参考网上修改):
import cv2
from PIL import Image
import numpy as np
import os
import time

import torch
from torch import nn
import torchvision.transforms as T
from main import get_args_parser as get_main_args_parser
from models import build_model

torch.set_grad_enabled(False)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("[INFO] 当前使用{}做推断".format(device))

图像数据处理
transform = T.Compose([
    T.Resize(800),
    T.ToTensor(),
    T.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

plot box by opencv
def plot_result(pil_img, prob, boxes, save_name=None, imshow=False, imwrite=False):
    opencvImage = cv2.cvtColor(np.array(pil_img), cv2.COLOR_RGB2BGR)
    LABEL =['N/A', 'person', 'bicycle', 'car', 'motorcycle', 'airplane', 'bus',
            'train', 'truck', 'boat', 'traffic light', 'fire hydrant', 'N/A',
            'stop sign', 'parking meter', 'bench', 'bird', 'cat', 'dog', 'horse',
            'sheep', 'cow', 'elephant', 'bear', 'zebra', 'giraffe', 'N/A', 'backpack',
            'umbrella', 'N/A', 'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis',
            'snowboard', 'sports ball', 'kite', 'baseball bat', 'baseball glove',
            'skateboard', 'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
            'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple', 'sandwich',
            'orange', 'broccoli', 'carrot', 'hot dog', 'pizza', 'donut', 'cake',
            'chair', 'couch', 'potted plant', 'bed', 'N/A', 'dining table', 'N/A',
            'N/A', 'toilet', 'N/A', 'tv', 'laptop', 'mouse', 'remote', 'keyboard',
            'cell phone', 'microwave', 'oven', 'toaster', 'sink', 'refrigerator', 'N/A',
            'book', 'clock', 'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush']
    for p, (xmin, ymin, xmax, ymax) in zip(prob, boxes):
        cl = p.argmax()
        label_text = '{}: {}%'.format(LABEL[cl], round(p[cl] * 100, 2))

        cv2.rectangle(opencvImage, (int(xmin), int(ymin)), (int(xmax), int(ymax)), (255, 255, 0), 2)
        cv2.putText(opencvImage, label_text, (int(xmin) + 10, int(ymin) + 30), cv2.FONT_HERSHEY_SIMPLEX, 1,
                    (255, 255, 0), 2)

    if imshow:
        cv2.imshow('detect', opencvImage)
        cv2.waitKey(0)

    if imwrite:
        if not os.path.exists("./result/pred"):
            os.makedirs('./result/pred')
        cv2.imwrite('./result/pred/{}'.format(save_name), opencvImage)

将xywh转xyxy
def box_cxcywh_to_xyxy(x):
    x_c, y_c, w, h = x.unbind(1)
    b = [(x_c - 0.5 * w), (y_c - 0.5 * h),
         (x_c + 0.5 * w), (y_c + 0.5 * h)]
    return torch.stack(b, dim=1)

def rescale_bboxes(out_bbox, size):
    img_w, img_h = size
    b = box_cxcywh_to_xyxy(out_bbox)
    b = b.cpu().numpy()
    b = b * np.array([img_w, img_h, img_w, img_h], dtype=np.float32)
    return b

def load_model(model_path , args):
    model, _, _ = build_model(args)
    model.cuda()
    model.eval()
    state_dict = torch.load(model_path) #  prob_threshold

    probas = probas.cpu().detach().numpy()
    keep = keep.cpu().detach().numpy()

    # convert boxes from [0; 1] to image scales
    bboxes_scaled = rescale_bboxes(outputs['pred_boxes'][0, keep], im.size)
    end = time.time()
    return probas[keep], bboxes_scaled, end - start

if __name__ == "__main__":

    main_args = get_main_args_parser().parse_args()
    # 加载模型
    dfdetr = load_model('DDETR/r50_deformable_detr-checkpoint.pth',main_args) #
  1. 预测结果及预览:

Deformable DETR 实战(训练及预测)

使用Deformable DETR进行预测:

Deformable DETR 实战(训练及预测)

Deformable DETR 实战(训练及预测)

参考:

https://www.jianshu.com/p/b364534fd0a7

Windows下运行Deformable-DETR_harold_du的博客-CSDN博客_deformable detr

Deformable DETR环境配置和应用_Alaso_soso的博客-CSDN博客

DETR导出onnx模型,并进行推理(cpu环境)_athrunsunny的博客-CSDN博客

Original: https://blog.csdn.net/dystsp/article/details/125949720
Author: dystsp
Title: Deformable DETR 实战(训练及预测)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/720163/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球