pytorch的模型保存加载和继续训练

import torch
from torch import nn
import numpy as np

定义一个三层的MLP分类模型

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(64, 32)
        self.linear1 = nn.Linear(32, 10)
        self.relu = nn.ReLU()

    def forward(self, x):
        x = self.linear(x)
        x = self.relu(x)
        x = self.linear1(x)
        return x

rand1 = torch.rand((100, 64)).to(torch.float)
label1 = np.random.randint(0, 10, size=100)
label1 = torch.from_numpy(label1).to(torch.long)
rand2 = torch.rand((100, 64)).to(torch.float)
label2 = np.random.randint(0, 10, size=100)
label2 = torch.from_numpy(label2).to(torch.long)

model = MyModel()
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)
loss = nn.CrossEntropyLoss()

epoch = 10
for i in range(epoch):
    output = model(rand1)
    my_loss = loss(output, label1)
    optimizer.zero_grad()
    my_loss.backward()
    optimizer.step()
    print("epoch:{} loss:{}".format(i, my_loss))

结果如下:记下这些loss值,观察下次继续训练的初始loss

epoch:0 loss:2.3494179248809814
epoch:1 loss:2.287858009338379
epoch:2 loss:2.2486231327056885
epoch:3 loss:2.2189149856567383
epoch:4 loss:2.193182945251465
epoch:5 loss:2.167125940322876
epoch:6 loss:2.140075206756592
epoch:7 loss:2.1100614070892334
epoch:8 loss:2.0764594078063965
epoch:9 loss:2.0402779579162598

采用torch.save函数保存模型,一般分为两种模式,分别是简单的保存所有参数,第二种是保存各部分参数,到一个字典结构里面。


save_path = r'model_para/'
torch.save(model, save_path+'model_full.pth')

保存模型参数,优化器参数和epoch情况。

def save_model(save_path, epoch, optimizer, model):
    torch.save({'epoch': epoch+1,
                'optimizer_dict': optimizer.state_dict(),
                'model_dict': model.state_dict()},
                save_path)
    print("model save success")
save_model(save_path+'model_dict.pth',epoch, optimizer, model)

对于保存的pth参数文件,使用torch.load进行加载,代码如下:

def load_model(save_name, optimizer, model):
    model_data = torch.load(save_name)
    model.load_state_dict(model_data['model_dict'])
    optimizer.load_state_dict(model_data['optimizer_dict'])
    print("model load success")

观察当前训练模型的权重参数

print(model.state_dict()['linear.weight'])
tensor([[-0.0215,  0.0299, -0.0255,  ..., -0.0997, -0.0899,  0.0499],
        [-0.0113, -0.0974,  0.1020,  ...,  0.0874, -0.0744,  0.0801],
        [ 0.0471,  0.1373,  0.0069,  ..., -0.0573, -0.0199, -0.0654],
        ...,
        [ 0.0693,  0.1900,  0.0013,  ..., -0.0348,  0.1541,  0.1372],
        [ 0.1672, -0.0086,  0.0189,  ...,  0.0926,  0.1545,  0.0934],
        [-0.0773,  0.0645, -0.1544,  ..., -0.1130,  0.0213, -0.0613]])

命名一个新模型,加载之前保存的参数文件,并打印出层参数

new_model = MyModel()
new_optimizer = torch.optim.Adam(new_model.parameters(), lr=0.01)
load_model(save_path+'model_dict.pth', new_optimizer, new_model)
print(new_model.state_dict()['linear.weight'])

可以看出新模型和当前模型的参数一致,说明参数加载成功。

model load success
tensor([[-0.0215,  0.0299, -0.0255,  ..., -0.0997, -0.0899,  0.0499],
        [-0.0113, -0.0974,  0.1020,  ...,  0.0874, -0.0744,  0.0801],
        [ 0.0471,  0.1373,  0.0069,  ..., -0.0573, -0.0199, -0.0654],
        ...,
        [ 0.0693,  0.1900,  0.0013,  ..., -0.0348,  0.1541,  0.1372],
        [ 0.1672, -0.0086,  0.0189,  ...,  0.0926,  0.1545,  0.0934],
        [-0.0773,  0.0645, -0.1544,  ..., -0.1130,  0.0213, -0.0613]])

在新模型加载原来模型参数的基础上,继续训练,观察loss值,是在之前训练的最终loss,继续下降,说明模型继续训练成功。

epoch = 10
for i in range(epoch):
    output = new_model(rand1)
    my_loss = loss(output, label1)
    new_optimizer.zero_grad()
    my_loss.backward()
    new_optimizer.step()
    print("epoch:{} loss:{}".format(i, my_loss))
epoch:0 loss:2.0036799907684326
epoch:1 loss:1.965193271636963
epoch:2 loss:1.924098253250122
epoch:3 loss:1.881495714187622
epoch:4 loss:1.835693359375
epoch:5 loss:1.7865667343139648
epoch:6 loss:1.7352293729782104
epoch:7 loss:1.6832704544067383
epoch:8 loss:1.6308385133743286
epoch:9 loss:1.5763107538223267

同样,在这里我发现一个问题,因为之前随机产生了2组数据,之前模型训练使用的rand1,这里只有继续训练rand1,之前模型的参数才有效,如果使用rand2,模型相当于从0训练(如下loss),这是因为,两组数据都是随机生成的,数据分布几乎不一样,所以上一组数据训练的模型在第二组数据几乎无效。

epoch:0 loss:2.523787498474121
epoch:1 loss:2.469816207885742
epoch:2 loss:2.4141526222229004
epoch:3 loss:2.379054069519043
epoch:4 loss:2.3563807010650635
epoch:5 loss:2.319946765899658
epoch:6 loss:2.271805763244629
epoch:7 loss:2.2274367809295654
epoch:8 loss:2.186885118484497
epoch:9 loss:2.144239902496338

但是在真实情况中,由于batch数据都是假设同一分布,所以不用考虑这个问题,

那么以上,就完成了pytorch的模型保存,加载和继续训练的三种重要过程,希望能够帮到您!!!

祝您训练愉快。

Original: https://blog.csdn.net/weixin_42327752/article/details/125405980
Author: Weiyaner
Title: pytorch的模型保存加载和继续训练

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/627071/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球