# 深度学习100例 | 第4例：水果识别 – PyTorch实现

### 文章目录

PyTorch VS TensorFlow

• TensorFlow：简单，模块封装比较好， 容易上手，对新手比较友好。在工业界最重要的是模型落地，目前国内的大部分企业支持TensorFlow模型在线部署，不支持Pytorch。
• PyTorch前沿算法多为PyTorch版本，如果是你高校学生or研究人员，建议学这个。相对于TensorFlow，Pytorch在易用性上更有优势，更加方便调试。

[En]

Of course, if you have plenty of time, I suggest that both models need to be understood, both of which are important.

🍨 本文的重点：将讲解如何使用PyTorch构建神经网络模型（将对这一块展开详细的讲解）

🍖 我的环境：

• 语言环境：Python3.8
• 编译器：Jupyter Lab
• 深度学习环境：
• torch==1.10.0+cu113
• torchvision==0.11.1+cu113

👉 往期精彩内容

## 一、导入数据

from torchvision.transforms import transforms
from torchvision            import datasets
import torchvision.models   as models
import torch.nn.functional  as F
import torch.nn             as nn
import torch,torchvision


### 1. 获取类别名字

import os,PIL,random,pathlib

data_dir = './04-data/'
data_dir = pathlib.Path(data_dir)

data_paths = list(data_dir.glob('*'))
classeNames = [str(path).split("\\")[1] for path in data_paths]
classeNames

['Apple',
'Banana',
'Carambola',
'Guava',
'Kiwi',
'Mango',
'muskmelon',
'Orange',
'Peach',
'Pear',
'Persimmon',
'Pitaya',
'Plum',
'Pomegranate',
'Tomatoes']


### 2. 加载数据文件

total_datadir = './04-data/'

train_transforms = transforms.Compose([
transforms.Resize([224, 224]),
transforms.ToTensor(),
transforms.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225])
])

total_data

Dataset ImageFolder
Number of datapoints: 12000
Root location: ./04-data/
StandardTransform
Transform: Compose(
Resize(size=[224, 224], interpolation=bilinear, max_size=None, antialias=None)
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)


### 3. 划分数据

train_size = int(0.8 * len(total_data))
test_size  = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data, [train_size, test_size])
train_dataset, test_dataset

(<torch.utils.data.dataset.subset at 0x24bbdb84ac0>,
<torch.utils.data.dataset.subset at 0x24bbdb84610>)
</torch.utils.data.dataset.subset></torch.utils.data.dataset.subset>

train_size,test_size

(9600, 2400)

train_loader = torch.utils.data.DataLoader(train_dataset,
batch_size=16,
shuffle=True,
num_workers=1)
batch_size=16,
shuffle=True,
num_workers=1)

print("The number of images in a training set is: ", len(train_loader)*16)
print("The number of images in a test set is: ", len(test_loader)*16)
print("The number of batches per epoch is: ", len(train_loader))

The number of images in a training set is:  9600
The number of images in a test set is:  2400
The number of batches per epoch is:  600

for X, y in test_loader:
print("Shape of X [N, C, H, W]: ", X.shape)
print("Shape of y: ", y.shape, y.dtype)
break

Shape of X [N, C, H, W]:  torch.Size([16, 3, 224, 224])
Shape of y:  torch.Size([16]) torch.int64


## 二、自建模型

nn.Conv2d()函数：

• 第一个参数（in_channels）是输入的channel数量，彩色图片为3，黑白图片为1。
• 第二个参数（out_channels）是输出的channel数量
• 第三个参数（kernel_size）是卷积核大小
• 第四个参数（stride）是步长，就是卷积操作时每次移动的格子数，默认为1

class Network_bn(nn.Module):
def __init__(self):
super(Network_bn, self).__init__()
"""
nn.Conv2d()函数：
第一个参数（in_channels）是输入的channel数量
第二个参数（out_channels）是输出的channel数量
第三个参数（kernel_size）是卷积核大小
第四个参数（stride）是步长，默认为1
"""
self.conv1 = nn.Conv2d(in_channels=3, out_channels=12, kernel_size=5, stride=1, padding=0)
self.bn1 = nn.BatchNorm2d(12)
self.conv2 = nn.Conv2d(in_channels=12, out_channels=12, kernel_size=5, stride=1, padding=0)
self.bn2 = nn.BatchNorm2d(12)
self.pool = nn.MaxPool2d(2,2)
self.conv4 = nn.Conv2d(in_channels=12, out_channels=24, kernel_size=5, stride=1, padding=0)
self.bn4 = nn.BatchNorm2d(24)
self.conv5 = nn.Conv2d(in_channels=24, out_channels=24, kernel_size=5, stride=1, padding=0)
self.bn5 = nn.BatchNorm2d(24)
self.fc1 = nn.Linear(24*50*50, len(classeNames))

def forward(self, x):
x = F.relu(self.bn1(self.conv1(x)))
x = F.relu(self.bn2(self.conv2(x)))
x = self.pool(x)
x = F.relu(self.bn4(self.conv4(x)))
x = F.relu(self.bn5(self.conv5(x)))
x = self.pool(x)
x = x.view(-1, 24*50*50)
x = self.fc1(x)

return x

device = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))

model = Network_bn().to(device)
model

Using cuda device

Network_bn(
(conv1): Conv2d(3, 12, kernel_size=(5, 5), stride=(1, 1))
(bn1): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv2): Conv2d(12, 12, kernel_size=(5, 5), stride=(1, 1))
(bn2): BatchNorm2d(12, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
(conv4): Conv2d(12, 24, kernel_size=(5, 5), stride=(1, 1))
(bn4): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(conv5): Conv2d(24, 24, kernel_size=(5, 5), stride=(1, 1))
(bn5): BatchNorm2d(24, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
(fc1): Linear(in_features=60000, out_features=15, bias=True)
)


## 三、模型训练

### 1. 优化器与损失函数

optimizer  = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=0.0001)
loss_model = nn.CrossEntropyLoss()

from torch.autograd import Variable

model.eval()
test_loss, correct = 0, 0
X, y = X.to(device), y.to(device)
pred = model(X)
test_loss += loss_model(pred, y).item()
correct += (pred.argmax(1) == y).type(torch.float).sum().item()
test_loss /= num_batches
correct /= size
print(f"Test Error: \n Accuracy: {(100*correct):>0.1f}%, Avg loss: {test_loss:>8f} \n")
return correct,test_loss

model=model.to(device)
model.train()

for i, (images, labels) in enumerate(train_loader, 0):

images = Variable(images.to(device))
labels = Variable(labels.to(device))

outputs = model(images)
loss = loss_model(outputs, labels)
loss.backward()
optimizer.step()

if i % 1000 == 0:
print('[%5d] loss: %.3f' % (i, loss))


### 2. 模型的训练

test_acc_list  = []
epochs = 30

for t in range(epochs):
print(f"Epoch {t+1}\n-------------------------------")
test_acc_list.append(test_acc)
print("Done!")

Epoch 1
[    0] loss: 0.468
Test Error:
Accuracy: 89.2%, Avg loss: 0.377265

......

Epoch 29
[    0] loss: 0.000
Test Error:
Accuracy: 91.8%, Avg loss: 0.660563

Done!



## 四、结果分析

import numpy as np
import matplotlib.pyplot as plt

x = [i for i in range(1,31)]

plt.plot(x, test_acc_list, label="Accuracy", alpha=0.8)

plt.xlabel("Epoch")
plt.ylabel("Accuracy")

plt.legend()
plt.show()


Original: https://blog.csdn.net/qq_38251616/article/details/125132612
Author: K同学啊
Title: 深度学习100例 | 第4例：水果识别 – PyTorch实现

(0)

### 大家都在看

• #### tensorflow2 tf2 DDPG算法玩立棍小游戏

DDPG算法就不做过多解读了,就是用来进行连续值预测,本文是使用DDPG进行立棍小游戏,详细过程解读注释在代码中,算法和模型都非常简单,考验的是基础,使用两个全连接模型,相互配合更…

人工智能 2023年5月24日
0205
• #### 关于论文《A Survey on Knowledge Graphs Representation, Acquisition and Applications》的心得复写一

关于论文《A Survey on Knowledge Graphs: Representation, Acquisition and Applications》的心得复写一 写在前…

人工智能 2023年6月10日
0137
• #### Python 数据分析函数汇总

人工智能 2023年7月7日
0132
• #### ARM S5PV210 汇编实现时钟设置代码详解

一、时钟设置的步骤分析 第1步：CLK_SRC寄存器的设置分析 先选择不使用 PLL。让外部 24MHz 原始时钟直接过去，绕过 APLL 那条路。 CLK_SRC 寄存器其实是用…

人工智能 2023年7月29日
0111
• #### Knowledge Graph Identification(知识图谱识别)

读论文 Knowledge Graph Identification(知识图谱识别) ; 摘要 大规模信息处理系统能够提取大量相关的事实，将这些候选事实转化为知识是挑战。在本文中，…

人工智能 2023年6月1日
0148
• #### 知识图谱笔记1：知识图谱概述

什么是知识图谱 是一种通用的语义知识描述框架：知识图谱是以三元组为基础，有向图作为数据结构，从知识本体和知识实例两个层次，对世界万物进行体系化、规范化的描述，并且实现高效的知识推理…

人工智能 2023年6月10日
0182
• #### Matplotlib 进阶（三）

一、Pandas绘图 Series和DataFrame是Pandas库中主要的两种数据结构，都内置了plot方法，可以绘制图形 1．Series.plot Series是一个一维数…

人工智能 2023年7月8日
0151
• #### 计算机视觉教程2-7：天使与恶魔?图文详解图像形态学运算(附代码)

目录 1 图像形态学运算 2 腐蚀 3 膨胀 4 开运算与闭运算 5 顶帽运算与底帽运算 6 恶魔与天使 1 图像形态学运算 在计算机视觉教程2-2：详解图像滤波算法(附Pytho…

人工智能 2023年7月27日
0163
• #### 通信网信息传输与分发技术国家级重点实验室2021年度预研基金项目申请指南

通信网信息传输与分发技术国家级重点实验室2021年度预研基金项目申请指南 人工智能技术与咨询 人工智能技术与咨询 北京龙腾亚太教育咨询有限公司依托中国管理科学研究院职业资格认证培训…

人工智能 2023年6月1日
0143
• #### 生物信息学概论——聚类分析TCGA-BRCA数据

人工智能 2023年6月2日
0190
• #### educoder-数据预处理基础

一、第1关引言-根深之树不怯风折，泉深之水不会涸竭 背景： 现实世界中数据大体上都是不完整，不一致的脏数据，无法直接进行数据挖掘，或挖掘结果差强人意。为了提高数据挖掘的质量产生了数…

人工智能 2023年7月18日
0213
• #### import tensorflow.keras as keras 报错No Module named keras

问题描述 环境:win10+anaconda+tf 1.2.0+keras 2.0.6+py 3.6.2 import tensorflow.keras as keras 在使用t…

人工智能 2023年5月25日
0133
• #### SCS【7】单细胞转录组之轨迹分析 (Monocle 3) 聚类、分类和计数细胞

点击关注，桓峰基因 桓峰基因公众号推出单细胞系列教程，有需要生信分析的老师可以联系我们！首选看下转录分析教程整理如下： Topic 6. 克隆进化之 Canopy Topic 7….

人工智能 2023年6月30日
0218
• #### 浅谈 NLP 细粒度情感分析（ABSA）

作者 | 周俊贤整理 | NewBeeNLP 最近在调研细粒度情感分析的论文，主要对一些深度学习方法进行调研，看论文的同时记录下自己的一些想法。 首先，何为细粒度的情感分析？如下图…

人工智能 2023年5月28日
0180
• #### MATLAB算法实战应用案例精讲-【回归算法】XGBoost算法（附Java、Python和R语言代码）

抵扣说明： 1.余额是钱包充值的虚拟货币，按照1:1的比例进行支付金额的抵扣。2.余额无法直接购买下载，可以购买VIP、C币套餐、付费专栏及课程。 Original: https:…

人工智能 2023年6月17日
0169
• #### 机器学习——梯度提升决策树（GBDT）

相关文章链接： 机器学习——人工神经网络（NN） 机器学习——卷积神经网络（CNN） 机器学习——循环神经网络（RNN） 机器学习——长短期记忆（LSTM） 机器学习——决策树（d…

人工智能 2023年6月25日
0184