当前位置: 首页 > news >正文

多头自注意力机制的代码实现

文章目录

  • 1、自注意力机制
  • 2、多头注意力机制

  • transformer的整体结构:
    在这里插入图片描述

1、自注意力机制

  • 自注意力机制如下:
    在这里插入图片描述
  • 计算过程:
    在这里插入图片描述
  • 代码如下:
class ScaledDotProductAttention(nn.Module):def __init__(self, embed_dim, key_size, value_size):super().__init__()self.W_q = nn.Linear(embed_dim, key_size, bias=False)self.W_k = nn.Linear(embed_dim, key_size, bias=False)self.W_v = nn.Linear(embed_dim, value_size, bias=False)def forward(self, x, attn_mask=None):"""Args:X: shape: (N, L, embed_dim), input sequence, 是经过input embedding后的输入序列,L个embed_dim维度的嵌入向量attn_mask: (N, L, L),用于对注意力矩阵(L, L)进行mask输出:shape:(N, L, embed_dim)"""query = self.W_q(x)  # (N, L, key_size)key = self.W_k(x)  # (N, L, key_size)value = self.W_v(x)  # (N, L, value_size)scores = torch.matmul(query, key.transpose(1, 2)) / math.sqrt(query.size(2))if attn_mask is not None:scores = scores.masked_fill(attn_mask, 0)attn_weights = F.softmax(scores, dim=-1)	# dim为-1表示,对每个嵌入向量与其他所有向量的注意力权重,进行softmax,以使每一行的和为1return torch.matmul(attn_weights, value)

2、多头注意力机制

  • 结构如下:
    在这里插入图片描述
  • 计算过程如下:
class MultiHeadSelfAttention(nn.Module):def __init__(self, embed_dim, num_heads, key_size, value_size, bias=False):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.q_head_dim = key_size // num_headsself.k_head_dim = key_size // num_headsself.v_head_dim = value_size // num_headsself.W_q = nn.Linear(embed_dim, key_size, bias=bias)self.W_k = nn.Linear(embed_dim, key_size, bias=bias)self.W_v = nn.Linear(embed_dim, value_size, bias=bias)        self.q_proj = nn.Linear(key_size, key_size, bias=bias)self.k_proj = nn.Linear(key_size, key_size, bias=bias)self.v_proj = nn.Linear(value_size, value_size, bias=bias)self.out_proj = nn.Linear(value_size, embed_dim, bias=bias)def forward(self, x):"""Args:X: shape: (N, L, embed_dim), input sequence, 是经过input embedding后的输入序列,L个embed_dim维度的嵌入向量Returns:output: (N, L, embed_dim)"""query = self.W_q(x)  # (N, L, key_size)key = self.W_k(x)  # (N, L, key_size)value = self.W_v(x)  # (N, L, value_size)q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)N, L, value_size = v.size()q = q.reshape(N, L, self.num_heads, self.q_head_dim).transpose(1, 2)k = k.reshape(N, L, self.num_heads, self.k_head_dim).transpose(1, 2)v = v.reshape(N, L, self.num_heads, self.v_head_dim).transpose(1, 2)att = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(k.size(-1))att = F.softmax(att, dim=-1)output = torch.matmul(att, v)output = output.transpose(1, 2).reshape(N, L, value_size)output = self.out_proj(output)return output

相关文章:

多头自注意力机制的代码实现

文章目录 1、自注意力机制2、多头注意力机制 transformer的整体结构: 1、自注意力机制 自注意力机制如下: 计算过程: 代码如下: class ScaledDotProductAttention(nn.Module):def __init__(self, embed_dim, key_size, value_…...

抽象工厂模式

目录 了解抽象工厂模式前的前置知识 什么是抽象工厂模式? 为什么要提出抽象工厂模式? 抽象工厂模式中的四大角色? 抽象工厂模式的优缺点? 抽象工厂模式的适用场景? 了解抽象工厂模式前的前置知识 在讲抽象工厂模式…...

登录校验-Filter-详解

目录 执行流程 拦截路径 过滤器链 小结 执行流程 过滤器Filter拦截到请求之后,首先执行方放行之前的逻辑,然后执行放行操作(doFilter),然后会访问对应的Web资源(对应的Controller类)&#…...

堆栈方法区笔记记录

成员变量分两种: 1)实例变量:没有static修饰,属于对象,存储在堆中,有几个对象就有几份,通过对象点来访问 2)静态变量:由static修饰,属于类,存储在方法区中,只有一份,通过类名点来访…...

新版微信小程序获取用户手机号

小程序手机号验证组件有两种 手机号快速验证组件 //原生写法 <button open-type"getPhoneNumber" bindgetphonenumber"getPhoneNumber"></button>Page({getPhoneNumber (e) {console.log(e.detail.code)} })uniapp写法 <button open-type…...

CSS实践 —— 悬浮盒子阴影加上移效果

悬浮盒子阴影加上移效果 代码 代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><title>Title</title><style>body{background-color: #f5f5f5;}.shadow {width: 100px;height: 100px;margin:…...

安全测试基础知识

软件安全测试是评估和测试系统以发现系统及其数据的安全风险和漏洞的过程。没有通用术语&#xff0c;但出于我们的目的&#xff0c;我们将评估定义为分析和发现漏洞&#xff0c;而不尝试实际利用这些漏洞。我们将测试定义为发现和尝试利用漏洞。 安全测试通常根据要测试的漏洞…...

列表首屏毫秒级加载与自动滚动定位方案

引用自 摸鱼wiki 场景 <template><div ref"commentsRef"><divv-for"comment in displayComments":key"comment.id":data-cell-id"comment.id"class"card">{{ comment.data }}</div></div> &…...

小区物业业主管理信息系统设计的设计与实现(论文+源码)_kaic

摘 要 随着互联网的发展&#xff0c;网络技术的发展变得极其重要&#xff0c;所以依靠计算机处理业务成为了一种社会普遍的现状。管理方式也自然而然的向着现代化技术方向而改变&#xff0c;所以纯人工管理方式在越来越完善的现代化管理技术的比较之下也就显得过于繁琐&#x…...

Fortran 微分方程求解 --ODEPACK

最近涉及到使用Fortran对微分方程求解&#xff0c;我们知道MATLAB已有内置的函数&#xff0c;比如ode家族&#xff0c;ode15s&#xff0c;对应着不同的求解办法。通过查看odepack的官方文档&#xff0c;我尝试使用了dlsode求解刚性和非刚性常微分方程组。 首先是github网址&am…...

8路光栅尺磁栅尺编码器或16路高速DI脉冲信号转Modbus TCP网络模块 YL99-RJ45

特点&#xff1a; ● 光栅尺磁栅尺解码转换成标准Modbus TCP协议 ● 高速光栅尺磁栅尺4倍频计数&#xff0c;频率可达5MHz ● 模块可以输出5V的电源给光栅尺或传感器供电 ● 支持8个光栅尺同时计数&#xff0c;可识别正反转 ● 可以设置作为16路独立DI高速计数器 ● 可网…...

【Python】函数

None类型 思考&#xff1a;若函数没有使用return语句返回数据&#xff0c;那么函数有返回值吗&#xff1f; 答&#xff1a;实际上是有的&#xff0c;Python中有一个特殊的字面量None&#xff0c;其类型是<class ‘NoneType’>&#xff0c;无返回值的函数&#xff0c;实…...

centos安装MySQL 解压版完整教程(按步骤傻瓜式安装

一、卸载系统自带的 Mariadb 查看&#xff1a; rpm -qa|grep mariadb 卸载&#xff1a; rpm -e --nodeps mariadb-libs-5.5.68-1.el7.x86_64 二、卸载 etc 目录下的 my.cnf 文件 rm -rf /etc/my.cnf 三、检查MySQL是否存在 有则先删除 #卸载mysql服务以及删除所有mysql目录 #没…...

【后端速成 Vue】第一个 Vue 程序

1、为什么要学习 Vue&#xff1f; 为什么使用 Vue? 回想之前&#xff0c;前后端交互的时候&#xff0c;前端收到后端响应的数据&#xff0c;接着将数据渲染到页面上&#xff0c;之前使用的是 JavaScript 或者 基于 JavaScript 的 Jquery&#xff0c;但是这两个用起来还是不太…...

Macbook pro M1 安装Ubuntu教程

先讲下心路历程 由于版主最近刚切换到Mac&#xff0c;所以在安装的时候一上手就选择了virutalbox&#xff0c;结果报错“The installer has detected an unsupported architecture. VirtualBox only runs on the amd64 architecture.” 后来去Reddit论坛上一看&#xff0c;才知…...

前端console.log打印内容与后端请求返回数据不一致

后端传值num0 前端打印num1 ,如图&#xff0c;console.log后台显示的数据与展开后不一致 造成该问题原因是深拷贝与浅拷贝的问题。 var obj JSON.parse(JSON.stringify(res)) 修改后打印 正常...

SQL入门:多表查询

SQL&#xff0c;或者说结构化查询语言(Structured Query Language)&#xff0c;是用于管理和操作关系型数据库的标准语言。在本篇文章中&#xff0c;我们将重点介绍SQL中的多表查询&#xff0c;这是一种强大的工具&#xff0c;可以帮助我们从多个相关的表格中获取数据。 数据库…...

【C++】进一步认识模板

&#x1f3d6;️作者&#xff1a;malloc不出对象 ⛺专栏&#xff1a;C的学习之路 &#x1f466;个人简介&#xff1a;一名双非本科院校大二在读的科班编程菜鸟&#xff0c;努力编程只为赶上各位大佬的步伐&#x1f648;&#x1f648; 目录 前言一、非类型模板参数二、模板的特…...

Mysql Oracle 区别

1. oracle select *, id需要在星号前加别名&#xff0c;mysql则不需要 mysql语法&#xff1a; select *, id from xin_student_t;oracle语法&#xff1a; select st.*, st.id from xin_student_t st;2. oracle表定义了别名&#xff0c;在查询时可以不用别名指定字段&#xf…...

华为OD-第K长的连续字母字符串长度

题目描述 给定一个字符串&#xff0c;只包含大写字母&#xff0c;求在包含同一字母的子串中&#xff0c;长度第 k 长的子串的长度&#xff0c;相同字母只取最长的那个子串。 代码实现 # coding:utf-8 # 第K长的连续字母字符串长度 # https://www.nowcoder.com/discuss/353150…...

工程地质软件市场:发展现状、趋势与策略建议

一、引言 在工程建设领域&#xff0c;准确把握地质条件是确保项目顺利推进和安全运营的关键。工程地质软件作为处理、分析、模拟和展示工程地质数据的重要工具&#xff0c;正发挥着日益重要的作用。它凭借强大的数据处理能力、三维建模功能、空间分析工具和可视化展示手段&…...

CocosCreator 之 JavaScript/TypeScript和Java的相互交互

引擎版本&#xff1a; 3.8.1 语言&#xff1a; JavaScript/TypeScript、C、Java 环境&#xff1a;Window 参考&#xff1a;Java原生反射机制 您好&#xff0c;我是鹤九日&#xff01; 回顾 在上篇文章中&#xff1a;CocosCreator Android项目接入UnityAds 广告SDK。 我们简单讲…...

新能源汽车智慧充电桩管理方案:新能源充电桩散热问题及消防安全监管方案

随着新能源汽车的快速普及&#xff0c;充电桩作为核心配套设施&#xff0c;其安全性与可靠性备受关注。然而&#xff0c;在高温、高负荷运行环境下&#xff0c;充电桩的散热问题与消防安全隐患日益凸显&#xff0c;成为制约行业发展的关键瓶颈。 如何通过智慧化管理手段优化散…...

EtherNet/IP转DeviceNet协议网关详解

一&#xff0c;设备主要功能 疆鸿智能JH-DVN-EIP本产品是自主研发的一款EtherNet/IP从站功能的通讯网关。该产品主要功能是连接DeviceNet总线和EtherNet/IP网络&#xff0c;本网关连接到EtherNet/IP总线中做为从站使用&#xff0c;连接到DeviceNet总线中做为从站使用。 在自动…...

SiFli 52把Imagie图片,Font字体资源放在指定位置,编译成指定img.bin和font.bin的问题

分区配置 (ptab.json) img 属性介绍&#xff1a; img 属性指定分区存放的 image 名称&#xff0c;指定的 image 名称必须是当前工程生成的 binary 。 如果 binary 有多个文件&#xff0c;则以 proj_name:binary_name 格式指定文件名&#xff0c; proj_name 为工程 名&…...

安宝特案例丨Vuzix AR智能眼镜集成专业软件,助力卢森堡医院药房转型,赢得辉瑞创新奖

在Vuzix M400 AR智能眼镜的助力下&#xff0c;卢森堡罗伯特舒曼医院&#xff08;the Robert Schuman Hospitals, HRS&#xff09;凭借在无菌制剂生产流程中引入增强现实技术&#xff08;AR&#xff09;创新项目&#xff0c;荣获了2024年6月7日由卢森堡医院药剂师协会&#xff0…...

Go 语言并发编程基础:无缓冲与有缓冲通道

在上一章节中&#xff0c;我们了解了 Channel 的基本用法。本章将重点分析 Go 中通道的两种类型 —— 无缓冲通道与有缓冲通道&#xff0c;它们在并发编程中各具特点和应用场景。 一、通道的基本分类 类型定义形式特点无缓冲通道make(chan T)发送和接收都必须准备好&#xff0…...

学习一下用鸿蒙​​DevEco Studio HarmonyOS5实现百度地图

在鸿蒙&#xff08;HarmonyOS5&#xff09;中集成百度地图&#xff0c;可以通过以下步骤和技术方案实现。结合鸿蒙的分布式能力和百度地图的API&#xff0c;可以构建跨设备的定位、导航和地图展示功能。 ​​1. 鸿蒙环境准备​​ ​​开发工具​​&#xff1a;下载安装 ​​De…...

LCTF液晶可调谐滤波器在多光谱相机捕捉无人机目标检测中的作用

中达瑞和自2005年成立以来&#xff0c;一直在光谱成像领域深度钻研和发展&#xff0c;始终致力于研发高性能、高可靠性的光谱成像相机&#xff0c;为科研院校提供更优的产品和服务。在《低空背景下无人机目标的光谱特征研究及目标检测应用》这篇论文中提到中达瑞和 LCTF 作为多…...

Modbus RTU与Modbus TCP详解指南

目录 1. Modbus协议基础 1.1 什么是Modbus? 1.2 Modbus协议历史 1.3 Modbus协议族 1.4 Modbus通信模型 🎭 主从架构 🔄 请求响应模式 2. Modbus RTU详解 2.1 RTU是什么? 2.2 RTU物理层 🔌 连接方式 ⚡ 通信参数 2.3 RTU数据帧格式 📦 帧结构详解 🔍…...