# 小熊飞桨练习册-01手写数字识别

## 数据集

• 运行脚本，包含以下步骤：获取数据，检查数据。

bash onekey.sh


bash get-data.sh


bash check-data.sh


## 网络模型

import paddle

# LeNet 网络模型
class LeNet(nn.Layer):
def __init__(self, num_classes=10):
super(LeNet, self).__init__()
if num_classes < 1:
raise Exception("分类数量 num_classes 必须大于 0: {}".format(num_classes))
self.num_classes = num_classes
self.conv1 = nn.Conv2D(
in_channels=1, out_channels=6, kernel_size=5, stride=1)
self.avg_pool1 = nn.AvgPool2D(kernel_size=2, stride=2)
self.conv2 = nn.Conv2D(
in_channels=6, out_channels=16, kernel_size=5, stride=1)
self.avg_pool2 = nn.AvgPool2D(kernel_size=2, stride=2)
self.conv3 = nn.Conv2D(
in_channels=16, out_channels=120, kernel_size=4, stride=1)
self.fc1 = nn.Linear(in_features=120, out_features=64)
self.fc2 = nn.Linear(in_features=64, out_features=num_classes)

def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.avg_pool1(x)
x = self.conv2(x)
x = F.relu(x)
x = self.avg_pool2(x)
x = self.conv3(x)
x = F.relu(x)
x = self.fc1(x)
x = F.relu(x)
x = self.fc2(x)
return x


## 数据集解析

import paddle
import os
import struct
import numpy as np

"""
"""

def __init__(self,
images_path: str,
labels_path: str,
transform=None,
):
"""
构造函数，定义数据集大小

Args:
images_path (str): 图像集路径
labels_path (str): 标签集路径
transform (Compose, optional): 转换数据的操作组合, 默认 None
"""
super(MNIST, self).__init__()
self.images_path = images_path
self.labels_path = labels_path
self._check_path(images_path, "数据路径错误")
self._check_path(labels_path, "标签路径错误")
self.transform = transform
self.images, self.labels = self.parse_dataset(images_path, labels_path)

def __getitem__(self, idx):
"""
获取单个数据和标签

Args:
idx (Any): 索引

Returns:
image (float32): 图像
label (int64): 标签
"""
image, label = self.images[idx], self.labels[idx]
# 这里 reshape 是2维 [28 ,28]
image = np.reshape(image, [28, 28])
if self.transform is not None:
image = self.transform(image)
# label.astype 如果是整型，只能是 int64
return image.astype('float32'), label.astype('int64')

def __len__(self):
"""
数据数量

Returns:
int: 数据数量
"""
return len(self.labels)

def _check_path(self, path: str, msg: str):
"""
检查路径是否存在

Args:
path (str): 路径
msg (str, optional): 异常消息

Raises:
Exception: 路径错误, 异常
"""
if not os.path.exists(path):
raise Exception("{}: {}".format(msg, path))

@staticmethod
def parse_dataset(images_path: str, labels_path: str):
"""
数据集解析

Args:
images_path (str): 图像集路径
labels_path (str): 标签集路径

Returns:
images: 图像集
labels: 标签集
"""
with open(images_path, 'rb') as imgpath:
# 解析图像集
magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
# 这里 reshape 是1维 [786]
images = np.fromfile(
imgpath, dtype=np.uint8).reshape(num, rows * cols)
with open(labels_path, 'rb') as lbpath:
# 解析标签集
labels = np.fromfile(lbpath, dtype=np.uint8)
return images, labels


## 开始训练

python3 train.py

  --cpu             是否使用 cpu 计算，默认使用 CUDA
--learning-rate   学习率，默认 0.001
--epochs          训练几轮，默认 2 轮
--batch-size      一批次数量，默认 128
--num-workers     线程数量，默认 2
--no-save         是否保存模型参数，默认保存, 选择后不保存模型参数
--log             是否输出 VisualDL 日志，默认不输出
--summary         输出网络模型信息，默认不输出，选择后只输出信息，不会开启训练


## 测试模型

python3 test.py

  --cpu           是否使用 cpu 计算，默认使用 CUDA
--batch-size    一批次数量，默认 128
--num-workers   线程数量，默认 2
--load-dir      读取模型参数，读取 params 目录下的子文件夹, 默认 best 目录


## 查看结果报表

python3 report.py


## VisualDL 可视化分析工具

• 安装和使用说明参考：VisualDL
• 训练的时候加上参数 –log
• 如果是 AI Studio 环境训练的把 log 目录下载下来，解压缩后放到本地项目目录下 log 目录
• 在项目目录下运行下面命令
• 然后根据提示的网址，打开浏览器访问提示的网址即可
visualdl --logdir ./log


Original: https://www.cnblogs.com/cnhemiya/p/16137055.html
Author: 小熊宝宝啊
Title: 小熊飞桨练习册-01手写数字识别

## 相关阅读

(0)

### 大家都在看

• #### 推荐 5 个 yyds 的开源 Python Web 框架

Python 2023年1月2日
041
• #### 【Python 趣味习题】

python趣味习题目录 凯撒密码-层层加密 * – 运行结果如下： 个人成绩计算 猜数字游戏 解数学方程 * Pandas 每日一练： – 1.将下面的字…

2022年8月19日
090
• #### ChatGPT通俗笔记：从GPT-N、RL之PPO算法到instructGPT、ChatGPT

前言 自从我那篇BERT通俗笔记一经发布，然后就不断改、不断找人寻求反馈、不断改，其中一位朋友倪老师(之前我司NLP高级班学员现课程助教老师之一)在谬赞BERT笔记无懈可击的同时，…

Python 2023年2月5日
049
• #### python之高级数据结构Collections

collections模块包含了内建类型之外的一些有用的工具，例如Counter、defaultdict、OrderedDict、deque以及nametuple。其中Counte…

Python 2022年9月6日
0118
• #### 送你5个MindSpore算子使用经验

摘要：MindSpore给大家提供了很多算子进行使用，今天给大家简单介绍下常用的一些算子使用时需要注意的内容。 MindSpore给大家提供了很多算子进行使用，今天给大家简单介绍下…

Python 2023年2月2日
026
• #### 时间序列学习（5）：ARMA模型定阶（AIC、BIC准则、Ljung-Box检验）

时间序列学习（5）：ARMA模型定阶（AIC、BIC准则、Ljung-Box检验） * – + 1、信息量准则 + 2、寻找对数收益率序列的最佳阶数 + 3、构建模型 …

2022年8月31日
0411
• #### pytorch基础

tensor = torch.randn(2,3,4) print(tensor.type()) # 数据类型 torch.FloatTensor，是一个浮点型的张量 print(…

Python 2023年1月12日
036
• #### Python pandas中DataFrame添加列、获取行列、获取元素值

直接通过赋值为空，添加一列。 >>> import pandas as pd >>> df = pd.DataFrame(np.arange(1…

Python 2022年12月30日
059
• #### Pandas文本数据

一、合并功能 （一）merge：pd.merge() 类似于vlookup 函数的作用，只会返回两个表中都含有的元素。表达形式 pd.merge(left,right,how = …

Python 2022年12月31日
041
• #### python 播放mp3模块_Python基于pygame模块播放MP3的方法示例

本文实例讲述了Python基于pygame模块播放MP3的方法。分享给大家供大家参考，具体如下： 安装pygame(可参考：安装Python和pygame及相应的环境变量配置) p…

Python 2023年1月22日
020
• #### 从零开始Docker部署Web服务

Python 2023年1月6日
032
• #### python np array归一化_浅谈利用numpy对矩阵进行归一化处理的方法

Python 2023年1月12日
033
• #### Python OpenCV配置CUDA以支持GPU加速 (不使用Visual Studio)

Welcome to My Blog 文章唯一地址：https://blog.csdn.net/REAL_liudebai/article/details/119356958 问题…

2022年8月28日
0293
• #### Python学习笔记

Python pandas库㈢ 前言 一、数据清洗 * ①缺失值处理 – (1)查看缺失值 (2)处理缺失值 ②重复值处理 – (1)查看重复值 (2)处理…

Python 2023年1月8日
036
• #### Python 国家地震台网 地震数据集完整分析、pyecharts、plotly，分析强震次数、震级分布、震级震源关系、发生位置、发生时段、最大震级、平均震级

Original: ython ] PyQt5 PySide2 笔记Author: PythonTitle: YEUNGCHIE

Python 2022年9月3日
0126
• #### scrapy中文指南 第二章 项目初始化和第一个小例子

第二章 项目初始化和第一个小例子.md 初始化项目 项目目录介绍 自定义爬虫类 * 自定义第一个爬虫 录入代码 – 代码解释： 运行爬虫 – 运行结果 代码…

Python 2023年1月26日
033