人体关键点检测(Keypoints Detection)

1.综述

Pytorch的 torchvision库中有关键点检测的模型keypointrcnn_resnet50_fpn()网络模型,其可以对17个人体关键点进行检测。
17个关键点包括:
鼻子,左眼,右眼,左耳,右耳,左肩,右肩,左胳膊肘,右胳膊肘,左手腕,右手腕,左臀,右臀,左膝,右膝,左脚踝,右脚踝
nose,left_eye, right_eye, left_ear, right_ear,
left_shoulder, right_shoulder, left_elbow, right_elbow,
left_wrist, right_wrist,left_hip, right_hip
left_knee, right_knee, left_ankle, right_ankle.

我们用COCO数据集测试如下:
原始图片:

人体关键点检测(Keypoints Detection)
人体关键点检测后的图片:
人体关键点检测(Keypoints Detection)
目标检测结果用蓝框框表示,表明目标检测为人的概率为1.0。然后17个关键点都可见且用红点表示出来。旁边是红点的编号,从1~17。

; 2. 单张图片检测代码如下:


import numpy as np
import torchvision
import torch
import torchvision.transforms as transforms
from PIL import Image, ImageDraw, ImageFont
import matplotlib.pyplot as plt

"""
    fire hydrant 消防栓,stop sign 停车标志, parking meter 停车收费器, bench 长椅。
    zebra 斑马, giraffe 长颈鹿, handbag 手提包, suitcase 手提箱, frisbee (游戏用)飞盘(flying disc)。
    skis 滑雪板(ski的复数),snowboard 滑雪板(ski是单板滑雪,snowboarding 是双板滑雪。)
    kite 风筝, baseball bat 棒球棍, baseball glove 棒球手套, skateboard 滑板, surfboard 冲浪板, tennis racket 网球拍。
    broccoli 西蓝花,donut甜甜圈,炸面圈(doughnut,空心的油炸面包), cake 蛋糕、饼, couch 长沙发(靠chi)。
    potted plant 盆栽植物。 dining table 餐桌。 laptop 笔记本电脑,remote 遥控器(=remote control),
    cell phone 移动电话(=mobile phone)(cellular 细胞的、蜂窝状的), oven 烤炉、烤箱。 toaster 烤面包器(toast 烤面包片)
    sink 洗碗池, refrigerator 冰箱。(=fridge), scissor剪刀(see, zer), teddy bear 泰迪熊。 hair drier 吹风机。
    toothbrush 牙刷。
"""
COCO_INSTANCE_CATEGORY_NAMES = [
    '__BACKGROUND__', 'person', 'bicycle', 'car', 'motorcycle',
    'airplane', 'bus', 'train', 'trunk', 'boat', 'traffic light',
    'fire hydrant', 'N/A', 'stop sign', 'parking meter', 'bench',
    'bird', 'cat', 'dog', 'horse', 'sheep', 'cow', 'elephant',
    'bear', 'zebra', 'giraffe', 'N/A', 'backpack', 'umbrella', 'N/A',
    'N/A', 'handbag', 'tie', 'suitcase', 'frisbee', 'skis', 'snowboard',
    'sports ball', 'kite', 'baseball bat', 'baseball glove', 'skateboard',
    'surfboard', 'tennis racket', 'bottle', 'N/A', 'wine glass',
    'cup', 'fork', 'knife', 'spoon', 'bowl', 'banana', 'apple',
    'sandwich', 'orange', 'broccoli', 'carrot', 'hot dog', 'pizza',
    'donut', 'cake', 'chair', 'couch', 'potted plant', 'bed', 'N/A',
    'dining table', 'N/A', 'N/A', 'toilet', 'N/A', 'tv', 'laptop',
    'mouse', 'remote', 'keyboard', 'cell phone', 'microwave', 'oven',
    'toaster', 'toaster', 'sink', 'refrigerator', 'N/A', 'book', 'clock',
    'vase', 'scissors', 'teddy bear', 'hair drier', 'toothbrush'
]

"""
    elbow 胳膊肘,wrist 手腕,hip 臀部
"""
COCO_PERSON_KEYPOINT_NAMES = ['nose', 'left_eye', 'right_eye', 'left_ear',
                              'right_ear', 'left_shoulder', 'right_shoulder', 'left_elbow',
                              'right_elbow', 'left_wrist', 'right_wrist', 'left_hip', 'right_hip',
                              'left_knee', 'right_knee', 'left_ankle', 'right_ankle']

model = torchvision.models.detection.keypointrcnn_resnet50_fpn(pretrained=True)
model.eval()

def Object_Detect(model, image_path, COCO_INSTANCE_CATEGORY_NAMES, threshold=0.5):

    image = Image.open(image_path)
    transform_d = transforms.Compose([transforms.ToTensor()])
    image_t = transform_d(image)
    print(image_t.shape)
    pred = model([image_t])
    print(pred)

    pred_class = [COCO_INSTANCE_CATEGORY_NAMES[ii] for ii in list(pred[0]['labels'].numpy())]
    pred_score = list(pred[0]['scores'].detach().numpy())

    pred_boxes = [[ii[0], ii[1], ii[2], ii[3]] for ii in list(pred[0]['boxes'].detach().numpy())]

    pred_index = [pred_score.index(x) for x in pred_score if x > 0.5]

    fontsize = np.int16(image.size[1] / 20)
    font1 = ImageFont.truetype("/usr/share/fonts/gnu-free/FreeMono.ttf", fontsize)

    draw = ImageDraw.Draw(image)
    for index in pred_index:
        box = pred_boxes[index]
        draw.rectangle(box, outline="blue")
        texts = pred_class[index]+":"+str(np.round(pred_score[index], 2))
        draw.text((box[0], box[1]), texts, fill="blue", font=font1)

    pred_keypoint = pred[0]["keypoints"]

    pred_keypoint = pred_keypoint[pred_index].detach().numpy()

    fontsize = np.int16(image.size[1] / 50)
    r = np.int16(image.size[1] / 150)
    font1 = ImageFont.truetype("/usr/share/fonts/gnu-free/FreeMono.ttf", fontsize)

    image3 = image.copy()
    draw = ImageDraw.Draw(image3)

    for index in range(pred_keypoint.shape[0]):

        keypoints = pred_keypoint[index]
        for ii in range(keypoints.shape[0]):
            x = keypoints[ii, 0]
            y = keypoints[ii, 1]
            visi =keypoints[ii, 2]
            if visi > 0:
                draw.ellipse(xy=(x-r, y-r, x+r, y+r), fill=(255, 0, 0))
                texts = str(ii+1)
                draw.text((x+r, y-r), texts, fill="red", font=font1)

    return image3

if __name__ == '__main__':
    image_path = "/mnt/COCO2017/val2017/000000000785.jpg"
    image = Object_Detect(model, image_path, COCO_INSTANCE_CATEGORY_NAMES)
    plt.imshow(image)
    plt.axis("off")

    plt.savefig('./skiing woman.png', bbox_inches='tight', pad_inches=0.0)
    plt.show()
  1. 识别多张图片代码如下:

4.放图(出自COCO2017数据集)

人体关键点检测(Keypoints Detection)
人体关键点检测(Keypoints Detection)
人体关键点检测(Keypoints Detection)
人体关键点检测(Keypoints Detection)
人体关键点检测(Keypoints Detection)
人体关键点检测(Keypoints Detection)
人体关键点检测(Keypoints Detection)
人体关键点检测(Keypoints Detection)
人体关键点检测(Keypoints Detection)
人体关键点检测(Keypoints Detection)
人体关键点检测(Keypoints Detection)
人体关键点检测(Keypoints Detection)

; 5.参考资料

《PyTorch深度学习入门与实战》孙玉林等著。

Original: https://blog.csdn.net/csdnliwenqi/article/details/121694973
Author: 爱学习的大白菜
Title: 人体关键点检测(Keypoints Detection)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/518509/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球