当前位置: 首页 > news >正文

Transformer中的自注意力是怎么实现的?

在Transformer模型中,自注意力(Self-Attention)是核心组件,用于捕捉输入序列中不同位置之间的关系。自注意力机制通过计算每个标记与其他所有标记之间的注意力权重,然后根据这些权重对输入序列进行加权求和,从而生成新的表示。下面是实现自注意力机制的代码及其详细说明。

自注意力机制的实现

1. 计算注意力得分(Scaled Dot-Product Attention)

自注意力机制的基本步骤包括以下几个部分:

  1. 线性变换:将输入序列通过三个不同的线性变换层,得到查询(Query)、键(Key)和值(Value)矩阵。
  2. 计算注意力得分:通过点积计算查询与键的相似度,再除以一个缩放因子(通常是键的维度的平方根),以稳定梯度。
  3. 应用掩码:在计算注意力得分后,应用掩码(如果有),避免未来信息泄露(用于解码器中的自注意力)。
  4. 计算注意力权重:通过softmax函数将注意力得分转换为概率分布。
  5. 加权求和:使用注意力权重对值进行加权求和,得到新的表示。
2. 多头注意力机制(Multi-Head Attention)

为了捕捉不同子空间的特征,Transformer使用多头注意力机制。通过将查询、键和值分割成多个头,每个头独立地计算注意力,然后将所有头的输出连接起来,并通过一个线性层进行组合。

自注意力机制代码实现

import torch
import torch.nn as nn
import torch.nn.functional as F# Scaled Dot-Product Attention
def scaled_dot_product_attention(query, key, value, mask=None):d_k = query.size(-1)scores = torch.matmul(query, key.transpose(-2, -1)) / torch.sqrt(torch.tensor(d_k, dtype=torch.float32))print(f"Scores shape: {scores.shape}")  # (batch_size, num_heads, seq_length, seq_length)if mask is not None:scores = scores.masked_fill(mask == 0, float('-inf'))attention_weights = F.softmax(scores, dim=-1)print(f"Attention weights shape: {attention_weights.shape}")  # (batch_size, num_heads, seq_length, seq_length)output = torch.matmul(attention_weights, value)print(f"Output shape after attention: {output.shape}")  # (batch_size, num_heads, seq_length, d_k)return output, attention_weights# Multi-Head Attention
class MultiHeadAttention(nn.Module):def __init__(self, d_model, num_heads):super(MultiHeadAttention, self).__init__()assert d_model % num_heads == 0self.d_model = d_modelself.num_heads = num_headsself.d_k = d_model // num_headsself.linear_query = nn.Linear(d_model, d_model)self.linear_key = nn.Linear(d_model, d_model)self.linear_value = nn.Linear(d_model, d_model)self.linear_out = nn.Linear(d_model, d_model)def forward(self, query, key, value, mask=None):batch_size = query.size(0)# Linear projectionsquery = self.linear_query(query)key = self.linear_key(key)value = self.linear_value(value)print(f"Query shape after linear: {query.shape}")  # (batch_size, seq_length, d_model)print(f"Key shape after linear: {key.shape}")      # (batch_size, seq_length, d_model)print(f"Value shape after linear: {value.shape}")  # (batch_size, seq_length, d_model)# Split into num_headsquery = query.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)key = key.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)value = value.view(batch_size, -1, self.num_heads, self.d_k).transpose(1, 2)print(f"Query shape after split: {query.shape}")   # (batch_size, num_heads, seq_length, d_k)print(f"Key shape after split: {key.shape}")       # (batch_size, num_heads, seq_length, d_k)print(f"Value shape after split: {value.shape}")   # (batch_size, num_heads, seq_length, d_k)# Apply scaled dot-product attentionx, attention_weights = scaled_dot_product_attention(query, key, value, mask)# Concatenate headsx = x.transpose(1, 2).contiguous().view(batch_size, -1, self.d_model)print(f"Output shape after concatenation: {x.shape}")  # (batch_size, seq_length, d_model)# Final linear layerx = self.linear_out(x)print(f"Output shape after final linear: {x.shape}")   # (batch_size, seq_length, d_model)return x, attention_weights# 示例用法
d_model = 512
num_heads = 8
batch_size = 64
seq_length = 10# 假设输入是随机生成的张量
query = torch.rand(batch_size, seq_length, d_model)
key = torch.rand(batch_size, seq_length, d_model)
value = torch.rand(batch_size, seq_length, d_model)# 创建多头注意力层
mha = MultiHeadAttention(d_model, num_heads)
output, attention_weights = mha(query, key, value)print("最终输出形状:", output.shape)  # 最终输出形状: (batch_size, seq_length, d_model)
print("注意力权重形状:", attention_weights.shape)  # 注意力权重形状: (batch_size, num_heads, seq_length, seq_length)

每一步的形状解释

  1. Linear Projections

    • Query, Key, Value分别经过线性变换。
    • 形状:[batch_size, seq_length, d_model]
  2. Split into Heads

    • 将Query, Key, Value分割成多个头。
    • 形状:[batch_size, num_heads, seq_length, d_k],其中d_k = d_model // num_heads
  3. Scaled Dot-Product Attention

    • 计算注意力得分(Scores)。
    • 形状:[batch_size, num_heads, seq_length, seq_length]
    • 计算注意力权重(Attention Weights)。
    • 形状:[batch_size, num_heads, seq_length, seq_length]
    • 使用注意力权重对Value进行加权求和。
    • 形状:[batch_size, num_heads, seq_length, d_k]
  4. Concatenate Heads

    • 将所有头的输出连接起来。
    • 形状:[batch_size, seq_length, d_model]
  5. Final Linear Layer

    • 通过一个线性层将连接的输出转换为最终的输出。
    • 形状:[batch_size, seq_length, d_model]

通过这种方式,我们可以清楚地看到每一步变换后的张量形状,理解自注意力和多头注意力机制的具体实现细节。

代码说明

  • scaled_dot_product_attention:实现了缩放点积注意力机制,计算查询和键的点积,应用掩码,计算softmax,然后使用权重对值进行加权求和。
  • MultiHeadAttention:实现了多头注意力机制,包括线性变换、分割、缩放点积注意力和最后的线性变换。

多头注意力机制的细节

  • 线性变换:将输入序列通过线性层转换为查询、键和值的矩阵。
  • 分割头:将查询、键和值的矩阵分割为多个头,每个头的维度是[batch_size, num_heads, seq_length, d_k]。
  • 缩放点积注意力:对每个头分别计算缩放点积注意力。
  • 连接头:将所有头的输出连接起来,得到[batch_size, seq_length, d_model]的张量。
  • 线性变换:通过一个线性层将连接的输出转换为最终的输出。

相关文章:

Transformer中的自注意力是怎么实现的?

在Transformer模型中,自注意力(Self-Attention)是核心组件,用于捕捉输入序列中不同位置之间的关系。自注意力机制通过计算每个标记与其他所有标记之间的注意力权重,然后根据这些权重对输入序列进行加权求和&#xff0c…...

LabVIEW鼠标悬停在波形图上的曲线来自动显示相应点的坐标

步骤 创建事件结构: 打开LabVIEW,创建一个新的VI。 在前面板上添加一个Waveform Graph控件。 在后面板上添加一个While Loop和一个事件结构(Event Structure)。 配置事件结构,选择Waveform Graph作为事件源&#xf…...

操作系统发展简史(Unix/Linux 篇 + DOS/Windows 篇)+ Mac 与 Microsoft 之风云争霸

操作系统发展简史(Unix/Linux 篇) 说到操作系统,大家都不会陌生。我们天天都在接触操作系统 —— 用台式机或笔记本电脑,使用的是 windows 和 macOS 系统;用手机、平板电脑,则是 android(安卓&…...

钡铼分布式 IO 系统 OPC UA边缘计算耦合器BL205

深圳钡铼技术推出的BL205耦合器支持OPC UA Server功能,以服务器形式对外提供数据。符合IEC 62541工业自动化统一架构通讯标准,数据可以选择加密(X.509证书)、身份验证方式传送。 安全策略支持basic128rsa15、basic256、basic256s…...

实现了一个心理测试的小程序,微信小程序学习使用问题总结

1. 如何在跳转页面中传递参数 ,在 onLoad 方法中通过 options 接收 2. radio 如何获取选中的值? bindchange 方法 参数e, e.detail.value 。 如果想要获取其他属性,使用data-xx 指定,然后 e.target.dataset.xx 获取。 3. 不刷…...

vue是如何进行监听数据变化的?vue2和vue3分别是什么?vue3为什么要更换?

Vue如何进行监听数据变化的? Vue.js 通过其响应式系统来监听数据变化。这个系统允许你声明式地将数据和 DOM 绑定,一旦数据发生变化,相关的 DOM 将自动更新。Vue 使用以下机制来实现数据的监听和响应: 响应式数据:在 …...

数据结构day3

一、思维导图 二、 #include "seqlist.h"#include<myhead.h> int main(int argc, const char *argv[]) {//创建一个顺序表SeqListPtr L list_create();if(NULL L){return -1;}//调用添加函数list_add(L,123);list_add(L,435);list_add(L,856);list_add(L,65…...

免费的数字孪生平台助力产业创新,让新质生产力概念有据可依

关于新质生产力的概念&#xff0c;在如今传统企业现代化发展中被反复提及。 那到底什么是新质生产力&#xff1f;它与哪些行业存在联系&#xff0c;我们又该使用什么工具来加快新质生产力的发展呢&#xff1f;今天我将介绍一款为发展新质生产力而量身定做的数字孪生工具。 新…...

mtsys2 编译 qemu 记录

参考链接 下载 MSYS2 MSYS2 MSYS2 换源 进入目录\msys64\etc\pacman.d&#xff0c; 在文件mirrorlist.msys的前面插入 Server http://mirrors.ustc.edu.cn/msys2/msys/$arch在文件mirrorlist.mingw32的前面插入 Server http://mirrors.ustc.edu.cn/msys2/mingw/i686在…...

【Python数据分析】数据分析三剑客:NumPy、SciPy、Matplotlib中常用操作汇总

文章目录 NumPy常见操作汇总SciPy常见操作汇总Matplotlib常见操作汇总官方文档链接NumPy常见操作汇总 在Python的NumPy库中,有许多常用的知识点,这里列出了一些核心功能和常见操作: 类别函数或特性描述基础操作np.array创建数组np.shape获取数组形状np.dtype查看数组数据类…...

STM32智能家居电力管理系统教程

目录 引言环境准备智能家居电力管理系统基础代码实现&#xff1a;实现智能家居电力管理系统 4.1 数据采集模块 4.2 数据处理与控制模块 4.3 通信与网络系统实现 4.4 用户界面与数据可视化应用场景&#xff1a;电力管理与优化问题解决方案与优化收尾与总结 1. 引言 智能家居电…...

C# 邮件发送

创建邮件类 // 有static时候 类名&#xff0c;方法名// MyEmail.方法名/// <summary>/// 给目标发送邮箱/// </summary>/// <param name"maiTo"></param>/// <param name"title"></param>/// <param name"con…...

Kotlin 协程简化回调

suspend 和 suspendCoroutine 实现 suspendCoroutine函数必须在协程作用域或挂起函数中才能调用&#xff0c;它接收一个Lambda表达式参数&#xff0c;主要作用是将当前协程立即挂起&#xff0c;然后在一个普通的线程中执行Lambda表达式中的代码。Lambda表达式的参数列表上会传…...

帝王蝶算法(EBOA)及Python和MATLAB实现

帝王蝶算法&#xff08;Emperor Butterfly Optimization Algorithm&#xff0c;简称EBOA&#xff09;是一种启发式优化算法&#xff0c;灵感来源于蝴蝶群体中的帝王蝶&#xff08;Emperor Butterfly&#xff09;。该算法模拟了帝王蝶群体中帝王蝶和其他蝴蝶之间的交互行为&…...

【学术会议征稿】第六届信息与计算机前沿技术国际学术会议(ICFTIC 2024)

第六届信息与计算机前沿技术国际学术会议(ICFTIC 2024) 2024 6th International Conference on Frontier Technologies of Information and Computer 第六届信息与计算机前沿技术国际学术会议(ICFTIC 2024)将在中国青岛举行&#xff0c;会期是2024年11月8-10日&#xff0c;为…...

PHP MySQL 读取数据

PHP MySQL 读取数据 PHP和MySQL是Web开发中的经典组合&#xff0c;广泛用于创建动态网站和应用程序。在PHP中读取MySQL数据库中的数据是一项基本技能&#xff0c;涉及到连接数据库、执行查询以及处理结果集。本文将详细介绍如何使用PHP从MySQL数据库中读取数据。 1. 环境准备…...

点亮 LED-I.MX6U嵌入式Linux C应用编程学习笔记基于正点原子阿尔法开发板

点亮 LED 应用层操控硬件的两种方式 背景 Linux系统将所有内容视作文件&#xff0c;包括硬件设备&#xff0c;通过文件I/O方式与硬件交互 设备文件&#xff0c;如字符设备文件与块设备文件&#xff0c;是硬件设备提供给应用层的接口 应用层通过设备文件进行I/O操作&#xff…...

从0到1搭建数据中台(4):neo4j初识及安装使用

在数据中台中&#xff0c;neo4j作为图数据库&#xff0c;可以用于数据血缘关系的存储 图数据库的其他用于主要用于知识图谱&#xff0c;人物关系的搭建&#xff0c;描述实体&#xff0c;关系&#xff0c;以及实体属性 安装 在官网 https://neo4j.com/ 下载安装包 neo4j-co…...

【20】读感 - 架构整洁之道(二)

概述 继上一篇文章讲了前两章的读感&#xff0c;已经归纳总结的重点&#xff0c;这章会继续跟进的看一下&#xff0c;深挖架构整洁之道。 编程范式 编程范式从早期到至今&#xff0c;提过哪些编程范式&#xff0c;结构化编程&#xff0c;面向对象编程&#xff0c;函数式编程…...

js vue axios post 数组请求参数获取转换, 后端go参数解析(gin框架)全流程示例

今天介绍的是前后端分离系统中的请求参数 数组参数的生成&#xff0c;api请求发送&#xff0c;到后端请求参数接收的全过程示例。 为何会有这个文章&#xff1a;后端同一个API接口同时处理单条或者多条数据&#xff0c;这样就要求我们在前端发送请求参数的时候需要统一将请…...

揭秘郭采洁浪漫升级

【揭秘&#xff01;郭采洁浪漫升级&#xff0c;与“莫拉怪乐”共谱爱情新篇章】在这个春意盎然的季节里&#xff0c;娱乐圈迎来了一则既意外又甜蜜的爆炸新闻——郭采洁&#xff0c;这位以独特气质与精湛演技著称的才女&#xff0c;悄然间迈入了人生的新阶段&#xff0c;而她的…...

数据结构(Java):力扣牛客 二叉树面试OJ题(一)

&#x1f449; ​​​​​​目录 &#x1f448; 1、题一&#xff1a;检查两棵树是否相同 1.1 思路分析 1.2 代码 2、题二&#xff1a;另一棵树的子树 2.1 思路分析 2.2 代码 3、题三&#xff1a;翻转二叉树 3.1 思路分析 3.2 代码 4、题四&#xff1a;判断树是否对称 …...

在国产芯片上实现YOLOv5/v8图像AI识别-【1.3】YOLOv5的介绍及使用(训练、导出)更多内容见视频

本专栏主要是提供一种国产化图像识别的解决方案&#xff0c;专栏中实现了YOLOv5/v8在国产化芯片上的使用部署&#xff0c;并可以实现网页端实时查看。根据自己的具体需求可以直接产品化部署使用。 B站配套视频&#xff1a;https://www.bilibili.com/video/BV1or421T74f 数据…...

逻辑门的题目怎么做?

FPGA语法练习——二输入逻辑门&#xff0c;一起来听~~ FPGA语法练习——二输入逻辑门 题目介绍&#xff1a;F学社-全球FPGA技术提升平台 (zzfpga.com)...

CentOS 7报错:yum命令报错 “ Cannot find a valid baseurl for repo: base/7/x86_6 ”

参考连接&#xff1a; 【linux】CentOS 7报错&#xff1a;yum命令报错 “ Cannot find a valid baseurl for repo: base/7/x86_6 ”_centos linux yum search ifconfig cannot find a val-CSDN博客 Centos7出现问题Cannot find a valid baseurl for repo: base/7/x86_64&…...

51单片机STC89C52RC——18.1 HC-SR04超声波测距

目的/效果 独立按键K1按下后开始测距&#xff0c;LCD显示距离&#xff08;mm&#xff09; 一&#xff0c;STC单片机模块 二&#xff0c;HC-SR04 超声波测距 2.1 HC-SR04 简介 HC-SR04超声波测距模块提供2cm~400cm的测距功能&#xff0c;精度达3mm。 2.2 时序 以上时序图表明…...

WordPress与 wp-cron.php

WordPress 傲居全球最流行的内容管理系统&#xff08;CMS&#xff09;之位&#xff0c;占据了互联网约43%的网站后台&#xff0c;这主要得益于其直观易用的用户界面以及丰富的扩展功能&#xff0c;特别是为新手用户提供了极大的便利。 然而&#xff0c;在畅享WordPress带来的便…...

bb-------

社保费申报及缴纳...

数据挖掘与分析部分实验与实训项目报告

一、机器学习算法的应用 1. 朴素贝叶斯分类器 相关代码 import pandas as pd from sklearn.model_selection import train_test_split from sklearn.naive_bayes import GaussianNB, MultinomialNB from sklearn.metrics import accuracy_score # 将数据加载到DataFrame中&a…...

Python中使用SpeechLib实现文本转换语音朗读的示例(修正bug)

一、修正SpeechLib的导入包顺序后的代码&#xff1a; from comtypes.client import CreateObjectengine CreateObject(SAPI.SpVoice) stream CreateObject(SAPI.SpFileStream)from comtypes.gen import SpeechLibinfile E:\\语音文档\\易经64卦读音.txt outfile E:\\demo.…...

包头做网站要多少钱/石狮seo

变量 变量的使用 php定义时不需要关键字&#xff0c;但必须使用$符号 这里的echo就是输出的意思&#xff0c;相当于python中的print 删除变量用unset()方法&#xff0c;为什么要干掉它&#xff0c;就是因为要释放内存 变量命名规则 1、变量名字必须以"$"开头&…...

网站建设运行问题及建议/苹果看国外新闻的app

AutoResetEvent 类 AutoResetEvent类的工作方式与ManualResetEvent类似。它会等超时事件发生或者信号事件发生然后通知正在等待的线程。ManualResetEvent和AutoResetEvent之间最重要差别之一是AutoResetEvent在WaitOne()方法执行完会改变自身状态。下面列表显示了如何使用AutoR…...

怎么将网站设置为首页/小说网站排名免费

文 | 我爱学Python简书 编辑 | EarlGrey推荐 | 编程派公众号(ID&#xff1a;codingpy)昨天在上厕所的时候突发奇想&#xff0c;当你把usb插进去的时候&#xff0c;能不能自动执行usb上的程序。查了一下&#xff0c;发现只有windows上可以&#xff0c;具体的大家也可以搜索(搜索…...

smluntan wordpress/地推拉新app推广怎么做

1、安装python2.7官网下载&#xff0c;安装&#xff0c;配置环境变量 path&#xff0c;命令行 执行python2、easy_install 安装 win7 64位必须使用ez_setup.py进行安装。方法是下载ez_setup.py后。在cmd下运行 python ez_setup.py&#xff0c;就可以自己主动安装setuptools。下…...

长春市网站制作/百度电脑版

描述 给一包含大写字母和整数(从 0 到 9)的字符串, 试写一函数返回有序的字母以及数字和. 样例 - 样例 1:输入 : str "AC2BEW3" 输出 : "ABCEW5" 说明 : 字母按字母表的顺序排列, 接着是整数的和(2 和 3)。解析 首先看了一下Java和Python的提交 Java…...

浙江建站优化品牌/青岛网站推广公司排名

1、ARRAY_SIZE 用来判断一个数组的 size&#xff0c;若传入的参数不是一个数组&#xff0c;编译将会报错。 使用此宏来安全的获取一个数组的 size。 include/linux/kernel.h#define ARRAY_SIZE(arr) (sizeof(arr) / sizeof((arr)[0]) __must_be_array(arr))2、__must_be_arr…...