文章
*
– 一、MNIST数据集及Softmax
–
+ 1.MNIST数据集
+ 2.Softmax
– 二、MNIST数据集分类
–
+ 1.导入第三方库
+ 2.加载数据及数据预处理
+ 3.训练模型
一、MNIST数据集及Softmax
1.MNIST数据集
大多数示例使用手写数字的 MNIST数据集。该数据集包含 60,000个用于 训练的示例和 10,000个用于 测试的示例。
每一张 图片包含 28*28个像素,在MNIST 训练数据集 中是一个 形状为[60000,28,28] 的张量,我们 首先需要把数据集 转成[60000,784],然后 才能放到网络中 训练。 第一个维度数字用来 索引图片, 第二个维度数字用来 索引每张图片中的 像素点。一般我们还需要把图片中的 数据归一化0~1之间。
MNIST数据集的 标签是 介于0-9的数字,我们要把标签 转化为”one-hotvectors”。一个one-hot向量除了 一位数字 是1外, 其余维度数字 都是0, 比如标签0将表示为([1,0,0,0,0,0,0,0,0,0]),标签3将表示为([0,0,0,1,0,0,0,0,0,0])。
因此,MNIST数据集的标签是一个[60000,10]的数字矩阵。
2828=784, 每张图片有 784个像素点,对应着 784个神经元。最后输出 10个神经元对应着 10个数字*。
; 2.Softmax
Softmax作用就是把神经网络的输出转化为概率值。
我们知道MNIST的结果是0-9,我们模型可能推测出一张图片的数字9的概率是80%,是数字8的概率是10%,然后其他数字的概率更小,总体概率加起来等于1。这是一个使用softmax回归模型的经典案例。softmax模型可以用来给不同的对象分配概率。
二、MNIST数据集分类
代码运行平台为jupyter-notebook,文章中的代码块,也是按照jupyter-notebook中的划分顺序进行书写的,运行文章代码,直接分单元粘入到jupyter-notebook即可。
1.导入第三方库
import numpy as np
from keras.datasets import mnist
from keras.utils import np_utils
from keras.models import Sequential
from keras.layers import Dense
from tensorflow.keras.optimizers import SGD
2.加载数据及数据预处理
(x_train,y_train),(x_test,y_test) = mnist.load_data()
print("x_shape:\n",x_train.shape)
print("y_shape:\n",y_train.shape)
x_train = x_train.reshape(x_train.shape[0],-1)/255.0
x_test = x_test.reshape(x_test.shape[0],-1)/255.0
y_train = np_utils.to_categorical(y_train,num_classes=10)
y_test = np_utils.to_categorical(y_test,num_classes=10)
3.训练模型
model = Sequential([
Dense(units=10,input_dim=784,bias_initializer='one',activation="softmax"),
])
sgd = SGD(lr=0.2)
model.compile(
optimizer=sgd,
loss="mse",
metrics=['accuracy']
)
model.fit(x_train,y_train,batch_size=32,epochs=10)
loss,accuracy = model.evaluate(x_test,y_test)
print("\ntest loss",loss)
print("accuracy:",accuracy)
最终运行结果:
注意
Dense(units=10,input_dim=784,bias_initializer='one',activation="softmax")
这里用到了 softmax激活函数。- 这里我们使用的
fit
方法进行的模型训练,之前的线性回归和非线性回归的模型训练方式和这不同。
代码:
model.compile(
optimizer=sgd,
loss="mse",
metrics=['accuracy']
)
中 添加 metrics=['accuracy']
, 可以在训练过程中计算准确率。
Original: https://blog.csdn.net/booze_/article/details/125621175
Author: booze-J
Title: 3.MNIST数据集分类
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/691235/
转载文章受原作者版权保护。转载请注明原作者出处!