二、现有网络模型(torchvision.models.vgg16)的修改与使用

二、现有网络模型(torchvision.models.vgg16)的修改使用、保存加载

1.torchvision.models.vgg16

官方文档 : https://pytorch.org/vision/stable/models.html#id2

二、现有网络模型(torchvision.models.vgg16)的修改与使用

pretrained (bool) – If True, returns a model pre-trained on ImageNet

ImageNet数据集太大不好下载

; 2.pretrained设置不同时网络模型的差别

二、现有网络模型(torchvision.models.vgg16)的修改与使用

3.如何修改现有网络结构

修改vgg16_true网路结构,添加linear层

import torchvision
from torch.nn import Linear

vgg16_false = torchvision.models.vgg16(pretrained=False)
vgg16_true = torchvision.models.vgg16(pretrained=True)

print(vgg16_true)

vgg16_true.classifier.add_module("add_linear",Linear(1000,10))
print(vgg16_true)

二、现有网络模型(torchvision.models.vgg16)的修改与使用
修改vgg16_false网路结构,更改分类器第6层为指定linear层
print(vgg16_false)
vgg16_false.classifier[6]=Linear(4096,10)
print(vgg16_false)

二、现有网络模型(torchvision.models.vgg16)的修改与使用

4.模型的保存、加载

vgg16_method1 结构+参数

import torch
import torchvision

vgg16 = torchvision.models.vgg16(pretrained=False)

torch.save(vgg16, "vgg16_method1.pth")

model = torch.load("vgg16_method1.pth")
print(model)

方式1的陷阱
自定义网络结构如下:

import torch
import torchvision
from torch import nn

class Qu(nn.Module):
    def __init__(self):
        super(Qu, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return

qu = Qu()
torch.save(qu, "qu_method1.pth")

在另一个文件加载该模型,会报错
AttributeError: Can’t get attribute ‘Qu’ on

正确的调用格式需要复制原模型的类定义

class Qu(nn.Module):
    def __init__(self):
        super(Qu, self).__init__()
        self.conv1 = nn.Conv2d(3, 64, kernel_size=3)

    def forward(self, x):
        x = self.conv1(x)
        return

model = torch.load("qu_method1.pth")
print(model)

或者用import

from model_save import *

model = torch.load("qu_method1.pth")
print(model)

vgg16_method2 参数(官方推荐)

import torch
import torchvision
from torch import nn

vgg16 = torchvision.models.vgg16(pretrained=False)

torch.save(vgg16.state_dict(), "vgg16_method2.pth")

vgg16 = torchvision.models.vgg16(pretrained=False)
vgg16.load_state_dict(torch.load("vgg16_method2.pth"))
print(vgg16)

Original: https://blog.csdn.net/weixin_44987829/article/details/122955436
Author: weixin_44987829
Title: 二、现有网络模型(torchvision.models.vgg16)的修改与使用

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/628932/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球