Pytorch深度学习基础 实战天气图片识别(基于ResNet50预训练模型,超详细)

文章目录

一、概述

🔥本项目使用 Pytroch,并基于 ResNet50模型,实现了对天气图片的识别,过程详细,十分适合基础阶段的同学阅读。

项目目录结构

Pytorch深度学习基础 实战天气图片识别(基于ResNet50预训练模型,超详细)

核心步骤

  • 数据处理
  • 准备配置文件
  • 构建自定义 DataSetDataloader
  • 构建模型
  • 训练模型
  • 编写预测模块
  • 效果展示

; 二、代码编写

1. 数据处理

本项目数据来源:
https://www.heywhale.com/mw/dataset/60d9bd7c056f570017c305ee/file
http://vcc.szu.edu.cn/research/2017/RSCM.html

由于数据是直接下载,且目录分的很规整,本项目的数据处理部分较为简单,直接手动复制,合并两个数据集即可。

数据概览

Pytorch深度学习基础 实战天气图片识别(基于ResNet50预训练模型,超详细)

Pytorch深度学习基础 实战天气图片识别(基于ResNet50预训练模型,超详细)

总数据量约 7万张

; 2. 准备配置文件

配置文件的主要存储一些各个模块通用的一些全局变量,如各种文件的存放位置等等(本人Java程序员出身,一些Python的代码规范不太熟悉,望见谅)。

config.py

import time

import torch

class Common:
    '''
    通用配置
    '''
    basePath = "D:/Data/weather/source/all/"
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    imageSize = (224,224)
    labels = ["cloudy","haze","rainy","shine","snow","sunny","sunrise","thunder"]

class Train:
    '''
    训练相关配置
    '''
    batch_size = 128
    num_workers = 0
    lr = 0.001
    epochs = 40
    logDir = "./log/" + time.strftime('%Y-%m-%d-%H-%M-%S',time.gmtime())
    modelDir = "./model/"

3. 自定义DataSet和DataLoader

dada_loader.py


import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from config import Common
from config import Train
import os
from PIL import Image
import torch.utils.data as Data
import numpy

transform = transforms.Compose([
    transforms.Resize(Common.imageSize),
    transforms.ToTensor()
])

def loadDataFromDir():
    '''
    从文件夹中获取数据
    '''
    images = []
    labels = []

    for d in os.listdir(Common.basePath):
        for imagePath in os.listdir(Common.basePath + d):

            image = Image.open(Common.basePath + d + "/" + imagePath).convert('RGB')
            print("加载数据" + str(len(images)) + "条")

            images.append(transform(image))

            categoryIndex = Common.labels.index(d)
            label = [0] * 8
            label[categoryIndex] = 1
            label = torch.tensor(label,dtype=torch.float)

            labels.append(label)

            image.close()

    return images, labels

class WeatherDataSet(Dataset):
    '''
    自定义DataSet
    '''

    def __init__(self):
        '''
        初始化DataSet
        :param transform: 自定义转换器
        '''
        images, labels = loadDataFromDir()
        self.images = images
        self.labels = labels

    def __len__(self):
        '''
        返回数据总长度
        :return:
        '''
        return len(self.images)

    def __getitem__(self, idx):
        image = self.images[idx]
        label = self.labels[idx]
        return image, label

def splitData(dataset):
    '''
    分割数据集
    :param dataset:
    :return:
    '''

    total_length = len(dataset)

    train_length = int(total_length * 0.8)
    validation_length = total_length - train_length

    train_dataset,validation_dataset = Data.random_split(dataset=dataset, lengths=[train_length, validation_length])
    return train_dataset, validation_dataset

train_dataset, validation_dataset = splitData(WeatherDataSet())

trainLoader = DataLoader(train_dataset, batch_size=Train.batch_size, shuffle=True, num_workers=Train.num_workers)

valLoader = DataLoader(validation_dataset, batch_size=Train.batch_size, shuffle=False,
                       num_workers=Train.num_workers)

主要步骤:

  1. 读取图片使用的是Python自带的 PIL

PIL教程:https://blog.csdn.net/weixin_43790276/article/details/108478270

  1. 由于使用的是残差网络,其图片尺寸必须是 3*224*224,故需要使用Pytroch的 transforms工具进行处理

transforms教程:https://blog.csdn.net/qq_38410428/article/details/94719553

  1. 自定义 DataSet(继承DataSet类,并实现重写三个核心方法)
  2. 分割数据
  3. 创建验证集和训练集各自的加载器

4. 构建模型

model.py

import torch
from torch import nn
import torchvision.models as models
from config import Common, Train

net = models.resnet50()
net.load_state_dict(torch.load("./model/resnet50-11ad3fa6.pth"))

class WeatherModel(nn.Module):
    def __init__(self, net):
        super(WeatherModel, self).__init__()

        self.net = net
        self.relu = nn.ReLU()
        self.dropout = nn.Dropout(0.1)
        self.fc = nn.Linear(1000, 8)
        self.output = nn.Softmax(dim=1)

    def forward(self, x):
        x = self.net(x)
        x = self.relu(x)
        x = self.dropout(x)
        x = self.fc(x)
        x = self.output(x)
        return x

model = WeatherModel(net)

主要步骤:

  1. 引入 Pytorch官方的残差网络预训练模型

关于新版本的引入方法:https://blog.csdn.net/Sihang_Xie/article/details/125646287

  1. 添加自己的全连接输出层
  2. 创建模型

5. 训练模型

train.py


import time
import torch
from torch import nn
import matplotlib.pyplot as plt
from torch.utils.tensorboard import SummaryWriter
from config import Common, Train
from model import model as weatherModel
from data_loader import trainLoader, valLoader
from torch import optim

model = weatherModel
model.to(Common.device)

criterion = nn.CrossEntropyLoss()

optimizer = optim.Adam(model.parameters(), lr=0.001)

writer = SummaryWriter(log_dir=Train.logDir, flush_secs=500)

def train(epoch):
    '''
    训练函数
    '''

    loader = trainLoader

    model.train()
    print()
    print('========== Train Epoch:{} Start =========='.format(epoch))
    epochLoss = 0
    epochAcc = 0
    correctNum = 0
    for data, label in loader:
        data, label = data.to(Common.device), label.to(Common.device)
        batchAcc = 0
        batchCorrectNum = 0
        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, label)
        loss.backward()
        optimizer.step()
        epochLoss += loss.item() * data.size(0)

        labels = torch.argmax(label, dim=1)
        outputs = torch.argmax(output, dim=1)
        for i in range(0, len(labels)):
            if labels[i] == outputs[i]:
                correctNum += 1
                batchCorrectNum += 1
        batchAcc = batchCorrectNum / data.size(0)
        print("Epoch:{}\t TrainBatchAcc:{}".format(epoch, batchAcc))

    epochLoss = epochLoss / len(trainLoader.dataset)
    epochAcc = correctNum / len(trainLoader.dataset)
    print("Epoch:{}\t Loss:{} \t Acc:{}".format(epoch, epochLoss, epochAcc))
    writer.add_scalar("train_loss", epochLoss, epoch)
    writer.add_scalar("train_acc", epochAcc, epoch)
    return epochAcc

def val(epoch):
    '''
    验证函数
    :param epoch: 轮次
    :return:
    '''

    loader = valLoader

    valLoss = []
    valAcc = []

    model.eval()
    print()
    print('========== Val Epoch:{} Start =========='.format(epoch))
    epochLoss = 0
    epochAcc = 0
    correctNum = 0
    with torch.no_grad():
        for data, label in loader:
            data, label = data.to(Common.device), label.to(Common.device)
            batchAcc = 0
            batchCorrectNum = 0
            output = model(data)
            loss = criterion(output, label)
            epochLoss += loss.item() * data.size(0)

            labels = torch.argmax(label, dim=1)
            outputs = torch.argmax(output, dim=1)
            for i in range(0, len(labels)):
                if labels[i] == outputs[i]:
                    correctNum += 1
                    batchCorrectNum += 1
            batchAcc = batchCorrectNum / data.size(0)
            print("Epoch:{}\t ValBatchAcc:{}".format(epoch, batchAcc))

        epochLoss = epochLoss / len(valLoader.dataset)
        epochAcc = correctNum / len(valLoader.dataset)
        print("Epoch:{}\t Loss:{} \t Acc:{}".format(epoch, epochLoss, epochAcc))
        writer.add_scalar("val_loss", epochLoss, epoch)
        writer.add_scalar("val_acc", epochAcc, epoch)
    return epochAcc

if __name__ == '__main__':
    maxAcc = 0.75
    for epoch in range(1,Train.epochs + 1):
        trainAcc = train(epoch)
        valAcc = val(epoch)
        if valAcc > maxAcc:
            maxAcc = valAcc

            torch.save(model, Train.modelDir + "weather-" + time.strftime('%Y-%m-%d-%H-%M-%S', time.gmtime()) + ".pth")

    torch.save(model,Train.modelDir+"weather-"+time.strftime('%Y-%m-%d-%H-%M-%S',time.gmtime())+".pth")

主要步骤

  1. 加载模型
  2. 准备损失函数及优化器
  3. 创建 tensorboard的writer

关于 tensorboard的使用:https://blog.csdn.net/weixin_43637851/article/details/116003280

  1. 编写训练函数及验证函数,同时记录损失和正确率

验证函数和训练函数的区别就是是否需要更新参数

  1. 循环训练 epochs次,不断保存正确率最大的模型,以及最后一次的训练模型
  2. 开始训练
  3. 不断调参(我就只训了3次),知道有一个比较满意的效果

训练过程中电脑的状态:

Pytorch深度学习基础 实战天气图片识别(基于ResNet50预训练模型,超详细)
查看训练日志(tensorboard)
Pytorch深度学习基础 实战天气图片识别(基于ResNet50预训练模型,超详细)
Pytorch深度学习基础 实战天气图片识别(基于ResNet50预训练模型,超详细)
保存的模型
Pytorch深度学习基础 实战天气图片识别(基于ResNet50预训练模型,超详细)

; 6. 编写预测模块

pridect.py

import torch
import torchvision.transforms as transforms
from PIL import Image
from config import Common
def pridect(imagePath, modelPath):
    '''
    预测函数
    :param imagePath: 图片路径
    :param modelPath: 模型路径
    :return:
    '''

    image = Image.open(imagePath)

    image = image.resize(Common.imageSize)
    image.show()

    model = torch.load(modelPath)
    model = model.to(Common.device)

    transform = transforms.ToTensor()
    x = transform(image)
    x = torch.unsqueeze(x, 0)
    x = x.to(Common.device)

    output = model(x)

    output = torch.argmax(output)
    print("预测结果:",Common.labels[output.item()])

if __name__ == '__main__':
    pridect("D:/Download/76ee4c5e833499949eac41561dcb487d.jpeg","./model/weather-2022-10-14-07-36-57.pth")

三、效果展示

去网上随便找的图片:

Pytorch深度学习基础 实战天气图片识别(基于ResNet50预训练模型,超详细)

Pytorch深度学习基础 实战天气图片识别(基于ResNet50预训练模型,超详细)

Pytorch深度学习基础 实战天气图片识别(基于ResNet50预训练模型,超详细)
Pytorch深度学习基础 实战天气图片识别(基于ResNet50预训练模型,超详细)

; 四、源码地址

https://github.com/mengxianglong123/weather-recognition

欢迎交流学习🥰🥰🥰

Original: https://blog.csdn.net/mengxianglong123/article/details/127330721
Author: 落花雨时
Title: Pytorch深度学习基础 实战天气图片识别(基于ResNet50预训练模型,超详细)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/770180/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • Python pip修改镜像源为豆瓣源的两种方法

    Python pip修改镜像源为豆瓣源常常遇到pip装包时速度过慢或者无法安装(请求超时)等问题,这个时候你就需要考虑一下给pip换源了 一、临时的方法 pip -i https:…

    Python 2023年9月17日
    073
  • webRTC demo

    准备: 信令服务 前端页面用于视频通话 demo github 地址。 前端页面 为了使 demo 尽量简单,功能页面如下,即包含登录、通过对方手机号拨打电话的功能。在实际生成过程…

    Python 2023年10月16日
    048
  • python3.7的下载,以及详细的安装教程

    目录如下 * – + 一、去Python官网下载你想要安装的python版本 + 二、python的安装 + 三、cmd打开页面运行python可能存在的异常 一、去P…

    Python 2023年8月1日
    064
  • ChatGPT/InstructGPT详解

    404. 抱歉,您访问的资源不存在。 可能是网址有误,或者对应的内容被删除,或者处于私有状态。 代码改变世界,联系邮箱 contact@cnblogs.com 弹尽粮绝,会员救园:…

    Python 2023年10月28日
    022
  • 从numpy掩码到pytorch掩码

    一、numpy布尔索引掩码 import numpy as np data = np.arange(1, 11) mask=[True,False,True,False,True,…

    Python 2023年8月25日
    039
  • 蓝桥杯有必要参赛吗?

    昨天和群里的小伙伴在群里聊,有的小伙伴竟然说蓝桥杯一等奖没有含量,我也是醉了! 就像去年看了一个号主写的:研究生遍地都是! 放眼全国14亿人口,别说研究生了,本科生占比有多少? &…

    Python 2023年9月15日
    038
  • 与众不同的异域年夜饭体验,你最中意哪一款?

    年夜饭,中国人一年中最重要的一顿团圆聚餐,不仅丰富多彩,还充满了各种吉祥寓意。如果你选择的是出境旅游过春节,那么一次异域年夜饭体验也可以让你的旅行充满乐趣,收获与众不同的别样回忆。…

    Python 2023年11月6日
    036
  • 房地产特征价格评估的次市场效应模型: 一种概率方法撰写

    文章目录 一、数据预处理部分 * (一)使用到的库 (二)使用到的函数 (三)实现流程 二、POI数据处理部分 * (一)使用到的库 (二)使用到的函数 (三)实现流程 三、BN …

    Python 2023年9月6日
    068
  • Linux–多线程(三)

    概念: 生产者消费者模式就是通过一个容器来解决生产者和消费者的强耦合问题。生产者和消费者彼此之间不直接通讯,而通过一个来进行通讯,所以生产者生产完数据之后不用等待消费者处理,直接扔…

    Python 2023年10月16日
    030
  • 【中秋征文】使用Python创意中秋节画月饼《花好月圆》

    大家好,我是猿童学🐵,又是一年中秋至——花好月圆夜,祝大家中秋节快乐!欢迎收看中秋创造第一期。今年是我在CSDN第一次过中秋节,特意为此去学习了用Python来画月饼,不仅可以学习…

    Python 2023年8月2日
    059
  • 【Flask+Echarts】使用Flask框架可视化的案例

    回答1: 和MySQL来实现数据 ,可以通过 步骤来实现: 1. 来搭建Web应用程序,根据需要设置路由和视图函数。 2. MySQL数据库来存储数据,建立需要的数据表,并通过 的…

    Python 2023年8月9日
    058
  • 【双目视觉】 SGBM算法应用(Python版)

    文章目录 * – 流程图 – 相机标定 – 立体匹配 – 效果 – + 1.原图像 + 2.深度图 + 3.代码链接 流…

    Python 2023年8月22日
    076
  • Conda 创建和删除虚拟环境Conda 创建和删除虚拟环境

    Conda 创建和删除虚拟环境 一.检验当前conda的版本 conda -V 二.conda常用的命令 1.查看已有的虚拟环境 conda env list 2.创建虚拟环境和删…

    Python 2023年9月9日
    072
  • DataFrame数据分析

    注:文中用到的数据文件可以在资源中免费获取。 基本统计 常用统计函数表 非空元素计算 最小值 最大值 最小值的位置 最大值的位置 )10%分位数 中位数 标准差 平均绝对偏差 一次…

    Python 2023年8月18日
    037
  • 机器学习练习题

    单项选择题 1.在NumPy中创建一个元素均为0的数组可以使用( )函数。 [A]A.zeros( ) B.arange( ) C.linspace( ) D.logspace( …

    Python 2023年9月29日
    076
  • tomcat的搭建和介绍

    第19章 tomcat的搭建 19.1 tomcat 学习之前的预备知识 19.1.1 什么是 JVM 和 JDK,JRE JVM java虚拟机,实现一份代码可以在不同的平台执行…

    Python 2023年6月10日
    065
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球