TF的模型文件

标签(由空格分隔):TensorFlow

[En]

Tags (separated by spaces): TensorFlow

Saver

tensorflow模型保存函数为:

tf.train.Saver()

当然,除了上面最简单的保存方法外,您还可以指定要保存的步骤数、保存频率以及磁盘上最多几个型号(删除之前的型号以保留固定数量),如下所示:

[En]

Of course, in addition to the simplest save method above, you can also specify the number of steps to save, how often to save, and a maximum of several models on disk (delete the previous ones to keep a fixed number), as follows:

创建保存程序时指定参数:

[En]

Specify parameters when creating a saver:

saver = tf.train.Saver(savable_variables, max_to_keep=n, keep_checkpoint_every_n_hours=m)

在哪里:

[En]

Where:

  • savable_variables指定待保存的变量,比如指定为tf.global_variables()保存所有global变量;指定为[v1, v2]保存v1和v2两个变量;如果省略,则保存所有;
  • max_to_keep指定磁盘上最多保有几个模型;
  • keep_checkpoint_every_n_hours指定多少小时保存一次。

保存模型时指定参数:

[En]

Specify parameters when saving the model:

saver.save(sess, 'model_name', global_step=step,write_meta_graph=False)

如上,其中可以指定模型文件名,步数,write_meta_graph则用来指定是否保存meta文件记录graph等等。

示例:

[En]

Example:

import tensorflow as tf
​
v1= tf.Variable(tf.random_normal([784, 200], stddev=0.35), name="v1")
v2= tf.Variable(tf.zeros([200]), name="v2")
v3= tf.Variable(tf.zeros([100]), name="v3")
saver = tf.train.Saver()
with tf.Session() as sess:
    init_op = tf.global_variables_initializer()
    sess.run(init_op)
    saver.save(sess,"checkpoint/model.ckpt",global_step=1)

运行后,保存模型并获得四个文件:

[En]

After running, save the model and get four files:

  • checkpoint
  • model.ckpt-1.data-00000-of-00001
  • model.ckpt-1.index
  • model.ckpt-1.meta

checkpoint中记录了已存储(部分)和最近存储的模型:

model_checkpoint_path: "model.ckpt-1"
all_model_checkpoint_paths: "model.ckpt-1"
...

meta file保存了graph结构,包括 GraphDef,SaverDef等,当存在meta file,我们可以不在文件中定义模型,也可以运行,而如果没有meta file,我们需要定义好模型,再加载data file,得到变量值。

index file为一个string-string table,table的key值为tensor名,value为serialized BundleEntryProto。每个BundleEntryProto表述了tensor的metadata,比如那个data文件包含tensor、文件中的偏移量、一些辅助数据等。

data file保存了模型的所有变量的值,TensorBundle集合。

Restore

Restore模型的过程可以分为两个部分,首先是创建模型,可以手动创建,也可以从meta文件里加载graph进行创建。

模型将加载为:

[En]

The model is loaded as:

with tf.Session() as sess:
    saver = tf.train.import_meta_graph('/xx/model.ckpt.meta')
    saver.restore(sess, "/xx/model.ckpt")

.meta文件中保存了图的结构信息,因此需要在导入checkpoint之前导入它。否则,程序不知道checkpoint中的变量对应的变量。另外也可以:

# Recreate the EXACT SAME variables
v1 = tf.Variable(..., name="v1")
v2 = tf.Variable(..., name="v2")
...

# Now load the checkpoint variable values
with tf.Session() as sess:
    saver = tf.train.Saver()
    saver.restore(sess, "/xx/model.ckpt")
    #saver.restore(sess, tf.train.latest_checkpoint('./'))

PS: 不存在model.ckpt文件,saver.py中:Users only need to interact with the user-specified prefix… instead of any physical pathname.

当然,重要的是要注意,并不是所有的TensorFlow模型都可以将图形输出到元文件或从元文件加载图形,如果模型的某些部分无法序列化,则此方法可能不起作用。

[En]

Of course, it’s important to note that not all TensorFlow models can output graph to or load from a meta file, and this approach may not work if there are parts of the model that cannot be serialized.

使用Restore的模型

with tf.Session() as sess:
  saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
  saver.restore(sess, tf.train.latest_checkpoint('./'))
  tvs = [v for v in tf.trainable_variables()]
  for v in tvs:
    print(v.name)
    print(sess.run(v))

顾名思义,上面就是查看模型中的可训练变量;或者我们也可以查看模型中的所有张量或运算,如下所示:

[En]

As the name suggests, the above is to view the trainable variables; in the model, or we can also view all the tensor or operations in the model, as follows:

with tf.Session() as sess:
  saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
  saver.restore(sess, tf.train.latest_checkpoint('./'))
  gv = [v for v in tf.global_variables()]
  for v in gv:
    print(v.name)

上面通过global_variables()获得的与前trainable_variables类似,只是多了一些非trainable的变量,比如定义时指定为trainable=False的变量,或Optimizer相关的变量。

几乎所有与运算相关的张量都可以如下获得:

[En]

Almost all operations-related tensor can be obtained as follows:

with tf.Session() as sess:
  saver = tf.train.import_meta_graph('model.ckpt-1000.meta')
  saver.restore(sess, tf.train.latest_checkpoint('./'))
  ops = [o for o in sess.graph.get_operations()]
  for o in ops:
    print(o.name)

首先,上面的sess.graph.get_operations()可以换为tf.get_default_graph().get_operations(),二者区别无非是graph明确的时候可以直接使用前者,否则需要使用后者。

通过这种方法得到的张量是相对完整的,由此我们可以对整个模型有一个大概的了解。但是,最方便的方式是使用tensorboard查看,当然,这需要您提前输出sess.graph。

[En]

The tensor obtained by this method is relatively complete, from which we can get a glimpse of the whole model. However, the most convenient way is to use tensorboard to view, of course, this requires you to output the sess.graph in advance.

这个操作比较简单,无非是找出原始模型的输入输出。

[En]

This operation is relatively simple, nothing more than to find the input and output of the original model.

只要搞清楚输入输出的tensor名字,即可直接使用TensorFlow中graph的get_tensor_by_name函数,建立输入输出的tensor:

with tf.get_default_graph() as graph:
  data = graph.get_tensor_by_name('data:0')
  output = graph.get_tensor_by_name('output:0')

在找到模型的输入和输出后,您可以直接使用它来继续训练整个模型,或者将输入数据反馈到模型中并转发以获得测试输出。

[En]

After finding the input and output from the model, you can use it directly to continue to train the entire model, or feed the input data into the model and forward to get the test output.

需要注意的是,有时从图中找到输入和输出张量的名称并不容易,因此在定义图时,最好给相应的张量一个明显的名称,例如:

[En]

It is important to note that sometimes it is not easy to find the names of input and output tensor from a graph, so when defining a graph, it is best to give the corresponding tensor an obvious name, such as:

data = tf.placeholder(tf.float32, shape=shape, name='input_data')
preds = tf.nn.softmax(logits, name='output')

诸如此类。这样,就可以直接使用tf.get_tensor_by_name(‘input_data:0’)之类的来找到输入输出了。

除了直接使用原始模型外,还可以扩展原始模型,例如继续处理1中的输出,并添加新的操作来完成对原始模型的扩展,例如:

[En]

In addition to directly using the original model, you can also extend the original model, such as continuing to process the output in 1 and adding new operations to complete the extension of the original model, such as:

with tf.get_default_graph() as graph:
  data = graph.get_tensor_by_name('data:0')
  output = graph.get_tensor_by_name('output:0')
  logits = tf.nn.softmax(output)

有时候,我们有对某模型的一部分进行fine-tune的需求,比如使用一个VGG的前面提取特征的部分,而微调其全连层,或者将其全连层更换为使用convolution来完成,等等。TensorFlow也提供了这种支持,可以使用TensorFlow的stop_gradient函数,将模型的一部分进行冻结。

with tf.get_default_graph() as graph:
  graph.get_tensor_by_name('fc1:0')
  fc1 = tf.stop_gradient(fc1)
  # add new procedure on fc1

Original: https://www.cnblogs.com/houkai/p/9723988.html
Author: 侯凯
Title: TF的模型文件

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/6317/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

发表回复

登录后才能评论
免费咨询
免费咨询
扫码关注
扫码关注
联系站长

站长Johngo!

大数据和算法重度研究者!

持续产出大数据、算法、LeetCode干货,以及业界好资源!

2022012703491714

微信来撩,免费咨询:xiaozhu_tec

分享本页
返回顶部