transformer系列5---transformer显存占用分析
Transformer显存占用分析
- 1 影响因素概述
- 2 前向计算临时Tensor显存占用
- 2.1 self-attention显存占用
- 2.2 MLP显存占用
- 3 梯度和优化器显存占用
- 3.1 模型训练过程两者显存占用
- 3.2 模型推理过程两者显存占用
1 影响因素概述
- 模型训练框架:例如pytorch框架的cuda context会占用大约几百MB显存,与版本有关;
- 模型参数大小,比如7B的模型以FP16格式要占用14GB显存;
- 前向计算过程中产生的临时Tensor:这部分Tensor需要被临时保存,以便在反向传播计算梯度时使用
- 反向传播计算得到的梯度:
- 优化器状态:全量微调的情况下,梯度与参数一样大,普通SGD没有动量,一阶动量优化器的自身参数大小与模型大小一样,比如momentum-SGD,二阶动量优化器一般为模型大小的两倍,比如Adam, transformer系列的大模型最常用的是Adam优化器
2 前向计算临时Tensor显存占用
2.1 self-attention显存占用
这部分Tensor的大小和模型的每一层结构形状有关(必须根据具体模型的每层形状来计算)也和具体的batch_size大小以及输入数据input_data的大小有关。
- 输入矩阵I:首先计算 Q = I ∗ W q Q =I * W^{q} Q=I∗Wq, K = I ∗ W k K = I * W^{k} K=I∗Wk, V = I ∗ W v V = I * W^{v} V=I∗Wv,输入I是临时Tensor,假设输入I的形状为 [b, s, d],元素个数为 bsd,占用显存大小为2bytes*bsd=2bsd bytes.
- Q K T QK^{T} QKT:Q和K是临时Tensor,假设形状为 [b, s, d],元素个数为 bsd,占用显存大小为22bytesbsd=4bsd bytes。
- softmax: A = Q K T A=QK^{T} A=QKT,输入形状[b, h, s, d] × [b, h, s, d],A矩阵输出形状为 [b, h, s, s],h是头个数。保存A矩阵占用的显存大小为=2bytes* b h s 2 bhs^{2} bhs2= 2 b h s 2 2bhs^{2} 2bhs2 bytes。
- dropout:需要保存一个mask矩阵,mask矩阵的形状与A相同,mask矩阵的元素为0或1,用1个byte表示,占用显存大小为 b h s 2 bhs^{2} bhs2 bytes。
- score* V加权:score矩阵的形状与A相同,占用显存大小为 2 b h s 2 2bhs^{2} 2bhs2 bytes。V矩阵形状[b, s, d],占用显存大小为2bytes*bsd=2bsd bytes。该步骤占用显存大小为 2 b h s 2 + 2 b s d 2bhs^{2}+2bsd 2bhs2+2bsd bytes。
- W O W^{O} WO输出映射:需要临时保存输入矩阵,形状[b, s, d],占用显存大小为2bytes*bsd=2bsd bytes。
- dropout:需要保存一个mask矩阵,mask矩阵的形状为上一步输出形状[b, s, d],mask矩阵的元素为0或1,用1个byte表示,占用显存大小为1bytes*bsd=bsd bytes。
综上步骤,self-attention块的占用显存大小为2bsd+4bsd+ 2 b h s 2 2bhs^{2} 2bhs2+ 2 b h s 2 2bhs^{2} 2bhs2+ 2 b h s 2 + 2 b s d 2bhs^{2}+2bsd 2bhs2+2bsd+2bsd+2bsd=11bsd+ 5 b h s 2 5bhs^{2} 5bhs2
2.2 MLP显存占用
- 第一个线性层需要保存其输入,输入形状为[b, s, d],占用显存大小为 2bytes*bsd=2bsd bytes。
- 激活函数需要保存其输入,为第一步的输出形状为[b, s, 4d],占用显存大小为2bytes*4bsd=8bsd bytes。
- 第二个线性层需要保存其输入,输入形状为[b, s, 4d],占用显存大小为2bytes*4bsd=8bsd bytes。
- 最后有一个dropout操作,需要保存mask矩阵,形状是上一步的输出形状[b, s, d],mask矩阵的元素为0或1,用1个byte表示,占用显存大小为1bytes*bsd=bsd bytes。
综上步骤,MLP的占用显存大小为2bsd+8bsd+8bsd+bsd=19bsd.
3 梯度和优化器显存占用
3.1 模型训练过程两者显存占用
参数占用显存 = 参数数目 × n
n = 2 : float16
n = 4 : float32
n = 8 : double64
其中,float32是最常用的类型,n是数据类型占用的bytes。
训练过程通常为模型参数前向传播,反向传播计算梯度,优化器更新,以Adam优化器为例分析,假如模型参数量为P:
- 混合精度训练:
1)使用float16的模型参数进行前向传递和反向传播,计算得到float16的梯度;
2)在优化器更新模型参数时,使用float32的优化器状态、float32的梯度、float32的模型参数来更新模型参数。
3)对于每个可训练模型参数,模型参数在步骤1)和步骤2)分别是2bytes,4bytes;梯度在步骤1)和步骤2)分别是分别是2bytes,4bytes;优化器状态是2* 模型大小=2*4bytes=8bytes。
每个参数占用(2+4)+(2+4)+8 = 20bytes。模型参数量M时总计20P bytes。
- 普通训练:
上述步骤1)2)均使用float32类型。对于每个可训练模型参数,模型参数在步骤1)和步骤2)分别是4bytes,4bytes;梯度在步骤1)和步骤2)分别是分别是4bytes,4bytes;优化器状态是2* 模型大小=2*4bytes=8bytes。
每个参数占用(4+4)+(4+4)+8 = 24bytes,模型参数量M时总计24P bytes。
3.2 模型推理过程两者显存占用
推理占用显存主要是模型参数,假如模型参数量为P,使用float16来进行推理,推理阶段模型参数占用的显存约2P bytes,使用float32来进行推理,推理阶段模型参数占用的显存约 4P bytes。
参考文章:https://zhuanlan.zhihu.com/p/624740065?utm_id=0
相关文章:
transformer系列5---transformer显存占用分析
Transformer显存占用分析 1 影响因素概述2 前向计算临时Tensor显存占用2.1 self-attention显存占用2.2 MLP显存占用 3 梯度和优化器显存占用3.1 模型训练过程两者显存占用3.2 模型推理过程两者显存占用 1 影响因素概述 模型训练框架:例如pytorch框架的cuda context…...

Docker项目部署
目录 一、前端项目部署 1、上传文件 2、开启容器 3、测试 二、后端项目部署 1、打包java项目 2、将jar包和Dockerfile文件长传到Linux系统 3、构建镜像 4、开启容器 5、测试 三、DockerCompose快速部署 基本语法 一、前端项目部署 1、上传文件 里面包括页面和配置文…...
vue3实现文本超出鼠标移入的时候文本滚动
判断文本长度是否大于容器长度 鼠标移入的时候判断,此处使用了tailwindcss,注意一下要设置文本不换行。 <divref"functionsItems"mouseenter"enterFunctionsItem($event, index)"><img class"w-5 h-5" :src&quo…...

光伏系统MPPT、恒功率控制切换Simulink仿真
💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…...

mysql双主互从通过KeepAlived虚拟IP实现高可用
mysql双主互从通过KeepAlived虚拟IP实现高可用 在mysql 双主互从的基础上, 架构图: Keepalived有两个主要的功能: 提供虚拟IP,实现双机热备通过LVS,实现负载均衡 安装 # 安装 yum -y install keepalived # 卸载 …...

苹果应用高版本出现:“无法安装此app,因为无法验证其完整性”是怎么回事?竟然是错误的?
最近经常有同学私聊我问苹果应用签名后用落地页下载出现高版本是什么意思?我一脸懵!还有这个操作?高版本是个啥玩意!所以我就上了一下科技去搜索引擎搜索了下,哈哈哈,然后了解下来发现是这样的首先我们确定…...
AF_UNIX和127.0.0.1(AF_INET)回环地址写数据速度对比
在linux下,存在着这样的情况,本地的进程间通信,并且其中一个是服务端,另外的都是客户端。 服务端通过绑定端口,客户端往127.0.0.1的对应端口发送,即可办到,不过这样会浪费一个端口,同…...

我在 NPM 发布了新包: con-colors
链接地址:npmjs.com con-colors 安装依赖 yarn add con-colors使用 导入: import { print } from "con-colors";使用: print.succ("成功的消息"); print.err("失败的消息")例子: import { p…...
【python数据建模】Scipy库
常用模块列表 模块名功能scipy.constants数学常量scipy.fft离散傅里叶变换scipy.integrate积分scipy.interpolate插值scipy.interpolate线性代数scipy.cluster聚类分析、向量量化scipy.io数据输入输出scipy.misc图像处理scipy.ndimagen维图像scipy.odr正交距离回归scipy.optim…...
C# App.xaml.cs的一些操作
一、保证只有一个进程 1.1 关闭旧的,打开新的 protected override void OnStartup(StartupEventArgs e) {base.OnStartup(e);var process Process.GetProcessesByName("Dog");if (process.Count() > 1) {var list process.ToList();list.Sort((p1,p2…...

【ORACLE】ORA-00972:标识符过长
问题 执行创建表结构sql,提示 ORA-00972:标识符过长; 如图所示,约束名称超过30个字符了 原因 一、11G and before 在使用11G数据库时,经常会遇到报错ORA-00972,原因是因为对象名称定义太长,…...

【Vue】Vue快速入门、Vue常用指令、Vue的生命周期
🐌个人主页: 🐌 叶落闲庭 💨我的专栏:💨 c语言 数据结构 javaEE 操作系统 Redis 石可破也,而不可夺坚;丹可磨也,而不可夺赤。 Vue 一、 Vue快速入门二、Vue常用指令2.1 v…...
Pandas 数据处理 类别数据和数值数据
要是作深度学习的话,可以直接用tensoflow框架的预处理层,我试过,比PyTorch自己写出来的会好一点,主要是简单好用。处理CSV文件 它类别的处理逻辑是onehot,比较标准稀疏,数值的话就是归一化了。 有时候不需…...

Android攻城狮学鸿蒙 -- 点击事件
具体参考:华为官网学习地址 1、点击事件,界面跳转 对于一个按钮设置点击事件,跳转页面。但是onclick中,如果pages前边加上“/”,就没法跳转。但是开发工具加上“/”才会给出提示。不知道是不是开发工具的bug。&#…...

jmeter性能测试常见的一些问题
一、request 请求超时设置 timeout 超时时间是可以手动设置的,新建一个 http 请求,在“高级”设置中找到“超时”设置,设置连接、响应时间为2000ms。 1. 请求连接超时,连不上服务器。 现象: Jmeter表现形式为ÿ…...
利用国外 vps 为 switch 设置代理服务器加速游戏下载
switch 在国内通过 wifi 连网后如果直接下载游戏的话速度特别慢,据说要挂一个晚上才能下载成功一个游戏。当我尝试下载时发现进度条基本不动,怀疑软件源是在国外的原因,于是想到可以通过国外 vps 代理中转的方式。具体步骤如下(以…...

云计算安全的新挑战:零信任架构的应用
文章目录 云计算的安全挑战什么是零信任架构?零信任架构的应用1. 多因素身份验证(MFA)2. 访问控制和策略3. 安全信息和事件管理(SIEM)4. 安全的应用程序开发 零信任架构的未来 🎉欢迎来到云计算技术应用专栏…...

基于SSM的药房药品采购集中管理系统的设计与实现
末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用Vue技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…...
【GIT版本控制】--远程仓库
一、连接远程仓库 连接到远程仓库是在GIT中进行协作和备份的关键步骤。以下是连接到远程仓库的基本步骤: 获取远程仓库的URL:首先,你需要获得远程仓库的URL。通常,这是远程仓库提供给你的,可以是HTTPS或SSH URL。例如…...
1:Allotment,2:FeeSell,3:混合Allotment+FreeSell
根据您的描述,这似乎是与酒店预订相关的三种不同的方式。下面是对这三种方式的解释: Allotment(配额):这是一种酒店预订方式,其中您可以与酒店签订协议,并购买其一定数量的房间或床位。在此之后…...

什么是库存周转?如何用进销存系统提高库存周转率?
你可能听说过这样一句话: “利润不是赚出来的,是管出来的。” 尤其是在制造业、批发零售、电商这类“货堆成山”的行业,很多企业看着销售不错,账上却没钱、利润也不见了,一翻库存才发现: 一堆卖不动的旧货…...
linux 错误码总结
1,错误码的概念与作用 在Linux系统中,错误码是系统调用或库函数在执行失败时返回的特定数值,用于指示具体的错误类型。这些错误码通过全局变量errno来存储和传递,errno由操作系统维护,保存最近一次发生的错误信息。值得注意的是,errno的值在每次系统调用或函数调用失败时…...
第25节 Node.js 断言测试
Node.js的assert模块主要用于编写程序的单元测试时使用,通过断言可以提早发现和排查出错误。 稳定性: 5 - 锁定 这个模块可用于应用的单元测试,通过 require(assert) 可以使用这个模块。 assert.fail(actual, expected, message, operator) 使用参数…...

如何将联系人从 iPhone 转移到 Android
从 iPhone 换到 Android 手机时,你可能需要保留重要的数据,例如通讯录。好在,将通讯录从 iPhone 转移到 Android 手机非常简单,你可以从本文中学习 6 种可靠的方法,确保随时保持连接,不错过任何信息。 第 1…...

BCS 2025|百度副总裁陈洋:智能体在安全领域的应用实践
6月5日,2025全球数字经济大会数字安全主论坛暨北京网络安全大会在国家会议中心隆重开幕。百度副总裁陈洋受邀出席,并作《智能体在安全领域的应用实践》主题演讲,分享了在智能体在安全领域的突破性实践。他指出,百度通过将安全能力…...
Python如何给视频添加音频和字幕
在Python中,给视频添加音频和字幕可以使用电影文件处理库MoviePy和字幕处理库Subtitles。下面将详细介绍如何使用这些库来实现视频的音频和字幕添加,包括必要的代码示例和详细解释。 环境准备 在开始之前,需要安装以下Python库:…...

ardupilot 开发环境eclipse 中import 缺少C++
目录 文章目录 目录摘要1.修复过程摘要 本节主要解决ardupilot 开发环境eclipse 中import 缺少C++,无法导入ardupilot代码,会引起查看不方便的问题。如下图所示 1.修复过程 0.安装ubuntu 软件中自带的eclipse 1.打开eclipse—Help—install new software 2.在 Work with中…...

Android15默认授权浮窗权限
我们经常有那种需求,客户需要定制的apk集成在ROM中,并且默认授予其【显示在其他应用的上层】权限,也就是我们常说的浮窗权限,那么我们就可以通过以下方法在wms、ams等系统服务的systemReady()方法中调用即可实现预置应用默认授权浮…...
Rapidio门铃消息FIFO溢出机制
关于RapidIO门铃消息FIFO的溢出机制及其与中断抖动的关系,以下是深入解析: 门铃FIFO溢出的本质 在RapidIO系统中,门铃消息FIFO是硬件控制器内部的缓冲区,用于临时存储接收到的门铃消息(Doorbell Message)。…...
基于matlab策略迭代和值迭代法的动态规划
经典的基于策略迭代和值迭代法的动态规划matlab代码,实现机器人的最优运输 Dynamic-Programming-master/Environment.pdf , 104724 Dynamic-Programming-master/README.md , 506 Dynamic-Programming-master/generalizedPolicyIteration.m , 1970 Dynamic-Programm…...