函数格式为: torch.load(f, map_location=None, pickle_module=pickle, **pickle_load_args)
,一般我们使用的时候,基本只使用前两个参数。
- 模型保存有两种形式,一种是保存模型的
state_dict()
,只是保存模型的参数。那么加载时需要先创建一个模型的实例model
,之后通过torch.load()
将保存的模型参数加载进来,得到dict
,再通过model.load_state_dict(dict)
将模型的参数更新。 - 另一种是将整个模型保存下来,之后加载的时候只需要通过
torch.load()
将模型加载,即可返回一个加载好的模型。
具体可参考:PyTorch模型的保存与加载。
具体来说, map_location
参数是用于重定向,比如此前模型的参数是在 cpu
中的,我们希望将其加载到 cuda:0
中。或者我们有多张卡,那么我们就可以将卡1中训练好的模型加载到卡2中,这在数据并行的分布式深度学习中可能会用到。
-
首先定义一个AlexNet,并使用
cuda:0
将其训练了一个猫狗分类,之后把模型存储起来。 -
我们先把
state_dict
加载进来。
model_path = "./cuda_model.pth"
model = torch.load(model_path)
print(next(model.parameters()).device)
结果为:
cuda:0
因为保存的时候就是模型就是 cuda:0
的,所以加载进来也是。
model_path = "./cuda_model.pth"
model = torch.load(model_path, map_location=torch.device('cpu'))
print(next(model.parameters()).device)
结果为:
cpu
模型从 cuda:0
变成了 cpu
。
model_path = "./cuda_model.pth"
model = torch.load(model_path, map_location={'cuda:0':'cuda:1'})
print(next(model.parameters()).device)
结果为:
cuda:1
模型从 cuda:0
变成了 cuda:1
。
model_path = "./cuda_model.pth"
model = torch.load(model_path, map_location={'cuda:2':'cpu'})
print(next(model.parameters()).device)
结果为:
cuda:0
模型还是 cuda:0
,并没有变成 cpu
。因为这个 map_location
的映射是不对的,原始的模型就是 cuda:0
,而映射是 cuda:2
到 cpu
,是不对的。这种情况下, map_location
返回 None
,也就是和不加 map_location
相同。
Original: https://blog.csdn.net/qq_43219379/article/details/123675375
Author: eecspan
Title: torch.load()加载模型及其map_location参数
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/716181/
转载文章受原作者版权保护。转载请注明原作者出处!