当前位置: 首页 > news >正文

transformer系列5---transformer显存占用分析

Transformer显存占用分析

  • 1 影响因素概述
  • 2 前向计算临时Tensor显存占用
    • 2.1 self-attention显存占用
    • 2.2 MLP显存占用
  • 3 梯度和优化器显存占用
    • 3.1 模型训练过程两者显存占用
    • 3.2 模型推理过程两者显存占用

1 影响因素概述

  1. 模型训练框架:例如pytorch框架的cuda context会占用大约几百MB显存,与版本有关;
  2. 模型参数大小,比如7B的模型以FP16格式要占用14GB显存;
  3. 前向计算过程中产生的临时Tensor:这部分Tensor需要被临时保存,以便在反向传播计算梯度时使用
  4. 反向传播计算得到的梯度:
  5. 优化器状态:全量微调的情况下,梯度与参数一样大,普通SGD没有动量,一阶动量优化器的自身参数大小与模型大小一样,比如momentum-SGD,二阶动量优化器一般为模型大小的两倍,比如Adam, transformer系列的大模型最常用的是Adam优化器

2 前向计算临时Tensor显存占用

2.1 self-attention显存占用

这部分Tensor的大小和模型的每一层结构形状有关(必须根据具体模型的每层形状来计算)也和具体的batch_size大小以及输入数据input_data的大小有关。

  1. 输入矩阵I:首先计算 Q = I ∗ W q Q =I * W^{q} Q=IWq K = I ∗ W k K = I * W^{k} K=IWk V = I ∗ W v V = I * W^{v} V=IWv,输入I是临时Tensor,假设输入I的形状为 [b, s, d],元素个数为 bsd,占用显存大小为2bytes*bsd=2bsd bytes.
  2. Q K T QK^{T} QKT:Q和K是临时Tensor,假设形状为 [b, s, d],元素个数为 bsd,占用显存大小为22bytesbsd=4bsd bytes。
  3. softmax: A = Q K T A=QK^{T} A=QKT,输入形状[b, h, s, d] × [b, h, s, d],A矩阵输出形状为 [b, h, s, s],h是头个数。保存A矩阵占用的显存大小为=2bytes* b h s 2 bhs^{2} bhs2= 2 b h s 2 2bhs^{2} 2bhs2 bytes。
  4. dropout:需要保存一个mask矩阵,mask矩阵的形状与A相同,mask矩阵的元素为0或1,用1个byte表示,占用显存大小为 b h s 2 bhs^{2} bhs2 bytes。
  5. score* V加权:score矩阵的形状与A相同,占用显存大小为 2 b h s 2 2bhs^{2} 2bhs2 bytes。V矩阵形状[b, s, d],占用显存大小为2bytes*bsd=2bsd bytes。该步骤占用显存大小为 2 b h s 2 + 2 b s d 2bhs^{2}+2bsd 2bhs2+2bsd bytes。
  6. W O W^{O} WO输出映射:需要临时保存输入矩阵,形状[b, s, d],占用显存大小为2bytes*bsd=2bsd bytes。
  7. dropout:需要保存一个mask矩阵,mask矩阵的形状为上一步输出形状[b, s, d],mask矩阵的元素为0或1,用1个byte表示,占用显存大小为1bytes*bsd=bsd bytes。
    综上步骤,self-attention块的占用显存大小为2bsd+4bsd+ 2 b h s 2 2bhs^{2} 2bhs2+ 2 b h s 2 2bhs^{2} 2bhs2+ 2 b h s 2 + 2 b s d 2bhs^{2}+2bsd 2bhs2+2bsd+2bsd+2bsd=11bsd+ 5 b h s 2 5bhs^{2} 5bhs2

2.2 MLP显存占用

  1. 第一个线性层需要保存其输入,输入形状为[b, s, d],占用显存大小为 2bytes*bsd=2bsd bytes。
  2. 激活函数需要保存其输入,为第一步的输出形状为[b, s, 4d],占用显存大小为2bytes*4bsd=8bsd bytes。
  3. 第二个线性层需要保存其输入,输入形状为[b, s, 4d],占用显存大小为2bytes*4bsd=8bsd bytes。
  4. 最后有一个dropout操作,需要保存mask矩阵,形状是上一步的输出形状[b, s, d],mask矩阵的元素为0或1,用1个byte表示,占用显存大小为1bytes*bsd=bsd bytes。

综上步骤,MLP的占用显存大小为2bsd+8bsd+8bsd+bsd=19bsd.

3 梯度和优化器显存占用

3.1 模型训练过程两者显存占用

参数占用显存 = 参数数目 × n
n = 2 : float16
n = 4 : float32
n = 8 : double64
其中,float32是最常用的类型,n是数据类型占用的bytes。
训练过程通常为模型参数前向传播,反向传播计算梯度,优化器更新,以Adam优化器为例分析,假如模型参数量为P:

  1. 混合精度训练:
    1)使用float16的模型参数进行前向传递和反向传播,计算得到float16的梯度;
    2)在优化器更新模型参数时,使用float32的优化器状态、float32的梯度、float32的模型参数来更新模型参数。
    3)对于每个可训练模型参数,模型参数在步骤1)和步骤2)分别是2bytes,4bytes;梯度在步骤1)和步骤2)分别是分别是2bytes,4bytes;优化器状态是2* 模型大小=2*4bytes=8bytes。

每个参数占用(2+4)+(2+4)+8 = 20bytes。模型参数量M时总计20P bytes。

  1. 普通训练:
    上述步骤1)2)均使用float32类型。对于每个可训练模型参数,模型参数在步骤1)和步骤2)分别是4bytes,4bytes;梯度在步骤1)和步骤2)分别是分别是4bytes,4bytes;优化器状态是2* 模型大小=2*4bytes=8bytes。

每个参数占用(4+4)+(4+4)+8 = 24bytes,模型参数量M时总计24P bytes。

3.2 模型推理过程两者显存占用

推理占用显存主要是模型参数,假如模型参数量为P,使用float16来进行推理,推理阶段模型参数占用的显存约2P bytes,使用float32来进行推理,推理阶段模型参数占用的显存约 4P bytes。

参考文章:https://zhuanlan.zhihu.com/p/624740065?utm_id=0

相关文章:

transformer系列5---transformer显存占用分析

Transformer显存占用分析 1 影响因素概述2 前向计算临时Tensor显存占用2.1 self-attention显存占用2.2 MLP显存占用 3 梯度和优化器显存占用3.1 模型训练过程两者显存占用3.2 模型推理过程两者显存占用 1 影响因素概述 模型训练框架:例如pytorch框架的cuda context…...

Docker项目部署

目录 一、前端项目部署 1、上传文件 2、开启容器 3、测试 二、后端项目部署 1、打包java项目 2、将jar包和Dockerfile文件长传到Linux系统 3、构建镜像 4、开启容器 5、测试 三、DockerCompose快速部署 基本语法 一、前端项目部署 1、上传文件 里面包括页面和配置文…...

vue3实现文本超出鼠标移入的时候文本滚动

判断文本长度是否大于容器长度 鼠标移入的时候判断&#xff0c;此处使用了tailwindcss&#xff0c;注意一下要设置文本不换行。 <divref"functionsItems"mouseenter"enterFunctionsItem($event, index)"><img class"w-5 h-5" :src&quo…...

光伏系统MPPT、恒功率控制切换Simulink仿真

&#x1f4a5;&#x1f4a5;&#x1f49e;&#x1f49e;欢迎来到本博客❤️❤️&#x1f4a5;&#x1f4a5; &#x1f3c6;博主优势&#xff1a;&#x1f31e;&#x1f31e;&#x1f31e;博客内容尽量做到思维缜密&#xff0c;逻辑清晰&#xff0c;为了方便读者。 ⛳️座右铭&a…...

mysql双主互从通过KeepAlived虚拟IP实现高可用

mysql双主互从通过KeepAlived虚拟IP实现高可用 在mysql 双主互从的基础上&#xff0c; 架构图&#xff1a; Keepalived有两个主要的功能&#xff1a; 提供虚拟IP&#xff0c;实现双机热备通过LVS&#xff0c;实现负载均衡 安装 # 安装 yum -y install keepalived # 卸载 …...

​苹果应用高版本出现:“无法安装此app,因为无法验证其完整性”是怎么回事?竟然是错误的?

最近经常有同学私聊我问苹果应用签名后用落地页下载出现高版本是什么意思&#xff1f;我一脸懵&#xff01;还有这个操作&#xff1f;高版本是个啥玩意&#xff01;所以我就上了一下科技去搜索引擎搜索了下&#xff0c;哈哈哈&#xff0c;然后了解下来发现是这样的首先我们确定…...

AF_UNIX和127.0.0.1(AF_INET)回环地址写数据速度对比

在linux下&#xff0c;存在着这样的情况&#xff0c;本地的进程间通信&#xff0c;并且其中一个是服务端&#xff0c;另外的都是客户端。 服务端通过绑定端口&#xff0c;客户端往127.0.0.1的对应端口发送&#xff0c;即可办到&#xff0c;不过这样会浪费一个端口&#xff0c;同…...

我在 NPM 发布了新包: con-colors

链接地址&#xff1a;npmjs.com con-colors 安装依赖 yarn add con-colors使用 导入&#xff1a; import { print } from "con-colors";使用&#xff1a; print.succ("成功的消息"); print.err("失败的消息")例子&#xff1a; import { p…...

【python数据建模】Scipy库

常用模块列表 模块名功能scipy.constants数学常量scipy.fft离散傅里叶变换scipy.integrate积分scipy.interpolate插值scipy.interpolate线性代数scipy.cluster聚类分析、向量量化scipy.io数据输入输出scipy.misc图像处理scipy.ndimagen维图像scipy.odr正交距离回归scipy.optim…...

C# App.xaml.cs的一些操作

一、保证只有一个进程 1.1 关闭旧的&#xff0c;打开新的 protected override void OnStartup(StartupEventArgs e) {base.OnStartup(e);var process Process.GetProcessesByName("Dog");if (process.Count() > 1) {var list process.ToList();list.Sort((p1,p2…...

【ORACLE】ORA-00972:标识符过长

问题 执行创建表结构sql&#xff0c;提示 ORA-00972&#xff1a;标识符过长&#xff1b; 如图所示&#xff0c;约束名称超过30个字符了 原因 一、11G and before 在使用11G数据库时&#xff0c;经常会遇到报错ORA-00972&#xff0c;原因是因为对象名称定义太长&#xff0c…...

【Vue】Vue快速入门、Vue常用指令、Vue的生命周期

&#x1f40c;个人主页&#xff1a; &#x1f40c; 叶落闲庭 &#x1f4a8;我的专栏&#xff1a;&#x1f4a8; c语言 数据结构 javaEE 操作系统 Redis 石可破也&#xff0c;而不可夺坚&#xff1b;丹可磨也&#xff0c;而不可夺赤。 Vue 一、 Vue快速入门二、Vue常用指令2.1 v…...

Pandas 数据处理 类别数据和数值数据

要是作深度学习的话&#xff0c;可以直接用tensoflow框架的预处理层&#xff0c;我试过&#xff0c;比PyTorch自己写出来的会好一点&#xff0c;主要是简单好用。处理CSV文件 它类别的处理逻辑是onehot&#xff0c;比较标准稀疏&#xff0c;数值的话就是归一化了。 有时候不需…...

Android攻城狮学鸿蒙 -- 点击事件

具体参考&#xff1a;华为官网学习地址 1、点击事件&#xff0c;界面跳转 对于一个按钮设置点击事件&#xff0c;跳转页面。但是onclick中&#xff0c;如果pages前边加上“/”&#xff0c;就没法跳转。但是开发工具加上“/”才会给出提示。不知道是不是开发工具的bug。&#…...

jmeter性能测试常见的一些问题

一、request 请求超时设置 timeout 超时时间是可以手动设置的&#xff0c;新建一个 http 请求&#xff0c;在“高级”设置中找到“超时”设置&#xff0c;设置连接、响应时间为2000ms。 1. 请求连接超时&#xff0c;连不上服务器。 现象&#xff1a; Jmeter表现形式为&#xff…...

利用国外 vps 为 switch 设置代理服务器加速游戏下载

switch 在国内通过 wifi 连网后如果直接下载游戏的话速度特别慢&#xff0c;据说要挂一个晚上才能下载成功一个游戏。当我尝试下载时发现进度条基本不动&#xff0c;怀疑软件源是在国外的原因&#xff0c;于是想到可以通过国外 vps 代理中转的方式。具体步骤如下&#xff08;以…...

云计算安全的新挑战:零信任架构的应用

文章目录 云计算的安全挑战什么是零信任架构&#xff1f;零信任架构的应用1. 多因素身份验证&#xff08;MFA&#xff09;2. 访问控制和策略3. 安全信息和事件管理&#xff08;SIEM&#xff09;4. 安全的应用程序开发 零信任架构的未来 &#x1f389;欢迎来到云计算技术应用专栏…...

基于SSM的药房药品采购集中管理系统的设计与实现

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;采用Vue技术开发 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#x…...

【GIT版本控制】--远程仓库

一、连接远程仓库 连接到远程仓库是在GIT中进行协作和备份的关键步骤。以下是连接到远程仓库的基本步骤&#xff1a; 获取远程仓库的URL&#xff1a;首先&#xff0c;你需要获得远程仓库的URL。通常&#xff0c;这是远程仓库提供给你的&#xff0c;可以是HTTPS或SSH URL。例如…...

1:Allotment,2:FeeSell,3:混合Allotment+FreeSell

根据您的描述&#xff0c;这似乎是与酒店预订相关的三种不同的方式。下面是对这三种方式的解释&#xff1a; Allotment&#xff08;配额&#xff09;&#xff1a;这是一种酒店预订方式&#xff0c;其中您可以与酒店签订协议&#xff0c;并购买其一定数量的房间或床位。在此之后…...

NFT Insider#110:The Sandbox与TB Media Global合作,YGG Web3游戏峰会阵容揭晓

引言&#xff1a;NFT Insider由NFT收藏组织WHALE Members、BeepCrypto出品&#xff0c;浓缩每周NFT新闻&#xff0c;为大家带来关于NFT最全面、最新鲜、最有价值的讯息。每期周报将从NFT市场数据&#xff0c;艺术新闻类&#xff0c;游戏新闻类&#xff0c;虚拟世界类&#xff0…...

在硅云上主机搭建wordpress并使用Astra主题和avada主题

目录 前言 准备 操作 DNS解析域名 云主机绑定域名 安装wordpress网站程序 网站内Astra主题设计操作 安装主题 网站内avada主题安装 上传插件 上传主题 选择网站主题 前言 一开始以为云虚拟主机和云服务器是一个东西&#xff0c;只不过前者是虚拟的后者是不是虚拟的…...

基于SSM+Vue的物流管理系统的设计与实现

末尾获取源码 开发语言&#xff1a;Java Java开发工具&#xff1a;JDK1.8 后端框架&#xff1a;SSM 前端&#xff1a;VueHTML 数据库&#xff1a;MySQL5.7和Navicat管理工具结合 服务器&#xff1a;Tomcat8.5 开发软件&#xff1a;IDEA / Eclipse 是否Maven项目&#xff1a;是 …...

【洛谷】P1114 “非常男女”计划

思路&#xff1a;思路和上一篇一模一样哒~&#xff08;这里就不多解释啦&#xff09; ACcode: #include <iostream> #include <cstring> #include <algorithm> using namespace std; const int N 2e510; int n,a[N],f[N]; int main() { ios::sync_with_st…...

list中符合 多条件中筛选符合条件的值

//查找身高在1.8米及以上的男生 // List<SsxlwdBean> boys list.stream().filter(s->s.getGender() && s.getHeight() > 1.8).collect(Collectors.toList()); xlseachitem list.stream().filter(list->list.xlname.contains(Upstrquery)||list.xlbm.…...

Amber中的信息传递——章节1.2-第三部分

程序列表 Amber 包含大量旨在帮助您进行化学系统计算研究的程序&#xff0c;而且发布的工具数量还在定期增加。 本节列出了 AmberTools 包含的主要程序。 这里列出了套件中包含的每个程序&#xff0c;并简要介绍了其主要功能&#xff0c;同时提供了相关文档参考。 对于大多数程…...

【嵌入式】常用串口协议与转换芯片详解

文章目录 0 前言1 一个通信的协议的组成2 常用协议名词解释2.1 UART2.2 RS-2322.3 RS-4852.4 RS-422 3 常用的芯片3.1 MAX2323.2 CP21023.3 CH3403.4 FT232 0 前言 最近有点想研究USB协议&#xff0c;正好也看到有评论说对如何选择USB转串口模块有些疑惑&#xff0c;其实我也一…...

缓存与数据库双写一致性问题解决方案

其实如果使用缓存&#xff0c;就会出现缓存和数据库的不一致问题&#xff0c;关键在于我们可以接受不一致的时间是多少&#xff0c;根据不同的需求采取不同的实现方案。 第一种&#xff1a;先更新数据库后更新缓存 做法简单&#xff0c;但是并发写情况下&#xff0c;会出现数…...

Java中的transient关键字是什么意思?

Java中的transient关键字是什么意思&#xff1f; 在Java中&#xff0c;transient 是一个关键字&#xff0c;用于修饰实例变量&#xff08;成员变量&#xff09;。当一个实例变量被声明为transient 时&#xff0c;它的值不会被持久化&#xff08;即不会被序列化&#xff09;。 …...

内存溢出和内存泄漏

内存溢出和内存泄漏 内存溢出 内存溢出相对于内存泄漏来说&#xff0c;尽管更容易被理解&#xff0c;但是同样的&#xff0c;内存溢出也是引发程序崩溃的罪魁祸首之一。由于GC一直在发展&#xff0c;所以一般情况下&#xff0c;除非应用程序占用的内存增长速度非常快&#xf…...

wordpress+翻页函数/电商网站搭建

神目智能迎宾系统是通过人脸识别&#xff0c;对访客进行欢迎/告警的智能系统&#xff0c;该系统分为两部分&#xff0c;前端部分可进行访客人脸检测、欢迎界面呈现、智能语音播报欢迎词&#xff1b;后台部分可接入神目弼马温管理平台&#xff0c;对数据进行多样化分类管理&…...

手机能用的网站/内部优化

信号处理 信号处理是指信号的表示&#xff0c;变换和运算以及提取它们所包含的信息。如我们可以分开两个或多个混在一起的信号&#xff0c;或者增强信号中某些成分的参数。 信号处理基础 信号分为数字信号和模拟信号&#xff0c;在计算机中连续信号只能让信号的离散时间间隔…...

信息网站设计案例/猪肉价格最新消息

2月11日消息&#xff0c;日前&#xff0c;红黄蓝教育&#xff08;NYSE&#xff1a;RYB&#xff09;发布公告&#xff0c;宣布以1.25亿元收购新加坡一家民营儿童教育集团近70%的股权&#xff0c;并即将从“RYB Education”更名为“GEH Education”&#xff0c;目前尚未公布对应的…...

网站流量如何转化为钱/网页生成

&#xff08;注&#xff1a;简介基于IDEA的版本为&#xff1a;11.0&#xff0c;下载地址&#xff1a; http://www.jetbrains.com/idea/ &#xff09; 打开IDEA&#xff0c;&#xff08;当第一次打开的时候出现的是一个欢迎页面&#xff0c;随便创建一个project来进入到IDEA的主…...

石家庄建网站挣钱优帮云/湖南网络推广机构

理解Session可以从其名称入手&#xff0c;翻译成汉语具有"会话"的意思。所谓的会话是一段有始有终的交互操作&#xff0c;比如通过手机与朋友进行语音通话。特别说明&#xff1a;Session与sessionStorage并无联系&#xff0c;只因名称有所关联所以做一下介绍。一.HTT…...

dw中网站统计总访问量怎么做/冯站长之家官网

由于继电器不能直接用数字I/O驱动&#xff0c;因此常通过单片机IO口输出控制信号&#xff0c;控制三极管的截止或饱和来驱动继电器&#xff0c; 注意&#xff1a;由于三级管的功耗基本都消耗在集电极上&#xff0c;从半导体结构上看&#xff0c;晶体管的C极面积最大&#xff0c…...