一、代码中的数据集可以点击以下链接进行下载
二、代码运行环境
Tensorflow-gpu==2.4.0
Python==3.7
三、数据集处理的代码如下所示
import tensorflow as tf
import os
import glob
import random
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
label_to_index = {
'airplane': 0,
'lake': 1
}
index_to_label = dict((v, k) for k, v in label_to_index.items())
def load_img(path):
img = tf.io.read_file(path)
img = tf.image.decode_jpeg(img, channels=3)
img = tf.image.resize(img, [256, 256])
img = tf.cast(img, tf.float32)
img = img / 255
return img
def load_dataset():
all_image_path = glob.glob(r'dataset/*/*.jpg')
random.shuffle(all_image_path)
all_image_labels = [label_to_index.get(img.split('\\')[1]) for img in all_image_path]
img_dataset = tf.data.Dataset.from_tensor_slices(all_image_path)
img_dataset = img_dataset.map(load_img)
label_dataset = tf.data.Dataset.from_tensor_slices(all_image_labels)
dataset = tf.data.Dataset.zip((img_dataset, label_dataset))
image_count = len(all_image_path)
test_count = int(image_count * 0.2)
train_count = image_count - test_count
train_dataset = dataset.skip(test_count)
test_dataset = dataset.take(test_count)
BATCH_SIZE = 16
train_dataset = train_dataset.repeat().shuffle(100).batch(BATCH_SIZE)
test_dataset = test_dataset.batch(BATCH_SIZE)
return train_dataset, test_dataset, train_count, test_count
if __name__ == '__main__':
train, test, train_c, test_c = load_dataset()
print(train)
print(test)
四、模型的构建代码如下所示
import tensorflow as tf
import os
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
def make_model():
model = tf.keras.Sequential()
model.add(tf.keras.layers.Conv2D(64, (3, 3), input_shape=(256, 256, 3), activation='relu', padding='same'))
model.add(tf.keras.layers.Conv2D(64, (3, 3), activation='relu', padding='same'))
model.add(tf.keras.layers.MaxPooling2D())
model.add(tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(tf.keras.layers.Conv2D(128, (3, 3), activation='relu', padding='same'))
model.add(tf.keras.layers.MaxPooling2D())
model.add(tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
model.add(tf.keras.layers.Conv2D(256, (3, 3), activation='relu', padding='same'))
model.add(tf.keras.layers.MaxPooling2D())
model.add(tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same'))
model.add(tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same'))
model.add(tf.keras.layers.MaxPooling2D())
model.add(tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same'))
model.add(tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same'))
model.add(tf.keras.layers.Conv2D(512, (3, 3), activation='relu', padding='same'))
model.add(tf.keras.layers.GlobalAveragePooling2D())
model.add(tf.keras.layers.Dense(1024, activation='relu'))
model.add(tf.keras.layers.Dense(256, activation='relu'))
model.add(tf.keras.layers.Dense(1, activation='sigmoid'))
return model
if __name__ == '__main__':
my_model = make_model()
五、模型的训练代码如下所示
import tensorflow as tf
import os
from data_loader import load_dataset
from model import make_model
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
train_data, test_data, train_count, test_count = load_dataset()
model = make_model()
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=0.0001),
loss=tf.keras.losses.BinaryCrossentropy(),
metrics=['acc']
)
steps_per_epoch = train_count // 16
val_step = test_count // 16
history = model.fit(train_data, epochs=100, steps_per_epoch=steps_per_epoch, validation_data=test_data,
validation_steps=val_step, workers=6)
model.save(r'model_data/class_model.h5')
六、模型的预测代码如下所示
import tensorflow as tf
import os
from data_loader import load_img, index_to_label
import matplotlib.pyplot as plt
os.environ['TF_XLA_FLAGS'] = '--tf_xla_enable_xla_devices'
os.environ['TF_FORCE_GPU_ALLOW_GROWTH'] = 'true'
model = tf.keras.models.load_model(r'model_data/class_model.h5')
while True:
path = input('请输入检测图片的路径:\n')
try:
img = load_img(path)
except:
print('图片路径输入错误!请重新输入正确文件路径!')
continue
else:
plt.imshow(img.numpy())
img = tf.expand_dims(img, axis=0)
pre_result = model.predict(img)
result = index_to_label.get((pre_result > 0.5).astype('int')[0][0])
plt.title(result)
plt.show()
七、代码的运行结果如下所示
; 八、代码的整体工程文件下载链接如下:
Original: https://blog.csdn.net/qq_44961869/article/details/122327273
Author: 水哥很水
Title: Tensorflow—使用Tensorflow进行机场与湖泊的二分类
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/664798/
转载文章受原作者版权保护。转载请注明原作者出处!