Pytorch统计网络参数计算工具、模型 FLOPs, MACs, MAdds 关系

Pytorch统计网络参数


def get_parameter_number(net):
    total_num = sum(p.numel() for p in net.parameters())
    trainable_num = sum(p.numel() for p in net.parameters() if p.requires_grad)
    return {'Total': total_num, 'Trainable': trainable_num}

print(model.state_dict())

FLOPs, MACs, MAdds 关系

Pytorch统计网络参数计算工具、模型 FLOPs, MACs, MAdds 关系
见文章:CNN模型复杂度(FLOPs、MAC)、参数量与运行速度

计算工具:

地址备注https://github.com/Lyken17/pytorch-OpCounterPytorchhttps://github.com/sovrasov/flops-counter.pytorchPytorchhttps://stackoverflow.com/questions/45085938/tensorflow-is-there-a-way-to-measure-flops-for-a-modelTensorFlow: 自带tf.RunMetadata()

另:在PyTorch中,可以使用 torchstat这个库来查看网络模型的一些信息,包括总的参数量params、MAdd、显卡内存占用量和FLOPs等。

!pip install torchstat
from torchstat import stat
from torchvision.models import resnet50, resnet101, resnet152, resnext101_32x8d

model = resnet50()

stat(model, (3, 224, 224))

total = sum([param.nelement() for param in model.parameters()])
print("Number of parameters: %.2fM" % (total/1e6))

也可以使用 torchsummary

!pip install torchsummary
from torchsummary import summary
summary(model, input_size=(ch, h, w), batch_size=-1)

Original: https://blog.csdn.net/user_lib/article/details/123452572
Author: 李代数
Title: Pytorch统计网络参数计算工具、模型 FLOPs, MACs, MAdds 关系

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/496849/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球