PyTorch使用ResNet18提取图像特征并进行相似度计算

模型部分我参考的是https://blog.csdn.net/sunqiande88/article/details/80100891这篇文章,同样是在Cifar-10上训练。

一、不使用PyTorch中的预训练模型

将训练的模型保存下来接后面使用,保存方式:

torch.save(net.state_dict(), 'path')

加载方式

model = ResNet18()
model.load_state_dict(torch.load('path'))
model.eval()

由于不是使用预训练模型所以输出特征层还是很好输出的,只需要将输出从fc层改为前一层即可:

    def forward(self, x):
        out = self.conv1(x)
        out = self.layer1(out)
        out = self.layer2(out)
        out = self.layer3(out)
        out = self.layer4(out)
        out = F.avg_pool2d(out, 4)
        out = out.view(out.size(0), -1)
        features = out.detach()
        out = self.fc(out)
        return features

保存ONNX模型:

    model = ResNet18()
    model.load_state_dict(torch.load(from_path))
    model.eval()
    dummy_input = torch.randn(1, 3, 32, 32, requires_grad=True)
    torch.onnx.export(model, dummy_input, to_path, export_params=True, opset_version=10, do_constant_folding=True
                      , input_names=["input"], output_names=['output'])

二、使用PyTorch中的预训练模型

模型的保存与加载与上面相同,但是由于使用预训练模型无法修改输出,所以需要使用其他方式修改模型输出。

将训练的模型保存下来接后面使用,保存方式:

torch.save(net.state_dict(), 'path')

加载方式:

model = models.resnet18(pretrained=False)
model.fc = nn.Linear(512, 10)
model.load_state_dict(torch.load(from_path))
model.eval()

class Identity(nn.Module):
    def __init__(self):
        super(Identity, self).__init__()

    def forward(self, x):
        return x

model = models.resnet18(pretrained=False)
model.fc = Identity()
x = torch.randn(1, 3, 32, 32)
output = model(x)

保存ONNX模型:


    model = models.resnet18(pretrained=False)
    model.fc = nn.Linear(512, 10)
    model.load_state_dict(torch.load('path'))

    model.fc = Identity()

    model.eval()

    dummy_input = torch.randn(1, 3, 32, 32, requires_grad=True)
    torch.onnx.export(model, dummy_input, to_path, export_params=True, opset_version=10, do_constant_folding=True
                      , input_names=['input'], output_names=['output'])

三、测试

测试.pth模型部分代码


def load_image(img_path, transform=None):
    imgs = []
    for name in sorted(os.listdir(img_path)):
        img = Image.open(img_path + name).convert('RGB')
        if transform is not None:
            img = transform(img)
        else:
            img = transforms.ToTensor()(img)
        imgs.append(img)
    return imgs

def predict(imgs):
    model = ResNet18()
    model.load_state_dict(torch.load('path'))
    model.to(device)
    model.eval()
    imgs = torch.stack(imgs, 0).to(device)
    with torch.no_grad():
        predicts = model(imgs)
        print(predicted)
    return predicted

测试ONNX部分分为以下几种方式,
1、使用Python中的onnxruntime库

session = onnxruntime.InferenceSession('ONNX Path')
    inputs = {session.get_inputs()[0].name: img.numpy()}
    outs = session.run(None, inputs)
    return outs

2、使用Python中OpenCV(Version:4.2.0)中的接口


def load_image_cv(img_path, mean=None, std=None):
    img = cv2.imread(img_path, cv2.IMREAD_ANYCOLOR)
    if img.shape[2] > 1:
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
    img_ = (img / 255. - mean) / std
    return img_.astype(np.float32)

def run_onnx_cv(from_path, img):
    net = cv2.dnn.readNetFromONNX(from_path)
    input = cv2.dnn.blobFromImage(img)
    net.setInput(input)
    outs = net.forward()
    return outs

3、使用C++版本的OpenCV(Version:4.2.0)


cv::Mat load_image_cv(const std::string& fileName, cv::Scalar mean, cv::Scalar std)
{
    cv::Mat img = cv::imread(fileName, cv::IMREAD_ANYCOLOR);
    if (img.empty()) {
        return cv::Mat();
    }
    img.convertTo(img, CV_32F, 1 / 255.);
    cv::subtract(img, mean, img);
    cv::divide(img, std, img);
    return img;
}

cv::Mat img = load_image_cv("path", cv::Scalar(0.4914, 0.4822, 0.4465), cv::Scalar(0.2023, 0.1994, 0.2010));
cv::dnn::Net net = cv::dnn::readNetFromONNX("ONNX Path");
net.setPreferableBackend(cv::dnn::Backend::DNN_BACKEND_OPENCV);
net.setPreferableTarget(cv::dnn::Target::DNN_TARGET_CPU);
cv::Mat input = cv::dnn::blobFromImage(img);
net.setInput(input);
cv::Mat predicted = net.forward();

效果图片

PyTorch使用ResNet18提取图像特征并进行相似度计算
PyTorch使用ResNet18提取图像特征并进行相似度计算
PyTorch使用ResNet18提取图像特征并进行相似度计算

Original: https://blog.csdn.net/qq_37299618/article/details/121486682
Author: Zzzzzzzzzzzzzz—
Title: PyTorch使用ResNet18提取图像特征并进行相似度计算

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/701394/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球