tensorflow笔记(二十六)——tf.estimator模型文件保存和加载

Estimator可以保存 ckptsaved_model两种格式的模型。
ckpt方式与session.run模型下保存模型格式一样(在sess.run模式下,通常使用saver = tf.train.Saver()和saver.save()保存模型),这种模型文件需要原始模型代码才能运行,一般用于训练中保存/加载权重。
saved_model格式是一种轻量化的模型,不仅包含权重值,还包含计算。它不需要原始模型构建代码就可以运行,因此,对共享和部署(使用 TFLite、TensorFlow.js、TensorFlow Serving 或 TensorFlow Hub)非常有用。比如使用spark进行infer的时候可以加载这种格式的模型,或者用TensorFlow Serving在线推理。(用于推理的模型导出格式还有FrozenGraph、HDF5、tfLite等,可以参考tensorflow 模型导出总结)

1.1 ckpt文件

整个模型其实包含4个文件:

  • model.ckpt-xxxxx.data-00000-of-00001: 保存当前参数值。比如网络的权值,偏置,操作等等。
  • model.ckpt.index :保存当前参数名。二进制或者其他格式,不可直接查看 。一个不可变的字符串表,每一个键是张量的名称,它的值是一个序列化的BundleEntryProto, 每个BundleEntryProto描述张量的元数据:”数据”文件中的哪个文件包含张量的内容,该文件的偏移量,校验和,一些辅助数据等等。
  • model.ckpt.meta:某个ckpt的meta数据 二进制 或者其他格式 不可直接查看,保存了TensorFlow计算图的结构信息。model.ckpt-200.meta:保存图结构。通俗地讲就是神经网络的网络结构。
  • checkpoint:文本文件,记录了保存的最新的checkpoint文件以及其它checkpoint文件列表。

如何保存ckpt文件?train()中包含了ckpt的保存,直接调用train就好,不需要额外的保存操作:


run_config = tf.estimator.RunConfig(
        model_dir=args.model_dir,
        save_checkpoints_steps=1000,
        keep_checkpoint_max=3)

estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)

estimator.train(input_fn=train_input_fn)

1.2 saved_model文件

包含saved_model.pb文件和一个variables目录,该目录下有两个文件:variables.data-00000-of-00001和variables.index两个文件。

  • saved_model.pb:保存模型结构
  • variables.data-00000-of-00001:保存变量值
  • variables.index:保存变量名

需要显式调用export_saved_model函数来保存,需要制定保存路径,制定数据input的格式(推理的时候需要根据格式处理数据),as_text指定是否按照ASCII编码格式写入到文件里。

estimator.train(input_fn=train_input_fn)

estimator.export_saved_model(
                args.pb_export_dir,
                tf.estimator.export.build_parsing_serving_input_receiver_fn(
                    feature_spec),
                as_text=False)

关于saved_model,我们后面再写一篇用spark加载saved_model模型进行离线infer的文章,这里知道不再细讲。

正如tf.estimator没有显式地保存ckpt模型,也不需要显式地load_model,而是通过定义Estimator的时候来指定模型路径 model_dir,训练、评估、预测都会用到这个model_dir,具体来说训练的时候保存到model_dir,评估和预测从model_dir读取模型文件。
指定model_dir也有两种方式,一是通过RunConfig配置,二是Estimator初始化参数model_dir传递,如下:


run_config = tf.estimator.RunConfig(
        model_dir=args.train_log_dir,
        session_config=config,
        save_checkpoints_steps=args.check_point_num,
        log_step_count_steps=args.log_every_n_steps)
estimator = tf.estimator.Estimator(model_fn=model_fn, config=run_config)

tf.estimator.Estimator(model_fn, model_dir=None, config=None, params=None, warm_start_from=None)

之前一次测试的时候,发现预测概率全部在0.5左右,最后发现是模型加载没有成功,预测结果其实是随机初始化的模型预测结果,所以概率都是0.5。怎么发现是模型没有加载成功呢?我把模型路径下的文件清空了之后预测,没有报错且预测概率就是0.5附近。然后为什么没有加载成功呢,是因为我的模型从平台存储空间上拷贝到测试机上时,只拷贝了model.ckpt-30000.data-00000-of-00001文件,这是不完整的。为什么没有模型也能预测呢?
看estimator.predict函数源码注释:

 def predict(self,
              input_fn,
              predict_keys=None,
              hooks=None,
              checkpoint_path=None,
              yield_single_examples=True):

checkpoint_path: Path of a specific checkpoint to predict. If None, the latest checkpoint in model_dir is used. If there are no checkpoints in model_dir, prediction is run with newly initialized Variables instead of ones restored from checkpoint.

注释清楚地说如果model_dir路径下没有checkpoint文件,就用最新初始化的参数进行预测,那预测出来的结果就是随机了。
为什么要这样设计?我猜测是在训练和评估的时候,计算loss的时候也要调用predict来计算输出值,但计算loss的时候未必有checkpoint,所以也是需要在没有模型的情况下来预测的。

Original: https://blog.csdn.net/hongxingabc/article/details/119977777
Author: starxhong
Title: tensorflow笔记(二十六)——tf.estimator模型文件保存和加载

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/514568/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球