浅谈深度学习:如何计算模型以及中间变量的显存占用大小

前言

亲,显存炸了,你的显卡快冒烟了!

torch.FatalError: cuda runtime error (2) : out of memory at /opt/conda/conda-bld/pytorch_1524590031827/work/aten/src/THC/generic/THCStorage.cu:58

想必这是所有炼丹师们最不想看到的错误,没有之一。

OUT OF MEMORY,显然是显存装不下你那么多的模型权重还有中间变量,然后程序奔溃了。怎么办,其实办法有很多,及时清空中间变量,优化代码,减少batch,等等等等,都能够减少显存溢出的风险。

但是这篇要说的是上面这一切优化操作的基础,如何去计算我们所使用的显存。学会如何计算出来我们设计的模型以及中间变量所占显存的大小,想必知道了这一点,我们对自己显存也就会得心应手了。

如何计算

首先我们应该了解一下基本的数据量信息:

  • 1 G = 1000 MB
  • 1 M = 1000 KB
  • 1 K = 1000 Byte
  • 1 B = 8 bit

好,肯定有人会问为什么是1000而不是1024,这里不过多讨论,只能说两种说法都是正确的,只是应用场景略有不同。这里统一按照上面的标准进行计算。

然后我们说一下我们平常使用的向量所占的空间大小,以Pytorch官方的数据格式为例(所有的深度学习框架数据格式都遵循同一个标准):

浅谈深度学习:如何计算模型以及中间变量的显存占用大小

我们只需要看左边的信息,在平常的训练中,我们经常使用的一般是这两种类型:

  • float32 单精度浮点型
  • int32 整型

一般一个8-bit的整型变量所占的空间为 1B也就是 8bit。而32位的float则占 4B也就是 32bit。而双精度浮点型double和长整型long在平常的训练中我们一般不会使用。

ps:消费级显卡对单精度计算有优化,服务器级别显卡对双精度计算有优化。

也就是说,假设有一幅RGB三通道真彩色图片,长宽分别为500 x 500,数据类型为单精度浮点型,那么这张图所占的显存的大小为:500 x 500 x 3 x 4B = 3M。

而一个(256,3,100,100)-(N,C,H,W)的FloatTensor所占的空间为256 x 3 x 100 x 100 x 4B = 31M

不多是吧,没关系,好戏才刚刚开始。

显存去哪儿了

看起来一张图片(3x256x256)和卷积层(256x100x100)所占的空间并不大,那为什么我们的显存依旧还是用的比较多,原因很简单,占用显存比较多空间的并不是我们输入图像,而是神经网络中的中间变量以及使用optimizer算法时产生的巨量的中间参数。

我们首先来简单计算一下Vgg16这个net需要占用的显存:

通常一个模型占用的显存也就是两部分:

  • 模型自身的参数(params)
  • 模型计算产生的中间变量(memory)

浅谈深度学习:如何计算模型以及中间变量的显存占用大小

图片来自cs231n,这是一个典型的sequential-net,自上而下很顺畅,我们可以看到我们输入的是一张224x224x3的三通道图像,可以看到一张图像只占用 150x4k,但上面是 150k,这是因为这里在计算的时候默认的数据格式是8-bit而不是32-bit,所以最后的结果要乘上一个4。

我们可以看到,左边的memory值代表:图像输入进去,图片以及所产生的中间卷积层所占的空间。我们都知道,这些形形色色的深层卷积层也就是深度神经网络进行”思考”的过程:

浅谈深度学习:如何计算模型以及中间变量的显存占用大小

图片从3通道变为64 –> 128 –> 256 –> 512 …. 这些都是卷积层,而我们的显存也主要是他们占用了。

还有上面右边的params,这些是神经网络的权重大小,可以看到第一层卷积是3×3,而输入图像的通道是3,输出通道是64,所以很显然,第一个卷积层权重所占的空间是 (3 x 3 x 3) x 64。

另外还有一个需要注意的是中间变量在backward的时候会翻倍!

举个例子,下面是一个计算图,输入 x,经过中间结果 z,然后得到最终变量 L

浅谈深度学习:如何计算模型以及中间变量的显存占用大小

我们在backward的时候需要保存下来的中间值。输出是 L,然后输入 x,我们在backward的时候要求 Lx的梯度,这个时候就需要在计算链 Lx中间的 z

浅谈深度学习:如何计算模型以及中间变量的显存占用大小

dz/dx这个中间值当然要保留下来以用于计算,所以粗略估计, backward的时候中间变量的占用了是 forward的两倍!

优化器和动量

要注意,优化器也会占用我们的显存!

为什么,看这个式子:

浅谈深度学习:如何计算模型以及中间变量的显存占用大小

浅谈深度学习:如何计算模型以及中间变量的显存占用大小

当然这只是SGD优化器,其他复杂的优化器如果在计算时需要的中间变量多的时候,就会占用更多的内存。

模型中哪些层会占用显存

有参数的层即会占用显存的层。我们一般的卷积层都会占用显存,而我们经常使用的激活层Relu没有参数就不会占用了。

占用显存的层一般是:

  • 卷积层,通常的conv2d
  • 全连接层,也就是Linear层
  • BatchNorm层
  • Embedding层

而不占用显存的则是:

  • 刚才说到的激活层Relu等
  • 池化层
  • Dropout层

具体计算方式:

  • Conv2d(Cin, Cout, K): 参数数目:Cin × Cout × K × K
  • Linear(M->N): 参数数目:M×N
  • BatchNorm(N): 参数数目: 2N
  • Embedding(N,W): 参数数目: N × W

额外的显存

总结一下,我们在总体的训练中,占用显存大概分以下几类:

  • 模型中的参数(卷积层或其他有参数的层)
  • 模型在计算时产生的中间参数(也就是输入图像在计算时每一层产生的输入和输出)
  • backward的时候产生的额外的中间参数
  • 优化器在优化时产生的额外的模型参数

但其实,我们占用的显存空间为什么比我们理论计算的还要大,原因大概是因为深度学习框架一些额外的开销吧,不过如果通过上面公式,理论计算出来的显存和实际不会差太多的。

如何优化

优化除了算法层的优化,最基本的优化无非也就一下几点:

  • 减少输入图像的尺寸
  • 减少batch,减少每次的输入图像数量
  • 多使用下采样,池化层
  • 一些神经网络层可以进行小优化,利用relu层中设置 inplace
  • 购买显存更大的显卡
  • 从深度学习框架上面进行优化

撩我吧

  • 如果你与我志同道合于此,老潘很愿意与你交流;
  • 如果你喜欢老潘的内容,欢迎关注和支持。
  • 如果你喜欢我的文章,希望点赞👍 收藏 📁 评论 💬 三连一下~

想知道老潘是如何学习踩坑的,想与我交流问题~请关注公众号「oldpan博客」。
老潘也会整理一些自己的私藏,希望能帮助到大家,点击神秘传送门获取。

Original: https://www.cnblogs.com/bigoldpan/p/14458169.html
Author: 老潘的博客
Title: 浅谈深度学习:如何计算模型以及中间变量的显存占用大小

原创文章受到原创版权保护。转载请注明出处:https://www.johngo689.com/685763/

转载文章受原作者版权保护。转载请注明原作者出处!

(0)

大家都在看

  • Seata启动seata-server.bat闪退

    解决方法:创建logs/seata_gc.log 文件夹及文件 Original: https://www.cnblogs.com/mingforyou/p/15742889.ht…

    技术杂谈 2023年5月31日
    0110
  • 基于阿里开源的COLA架构和DDD领域驱动设计构建货物运输系统

    COLA 是 Clean Object-Oriented and Layered Architecture的缩写,代表”整洁面向对象分层架构”,是来自阿里技…

    技术杂谈 2023年6月1日
    097
  • FPGA学习-2,一点理解

    1、Wire只能赋一次值,Reg可以多次改变2、#100这种是在仿真系统下有效。3、同一个文件下也可以写多个module. 本博客是个人工作中记录,遇到问题可以互相探讨,没有遇到的…

    技术杂谈 2023年6月1日
    087
  • 操作系统复习错题集合

    操作系统复习错题集合 ​ 主要记一下这个写操作,是增删目录中的目录项 ​ 文件有逻辑结构和物理结构,逻辑结构有流式和记录式,物理结构有顺序式、索引式、链接式 UNIX题目一概背记。…

    技术杂谈 2023年7月11日
    070
  • lambda表达式常用00

    交集 并集 差集 List集合的过滤之lambda表达式 lambda表达式将List对象某个字段转换以逗号分隔的String类型 Original: https://www.cn…

    技术杂谈 2023年7月24日
    080
  • Qt error: ‘class Ui::XXXXX‘ has no member named ‘XXXXX‘

    这个原因是因为 设计界面对应的 ui_xx.h文件未更新造成的(原因:比如我们工程从一台机器复制到另一台机器,有可能造成该文件不再更新了)(在我们的main.cpp同级目录那个ui…

    技术杂谈 2023年5月31日
    0141
  • 测试执行和软件缺陷

    测试执行 1.基本概念 测试执行就是执行测试用例、提交Bug 单、测试结论的评估和总结等一系列测试活动,测试执行不仅包含测试用例的执行,还包括其它测试活动. 2.注意事项 (1) …

    技术杂谈 2023年7月25日
    079
  • AndroidC/C++层hook和java层hook原理以及比较

    作者:Denny Qiao(乔喜铭),云智慧/架构师。 云智慧集团成立于2009年,是全栈智能业务运维解决方案服务商。经过多年自主研发,公司形成了从IT运维、电力运维到IoT运维的…

    技术杂谈 2023年7月24日
    076
  • k8s整合Traefik2入门(一)

    k8s整合Traefik入门(一) 安装 首先下载helm,根据自己的k8s版本来选择相应的版本 [root@k8s-master1 ~]# tar -zvxf helm-v3.6…

    技术杂谈 2023年5月31日
    0114
  • Windows下USB磁盘开发系列二:枚举系统中所有USB设备

    上篇 《Windows下USB磁盘开发系列一:枚举系统中U盘的盘符》介绍了很简单的获取系统U盘盘符的办法,现在介绍下如何枚举系统中所有USB设备(不光是U盘)。 主要调用的API如…

    技术杂谈 2023年5月31日
    084
  • quartz框架(五)-Trigger相关内容

    上篇博文,博主介绍了Job的相关内容。本篇博文,博主将介绍Trigger相关的内容。 Trigger是触发器的意思,它只定义Trigger相关属性的Get方法。一个Trigger只…

    技术杂谈 2023年7月24日
    073
  • 读经典【1】重构:改善既有代码的设计

    五星好评。很实用。 最近读了重构原版书,同时也在使用其中的一些技巧来改善工作中的项目,自己改完代码会有成就感。 这本书改变了我原有的思想钢印:代码能成功跑起来就不要去动它。实际上,…

    技术杂谈 2023年7月25日
    089
  • C++ 回调函数及 std::function 与 std::bind

    回调函数是做为参数传递的一种函数,在早期C样式编程当中,回调函数必须依赖函数指针来实现。 而后的C++语言当中,又引入了 std::function 与 std::bind 来配合…

    技术杂谈 2023年6月21日
    0109
  • 71.底细

    dfs posted @2022-09-28 08:47 随遇而安== 阅读(6 ) 评论() 编辑 Original: https://www.cnblogs.com/55zjc…

    技术杂谈 2023年6月21日
    0105
  • 使用ssl_exporter监控K8S集群证书

    使用kubeadm搭建的集群默认证书有效期是1年,续费证书其实是一件很快的事情。但是就怕出事了才发现,毕竟作为专业搬砖工程师,每天都很忙的。 鉴于此,监控集群证书有效期是一件不得不…

    技术杂谈 2023年6月1日
    089
  • Sonarqube安装(Docker)

    一,拉取相关镜像并运行 拉取sonarqube镜像 docker pull sonarqube:9.1.0-community 在运行之前要提前安装postgres并允许,新建数据…

    技术杂谈 2023年7月10日
    082
亲爱的 Coder【最近整理,可免费获取】👉 最新必读书单  | 👏 面试题下载  | 🌎 免费的AI知识星球