文章目录
保存加载模型
训练完模型之后,需要保存的,要不每次想测试的时候,都要走一遍训练,多麻烦呀。所以就需要保存以及加载。而且,有时候,模型跑着跑着就断了,还可以接着训练。保存加载模型
引包
以前一般是保存成 .h5
格式的文件,现在还有一种是 SaveModel
。
import os
import tensorflow as tf
from tensorflow import keras
print(tf.version.VERSION)
数据准备
用MNIST 数据集进行演示,数据量就选1000个,然后让数值在0-1之间。
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()
train_labels = train_labels[:1000]
test_labels = test_labels[:1000]
train_images = train_images[:1000].reshape(-1, 28 * 28) / 255.0
test_images = test_images[:1000].reshape(-1, 28 * 28) / 255.0
模型准备
一个简简单单的两层全连接模型。
def create_model():
model = tf.keras.models.Sequential([
keras.layers.Dense(512, activation='relu', input_shape=(784,)),
keras.layers.Dropout(0.2),
keras.layers.Dense(10)
])
model.compile(optimizer='adam',
loss=tf.losses.SparseCategoricalCrossentropy(from_logits=True),
metrics=[tf.metrics.SparseCategoricalAccuracy()])
return model
model = create_model()
model.summary()
跑起来并保存模型
用这个方法进行回调 tf.keras.callbacks.ModelCheckpoint
,
回调允许您在训练期间和结束时持续保存模型
方法定义如下
tf.keras.callbacks.ModelCheckpoint(
filepath,
monitor='val_loss',
verbose=0,
save_best_only=False,
save_weights_only=False,
mode='auto',
save_freq='epoch',
options=None,
initial_value_threshold=None,
**kwargs
)
回到例子里来,这里的 tf.keras.callbacks.ModelCheckpoint
只用了三个参数
filepath
保存路径
save_weights_only
只保存模型参数
verbose
输出信息参数
checkpoint_path = "training_1/cp.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
cp_callback = tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_path,
save_weights_only=True,
verbose=1)
model.fit(train_images,
train_labels,
epochs=10,
validation_data=(test_images, test_labels),
callbacks=[cp_callback])
执行完毕后,在文件同目录文件夹下,可以看见保存的文件了。
加载模型
再建立一个一摸一样的全新模型,来加载一下,再进行测试。
model = create_model()
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Untrained model, accuracy: {:5.2f}%".format(100 * acc))
输出,未经训练的模型,准确率只有7.4%
Untrained model, accuracy: 7.40%
加载下之前训好的,再测试一下。
model.load_weights(checkpoint_path)
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
输出,训练后的模型准确率就有86.5%
Restored model, accuracy: 86.50%
checkpoint 回调选项
回调里有一些选项,可以为 checkpoint
提供唯一名称,并且调整保存的频率。比如下面这个每五次保存依次。
checkpoint_path = "training_2/cp-{epoch:04d}.ckpt"
checkpoint_dir = os.path.dirname(checkpoint_path)
batch_size = 32
cp_callback = tf.keras.callbacks.ModelCheckpoint(
filepath=checkpoint_path,
verbose=1,
save_weights_only=True,
save_freq=5*batch_size)
model = create_model()
model.save_weights(checkpoint_path.format(epoch=0))
model.fit(train_images,
train_labels,
epochs=50,
batch_size=batch_size,
callbacks=[cp_callback],
validation_data=(test_images, test_labels),
verbose=0)
执行完毕会在项目中保存这样格式的文件。
将最近的
checkpoint
拿出来,就是 training_2/cp-0050.ckpt
latest = tf.train.latest_checkpoint(checkpoint_dir)
带入全新模型测试一下。
model = create_model()
model.load_weights(latest)
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
结果如下
Restored model, accuracy: 87.30%
保存模型的另外一些方式
手动保存
model.save_weights('./checkpoints/my_checkpoint')
model = create_model()
model.load_weights('./checkpoints/my_checkpoint')
loss, acc = model.evaluate(test_images, test_labels, verbose=2)
print("Restored model, accuracy: {:5.2f}%".format(100 * acc))
保存整个模型,不光是模型参数,还有模型结构、优化器等等,可以接着训练。两种文件格式 SavedModel
和 HDF5
SaveModel文件格式
新模型新训练,再保存。
以这种格式保存的模型可以使用
tf.keras.models.load_model
恢复
model = create_model()
model.fit(train_images, train_labels, epochs=5)
model.save('saved_model/my_model')
SavedModel 格式是一个包含 protobuf 二进制文件和 TensorFlow 检查点的目录
直接加载一个新的模型出来,不用先定义
new_model = tf.keras.models.load_model('saved_model/my_model')
new_model.summary()
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
print(new_model.predict(test_images).shape)
输出
32/32 - 0s - loss: 0.4384 - sparse_categorical_accuracy: 0.8530
Restored model, accuracy: 85.30%
(1000, 10)
HDF5 文件格式
新建个模型,和刚刚一样的操作,就是保存变了。
model = create_model()
model.fit(train_images, train_labels, epochs=5)
model.save('my_model.h5')
直接从 .h5
文件中加载模型
new_model = tf.keras.models.load_model('my_model.h5')
new_model.summary()
loss, acc = new_model.evaluate(test_images, test_labels, verbose=2)
print('Restored model, accuracy: {:5.2f}%'.format(100 * acc))
结果
32/32 - 0s - loss: 0.4126 - sparse_categorical_accuracy: 0.8580
Restored model, accuracy: 85.80%
Original: https://blog.csdn.net/u010095372/article/details/124515424
Author: 赫凯
Title: Tensorflow2.0学习-保存和加载模型 (五)
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/521256/
转载文章受原作者版权保护。转载请注明原作者出处!