在各种主流的深度学习框架中,手写数字的识别均作为第一个入门教程,同样,在开始学习Keras这个近些年非常流行的深度学习框架时,也用全连接手写数字识别作为入门的第一个例子。
为了紧跟时代潮流,在之后我的所有博客中,均使用TensorFlow2以上的版本,而在TensorFlow2中的版本,Keras已经集成到tf中去了,可见谷歌也在逐步放弃TensorFlow1中的模型搭建方法,转而使用Keras来进行模型训练。因此可见,Keras的优势,也希望大家从一开始就学习Keras。
使用的TensorFlow版本:2.5.0
1 全连接手写数字模型的训练
训练代码如下所示:
from tensorflow.keras.datasets import mnist
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras.utils import to_categorical
if __name__=='__main__':
(train_images,train_lables),(test_images,test_lables)=mnist.load_data()
train_images=train_images.reshape(train_images.shape[0],-1)
train_images=train_images.astype('float32')/255
test_images=test_images.reshape(test_images.shape[0],-1)
test_images=test_images.astype('float32')/255
train_lables=to_categorical(train_lables)
test_lables=to_categorical(test_lables)
network=models.Sequential()
network.add(layers.Dense(units=512,activation='relu',input_shape=(28*28,)))
network.add(layers.Dense(units=10,activation='softmax'))
network.compile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['accuracy'])
network.summary()
history=network.fit(x=train_images,y=train_lables,epochs=5,batch_size=128)
print(history.history)
test_loss,test_acc=network.evaluate(x=test_images,y=test_lables)
print(test_acc)
network.save(filepath='./model/test.h5')
模型框架采用两层神经网络结构,数据输入结构为(784,)
[En]
The model framework adopts a two-layer neural network structure, and the structure of data input is (784,)
第一个隐藏层的权重参数W1形状为(784,512),偏置为(512,),
隐藏层的激活函数采用Relu函数,Relu函数的公式为:
输出图像为:
第二层为输出层,输出层的权重参数形状为(512,10),偏置为(10,),因为手写数字的识别属于分类问题,分类问题的输出层激活函数使用softmax函数,softmax函数的数学公式如下:
这一句代码也很重要:
network.compile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['accuracy'])
optimizer=’rmsprop’表示在神经网络训练中,参数的更新采用的方法为rmsprop,该方法可以一开始多学,之后少学,也就是一开始的学习率设置的大一些,随着训练的进行,学习率会逐渐降低。
loss=’categorical_crossentropy’表示为损失函数为分类交叉熵,工时如下:
使用上面的代码训练,训练过程如下:
[En]
Using the code training above, the training process is as follows:
Epoch 1/5
469/469 [==============================] - 5s 8ms/step - loss: 0.2536 - accuracy: 0.9266
Epoch 2/5
469/469 [==============================] - 4s 8ms/step - loss: 0.1035 - accuracy: 0.9695
Epoch 3/5
469/469 [==============================] - 4s 9ms/step - loss: 0.0696 - accuracy: 0.9789
Epoch 4/5
469/469 [==============================] - 5s 10ms/step - loss: 0.0494 - accuracy: 0.9848
Epoch 5/5
469/469 [==============================] - 5s 10ms/step - loss: 0.0375 - accuracy: 0.9887
{'loss': [0.25358861684799194, 0.10348740965127945, 0.06957202404737473, 0.04940706863999367, 0.03746318444609642], 'accuracy': [0.9265999794006348, 0.9694666862487793, 0.9788500070571899, 0.9847833514213562, 0.9887333512306213]}
313/313 [==============================] - 1s 2ms/step - loss: 0.0676 - accuracy: 0.9797
0.9797000288963318
训练完成后,得到h5模型文件,我们使用Netron这个工具查看模型文件,可以很清晰的查看模型的网络结构图。
2 全连接手写数字模型的调用
在对模型进行训练后,我们可以使用该模型对手写数字图像进行预测,从而进一步检验该模型的效果。
[En]
After training the model, we can use the model to predict the handwritten digital image, which can further test the effect of the model.
调用代码如下:
from tensorflow.keras.datasets import mnist
from tensorflow.keras import models
import numpy as np
import cv2
if __name__=='__main__':
(train_images,train_lables),(test_images,test_lables)=mnist.load_data()
test_img=test_images[2]
test_img=test_img.reshape(test_img.shape+(1,))
cv2.imshow('test',test_img)
cv2.waitKey(0)
model=models.load_model('./model/test.h5')
test_data=test_img.reshape(1,-1)
test_data=test_data.astype('float32')/255
output=model.predict(test_data)
print(output)
output_argmax=output.argmax(axis=1)
print(output_argmax)
print('模型预测的值为:'+str(output_argmax[0]))
print('图片的真实标签值为:'+str(test_lables[2]))
识别的图像显示为:
预测的结果为:
[[6.3377036e-07 9.9490094e-01 4.5994075e-04 5.5514480e-05 9.1660942e-05
1.7706998e-05 3.7915866e-05 2.5628312e-03 1.8701523e-03 2.7018891e-06]]
[1]
模型预测的值为:1
图片的真实标签值为:1
由此可见,该模型对图像的预测结果是正确的。
[En]
It can be seen that the prediction result of the model for the picture is correct.
Original: https://blog.csdn.net/qq_37781464/article/details/122684455
Author: Keras深度学习
Title: Keras深度学习(1)-全连接手写数字的识别
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/511214/
转载文章受原作者版权保护。转载请注明原作者出处!