当前位置: 首页 > news >正文

Pytorch手撸Attention

Pytorch手撸Attention

注释写的很详细了,对照着公式比较下更好理解,可以参考一下知乎的文章

注意力机制

在这里插入图片描述

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SelfAttention(nn.Module):def __init__(self, embed_size):super(SelfAttention, self).__init__()self.embed_size = embed_size# 定义三个全连接层,用于生成查询(Q)、键(K)和值(V)# 用Linear线性层让q、k、y能更好的拟合实际需求self.value = nn.Linear(embed_size, embed_size)self.key = nn.Linear(embed_size, embed_size)self.query = nn.Linear(embed_size, embed_size)def forward(self, x):# x 的形状应为 (batch_size批次数量, seq_len序列长度, embed_size嵌入维度)batch_size, seq_len, embed_size = x.shapeQ = self.query(x)K = self.key(x)V = self.value(x)# 计算注意力分数矩阵# 使用 Q 矩阵乘以 K 矩阵的转置来得到原始注意力分数# 注意力分数的形状为 [batch_size, seq_len, seq_len]# K.transpose(1,2)转置后[batch_size, embed_size, seq_len]# 为什么不直接使用 .T 直接转置?直接转置就成了[embed_size, seq_len,batch_size],不方便后续进行矩阵乘法attention_scores = torch.matmul(Q, K.transpose(1, 2)) / torch.sqrt(torch.tensor(self.embed_size, dtype=torch.float32))# 应用 softmax 获取归一化的注意力权重,dim=-1表示基于最后一个维度做softmaxattention_weight = F.softmax(attention_scores, dim=-1)# 应用注意力权重到 V 矩阵,得到加权和# 输出的形状为 [batch_size, seq_len, embed_size]output = torch.matmul(attention_weight, V)return output

多头注意力机制

在这里插入图片描述

class MultiHeadAttention(nn.Module):def __init__(self, embed_size, num_heads):super().__init__()self.embed_size = embed_sizeself.num_heads = num_heads# 整除来确定每个头的维度self.head_dim = embed_size // num_heads# 加入断言,防止head_dim是小数,必须保证可以整除assert self.head_dim * num_heads == embed_sizeself.q = nn.Linear(embed_size, embed_size)self.k = nn.Linear(embed_size, embed_size)self.v = nn.Linear(embed_size, embed_size)self.out = nn.Linear(embed_size, embed_size)def forward(self, query, key, value):# N就是batch_size的数量N = query.shape[0]# *_len是序列长度q_len = query.shape[1]k_len = key.shape[1]v_len = value.shape[1]# 通过线性变换让矩阵更好的拟合queries = self.q(query)keys = self.k(key)values = self.v(value)# 重新构建多头的queries,permute调整tensor的维度顺序# 结合下文demo进行理解queries = queries.reshape(N, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)keys = keys.reshape(N, k_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)values = values.reshape(N, v_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3)# 计算多头注意力分数attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / torch.sqrt(torch.tensor(self.head_dim, dtype=torch.float32))attention = F.softmax(attention_scores, dim=-1)# 整合多头注意力机制的计算结果out = torch.matmul(attention, values).permute(0, 2, 1, 3).reshape(N, q_len, self.embed_size)# 过一遍线性函数out = self.out(out)return out

demo测试

self-attention测试
# 测试自注意力机制
batch_size = 2
seq_len = 3
embed_size = 4# 生成一个随机数据 tensor
input_tensor = torch.rand(batch_size, seq_len, embed_size)# 创建自注意力模型实例
model = SelfAttention(embed_size)# print输入数据
print("输入数据 [batch_size, seq_len, embed_size]:")
print(input_tensor)# 运行自注意力模型
output_tensor = model(input_tensor)# print输出数据
print("输出数据 [batch_size, seq_len, embed_size]:")
print(output_tensor)

=======print=========

输入数据 [batch_size, seq_len, embed_size]:
tensor([[[0.7579, 0.7342, 0.1031, 0.8610],[0.8250, 0.0362, 0.8953, 0.1687],[0.8254, 0.8506, 0.9826, 0.0440]],[[0.0700, 0.4503, 0.1597, 0.6681],[0.8587, 0.4884, 0.4604, 0.2724],[0.5490, 0.7795, 0.7391, 0.9113]]])输出数据 [batch_size, seq_len, embed_size]:
tensor([[[-0.3714,  0.6405, -0.0865, -0.0659],[-0.3748,  0.6389, -0.0861, -0.0706],[-0.3694,  0.6388, -0.0855, -0.0660]],[[-0.2365,  0.4541, -0.1811, -0.0354],[-0.2338,  0.4455, -0.1871, -0.0370],[-0.2332,  0.4458, -0.1867, -0.0363]]], grad_fn=<UnsafeViewBackward0>)
MultiHeadAttention

多头注意力机制务必自己debug一下,主要聚焦在理解如何拆分成多头的,不结合代码你很难理解多头的操作过程

1、queries.reshape(N, q_len, self.num_heads, self.head_dim).permute(0, 2, 1, 3) 处理之后的 size = torch.Size([64, 8, 10, 16])

  • 通过上述操作,queries 张量的最终形状变为 [N, self.num_heads, q_len, self.head_dim]。这样的排列方式使得每个注意力头可以单独处理对应的序列部分,而每个头的处理仅关注其分配到的特定维度 self.head_dim
  • 这个形状是为了后续的矩阵乘法操作准备的,其中每个头的查询将与对应的键进行点乘,以计算注意力分数

2、attention_scores = torch.matmul(queries, keys.transpose(-2, -1)) / torch.sqrt( torch.tensor(self.head_dim, dtype=torch.float32)) 将reshape后的quries的后两个维度进行转置后点乘,对应了 Q ⋅ K T Q \cdot K^T QKT ;根据demo这里的头数为8,所以公式中对应的下标 i i i 为8

3、在进行完多头注意力机制的计算后通过 torch.matmul(attention, values).permute(0, 2, 1, 3).reshape(N, q_len, self.embed_size) 整合,变回原来的 [batch_size,seq_length,embed_size]形状

# 测试多头注意力
embed_size = 128  # 嵌入维度
num_heads = 8    # 头数
attention = MultiHeadAttention(embed_size, num_heads)# 创建随机数据模拟 [batch_size, seq_length, embedding_dim]
batch_size = 64
seq_length = 10
dummy_values = torch.rand(batch_size, seq_length, embed_size)
dummy_keys = torch.rand(batch_size, seq_length, embed_size)
dummy_queries = torch.rand(batch_size, seq_length, embed_size)# 计算多头注意力输出
output = attention(dummy_values, dummy_keys, dummy_queries)
print(output.shape)  # [batch_size, seq_length, embed_size]

=======print=========

torch.Size([64, 10, 128])

如果你难以理解权重矩阵的拼接和拆分,推荐李宏毅的attention课程(YouTobe)

相关文章:

Pytorch手撸Attention

Pytorch手撸Attention 注释写的很详细了&#xff0c;对照着公式比较下更好理解&#xff0c;可以参考一下知乎的文章 注意力机制 import torch import torch.nn as nn import torch.nn.functional as Fclass SelfAttention(nn.Module):def __init__(self, embed_size):super(S…...

PyCharm 2024.1 发布:全面升级,助力高效编程!

PyCharm 2024.1 发布&#xff1a;全面升级&#xff0c;助力高效编程&#xff01; 文章目录 PyCharm 2024.1 发布&#xff1a;全面升级&#xff0c;助力高效编程&#xff01;摘要引言 Hugging Face&#xff1a;模型和数据集的快速文档预览针对 JavaScript 和 TypeScript 的全行代…...

Nginx基础(06)

Nginx基础&#xff08;05&#xff09; uWSGI 介绍 uWSGI 是一个 Web服务器 主要用途是将Web应用程序部署到生产环境中 可以用来连接Nginx服务与Python动态网站 1. 用 uWSGI 部署 Python 网站项目 配置 Nginx 使其可以将动态访问转交给 uWSGI 安装 python 工具及依赖 安…...

【Qt 学习笔记】QWidget的windowOpacity属性 | cursor属性 | font属性

博客主页&#xff1a;Duck Bro 博客主页系列专栏&#xff1a;Qt 专栏关注博主&#xff0c;后期持续更新系列文章如果有错误感谢请大家批评指出&#xff0c;及时修改感谢大家点赞&#x1f44d;收藏⭐评论✍ QWidget的windowOpacity属性 | cursor属性 | font属性 文章编号&#…...

Python爬虫:requests模块的基本使用

学习目标&#xff1a; 了解 requests模块的介绍掌握 requests的基本使用掌握 response常见的属性掌握 requests.text和content的区别掌握 解决网页的解码问题掌握 requests模块发送带headers的请求掌握 requests模块发送带参数的get请求 1 为什么要重点学习requests模块&…...

C++traits

traits C的标准库提供了<type_traits>,它定义了一些编译时基于模板类的接口用于查询、修改类型的特征&#xff1a;输入的时类型&#xff0c;输出与该类型相关的属性 通过type_traits技术编译器可以回答一系列问题&#xff1a;它是否为数值类型&#xff1f;是否为函数对象…...

gitee和idea集成

1 集成插件 2 配置账号密码 3 直接将项目传到仓库 4直接从gitee下载项目...

阿维·威格德森(Avi Wigderson)研究成果对人工智能领域的应用有哪些影响

AI人工智能的影响 威格德森&#xff08;Avi Wigderson&#xff09;的研究成果对人工智能领域的应用产生了深远的影响。 首先&#xff0c;威格德森在计算复杂性理论、算法和优化方面的贡献为人工智能领域提供了高效、准确的计算模型和算法。他的研究帮助我们更好地理解计算问题…...

【免费领取源码】可直接复用的医院管理系统!

今天给大家分享一套基于SpringbootVue的医院管理系统源码&#xff0c;在实际项目中可以直接复用。(免费提供&#xff0c;文中自取) 系统运行图&#xff08;设计报告和接口文档&#xff09; 1、后台管理页面 2、排班管理页面 3、设计报告包含接口文档 源码免费领取方式 后台私信…...

leetcode代码记录(全排列 II

目录 1. 题目&#xff1a;2. 我的代码&#xff1a;小结&#xff1a; 1. 题目&#xff1a; 给定一个可包含重复数字的序列 nums &#xff0c;按任意顺序 返回所有不重复的全排列。 示例 1&#xff1a; 输入&#xff1a;nums [1,1,2] 输出&#xff1a; [[1,1,2], [1,2,1], [2,1…...

【数据结构与算法】之双向链表及其实现!

​ 个人主页&#xff1a;秋风起&#xff0c;再归来~ 数据结构与算法 个人格言&#xff1a;悟已往之不谏&#xff0c;知来者犹可追 克心守己&#xff0c;律己则安&#xff01; 目录 1、双向链表的结构及概念 2、双向链表的实现 2.1 要实现的接口…...

记一次奇妙的某个edu渗透测试

前话&#xff1a; 对登录方法的轻视造成一系列的漏洞出现&#xff0c;对接口确实鉴权造成大量的信息泄露。从小程序到web端网址的奇妙的测试就此开始。&#xff08;文章厚码&#xff0c;请见谅&#xff09; 1. 寻找到目标站点的小程序 进入登录发现只需要姓名加学工号就能成功…...

设计模式学习笔记 - 设计模式与范式 -总结:1.回顾23中设计模式的原理、背后的思想、应用场景等

1.创建型设计模式 创建型设计模式包括&#xff1a;单例模式、工厂模式、建造者模式、原型模式。它主要解决对象的创建问题&#xff0c;封装复杂的创建过程&#xff0c;解耦对象的创建代码和使用代码。 1.单例模式 单例模式用来创建全局唯一的对象。一个类只允许创建一个对象…...

22 文件系统

了解了被打开的文件&#xff0c;肯定还有没被打开的文件&#xff0c;就是磁盘上的文件。先从磁盘开始认识 磁盘 概念 内存是掉电易失存储介质&#xff0c;磁盘是永久性存储介质 磁盘的种类有SSD&#xff0c;U盘&#xff0c;flash卡&#xff0c;光盘&#xff0c;磁带。磁盘是…...

OVITO-2.9版本

关注 M r . m a t e r i a l , \color{Violet} \rm Mr.material\ , Mr.material , 更 \color{red}{更} 更 多 \color{blue}{多} 多 精 \color{orange}{精} 精 彩 \color{green}{彩} 彩&#xff01; 主要专栏内容包括&#xff1a; †《LAMMPS小技巧》&#xff1a; ‾ \textbf…...

【Java开发指南 | 第一篇】类、对象基础概念及Java特征

读者可订阅专栏&#xff1a;Java开发指南 |【CSDN秋说】 文章目录 类、对象基础概念Java特征 Java 是一种面向对象的编程语言&#xff0c;它主要通过类和对象来组织和管理代码。 类、对象基础概念 类&#xff1a;类是一个模板&#xff0c;它描述一类对象的行为和状态。例如水…...

Neo4j 图形数据库中有哪些构建块?

Neo4j 图形数据库具有以下构建块 - 节点属性关系标签数据浏览器 节点 节点是 Graph 的基本单位。 它包含具有键值对的属性&#xff0c;如下图所示。 NEmployee 节点 在这里&#xff0c;节点 Name "Employee" &#xff0c;它包含一组属性作为键值对。 属性 属性是…...

002 springboot整合mybatis-plus

文章目录 TestMybatisGenerate.javapom.xmlapplication.yamlReceiveAddressMapper.xmlreceive_address.sqlReceiveAddress.javaReceiveAddressMapper.javaIReceiveAddressServiceReceiveAddressServiceImpl.javaReceiveAddressController.javaTestAddressService.javaSpringboo…...

代码随想录训练营第三十五期|第天16|二叉树part03|104.二叉树的最大深度 ● 111.二叉树的最小深度● 222.完全二叉树的节点个数

104. 二叉树的最大深度 - 力扣&#xff08;LeetCode&#xff09; 递归&#xff0c;可以前序遍历&#xff0c;也可以后序遍历 前序遍历是backtracking 下面是后序遍历的代码&#xff1a; /*** Definition for a binary tree node.* public class TreeNode {* int val;* …...

Mac版2024 CleanMyMac X 4.15.2 核心功能详解 cleanmymac这个软件怎么样?cleanmymac到底好不好用?

近些年伴随着苹果生态的蓬勃发展&#xff0c;越来越多的用户开始尝试接触Mac电脑。然而很多人上手Mac后会发现&#xff0c;它的使用逻辑与Windows存在很多不同&#xff0c;而且随着使用时间的增加&#xff0c;一些奇奇怪怪的文件也会占据有限的磁盘空间&#xff0c;进而影响使用…...

【华为OD机试】执行任务赚积分【C卷|100分】

题目描述 现有N个任务需要处理&#xff0c;同一时间只能处理一个任务&#xff0c;处理每个任务所需要的时间固定为1。 每个任务都有最晚处理时间限制和积分值&#xff0c;在最晚处理时间点之前处理完成任务才可获得对应的积分奖励。 可用于处理任务的时间有限&#xff0c;请问在…...

mybatis分页实现总结

1.mybatis拦截器相关知识 1.作用 mybatis的拦截器是mybatis提供的一个拓展机制&#xff0c;允许用户在使用时根据各自的需求对sql执行的各个阶段进行干预。比较常见的如对执行的sql进行监控&#xff0c;排查sql的执行时间&#xff0c;对sql进行拦截拼接需要的场景&#xff0c…...

Vue3——html-doc-js(html导出为word的js库)

一、下载 官方地址 html-doc-js - npm npm install html-doc-js 二、使用方法 // 使用页面中引入 import exportWord from html-doc-js// 配置项以及实现下载方法 const wrap document.getElementById(test)const config {document:document, //默认当前文档的document…...

第19天:信息打点-小程序应用解包反编译动态调试抓包静态分析源码架构

第十九天 本课意义 1.如何获取到目标小程序信息 2.如何从小程序中提取资产信息 一、Web&备案信息&单位名称中发现小程序 1.国内主流小程序平台 微信 百度 支付宝 抖音头条 2.小程序结构 1.主体结构 小程序包含一个描述整体程序的app和多个描述各自页面的page …...

外观模式:简化复杂系统的统一接口

在面向对象的软件开发中&#xff0c;外观模式是一种常用的结构型设计模式&#xff0c;旨在为复杂的系统提供一个简化的接口。通过创建一个统一的高级接口&#xff0c;这个模式帮助客户端通过一个简单的方式与复杂的子系统交互。本文将详细介绍外观模式的定义、实现、应用场景以…...

PHP数组去重

public function array_unique_key($arr,$key) {$tmp_arrarray();foreach($arr as $k > $v){if(in_array($v[$key],$tmp_arr)){ //判断是否重复unset($arr[$k]); //重复则删除}else{$tmp_arr[]$v[$key]; //将值存储在临时数组中}}return $arr; } public function array…...

论软件系统的架构风格,使用三段论 写一篇系统架构师论文

软件系统的架构风格是指在软件系统设计与开发过程中&#xff0c;采用的一组相互协调的设计原则、模式和实践。这些风格不仅影响着系统的技术实现&#xff0c;还关乎到系统的可维护性、可扩展性和可靠性等关键质量属性。通过三段论的结构&#xff0c;本文旨在探讨软件系统架构风…...

深度挖掘响应式模式的潜力,从而精准优化AI与机器学习项目的运行效能,引领技术革新潮流

​&#x1f308; 个人主页&#xff1a;danci_ &#x1f525; 系列专栏&#xff1a;《设计模式》 &#x1f4aa;&#x1f3fb; 制定明确可量化的目标&#xff0c;坚持默默的做事。 &#x1f525; 转载自热榜文章&#xff1a;探索设计模式的魅力&#xff1a;深度挖掘响应式模式的…...

企业级网络安全:入侵防御实时阻止,守护您的业务安全

随着互联网技术的快速发展&#xff0c;企业级网络安全问题日益凸显。在这个数字化时代&#xff0c;企业的业务安全不仅关系到企业的形象和声誉&#xff0c;还直接影响到企业的生存和发展。因此&#xff0c;加强企业级网络安全&#xff0c;预防和抵御各种网络攻击已成为企业的重…...

(一)Java八股——Redis

1 Redis缓存 1.1 什么是缓存穿透 ? 怎么解决 ?&#xff08;穿透&#xff09; ​ 缓存穿透是指查询一个一定不存在的数据&#xff0c;如果从存储层查不到数据则不写入缓存&#xff0c;这将导致这个不存在的数据每次请求都要到 DB 去查询&#xff0c;可能导致 DB 挂掉。这种情…...

杭州网站设计的公司/百度推广最简单方法

背景 通过添加下列类&#xff0c;可以快捷的变换背景颜色&#xff0c;如果是链接的话&#xff0c;鼠标移动上去会变暗 bg-primary 被修饰元素将会应到primary类&#xff0c;显示吃淡蓝色&#xff0c;文本颜色会变成白色。bg-success 被修饰元素表示成功的信息&#xff0c;背景变…...

开发一个游戏大约要多少钱/长治seo顾问

2017年4月4日&#xff0c;美国计算机协会&#xff08;ACM&#xff09;公布了2016年图灵奖的获得者&#xff1a;万维网&#xff08;World Wide Web&#xff09;的发明人、麻省理工学院教授 Tim Berners-Lee。 美国计算机协会给 Tim Berners-Lee 的获奖理由是&#xff1a;发明Wor…...

网站建设 东方网景/网站管理与维护

Expression TreesExpression Trees它的API核心是继承一个抽象基类叫Expression.从其他的Expression中创建一个Expression你能取一个表达式树和修改它,并用它创建另一个表达式.在下面的事例中,我们将开始使用一个x*x的lambda expression并且修改这个表达式,给它添加2.让我们看看…...

一个做任务赚钱的网站/saascrm国内免费pdf

一.time模块 import time # 当前时间戳 print(time.time()) # 获取当前时间 print(time.localtime()) # 获取格式化时间 print(time.strftime("%Y-%m-%d %H:%M:%S"))二.datetime模块 import datetime # 获取当前日期 print(datetime.date.today()) print(datetime.…...

网站排名超快/什么是软文

前言:本博文主要讲述vtkPolygon的相关内容,详细介绍vtkPolygon进行三角剖分时的具体步骤,以及vtkPolygon三角剖分存在的问题及解决方案。 目录 描述 存在问题 三角剖分算法 解决方案 描述 vtkPolygon是vtkCell的一个具体实现,用于表示2D的n多边形。多边形不能有任何内…...

独立商城网站建设/深圳关键词优化报价

生产环境下&#xff0c;有时候需要访问图片&#xff0c;正常需要应用ftp、nginx等配套使用&#xff0c;但是有时候为了简化&#xff0c;可以用以下的两种简单的访问&#xff0c;说实话&#xff0c;就是为了偷懒&#xff0c;但是效果是能有的&#xff0c;这就行了&#xff0c;所…...