transformer系列5---transformer显存占用分析
Transformer显存占用分析
- 1 影响因素概述
- 2 前向计算临时Tensor显存占用
- 2.1 self-attention显存占用
- 2.2 MLP显存占用
- 3 梯度和优化器显存占用
- 3.1 模型训练过程两者显存占用
- 3.2 模型推理过程两者显存占用
1 影响因素概述
- 模型训练框架:例如pytorch框架的cuda context会占用大约几百MB显存,与版本有关;
- 模型参数大小,比如7B的模型以FP16格式要占用14GB显存;
- 前向计算过程中产生的临时Tensor:这部分Tensor需要被临时保存,以便在反向传播计算梯度时使用
- 反向传播计算得到的梯度:
- 优化器状态:全量微调的情况下,梯度与参数一样大,普通SGD没有动量,一阶动量优化器的自身参数大小与模型大小一样,比如momentum-SGD,二阶动量优化器一般为模型大小的两倍,比如Adam, transformer系列的大模型最常用的是Adam优化器
2 前向计算临时Tensor显存占用
2.1 self-attention显存占用
这部分Tensor的大小和模型的每一层结构形状有关(必须根据具体模型的每层形状来计算)也和具体的batch_size大小以及输入数据input_data的大小有关。
- 输入矩阵I:首先计算 Q = I ∗ W q Q =I * W^{q} Q=I∗Wq, K = I ∗ W k K = I * W^{k} K=I∗Wk, V = I ∗ W v V = I * W^{v} V=I∗Wv,输入I是临时Tensor,假设输入I的形状为 [b, s, d],元素个数为 bsd,占用显存大小为2bytes*bsd=2bsd bytes.
- Q K T QK^{T} QKT:Q和K是临时Tensor,假设形状为 [b, s, d],元素个数为 bsd,占用显存大小为22bytesbsd=4bsd bytes。
- softmax: A = Q K T A=QK^{T} A=QKT,输入形状[b, h, s, d] × [b, h, s, d],A矩阵输出形状为 [b, h, s, s],h是头个数。保存A矩阵占用的显存大小为=2bytes* b h s 2 bhs^{2} bhs2= 2 b h s 2 2bhs^{2} 2bhs2 bytes。
- dropout:需要保存一个mask矩阵,mask矩阵的形状与A相同,mask矩阵的元素为0或1,用1个byte表示,占用显存大小为 b h s 2 bhs^{2} bhs2 bytes。
- score* V加权:score矩阵的形状与A相同,占用显存大小为 2 b h s 2 2bhs^{2} 2bhs2 bytes。V矩阵形状[b, s, d],占用显存大小为2bytes*bsd=2bsd bytes。该步骤占用显存大小为 2 b h s 2 + 2 b s d 2bhs^{2}+2bsd 2bhs2+2bsd bytes。
- W O W^{O} WO输出映射:需要临时保存输入矩阵,形状[b, s, d],占用显存大小为2bytes*bsd=2bsd bytes。
- dropout:需要保存一个mask矩阵,mask矩阵的形状为上一步输出形状[b, s, d],mask矩阵的元素为0或1,用1个byte表示,占用显存大小为1bytes*bsd=bsd bytes。
综上步骤,self-attention块的占用显存大小为2bsd+4bsd+ 2 b h s 2 2bhs^{2} 2bhs2+ 2 b h s 2 2bhs^{2} 2bhs2+ 2 b h s 2 + 2 b s d 2bhs^{2}+2bsd 2bhs2+2bsd+2bsd+2bsd=11bsd+ 5 b h s 2 5bhs^{2} 5bhs2
2.2 MLP显存占用
- 第一个线性层需要保存其输入,输入形状为[b, s, d],占用显存大小为 2bytes*bsd=2bsd bytes。
- 激活函数需要保存其输入,为第一步的输出形状为[b, s, 4d],占用显存大小为2bytes*4bsd=8bsd bytes。
- 第二个线性层需要保存其输入,输入形状为[b, s, 4d],占用显存大小为2bytes*4bsd=8bsd bytes。
- 最后有一个dropout操作,需要保存mask矩阵,形状是上一步的输出形状[b, s, d],mask矩阵的元素为0或1,用1个byte表示,占用显存大小为1bytes*bsd=bsd bytes。
综上步骤,MLP的占用显存大小为2bsd+8bsd+8bsd+bsd=19bsd.
3 梯度和优化器显存占用
3.1 模型训练过程两者显存占用
参数占用显存 = 参数数目 × n
n = 2 : float16
n = 4 : float32
n = 8 : double64
其中,float32是最常用的类型,n是数据类型占用的bytes。
训练过程通常为模型参数前向传播,反向传播计算梯度,优化器更新,以Adam优化器为例分析,假如模型参数量为P:
- 混合精度训练:
1)使用float16的模型参数进行前向传递和反向传播,计算得到float16的梯度;
2)在优化器更新模型参数时,使用float32的优化器状态、float32的梯度、float32的模型参数来更新模型参数。
3)对于每个可训练模型参数,模型参数在步骤1)和步骤2)分别是2bytes,4bytes;梯度在步骤1)和步骤2)分别是分别是2bytes,4bytes;优化器状态是2* 模型大小=2*4bytes=8bytes。
每个参数占用(2+4)+(2+4)+8 = 20bytes。模型参数量M时总计20P bytes。
- 普通训练:
上述步骤1)2)均使用float32类型。对于每个可训练模型参数,模型参数在步骤1)和步骤2)分别是4bytes,4bytes;梯度在步骤1)和步骤2)分别是分别是4bytes,4bytes;优化器状态是2* 模型大小=2*4bytes=8bytes。
每个参数占用(4+4)+(4+4)+8 = 24bytes,模型参数量M时总计24P bytes。
3.2 模型推理过程两者显存占用
推理占用显存主要是模型参数,假如模型参数量为P,使用float16来进行推理,推理阶段模型参数占用的显存约2P bytes,使用float32来进行推理,推理阶段模型参数占用的显存约 4P bytes。
参考文章:https://zhuanlan.zhihu.com/p/624740065?utm_id=0
相关文章:
transformer系列5---transformer显存占用分析
Transformer显存占用分析 1 影响因素概述2 前向计算临时Tensor显存占用2.1 self-attention显存占用2.2 MLP显存占用 3 梯度和优化器显存占用3.1 模型训练过程两者显存占用3.2 模型推理过程两者显存占用 1 影响因素概述 模型训练框架:例如pytorch框架的cuda context…...
Docker项目部署
目录 一、前端项目部署 1、上传文件 2、开启容器 3、测试 二、后端项目部署 1、打包java项目 2、将jar包和Dockerfile文件长传到Linux系统 3、构建镜像 4、开启容器 5、测试 三、DockerCompose快速部署 基本语法 一、前端项目部署 1、上传文件 里面包括页面和配置文…...
vue3实现文本超出鼠标移入的时候文本滚动
判断文本长度是否大于容器长度 鼠标移入的时候判断,此处使用了tailwindcss,注意一下要设置文本不换行。 <divref"functionsItems"mouseenter"enterFunctionsItem($event, index)"><img class"w-5 h-5" :src&quo…...
光伏系统MPPT、恒功率控制切换Simulink仿真
💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…...
mysql双主互从通过KeepAlived虚拟IP实现高可用
mysql双主互从通过KeepAlived虚拟IP实现高可用 在mysql 双主互从的基础上, 架构图: Keepalived有两个主要的功能: 提供虚拟IP,实现双机热备通过LVS,实现负载均衡 安装 # 安装 yum -y install keepalived # 卸载 …...
苹果应用高版本出现:“无法安装此app,因为无法验证其完整性”是怎么回事?竟然是错误的?
最近经常有同学私聊我问苹果应用签名后用落地页下载出现高版本是什么意思?我一脸懵!还有这个操作?高版本是个啥玩意!所以我就上了一下科技去搜索引擎搜索了下,哈哈哈,然后了解下来发现是这样的首先我们确定…...
AF_UNIX和127.0.0.1(AF_INET)回环地址写数据速度对比
在linux下,存在着这样的情况,本地的进程间通信,并且其中一个是服务端,另外的都是客户端。 服务端通过绑定端口,客户端往127.0.0.1的对应端口发送,即可办到,不过这样会浪费一个端口,同…...
我在 NPM 发布了新包: con-colors
链接地址:npmjs.com con-colors 安装依赖 yarn add con-colors使用 导入: import { print } from "con-colors";使用: print.succ("成功的消息"); print.err("失败的消息")例子: import { p…...
【python数据建模】Scipy库
常用模块列表 模块名功能scipy.constants数学常量scipy.fft离散傅里叶变换scipy.integrate积分scipy.interpolate插值scipy.interpolate线性代数scipy.cluster聚类分析、向量量化scipy.io数据输入输出scipy.misc图像处理scipy.ndimagen维图像scipy.odr正交距离回归scipy.optim…...
C# App.xaml.cs的一些操作
一、保证只有一个进程 1.1 关闭旧的,打开新的 protected override void OnStartup(StartupEventArgs e) {base.OnStartup(e);var process Process.GetProcessesByName("Dog");if (process.Count() > 1) {var list process.ToList();list.Sort((p1,p2…...
【ORACLE】ORA-00972:标识符过长
问题 执行创建表结构sql,提示 ORA-00972:标识符过长; 如图所示,约束名称超过30个字符了 原因 一、11G and before 在使用11G数据库时,经常会遇到报错ORA-00972,原因是因为对象名称定义太长,…...
【Vue】Vue快速入门、Vue常用指令、Vue的生命周期
🐌个人主页: 🐌 叶落闲庭 💨我的专栏:💨 c语言 数据结构 javaEE 操作系统 Redis 石可破也,而不可夺坚;丹可磨也,而不可夺赤。 Vue 一、 Vue快速入门二、Vue常用指令2.1 v…...
Pandas 数据处理 类别数据和数值数据
要是作深度学习的话,可以直接用tensoflow框架的预处理层,我试过,比PyTorch自己写出来的会好一点,主要是简单好用。处理CSV文件 它类别的处理逻辑是onehot,比较标准稀疏,数值的话就是归一化了。 有时候不需…...
Android攻城狮学鸿蒙 -- 点击事件
具体参考:华为官网学习地址 1、点击事件,界面跳转 对于一个按钮设置点击事件,跳转页面。但是onclick中,如果pages前边加上“/”,就没法跳转。但是开发工具加上“/”才会给出提示。不知道是不是开发工具的bug。&#…...
jmeter性能测试常见的一些问题
一、request 请求超时设置 timeout 超时时间是可以手动设置的,新建一个 http 请求,在“高级”设置中找到“超时”设置,设置连接、响应时间为2000ms。 1. 请求连接超时,连不上服务器。 现象: Jmeter表现形式为ÿ…...
利用国外 vps 为 switch 设置代理服务器加速游戏下载
switch 在国内通过 wifi 连网后如果直接下载游戏的话速度特别慢,据说要挂一个晚上才能下载成功一个游戏。当我尝试下载时发现进度条基本不动,怀疑软件源是在国外的原因,于是想到可以通过国外 vps 代理中转的方式。具体步骤如下(以…...
云计算安全的新挑战:零信任架构的应用
文章目录 云计算的安全挑战什么是零信任架构?零信任架构的应用1. 多因素身份验证(MFA)2. 访问控制和策略3. 安全信息和事件管理(SIEM)4. 安全的应用程序开发 零信任架构的未来 🎉欢迎来到云计算技术应用专栏…...
基于SSM的药房药品采购集中管理系统的设计与实现
末尾获取源码 开发语言:Java Java开发工具:JDK1.8 后端框架:SSM 前端:采用Vue技术开发 数据库:MySQL5.7和Navicat管理工具结合 服务器:Tomcat8.5 开发软件:IDEA / Eclipse 是否Maven项目&#x…...
【GIT版本控制】--远程仓库
一、连接远程仓库 连接到远程仓库是在GIT中进行协作和备份的关键步骤。以下是连接到远程仓库的基本步骤: 获取远程仓库的URL:首先,你需要获得远程仓库的URL。通常,这是远程仓库提供给你的,可以是HTTPS或SSH URL。例如…...
1:Allotment,2:FeeSell,3:混合Allotment+FreeSell
根据您的描述,这似乎是与酒店预订相关的三种不同的方式。下面是对这三种方式的解释: Allotment(配额):这是一种酒店预订方式,其中您可以与酒店签订协议,并购买其一定数量的房间或床位。在此之后…...
Python|GIF 解析与构建(5):手搓截屏和帧率控制
目录 Python|GIF 解析与构建(5):手搓截屏和帧率控制 一、引言 二、技术实现:手搓截屏模块 2.1 核心原理 2.2 代码解析:ScreenshotData类 2.2.1 截图函数:capture_screen 三、技术实现&…...
使用VSCode开发Django指南
使用VSCode开发Django指南 一、概述 Django 是一个高级 Python 框架,专为快速、安全和可扩展的 Web 开发而设计。Django 包含对 URL 路由、页面模板和数据处理的丰富支持。 本文将创建一个简单的 Django 应用,其中包含三个使用通用基本模板的页面。在此…...
微软PowerBI考试 PL300-选择 Power BI 模型框架【附练习数据】
微软PowerBI考试 PL300-选择 Power BI 模型框架 20 多年来,Microsoft 持续对企业商业智能 (BI) 进行大量投资。 Azure Analysis Services (AAS) 和 SQL Server Analysis Services (SSAS) 基于无数企业使用的成熟的 BI 数据建模技术。 同样的技术也是 Power BI 数据…...
将对透视变换后的图像使用Otsu进行阈值化,来分离黑色和白色像素。这句话中的Otsu是什么意思?
Otsu 是一种自动阈值化方法,用于将图像分割为前景和背景。它通过最小化图像的类内方差或等价地最大化类间方差来选择最佳阈值。这种方法特别适用于图像的二值化处理,能够自动确定一个阈值,将图像中的像素分为黑色和白色两类。 Otsu 方法的原…...
ETLCloud可能遇到的问题有哪些?常见坑位解析
数据集成平台ETLCloud,主要用于支持数据的抽取(Extract)、转换(Transform)和加载(Load)过程。提供了一个简洁直观的界面,以便用户可以在不同的数据源之间轻松地进行数据迁移和转换。…...
GitHub 趋势日报 (2025年06月08日)
📊 由 TrendForge 系统生成 | 🌐 https://trendforge.devlive.org/ 🌐 本日报中的项目描述已自动翻译为中文 📈 今日获星趋势图 今日获星趋势图 884 cognee 566 dify 414 HumanSystemOptimization 414 omni-tools 321 note-gen …...
NLP学习路线图(二十三):长短期记忆网络(LSTM)
在自然语言处理(NLP)领域,我们时刻面临着处理序列数据的核心挑战。无论是理解句子的结构、分析文本的情感,还是实现语言的翻译,都需要模型能够捕捉词语之间依时序产生的复杂依赖关系。传统的神经网络结构在处理这种序列依赖时显得力不从心,而循环神经网络(RNN) 曾被视为…...
Rapidio门铃消息FIFO溢出机制
关于RapidIO门铃消息FIFO的溢出机制及其与中断抖动的关系,以下是深入解析: 门铃FIFO溢出的本质 在RapidIO系统中,门铃消息FIFO是硬件控制器内部的缓冲区,用于临时存储接收到的门铃消息(Doorbell Message)。…...
Web 架构之 CDN 加速原理与落地实践
文章目录 一、思维导图二、正文内容(一)CDN 基础概念1. 定义2. 组成部分 (二)CDN 加速原理1. 请求路由2. 内容缓存3. 内容更新 (三)CDN 落地实践1. 选择 CDN 服务商2. 配置 CDN3. 集成到 Web 架构 …...
如何在网页里填写 PDF 表格?
有时候,你可能希望用户能在你的网站上填写 PDF 表单。然而,这件事并不简单,因为 PDF 并不是一种原生的网页格式。虽然浏览器可以显示 PDF 文件,但原生并不支持编辑或填写它们。更糟的是,如果你想收集表单数据ÿ…...
