梯度截断代码

梯度截断代码

需要添加在loss反向传播后,optimizer.step()前

将梯度裁剪到-grad_clip和grad_clip之间

def clip_gradient(optimizer, grad_clip):
"""
    Clips gradients computed during backpropagation to avoid explosion of gradients.

    :param optimizer: optimizer with the gradients to be clipped
    :param grad_clip: clip value
"""
    for group in optimizer.param_groups:
        for param in group["params"]:
            if param.grad is not None:
                param.grad.data.clamp_(-grad_clip, grad_clip)

或者

            nn.utils.clip_grad_norm_(model.parameters(), max_norm=20, norm_type=2)

Original: https://www.cnblogs.com/yuzhoutaiyang/p/16215614.html
Author: 宇宙•太阳
Title: 梯度截断代码

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/567759/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球