Python SVM手写数字识别

Python 基于sklearn – svm实现MNIST手写数字识别

一、数据集:MNIST

数据地址:http://yann.lecun.com/exdb/mnist/

训练数据:MNIST中的60000张图像,0-9的手写数字

测试数据:MNIST中的10000张图像,0-9的手写数字

注意:训练和测试代码直接使用了ubyte格式数据,即只对原数据进行了解压,没有先转换为png/jpg,但也附上png数据转换代码。

数据格式转换:从ubyte转换到png格式,存储格式:mnist_train>label>.png,代码如下:

提示:PIL不再支持新版本,要额外安装Pillow库

import numpy as np
import struct

from PIL import Image
import os

data_file = 'train-images.idx3-ubyte'
It's 47040016B, but we should set to 47040000B
data_file_size = 47040016
data_file_size = str(data_file_size - 16) + 'B'

data_buf = open(data_file, 'rb').read()

magic, numImages, numRows, numColumns = struct.unpack_from(
    '>IIII', data_buf, 0)
datas = struct.unpack_from(
    '>' + data_file_size, data_buf, struct.calcsize('>IIII'))
datas = np.array(datas).astype(np.uint8).reshape(
    numImages, 1, numRows, numColumns)

label_file = 'train-labels.idx1-ubyte'

It's 60008B, but we should set to 60000B
label_file_size = 60008
label_file_size = str(label_file_size - 8) + 'B'

label_buf = open(label_file, 'rb').read()

magic, numLabels = struct.unpack_from('>II', label_buf, 0)
labels = struct.unpack_from(
    '>' + label_file_size, label_buf, struct.calcsize('>II'))
labels = np.array(labels).astype(np.int64)

datas_root = 'mnist_train'
if not os.path.exists(datas_root):
    os.mkdir(datas_root)

for i in range(10):
    file_name = datas_root + os.sep + str(i)
    if not os.path.exists(file_name):
        os.mkdir(file_name)

count = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
for ii in range(numLabels):
    img = Image.fromarray(datas[ii, 0, 0:28, 0:28])
    label = labels[ii]
    file_name = datas_root + os.sep + str(label) + os.sep + \
                str(label) + '_' + str(count[label]) + '.png'
    count[label] = count[label] + 1
    # file_name = datas_root + os.sep + str(label) + os.sep + \
    #             'mnist_train_' + str(ii) + '.png'
    img.save(file_name)

data_file = 't10k-images.idx3-ubyte'
It's 7840016B, but we should set to 7840000B
data_file_size = 7840016
data_file_size = str(data_file_size - 16) + 'B'

data_buf = open(data_file, 'rb').read()

magic, numImages, numRows, numColumns = struct.unpack_from(
    '>IIII', data_buf, 0)
datas = struct.unpack_from(
    '>' + data_file_size, data_buf, struct.calcsize('>IIII'))
datas = np.array(datas).astype(np.uint8).reshape(
    numImages, 1, numRows, numColumns)

label_file = 't10k-labels.idx1-ubyte'

It's 10008B, but we should set to 10000B
label_file_size = 10008
label_file_size = str(label_file_size - 8) + 'B'

label_buf = open(label_file, 'rb').read()

magic, numLabels = struct.unpack_from('>II', label_buf, 0)
labels = struct.unpack_from(
    '>' + label_file_size, label_buf, struct.calcsize('>II'))
labels = np.array(labels).astype(np.int64)

datas_root = 'mnist_test'
if not os.path.exists(datas_root):
    os.mkdir(datas_root)

for i in range(10):
    file_name = datas_root + os.sep + str(i)
    if not os.path.exists(file_name):
        os.mkdir(file_name)

count = np.array([1, 1, 1, 1, 1, 1, 1, 1, 1, 1])
for ii in range(numLabels):
    img = Image.fromarray(datas[ii, 0, 0:28, 0:28])
    label = labels[ii]
    file_name = datas_root + os.sep + str(label) + os.sep + \
                str(label) + '_' + str(count[label]) + '.png'
    count[label] = count[label] + 1
    # file_name = datas_root + os.sep + str(label) + os.sep + \
    #             'mnist_test_' + str(ii) + '.png'
    img.save(file_name)

转换后的数据如下图

Python SVM手写数字识别

二、训练模型

import numpy as np
import struct
import pickle
from sklearn import svm
###用于做数据预处理
from sklearn import preprocessing

##读取数据集
def load_mnist_train(labels_path, images_path):
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II', lbpath.read(8))
        labels = np.fromfile(lbpath, dtype=np.uint8)
    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
        images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(labels), 784)
    return images, labels

if __name__ == '__main__':
    ##读取训练数据
    labels_path = "train-labels.idx1-ubyte"
    images_path = "train-images.idx3-ubyte"
    train_images, train_labels = load_mnist_train(labels_path, images_path)

    ##标准化
    X = preprocessing.StandardScaler().fit_transform(train_images)
    X_train = X[0:60000]
    y_train = train_labels[0:60000]

    ##定义并训练模型
    model_svc = svm.SVC()
    model_svc.fit(X_train, y_train)
    file = open("model.pickle", "wb")
    ##保存模型
    pickle.dump(model_svc, file)
    file.close()

三、测试模型

import numpy as np
import struct
import pickle
###用于做数据预处理
from sklearn import preprocessing

def test(images_path, labels_path, modelPath):
    # 读取测试图像
    with open(labels_path, 'rb') as lbpath:
        magic, n = struct.unpack('>II', lbpath.read(8))
        test_labels = np.fromfile(lbpath, dtype=np.uint8)
    with open(images_path, 'rb') as imgpath:
        magic, num, rows, cols = struct.unpack('>IIII', imgpath.read(16))
        test_images = np.fromfile(imgpath, dtype=np.uint8).reshape(len(test_labels), 784)

    ##读取模型
    file = open(modelPath, "rb")
    model_svc = pickle.load(file)
    file.close()

    ##评分并预测
    x = preprocessing.StandardScaler().fit_transform(test_images)
    x_test = x[0:10000]
    y_test = test_labels[0:10000]
    num = model_svc.predict(x_test)
    for i in range(10000):
        print("Real:", y_test[i], "Predict:", num[i])
    print("Accuracy:", model_svc.score(x_test, y_test))
    return num

if __name__ == '__main__':
    images_path = "t10k-images.idx3-ubyte"
    labels_path = "t10k-labels.idx1-ubyte"
    modelPath = "model.pickle"
    num = test(images_path, labels_path, modelPath)

四、参考资料

图片格式转换: MNIST数据集格式ubyte转png_haoji007的博客-CSDN博客_ubyte

模型训练及测试:图像处理基本库的学习笔记2–SVM,MATLAB,Tensorflow下分别对mnist数据集进行训练,并且进行预测 – 灰信网(软件开发博客聚合)

sklearn-svm模型参数设置:机器学习笔记(3)-sklearn支持向量机SVM – 简书

模型保存和调用: 基于sklearn的SVM模型保存与调用_hellosonny的博客-CSDN博客_svm保存模型

单个图片测试:基于svm机器学习的手写数字识别_Brinshy的博客-CSDN博客_基于svm的手写数字识别

Original: https://blog.csdn.net/weixin_43349279/article/details/124507662
Author: 跑路小饼
Title: Python SVM手写数字识别

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/623227/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球