saver.save和saver.restore

saver()与restore()只是保存了session中的相关变量对应的值,并不涉及模型的结构。

Saver的作用是将我们训练好的模型的参数保存下来,以便下一次继续用于训练或测试;Restore则是将训练好的参数提取出来。Saver类训练完后,是以checkpoints文件形式保存。提取的时候也是从checkpoints文件中恢复变量。Checkpoints文件是一个二进制文件,它把变量名映射到对应的tensor值 。

一般地,Saver会自动的管理Checkpoints文件。我们可以指定保存最近的N个Checkpoints文件,当然每一步都保存ckpt文件也是可以的,只是没必要,费存储空间。

  • saver()可以选择global_step参数来为ckpt文件名添加数字标记:
saver.save(sess, 'my-model', global_step=0) ==> filename: 'my-model-0'
...

saver.save(sess, 'my-model', global_step=1000) ==> filename: 'my-model-1000'
  • max_to_keep参数定义saver()将自动保存的最近n个ckpt文件,默认n=5,即保存最近的5个检查点ckpt文件。若n=0或者None,则保存所有的ckpt文件。
  • keep_checkpoint_every_n_hours与max_to_keep类似,定义每n小时保存一个ckpt文件。
...

saver = tf.train.Saver(...variables...)

sess = tf.Session()
for step in xrange(1000000):
    sess.run(..training_op..)
    if step % 1000 == 0:

        saver.save(sess, 'my-model', global_step=step)
restore(sess, save_path)

  • sess: 保存参数的会话。
  • save_path: 保存参数的路径。
  • 从文件恢复变量时,不需要提前初始化,因为恢复本身就是一种初始化变量的方式。
    [En]

    when restoring variables from a file, you do not need to initialize them in advance, because “restore” itself is a way to initialize variables.*

  • 可以使用tf.train.latest_checkpoint()来自动获取最后一次保存的模型。如:
model_file=tf.train.latest_checkpoint('ckpt/')saver.restore(sess,model_file)

在实验中,最后一代可能不是验证精度最高的一代,所以我们不想默认保存最后一代,而是希望保存验证精度最高的一代,只需添加一个中间变量和一个判断语句。

[En]

In the experiment, the last generation may not be the generation with the highest verification accuracy, so we do not want to save the last generation by default, but want to save the generation with the highest verification accuracy, just add an intermediate variable and a judgment statement.

saver=tf.train.Saver(max_to_keep=1)
max_acc=0
for i in range(100):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
  if val_acc>max_acc:
      max_acc=val_acc
      saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
sess.close()

如果我们想保存验证精度最高的三代,且把每次的验证精度也随之保存下来,则我们可以生成一个txt文件用于保存。

saver=tf.train.Saver(max_to_keep=3)
max_acc=0
f=open('ckpt/acc.txt','w')
for i in range(100):
  batch_xs, batch_ys = mnist.train.next_batch(100)
  sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
  val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
  print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
  f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
  if val_acc>max_acc:
      max_acc=val_acc
      saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
f.close()
sess.close()

我们可以将代码的后半部分更改为:

[En]

We can change the second half of the code to:

sess=tf.InteractiveSession()
sess.run(tf.global_variables_initializer())

is_train=False
saver=tf.train.Saver(max_to_keep=3)

if is_train:
    max_acc=0
    f=open('ckpt/acc.txt','w')
    for i in range(100):
      batch_xs, batch_ys = mnist.train.next_batch(100)
      sess.run(train_op, feed_dict={x: batch_xs, y_: batch_ys})
      val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
      print('epoch:%d, val_loss:%f, val_acc:%f'%(i,val_loss,val_acc))
      f.write(str(i+1)+', val_acc: '+str(val_acc)+'\n')
      if val_acc>max_acc:
          max_acc=val_acc
          saver.save(sess,'ckpt/mnist.ckpt',global_step=i+1)
    f.close()

else:
    model_file=tf.train.latest_checkpoint('ckpt/')
    saver.restore(sess,model_file)
    val_loss,val_acc=sess.run([loss,acc], feed_dict={x: mnist.test.images, y_: mnist.test.labels})
    print('val_loss:%f, val_acc:%f'%(val_loss,val_acc))
sess.close()

参考:
https://www.cnblogs.com/denny402/p/6940134.html
https://blog.csdn.net/hellocsz/article/details/89097380

Original: https://blog.csdn.net/qq_40133431/article/details/121342927
Author: 泡泡龙的泡泡
Title: saver.save和saver.restore

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/511248/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球