基于Tensorflow2.x的CNN的病虫害分类(有界面)

基于Tensorflow2.x的ResNet-50的病虫害分类讲解(有界面):

基于Tensorflow2.x的CNN的病虫害分类j讲解(有界面)讲解:(39条消息) 基于Tensorflow2.x的CNN的病虫害分类(有界面)_天道酬勤者的博客-CSDN博客

基于Tensorflow2.x的MobileNet的病虫害分类(有界面)讲解:

(41条消息) 基于Tensorflow2.x的MobileNet的病虫害分类(有界面)_songyang66的博客-CSDN博客

基于Tensorflow2.x的ResNet的病虫害分类(有界面)文件下载:

基于Tensorflow2.x的CNN的病虫害分类(有界面)文件下载:(41条消息) 基于Tensorflow2.x的CNN的病虫害分类(有界面)-深度学习文档类资源-CSDN文库

基于Tensorflow2.x的mobilenet的病虫害分类(有界面)文件下载:(41条消息) 基于Tensorflow2.x的mobilenet的病虫害分类(有界面)-深度学习文档类资源-CSDN文库

基于Tensorflow2.x的CNN的病虫害分类(有界面)

datasort.py用于对数据集图片进行排序重命名,代码如下:

import os

class BatchRename():

    def __init__(self):
        self.path = "B:/BaiduNetdiskDownload/class/testing_data/Leaf-ulcer" #图片的路径

    def rename(self):
        filelist = os.listdir(self.path)
        filelist.sort()
        total_num = len(filelist)
        i = 0
        for item in filelist:
            item = item.lower()
            if item.endswith('.png'):
                src = os.path.join(self.path, item)
                s = str(i)
                s = s.zfill(2)  #Python zfill() 方法返回指定长度的字符串,原字符串右对齐,前面填充0
                dst = os.path.join(os.path.abspath(self.path), s + '.png')

                '''
                    概述:
                    os.rename() 方法用于重命名文件或目录,从 src 到 dst,如果dst是一个存在的目录, 将抛出OSError
                    语法:
                    rename()方法语法格式如下:os.rename(src, dst)
                    参数:
                    src  要修改的目录名
                    dst  修改后的目录名
                    返回值:
                    该方法没有返回值
                '''
                try:
                    os.rename(src, dst)
                    print ('converting %s to %s ...' % (src, dst))
                    i = i + 1
                except Exception as e:
                    print(e)
                    print('rename dir fail\r\n')

        print ('total %d to rename & converted %d jpgs' % (total_num, i))

if __name__ == '__main__':
    demo = BatchRename()
    demo.rename()

datasort.py效果如下:

基于Tensorflow2.x的CNN的病虫害分类(有界面)

data_progress.py用于对数据集进行分类,具体代码如下:

import os
import random
from shutil import copy2

def data_set_split(initial_data_folder, target_data_folder, train_scale=0.8, val_scale=0.2, test_scale=0.0):
    '''
    读取源数据文件夹,生成划分好的文件夹,分为trian、val、test三个文件夹进行
    :param initial_data_folder: 源文件夹 E:/biye/gogogo/note_book/torch_note/data/utils_test/data_split/initial_data
    :param target_data_folder: 目标文件夹 E:/biye/gogogo/note_book/torch_note/data/utils_test/data_split/target_data
    :param train_scale: 训练集比例
    :param val_scale: 验证集比例
    :param test_scale: 测试集比例
    :return:
    '''
    print("开始数据集划分")
    class_names = os.listdir(initial_data_folder)#用于返回initial_data_folder文件夹包含的图片文件的名字的列表
    # 在目标目录下创建文件夹
    split_names = ['train', 'val', 'test']
    for split_name in split_names:
        split_path = os.path.join(target_data_folder, split_name)
        if os.path.isdir(split_path):  #os.path.isdir()函数判断某一路径是否为目录
            pass
        else:
            os.mkdir(split_path)   #os.mkdir()函数创建目录(创建一级目录),其参数path 为要创建目录的路径    拓展:使用os.rmdir()函数删除目录。
        # 然后在split_path的目录下创建类别文件夹,生成Fruit-anthrax等五个文件夹
        for class_name in class_names:
            class_split_path = os.path.join(split_path, class_name)
            if os.path.isdir(class_split_path):
                pass
            else:
                os.mkdir(class_split_path)

    # 按照比例划分数据集,并进行数据图片的复制
    # 首先进行分类遍历
    for class_name in class_names:
        current_class_data_path = os.path.join(initial_data_folder, class_name)
        current_all_data = os.listdir(current_class_data_path)#用于返回current_class_data_path文件夹包含的图片文件的名字的列表
        current_data_length = len(current_all_data)
        current_data_index_list = list(range(current_data_length))
        random.shuffle(current_data_index_list)

        train_folder = os.path.join(os.path.join(target_data_folder, 'train'), class_name)
        val_folder = os.path.join(os.path.join(target_data_folder, 'val'), class_name)
        test_folder = os.path.join(os.path.join(target_data_folder, 'test'), class_name)
        train_stop_flag = current_data_length * train_scale
        val_stop_flag = current_data_length * (train_scale + val_scale)
        current_idx = 0
        train_num = 0
        val_num = 0
        test_num = 0
        for i in current_data_index_list:
            initial_img_path = os.path.join(current_class_data_path, current_all_data[i])
            if current_idx <= 1 train_stop_flag: copy2(initial_img_path, train_folder) train_num="train_num" + elif (current_idx> train_stop_flag) and (current_idx <= 1 val_stop_flag): copy2(initial_img_path, val_folder) val_num="val_num" + else: test_folder) test_num="test_num" current_idx="current_idx" print("*********************************{}*************************************".format(class_name)) print( "{}类按照{}:{}:{}的比例划分完成,一共{}张图片".format(class_name, train_scale, val_scale, test_scale, current_data_length)) print("训练集{}:{}张".format(train_folder, train_num)) print("验证集{}:{}张".format(val_folder, val_num)) print("测试集{}:{}张".format(test_folder, test_num)) if __name__="=" '__main__': initial_data_folder="training_data" # 原始数据集路径 target_data_folder="split_data" 目标存放的路径 data_set_split(initial_data_folder, target_data_folder)< code></=></=>

data_progress.py效果展示:

基于Tensorflow2.x的CNN的病虫害分类(有界面)

基于Tensorflow2.x的CNN的病虫害分类(有界面)

train_cnn.py用于训练cnn网络

import tensorflow as tf
import matplotlib.pyplot as plt
from time import *

&#x6570;&#x636E;&#x96C6;&#x52A0;&#x8F7D;&#x51FD;&#x6570;&#xFF0C;&#x6307;&#x660E;&#x6570;&#x636E;&#x96C6;&#x7684;&#x4F4D;&#x7F6E;&#x5E76;&#x7EDF;&#x4E00;&#x5904;&#x7406;&#x4E3A;imgheight*imgwidth&#x7684;&#x5927;&#x5C0F;&#xFF0C;&#x540C;&#x65F6;&#x8BBE;&#x7F6E;batch
def data_load(data_dir, test_data_dir, img_height, img_width, batch_size):
    # &#x52A0;&#x8F7D;&#x8BAD;&#x7EC3;&#x96C6;
    #tf.keras.preprocessing.image_dataset_from_directory&#x4ECE;&#x76EE;&#x5F55;&#x4E2D;&#x7684;&#x56FE;&#x50CF;&#x6587;&#x4EF6;&#x751F;&#x6210;&#x4E00;&#x4E2A; tf.data.Dataset
    train_ds = tf.keras.preprocessing.image_dataset_from_directory(
        data_dir,  # &#x6570;&#x636E;&#x6240;&#x5728;&#x76EE;&#x5F55;&#x3002;&#x5982;&#x679C;&#x6807;&#x7B7E;&#x662F;&#x201C;inferred&#x201D;&#xFF08;&#x9ED8;&#x8BA4;&#xFF09;&#xFF0C;&#x5219;&#x5B83;&#x5E94;&#x8BE5;&#x5305;&#x542B;&#x5B50;&#x76EE;&#x5F55;&#xFF0C;&#x6BCF;&#x4E2A;&#x76EE;&#x5F55;&#x5305;&#x542B;&#x4E00;&#x4E2A;&#x7C7B;&#x7684;&#x56FE;&#x50CF;&#x3002;&#x5426;&#x5219;&#xFF0C;&#x5C06;&#x5FFD;&#x7565;&#x76EE;&#x5F55;&#x7ED3;&#x6784;&#x3002;
        #&#x7701;&#x7565;labels: &#x201C;inferred&#x201D;&#xFF08;&#x6807;&#x7B7E;&#x4ECE;&#x76EE;&#x5F55;&#x7ED3;&#x6784;&#x751F;&#x6210;&#xFF09;&#xFF0C;&#x6216;&#x8005;&#x662F;&#x6574;&#x6570;&#x6807;&#x7B7E;&#x7684;&#x5217;&#x8868;/&#x5143;&#x7EC4;&#xFF0C;&#x5176;&#x5927;&#x5C0F;&#x4E0E;&#x76EE;&#x5F55;&#x4E2D;&#x627E;&#x5230;&#x7684;&#x56FE;&#x50CF;&#x6587;&#x4EF6;&#x7684;&#x6570;&#x91CF;&#x76F8;&#x540C;&#x3002;&#x6807;&#x7B7E;&#x5E94;&#x6839;&#x636E;&#x56FE;&#x50CF;&#x6587;&#x4EF6;&#x8DEF;&#x5F84;&#x7684;&#x5B57;&#x6BCD;&#x987A;&#x5E8F;&#x6392;&#x5E8F;&#xFF08;&#x901A;&#x8FC7;Python&#x4E2D;&#x7684;os.walk(directory)&#x83B7;&#x5F97;&#xFF09;&#x3002;
        label_mode='categorical',# 'int'&#xFF1A;&#x8868;&#x793A;&#x6807;&#x7B7E;&#x88AB;&#x7F16;&#x7801;&#x6210;&#x6574;&#x6570;&#xFF08;&#x4F8B;&#x5982;&#xFF1A;sparse_categorical_crossentropy loss&#xFF09;&#x3002;&#x2018;categorical&#x2019;&#x6307;&#x6807;&#x7B7E;&#x88AB;&#x7F16;&#x7801;&#x4E3A;&#x5206;&#x7C7B;&#x5411;&#x91CF;&#xFF08;&#x4F8B;&#x5982;&#xFF1A;categorical_crossentropy loss&#xFF09;&#x3002;&#x2018;binary&#x2019;&#x610F;&#x5473;&#x7740;&#x6807;&#x7B7E;&#xFF08;&#x53EA;&#x80FD;&#x6709;2&#x4E2A;&#xFF09;&#x88AB;&#x7F16;&#x7801;&#x4E3A;&#x503C;&#x4E3A;0&#x6216;1&#x7684;float32&#x6807;&#x91CF;&#xFF08;&#x4F8B;&#x5982;&#xFF1A;binary_crossentropy&#xFF09;&#x3002;None&#xFF08;&#x65E0;&#x6807;&#x7B7E;&#xFF09;&#x3002;
        seed=123,  #&#x7528;&#x4E8E;shuffle&#x548C;&#x8F6C;&#x6362;&#x7684;&#x53EF;&#x9009;&#x968F;&#x673A;&#x79CD;&#x5B50;
        image_size=(img_height, img_width),  #&#x6570;&#x636E;&#x6279;&#x6B21;&#x7684;&#x5927;&#x5C0F;&#x3002;&#x9ED8;&#x8BA4;&#x503C;&#xFF1A;32
        batch_size=batch_size)   #&#x6570;&#x636E;&#x6279;&#x6B21;&#x7684;&#x5927;&#x5C0F;&#x3002;&#x9ED8;&#x8BA4;&#x503C;&#xFF1A;32
    # &#x52A0;&#x8F7D;&#x9A8C;&#x8BC1;&#x96C6;
    val_ds = tf.keras.preprocessing.image_dataset_from_directory(
        test_data_dir,
        label_mode='categorical',
        seed=123,
        image_size=(img_height, img_width),
        batch_size=batch_size)
    class_names = train_ds.class_names
    # &#x8FD4;&#x56DE;&#x5904;&#x7406;&#x4E4B;&#x540E;&#x7684;&#x8BAD;&#x7EC3;&#x96C6;&#x3001;&#x9A8C;&#x8BC1;&#x96C6;&#x548C;&#x7C7B;&#x540D;
    return train_ds, val_ds, class_names
#&#x62D3;&#x5C55;&#xFF1A;
#&#x5982;&#x679C;label_mode &#x662F; int, labels&#x662F;&#x5F62;&#x72B6;&#x4E3A;&#xFF08;batch_size, &#xFF09;&#x7684;int32&#x5F20;&#x91CF;
#&#x5982;&#x679C;label_mode &#x662F; binary, labels&#x662F;&#x5F62;&#x72B6;&#x4E3A;&#xFF08;batch_size, 1&#xFF09;&#x7684;1&#x548C;0&#x7684;float32&#x5F20;&#x91CF;&#x3002;
#&#x5982;&#x679C;label_mode &#x662F; categorial, labels&#x662F;&#x5F62;&#x72B6;&#x4E3A;&#xFF08;batch_size, num_classes&#xFF09;&#x7684;float32&#x5F20;&#x91CF;&#xFF0C;&#x8868;&#x793A;&#x7C7B;&#x7D22;&#x5F15;&#x7684;one-hot&#x7F16;&#x7801;&#x3002;

&#x6784;&#x5EFA;CNN&#x6A21;&#x578B;
def model_load(IMG_SHAPE=(224, 224, 3), class_num=5):
    # &#x642D;&#x5EFA;&#x6A21;&#x578B;
    model = tf.keras.models.Sequential([
        # &#x5BF9;&#x6A21;&#x578B;&#x505A;&#x5F52;&#x4E00;&#x5316;&#x7684;&#x5904;&#x7406;&#xFF0C;&#x5C06;0-255&#x4E4B;&#x95F4;&#x7684;&#x6570;&#x5B57;&#x7EDF;&#x4E00;&#x5904;&#x7406;&#x5230;0&#x5230;1&#x4E4B;&#x95F4;
        tf.keras.layers.experimental.preprocessing.Rescaling(1. / 255, input_shape=IMG_SHAPE),
        # &#x5377;&#x79EF;&#x5C42;&#xFF0C;&#x8BE5;&#x5377;&#x79EF;&#x5C42;&#x7684;&#x8F93;&#x51FA;&#x4E3A;32&#x4E2A;&#x901A;&#x9053;&#xFF0C;&#x5377;&#x79EF;&#x6838;&#x7684;&#x5927;&#x5C0F;&#x662F;3*3&#xFF0C;&#x6FC0;&#x6D3B;&#x51FD;&#x6570;&#x4E3A;relu
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu'),
        # &#x6DFB;&#x52A0;&#x6C60;&#x5316;&#x5C42;&#xFF0C;&#x6C60;&#x5316;&#x7684;kernel&#x5927;&#x5C0F;&#x662F;2*2
        tf.keras.layers.MaxPooling2D(2, 2),
        # Add another convolution
        # &#x5377;&#x79EF;&#x5C42;&#xFF0C;&#x8F93;&#x51FA;&#x4E3A;64&#x4E2A;&#x901A;&#x9053;&#xFF0C;&#x5377;&#x79EF;&#x6838;&#x5927;&#x5C0F;&#x4E3A;3*3&#xFF0C;&#x6FC0;&#x6D3B;&#x51FD;&#x6570;&#x4E3A;relu
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        # &#x6C60;&#x5316;&#x5C42;&#xFF0C;&#x6700;&#x5927;&#x6C60;&#x5316;&#xFF0C;&#x5BF9;2*2&#x7684;&#x533A;&#x57DF;&#x8FDB;&#x884C;&#x6C60;&#x5316;&#x64CD;&#x4F5C;
        tf.keras.layers.MaxPooling2D(2, 2),
        # &#x5C06;&#x4E8C;&#x7EF4;&#x7684;&#x8F93;&#x51FA;&#x8F6C;&#x5316;&#x4E3A;&#x4E00;&#x7EF4;
        tf.keras.layers.Flatten(),
        # The same 128 dense layers, and 10 output layers as in the pre-convolution example:
        tf.keras.layers.Dense(128, activation='relu'),
        # &#x901A;&#x8FC7;softmax&#x51FD;&#x6570;&#x5C06;&#x6A21;&#x578B;&#x8F93;&#x51FA;&#x4E3A;&#x7C7B;&#x540D;&#x957F;&#x5EA6;&#x7684;&#x795E;&#x7ECF;&#x5143;&#x4E0A;&#xFF0C;&#x6FC0;&#x6D3B;&#x51FD;&#x6570;&#x91C7;&#x7528;softmax&#x5BF9;&#x5E94;&#x6982;&#x7387;&#x503C;
        tf.keras.layers.Dense(class_num, activation='softmax')
    ])
    # &#x8F93;&#x51FA;&#x6A21;&#x578B;&#x4FE1;&#x606F;
    model.summary()
    # &#x6307;&#x660E;&#x6A21;&#x578B;&#x7684;&#x8BAD;&#x7EC3;&#x53C2;&#x6570;&#xFF0C;&#x4F18;&#x5316;&#x5668;&#x4E3A;sgd&#x4F18;&#x5316;&#x5668;&#xFF0C;&#x635F;&#x5931;&#x51FD;&#x6570;&#x4E3A;&#x4EA4;&#x53C9;&#x71B5;&#x635F;&#x5931;&#x51FD;&#x6570;
    model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
    # &#x8FD4;&#x56DE;&#x6A21;&#x578B;
    return model

&#x5C55;&#x793A;&#x8BAD;&#x7EC3;&#x8FC7;&#x7A0B;&#x7684;&#x66F2;&#x7EBF;
def show_loss_acc(history):
    # &#x4ECE;history&#x4E2D;&#x63D0;&#x53D6;&#x6A21;&#x578B;&#x8BAD;&#x7EC3;&#x96C6;&#x548C;&#x9A8C;&#x8BC1;&#x96C6;&#x51C6;&#x786E;&#x7387;&#x4FE1;&#x606F;&#x548C;&#x8BEF;&#x5DEE;&#x4FE1;&#x606F;
    acc = history.history['accuracy']
    val_acc = history.history['val_accuracy']
    loss = history.history['loss']
    val_loss = history.history['val_loss']

    # &#x6309;&#x7167;&#x4E0A;&#x4E0B;&#x7ED3;&#x6784;&#x5C06;&#x56FE;&#x753B;&#x8F93;&#x51FA;
    plt.figure(figsize=(8, 8))
    plt.subplot(2, 1, 1)
    plt.plot(acc, label='Training Accuracy')
    plt.plot(val_acc, label='Validation Accuracy')
    plt.legend(loc='lower right')
    plt.ylabel('Accuracy')
    plt.ylim([min(plt.ylim()), 1])
    plt.title('Training and Validation Accuracy')

    plt.subplot(2, 1, 2)
    plt.plot(loss, label='Training Loss')
    plt.plot(val_loss, label='Validation Loss')
    plt.legend(loc='upper right')
    plt.ylabel('Cross Entropy')
    plt.title('Training and Validation Loss')
    plt.xlabel('epoch')
    plt.savefig('results/results_cnn.png', dpi=100)
    plt.show()

def train(epochs):
    print("&#x5F00;&#x59CB;&#x8BAD;&#x7EC3;&#xFF0C;&#x8BB0;&#x5F55;&#x5F00;&#x59CB;&#x65F6;&#x95F4;&#x3002;&#x3002;&#x3002;")
    # &#x5F00;&#x59CB;&#x8BAD;&#x7EC3;&#xFF0C;&#x8BB0;&#x5F55;&#x5F00;&#x59CB;&#x65F6;&#x95F4;
    begin_time = time()
    # todo &#x52A0;&#x8F7D;&#x6570;&#x636E;&#x96C6;&#xFF0C; &#x4FEE;&#x6539;&#x4E3A;&#x76EE;&#x6807;&#x7684;&#x6570;&#x636E;&#x96C6;&#x7684;&#x8DEF;&#x5F84;
    print("&#x52A0;&#x8F7D;&#x6570;&#x636E;&#x96C6;&#x4E2D;&#x3002;&#x3002;&#x3002;")
    train_ds, val_ds, class_names = data_load("split_data/train",
                                              "split_data/val", 224, 224, 16)
    print(class_names)
    print("&#x52A0;&#x8F7D;&#x6A21;&#x578B;&#x4E2D;&#x3002;&#x3002;&#x3002;")
    # &#x52A0;&#x8F7D;&#x6A21;&#x578B;
    model = model_load(class_num=len(class_names))
    # &#x6307;&#x660E;&#x8BAD;&#x7EC3;&#x7684;&#x8F6E;&#x6570;epoch&#xFF0C;&#x5F00;&#x59CB;&#x8BAD;&#x7EC3;
    print("&#x6307;&#x660E;&#x8BAD;&#x7EC3;&#x7684;&#x8F6E;&#x6570;epoch&#xFF0C;&#x5F00;&#x59CB;&#x8BAD;&#x7EC3;&#x4E2D;&#x3002;&#x3002;&#x3002;")
    history = model.fit(train_ds, validation_data=val_ds, epochs=epochs)
    print("&#x4FDD;&#x5B58;&#x6A21;&#x578B;&#x4E2D;&#x3002;&#x3002;&#x3002;")
    # todo &#x4FDD;&#x5B58;&#x6A21;&#x578B;&#xFF0C; &#x4FEE;&#x6539;&#x4E3A;&#x4F60;&#x8981;&#x4FDD;&#x5B58;&#x7684;&#x6A21;&#x578B;&#x7684;&#x540D;&#x79F0;
    model.save("results/cnn_orange.h5")
    print("&#x8BB0;&#x5F55;&#x7ED3;&#x675F;&#x65F6;&#x95F4;&#x4E2D;&#x3002;&#x3002;&#x3002;")
    # &#x8BB0;&#x5F55;&#x7ED3;&#x675F;&#x65F6;&#x95F4;
    end_time = time()
    run_time = end_time - begin_time
    print('&#x8BE5;&#x5FAA;&#x73AF;&#x7A0B;&#x5E8F;&#x8FD0;&#x884C;&#x65F6;&#x95F4;&#xFF1A;', run_time, "s")  # &#x8BE5;&#x5FAA;&#x73AF;&#x7A0B;&#x5E8F;&#x8FD0;&#x884C;&#x65F6;&#x95F4;&#xFF1A; 1.4201874732
    # &#x7ED8;&#x5236;&#x6A21;&#x578B;&#x8BAD;&#x7EC3;&#x8FC7;&#x7A0B;&#x56FE;
    show_loss_acc(history)

if __name__ == '__main__':
    train(epochs=30)

train_cnn.py训练结果:

基于Tensorflow2.x的CNN的病虫害分类(有界面)

model_test.py用于用测试集图片对cnn模型进行预测,并观察准确率

import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
#&#x6D4B;&#x8BD5;&#x4E4B;&#x540E;&#x5728;&#x547D;&#x4EE4;&#x884C;&#x4E2D;&#x4F1A;&#x8F93;&#x51FA;&#x6BCF;&#x4E2A;&#x6A21;&#x578B;&#x7684;&#x51C6;&#x786E;&#x7387;&#xFF0C;&#x5E76;&#x4E14;&#x4F1A;&#x5728;results&#x76EE;&#x5F55;&#x4E0B;&#x751F;&#x6210;&#x76F8;&#x5E94;&#x7684;&#x70ED;&#x529B;&#x56FE;
plt.rcParams['font.family'] = ['sans-serif']
plt.rcParams['font.sans-serif'] = ['SimHei']

&#x6570;&#x636E;&#x52A0;&#x8F7D;&#xFF0C;&#x5206;&#x522B;&#x4ECE;&#x8BAD;&#x7EC3;&#x7684;&#x6570;&#x636E;&#x96C6;&#x7684;&#x6587;&#x4EF6;&#x5939;&#x548C;&#x6D4B;&#x8BD5;&#x7684;&#x6587;&#x4EF6;&#x5939;&#x4E2D;&#x52A0;&#x8F7D;&#x8BAD;&#x7EC3;&#x96C6;&#x548C;&#x9A8C;&#x8BC1;&#x96C6;
def data_load(test_data_dir, img_height, img_width, batch_size):
    # &#x52A0;&#x8F7D;&#x6D4B;&#x8BD5;&#x96C6;
    test_ds = tf.keras.preprocessing.image_dataset_from_directory(
        test_data_dir,
        label_mode='categorical',
        seed=123,
        image_size=(img_height, img_width),
        batch_size=batch_size)
    class_names = test_ds.class_names
    # &#x8FD4;&#x56DE;&#x5904;&#x7406;&#x4E4B;&#x540E;&#x7684;&#x8BAD;&#x7EC3;&#x96C6;&#x3001;&#x9A8C;&#x8BC1;&#x96C6;&#x548C;&#x7C7B;&#x540D;
    return test_ds, class_names

&#x6D4B;&#x8BD5;cnn&#x6A21;&#x578B;&#x51C6;&#x786E;&#x7387;
def test_cnn():
    # todo &#x52A0;&#x8F7D;&#x6570;&#x636E;, &#x4FEE;&#x6539;&#x4E3A;&#x4F60;&#x81EA;&#x5DF1;&#x7684;&#x6570;&#x636E;&#x96C6;&#x7684;&#x8DEF;&#x5F84;
    test_ds, class_names = data_load("B:\\class\\testing_data", 224, 224, 16)
    # todo &#x52A0;&#x8F7D;&#x6A21;&#x578B;&#xFF0C;&#x4FEE;&#x6539;&#x4E3A;&#x4F60;&#x7684;&#x6A21;&#x578B;&#x540D;&#x79F0;
    model = tf.keras.models.load_model("results/cnn_orange.h5")
    # model.summary()
    # &#x6D4B;&#x8BD5;
    loss, accuracy = model.evaluate(test_ds)
    # &#x8F93;&#x51FA;&#x7ED3;&#x679C;
    print('CNN test accuracy :', accuracy)

    # &#x5BF9;&#x6A21;&#x578B;&#x5206;&#x5F00;&#x8FDB;&#x884C;&#x63A8;&#x7406;
    test_real_labels = []
    test_pre_labels = []
    for test_batch_images, test_batch_labels in test_ds:
        test_batch_labels = test_batch_labels.numpy()
        test_batch_pres = model.predict(test_batch_images)
        # print(test_batch_pres)

        test_batch_labels_max = np.argmax(test_batch_labels, axis=1)
        test_batch_pres_max = np.argmax(test_batch_pres, axis=1)
        # print(test_batch_labels_max)
        # print(test_batch_pres_max)
        # &#x5C06;&#x63A8;&#x7406;&#x5BF9;&#x5E94;&#x7684;&#x6807;&#x7B7E;&#x53D6;&#x51FA;
        for i in test_batch_labels_max:
            test_real_labels.append(i)

        for i in test_batch_pres_max:
            test_pre_labels.append(i)
        # break

    # print(test_real_labels)
    # print(test_pre_labels)
    class_names_length = len(class_names)
    heat_maps = np.zeros((class_names_length, class_names_length))
    for test_real_label, test_pre_label in zip(test_real_labels, test_pre_labels):
        heat_maps[test_real_label][test_pre_label] = heat_maps[test_real_label][test_pre_label] + 1

    print(heat_maps)
    heat_maps_sum = np.sum(heat_maps, axis=1).reshape(-1, 1)
    # print(heat_maps_sum)
    print()
    heat_maps_float = heat_maps / heat_maps_sum
    print(heat_maps_float)
    # title, x_labels, y_labels, harvest
    show_heatmaps(title="heatmap", x_labels=class_names, y_labels=class_names, harvest=heat_maps_float,
                  save_name="results/heatmap_cnn.png")

def show_heatmaps(title, x_labels, y_labels, harvest, save_name):
    # &#x8FD9;&#x91CC;&#x662F;&#x521B;&#x5EFA;&#x4E00;&#x4E2A;&#x753B;&#x5E03;
    fig, ax = plt.subplots()
    # cmap https://blog.csdn.net/ztf312/article/details/102474190
    im = ax.imshow(harvest, cmap="OrRd")
    # &#x8FD9;&#x91CC;&#x662F;&#x4FEE;&#x6539;&#x6807;&#x7B7E;
    # We want to show all ticks...

    ax.set_xticks(np.arange(len(y_labels)))
    ax.set_yticks(np.arange(len(x_labels)))
    # ... and label them with the respective list entries
    ax.set_xticklabels(y_labels)
    ax.set_yticklabels(x_labels)

    # &#x56E0;&#x4E3A;x&#x8F74;&#x7684;&#x6807;&#x7B7E;&#x592A;&#x957F;&#x4E86;&#xFF0C;&#x9700;&#x8981;&#x65CB;&#x8F6C;&#x4E00;&#x4E0B;&#xFF0C;&#x66F4;&#x52A0;&#x597D;&#x770B;
    # Rotate the tick labels and set their alignment.

    plt.setp(ax.get_xticklabels(), rotation=45, ha="right",
             rotation_mode="anchor")

    # &#x6DFB;&#x52A0;&#x6BCF;&#x4E2A;&#x70ED;&#x529B;&#x5757;&#x7684;&#x5177;&#x4F53;&#x6570;&#x503C;
    # Loop over data dimensions and create text annotations.

    for i in range(len(x_labels)):
        for j in range(len(y_labels)):
            text = ax.text(j, i, round(harvest[i, j], 2),
                           ha="center", va="center", color="black")
    ax.set_xlabel("Predict label")
    ax.set_ylabel("Actual label")
    ax.set_title(title)
    fig.tight_layout()
    plt.colorbar(im)
    plt.savefig(save_name, dpi=100)
    plt.show()

if __name__ == '__main__':
    test_cnn()

model_test.py用于用测试集图片对cnn模型进行预测,并观察准确率

基于Tensorflow2.x的CNN的病虫害分类(有界面)

design.py用于测试界面

import tensorflow as tf
from PyQt5.QtGui import *
from PyQt5.QtCore import *
from PyQt5.QtWidgets import *
import sys
import cv2
from PIL import Image
import numpy as np
import shutil

class MainWindow(QTabWidget):
    # &#x521D;&#x59CB;&#x5316;
    def __init__(self):
        super().__init__()
        self.setWindowIcon(QIcon('images/logo.png'))
        self.setWindowTitle('&#x5927;&#x7530;&#x67D1;&#x6A58;&#x75C5;&#x866B;&#x5BB3;&#x8BC6;&#x522B;&#x7CFB;&#x7EDF;')  # todo &#x4FEE;&#x6539;&#x7CFB;&#x7EDF;&#x540D;&#x79F0;
        # &#x6A21;&#x578B;&#x521D;&#x59CB;&#x5316;
        self.model = tf.keras.models.load_model("results/mobilenet_orange.h5")  # todo &#x4FEE;&#x6539;&#x6A21;&#x578B;&#x540D;&#x79F0;:cnn_orange.h5&#x3001;mobilenet_orange.h5&#x3001;resnet_orange.h5
        self.to_predict_name = "images/background.jpg"  # todo &#x4FEE;&#x6539;&#x521D;&#x59CB;&#x56FE;&#x7247;&#xFF0C;&#x8FD9;&#x4E2A;&#x56FE;&#x7247;&#x8981;&#x653E;&#x5728;images&#x76EE;&#x5F55;&#x4E0B;
        self.class_names = ['Fruit-anthrax', 'Fruit-ulcer', 'Leaf-anthrax','Leaf-ulcer', 'leaf_thyroid']  # todo &#x4FEE;&#x6539;&#x7C7B;&#x540D;&#xFF0C;&#x8FD9;&#x4E2A;&#x6570;&#x7EC4;&#x5728;&#x6A21;&#x578B;&#x8BAD;&#x7EC3;&#x7684;&#x5F00;&#x59CB;&#x4F1A;&#x8F93;&#x51FA;
        self.resize(900, 700)
        self.initUI()

    # &#x754C;&#x9762;&#x521D;&#x59CB;&#x5316;&#xFF0C;&#x8BBE;&#x7F6E;&#x754C;&#x9762;&#x5E03;&#x5C40;
    def initUI(self):
        main_widget = QWidget()
        main_layout = QHBoxLayout()
        font = QFont('&#x6977;&#x4F53;', 18)

        # &#x4E3B;&#x9875;&#x9762;&#xFF0C;&#x8BBE;&#x7F6E;&#x7EC4;&#x4EF6;&#x5E76;&#x5728;&#x7EC4;&#x4EF6;&#x653E;&#x5728;&#x5E03;&#x5C40;&#x4E0A;
        left_widget = QWidget()
        left_layout = QVBoxLayout()
        img_title = QLabel("&#x6837;&#x672C;")
        img_title.setFont(font)
        img_title.setAlignment(Qt.AlignCenter)
        self.img_label = QLabel()
        img_init = cv2.imread(self.to_predict_name)
        h, w, c = img_init.shape
        scale = 400 / h
        img_show = cv2.resize(img_init, (0, 0), fx=scale, fy=scale)
        cv2.imwrite("images/show.png", img_show)
        img_init = cv2.resize(img_init, (224, 224))
        cv2.imwrite('images/target.png', img_init)
        self.img_label.setPixmap(QPixmap("images/show.png"))
        left_layout.addWidget(img_title)
        left_layout.addWidget(self.img_label, 1, Qt.AlignCenter)
        left_widget.setLayout(left_layout)
        right_widget = QWidget()
        right_layout = QVBoxLayout()
        btn_change = QPushButton(" &#x4E0A;&#x4F20;&#x56FE;&#x7247; ")
        btn_change.clicked.connect(self.change_img)
        btn_change.setFont(font)
        btn_predict = QPushButton(" &#x5F00;&#x59CB;&#x8BC6;&#x522B; ")
        btn_predict.setFont(font)
        btn_predict.clicked.connect(self.predict_img)
        label_result = QLabel(' &#x67D1;&#x6A58;&#x75C5;&#x866B;&#x5BB3;&#x540D;&#x79F0; ')
        self.result = QLabel("&#x7B49;&#x5F85;&#x8BC6;&#x522B;")
        label_result.setFont(QFont('&#x6977;&#x4F53;', 16))
        self.result.setFont(QFont('&#x6977;&#x4F53;', 24))
        right_layout.addStretch()
        right_layout.addWidget(label_result, 0, Qt.AlignCenter)
        right_layout.addStretch()
        right_layout.addWidget(self.result, 0, Qt.AlignCenter)
        right_layout.addStretch()
        right_layout.addStretch()
        right_layout.addWidget(btn_change)
        right_layout.addWidget(btn_predict)
        right_layout.addStretch()
        right_widget.setLayout(right_layout)
        main_layout.addWidget(left_widget)
        main_layout.addWidget(right_widget)
        main_widget.setLayout(main_layout)

        # &#x5173;&#x4E8E;&#x9875;&#x9762;&#xFF0C;&#x8BBE;&#x7F6E;&#x7EC4;&#x4EF6;&#x5E76;&#x628A;&#x7EC4;&#x4EF6;&#x653E;&#x5728;&#x5E03;&#x5C40;&#x4E0A;
        about_widget = QWidget()
        about_layout = QVBoxLayout()
        about_title = QLabel('&#x6B22;&#x8FCE;&#x4F7F;&#x7528;&#x67D1;&#x6A58;&#x75C5;&#x866B;&#x5BB3;&#x8BC6;&#x522B;&#x7CFB;&#x7EDF;')  # todo &#x4FEE;&#x6539;&#x6B22;&#x8FCE;&#x8BCD;&#x8BED;
        about_title.setFont(QFont('&#x6977;&#x4F53;', 18))
        about_title.setAlignment(Qt.AlignCenter)
        about_img = QLabel()
        about_img.setPixmap(QPixmap('images/bj.jpg'))
        about_img.setAlignment(Qt.AlignCenter)
        label_super = QLabel("&#x4F5C;&#x8005;&#xFF1A;&#x5B8B;&#x626C;")  # todo &#x66F4;&#x6362;&#x4F5C;&#x8005;&#x4FE1;&#x606F;
        label_super.setFont(QFont('&#x6977;&#x4F53;', 15))
        # label_super.setOpenExternalLinks(True)
        label_super.setAlignment(Qt.AlignRight)
        about_layout.addWidget(about_title)
        about_layout.addStretch()
        about_layout.addWidget(about_img)
        about_layout.addStretch()
        about_layout.addWidget(label_super)
        about_widget.setLayout(about_layout)

        # &#x6DFB;&#x52A0;&#x6CE8;&#x91CA;
        self.addTab(main_widget, '&#x4E3B;&#x9875;')
        self.addTab(about_widget, '&#x5173;&#x4E8E;')
        self.setTabIcon(0, QIcon('images/&#x4E3B;&#x9875;&#x9762;.png'))
        self.setTabIcon(1, QIcon('images/&#x5173;&#x4E8E;.png'))

    # &#x4E0A;&#x4F20;&#x5E76;&#x663E;&#x793A;&#x56FE;&#x7247;
    def change_img(self):
        openfile_name = QFileDialog.getOpenFileName(self, 'chose files', '',
                                                    'Image files(*.jpg *.png *jpeg)')  # &#x6253;&#x5F00;&#x6587;&#x4EF6;&#x9009;&#x62E9;&#x6846;&#x9009;&#x62E9;&#x6587;&#x4EF6;
        img_name = openfile_name[0]  # &#x83B7;&#x53D6;&#x56FE;&#x7247;&#x540D;&#x79F0;
        if img_name == '':
            pass
        else:
            target_image_name = "images/tmp_up." + img_name.split(".")[-1]  # &#x5C06;&#x56FE;&#x7247;&#x79FB;&#x52A8;&#x5230;&#x5F53;&#x524D;&#x76EE;&#x5F55;
            shutil.copy(img_name, target_image_name)
            self.to_predict_name = target_image_name
            img_init = cv2.imread(self.to_predict_name)  # &#x6253;&#x5F00;&#x56FE;&#x7247;
            h, w, c = img_init.shape
            scale = 400 / h
            img_show = cv2.resize(img_init, (0, 0), fx=scale, fy=scale)  # &#x5C06;&#x56FE;&#x7247;&#x7684;&#x5927;&#x5C0F;&#x7EDF;&#x4E00;&#x8C03;&#x6574;&#x5230;400&#x7684;&#x9AD8;&#xFF0C;&#x65B9;&#x4FBF;&#x754C;&#x9762;&#x663E;&#x793A;
            cv2.imwrite("images/show.png", img_show)
            img_init = cv2.resize(img_init, (224, 224))  # &#x5C06;&#x56FE;&#x7247;&#x5927;&#x5C0F;&#x8C03;&#x6574;&#x5230;224*224&#x7528;&#x4E8E;&#x6A21;&#x578B;&#x63A8;&#x7406;
            cv2.imwrite('images/target.png', img_init)
            self.img_label.setPixmap(QPixmap("images/show.png"))
            self.result.setText("&#x7B49;&#x5F85;&#x8BC6;&#x522B;")

    # &#x9884;&#x6D4B;&#x56FE;&#x7247;
    def predict_img(self):
        img = Image.open('images/target.png')  # &#x8BFB;&#x53D6;&#x56FE;&#x7247;
        img = np.asarray(img)  # &#x5C06;&#x56FE;&#x7247;&#x8F6C;&#x5316;&#x4E3A;numpy&#x7684;&#x6570;&#x7EC4;
        outputs = self.model.predict(img.reshape(1, 224, 224, 3))  # &#x5C06;&#x56FE;&#x7247;&#x8F93;&#x5165;&#x6A21;&#x578B;&#x5F97;&#x5230;&#x7ED3;&#x679C;
        result_index = int(np.argmax(outputs))
        result = self.class_names[result_index]  # &#x83B7;&#x5F97;&#x5BF9;&#x5E94;&#x7684;&#x6C34;&#x679C;&#x540D;&#x79F0;
        self.result.setText(result)  # &#x5728;&#x754C;&#x9762;&#x4E0A;&#x505A;&#x663E;&#x793A;

    # &#x754C;&#x9762;&#x5173;&#x95ED;&#x4E8B;&#x4EF6;&#xFF0C;&#x8BE2;&#x95EE;&#x7528;&#x6237;&#x662F;&#x5426;&#x5173;&#x95ED;
    def closeEvent(self, event):
        reply = QMessageBox.question(self,
                                     '&#x9000;&#x51FA;',
                                     "&#x662F;&#x5426;&#x8981;&#x9000;&#x51FA;&#x7A0B;&#x5E8F;&#xFF1F;",
                                     QMessageBox.Yes | QMessageBox.No,
                                     QMessageBox.No)
        if reply == QMessageBox.Yes:
            self.close()
            event.accept()
        else:
            event.ignore()

if __name__ == "__main__":
    app = QApplication(sys.argv)
    x = MainWindow()
    x.show()
    sys.exit(app.exec_())

design.py用于测试界面

基于Tensorflow2.x的CNN的病虫害分类(有界面)

Original: https://blog.csdn.net/songyang66/article/details/124106875
Author: 天道酬勤者
Title: 基于Tensorflow2.x的CNN的病虫害分类(有界面)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/508775/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球