Tensorflow2.0学习-保存和加载模型 (五)

文章目录

保存加载模型

训练完模型之后,需要保存的,要不每次想测试的时候,都要走一遍训练,多麻烦呀。所以就需要保存以及加载。而且,有时候,模型跑着跑着就断了,还可以接着训练。保存加载模型

引包

以前一般是保存成 .h5格式的文件,现在还有一种是 SaveModel

import os

import tensorflow as tf
from tensorflow import keras

print(tf.version.VERSION)

数据准备

MNIST 数据集进行演示,数据量就选1000个,然后让数值在0-1之间。

(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_labels = train_labels[:1000]
test_labels = test_labels[:1000]

train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0

模型准备

一个简简单单的两层全连接模型。


def create_model():
  model = tf.keras.models.Sequential([
    keras.layers.Dense(512, activation='relu', input_shape=(784,)),
    keras.layers.Dropout(0.2),
    keras.layers.Dense(10)
  ])

  model.compile(optimizer='adam',
                loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
                metrics=[tf.metrics.SparseCategoricalAccuracy()])

  return model

model = create_model()

model.summary()

跑起来并保存模型

用这个方法进行回调 tf.keras.callbacks.ModelCheckpoint

回调允许您在训练期间和结束时持续保存模型

方法定义如下

tf.keras.callbacks.ModelCheckpoint(
    filepath,
    monitor='val_loss',
    verbose=0,
    save_best_only=False,
    save_weights_only=False,
    mode='auto',
    save_freq='epoch',
    options=None,
    initial_value_threshold=None,
    **kwargs
)

回到例子里来,这里的 tf.keras.callbacks.ModelCheckpoint只用了三个参数
filepath保存路径
save_weights_only只保存模型参数
verbose输出信息参数

checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
                                                 save_weights_only=True,
                                                 verbose=1)

model.fit(train_images,
          train_labels,
          epochs=10,
          validation_data=(test_images, test_labels),
          callbacks=[cp_callback])

执行完毕后,在文件同目录文件夹下,可以看见保存的文件了。

Tensorflow2.0学习-保存和加载模型 (五)

加载模型

再建立一个一摸一样的全新模型,来加载一下,再进行测试。


model = create_model()

loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))

输出,未经训练的模型,准确率只有7.4%

Untrained model, accuracy:  7.40%

加载下之前训好的,再测试一下。


model.load_weights(checkpoint_path)

loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))

输出,训练后的模型准确率就有86.5%

Restored model, accuracy: 86.50%

checkpoint 回调选项

回调里有一些选项,可以为 checkpoint提供唯一名称,并且调整保存的频率。比如下面这个每五次保存依次。


checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)

batch_size = 32

cp_callback = tf.keras.callbacks.ModelCheckpoint(
    filepath=checkpoint_path,
    verbose=1,
    save_weights_only=True,
    save_freq=5*batch_size)

model = create_model()

model.save_weights(checkpoint_path.format(epoch=0))

model.fit(train_images,
          train_labels,
          epochs=50,
          batch_size=batch_size,
          callbacks=[cp_callback],
          validation_data=(test_images, test_labels),
          verbose=0)

执行完毕会在项目中保存这样格式的文件。

Tensorflow2.0学习-保存和加载模型 (五)
将最近的 checkpoint拿出来,就是 training_2/cp-0050.ckpt
latest = tf.train.latest_checkpoint(checkpoint_dir)

带入全新模型测试一下。


model = create_model()

model.load_weights(latest)

loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))

结果如下

Restored model, accuracy: 87.30%

保存模型的另外一些方式

手动保存


model.save_weights('./checkpoints/my_checkpoint')

model = create_model()

model.load_weights('./checkpoints/my_checkpoint')

loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))

保存整个模型,不光是模型参数,还有模型结构、优化器等等,可以接着训练。两种文件格式 SavedModelHDF5

SaveModel文件格式

新模型新训练,再保存。

以这种格式保存的模型可以使用 tf.keras.models.load_model 恢复


model = create_model()
model.fit(train_images, train_labels, epochs=5)

model.save('saved_model/my_model')

SavedModel 格式是一个包含 protobuf 二进制文件和 TensorFlow 检查点的目录

Tensorflow2.0学习-保存和加载模型 (五)
直接加载一个新的模型出来,不用先定义
new_model = tf.keras.models.load_model('saved_model/my_model')

new_model.summary()

loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))

print(new_model.predict(test_images).shape)

输出

32/32 - 0s - loss: 0.4384 - sparse_categorical_accuracy: 0.8530
Restored model, accuracy: 85.30%
(1000, 10)

HDF5 文件格式

新建个模型,和刚刚一样的操作,就是保存变了。


model = create_model()
model.fit(train_images, train_labels, epochs=5)

model.save('my_model.h5')

直接从 .h5文件中加载模型


new_model = tf.keras.models.load_model('my_model.h5')

new_model.summary()

loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))

结果

32/32 - 0s - loss: 0.4126 - sparse_categorical_accuracy: 0.8580
Restored model, accuracy: 85.80%

Original: https://blog.csdn.net/u010095372/article/details/124515424
Author: 赫凯
Title: Tensorflow2.0学习-保存和加载模型 (五)

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/521256/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球