tensorflow导入自己的数据集

在构建tensorflow模型过程中,可谓是曲折颇多,一些教程上教会了我们如何使用下载的现成数据集,但却没有提及如何构建自己的数据集。我自己在学习过程中也走了不少弯路,希望这一系列的博客能解决大家的一些困惑。

我们通过以下步骤在本地构建数据集

[En]

We build the dataset locally in the following steps

1.数据处理
2.数据增强
3.数据导入
4.构建模型
5.训练模型

本文首先介绍了数据处理的一些操作,以下几个步骤将慢慢展开。

[En]

This article first talks about some operations of data processing, and the following steps will be sent out slowly.

1.导入第三方库

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import math
import pathlib
import random
import matplotlib.pyplot as plt
import numpy as np

这里会注意到,我在导入os库时,在后面加了

os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'

这句话的作用是避免报错:This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN)

2.导入数据路径

data_root = pathlib.Path('./image')
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]

我这里的./image是我本地图片集所在的文件夹,image文件夹下是两个分别保存不同种类图片的文件夹,因为我这里是做二分类,所以只有两个不同种类的文件夹,如果大家需要构建识别多种图片的模型,可以添加其他文件夹。

tensorflow导入自己的数据集

3.随机打乱图片,这一步的目的是为了让图片集去特殊化,提高模型的准确率,因为如果你的图片中有比较相近的,而且数量比较多,会影响模型的学习。这一步是调用了random的shuffle,传入图片集列表,随机打乱。

random.shuffle(all_image_paths)

4.构建标签及索引

其实是构建了一个字典

#列出标签
label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
#为标签分配索引
label_to_index = dict((name, index) for index, name in enumerate(label_names))
#创建列表,存放标签和索引
all_image_labels = [label_to_index[pathlib.Path(path).parent.name]
                    for path in all_image_paths]

5.加载和格式化图片

我们可以看到,tf.image.decode_jpeg(image,channels=3,这句话的作用是把图片变成三通道图,即RGB式图片。需要强调一下,tf.image.resize()这个小东西好用的很,可以把你的图片统一大小,这在后面我们训练模型是必须的,统一大小的图片更有利于我们的模型学习。而image/255.0是为了使图像进行归一化,得到的数值范围为[0, 1],彩色图片会变成灰图。

load_and_prepro_image()这个函数就是读取传入路径的图片集,然后返回值是经过了preprocess_image 这个函数的调用,将返回的图片处理为灰度图,比较简单暴力。

#加载和格式化图片
def preprocess_image(image):
    image = tf.image.decode_jpeg(image,channels=3)
    image = tf.image.resize(image,[192,192])
    image /= 255.0
    return image

def load_and_prepro_image(path):
    image = tf.io.read_file(path)
    return preprocess_image(image)
for i in range(len(all_image_paths)):
    image_path = all_image_paths[i]
    label = all_image_labels[i]
    plt.imshow(load_and_prepro_image(image_path))
    plt.grid(False)
    plt.xlabel(image_path)
    plt.title(label_names[label].title())
    #plt.show()

然后关于这个for循环,其实不是必须的,只是为了方便我们检查图片的处理效果,调用的库是matplotlib,python比较有名的绘图库。

就先到这,后会有期。

下面是全部源码,tensorflow版本是2.5,py版本3.7,cuda11.6。

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import datasets, layers, optimizers, Sequential, metrics
from tensorflow.python.framework.convert_to_constants import convert_variables_to_constants_v2
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2'
import math
import pathlib
import random
import matplotlib.pyplot as plt
import numpy as np

#数据处理
#导入数据路径
data_root = pathlib.Path('./image')
all_image_paths = list(data_root.glob('*/*'))
all_image_paths = [str(path) for path in all_image_paths]
#随机打乱图片
random.shuffle(all_image_paths)
#列出标签
label_names = sorted(item.name for item in data_root.glob('*/') if item.is_dir())
#为标签分配索引
label_to_index = dict((name, index) for index, name in enumerate(label_names))
#创建列表,存放标签和索引
all_image_labels = [label_to_index[pathlib.Path(path).parent.name]
                    for path in all_image_paths]
#加载和格式化图片
def preprocess_image(image):
    image = tf.image.decode_jpeg(image,channels=3)
    image = tf.image.resize(image,[192,192])
    image = image/255.0
    return image

def load_and_prepro_image(path):
    image = tf.io.read_file(path)
    return preprocess_image(image)
for i in range(len(all_image_paths)):
    image_path = all_image_paths[i]
    label = all_image_labels[i]
    plt.imshow(load_and_prepro_image(image_path))
    plt.grid(False)
    plt.xlabel(image_path)
    plt.title(label_names[label].title())
    #plt.show()

Original: https://blog.csdn.net/m0_58775709/article/details/123884294
Author: 冯简
Title: tensorflow导入自己的数据集

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/496707/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球