【Pytorch-6】-模型保存与加载

其实Pytorch模型保存还是挺简单的,但是不同方式也有优劣之分吧。有时候,我们不仅仅需要保存模型参数,而有时需要保存训练的所有现场,包括优化器的内容。即有时候是只保存参数,但有时候需要保存模型训练的全过程。

我们实际上保存的是模型的参数,没有保存模型的结构的完整信息。

即,保存的模型是以字典形式保存的,所以被称作为state_dict。上面实际上我们按照已经定义好的模型进行加载,所以使用model.load_state_dict。其中的键信息实际是原本模型的层次的名字,因此模型在重新读取的时候,需要我们先实例化完全一致的结构,再进行参数的加载。

如果model是pytorch的nn.module继承而来的,那么如下:

model_path = os.path.join(output, 'model.pth')
torch.save(model.state_dict(), model_path)

这里有 .pth的格式存储,还有 .pkl格式,以及 .pt的格式。

之后,如果要进行推理或者使用时加载模型,只需要模型的结构对应,就可以直接加载:

model.load_state_dict(torch.load(args.model_path))

总结如下:

  • 保存模型时调用 state_dict() 获取模型的参数,而不保存结构
  • 加载模型时需要预先实例化一个对应的结构
  • 加载模型使用 load_state_dict 方法,其参数不是文件路径,而是 torch.load(PATH)

这是完整的存储了模型的信息的方法,包括模型的参数信息、模型的结构信息、参数等等所有内容。和方法一相比,弊端是会占用更大的信息,优势是,我们不需要知道文件中的模型究竟是什么样的,直接读取即可使用了:

torch.save(model, PATH)

model = torch.load(PATH)

有时我们不仅要保存模型,还要连带保存一些其他的信息。比如在训练过程中保存一些 checkpoint,往往除了模型,还要保存它的 epochlossoptimizer等信息,以便于加载后对这些 checkpoint 继续训练等操作;或者再比如,有时候需要将多个模型一起打包保存等。

这里我们主要将多个内容放入一个字典进行保存:

torch.save({
            'epoch': epoch,
            'model_state_dict': model.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': loss,
            ...

            }, PATH)

加载的时候,我们需要将各个对应的元素按照原本的类别,进行数据初始化,例如优化器必须还是之前的优化器,模型还是之前的模型结构(主要这里例子是state_dict,不然直接保存模型也是可以的)

model = TheModelClass(*args, **kwargs)
optimizer = TheOptimizerClass(*args, **kwargs)

checkpoint = torch.load(PATH)
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
epoch = checkpoint['epoch']
loss = checkpoint['loss']

我们时常会涉及到,在有GPU的服务器进行训练,但是在CPU上进行推理和使用的情况。正常的CPU训练、CPU加载或者GPU训练、GPU使用,都是没问题的,主要是设备不同时的问题。

GPU训,GPU加载

最为正常和一般的情况,照常操作,不过还是别忘记把模型放到GPU上去。

GPUidx=0
device = torch.device('cuda:{}'.format(GPUidx) if torch.cuda.is_available() else 'cpu')

BATCH_SIZE = 64
N_EPOCHS = 15
INPUT_DIM = 28 * 28
HIDDEN_DIM = 256
LATENT_DIM = 20

encoder = Encoder(INPUT_DIM, HIDDEN_DIM, LATENT_DIM)
decoder = Decoder(LATENT_DIM, HIDDEN_DIM, INPUT_DIM)
VAEmodel = VAE(encoder, decoder).to(device)

VAEmodel.load_state_dict(torch.load(modelpath))

GPU训练,CPU加载

保存的行为一致,我们只需要在torch.load时,对相应的参数 map_location进行设置即可:

torch.save(net.state_dict(), PATH)

device = torch.device("cpu")

loaded_net = Net()
loaded_net.load_state_dict(torch.load(PATH, map_location=device))

CPU训练,GPU加载

虽然一般不太可能,但还是啰嗦一下

torch.save(net.state_dict(), PATH)

device = torch.device("cuda")

loaded_net = Net()
loaded_net.load_state_dict(torch.load(PATH, map_location=device))

loaded_net.to(device)

Original: https://blog.csdn.net/zeiyousao/article/details/123724936
Author: 临淮郡人
Title: 【Pytorch-6】-模型保存与加载

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/692549/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球