在多卡的GPU服务器,当我们在上面跑程序的时候,当迭代次数或者epoch足够大的时候,我们通常会使用nn.DataParallel函数来用多个GPU来加速训练。一般我们会在代码中加入以下这些:
model = model.cuda()
device_ids = [0, 1] # id为0和1的两块显卡
model = torch.nn.DataParallel(model, device_ids=device_ids)
device_ids = [0, 1]
model = torch.nn.DataParallel(model, device_ids=device_ids).cuda()
函数定义:
Parameters 参数:
module即表示你定义的模型;
device_ids表示你训练的device;
output_device这个参数表示输出结果的device;
而这最后一个参数output_device一般情况下是省略不写的,那么默认就是在device_ids[0],也就是第一块卡上, 也就解释了为什么第一块卡的显存会占用的比其他卡要更多一些。
torch.nn.DataParallel源码解读:
class DataParallel(Module):
def __init__(self, module, device_ids=None, output_device=None, dim=0):
super(DataParallel, self).__init__()
device_type = _get_available_device_type()
if device_type is None:
self.module = module
self.device_ids = []
return
if device_ids is None:
device_ids = _get_all_device_indices()
if output_device is None:
output_device = device_ids[0]
self.dim = dim
self.module = module
self.device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
self.output_device = _get_device_index(output_device, True)
self.src_device_obj = torch.device(device_type, self.device_ids[0])
_check_balance(self.device_ids)
if len(self.device_ids) == 1:
self.module.to(self.src_device_obj)
def forward(self, *inputs, **kwargs):
if not self.device_ids:
return self.module(*inputs, **kwargs)
for t in chain(self.module.parameters(), self.module.buffers()):
if t.device != self.src_device_obj:
raise RuntimeError("module must have its parameters and buffers "
"on device {} (device_ids[0]) but found one of "
"them on device: {}".format(self.src_device_obj, t.device))
inputs, kwargs = self.scatter(inputs, kwargs, self.device_ids)
if len(self.device_ids) == 1:
return self.module(*inputs[0], **kwargs[0])
replicas = self.replicate(self.module, self.device_ids[:len(inputs)])
outputs = self.parallel_apply(replicas, inputs, kwargs)
return self.gather(outputs, self.output_device)
def replicate(self, module, device_ids):
return replicate(module, device_ids, not torch.is_grad_enabled())
def scatter(self, inputs, kwargs, device_ids):
return scatter_kwargs(inputs, kwargs, device_ids, dim=self.dim)
def parallel_apply(self, replicas, inputs, kwargs):
return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)])
def gather(self, outputs, output_device):
return gather(outputs, output_device, dim=self.dim)
[docs]def data_parallel(module, inputs, device_ids=None, output_device=None, dim=0, module_kwargs=None):
r"""Evaluates module(input) in parallel across the GPUs given in device_ids.
This is the functional version of the DataParallel module.
Args:
module (Module): the module to evaluate in parallel
inputs (Tensor): inputs to the module
device_ids (list of int or torch.device): GPU ids on which to replicate module
output_device (list of int or torch.device): GPU location of the output Use -1 to indicate the CPU.
(default: device_ids[0])
Returns:
a Tensor containing the result of module(input) located on
output_device
"""
if not isinstance(inputs, tuple):
inputs = (inputs,)
device_type = _get_available_device_type()
if device_ids is None:
device_ids = _get_all_device_indices()
if output_device is None:
output_device = device_ids[0]
device_ids = list(map(lambda x: _get_device_index(x, True), device_ids))
output_device = _get_device_index(output_device, True)
src_device_obj = torch.device(device_type, device_ids[0])
for t in chain(module.parameters(), module.buffers()):
if t.device != src_device_obj:
raise RuntimeError("module must have its parameters and buffers "
"on device {} (device_ids[0]) but found one of "
"them on device: {}".format(src_device_obj, t.device))
inputs, module_kwargs = scatter_kwargs(inputs, module_kwargs, device_ids, dim)
if len(device_ids) == 1:
return module(*inputs[0], **module_kwargs[0])
used_device_ids = device_ids[:len(inputs)]
replicas = replicate(module, used_device_ids)
outputs = parallel_apply(replicas, inputs, module_kwargs, used_device_ids)
return gather(outputs, output_device, dim)
Original: https://blog.csdn.net/sazass/article/details/116615028
Author: 大黑山修道
Title: 【pytorch系列】torch.nn.DataParallel用法详解
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/709410/
转载文章受原作者版权保护。转载请注明原作者出处!