TensorFlow自定义训练函数

本文记录了在TensorFlow框架中自定义训练函数的模板并简述了使用自定义训练函数的优势与劣势。

首先需要说明的是,本文中所记录的训练函数模板参考自https://stackoverflow.com/questions/59438904/applying-callbacks-in-a-custom-training-loop-in-tensorflow-2-0中的回答以及Hands-On Machine Learning with Scikit-Learn, Keras, and Tensorflow一书中第12.3.9节的内容,如有错漏,欢迎指正。

为什么和什么时候需要自定义训练函数

除非你真的需要额外的灵活性,否则应该更倾向使用fit()方法,为不是实现你自己的循环,尤其是在团队合作中。

如果你仍然想知道为什么你需要一个定制的训练功能,那么你还不需要一个定制的训练功能。通常只有在构建一些结构奇怪的模型时,才会发现Model.fit()不能完全满足要求。尝试的第一种方式是查看TensorFlow相关部分的源代码,看看除了理解之外是否还有其他参数或方法,然后考虑使用自定义训练函数。自定义培训函数无疑会使代码变得更长、更难维护和更难理解。

[En]

If you’re still wondering why you need a custom training function, you don’t need a custom training function yet. Usually only when building some models with strange structure, we will find that model.fit () can not fully meet the requirements. The first way to try is to look at the source code of the relevant parts of TensorFlow to see if there are any parameters or methods other than understanding, and then consider using custom training functions. There is no doubt that custom training functions will make the code longer, harder to maintain, and more difficult to understand.

但是,定制训练函数的灵活性是fit()方法所无法比拟的。例如,在自定义函数中,您可以实现使用多个不同优化器的训练循环,或对多个数据集计算验证循环。

[En]

However, the flexibility of the custom training function is unmatched by the fit () method. For example, in custom functions you can implement training loops that use multiple different optimizers or compute validation loops on multiple datasets.

自定义训练函数模板

模板设计的目的是让我们通过重用代码块填补关键部分的空白来快速完成自定义训练函数,这样我们就可以更关注训练函数结构本身而不是细节(如未知长度训练集的处理),并实现Fit()方法支持的一些函数(如回调类的使用)。

[En]

The purpose of the template design is to let us quickly complete the custom training function through * reusing the code block * and * filling in the blanks in the key parts * , so that we can focus more on the training function structure itself rather than on the details (such as the processing of the unknown length training set) and implement some functions supported by the fit () method (such as the use of the Callback class).

 def train(model:keras.Model,train_batchs,epochs=1,initial_epoch=0,callbacks=None,steps_per_epoch=None,val_batchs=None):
    callbacks = tf.keras.callbacks.CallbackList(
        callbacks, add_history=True, model=model)

    logs_dict = {}

    # init optimizer, loss function and metrics
    optimizer = keras.optimizers.Nadam(learning_rate=0.0005)
    loss_fn = keras.losses.MeanSquaredError

    train_loss_tracker = keras.metrics.Mean(name="train_loss")
    val_loss_tracker = keras.metrics.Mean(name="val_loss")
    # train_acc_metric = tf.keras.metrics.BinaryAccuracy(name="train_acc")
    # val_acc_metric = tf.keras.metrics.BinaryAccuracy(name="val_acc")

    def count(): # infinite iter
        x = 0
        while True:yield x;x+=1

    def print_status_bar(iteration, total, metrics=None):
        metrics = " - ".join(["{}:{:.4f}".format(m.name,m.result()) for m in (metrics or [])])
        end = "" if iteration < total or float('inf') else "\n"
        print("\r{}/{} - ".format(iteration,total) + metrics, end=end)

    def train_step(x,y,loss_tracker:keras.metrics.Metric):
        with tf.GradientTape() as tape:
            outputs = model(x)
            main_loss = tf.reduce_mean(loss_fn(y,outputs))

            loss = tf.add_n([main_loss] + model.losses)
        gradients = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(gradients,model.trainable_variables))
        loss_tracker.update_state(loss)
        return {loss_tracker.name:loss_tracker.result()}

    def val_step(x,y,loss_tracker:keras.metrics.Metric):
        outputs = model.predict(x,verbose=0)
        main_loss = tf.reduce_mean(loss_fn(y,outputs))

        loss = tf.add_n([main_loss] + model.losses)
        loss_tracker.update_state(loss)
        return {loss_tracker.name:loss_tracker.result()}

    # init train_batchs
    train_iter = iter(train_batchs)

    callbacks.on_train_begin(logs=logs_dict)
    for i_epoch in range(initial_epoch, epochs):

        # init steps
        infinite_flag = False
        if steps_per_epoch is None:
            infinite_flag = True
            step_iter = count()
        else:
            step_iter = range(steps_per_epoch)

        # train_loop
        for i_step in step_iter:
            callbacks.on_batch_begin(i_step, logs=logs_dict)
            callbacks.on_train_batch_begin(i_step, logs=logs_dict)

            try:
                X_batch, y_batch = train_iter.next()
            except StopIteration:
                train_iter = iter(train_batchs)
                if infinite_flag is True:
                    break
                else:
                    X_batch, y_batch = train_iter.next()

            train_logs_dict = train_step(x=X_batch,y=y_batch,loss_tracker=train_loss_tracker)
            logs_dict.update(train_logs_dict)

            print_status_bar(i_step, steps_per_epoch or i_step, [train_loss_tracker])

            callbacks.on_train_batch_end(i_step, logs=logs_dict)
            callbacks.on_batch_end(i_step, logs=logs_dict)

        if steps_per_epoch is None:
            print()
            steps_per_epoch = i_step

        if val_batchs is not None:
            # val_loop
            for i_step,(X_batch,y_batch) in enumerate(iter(val_batchs)):
                callbacks.on_batch_begin(i_step, logs=logs_dict)
                callbacks.on_test_batch_begin(i_step, logs=logs_dict)

                val_logs_dict = val_step(x=X_batch,y=y_batch,loss_tracker=val_loss_tracker)
                logs_dict.update(val_logs_dict)

                callbacks.on_test_batch_end(i_step, logs=logs_dict)
                callbacks.on_batch_end(i_step, logs=logs_dict)

            logs_dict.update(val_logs_dict)

        print_status_bar(steps_per_epoch, steps_per_epoch, [train_loss_tracker, val_loss_tracker])
        callbacks.on_epoch_end(i_epoch, logs=logs_dict)

        for metric in [train_loss_tracker, val_loss_tracker]:
            metric.reset_states()

    callbacks.on_train_end(logs=logs_dict)

    # Fetch the history object we normally get from keras.fit
    history_object = None
    for cb in callbacks:
        if isinstance(cb, tf.keras.callbacks.History):
            history_object = cb
    return history_object

Original: https://www.cnblogs.com/yc0806/p/16534447.html
Author: 多事鬼间人
Title: TensorFlow自定义训练函数

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/7261/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

最近整理资源【免费获取】:   👉 程序员最新必读书单  | 👏 互联网各方向面试题下载 | ✌️计算机核心资源汇总