当前位置: 首页 > news >正文

TR4 - Transformer中的多头注意力机制


目录

  • 前言
  • 自注意力机制
    • Self-Attention层的具体机制
    • Self-Attention 矩阵计算
  • 多头注意力机制
    • 例子解析
  • 代码实现
  • 总结与心得体会


前言

多头注意力机制可以说是Transformer中最主要的模块,没有之一。这次我们来仔细分析一下注意力机制与多头注意力机制。

自注意力机制

在Transformer模型中,输入的文本序列经过输入处理转换为一个向量的序列,然后就会被送到第1层的编码器,第一层的编码器的输出同样是一个向量的序列,再送到下一层编码器。
encoder向量流动
通过上图可以发现,向量在层间流动时,向量的数量和维度都是不变的。单层编码器接收到上一层的输入,然后进入自注意力层计算,然后再输入到前馈神经网络中,最后得到每个位置的新向量。

Self-Attention层的具体机制

例如想要翻译的句子为:“The animal didn’t cross the street because it was too tired”。

句子中的it是一个代词,想要知道它具体代指什么,对模型来说并不容易。通过引用Self-Attention机制,模型就会最终计算出it代指的是animal。同样的,当模型处理句子中其他词时,Self-Attention机制也可以让模型不仅仅关注当前位置的词,还关注句中其它位置相关的词,进而更好地理解当前位置的词。

通过一个简单的例子来解释自注意力机制的计算过程:假设一句话为"Thinking Machines"。

自注意力会计算:Thinking-Thinking、Thinking-Machines、Machines-Thinking、Machines-Machines共2的2次方种组合。

具体的计算过程如下:

  • 1 对输入编码器的词向量进行线性变换,得到Query、Key和Value向量。变换的过程是通过词向量分别和3个参数矩阵相乘,参数矩阵可以通过模型训练学习到。

向量计算

  • 2 计算 Attention Score (注意力分数 )。

假如我们现在计算Thinking的Attention Score,需要根据Thinking对应的词向量,对句子中的其他词向量都计算一个分数,这些分数决定了在编码Thinking这个词时,对句子中其它位置的词向量的权重。

Attention Score 是根据Thinking对应的Query向量和其他位置的每个词的Key向量进行点积得到的。Thinking的第一个Attention Score 就是q1和k1的点积,第二个分数是 q 1 q_1 q1 k 2 k_2 k2的点积。
Attention Score计算

  • 3 把得到的每个分数除以 d k \sqrt{d_k} dk d k d_k dk是Key向量的维度。这一步的目的是为了在反向传播时,求梯度时更加稳定。

s c o r e 11 = q 1 ⋅ k 1 d k score_{11} = \frac{q_1 \cdot k_1}{\sqrt{d_k}} score11=dk q1k1

s c o r e 12 = q 1 ⋅ k 2 d k score_{12} = \frac{q_1 \cdot k_2}{\sqrt{d_k}} score12=dk q1k2

  • 4 然后把分数经过一个Softmax函数,通过Softmax将分数归一化,使分数都是正数并且加起来等于1。

s c o r e 11 = s o f t m a x ( s c o r e 11 ) score_{11} = softmax(score_{11}) score11=softmax(score11)

s c o r e 12 = s o f t m a x ( s c o r e 12 ) score_{12} = softmax(score_{12}) score12=softmax(score12)

Softmax 计算Score

  • 5 得到每个词向量的分数后,将分数分别与对应的Value向量相乘。对于分数高的位置,相乘后的值就越大,我们把更多的注意力放到了它们的身上;对于分数低的位置,相乘后的值就越小,这些位置的词可能就相关性不大。
    计算sum
  • 6 把第5步得到的Value向量相加,就得到了Self-Attention在当前位置对应的输出

z 1 = v 1 × s c o r e 11 + v 2 × s c o r e 12 z_1 = v_1 \times score_{11} + v_2 \times score_{12} z1=v1×score11+v2×score12

最后整体看一下Self-Attention计算的全过程

Self-Attention全过程

Self-Attention 矩阵计算

具体的实现时,并不会像上面那样阶段分明的分成6个步骤,而是将向量合并到一起,进行矩阵运算。

X 1 X_1 X1: 第一个单词的输入向量
X 2 X_2 X2: 第二个单词的输入向量
X = [ X 1 ; X 2 ] X = [X_1;X_2] X=[X1;X2] 将两个向量合并为矩阵

具体来说分为了两步:

  • 1:计算Query、Key、Value的矩阵。

    Q = X W Q Q = XW^Q Q=XWQ:计算Query

    K = X W K K = XW^K K=XWK:计算Key

    V = X W V V = XW^V V=XWV:计算Value

    把所有的词向量放到一个矩阵X中,然后分别和3个权重矩阵 W Q W^Q WQ W K W^K WK W V W^V WV相乘,得到 Q Q Q K K K V V V矩阵。矩阵X中的每一行,表示句子中的每一个词的词向量。 Q Q Q K K K V V V矩阵中的每一行表示Query向量、Key向量、Value向量,向量的维度是 d k d_k dk

    QKV矩阵乘法

  • 2:矩阵计算把上面第2步到第6步压缩为一步,直接得到Self-Attention的输出

    Z = s o f t m a x ( Q K T d k ) × V Z = softmax(\frac {QK^T} {\sqrt{d_k}}) \times V Z=softmax(dk QKT)×V

    计算Z

多头注意力机制

Transformer的论文中,通过增加多头注意力机制(一组注意力称为一个Attention Head),进一步完善了Self-Attention。这种机制从如下两个方面增强了Attention层的能力:

  • 扩展了模型关注不同位置的能力

    在上面的例子中,第一个位置的输出 z 1 z_1 z1包含了句子中其他每个位置的很小一部分信息。但 z 1 z_1 z1仅仅是单个向量,所以可能仅由第1个位置的信息主导了。而当我们翻译句子:The animal didn't cross the street because it was too tired时,我们不仅希望模型关注到it本身,还希望模型关注到Theanimal,甚至关注到tired

  • 多头注意力机制赋予了Attention层多个“子表示空间”

    多头注意力机制会有多组 W Q W^Q WQ W K W^K WK W V W^V WV的权重矩阵,因此可以将 X X X变换到更多种子空间中进行表示 。
    多头注意力机制
    每组注意力设定单独的 W Q W^Q WQ W K W^K WK W V W^V WV参数矩阵。将输入 X X X与它们相乘,得到多组 Q Q Q K K K V V V矩阵。接下来把每组的 Q Q Q K K K V V V计算得到各自的 Z Z Z
    Z矩阵计算
    由于前馈神经网络层接收的是1个矩阵(其中每行的向量表示一个词),而不是8个矩阵,所以要直接把8个子矩阵拼接得到一个大矩阵,然后和另一个权重矩阵 W O W^O WO相乘做一次变换,映射到前馈神经网络层所需要的维度。
    子矩阵映射变换
    把多头注意力放到一张图中:
    多头注意力机制运算

例子解析

再来看一下上面提到的it的例子,不同的Attention Heads对应的it attention了哪些内容。
It的Attention
图中绿色和橙色线条分别表示2组不同的Attention Heads。可以看到,当我们编码单词it时,其中一个Attention Head(橙色)最关注的是the animal,另外一个绿色Attention Head关注的是tired。因此在某种意义上,it在模型中的表示,融合了animaltire的部分表达。

代码实现

class MultiHeadAttention(nn.Module):def __init__(self, hid_dim, n_heads, dropout):super().__init__()self.hid_dim = hid_dimself.n_heads = n_heads# hid_dim必须整除assert hid_dim % n_heads == 0# 定义wqself.w_q = nn.Linear(hid_dim, hid_dim)# 定义wkself.w_k = nn.Linear(hid_dim, hid_dim)# 定义wvself.w_v = nn.Linear(hid_dim, hid_dim)self.fc = nn.Linear(hid_dim, hid_dim)self.do = nn.Dropout(dropout)self.scale = torch.sqrt(torch.FloatTensor([hid_dim//n_heads]))def forward(self, query, key, value, mask=None):# Q与KV在句子长度这一个维度上数值可以不一样bsz = query.shape[0]Q = self.w_q(query)K = self.w_k(key)V = self.w_v(value)# 将QKV拆成多组,方案是将向量直接拆开了# (64, 12, 300) -> (64, 12, 6, 50) -> (64, 6, 12, 50)# (64, 10, 300) -> (64, 10, 6, 50) -> (64, 6, 10, 50)# (64, 10, 300) -> (64, 10, 6, 50) -> (64, 6, 10, 50)Q = Q.view(bsz, -1, self.n_heads, self.hid_dim//self.n_heads).permute(0, 2, 1, 3)K = K.view(bsz, -1, self.n_heads, self.hid_dim//self.n_heads).permute(0, 2, 1, 3)V = V.view(bsz, -1, self.n_heads, self.hid_dim//self.n_heads).permute(0, 2, 1, 3)# 第1步,Q x K / scale# (64, 6, 12, 50) x (64, 6, 50, 10) -> (64, 6, 12, 10)attention = torch.matmul(Q, K.permute(0, 1, 3, 2)) / self.scale# 需要mask掉的地方,attention设置的很小很小if mask is not None:attention = attention.masked_fill(mask == 0, -1e10)# 第2步,做softmax 再dropout得到attentionattention = self.do(torch.softmax(attention, dim=-1))# 第3步,attention结果与k相乘,得到多头注意力的结果# (64, 6, 12, 10) x (64, 6, 10, 50) -> (64, 6, 12, 50)x = torch.matmul(attention, V)# 把结果转回去# (64, 6, 12, 50) -> (64, 12, 6, 50)x = x.permute(0, 2, 1, 3).contiguous()# 把结果合并# (64, 12, 6, 50) -> (64, 12, 300)x = x.view(bsz, -1, self.n_heads * (self.hid_dim // self.n_heads))x = self.fc(x)return x        

测试一下是否能输出

query = torch.rand(64, 12, 300)
key = torch.rand(64, 10, 300)
value = torch.rand(64, 10, 300)
attention = MultiHeadAttention(hid_dim=300, n_heads=6, dropout=0.1)
output = attention(query, key, value)
print(output.shape)

输出

总结与心得体会

通过对多头注意力机制的学习,有一个让我印象深刻的地方就是,它的多头注意力机制不是像其它模块设计思路一样,对同一个输入做了多组运算,而是将输入切分成不同的部分,每部分分别做了多组运算。由于自然语言处理中,一个单词的词向量往往是很长的,所以这种方式比CV的那种堆叠的方式能减少很多计算量,并且在效果方面不会损失太多。

个人感觉:词向量的不同分组之间的关系有点像计算机视觉中,彩色图像的多个通道,多头注意力机制有点像后面的通道注意力的计算。

相关文章:

TR4 - Transformer中的多头注意力机制

目录 前言自注意力机制Self-Attention层的具体机制Self-Attention 矩阵计算 多头注意力机制例子解析 代码实现总结与心得体会 前言 多头注意力机制可以说是Transformer中最主要的模块,没有之一。这次我们来仔细分析一下注意力机制与多头注意力机制。 自注意力机制…...

three.js跟着教程实现VR效果(四)

参照教程:https://juejin.cn/post/6973865268426571784(作者:大帅老猿) 1.WebGD3D引擎 用three.js (1)使用立方体6面图 camera放到 立方体的中间 like “回” 让贴图向内翻转 (2)使…...

AI预测体彩排3第1弹【2024年4月12日预测--第1套算法开始计算第1次测试】

前面经过多个模型几十次对福彩3D的预测,积累了一定的经验,摸索了一些稳定的规律,有很多彩友让我也出一下排列3的预测结果,我认为目前时机已成熟,且由于福彩3D和体彩排列3的玩法完全一样,我认为3D的规律和模…...

spring 中的控制反转

在Spring框架中,控制反转(IoC,Inversion of Control)是指将对象的创建和管理交给了容器,而不是在应用程序代码中直接创建对象。在传统的编程模式中,应用程序代码通常负责创建对象并管理它们的生命周期&…...

GO并发总是更快吗?

许多开发人员的一个误解是,并发解决方案总是比串行更快,大错特错。解决方案的整体性能取决于许多因素,例如,结构的效率(并发)、可以并行处理的部分以及计算单元的竞争程度。 1. GO调度 线程是操作系统可以执行的最小单元。如果一个进程想要同时执行多个动作,它可以启动…...

echarts折线图自定义打点标记小工具

由于没研究明白echarts怎么用label和lableLine实现自定义打点标记&#xff0c;索性用markPoint把长方形压扁成线模拟了一番自定义打点标记&#xff0c;记录下来备用。&#xff08;markLine同理也能实现&#xff09; 实现代码如下&#xff1a; <!DOCTYPE html> <html…...

【图论】Leetcode 200. 岛屿数量【中等】

岛屿数量 给你一个由 ‘1’&#xff08;陆地&#xff09;和 ‘0’&#xff08;水&#xff09;组成的的二维网格&#xff0c;请你计算网格中岛屿的数量。 岛屿总是被水包围&#xff0c;并且每座岛屿只能由水平方向和/或竖直方向上相邻的陆地连接形成。 此外&#xff0c;你可以…...

酒店大厅装水离子雾化壁炉前和装后对比

在酒店大厅装水离子雾化壁炉之前和之后&#xff0c;大厅的氛围和体验会有显著的对比&#xff1a; 装水离子雾化壁炉之前&#xff1a; 传统感&#xff1a;在壁炉安装之前&#xff0c;大厅可能会有传统的装饰或者简单的暖气设备&#xff0c;缺乏现代化的元素。这种传统感可能会…...

城市内涝与海绵城市规划设计中的水文水动力模拟

原文链接&#xff1a;城市内涝与海绵城市规划设计中的水文水动力模拟https://mp.weixin.qq.com/s?__bizMzUzNTczMDMxMg&mid2247601198&idx5&sn35b9e5e3961ea2f190f9742236a7217f&chksmfa820dc9cdf584df97633f64d19bdc3e5f7d1a5a85000c8f040e1953c51b9b39c87b5…...

C++项目实战与经验分享

在编程世界中,C++ 是一种功能强大且灵活的编程语言,广泛应用于系统级编程、游戏开发、嵌入式系统以及高性能计算等领域。本文将分享一个基于C++的图像处理系统项目实战经验,并深入探讨在开发过程中遇到的问题及解决方案。 一、项目概述 本次项目实战的目标是开发一个基于C…...

Day17_学点JavaEE_转发、重定向、Get、POST、乱码问题总结

1 转发 转发&#xff1a;一般查询了数据之后&#xff0c;转发到一个jsp页面进行展示 req.setAttribute("list", list); req.getRequestDispatcher("student_list.jsp").forward(req, resp);2 重定向 重定向&#xff1a;一般添加、删除、修改之后重定向到…...

Mouse IFN-α ELISA kit (Quick Test)

干扰素α&#xff08;IFN-α&#xff09;是一类由免疫细胞分泌的内源性调节因子&#xff0c;也被称为白细胞干扰素&#xff0c;主要参与响应病毒感染的先天性免疫。 基于结构特征、受体、细胞来源和生物活性的不同&#xff0c;干扰素可被分为Ⅰ、Ⅱ、Ⅲ三种类型&#xff0c;其中…...

AMD Tensile 简介与示例

按照知其然&#xff0c;再知其所以然的认知次序进行 1&#xff0c;下载代码 git clone --recursive https://github.com/ROCm/Tensile.git 2&#xff0c;安装 Tensile cd Tensile mkdir build cd build ../Tensile/bin/Tensile ../Tensile/Configs/rocblas_dgemm_nn_asm_full…...

Rust语言

文章目录 Rust语言一&#xff0c;Rust语言是什么二&#xff0c;Rust语言能做什么&#xff1f;Rust语言的设计使其适用于许多不同的领域&#xff0c;包括但不限于以下几个方面&#xff1a;1. 传统命令行程序&#xff1a;2. Web 应用&#xff1a;3. 网络服务器&#xff1a;4. 嵌入…...

排序算法之冒泡排序

目录 一、简介二、代码实现三、应用场景 一、简介 算法平均时间复杂度最好时间复杂度最坏时间复杂度空间复杂度排序方式稳定性冒泡排序O(n^2 )O(n)O(n^2)O(1)In-place稳定 稳定&#xff1a;如果A原本在B前面&#xff0c;而AB&#xff0c;排序之后A仍然在B的前面&#xff1b; 不…...

js打印页面源码 ,打印选取的容器里的内容,打印指定内容

js打印页面源码 &#xff0c;打印选取的容器里的内容&#xff0c;打印指定内容 效果 代码 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta http-equiv"X-UA-Compatible" content"IEedge&…...

算法练习第五十天|123.买卖股票的最佳时机III、188.买卖股票的最佳时机IV

123. 买卖股票的最佳时机 III 188. 买卖股票的最佳时机 IV 123.买卖股票的最佳时机III class Solution {public int maxProfit(int[] prices) {//dp[i][j] 第i天买卖股票获得的最大利润/**j0不操作j1第一次持有j2第一次不持有j3第二次持有j4第二次不持有dp[i][0] dp[i-1][0]d…...

细胞世界:4.细胞分化(划区域)与细胞衰老(设施磨损)

(1)细胞凋亡 1. 概念&#xff1a;细胞凋亡可以比作城市的规划者主动拆除某些建筑来更新城市或防止危险建筑对市民的潜在伤害。这是一个有序的过程&#xff0c;由城市&#xff08;细胞内部&#xff09;的特定规划&#xff08;基因&#xff09;所决定。 2. 特征&#xff1a;细…...

c语言:操作符

操作符 一.算术操作符: + - * % / 1.除了%操作符之外,其他的几个操作符可以作用与整数和浮点数,如:5%2.0//error. 2.对于操作符,如果两个操作数都为整数,执行整数除法而只要有浮点数执行的就是浮点数除法。 3.%操作符的两个操作数必须为整数。 二.移位操作符:<&…...

谷歌seo自然搜索排名怎么提升快?

要想在谷歌上排名快速上升&#xff0c;关键在于运用GPC爬虫池跟高低搭配的外链组合 首先你要做的&#xff0c;就是让谷歌的蜘蛛频繁来你的网站&#xff0c;网站需要被谷歌蜘蛛频繁抓取和索引&#xff0c;那这时候GPC爬虫池就能派上用场了&#xff0c;GPC爬虫池能够帮你大幅度提…...

使用VSCode开发Django指南

使用VSCode开发Django指南 一、概述 Django 是一个高级 Python 框架&#xff0c;专为快速、安全和可扩展的 Web 开发而设计。Django 包含对 URL 路由、页面模板和数据处理的丰富支持。 本文将创建一个简单的 Django 应用&#xff0c;其中包含三个使用通用基本模板的页面。在此…...

理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端

&#x1f31f; 什么是 MCP&#xff1f; 模型控制协议 (MCP) 是一种创新的协议&#xff0c;旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议&#xff0c;它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...

大数据零基础学习day1之环境准备和大数据初步理解

学习大数据会使用到多台Linux服务器。 一、环境准备 1、VMware 基于VMware构建Linux虚拟机 是大数据从业者或者IT从业者的必备技能之一也是成本低廉的方案 所以VMware虚拟机方案是必须要学习的。 &#xff08;1&#xff09;设置网关 打开VMware虚拟机&#xff0c;点击编辑…...

macOS多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用

文章目录 问题现象问题原因解决办法 问题现象 macOS启动台&#xff08;Launchpad&#xff09;多出来了&#xff1a;Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用。 问题原因 很明显&#xff0c;都是Google家的办公全家桶。这些应用并不是通过独立安装的…...

大学生职业发展与就业创业指导教学评价

这里是引用 作为软工2203/2204班的学生&#xff0c;我们非常感谢您在《大学生职业发展与就业创业指导》课程中的悉心教导。这门课程对我们即将面临实习和就业的工科学生来说至关重要&#xff0c;而您认真负责的教学态度&#xff0c;让课程的每一部分都充满了实用价值。 尤其让我…...

Maven 概述、安装、配置、仓库、私服详解

目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...

代理篇12|深入理解 Vite中的Proxy接口代理配置

在前端开发中,常常会遇到 跨域请求接口 的情况。为了解决这个问题,Vite 和 Webpack 都提供了 proxy 代理功能,用于将本地开发请求转发到后端服务器。 什么是代理(proxy)? 代理是在开发过程中,前端项目通过开发服务器,将指定的请求“转发”到真实的后端服务器,从而绕…...

NXP S32K146 T-Box 携手 SD NAND(贴片式TF卡):驱动汽车智能革新的黄金组合

在汽车智能化的汹涌浪潮中&#xff0c;车辆不再仅仅是传统的交通工具&#xff0c;而是逐步演变为高度智能的移动终端。这一转变的核心支撑&#xff0c;来自于车内关键技术的深度融合与协同创新。车载远程信息处理盒&#xff08;T-Box&#xff09;方案&#xff1a;NXP S32K146 与…...

SQL慢可能是触发了ring buffer

简介 最近在进行 postgresql 性能排查的时候,发现 PG 在某一个时间并行执行的 SQL 变得特别慢。最后通过监控监观察到并行发起得时间 buffers_alloc 就急速上升,且低水位伴随在整个慢 SQL,一直是 buferIO 的等待事件,此时也没有其他会话的争抢。SQL 虽然不是高效 SQL ,但…...

【Nginx】使用 Nginx+Lua 实现基于 IP 的访问频率限制

使用 NginxLua 实现基于 IP 的访问频率限制 在高并发场景下&#xff0c;限制某个 IP 的访问频率是非常重要的&#xff0c;可以有效防止恶意攻击或错误配置导致的服务宕机。以下是一个详细的实现方案&#xff0c;使用 Nginx 和 Lua 脚本结合 Redis 来实现基于 IP 的访问频率限制…...