pytorch实现Resnet系列的分类任务

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

文章目录

前言

最近需要一系列传统分类方法做对比,所以就顺便把自己复现resnet系列分类实验的过程记录一下,还是老传统:文末有源码。

一、数据集格式

pytorch实现Resnet系列的分类任务
其中training文件夹下的basal、her2都是要分类的类别,basal里面就是一张张图片了。
在这里给大家一个公开的10种猴子的分类数据集,已经分好类了,大家可以直接下载使用。数据集地址:https://www.kaggle.com/slothkong/10-monkey-species

; 二、训练部分

train.py文件下可以自行修改权值文件保存的地址和batch size:

if __name__ == '__main__':

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using {} device.".format(device))

    if not os.path.exists('./logs'):
        os.makedirs('./logs')

    BATCH_SIZE = 16

修改数据集的地址

    train_dataset = datasets.ImageFolder("./datasets/training", transform=data_transform["train"])
    train_loader = torch.utils.data.DataLoader(dataset=train_dataset, batch_size=BATCH_SIZE, shuffle=True,
                                               num_workers=2)
    len_train = len(train_dataset)
    val_dataset = datasets.ImageFolder("./datasets/validation", transform=data_transform["val"])
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                                             num_workers=2)
    len_val = len(val_dataset)

可以自行选择网络类型,resnet50或者resnet34或者其他都可以,损失函数就只有一个CEloss,优化器使用的是adam,epoch根据自己的数据多少和bacth size自行修改。

    net = resnet50()
    loss_function = nn.CrossEntropyLoss()
    optimizer = optim.Adam(net.parameters(), lr=0.0001)
    epoch = 100

三、评价

改完上面的参数后就可以训练了,训练结束之后可以对于分类结果进行评价,需要续改evaluate.py的相关内容,首先修改训练生产的权值文件的路径。

if __name__ == '__main__':
    model = torch.load("./logs/best.pth")

    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    class_correct = [0.] * 10
    class_total = [0.] * 10
    y_test, y_pred = [], []
    X_test = []

下面修改验证集或者测试集的指向路径。

    data_transform = transforms.Compose([transforms.Resize(256),
                                       transforms.CenterCrop(224),
                                       transforms.ToTensor(),
                                       transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])])
    val_dataset = datasets.ImageFolder("./datasets/validation", transform=data_transform)
    val_loader = torch.utils.data.DataLoader(dataset=val_dataset, batch_size=BATCH_SIZE, shuffle=False,
                                                 num_workers=2)

    classes = val_dataset.classes

最后运行evaluate.py就可以。

结果

pytorch实现Resnet系列的分类任务

; 总结

以上就是今天要讲的内容,本文仅仅简单介绍了resnet分类网络的使用。
源码分享在网盘里面:网盘
提取码:57hs

Original: https://blog.csdn.net/wangshuhuan1/article/details/126034312
Author: 七白学长
Title: pytorch实现Resnet系列的分类任务

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/666904/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球