mnist手写数字模型训练、保存、加载及图片预测

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档

非专业程序员,主业PLC单片机,2019年想扩充知识体系,紧跟潮流,带学生参加了人工智能大赛,才开始接触tensorflow以及深度学习的基本过程,非常艰难。后来比赛完了之后,学生也毕业了,因为感觉难度过大,而且自己将来也不准备转到这个行业,干脆就放弃了。最近疫情关在家里,想了一个晚上,对于本专业,自己干过项目,参加过大赛,虽然没有掌握PLC单片机所有的知识,但是掌握了方法论,能够快速的学习新的设备和目前尚未掌握的功能,没有应用的点,学习那些知识也就没有太大的必要。所以就决定利用疫情,继续开拓对自己来说仍然是新的领域,人工智能,从头开始。

现在再看tensorflow,已经改头换面了,1.x版本太过复杂,难以理解,2.x改善了很多,入门容易。

今天还是从手写数字开始,入门代码非常多,大多都是有关模型训练的,官网也有保存模型及加载模型的代码,这里就不再多写

预测部分,我在网上找到的都是使用mnist自己的测试数据来进行预测,有的是使用加载模型的方法测试准确率,有的是预测测试集中的数据,但是没有针对一个自己的图片(可以是摄像头拍的,可以是自己在画图里写的数字)的预测方法,这里主要解决这个问题。

这个想法产生的原因很简单,其实就是需要将我们的工作应用到现实中,整个过程我想是这几个步骤:准备训练数据、数据预处理来适配模型网络、搭建深度学习网络、训练模型、模型保存,到这实际上开发工作已经完成,下面的步骤就是要应用了,准备数据、加载模型、预测结果,预测的结果将用到后面的业务逻辑。

import tensorflow as tf
from tensorflow import keras
import cv2
from keras.preprocessing.image import img_to_array
import numpy as np

加载数据

def loadData():
    mnist = tf.keras.datasets.mnist
    (x_train, y_train),(x_test, y_test) = mnist.load_data()
    x_train, x_test = x_train / 255.0, x_test / 255.0

创建模型和训练

此部分大多来自官方文档和网络,就是一层全连接,理解也较为容易,最后将模型存为h5文件。

def create_model():
    model = tf.keras.models.Sequential([
      tf.keras.layers.Flatten(input_shape=(28, 28)),
      tf.keras.layers.Dense(128, activation='relu'),
      tf.keras.layers.Dropout(0.2),
      tf.keras.layers.Dense(10, activation='softmax')
    ])

    model.compile(optimizer='adam',
                  loss='sparse_categorical_crossentropy',
                  metrics=['accuracy'])
    return model
def train():
    model = create_model()
    model.summary()
    model.fit(x_train, y_train, epochs=5)
    model.evaluate(x_test,  y_test, verbose=2)
    model.save('my_model.h5')

模型应用

现在我们有了训练好的模型,正常逻辑就是考虑如何应用,官方文档有加载模型的方法,加载好之后预测就是一个predict函数,将预测的数据传进去就能得出结果,因为输入的尺寸是28*28,所以我考虑到图片大小不一,需要转换尺寸,这里我想到了用opencv,所以下面的函数就是使用cv处理图片。
为了简化操作过程探寻方法论,我使用黑底图片。

def imgTool():
    img = cv2.imread("D:/workspace/MNIST_data/1.jpg")
    img = cv2.resize(img,(28,28))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img = img_to_array(img)
    img = np.expand_dims(img, axis=0)

即使是黑底图片,也一定要转灰度,否则维度不对,在这浪费了不少时间。
步骤:读图片、塑形、灰度、图转矩阵、展开矩阵。
这个函数只用来测试图片处理结果。

最后的预测函数,在主程序里调用即可。

def predict():
    img = cv2.imread("D:/workspace/MNIST_data/1.jpg")
    img = cv2.resize(img,(28,28))
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    img = img_to_array(img)
    img = np.expand_dims(img, axis=0)
    img = img / 255.0
    new_model = tf.keras.models.load_model("my_model.h5")
    new_model.summary()
    pre= new_model.predict(img)
    print(np.argmax(pre))

这里不讨论模型的准确度及网络的合理性,主要问题是训练好的模型如何应用。

Original: https://blog.csdn.net/bobbycumt/article/details/124476990
Author: bobbycumt
Title: mnist手写数字模型训练、保存、加载及图片预测

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/690182/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球