LeNet-5-实现-cifar2

标题`#LeNet-5 完成 cifar2(无注释源代码在本文最下方)

import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers, losses, Model

1)定义一个有参有返回值的函数用于加载图片

def load_img(file_path):
img = tf.io.read_file(file_path)
img = tf.image.decode_jpeg(img)
img = tf.image.resize(img, [32, 32]) / 255.

label = tf.constant(1,tf.int32) if tf.strings.regex_full_match(file_path, ‘. airplane.‘) else tf.constant(0,tf.int32)
return img, label

2)合理定义相关参数

batch_num = 100
epochs = 15

3)使用通道和自定义函数加载cifar2数据集

train_data = tf.data.Dataset.list_files(‘cifar2/train/ /.jpg’).map(load_img,tf.data.experimental.AUTOTUNE).shuffle(buffer_size=1000).batch(100).prefetch(tf.data.experimental.AUTOTUNE)
test_data = tf.data.Dataset.list_files(‘cifar2/test/ /.jpg’).map(load_img,tf.data.experimental.AUTOTUNE).shuffle(buffer_size=1000).batch(100).prefetch(tf.data.experimental.AUTOTUNE)

②模型搭建

class LeNet(Model):
def init(self):
super(LeNet, self). init()
self.c1 = layers.Conv2D(6, 5)
self.s2 = layers.MaxPooling2D()
self.c3 = layers.Conv2D(16, 5)
self.s4 = layers.MaxPooling2D()
self.f5 = layers.Flatten()
self.d6 = layers.Dense(120, activation=’relu’)
self.d7 = layers.Dense(84, activation=’relu’)
self.d8 = layers.Dense(2, activation=’softmax’)

3)进行正向传播
@tf.function
def call(self, inputs):
    out = self.c1(inputs)
    out = self.s2(out)
    out = self.c3(out)
    out = self.s4(out)
    out = self.f5(out)
    out = self.d6(out)
    out = self.d7(out)
    out = self.d8(out)
    return out

③模型预测

model = LeNet()

1)查看每层的参数数量

model.build((None, 32, 32, 3))
model.summary()

2)进行训练,选择Adam优化器,合理选择损失函数和迭代次数

model.compile(‘adam’, losses.sparse_categorical_crossentropy, ‘accuracy’)
history = model.fit(train_data, epochs=epochs, validation_data=test_data)

3)绘制训练集与测试集准确率对比图

plt.plot(history.history[‘val_accuracy’], label=’test_accuracy’)
plt.plot(history.history[‘accuracy’], label=’train_accuracy’)
plt.legend()
plt.show()

”’
Model: “le_net”

Layer (type) Output Shape Param #

conv2d (Conv2D) multiple 456

max_pooling2d (MaxPooling2D) multiple 0

conv2d_1 (Conv2D) multiple 2416

max_pooling2d_1 (MaxPooling2 multiple 0

flatten (Flatten) multiple 0

dense (Dense) multiple 48120

dense_1 (Dense) multiple 10164

dense_2 (Dense) multiple 170

Total params: 61,326
Trainable params: 61,326
Non-trainable params: 0

Epoch 1/15
100/100 [] – 4s 40ms/step – loss: 0.4533 – accuracy: 0.7902 – val_loss: 0.3372 – val_accuracy: 0.8515
Epoch 2/15
100/100 [] – 4s 40ms/step – loss: 0.3327 – accuracy: 0.8545 – val_loss: 0.2771 – val_accuracy: 0.8810
Epoch 3/15
100/100 [] – 4s 43ms/step – loss: 0.2565 – accuracy: 0.8940 – val_loss: 0.2434 – val_accuracy: 0.9025
Epoch 4/15
100/100 [] – 5s 45ms/step – loss: 0.2159 – accuracy: 0.9110 – val_loss: 0.2283 – val_accuracy: 0.9115
Epoch 5/15
100/100 [] – 5s 46ms/step – loss: 0.1786 – accuracy: 0.9289 – val_loss: 0.2228 – val_accuracy: 0.9030
Epoch 6/15
100/100 [] – 4s 45ms/step – loss: 0.1574 – accuracy: 0.9384 – val_loss: 0.2079 – val_accuracy: 0.9175
Epoch 7/15
100/100 [] – 4s 45ms/step – loss: 0.1290 – accuracy: 0.9529 – val_loss: 0.2092 – val_accuracy: 0.9205
Epoch 8/15
100/100 [] – 4s 41ms/step – loss: 0.1022 – accuracy: 0.9603 – val_loss: 0.2297 – val_accuracy: 0.9095
Epoch 9/15
100/100 [] – 4s 43ms/step – loss: 0.0907 – accuracy: 0.9671 – val_loss: 0.2313 – val_accuracy: 0.9200
Epoch 10/15
100/100 [] – 4s 44ms/step – loss: 0.0670 – accuracy: 0.9744 – val_loss: 0.2353 – val_accuracy: 0.9230
Epoch 11/15
100/100 [] – 4s 40ms/step – loss: 0.0501 – accuracy: 0.9817 – val_loss: 0.2627 – val_accuracy: 0.9160
Epoch 12/15
100/100 [] – 4s 39ms/step – loss: 0.0366 – accuracy: 0.9888 – val_loss: 0.2789 – val_accuracy: 0.9250
Epoch 13/15
100/100 [] – 4s 39ms/step – loss: 0.0293 – accuracy: 0.9901 – val_loss: 0.2958 – val_accuracy: 0.9115
Epoch 14/15
100/100 [] – 4s 39ms/step – loss: 0.0335 – accuracy: 0.9860 – val_loss: 0.3240 – val_accuracy: 0.9090
Epoch 15/15
100/100 [==============================] – 4s 40ms/step – loss: 0.0201 – accuracy: 0.9939 – val_loss: 0.3261 – val_accuracy: 0.9235

Process finished with exit code 0
”’

LeNet-5-实现-cifar2
import tensorflow as tf
import matplotlib.pyplot as plt
from tensorflow.keras import layers, losses, Model
def load_img(file_path):
    img = tf.io.read_file(file_path)
    img = tf.image.decode_jpeg(img)
    img = tf.image.resize(img, [32, 32]) / 255.

    label = tf.constant(1,tf.int32) if tf.strings.regex_full_match(file_path, '.*airplane.*') else tf.constant(0,tf.int32)
    return img, label
batch_num = 100
epochs = 15
train_data = tf.data.Dataset.list_files('cifar2/train/*/*.jpg')./
map(load_img,tf.data.experimental.AUTOTUNE)./
shuffle(buffer_size=1000).batch(100)./
prefetch(tf.data.experimental.AUTOTUNE)
test_data = tf.data.Dataset.list_files('cifar2/test/*/*.jpg')./
map(load_img,tf.data.experimental.AUTOTUNE)./
shuffle(buffer_size=1000).batch(100)./
prefetch(tf.data.experimental.AUTOTUNE)

class LeNet(Model):
    def __init__(self):
        super(LeNet, self).__init__()
        self.c1 = layers.Conv2D(6, 5)
        self.s2 = layers.MaxPooling2D()
        self.c3 = layers.Conv2D(16, 5)
        self.s4 = layers.MaxPooling2D()
        self.f5 = layers.Flatten()
        self.d6 = layers.Dense(120, activation='relu')
        self.d7 = layers.Dense(84, activation='relu')
        self.d8 = layers.Dense(2, activation='softmax')
    @tf.function
    def call(self, inputs):
        out = self.c1(inputs)
        out = self.s2(out)
        out = self.c3(out)
        out = self.s4(out)
        out = self.f5(out)
        out = self.d6(out)
        out = self.d7(out)
        out = self.d8(out)
        return out

model = LeNet()
model.build((None, 32, 32, 3))
model.summary()

model.compile('adam', losses.sparse_categorical_crossentropy, 'accuracy')
history = model.fit(train_data, epochs=epochs, validation_data=test_data)

plt.plot(history.history['val_accuracy'], label='test_accuracy')
plt.plot(history.history['accuracy'], label='train_accuracy')
plt.legend()
plt.show()

Original: https://blog.csdn.net/Refuse_to_fail/article/details/121363364
Author: 狂风后的平静
Title: LeNet-5-实现-cifar2

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/514502/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球