keras简单神经网络搭建并训练测试

keras简单神经网络搭建并训练测试
keras简单神经网络搭建并训练测试

通过Keras搭建简单的神经网络,这里以minist数据集为例,测试手写字体训练效果,并进行一些简单的应用。

; 环境

在Windows下进行的测试,主要的安装包如下:

  • tensorflow_gpu==2.2.0
  • imutils==0.5.4
  • opencv_python==4.5.3.56
  • scikit_image==0.18.3
  • scikit_learn==0.24.2
  • numpy==1.21.2
  • py_sudoku==1.0.1

目录结构如下:

keras简单神经网络搭建并训练测试

搭建网络

通过Keras来搭建几层简单网络,可以用TensorFlow里集成的Keras,或者单独安装Keras包。使用 MNIST 数据集来训练模型识别数字。

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import MaxPooling2D
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Flatten
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import Dropout

class MnistNet:
    @staticmethod
    def build(width, height, depth, classes):

        model = Sequential()
        inputShape = (height, width, depth)

        model.add(Conv2D(32, (5, 5), padding="same",
            input_shape=inputShape))
        model.add(Activation("relu"))
        model.add(MaxPooling2D(pool_size=(2, 2)))

        model.add(Conv2D(32, (3, 3), padding="same"))
        model.add(Activation("relu"))
        model.add(MaxPooling2D(pool_size=(2, 2)))

        model.add(Flatten())
        model.add(Dense(64))
        model.add(Activation("relu"))
        model.add(Dropout(0.5))

        model.add(Dense(64))
        model.add(Activation("relu"))
        model.add(Dropout(0.5))

        model.add(Dense(classes))
        model.add(Activation("softmax"))

        return model

训练网络


from mnistnet import MnistNet
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.datasets import mnist
from sklearn.preprocessing import LabelBinarizer
from sklearn.metrics import classification_report
import argparse

ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", required=True,
    help="model output path")
args = vars(ap.parse_args())

INIT_LR = 1e-3
EPOCHS = 16
Batch_Size = 160

print("[LOGS] Please wait...")
((trainData, trainLabels), (testData, testLabels)) = mnist.load_data()

print(trainData.shape[0])
trainData = trainData.reshape((trainData.shape[0], 28, 28, 1))
testData = testData.reshape((testData.shape[0], 28, 28, 1))

trainData = trainData.astype("float32") / 255.0
testData = testData.astype("float32") / 255.0

le = LabelBinarizer()
trainLabels = le.fit_transform(trainLabels)
testLabels = le.transform(testLabels)

opt = Adam(lr=INIT_LR)
model = MnistNet.build(width=28, height=28, depth=1, classes=10)
model.compile(loss="categorical_crossentropy", optimizer=opt,
    metrics=["accuracy"])
print("[LOGS] compiling model...")

H = model.fit(
    trainData, trainLabels,
    validation_data=(testData, testLabels),
    batch_size=Batch_Size,
    epochs=EPOCHS,
    verbose=1)
print("[LOGS] training network...")

predictions = model.predict(testData)
print("[LOGS] evaluating network...")
print(classification_report(
    testLabels.argmax(axis=1),
    predictions.argmax(axis=1),
    target_names=[str(x) for x in le.classes_]))

model.save(args["model"], save_format="h5")

使用命令行输入来启动训练:

python train_classifier.py --model model/model_mnist.h5

keras简单神经网络搭建并训练测试
等待训练完成,如下图示意:
keras简单神经网络搭建并训练测试

测试效果

通过手写一些数字0-9来进行简单的测试。

from tensorflow.keras.models import load_model
import cv2
import imutils
from imutils.contours import sort_contours
import numpy as np

imgPath = "image/test3.jpg"
model_path = "model/model_mnist.h5"
is_show = True

vs_img = cv2.imread(imgPath)

model = load_model(model_path)
model.summary()

frame = imutils.resize(vs_img,width=200)

gray = cv2.cvtColor(frame,cv2.COLOR_BGR2GRAY)

bl = cv2.GaussianBlur(gray,(5,5),0)

edge_canny = cv2.Canny(bl, 85, 200)

kernel = np.ones((3,3),np.uint8)
edge_canny = cv2.dilate(edge_canny,kernel)

if is_show:
    cv2.imshow("edge_canny", edge_canny)
    cv2.waitKey(10)
items = cv2.findContours(edge_canny.copy(), cv2.RETR_EXTERNAL,cv2.CHAIN_APPROX_SIMPLE)

conts = items[0] if len(items) == 2 else items[1]

conts,_ = sort_contours(conts,method="left-to-right")

find_chars = []

for i in conts:

    (x,y,w,h) = cv2.boundingRect(i)

    if(w>2 and w< 100) and (h>5 and h< 100):

        roi = gray[y:y+h,x:x+w]
        mask = np.zeros(roi.shape,dtype="uint8")
        digit = cv2.bitwise_and(roi, roi, mask=mask)

        _, th = cv2.threshold(roi, 0 ,255,cv2.THRESH_BINARY_INV | cv2.THRESH_OTSU)

        th_H,th_W = th.shape

        if th_H < th_W:
            th = imutils.resize(th,width=28)
        else:
            th = imutils.resize(th,height=28)

        th_H, th_W = th.shape
        dx = int(max(0,28-th_W)/2)
        dy = int(max(0,28-th_H)/2)

        padding = cv2.copyMakeBorder(th,top=dy,bottom=dy,left=dx,right=dx,
                                     borderType=cv2.BORDER_CONSTANT,value=(0,0,0))
        padding = cv2.resize(padding,(28,28))

        padding = padding.astype("float32")/255.0
        padding = np.expand_dims(padding,axis=-1)

        print(((x,y,w,h)))
        find_chars.append((padding,(x,y,w,h)))
    else:
        print("next ... ")
        continue

boxes = [b[1] for b in find_chars]

find_chars = np.array([f[0] for f in find_chars], dtype="float32")
if find_chars is None:
    print("can not find chars ...")

predicts = model.predict(find_chars)

labels = "0123456789"

for (pred, (x,y,w,h)) in zip(predicts,boxes):

    p = np.argmax(pred)
    pre = pred[p]
    label = labels[p]

    cv2.rectangle(frame,(x,y),(x+w,y+h),(255,0,0),2)
    cv2.putText(frame,label,(x-10,y-10),cv2.FONT_HERSHEY_SIMPLEX,1,(0,255,0),2)
    cv2.imshow("result",frame)
    cv2.waitKey(10)

测试效果如下所示:

keras简单神经网络搭建并训练测试

简单益智游戏应用

拿上面训练好的数字模型来识别数独板中的数字并解决数独填空。
流程如下:

  1. 输入一张待解谜的数独图像;
  2. 在图像中找到每个数字的位置;
  3. 给数独划分网格,一般是9×9,计算得到每个格子的位置;
  4. 判断格子中是否有数字,有的话就进行OCR识别;
  5. 用数独算法来解谜题;
  6. 结果输出显示

识别主要代码如下:


ap = argparse.ArgumentParser()
ap.add_argument("-m", "--model", required=True,
    help="path to trained digit classifier")
ap.add_argument("-i", "--image", required=True,
    help="path to input sudoku puzzle image")
ap.add_argument("-d", "--is_show", type=int, default=-1,
    help="is show each step ")
args = vars(ap.parse_args())

model = load_model(args["model"])
print("loading digit classifier...")

image = cv2.imread(args["image"])
print("processing image...")
if image is None:
    print("could not load image ...")

image = imutils.resize(image, width=400)
src = image.copy()
gray = cv2.cvtColor(image,cv2.COLOR_BGR2GRAY)

board_9 = np.zeros((9, 9), dtype="int")

stepX = gray.shape[1] // 9
stepY = gray.shape[0] // 9

each_loc = []

for y in range(0, 9):

    c_row = []

    for x in range(0, 9):

        startX = x * stepX
        startY = y * stepY
        endX = (x + 1) * stepX
        endY = (y + 1) * stepY

        c_row.append((startX, startY, endX, endY))

        grid_img = gray[startY:endY, startX:endX]
        number = extract_number(grid_img, is_show=False)

        if number is not None:
            two_h = np.hstack([grid_img, number])

            roi = cv2.resize(number, (28, 28))
            roi = roi.astype("float") / 255.0
            roi = img_to_array(roi)
            roi = np.expand_dims(roi, axis=0)

            pred = model.predict(roi).argmax(axis=1)[0]
            board_9[y, x] = pred

    each_loc.append(c_row)

print("OCR sudoku board:")
makeup = Sudoku(3, 3, board=board_9.tolist())
makeup.show()

print("solving sudoku makeup...")
solution = makeup.solve()
solution.show_full()

for (grid, b) in zip(each_loc, solution.board):

    for (box, n) in zip(grid, b):

        startX, startY, endX, endY = box

        textX = int((endX - startX) * 0.3)
        textY = int((endY - startY) * -0.25)
        textX += startX
        textY += endY

        cv2.putText(src, str(n), (textX, textY),
            cv2.FONT_HERSHEY_SIMPLEX, 0.8, (0, 255, 0), 2)

cv2.imshow("Results", src)
cv2.waitKey(0)
cv2.imwrite("output/res.jpg",src)

测试结果:
绿色为识别后解出来数字。

keras简单神经网络搭建并训练测试

代码

完整代码:
https://github.com/ssggle/keras_mnistnet

Reference

https://keras.io/examples/
http://yann.lecun.com/exdb/mnist/

Original: https://blog.csdn.net/y459541195/article/details/115052939
Author: 圆滚熊
Title: keras简单神经网络搭建并训练测试

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/521021/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球