TF2-Tips:自定义model.fit

官方示例

keras官方代码给的例子很详细:Customizing what happens in fit()

基础

class CustomModel(keras.Model):
    def train_step(self, data):

        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)

            loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        self.compiled_metrics.update_state(y, y_pred)

        return {m.name: m.result() for m in self.metrics}

import numpy as np

inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=3)
  • CustomModel继承keras.Model,重写了train_step方法
  • self.compiled_loss就是model.compile中的loss方法
  • self.compiled_metrics就是model.compile中的metrics方法

在train_step方法中自定义loss:

loss_tracker = keras.metrics.Mean(name="loss")
mae_metric = keras.metrics.MeanAbsoluteError(name="mae")

class CustomModel(keras.Model):
    def train_step(self, data):
        x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)

            loss = keras.losses.mean_squared_error(y, y_pred)

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        loss_tracker.update_state(loss)
        mae_metric.update_state(y, y_pred)
        return {"loss": loss_tracker.result(), "mae": mae_metric.result()}

    @property
    def metrics(self):

        return [loss_tracker, mae_metric]

inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)

model.compile(optimizer="adam")

x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
model.fit(x, y, epochs=5)
  • loss_tracker有两个方法:
  • update_state:传loss
  • result:当前平均loss
  • property修饰的metrics方法:
  • 在每个epoch开始前调用reset_states方法
  • 如果去掉metrics,则训练中体现的loss不是每个epoch的累积平均loss,而是从训练开始时的累积平均loss
  • 注意:这种情况下,model.compile中不需要再写loss了
  • 踩坑:对于tf2.0和tf2.1,在fit时会报错:”ValueError: The model cannot be compiled because it has no loss to optimize.” TF2.2及以上没问题。
  • 参考文章:AI学习笔记–Tensorflow自定义

class weight&sample weight

class CustomModel(keras.Model):
    def train_step(self, data):

        if len(data) == 3:
            x, y, sample_weight = data
        else:
            sample_weight = None
            x, y = data

        with tf.GradientTape() as tape:
            y_pred = self(x, training=True)

            loss = self.compiled_loss(
                y,
                y_pred,
                sample_weight=sample_weight,
                regularization_losses=self.losses,
            )

        trainable_vars = self.trainable_variables
        gradients = tape.gradient(loss, trainable_vars)

        self.optimizer.apply_gradients(zip(gradients, trainable_vars))

        self.compiled_metrics.update_state(y, y_pred, sample_weight=sample_weight)

        return {m.name: m.result() for m in self.metrics}

inputs = keras.Input(shape=(32,))
outputs = keras.layers.Dense(1)(inputs)
model = CustomModel(inputs, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["mae"])

x = np.random.random((1000, 32))
y = np.random.random((1000, 1))
sw = np.random.random((1000, 1))
model.fit(x, y, sample_weight=sw, epochs=3)

Idea

自监督任务没有label,loss需要自行设计,此场景适合自定义train_step方法。以对比学习为例:

  • 首先model.fit(x,y)中的x可以是一对正例,y可置None,此时train_step函数的输入为tuple:(x, )
  • 对一个batch设计compute_loss函数
  • call函数也需要自己设计,接受token id和seg id,返回embeding
  • 在train_step方法中调用call和compute_loss,使用loss_tracker.update_state传递loss

keras官方有一个关于clip算法的jupyter:Natural language image search with a Dual Encoder,其DualEncoder类的设计值得一读。
有空时我会仿照上面的思路写一个simcse的keras实现,欢迎follow~

Original: https://blog.csdn.net/weixin_44597588/article/details/123894936
Author: 一只用R的浣熊
Title: TF2-Tips:自定义model.fit

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/514508/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球