梯度截断代码
需要添加在loss反向传播后,optimizer.step()前
将梯度裁剪到-grad_clip和grad_clip之间
def clip_gradient(optimizer, grad_clip):
"""
Clips gradients computed during backpropagation to avoid explosion of gradients.
:param optimizer: optimizer with the gradients to be clipped
:param grad_clip: clip value
"""
for group in optimizer.param_groups:
for param in group["params"]:
if param.grad is not None:
param.grad.data.clamp_(-grad_clip, grad_clip)
或者
nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)
Original: https://www.cnblogs.com/yuzhoutaiyang/p/16215614.html
Author: 宇宙•太阳
Title: 梯度截断代码
原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/567759/
转载文章受原作者版权保护。转载请注明原作者出处!