代驾网站开发/新型营销方式
在《Transformer的PyTorch实现之若干问题探讨(一)》中探讨了Transformer的训练整体流程,本文进一步探讨Transformer训练过程中teacher forcing的实现原理。
1.Transformer中decoder的流程
在论文《Attention is all you need》中,关于encoder及self attention有较为详细的论述,这也是网上很多教程在谈及transformer时候会重点讨论的部分。但是关于transformer的decoder部分,他的结构上与encoder实际非常像,但其中有一些巧妙的设计。本文会详细谈谈。首先给出一个完整transformer的结构图:
上图左侧为encoder部分,右侧为decoder部分。对于decoder部分,将enc_input经过multi head attention后得到的张量,以K,V送入decoder中。而decoder阶段的masked multi head attention需要解决如何将dec_input编码成Q。最终输出的logits实际是与Q的维度一致。对于Scaled Dot-Product Attention,其公式如下:
A t t e n t i o n ( Q , K , V ) = s o f t m a x ( Q K T d k ) V Attention(Q, K, V) = softmax\left(\frac{QK^T}{\sqrt{d_k}}\right)V Attention(Q,K,V)=softmax(dkQKT)V
在《Transformer的PyTorch实现之若干问题探讨(一)》中,decoder阶段,Q的维度为[2,8,6,64](2为batch size,8为head数,6为句子长度,64为向量长度),K的维度为[2,8,5,64],V的维度为[2,8,5,64]。其中, Q K T QK^T QKT的维度为[2,8,6,5] 的,可以理解每个查询张量Q对每个键值张K的注意力权重。之后乘以V,维度为[2,8,6,64]。可以看到最终的维度是根据查询张量Q来加权值向量V。Q就是dec_input经过masked multi head attention得来。那么,dec_input中实际是包含了所有的标签的。那么dec_input是如何mask掉不需要的token的呢?
2.Decoder中的self attention mask
class Decoder(nn.Module):def __init__(self):super(Decoder, self).__init__()self.tgt_emb = nn.Embedding(tgt_vocab_size, d_model)self.pos_emb = PositionalEncoding(d_model)self.layers = nn.ModuleList([DecoderLayer() for _ in range(n_layers)])def forward(self, dec_inputs, enc_inputs, enc_outputs):'''这三个参数对应的不是Q、K、V,dec_inputs是Q,enc_outputs是K和V,enc_inputs是用来计算padding mask的dec_inputs: [batch_size, tgt_len]enc_inpus: [batch_size, src_len]enc_outputs: [batch_size, src_len, d_model]'''dec_outputs = self.tgt_emb(dec_inputs)#词序号编码成向量dec_outputs = self.pos_emb(dec_outputs).cuda()#位置编码dec_self_attn_pad_mask = get_attn_pad_mask(dec_inputs, dec_inputs).cuda() #[2, 6, 6]dec_self_attn_subsequence_mask = get_attn_subsequence_mask(dec_inputs).cuda() #[2, 6, 6],上三角矩阵# 将两个mask叠加,布尔值可以视为0和1,和大于0的位置是需要被mask掉的,赋为True,和为0的位置是有意义的为Falsedec_self_attn_mask = torch.gt((dec_self_attn_pad_mask +dec_self_attn_subsequence_mask), 0).cuda()# 这是co-attention部分,为啥传入的是enc_inputs而不是enc_outputs:enc_outputs是向量,这儿是需要通过词编码来判断是否需要mask掉dec_enc_attn_mask = get_attn_pad_mask(dec_inputs, enc_inputs) #[2, 6, 5]for layer in self.layers:dec_outputs = layer(dec_outputs, enc_outputs, dec_self_attn_mask, dec_enc_attn_mask)return dec_outputs # dec_outputs: [batch_size, tgt_len, d_model]
上述代码为Decoder部分。可以看到有两个mask:dec_self_attn_pad_mask(用于将dec_inputs中的P mask掉)与dec_self_attn_subsequence_mask(用于实现decoder的self attention)。这两个mask在后面会相加合并。这儿可以分别展示二者的值,其中:
dec_self_attn_pad_mask:
tensor([[[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False]],[[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False],[False, False, False, False, False, False]]], device='cuda:0')#[2, 6, 6]
dec_self_attn_subsequence_mask:
tensor([[[0, 1, 1, 1, 1, 1],[0, 0, 1, 1, 1, 1],[0, 0, 0, 1, 1, 1],[0, 0, 0, 0, 1, 1],[0, 0, 0, 0, 0, 1],[0, 0, 0, 0, 0, 0]],[[0, 1, 1, 1, 1, 1],[0, 0, 1, 1, 1, 1],[0, 0, 0, 1, 1, 1],[0, 0, 0, 0, 1, 1],[0, 0, 0, 0, 0, 1],[0, 0, 0, 0, 0, 0]]], device='cuda:0', dtype=torch.uint8)#[2, 6, 6]
可以看到,dec_self_attn_pad_mask全为false,这是因为dec_input中不包含P,而dec_self_attn_subsequence_mask为上三角矩阵,对于每个token,需要mask掉它之后的token(本代码中,为1或True的位置会被mask掉)。接下来进一步追问,为什么上三角矩阵就可以mask掉该token之后的token?具体是如何实现的呢?
对于前文的Scaled Dot-Product Attention公式,代码中的表述实际为:
def forward(self, Q, K, V, attn_mask):'''Q: [batch_size, n_heads, len_q, d_k]K: [batch_size, n_heads, len_k, d_k]V: [batch_size, n_heads, len_v(=len_k), d_v] 全文两处用到注意力,一处是self attention,另一处是co attention,前者不必说,后者的k和v都是encoder的输出,所以k和v的形状总是相同的attn_mask: [batch_size, n_heads, seq_len, seq_len]'''# 1) 计算注意力分数QK^T/sqrt(d_k)scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores: [batch_size, n_heads, len_q, len_k]# 2) 进行 mask 和 softmax# mask为True的位置会被设为-1e9scores.masked_fill_(attn_mask, -1e9) # 把True设为-1e9attn = nn.Softmax(dim=-1)(scores) # attn: [batch_size, n_heads, len_q, len_k]# 3) 乘V得到最终的加权和context = torch.matmul(attn, V) # context: [batch_size, n_heads, len_q, d_v], [2, 8, 5, 64]'''得出的context是每个维度(d_1-d_v)都考虑了在当前维度(这一列)当前token对所有token的注意力后更新的新的值,换言之每个维度d是相互独立的,每个维度考虑自己的所有token的注意力,所以可以理解成1列扩展到多列返回的context: [batch_size, n_heads, len_q, d_v]本质上还是batch_size个句子,只不过每个句子中词向量维度512被分成了8个部分,分别由8个头各自看一部分,每个头算的是整个句子(一列)的512/8=64个维度,最后按列拼接起来'''return context # context: [batch_size, n_heads, len_q, d_v]
其中,Q,K,V的维度都是[2, 8, 6, 64], score的维度为[2, 8, 6, 6],即每个token之间的注意力分数。这儿取出一个batch中的一个head下的注意力分数a为例,a的维度为[6, 6],如图所示:
如上图所示,在得分score中,标黄的0.71和0.24分别是S与S,以及S与I的词向量相乘得到。由于I在S后面,所以需要通过mask将其置为负无穷大,而0.71需要保留,因为是S与S在同一个位置上。因此这个mask矩阵为上三角矩阵。
相关文章:

Transformer的PyTorch实现之若干问题探讨(二)
在《Transformer的PyTorch实现之若干问题探讨(一)》中探讨了Transformer的训练整体流程,本文进一步探讨Transformer训练过程中teacher forcing的实现原理。 1.Transformer中decoder的流程 在论文《Attention is all you need》中࿰…...

解释Python中的GIL(全局解释器锁)及其影响。描述Python中的垃圾回收机制。Python中的类变量和实例变量有什么区别
解释Python中的GIL(全局解释器锁)及其影响 Python中的GIL(全局解释器锁)是CPython解释器中的一个机制,用于同步线程的执行。GIL确保任何时候只有一个线程在执行Python字节码。这意味着,即使在多核或多处理器…...

Appium使用初体验之参数配置,简单能够运行起来
一、服务器配置 Appium Server配置与Appium Server GUI(可视化客户端)中的配置对应,尤其是二者如果不在同一台机器上,那么就需要配置Appium Server GUI所在机器的IP(Appium Server GUI的HOST也需要配置本机IP…...

Java:JDK8新特性(Stream流)、File类、递归 --黑马笔记
一、JDK8新特性(Stream流) 接下来我们学习一个全新的知识,叫做Stream流(也叫Stream API)。它是从JDK8以后才有的一个新特性,是专业用于对集合或者数组进行便捷操作的。有多方便呢?我们用一个案…...

【Unity ShaderGraph】| 物体靠近时局部溶解,根据坐标控制溶解的位置【文末送书】
前言 【Unity ShaderGraph】| 物体靠近时局部溶解,根据坐标控制溶解的位置一、效果展示二、根据坐标控制溶解的位置,物体靠近局部溶解三、应用实例👑评论区抽奖送书 前言 本文将使用ShaderGraph制作一个根据坐标控制溶解的位置,物…...

测试OpenSIPS3.4.3的lua模块
这几天测试OpenSIPS3.4.3的lua模块,记录如下: 有bug,但能用 但现实世界就是这样,总是不完美的,发现之后马上提了issue 下面这段代码运行报错: function func1(msg) xlog("ERR","…...

【机器学习】数据清洗之处理缺失点
🎈个人主页:甜美的江 🎉欢迎 👍点赞✍评论⭐收藏 🤗收录专栏:机器学习 🤝希望本文对您有所裨益,如有不足之处,欢迎在评论区提出指正,让我们共同学习、交流进步…...

Linux 命令行的世界 :2.文件系统中跳转
我们需要学习的第一件事(除了打字之外)是如何在 Linux 文件系统中跳转。在这一章节中,我们将介绍以下命令:pwd 打印出当前工作目录名 cd 更改目录 ls 列出目录内容 Linux以分层目录结构来组织所有文件。这就意味着所有文件…...

R语言:箱线图绘制(添加平均值趋势线)
箱线图绘制 1. 写在前面2.箱线图绘制2.1 相关R包导入2.2 数据导入及格式转换2.3 ggplot绘图 1. 写在前面 今天有时间把之前使用过的一些代码和大家分享,其中箱线图绘制我认为是非常有用的一个部分。之前我是比较喜欢使用origin进行绘图,但是绘制的图不太…...

Open3D 模型切片
目录 一、算法原理1、算法过程2、主要函数二、代码实现三、结果展示1、原始数据2、切片结果本文由CSDN点云侠原创,原文链接。如果你不是在点云侠的博客中看到该文章,那么此处便是不要脸的爬虫与GPT。 一、算法原理...

KtConnect 本地连接连接K8S工具
KT Connect简介 Kt Connect (Kubernetes Developer Tool)是一个阿里开源、轻量级的面向 Kubernetes 用户的开发测试环境治理辅助工具。其核心是通过建立本地到集群以及集群到本地的双向通道。 1.阿里开源,轻量级, 2. 安装快捷简单…...

【Java万花筒】数据的安全钥匙:Java的加密与保护方法
编码的盾牌:Java开发人员的安全性武器库 前言 在当今数字化时代,保护用户数据和信息的安全已成为开发人员的首要任务。无论是在Web应用程序开发还是安全测试中,加密和安全性都是至关重要的。本文将介绍六个Java库和工具,它们为开…...

【Java多线程案例】实现阻塞队列
1. 阻塞队列简介 1.1 阻塞队列概念 阻塞队列:是一种特殊的队列,具有队列"先进先出"的特性,同时相较于普通队列,阻塞队列是线程安全的,并且带有阻塞功能,表现形式如下: 当队列满时&…...

【制作100个unity游戏之24】unity制作一个3D动物AI生态系统游戏3(附项目源码)
最终效果 文章目录 最终效果系列目录前言随着地面法线旋转在地形上随机生成动物不同部位颜色不同最终效果源码完结系列目录 前言 欢迎来到【制作100个Unity游戏】系列!本系列将引导您一步步学习如何使用Unity开发各种类型的游戏。在这第24篇中,我们将探索如何用unity制作一…...

home work day5
第四章 堆与拷贝构造函数 一 、程序阅读题 1、给出下面程序输出结果。 #include <iostream.h> class example {int a; public: example(int b5){ab;} void print(){aa1;cout <<a<<"";} void print()const {cout<<a<<endl;} …...

c#安全-nativeAOT
文章目录 前记AOT测试反序列化Emit 前记 JIT\AOT JIT编译器(Just-in-Time Complier),AOT编译器(Ahead-of-Time Complier)。 AOT测试 首先编译一段普通代码 using System; using System.Runtime.InteropServices; namespace co…...

【Java】案例:检测MySQL是否存在某数据库,没有则创建
1.代码 package hello; import java.sql.*;public class CeShi {//定义基本数据static final String JDBC_DRIVER "com.mysql.cj.jdbc.Driver";static final String DB_URL "jdbc:mysql://localhost/";static final String USER "your_username&q…...

内网渗透靶场02----Weblogic反序列化+域渗透
网络拓扑: 攻击机: Kali: 192.168.111.129 Win10: 192.168.111.128 靶场基本配置:web服务器双网卡机器: 192.168.111.80(模拟外网)10.10.10.80(模拟内网)域成员机器 WIN7PC192.168.…...

[嵌入式系统-9]:C语言程序调用汇编语言程序的三种方式
目录 1. 使用函数声明和函数调用: 2. 使用汇编内联(Inline Assembly): 3. 使用汇编代码文件和链接器: C语言程序可以调用汇编程序的方式有多种,下面列举了几种常见的方式: 1. 使用函数声明和…...

备战蓝桥杯---搜索(完结篇)
再看一道不完全是搜索的题: 解法1:贪心并查集: 把冲突事件从大到小排,判断是否两个在同一集合,在的话就返回,不在的话就合并。 下面是AC代码: #include<bits/stdc.h> using namespace …...

深入浅出:Golang的Crypto/SHA256库实战指南
深入浅出:Golang的Crypto/SHA256库实战指南 介绍crypto/sha256库概览主要功能应用场景库结构和接口实例 基础使用教程字符串哈希化文件哈希化处理大型数据 进阶使用方法增量哈希计算使用Salt增强安全性多线程哈希计算 实际案例分析案例一:安全用户认证系…...

Unity_ShaderGraph节点问题
Unity_ShaderGraph节点问题 Unity版本:Unity2023.1.19 为什么在Unity2023.1.19的Shader Graph中找不见PBR Master节点? 以下这个PBR Maste从何而来?...

Java集合 Collection接口
这里写目录标题 集合Collection接口创建一个性表增加元素删除元素修改元素判断元素遍历集合实例判断元素是否存在 集合 Java中的Collection接口是集合类的一个顶级接口,它定义了一些基本的操作,如添加、删除、查找等。Collection接口主要有以下几个常用…...

C# Task的使用
C#中的Task类是.NET框架中用于实现异步编程的核心组件之一,它在.NET Framework 4及更高版本以及.NET Core中广泛使用。Task对象代表一个异步操作,并提供了跟踪异步操作状态、获取结果和处理完成通知的方法。 Task 类提供了对异步操作的封装,…...

尚硅谷Ajax笔记
一天拿下 介绍二级目录三级目录 b站链接 介绍 ajax优缺点 http node.js下载配置好环境 express框架 切换到项目文件夹,执行下面两条命令 有报错,退出用管理员身份打开 或者再命令提示符用管理员身份打开 npm init --yes npm i express请求 <script>//引…...

【MATLAB源码-第138期】基于matlab的D2D蜂窝通信仿真,对比启发式算法,最优化算法和随机算法的性能。
操作环境: MATLAB 2022a 1、算法描述 D2D蜂窝通信介绍 D2D蜂窝通信允许在同一蜂窝网络覆盖区域内的终端设备直接相互通信,而无需数据经过基站或网络核心部分转发。这种通信模式具有几个显著优点:首先,它可以显著降低通信延迟&…...

AcWing 第 142 场周赛 B.最有价值字符串(AcWing 5468) (Java)
AcWing 第 142 场周赛 B.最有价值字符串(AcWing 5468) (Java) 比赛链接:AcWing 第 142 场周赛 x题传送门:B.最有价值字符串 题目:不展示 分析: 题目不难,不过有坑😭。 我们可以定义一个数组记录每个字…...

滑块识别验证
滑块识别 1. 获取图片 测试网站:https://www.geetest.com/adaptive-captcha-demo 2. 点击滑块拼图并开始验证 # 1.打开首页 driver.get(https://www.geetest.com/adaptive-captcha-demo)# 2.点击【滑动拼图验证】 tag WebDriverWait(driver, 30, 0.5).until(la…...

每日五道java面试题之java基础篇(四)
第一题. 访问修饰符 public、private、protected、以及不写(默认)时的区别? Java 中,可以使⽤访问控制符来保护对类、变量、⽅法和构造⽅法的访问。Java ⽀持 4 种不同的访问权限。 default (即默认,什么也不写&…...

我的docker随笔43:问答平台answer部署
本文介绍开源问答社区平台Answer的容器化部署。 起因 笔者一直想搭建一个类似stack overflower这样的平台,自使用了Typora,就正式全面用MarkdownTyporagit来积累自己的个人知识库,但没有做到web化,现在也还在探索更好的方法。 无…...