注意力机制(四):多头注意力
专栏:神经网络复现目录
注意力机制
注意力机制(Attention Mechanism)是一种人工智能技术,它可以让神经网络在处理序列数据时,专注于关键信息的部分,同时忽略不重要的部分。在自然语言处理、计算机视觉、语音识别等领域,注意力机制已经得到了广泛的应用。
注意力机制的主要思想是,在对序列数据进行处理时,通过给不同位置的输入信号分配不同的权重,使得模型更加关注重要的输入。例如,在处理一句话时,注意力机制可以根据每个单词的重要性来调整模型对每个单词的注意力。这种技术可以提高模型的性能,尤其是在处理长序列数据时。
在深度学习模型中,注意力机制通常是通过添加额外的网络层实现的,这些层可以学习到如何计算权重,并将这些权重应用于输入信号。常见的注意力机制包括自注意力机制(self-attention)、多头注意力机制(multi-head attention)等。
总之,注意力机制是一种非常有用的技术,它可以帮助神经网络更好地处理序列数据,提高模型的性能。
文章目录
- 注意力机制
- 多头注意力
- 数学逻辑
- 实现
多头注意力
多头注意力(Multi-Head Attention)是注意力机制的一种扩展形式,可以在处理序列数据时更有效地提取信息。
在标准的注意力机制中,我们计算一个加权的上下文向量来表示输入序列的信息。而在多头注意力中,我们使用多组注意力权重,每组权重可以学习到不同的语义信息,并且每组权重都会产生一个上下文向量。最后,这些上下文向量会被拼接起来,再通过一个线性变换得到最终的输出。
多头注意力是Transformer模型中的一个重要组成部分,被广泛用于各种自然语言处理任务,如机器翻译、文本分类等。
数学逻辑
在实现多头注意力之前,让我们用数学语言将这个模型形式化地描述出来。 给定查询q∈Rdqq\in R^{d_q}q∈Rdq、 键k∈Rdkk\in R^{d_k}k∈Rdk和值v∈Rdvv\in R^{d_v}v∈Rdv, 每个注意力头的计算方法为:
hi=f(Wi(q)q,Wi(k)k,Wi(v)v)∈Rpvh_i=f(W_i^{(q)}q,W_i^{(k)}k,W_i^{(v)}v)\in R^{pv}hi=f(Wi(q)q,Wi(k)k,Wi(v)v)∈Rpv
其中,可学习的参数包括 Wi(q)W_i^{(q)}Wi(q)、 Wi(k)W_i^{(k)}Wi(k)和 Wi(v)W_i^{(v)}Wi(v), 以及代表注意力汇聚的函数fff。 fff可以是加性注意力和缩放点积注意力。 多头注意力的输出需要经过另一个线性转换, 它对应着hhh个头连结后的结果,因此其可学习参数是 WoW_oWo:
实现
在实现过程中通常选择缩放点积注意力作为每一个注意力头。 为了避免计算代价和参数代价的大幅增长, 我们设定pq=pk=pv=pp/hp_q=p_k=p_v=p_p/hpq=pk=pv=pp/h。 值得注意的是,如果将查询、键和值的线性变换的输出数量设置为pqh=pkh=pvh=ppp_qh=p_kh=p_vh=p_ppqh=pkh=pvh=pp, 则可以并行计算hhh个头。 在下面的实现中,是通过参数pop_oponum_hiddens指定的。
#@save
class MultiHeadAttention(nn.Module):"""多头注意力"""def __init__(self, key_size, query_size, value_size, num_hiddens,num_heads, dropout, bias=False, **kwargs):super(MultiHeadAttention, self).__init__(**kwargs)self.num_heads = num_headsself.attention = d2l.DotProductAttention(dropout)self.W_q = nn.Linear(query_size, num_hiddens, bias=bias)self.W_k = nn.Linear(key_size, num_hiddens, bias=bias)self.W_v = nn.Linear(value_size, num_hiddens, bias=bias)self.W_o = nn.Linear(num_hiddens, num_hiddens, bias=bias)def forward(self, queries, keys, values, valid_lens):# queries,keys,values的形状:# (batch_size,查询或者“键-值”对的个数,num_hiddens)# valid_lens 的形状:# (batch_size,)或(batch_size,查询的个数)# 经过变换后,输出的queries,keys,values 的形状:# (batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)queries = transpose_qkv(self.W_q(queries), self.num_heads)keys = transpose_qkv(self.W_k(keys), self.num_heads)values = transpose_qkv(self.W_v(values), self.num_heads)if valid_lens is not None:# 在轴0,将第一项(标量或者矢量)复制num_heads次,# 然后如此复制第二项,然后诸如此类。valid_lens = torch.repeat_interleave(valid_lens, repeats=self.num_heads, dim=0)# output的形状:(batch_size*num_heads,查询的个数,# num_hiddens/num_heads)output = self.attention(queries, keys, values, valid_lens)# output_concat的形状:(batch_size,查询的个数,num_hiddens)output_concat = transpose_output(output, self.num_heads)return self.W_o(output_concat)
为了能够使多个头并行计算, 上面的MultiHeadAttention类将使用下面定义的两个转置函数。 具体来说,transpose_output函数反转了transpose_qkv函数的操作。
#@save
def transpose_qkv(X, num_heads):"""为了多注意力头的并行计算而变换形状"""# 输入X的形状:(batch_size,查询或者“键-值”对的个数,num_hiddens)# 输出X的形状:(batch_size,查询或者“键-值”对的个数,num_heads,# num_hiddens/num_heads)X = X.reshape(X.shape[0], X.shape[1], num_heads, -1)# 输出X的形状:(batch_size,num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)X = X.permute(0, 2, 1, 3)# 最终输出的形状:(batch_size*num_heads,查询或者“键-值”对的个数,# num_hiddens/num_heads)return X.reshape(-1, X.shape[2], X.shape[3])#@save
def transpose_output(X, num_heads):"""逆转transpose_qkv函数的操作"""X = X.reshape(-1, num_heads, X.shape[1], X.shape[2])X = X.permute(0, 2, 1, 3)return X.reshape(X.shape[0], X.shape[1], -1)
代码解释:
这段代码实现了多头注意力机制,其中 MultiHeadAttention 类实现了多头注意力的前向传播, transpose_qkv 函数将输入的 queries, keys, values 通过线性变换并按照 num_heads 进行分组,最终输出变换后的 queries, keys, values,在前向传播中使用这些变换后的 queries, keys, values 来计算注意力权重。在 transpose_qkv 函数的实现中,首先将 queries, keys, values 转换成形状为 (batch_size, queries/keys/values_num, num_hiddens) 的张量,然后根据 num_heads 将最后一维进行分组,变换成形状为 (batch_size, num_heads, queries/keys/values_num, num_hiddens/num_heads) 的张量,最后将第一维和第二维进行交换,输出形状为 (batch_size*num_heads, queries/keys/values_num, num_hiddens/num_heads) 的张量。transpose_output 函数实现了对 MultiHeadAttention 的输出进行逆转换的操作。
这么做的原因是因为多头注意力机制可以将输入张量进行 num_heads 个独立的注意力计算,将计算结果在最后一维拼接起来作为输出,这样可以提高模型的并行性,加快计算速度。同时,通过变换形状将 num_heads 独立处理,也可以增强模型对不同位置和特征的表征能力。
具体来说,这段代码实现的是一个MultiHeadAttention类,其中定义了一个forward方法。这个方法接收一个查询序列queries,一个键序列keys,一个值序列values和一个有效长度序列valid_lens作为输入,然后输出一个加权聚合的结果。
MultiHeadAttention类的初始化方法中,我们定义了几个线性层,以及注意力计算函数,然后用这些组件来定义一个多头注意力层。该层包括将输入queries、keys和values通过三个线性层进行变换,以便将它们的形状变为(batch_size * num_heads,查询或者“键-值”对的个数,num_hiddens/num_heads),其中num_heads表示注意力头的数量。然后,我们通过调用transpose_qkv函数对这些变换后的输入进行一次变换,以便在注意力计算函数中实现多头并行计算。最后,我们通过调用transpose_output函数将输出重构成(batch_size,查询的个数,num_hiddens),并通过一个线性层对其进行变换,输出最终结果。
transpose_qkv函数将输入的queries、keys和values通过reshape和permute操作进行变换,以便多头并行计算。具体来说,它将输入变换为(batch_size,查询或者“键-值”对的个数,num_heads,num_hiddens/num_heads)的形状,然后将第2和第3个轴进行交换。最后,它将输出变换为(batch_size * num_heads, 查询或者“键-值”对的个数, num_hiddens/num_heads)的形状。
transpose_output函数将多头并行计算得到的输出通过reshape和permute操作逆转回原来的形状,具体来说,它将输出变换为(batch_size,查询的个数,num_heads, num_hiddens/num_heads)的形状,然后将第2和第3个轴进行交换,最终将输出变换为(batch_size,查询的个数,num_hiddens)的形状。
这里似乎所有的单头都是同一些参数,这样不会导致每个单头的输出都是一样的吗?
这里的确有点难懂, 这里其实是把所有注意力头里面的参数拼起来, 变成了一个大的全连接层
相关文章:
![](https://img-blog.csdnimg.cn/ce588b4c14fa459797854759a528ac1f.png)
注意力机制(四):多头注意力
专栏:神经网络复现目录 注意力机制 注意力机制(Attention Mechanism)是一种人工智能技术,它可以让神经网络在处理序列数据时,专注于关键信息的部分,同时忽略不重要的部分。在自然语言处理、计算机视觉、语…...
![](https://www.ngui.cc/images/no-images.jpg)
【2023Unity游戏开发教程】零基础带你从小白到超神19——射线检测
文章目录 射线检测从某点发射一条射线从摄像机发射一条射线射线检测 游戏中的红外线,默认肉眼是看不到的,从某个初始点开始,沿着特定的方向发射一条不可见且无限长的射线,通过此射线检测是否有任何模型添加了Collider碰撞器组件。一旦检测到碰撞,停止射线继续发射。 碰撞检…...
![](https://img-blog.csdnimg.cn/1b921ab9984e4de6a56fbd8e21584728.png)
内存泄漏和内存溢出的区别
参考答案 内存溢出(out of memory):指程序在申请内存时,没有足够的内存空间供其使用,出现 out of memory。内存泄露(memory leak):指程序在申请内存后,无法释放已申请的内存空间,内存泄露堆积会导致内存被…...
![](https://www.ngui.cc/images/no-images.jpg)
文本三剑客之sed编辑器
文本三剑客:都是按行读取后处理。 grep 过滤行内容。awk 过滤字段。sed 过滤行内容;修改行内容。sed编辑器 sed是一种流编辑器,流编辑器会在编辑器处理数据之前基于预先提供的一组规则来编辑数据流。 sed编辑器可以根据命令来处理数据流中…...
![](https://img-blog.csdnimg.cn/514963c6d84945c385e591752b89fb02.png)
深度学习:GPT1、GPT2、GPT-3
深度学习:GPT1、GPT2、GPT3的原理与模型代码解读GPT-1IntroductionFramework自监督学习微调ExperimentGPT-2IntroductionApproachConclusionGPT-3GPT-1 Introduction GPT-1(Generative Pre-training Transformer-1)是由OpenAI于2018年发布的…...
![](https://img-blog.csdnimg.cn/bc7f8b5cd20b4ffabc707b6bf0c88373.png)
使用Docker 一键部署SpringBoot和SpringCloud项目
使用Docker 一键部署SpringBoot和SpringCloud项目 1. 准备工作2. 创建Dockerfile3. 创建Docker Compose文件4. 构建和运行Docker镜像5. 验证部署6. 总结Docker是一个非常流行的容器化技术,可以方便地将应用程序和服务打包成容器并运行在不同的环境中。在本篇博客中,我将向您展…...
![](https://img-blog.csdnimg.cn/844af129ad1146f2a05f5048d354aff7.png)
【数据结构】用栈实现队列
💯💯💯 本篇总结利用栈如何实现队列的相关操作,不难观察,栈和队列是可以相互转化的,需要好好总结它们的特性,构造出一个恰当的结构来实现即可,所以本篇难点不在代码思维,…...
![](https://img-blog.csdnimg.cn/3afb635dbeab437f9180eda75382f05f.png)
[Netty源码] 服务端启动过程 (二)
文章目录1.ServerBootstrap2.服务端启动过程3.具体步骤分析3.1 创建服务端Channel3.2 初始化服务端Channel3.3 注册selector3.4 端口绑定1.ServerBootstrap ServerBootstrap引导服务端启动流程: //主EventLoopGroup NioEventLoopGroup master new NioEventLoopGroup(); //从E…...
![](https://www.ngui.cc/images/no-images.jpg)
Week 14
代码源每日一题Div2 106. 订单编号 原题链接:订单编号 思路:这题本来没啥思路,直到获得了某位佬的提示才会做( 我们可以用set来维护一些区间,这些区间为 pair 类型,表示没有使用过的编号,每次…...
![](https://img-blog.csdnimg.cn/6aa685942b77487a82ba2a55ae150914.png#pic_center)
【微信小程序】-- 使用 Git 管理项目(五十)
💌 所属专栏:【微信小程序开发教程】 😀 作 者:我是夜阑的狗🐶 🚀 个人简介:一个正在努力学技术的CV工程师,专注基础和实战分享 ,欢迎咨询! &…...
![](https://www.ngui.cc/images/no-images.jpg)
leetcode每日一题:134. 加油站
系列:贪心算法 语言:java 题目来源:Leetcode134. 加油站 题目 在一条环路上有 n 个加油站,其中第 i 个加油站有汽油 gas[i] 升。 你有一辆油箱容量无限的的汽车,从第 i 个加油站开往第 i1 个加油站需要消耗汽油 cost[…...
![](https://img-blog.csdnimg.cn/img_convert/54b958bb366b59b4284c0ec2f13790ab.png)
开放式基金实时排行 API 数据接口
开放式基金实时排行 API 数据接口 多维度参数返回,实时数据,类型参数筛选。 1. 产品功能 返回实时开放式基金排行数据可定义查询基金类型参数;多个基金属性值返回多维指标,一次查询毫秒级返回;数据持续更新与维护&am…...
![](https://img-blog.csdnimg.cn/504aa1a6591b45bc92273693fe60071f.png)
Android开发中synchronized的实现原理
synchronized的三种使用方式 **1.修饰实例方法,**作用于当前实例加锁,进入同步代码前要获得当前实例的锁。 没有问题的写法: public class AccountingSync implements Runnable{//共享资源(临界资源)static int i0;/*** synchronized 修饰实例方法*/p…...
![](https://www.ngui.cc/images/no-images.jpg)
【华为OD机试 2023最新 】 统一限载货物数最小值(C++)
题目描述 火车站附近的货物中转站负责将到站货物运往仓库,小明在中转站负责调度2K辆中转车(K辆干货中转车,K辆湿货中转车)。 货物由不同供货商从各地发来,各地的货物是依次进站,然后小明按照卸货顺序依次装货到中转车,一个供货商的货只能装到一辆车上,不能拆装,但是…...
![](https://img-blog.csdnimg.cn/9d2118e4d10843adae5c350ef27f7b2d.png)
【生活工作经验 十】ChatGPT模型对话初探
最近探索了下全球大火的ChatGPT,想对此做个初步了解 一篇博客 当今社会,自然语言处理技术得到了迅速的发展,人工智能技术也越来越受到关注。其中,基于深度学习的大型语言模型,如GPT(Generative Pre-train…...
![](https://img-blog.csdnimg.cn/675e6a7223364124adc93b803af73dd6.png)
基于Spring Boot房产销售平台的设计与实现【源码+论文】分享
开发语言:Java 框架:springboot JDK版本:JDK1.8 服务器:tomcat7 数据库:mysql 5.7 数据库工具:Navicat11 开发软件:eclipse/myeclipse/idea Maven包:Maven3.3.9 摘要 信息技术的发展…...
![](https://img-blog.csdnimg.cn/24ec1eb9e0a04924a0bf417e305f6322.jpeg)
不同类型的电机的工作原理和控制方法汇总
电机控制是指对电机的启动、调速(加速、减速)、运转方向和停止进行的控制,不同类型的电机有着不同的工作原理和控制方法。 一、无刷电机 无刷电机是由电机主体和电机驱动板组成的一种没有电刷和换向器的机电一体化产品。在无刷电机中…...
![](https://img-blog.csdnimg.cn/24b696d76d374a9992017e1625389592.gif)
计算机网络管理 TCP三次握手的建立过程,Wireshark抓包分析并验证TCP三次握手建立连接的报文
⬜⬜⬜ ---🟧🟨🟩🟦🟪 (*^▽^*)欢迎光临 🟧🟨🟩🟦🟪---⬜⬜⬜ ✏️write in front✏️ 📝个人主页:陈丹宇jmu 🎁欢迎各位→…...
![](https://img-blog.csdnimg.cn/img_convert/aac7bd8f15d34416929c0c47ca30b811.png)
HTTP/2.x:最新的网页加载技术,快速提高您的SEO排名
2.1 http2概念HTTP/2.0(又称HTTP2)是HTTP协议的第二个版本。它是对HTTP/1.x的更新,旨在提高网络性能和安全性。HTTP/2.0是由互联网工程任务组(IETF)标准化的,并于2015年发布。2.2 http2.x与http1.x区别HTTP…...
![](https://www.ngui.cc/images/no-images.jpg)
机器学习----线性回归
第一关:简单线性回归与多元线性回归 1、下面属于多元线性回归的是? A、 求得正方形面积与对角线之间的关系。 B、 建立股票价格与成交量、换手率等因素之间的线性关系。 C、 建立西瓜价格与西瓜大小、西瓜产地、甜度等因素之间的线性关系。 D、 建立西瓜…...
![](https://img-blog.csdnimg.cn/c4e861f15f634b6aa0f5c56756d6081b.png)
MS2131 USB 3.0 高清音视频采集+HDMI 环出+混音处理芯片 应用网络直播一体机
MS2131 是一款 USB 3.0 高清视频和音频采集处理芯片,内部集成 USB 3.0 Device 控制器、 数据收发模块、音视频处理模块。MS2131 可以通过 USB 3.0 接口将 HDMI 输入的音视频信号传 送到 PC、智能手机、平板电脑上预览或采集。MS2131 支持 HDMI 环出功能,…...
![](https://img-blog.csdnimg.cn/img_convert/38c0bbb82b2fd56645b1a3bafc17a836.png)
基于堆与AdjustDown的TOP-K问题
TIPSTOP-K问题TOP-K问题:就是说现在比如说有n个数据,然后需要从这n个数据里面找到最大的或最小的前k个。一般来讲思路的话就是:先把这n个数据给他建一个堆,建堆完成之后,然后就去调堆,然后大概只需要调k次&…...
![](https://upic.fenxiangbe.com/uPic/2023/03/23/image-20230323191424001.png)
在CentOS上安装Docker引擎
1,先决条件#### 1-1操作系统要求1-2 卸载旧版本 2,安装方法2-1使用存储库安装设置存储库安装 Docker 引擎 本文永久更新地址: 官方地址:https://docs.docker.com/engine/install/centos/ 1,先决条件 #### 1-1操作系统要求 要安装 Docker Engine,您需要…...
![](https://img-blog.csdnimg.cn/ef751b0573a74d0c89769c9eb8a0111e.jpeg)
【10】核心易中期刊推荐——模式识别与机器学习
🚀🚀🚀NEW!!!核心易中期刊推荐栏目来啦 ~ 📚🍀 核心期刊在国内的应用范围非常广,核心期刊发表论文是国内很多作者晋升的硬性要求,并且在国内属于顶尖论文发表,具有很高的学术价值。在中文核心目录体系中,权威代表有CSSCI、CSCD和北大核心。其中,中文期刊的数…...
![](https://www.ngui.cc/images/no-images.jpg)
【数据结构】并查集
目录 一:用途 二:实现 O(1) 三:例题 例题1:集合 例题2:连通图无向 例题3:acwing 240 食物链 一:用途 将两个集合合并询问两个元素是否在一个集合当中 二:实现 O(1) 每…...
![](https://www.ngui.cc/images/no-images.jpg)
软考--网络攻击分类
网络攻击的主要手段包括口令入侵、放置特洛伊木马程序、拒绝服务(DoS)攻击、端口扫描、网络监听、欺骗攻击和电子邮件攻击等。口令入侵是指使用某些合法用户的账号和口令登录到目的主机,然后再实施攻击活动。特洛伊木马(Trojans)程序常被伪装…...
![](https://img-blog.csdnimg.cn/f8b284a4e903400a9f74036e9221054b.jpeg#pic_center)
蓝桥杯刷题冲刺 | 倒计时17天
作者:指针不指南吗 专栏:蓝桥杯倒计时冲刺 🐾马上就要蓝桥杯了,最后的这几天尤为重要,不可懈怠哦🐾 文章目录1.长草2.分考场1.长草 题目 链接: 长草 - 蓝桥云课 (lanqiao.cn) 题目描述 小明有一…...
![](https://img-blog.csdnimg.cn/88848a17a762461fb3e308fca286aefc.png)
冲击蓝桥杯-并查集,前缀和,字符串
目录 前言 一、并查集 1、并查集的合并(带路径压缩) 2、询问是否为同一个集合 3、例题 二、前缀和 1 、前缀和是什么 2、经典题目 三- 字符串处理 1、字符串的插入 2、字符串转化为int类型 3、字符反转 前言 并查集合前缀,字符串…...
![](https://img-blog.csdnimg.cn/85c27429cfa042b789381d1d350a840d.png)
【matlab学习笔记】线性方程组求解方法
线性方程组求解方法2.1 求逆法实现方式例子2.2 分解法LU分解(Doolittle分解)实现方法例子QR分解法实现方法例子Cholesky 分解法实现方法例子奇异值分解法实现方法例子Hessenberg 分解实现方法例子Schur 分解实现方法例子2.3 迭代法逐次迭代法里查森迭代法…...
![](https://img-blog.csdnimg.cn/ebd6368985ad4404acfc46a2de000eae.gif)
Python带你一键下载到最新章节,不付费也能看
前言 大家早好、午好、晚好吖 ❤ ~欢迎光临本文章 完整源码、素材皆可点击文章下方名片获取此处跳转 开发环境: python 3.8 运行代码 pycharm 2022.3 辅助敲代码 requests 发送请求/第三方模块 模块安装:win R 输入cmd 输入安装命令 pip install 模块名 如果…...
![](/images/no-images.jpg)
网站的基础服务/新手如何学seo
Problem Description输入1个正整数n,计算1(12)(123)...(123...n)Input输入正整数n(多组数据)Output输出1(12)(123)...(123...n)的值(每组数据一行)Sample Input2Sample Output4#includeusing namespace std;int main(){int n,t;long sum;while(cin>>n){sum;t;for(int i;i…...
![](/images/no-images.jpg)
网站开发需要学多久/谷歌推广开户多少费用
选择排序(Selection sort)跟插入排序一样,也是O(n^2)的复杂度,这个排序方式也可以用我们的扑克牌来解释。 概念 桌面上有一堆牌,也是杂乱无章的,现在我们想将牌由小到大排序,如果使用选择排序来做ÿ…...
![](/images/no-images.jpg)
运用asp做购物网站的心得/seo优化员
一个规则的实心十二面体,它的 20个顶点标出世界著名的20个城市,你从一个城市出发经过每个城市刚好一次后回到出发的城市。 Input前20行的第i行有3个数,表示与第i个城市相邻的3个城市.第20行以后每行有1个数m,m<20,m>1.m0退出. Output输出从第m个城…...
![](/images/no-images.jpg)
做网站的工作轻松吗/设计一个公司网站多少钱
1.环境概述 虚拟机系统:CentOS Linux release 7.3.1611 (Core) 宿主机系统:Mac Sierra version 10.12.3 nginx:1.10.3 php:7.1.2 2.虚拟机 为了使得虚拟机和主机互通且虚拟机能联网,在安装系统之前需要设置网络。在当前…...
![](/images/no-images.jpg)
昆明出入最新规定/湖北短视频seo营销
1、执行如下语句获取删除语句 SELECT CONCAT( drop table , table_name, ; ) from information_schema.tableswhere table_schema数据库名 and table_typebase table 2、拷贝语句,然后复制到nvicat进行执行 转载于:https://www.cnblogs.com/javabg/p/10083945.html…...
![](https://img-blog.csdnimg.cn/img_convert/631489e60d8806b91f3b68a3285dc3e5.png)
怎么做相亲网站/seo公司怎么推广宣传
一、什么是表?但凡是用过MySQL都知道,直观上看,MySQL的数据都存在数据表中。比如一条Update SQL:update user set username 白日梦 where id 999;它将user这张数据表中id为1的记录的username列修改成了‘白日梦这里的user其实就…...