文章目录
TensorFlow2.x学习笔记—Keras高层接口
在 TensorFlow2.x
版本中, Keras
被正式确定为 TensorFlow
的高层 API
唯一接口,取代了 TensorFlow1.x
版本中自带的 tf.layers
等高层接口。也就是说,现在只能使用 Keras
的接口来完成 TensorFlow
层方式的模型搭建与训练。在 TensorFlow
中, Keras
被实现在 tf.keras
子模块中。对于使用 TensorFlow
的开发者来说, tf.keras
可以理解为一个普通的子模块,与其他子模块,如 tf.math
, tf.data
等并没有什么差别。
1. 常见功能模块
- 常见数据集加载函数
- 网络层类
- 模型容器
- 损失函数类
- 优化器类
- 经典模型类
1.1 常见数据集加载函数
该路径下面有一个 mnist.py
文件
from tensorflow.keras.datasets.mnist import load_data
data = load_data("mnist.npz")
x_train, y_train = data[0][0], data[0][1]
x_test, y_test = data[1][0], data[1][1]
1.2 网络层类
import tensorflow as tf
from tensorflow.keras import layers
x = tf.constant([2., 1.])
layer = layers.Softmax(axis = -1)
layer(x)
1.3 网络容器
Keras
网络容器 Sequential
将多个网络层封装成一个大网络模型,只需要调用网络模型的实例一次即可完成数据从第一层到最末层的顺序运算。
from tensorflow.keras import layers, Sequential
network = Sequential([layers.Dense(3, activation = None),
layers.ReLU(),
layers.Dense(2, activation = None),
layers.ReLU()])
x = tf.random.normal([4, 3])
network(x)
追加网络层
layer_num = 2
network = Sequential([])
for _ in range(layer_num):
network.add(layers.Dense(3))
network.add(layers.ReLU())
network.build(input_shape = (None, 4))
network.summary()
Model: "sequential"
_________________________________________________________________
Layer (type) Output Shape Param
=================================================================
dense (Dense) (None, 3) 15 (4 * 3 + 3)
_________________________________________________________________
re_lu (ReLU) (None, 3) 0
_________________________________________________________________
dense_1 (Dense) (None, 3) 12 (3 * 3 + 3)
_________________________________________________________________
re_lu_1 (ReLU) (None, 3) 0
=================================================================
Total params: 27
Trainable params: 27
Non-trainable params: 0
_________________________________________________________________
for p in network.trainable_variables:
print(p.name, p.shape)
dense_2/kernel:0 (4, 3)
dense_2/bias:0 (3,)
dense_3/kernel:0 (3, 3)
dense_3/bias:0 (3,)
2. 模型装配、训练与测试
2.1 模型装配
keras.Model
keras.layers.Layer
network = Sequential([layers.Dense(256, activation = "relu"),
layers.Dense(128, activation = "relu"),
layers.Dense(64, activation = "relu"),
layers.Dense(32, activation = "relu"),
layers.Dense(10)])
network.build(input_shape = (None, 28 * 28))
network.summary()
Model: "sequential_4"
_________________________________________________________________
Layer (type) Output Shape Param
=================================================================
dense_15 (Dense) (None, 256) 200960
_________________________________________________________________
dense_16 (Dense) (None, 128) 32896
_________________________________________________________________
dense_17 (Dense) (None, 64) 8256
_________________________________________________________________
dense_18 (Dense) (None, 32) 2080
_________________________________________________________________
dense_19 (Dense) (None, 10) 330
=================================================================
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
________________________________________________________________
782 × 256 + 256 = 200960 782\times 256 + 256 = 200960 7 8 2 ×2 5 6 +2 5 6 =2 0 0 9 6 0
256 × 128 + 128 = 32896 256\times 128 + 128 = 32896 2 5 6 ×1 2 8 +1 2 8 =3 2 8 9 6
128 × 64 + 64 = 8256 128\times 64 +64 = 8256 1 2 8 ×6 4 +6 4 =8 2 5 6
64 × 32 + 32 = 2080 64\times 32+ 32 = 2080 6 4 ×3 2 +3 2 =2 0 8 0
32 ∗ 10 + 10 = 330 32 * 10 +10 = 330 3 2 ∗1 0 +1 0 =3 3 0
- *通过compile()函数指定网络使用的优化器对象,损失函数,评价指标等
from tensorflow.keras import optimizers, losses
network.compile(optimizer = optimizers.Adam(lr = 0.01),
loss = losses.CategoricalCrossentropy(from_logits = True),
metrics = ["accuracy"])
2.2 模型训练
- *通过fit()函数送入待训练的数据和验证用的数据集
history = network.fit(train, epochs = 5, validation_data = val, validation_freq = 2)
history.history
2.3 模型测试
- *通过Model.predict(x)方法完成模型的预测
x, y = next(iter(db_test))
print("predict x:", x.shape)
out = network.predict(x)
print(out)
2.4 模型保存与加载
- *Tensor方式
network.save_weights("weights.ckpt")
print("saved weights.")
del network
network = Sequential([layers.Dense(256, activation = "relu"),
layers.Dense(128, activation = "relu"),
layers.Dense(64, activation = "relu"),
layers.Dense(32, activation = "relu"),
layers.Dense(10)])
network.compile(optimizer = optimizers.Adam(lr = 0.01),
loss = tf.losses.CategoricalCrossentropy(from_logits = True),
metrics = ['accuracy'])
network.load_weights("weights.cpkt")
print("loaded weights!")
- *网络方式
network.save("model.h5")
print("saved total model.")
del network
network = tf.keras.models.load_model("model.h5")
- *Save Model 方式
tf.keras.experimental.export_saved_model(network, 'model-savedmodel')
print('export saved model.')
del network
network = tf.keras.experimental.load_from_saved_model('model-savedmodel')
2.5 自定义类
- 创建自定义网络层类,需要继承自 layers.Layer 基类
- 创建自定义的网络类,需要继承自keras.Model 基类
P180
2.6 模型乐园
P181
2.7 测量工具
- *新建测量器
from tensorflow.keras import metrics
loss_meter = metrics.Mean()
- *写入数据
loss_meter.update_state(float(loss))
- *读取统计数据
print(step, "loss:", loss_meter.result())
- *清零测量器
if step % 100 == 0:
print(step, "loss:", loss_meter.result())
loss_meter.reset_states()
实战
acc_meter = metrics.Accuracy()
out = network(x)
pred = tf.argmax(out, axis = 1)
pred = tf.cast(pred, dtype = tf.int32)
acc_meter.update_state(y, pred)
print(step, "Evaluate Acc:", acc_meter.result().numpy())
acc_meter.reset_states()
2.8 可视化
- *模型端
summary_writer = tf.summary.create_file_writer(log_dir)
with summary_writer.as_default():
tf.summary.scalar('train-loss', float(loss), step=step)
with summary_writer.as_default():
tf.summary.scalar('test-acc', float(total_correct/total), step=step)
tf.summary.image("val-onebyone-images:", val_images, max_outputs=9, step=step)
P185
- 浏览器端
tensorboard --logdir path
with summary_writer.as_default():
tf.summary.scalar('train-loss', float(loss), step=step)
tf.summary.histogram('y-hist', y, step=step)
tf.summary.text('loss-text', str(float(loss)))
Facebook 的 Visdom
Original: https://blog.csdn.net/m0_46459047/article/details/122092949
Author: LittleFish0820
Title: 【TensorFlow2.x】Keras高层接口
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/520847/
转载文章受原作者版权保护。转载请注明原作者出处!