当前位置: 首页 > news >正文

多头自注意力机制的代码实现

文章目录

  • 1、自注意力机制
  • 2、多头注意力机制

  • transformer的整体结构:
    在这里插入图片描述

1、自注意力机制

  • 自注意力机制如下:
    在这里插入图片描述
  • 计算过程:
    在这里插入图片描述
  • 代码如下:
class ScaledDotProductAttention(nn.Module):def __init__(self, embed_dim, key_size, value_size):super().__init__()self.W_q = nn.Linear(embed_dim, key_size, bias=False)self.W_k = nn.Linear(embed_dim, key_size, bias=False)self.W_v = nn.Linear(embed_dim, value_size, bias=False)def forward(self, x, attn_mask=None):"""Args:X: shape: (N, L, embed_dim), input sequence, 是经过input embedding后的输入序列,L个embed_dim维度的嵌入向量attn_mask: (N, L, L),用于对注意力矩阵(L, L)进行mask输出:shape:(N, L, embed_dim)"""query = self.W_q(x)  # (N, L, key_size)key = self.W_k(x)  # (N, L, key_size)value = self.W_v(x)  # (N, L, value_size)scores = torch.matmul(query, key.transpose(1, 2)) / math.sqrt(query.size(2))if attn_mask is not None:scores = scores.masked_fill(attn_mask, 0)attn_weights = F.softmax(scores, dim=-1)	# dim为-1表示,对每个嵌入向量与其他所有向量的注意力权重,进行softmax,以使每一行的和为1return torch.matmul(attn_weights, value)

2、多头注意力机制

  • 结构如下:
    在这里插入图片描述
  • 计算过程如下:
class MultiHeadSelfAttention(nn.Module):def __init__(self, embed_dim, num_heads, key_size, value_size, bias=False):super().__init__()self.embed_dim = embed_dimself.num_heads = num_headsself.q_head_dim = key_size // num_headsself.k_head_dim = key_size // num_headsself.v_head_dim = value_size // num_headsself.W_q = nn.Linear(embed_dim, key_size, bias=bias)self.W_k = nn.Linear(embed_dim, key_size, bias=bias)self.W_v = nn.Linear(embed_dim, value_size, bias=bias)        self.q_proj = nn.Linear(key_size, key_size, bias=bias)self.k_proj = nn.Linear(key_size, key_size, bias=bias)self.v_proj = nn.Linear(value_size, value_size, bias=bias)self.out_proj = nn.Linear(value_size, embed_dim, bias=bias)def forward(self, x):"""Args:X: shape: (N, L, embed_dim), input sequence, 是经过input embedding后的输入序列,L个embed_dim维度的嵌入向量Returns:output: (N, L, embed_dim)"""query = self.W_q(x)  # (N, L, key_size)key = self.W_k(x)  # (N, L, key_size)value = self.W_v(x)  # (N, L, value_size)q, k, v = self.q_proj(query), self.k_proj(key), self.v_proj(value)N, L, value_size = v.size()q = q.reshape(N, L, self.num_heads, self.q_head_dim).transpose(1, 2)k = k.reshape(N, L, self.num_heads, self.k_head_dim).transpose(1, 2)v = v.reshape(N, L, self.num_heads, self.v_head_dim).transpose(1, 2)att = torch.matmul(q, k.transpose(-1, -2)) / math.sqrt(k.size(-1))att = F.softmax(att, dim=-1)output = torch.matmul(att, v)output = output.transpose(1, 2).reshape(N, L, value_size)output = self.out_proj(output)return output

相关文章:

多头自注意力机制的代码实现

文章目录 1、自注意力机制2、多头注意力机制 transformer的整体结构: 1、自注意力机制 自注意力机制如下: 计算过程: 代码如下: class ScaledDotProductAttention(nn.Module):def __init__(self, embed_dim, key_size, value_…...

抽象工厂模式

目录 了解抽象工厂模式前的前置知识 什么是抽象工厂模式? 为什么要提出抽象工厂模式? 抽象工厂模式中的四大角色? 抽象工厂模式的优缺点? 抽象工厂模式的适用场景? 了解抽象工厂模式前的前置知识 在讲抽象工厂模式…...

登录校验-Filter-详解

目录 执行流程 拦截路径 过滤器链 小结 执行流程 过滤器Filter拦截到请求之后,首先执行方放行之前的逻辑,然后执行放行操作(doFilter),然后会访问对应的Web资源(对应的Controller类)&#…...

堆栈方法区笔记记录

成员变量分两种: 1)实例变量:没有static修饰,属于对象,存储在堆中,有几个对象就有几份,通过对象点来访问 2)静态变量:由static修饰,属于类,存储在方法区中,只有一份,通过类名点来访…...

新版微信小程序获取用户手机号

小程序手机号验证组件有两种 手机号快速验证组件 //原生写法 <button open-type"getPhoneNumber" bindgetphonenumber"getPhoneNumber"></button>Page({getPhoneNumber (e) {console.log(e.detail.code)} })uniapp写法 <button open-type…...

CSS实践 —— 悬浮盒子阴影加上移效果

悬浮盒子阴影加上移效果 代码 代码 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><title>Title</title><style>body{background-color: #f5f5f5;}.shadow {width: 100px;height: 100px;margin:…...

安全测试基础知识

软件安全测试是评估和测试系统以发现系统及其数据的安全风险和漏洞的过程。没有通用术语&#xff0c;但出于我们的目的&#xff0c;我们将评估定义为分析和发现漏洞&#xff0c;而不尝试实际利用这些漏洞。我们将测试定义为发现和尝试利用漏洞。 安全测试通常根据要测试的漏洞…...

列表首屏毫秒级加载与自动滚动定位方案

引用自 摸鱼wiki 场景 <template><div ref"commentsRef"><divv-for"comment in displayComments":key"comment.id":data-cell-id"comment.id"class"card">{{ comment.data }}</div></div> &…...

小区物业业主管理信息系统设计的设计与实现(论文+源码)_kaic

摘 要 随着互联网的发展&#xff0c;网络技术的发展变得极其重要&#xff0c;所以依靠计算机处理业务成为了一种社会普遍的现状。管理方式也自然而然的向着现代化技术方向而改变&#xff0c;所以纯人工管理方式在越来越完善的现代化管理技术的比较之下也就显得过于繁琐&#x…...

Fortran 微分方程求解 --ODEPACK

最近涉及到使用Fortran对微分方程求解&#xff0c;我们知道MATLAB已有内置的函数&#xff0c;比如ode家族&#xff0c;ode15s&#xff0c;对应着不同的求解办法。通过查看odepack的官方文档&#xff0c;我尝试使用了dlsode求解刚性和非刚性常微分方程组。 首先是github网址&am…...

8路光栅尺磁栅尺编码器或16路高速DI脉冲信号转Modbus TCP网络模块 YL99-RJ45

特点&#xff1a; ● 光栅尺磁栅尺解码转换成标准Modbus TCP协议 ● 高速光栅尺磁栅尺4倍频计数&#xff0c;频率可达5MHz ● 模块可以输出5V的电源给光栅尺或传感器供电 ● 支持8个光栅尺同时计数&#xff0c;可识别正反转 ● 可以设置作为16路独立DI高速计数器 ● 可网…...

【Python】函数

None类型 思考&#xff1a;若函数没有使用return语句返回数据&#xff0c;那么函数有返回值吗&#xff1f; 答&#xff1a;实际上是有的&#xff0c;Python中有一个特殊的字面量None&#xff0c;其类型是<class ‘NoneType’>&#xff0c;无返回值的函数&#xff0c;实…...

centos安装MySQL 解压版完整教程(按步骤傻瓜式安装

一、卸载系统自带的 Mariadb 查看&#xff1a; rpm -qa|grep mariadb 卸载&#xff1a; rpm -e --nodeps mariadb-libs-5.5.68-1.el7.x86_64 二、卸载 etc 目录下的 my.cnf 文件 rm -rf /etc/my.cnf 三、检查MySQL是否存在 有则先删除 #卸载mysql服务以及删除所有mysql目录 #没…...

【后端速成 Vue】第一个 Vue 程序

1、为什么要学习 Vue&#xff1f; 为什么使用 Vue? 回想之前&#xff0c;前后端交互的时候&#xff0c;前端收到后端响应的数据&#xff0c;接着将数据渲染到页面上&#xff0c;之前使用的是 JavaScript 或者 基于 JavaScript 的 Jquery&#xff0c;但是这两个用起来还是不太…...

Macbook pro M1 安装Ubuntu教程

先讲下心路历程 由于版主最近刚切换到Mac&#xff0c;所以在安装的时候一上手就选择了virutalbox&#xff0c;结果报错“The installer has detected an unsupported architecture. VirtualBox only runs on the amd64 architecture.” 后来去Reddit论坛上一看&#xff0c;才知…...

前端console.log打印内容与后端请求返回数据不一致

后端传值num0 前端打印num1 ,如图&#xff0c;console.log后台显示的数据与展开后不一致 造成该问题原因是深拷贝与浅拷贝的问题。 var obj JSON.parse(JSON.stringify(res)) 修改后打印 正常...

SQL入门:多表查询

SQL&#xff0c;或者说结构化查询语言(Structured Query Language)&#xff0c;是用于管理和操作关系型数据库的标准语言。在本篇文章中&#xff0c;我们将重点介绍SQL中的多表查询&#xff0c;这是一种强大的工具&#xff0c;可以帮助我们从多个相关的表格中获取数据。 数据库…...

【C++】进一步认识模板

&#x1f3d6;️作者&#xff1a;malloc不出对象 ⛺专栏&#xff1a;C的学习之路 &#x1f466;个人简介&#xff1a;一名双非本科院校大二在读的科班编程菜鸟&#xff0c;努力编程只为赶上各位大佬的步伐&#x1f648;&#x1f648; 目录 前言一、非类型模板参数二、模板的特…...

Mysql Oracle 区别

1. oracle select *, id需要在星号前加别名&#xff0c;mysql则不需要 mysql语法&#xff1a; select *, id from xin_student_t;oracle语法&#xff1a; select st.*, st.id from xin_student_t st;2. oracle表定义了别名&#xff0c;在查询时可以不用别名指定字段&#xf…...

华为OD-第K长的连续字母字符串长度

题目描述 给定一个字符串&#xff0c;只包含大写字母&#xff0c;求在包含同一字母的子串中&#xff0c;长度第 k 长的子串的长度&#xff0c;相同字母只取最长的那个子串。 代码实现 # coding:utf-8 # 第K长的连续字母字符串长度 # https://www.nowcoder.com/discuss/353150…...

【编程题】有效三角形的个数

文章目录 一、题目二、算法讲解三、题目链接四、补充 一、题目 给定一个包含非负整数的数组 nums &#xff0c;返回其中可以组成三角形三条边的三元组个数。 示例1&#xff1a; 输入: nums [2,2,3,4] 输出: 3 **解释:**有效的组合是: 2,3,4 (使用第一个 2) 2,3,4 (使用第二个 …...

【mysql是怎样运行的】-EXPLAIN详解

文章目录 1.基本语法2. EXPLAIN各列作用1. table2. id3. select_type4. partitions5. type 1.基本语法 EXPLAIN SELECT select_options #或者 DESCRIBE SELECT select_optionsEXPLAIN 语句输出的各个列的作用如下&#xff1a; 列名描述id在一个大的查询语句中每个SELECT关键…...

数据结构例题代码及其讲解-链表

链表 单链表的结构体定义及其初始化。 typedef struct LNode {int data;struct LNode* next; }LNode, *LinkList;①强调结点 LNode *p; ②强调链表 LinkList p; //初始化 LNode* initList() {//定义头结点LNode* L (LNode*)malloc(sizeof(LNode));L->next NULL;return …...

[Open-source tool] 可搭配PHP和SQL的表單開源工具_Form tools(1):簡介和建置

Form tools是一套可搭配PHP和SQL的表單開源工具&#xff0c;可讓開發者靈活運用&#xff0c;同時其有數個表單模板和應用模組供挑選&#xff0c;方便且彈性。Form tools已開發超過20年&#xff0c;為不同領域的需求者或開發者提供一個自由和開放的平台&#xff0c;使他們可建構…...

移动数据业务价值链的整合

3G 时代移动数据业务开发体系的建立和发展&#xff0c;要求运营商从封闭、统一的业 务形态、单一提供业务&#xff0c;向开放的、个性化多元化的业务体系以及多方合作参与提 供业务的方向发展&#xff0c;不可避免的使通信价值链不断延长和升级&#xff0c;内容提供商、服务 …...

合并两个链表

题目描述 将两个升序链表合并为一个新的 升序 链表并返回。新链表是通过拼接给定的两个链表的所有节点组成的。 比如以下例子&#xff1a; 题目接口&#xff1a; /*** Definition for singly-linked list.* struct ListNode {* int val;* ListNode *next;* ListN…...

测试框架pytest教程(9)跳过测试skip和xfail

skip无条件跳过 使用装饰器 pytest.mark.skip(reason"no way of currently testing this") def test_example(faker):print("nihao")print(faker.words()) 方法内部调用 满足条件时跳过 def test_example():a1if a>0:pytest.skip("unsupported …...

HTML <textarea> 标签

实例 <textarea rows="3" cols="20"> 收拾收拾 </textarea>定义和用法 <textarea> 标签定义多行的文本输入控件。 文本区中可容纳无限数量的文本,其中的文本的默认字体是等宽字体(通常是 Courier)。 可以通过 cols 和 rows 属性来…...

探索图结构:从基础到算法应用

文章目录 理解图的基本概念学习图的遍历算法学习最短路径算法案例分析&#xff1a;使用 Dijkstra 算法找出最短路径结论 &#x1f389;欢迎来到数据结构学习专栏~探索图结构&#xff1a;从基础到算法应用 ☆* o(≧▽≦)o *☆嗨~我是IT陈寒&#x1f379;✨博客主页&#xff1a;I…...

Redis之GEO类型解读

目录 基本介绍 基本命令 geoadd 命令 geopos 命令 geodist 命令 georadius 命令 georadiusbymember 命令 geohash 命令 基本介绍 GEO 主要用于存储地理位置信息&#xff08;纬度、经度、名称&#xff09;添加到指定的key中。该功能在 Redis 3.2 版本新增。 GEO&…...

wordpress字号修改/seo点击

本文是对《【硬刚大数据之学习路线篇】从零到大数据专家的学习指南(全面升级版)》的ES部分补充。...

广州网站建设设计厂家/百度竞价托管

1.首先查看服务器是否支持伪静态。。查看方法&#xff1a;$arrapache_get_modules();$tempfalse;for($i0;$i{i f($arr[$i]"mod_rewrite"){$temptrue;}}如果temptrue则支持(可以使是否支持在网页上显示)&#xff0c;如果用wamp集成环境就可以直接查看Apache modules 里…...

微信企业号网站开发软件/河南网站关键词优化代理

2019独角兽企业重金招聘Python工程师标准>>> 这篇主要看的是mybatis对于数据库连接池的实现。实现类为org.apache.ibatis.datasource.pooled包下的PooledDataSource类。 数据库连接池的代码并不难&#xff0c;关键在于理解他的连接和释放的策略。 连接池参数 protec…...

app 网站可以做的免费推广/网址收录网站

本文主要介绍九江职业技术学院2020有哪些专业及什么专业好的相关信息&#xff0c;对学校感兴趣&#xff0c;想要报考该校的同学请信息的阅读文章&#xff0c;若有其他有关该校的招生方面的信息可以直接咨询网站的在线老师&#xff0c;向他们进行咨询.一、九江职业技术学院专业大…...

自己的网站建设/百度重庆营销中心

linux使用于广泛的体系结构&#xff0c;因此需要用一种与体系结构无关的方式来描述内存。linux用VM描述和管理内存。在VM中兽药的普遍概念就是非一致内存访问。对于大型机器而言&#xff0c;内存会分成许多簇&#xff0c;依据簇与处理器“距离”的不同&#xff0c;访问不同的簇…...

哈尔滨关键词优化推广/如何做谷歌seo推广

SRILM是一个统计和分析语言模型的工具&#xff0c;提供一些命令行工具&#xff0c;如ngram,ngram-count&#xff0c;可以很方便的统计NGRAM的语言模型。 1&#xff0c;下载 我开始在这个站上下载&#xff0c;感觉很慢。 http://www.speech.sri.com/projects/srilm/download.h…...