Keras深度学习(1)-全连接手写数字的识别

在各种主流的深度学习框架中,手写数字的识别均作为第一个入门教程,同样,在开始学习Keras这个近些年非常流行的深度学习框架时,也用全连接手写数字识别作为入门的第一个例子。
为了紧跟时代潮流,在之后我的所有博客中,均使用TensorFlow2以上的版本,而在TensorFlow2中的版本,Keras已经集成到tf中去了,可见谷歌也在逐步放弃TensorFlow1中的模型搭建方法,转而使用Keras来进行模型训练。因此可见,Keras的优势,也希望大家从一开始就学习Keras。
使用的TensorFlow版本:2.5.0

1 全连接手写数字模型的训练

训练代码如下所示:

from tensorflow.keras.datasets import mnist
from tensorflow.keras import models
from tensorflow.keras import layers
from tensorflow.keras.utils import to_categorical

if __name__=='__main__':

    (train_images,train_lables),(test_images,test_lables)=mnist.load_data()

    train_images=train_images.reshape(train_images.shape[0],-1)

    train_images=train_images.astype('float32')/255

    test_images=test_images.reshape(test_images.shape[0],-1)
    test_images=test_images.astype('float32')/255

    train_lables=to_categorical(train_lables)
    test_lables=to_categorical(test_lables)

    network=models.Sequential()
    network.add(layers.Dense(units=512,activation='relu',input_shape=(28*28,)))
    network.add(layers.Dense(units=10,activation='softmax'))

    network.compile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['accuracy'])
    network.summary()

    history=network.fit(x=train_images,y=train_lables,epochs=5,batch_size=128)
    print(history.history)

    test_loss,test_acc=network.evaluate(x=test_images,y=test_lables)
    print(test_acc)

    network.save(filepath='./model/test.h5')

模型框架采用两层神经网络结构,数据输入结构为(784,)

[En]

The model framework adopts a two-layer neural network structure, and the structure of data input is (784,)

第一个隐藏层的权重参数W1形状为(784,512),偏置为(512,),
隐藏层的激活函数采用Relu函数,Relu函数的公式为:

Keras深度学习(1)-全连接手写数字的识别
输出图像为:
Keras深度学习(1)-全连接手写数字的识别

第二层为输出层,输出层的权重参数形状为(512,10),偏置为(10,),因为手写数字的识别属于分类问题,分类问题的输出层激活函数使用softmax函数,softmax函数的数学公式如下:

Keras深度学习(1)-全连接手写数字的识别
这一句代码也很重要:
network.compile(optimizer='rmsprop',loss='categorical_crossentropy',metrics=['accuracy'])

optimizer=’rmsprop’表示在神经网络训练中,参数的更新采用的方法为rmsprop,该方法可以一开始多学,之后少学,也就是一开始的学习率设置的大一些,随着训练的进行,学习率会逐渐降低。
loss=’categorical_crossentropy’表示为损失函数为分类交叉熵,工时如下:

Keras深度学习(1)-全连接手写数字的识别

使用上面的代码训练,训练过程如下:

[En]

Using the code training above, the training process is as follows:

Epoch 1/5
469/469 [==============================] - 5s 8ms/step - loss: 0.2536 - accuracy: 0.9266
Epoch 2/5
469/469 [==============================] - 4s 8ms/step - loss: 0.1035 - accuracy: 0.9695
Epoch 3/5
469/469 [==============================] - 4s 9ms/step - loss: 0.0696 - accuracy: 0.9789
Epoch 4/5
469/469 [==============================] - 5s 10ms/step - loss: 0.0494 - accuracy: 0.9848
Epoch 5/5
469/469 [==============================] - 5s 10ms/step - loss: 0.0375 - accuracy: 0.9887
{'loss': [0.25358861684799194, 0.10348740965127945, 0.06957202404737473, 0.04940706863999367, 0.03746318444609642], 'accuracy': [0.9265999794006348, 0.9694666862487793, 0.9788500070571899, 0.9847833514213562, 0.9887333512306213]}
313/313 [==============================] - 1s 2ms/step - loss: 0.0676 - accuracy: 0.9797
0.9797000288963318

训练完成后,得到h5模型文件,我们使用Netron这个工具查看模型文件,可以很清晰的查看模型的网络结构图。

Keras深度学习(1)-全连接手写数字的识别

2 全连接手写数字模型的调用

在对模型进行训练后,我们可以使用该模型对手写数字图像进行预测,从而进一步检验该模型的效果。

[En]

After training the model, we can use the model to predict the handwritten digital image, which can further test the effect of the model.

调用代码如下:

from tensorflow.keras.datasets import mnist
from tensorflow.keras import models
import numpy as np
import cv2

if __name__=='__main__':

    (train_images,train_lables),(test_images,test_lables)=mnist.load_data()

    test_img=test_images[2]
    test_img=test_img.reshape(test_img.shape+(1,))

    cv2.imshow('test',test_img)
    cv2.waitKey(0)

    model=models.load_model('./model/test.h5')

    test_data=test_img.reshape(1,-1)
    test_data=test_data.astype('float32')/255
    output=model.predict(test_data)
    print(output)

    output_argmax=output.argmax(axis=1)
    print(output_argmax)

    print('模型预测的值为:'+str(output_argmax[0]))
    print('图片的真实标签值为:'+str(test_lables[2]))

识别的图像显示为:

Keras深度学习(1)-全连接手写数字的识别
预测的结果为:
[[6.3377036e-07 9.9490094e-01 4.5994075e-04 5.5514480e-05 9.1660942e-05
  1.7706998e-05 3.7915866e-05 2.5628312e-03 1.8701523e-03 2.7018891e-06]]
[1]
模型预测的值为:1
图片的真实标签值为:1

由此可见,该模型对图像的预测结果是正确的。

[En]

It can be seen that the prediction result of the model for the picture is correct.

Original: https://blog.csdn.net/qq_37781464/article/details/122684455
Author: Keras深度学习
Title: Keras深度学习(1)-全连接手写数字的识别

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/511214/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球