Scaled_dot_product_attention(SDPA)使用详解
在学习huggingFace的Transformer库时,我们不可避免会遇到scaled_dot_product_attention(SDPA)这个函数,它被用来加速大模型的Attention计算,本文就详细介绍一下它的使用方法,核心内容主要参考了torch.nn.functional中该函数的注释。
1. Attention计算公式
Attention的计算主要涉及三个矩阵:Q、K、V。我们先不考虑multi-head attention,只考虑one head的self attention。在大模型的prefill阶段,这三个矩阵的维度均为N x d,N即为上下文的长度;在decode阶段,Q的维度为1 x d, KV还是N x d。然后通过下面的公式计算attention矩阵:
O = A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d ) V O=Attention(Q,K,V)=softmax(\frac{QK^T}{\sqrt d})V O=Attention(Q,K,V)=softmax(dQKT)V
在真正使用attention的时候,我们往往采用multi-head attention(MHA)。MHA的计算公式和one head attention基本一致,它改变了Q、K、V每一行的定义:将维度d的向量分成h组变成一个h x dk的矩阵,Q、K、V此时成为了 N ∗ h ∗ d k N * h * d_k N∗h∗dk的三维矩阵(不考虑batch维)。分别将Q、K、V的第一和第二维进行转置得到三个维度为 h ∗ N ∗ d k h * N * d_k h∗N∗dk的三维矩阵。此时的三个矩阵就是具有h个头的Q、K、V,我们就可以按照self attention的定义计算h个头的attention值。
不过,在真正进行大模型推理的时候就会发现KV Cache是非常占显存的,所以大家尝试各种手段压缩KV Cache,具体可以参考《大模型推理–KV Cache压缩》。一种手段就是将MHA替换成group query attention(GQA),这块在torch2.5以上的SDPA中也已经得到了支持。
2. SDPA伪代码
在SDPA的注释中,给出了伪代码:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0,is_causal=False, scale=None, enable_gqa=False) -> torch.Tensor:L, S = query.size(-2), key.size(-2)scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scaleattn_bias = torch.zeros(L, S, dtype=query.dtype)if is_causal:assert attn_mask is Nonetemp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))attn_bias.to(query.dtype)if attn_mask is not None:if attn_mask.dtype == torch.bool:attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))else:attn_bias += attn_maskif enable_gqa:key = key.repeat_interleave(query.size(-3)//key.size(-3), -3)value = value.repeat_interleave(query.size(-3)//value.size(-3), -3)attn_weight = query @ key.transpose(-2, -1) * scale_factorattn_weight += attn_biasattn_weight = torch.softmax(attn_weight, dim=-1)attn_weight = torch.dropout(attn_weight, dropout_p, train=True)return attn_weight @ value
可以看出,我们实际在使用SDPA时除了query、key和value之外,还有另外几个参数:attn_mask、dropout_p、is_causal、scale和enable_gqa。scale就是计算Attention时的缩放因子,一般无需传递。dropout_p表示Dropout概率,在推理阶段也不需要传递,不过官方建议如下输入:dropout_p=(self.p if self.training else 0.0)。我们着重看一下另外三个参数在使用时该如何设置。
先看enable_gqa。前面提到GQA是一种KV Cache压缩方法,MHA的KV和Q一样,也会有h个头,GQA则将KV的h个头进行压缩来减小KV Cache的大小。比如Qwen2-7B-Instruct这个模型,Q的h等于28,KV的h等于4,相当于把KV Cache压缩到之前的七分之一。GQA虽然压缩了KV Cache,但是真正要计算Attention的时候还是需要对齐KV与Q的head数,所以我们可以看到HF Transformer库中的qwen2.py在Attention计算时会有一个repeat_kv的操作,目的就是将QKV的head数统一。在torch2.5以后的版本中,我们无需再手动去执行repeat_kv,直接将SDPA的enable_gqa设置为True即可自动完成repeat_kv,而且速度比自己去做repaet_kv还要更快。
attn_mask和is_causal两个参数的作用相同,目的都是要给softmax之前的QKT矩阵添加mask。只不过attn_mask是自己在外面构造mask矩阵,is_causal则是根据大模型推理的阶段属于prefill还是decode来进行设置。通过看伪代码可以看出,SDPA会首先构造一个L x S的零矩阵attn_bias,L表示Q的上下文长度,S表示KV Cache的长度。在prefill阶段,L和S相等,在decode阶段,L为1,S还是N。所以在prefill阶段,attn_bias就是一个N x N的矩阵,将is_causal设置为True时就会构造一个下三角为0,上三角为负无穷的矩阵作为attn_bias,然后将其加到QKT矩阵上,这样就实现了因果关系的Attention计算。在decode阶段,attn_bias就是一个1 x N的向量,此时可以将is_causal设置为False,attn_bias始终为0就不会对 Q K T QK^T QKT行向量产生影响,表示KV Cache所有的行都参与计算,因果关系保持正确。
attn_mask作用和is_causal一样,但是需要我们自行构造,如果你对如何构造不了解建议就使用is_causal选项,prefill阶段设置为True,decode阶段设置为False,attn_mask设置为None。不过,如果prefill按照chunk来执行也即chunk_prefill阶段,我们会发现is_causal设置为True时的attn_bias设置的不正确,我们不是从左上角开始构造下三角矩阵,而是要从右下角开始构造下三角矩阵,这种情况下我们可以从外面自行构造attn_mask矩阵代替SDPA的构造。attn_mask有两种构造方式,一种是bool类型,True的位置会保持不变,False的位置会置为负无穷;一种是float类型,会直接将attn_mask加到SDPA内部的attn_bias上,和bool类型一样,我们一般是构造一个下三角为0上三角为负无穷的矩阵。总结来说,绝大多数情况下我们只需要设置is_causal选项,prefill阶段设置为True,decode阶段设置为False,attn_mask设置为None即可。如果推理阶段引入了chunk_prefill,则我们需要自行构造attn_mask,但是要注意构造的attn_mask矩阵是从右下角开始的下三角矩阵。
3. SDPA实现(翻译自SDPA注释)
目前SDPA有三种实现:
- 基于FlashAttention-2的实现;
- Memory-Efficient Attention(facebook xformers);
- Pytorch版本对上述伪代码的c++实现(对应MATH后端)。
针对CUDA后端,SDPA可能会调用经过优化的内核以提高性能。对于所有其他后端,将使用PyTorch实现。所有实现方式默认都是启用的,SDPA会尝试根据输入自动选择最优的实现方式。为了对使用哪种实现方式提供更细粒度的控制,torch提供了以下函数来启用和禁用各种实现方式:
- torch.nn.attention.sdpa_kernel:一个上下文管理器,用于启用或禁用任何一种实现方式;
- torch.backends.cuda.enable_flash_sdp:全局启用或禁用FlashAttention
- torch.backends.cuda.enable_mem_efficient_sdp:全局启用或禁用memory efficient attention
- torch.backends.cuda.enable_math_sdp:全局启用或禁用PyTorch的C++实现。
每个融合内核都有特定的输入限制。如果用户需要使用特定的融合实现方式,请使用torch.nn.attention.sdpa_kernel 禁用PyTorch 的C++实现。如果某个融合实现方式不可用,将会发出警告,说明该融合实现方式无法运行的原因。由于融合浮点运算的特性,此函数的输出可能会因所选择的后端内核而异。C++ 实现支持torch.float64,当需要更高精度时可以使用。对于math后端,如果输入是torch.half或torch.bfloat16类型,那么所有中间计算结果都会保持为torch.float类型。
4. SDPA使用示例
首先强调一点,灌入SDPA的QKV都是做过转置的,也即维度为batch x head x N x d,在老版本的torch中还需要QKV都是contiguous的,新版本下无此要求。SDPA注释中还给了两个示例,我们在此也给出:
# Optionally use the context manager to ensure one of the fused kernels is runquery = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")with sdpa_kernel(backends=[SDPBackend.FLASH_ATTENTION]):F.scaled_dot_product_attention(query,key,value)
上述示例中,给定的输入为batch等于32,head等于8,上下文长度128,embedding维度64,然后通过sdpa_kernel选择使用FlashAttention。
示例二:
# Sample for GQA for llama3
query = torch.rand(32, 32, 128, 64, dtype=torch.float16, device="cuda")
key = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
value = torch.rand(32, 8, 128, 64, dtype=torch.float16, device="cuda")
with sdpa_kernel(backends=[SDPBackend.MATH]):F.scaled_dot_product_attention(query,key,value,enable_gqa=True)
示例二演示了GQA的用法,给定的query head数为32,key和value均为8,此时我们可以通过enable_gqa选项来实现对GQA的支持,此外代码还通过sdpa_kernel选项使用了MATH后端。
5. 参考
- FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning
- Memory-Efficient Attention
- Grouped-Query Attention
- Attention Is All You Need
相关文章:

Scaled_dot_product_attention(SDPA)使用详解
在学习huggingFace的Transformer库时,我们不可避免会遇到scaled_dot_product_attention(SDPA)这个函数,它被用来加速大模型的Attention计算,本文就详细介绍一下它的使用方法,核心内容主要参考了torch.nn.functional中该函数的注释…...

Linux练级宝典->Linux进程概念介绍
目录 进程基本概念 PCB概念 task_struct tack_struct内容分类 PID和PPID fork函数创建子进程 进程优先级概念 4个名词 进程地址空间 进程地址空间的意义 内核进程调度队列 优先级 活动队列 过期队列 进程基本概念 一个正在执行的程序。担当分配系统资源的实体&#…...

OpenHarmony 5.0 mpegts封装的H265视频播放失败的解决方案
问题现象 OpenHarmony 5.0版本使用AVPlayer播放mpegts封装格式的H.265(HEVC)编码格式的视频时出现报错导致播放失败 问题原因 OpenHarmony 5.0版本AVPlayer播放器使用histreamer引擎,因为 libav_codec_hevc_parser.z.so 动态库未开源导致H265编码格式视频解析不到…...

Qt从入门到入土(九) -model/view(模型/视图)框架
简介 Qt的模型/视图(Model/View)架构是一种用于分离数据处理和用户界面展示的设计模式。它允许开发者将数据存储和管理(模型)与数据的显示和交互(视图)解耦,从而提高代码的可维护性和可扩展性。…...

缓存之美:Guava Cache 相比于 Caffeine 差在哪里?
大家好,我是 方圆。本文将结合 Guava Cache 的源码来分析它的实现原理,并阐述它相比于 Caffeine Cache 在性能上的劣势。为了让大家对 Guava Cache 理解起来更容易,我们还是在开篇介绍它的原理: Guava Cache 通过分段(…...

[漏洞篇]XSS漏洞详解
[漏洞篇]XSS漏洞 一、 介绍 概念 XSS:通过JS达到攻击效果 XSS全称跨站脚本(Cross Site Scripting),为避免与层叠样式表(Cascading Style Sheets, CSS)的缩写混淆,故缩写为XSS。这是一种将任意 Javascript 代码插入到其他Web用户页面里执行以…...

【Leetcode 每日一题】2269. 找到一个数字的 K 美丽值
问题背景 一个整数 n u m num num 的 k k k 美丽值定义为 n u m num num 中符合以下条件的 子字符串 数目: 子字符串长度为 k k k。子字符串能整除 n u m num num。 给你整数 n u m num num 和 k k k,请你返回 n u m num num 的 k k k 美丽值…...

IO进程线程(线程)
作业 1.创建两个线程,分支线程1拷贝文件的前一部分,分支线程2拷贝文件的后一部分 2.创建三个线程,实现线程A打印A,线程B打印B,线程C打印C;重复打印顺序ABC。 信号量实现: 条件变量实现&#x…...

1-002:MySQL InnoDB引擎中的聚簇索引和非聚簇索引有什么区别?
在 MySQL InnoDB 存储引擎 中,索引主要分为 聚簇索引(Clustered Index) 和 非聚簇索引(Secondary Index)。它们的主要区别如下: 1. 聚簇索引(Clustered Index) 定义 聚簇索引是表数…...

tomcat单机多实例部署
一、部署方法 多实例可以运行多个不同的应用,也可以运行相同的应用,类似于虚拟主机,但是他可以做负载均衡。 方式一: 把tomcat的主目录挨个复制,然后把每台主机的端口给改掉就行了。 优点是最简单最直接,…...

论文阅读分享——UMDF(AAAI-24)
概述 题目:A Unified Self-Distillation Framework for Multimodal Sentiment Analysis with Uncertain Missing Modalities 发表:The Thirty-Eighth AAAI Conference on Artificial Intelligence (AAAI-24) 年份:2024 Github:暂…...

解决asp.net mvc发布到iis下安全问题
解决asp.net mvc发布到iis下安全问题 环境信息1.The web/application server is leaking version information via the "Server" HTTP response2.确保您的Web服务器、应用程序服务器、负载均衡器等已配置为强制执行Strict-Transport-Security。3.在HTML提交表单中找不…...

概念|RabbitMQ 消息生命周期 待消费的消息和待应答的消息有什么区别
目录 消息生命周期 一、消息创建与发布阶段 二、消息路由与存储阶段 三、消息存活与过期阶段 四、消息投递与消费阶段 五、消息生命周期终止 关键配置建议 待消费的消息和待应答的消息 一、待消费的消息(Unconsumed Messages) 二、待应答的消息…...

springboot三层架构详细讲解
目录 springBoot三层架构 0.简介1.各层架构 1.1 Controller层1.2 Service层1.3 ServiceImpl1.4 Mapper1.5 Entity1.6 Mapper.xml 2.各层之间的联系 2.1 Controller 与 Service2.2 Service 与 ServiceImpl2.3 Service 与 Mapper2.4 Mapper 与 Mapper.xml2.5 Service 与 Entity2…...

2025最新群智能优化算法:云漂移优化(Cloud Drift Optimization,CDO)算法求解23个经典函数测试集,MATLAB
一、云漂移优化算法 云漂移优化(Cloud Drift Optimization,CDO)算法是2025年提出的一种受自然现象启发的元启发式算法,它模拟云在大气中漂移的动态行为来解决复杂的优化问题。云在大气中受到各种大气力的影响,其粒子的…...

2025年Draw.io最新版本下载安装教程,附详细图文
2025年Draw.io最新版本下载安装教程,附详细图文 大家好,今天给大家介绍一款非常实用的流程图绘制软件——Draw.io。不管你是平时需要设计流程图、绘制思维导图,还是制作架构图,甚至是简单的草图,它都能帮你轻松搞定。…...

记录--洛谷 P1451 求细胞数量
如果想查看完整题目,请前往洛谷 P1451 求细胞数量 P1451 求细胞数量 题目描述 一矩形阵列由数字 0 0 0 到 9 9 9 组成,数字 1 1 1 到 9 9 9 代表细胞,细胞的定义为沿细胞数字上下左右若还是细胞数字则为同一细胞,求给定矩形…...

Android Studio 配置国内镜像源
Android Studio版本号:2022.1.1 Patch 2 1、配置gradle国内镜像,用腾讯云 镜像源地址:https\://mirrors.cloud.tencent.com/gradle 2、配置Android SDK国内镜像 地址:Index of /AndroidSDK/...

做到哪一步才算精通SQL
做到哪一步才算精通SQL-Structured Query Language 数据定义语言 DDL for StructCREATE:用来创建数据库、表、索引等对象ALTER:用来修改已存在的数据库对象DROP:用来删除整个数据库或者数据库中的表TRUNCATE:用来删除表中所有的行…...

Manus演示案例: 英伟达财务估值建模 解锁投资洞察的深度剖析
在当今瞬息万变的金融投资领域,精准剖析企业价值是投资者决胜市场的关键。英伟达(NVIDIA),作为科技行业的耀眼明星,其在人工智能和半导体领域的卓越表现备受瞩目。Manus 凭借专业的财务估值建模能力,深入挖…...

postman接口请求中的 Raw是什么
前言 在现代的网络开发中,API 的使用已经成为数据交换的核心方式之一。然而,在与 API 打交道时,关于如何发送请求体(body)内容类型的问题常常困扰着开发者们,尤其是“raw”和“json”这两个术语之间的区别…...

DeepSeek大语言模型下几个常用术语
昨天刷B站看到复旦赵斌老师说的一句话“科幻电影里在人脑中植入芯片或许在当下无法实现,但当下可以借助AI人工智能实现人类第二脑”(大概是这个意思) 💞更多内容,可关注公众号“ 一名程序媛 ”,我们一起从 …...

ctf-WEB: 关于 GHCTF Message in a Bottle plus 与 Message in a Bottle 的非官方wp解法
Message in a Bottle from bottle import Bottle, request, template, runapp Bottle()# 存储留言的列表 messages [] def handle_message(message):message_items "".join([f"""<div class"message-card"><div class"me…...

测试用例详解
一、通用测试用例八要素 1、用例编号; 2、测试项目; 3、测试标题; 4、重要级别; 5、预置条件; 6、测试输入; 7、操作步骤; 8、预期输出 二、具体分析通…...

c#面试题整理7
1.UDP和TCP的区别 UDP是只要能连上终端就发送,至于终端是否收到,不管。 TCP则是会存在交换,即发送失败或成功,是可知的。 2.进程和线程的区别 双击一个程序的exe文件,程序执行了,这就是一个进程。 这个…...

OpenManus-通过源码方式本地运行OpenManus,含踩坑及处理方案,chrome.exe位置修改
前言:最近 Manus 火得一塌糊涂啊,OpenManus 也一夜之间爆火,那么作为程序员应该来尝尝鲜 1、前期准备 FastGithub:如果有科学上网且能正常访问 github 则不需要下载此软件,此软件是提供国内直接访问 githubGit&#…...

【性能测试】Jmeter下载安装、环境配置-小白使用手册(1)
本篇文章主要包含Jmeter的下载安装、环境配置 添加线程组、结果树、HTTP请求、请求头设置。JSON提取器的使用,用户自定义变量 目录 一:引入 1:软件介绍 2:工作原理 3:安装Jmeter 4:启动方式 …...

HTML星球大冒险之路线图
第一章:欢迎来到 HTML 星球! 1.1 宇宙的基石:HTML 是什么? 🌍 比喻:HTML 是网页世界的「乐高积木」,用标签搭建一切可见内容🎯 目标:理解 HTML 的作用,掌握…...

初识大模型——大语言模型 LLMBook 学习(一)
1. 大模型发展历程 🔹 1. 早期阶段(1950s - 1990s):基于规则和统计的方法 代表技术: 1950s-1960s:规则驱动的语言处理 早期的 NLP 主要依赖 基于规则的系统,如 Noam Chomsky 提出的 生成语法&…...

LabVIEW伺服阀高频振动测试
在伺服阀高频振动测试中,闭环控制系统的实时性与稳定性至关重要。针对用户提出的1kHz控制频率需求及Windows平台兼容性问题,本文重点分析NI PCIe-7842R实时扩展卡的功能与局限性,并提供其他替代方案的综合对比,以帮助用户选择适合…...