机器学习深度学习——NLP实战(自然语言推断——注意力机制实现)
👨🎓作者简介:一位即将上大四,正专攻机器学习的保研er
🌌上期文章:机器学习&&深度学习——NLP实战(自然语言推断——数据集)
📚订阅专栏:机器学习&&深度学习
希望文章对你们有所帮助
NLP实战(自然语言推断——注意力机制实现)
- 引入
- 模型
- 注意(Attending)
- 比较
- 聚合
- 整合代码
- 训练和评估模型
- 读取数据集
- 创建模型
- 训练和评估模型
- 使用模型
- 小结
引入
在之前已经介绍了什么是自然语言推断,并且下载并处理了SNLI数据集。由于许多模型都是基于复杂而深度的架构,因此提出用注意力机制解决自然语言推断问题,并且称之为“可分解注意力模型”。这使得模型没有循环层或卷积层,在SNLI数据集上以更少的参数实现了当时的最佳结果。下面就实现这种基于注意力的自然语言推断方法(使用MLP),如下图所述:
这里的任务就是要将预训练GloVe送到注意力和MLP的自然语言推断架构。
模型
与保留前提和假设中词元的顺序,我们可以将一个文本序列中的词元与另一个文本序列中的每个词元对齐,然后比较和聚合这些信息,以预测前提和假设之间的逻辑关系。这和机器翻译中源句和目标句之间的词元对齐类似,前提和假设之间的词元对齐可以通过注意力机制来灵活完成。如下所示就是使用注意力机制来实现自然语言推断的模型图:
上面的i和i相对,前提中的sleep会对应tired,假设中的tired对应的是need sleep。
从高层次讲,它由三个联合训练的步骤组成:对齐、比较和汇总,下面会通过代码来解释和实现。
import torch
from torch import nn
from torch.nn import functional as F
from d2l import torch as d2l
注意(Attending)
第一步是将一个文本序列中的词元与另一个序列中的每个词元对齐。假设前提是“我需要睡眠”,假设是“我累了”。由于语义上的相似性,我们不妨将假设中的“我”与前提中的“我”对齐,将假设中的“累”与前提中的“睡眠”对齐。同样,我们可能希望将前提中的“我”与假设中的“我”对齐,将前提中“需要睡眠”与假设中的“累”对齐。
注意,这种对齐是使用的加权平均的“软”对齐,其中理想情况下较大的权重与要对齐的词元相关联。为了便于演示,上图是用了“硬”对齐的方式来展示。
现在,我们要详细描述使用注意力机制的软对齐。
用
A = ( a 1 , . . . , a m ) 和 B = ( b 1 , . . . , b n ) A=(a_1,...,a_m)和B=(b_1,...,b_n) A=(a1,...,am)和B=(b1,...,bn)
分别表示前提和假设,其词元数量分别为m和n,其中:
a 1 , b j ∈ R d 是 d 维的词向量 a_1,b_j∈R^d是d维的词向量 a1,bj∈Rd是d维的词向量
关于软对齐,我们将注意力权重计算为:
e i j = f ( a i ) T f ( b j ) e_{ij}=f(a_i)^Tf(b_j) eij=f(ai)Tf(bj)
其中函数f是在下面的mlp函数中定义的多层感知机。输出维度f由mlp的num_hiddens参数指定。
def mlp(num_inputs, num_hiddens, flatten):net = []net.append(nn.Dropout(0.2))net.append(nn.Linear(num_inputs, num_hiddens))net.append(nn.ReLU())if flatten:net.append(nn.Flatten(start_dim=1))net.append(nn.Dropout(0.2))net.append(nn.Linear(num_hiddens, num_hiddens))net.append(nn.ReLU())if flatten:net.append(nn.Flatten(start_dim=1))return nn.Sequential(*net)
值得注意的是,上式中,f分别输入ai和bi,而不是把它们一对放在一起作为输入。这种分解技巧导致f只有m+n次计算(线性复杂度),而不是mn次计算(二次复杂度)。
对上式中的注意力权重进行规范化,我们计算假设中所有词元向量的加权平均值,以获得假设的表示,该假设与前提中索引i的词元进行软对齐:
β i = ∑ j = 1 n e x p ( e i j ) ∑ k = 1 n e x p ( e i k ) b j β_i=\sum_{j=1}^n\frac{exp(e_{ij})}{\sum_{k=1}^nexp(e_{ik})}b_j βi=j=1∑n∑k=1nexp(eik)exp(eij)bj
同理,我们计算假设中索引为j的每个词元与前提词元的软对齐:
α j = ∑ i = 1 m e x p ( e i j ) ∑ k = 1 m e x p ( e k j ) a i α_j=\sum_{i=1}^m\frac{exp(e_{ij})}{\sum_{k=1}^mexp(e_{kj})}a_i αj=i=1∑m∑k=1mexp(ekj)exp(eij)ai
下面,我们定义Attend类来计算假设(beta)与输入前提A的软对齐以及前提(alpha)与输入假设B的软对齐。
class Attend(nn.Module):def __init__(self, num_inputs, num_hiddens, **kwargs):super(Attend, self).__init__(**kwargs)self.f = mlp(num_inputs, num_hiddens, flatten=False)def forward(self, A, B):# A/B的形状:(批量大小,序列A/B的词元数,embed_size)# f_A/f_B的形状:(批量大小,序列A/B的词元数,num_hiddens)f_A = self.f(A)f_B = self.f(B)# e的形状:(批量大小,序列A的词元数,序列B的词元数)e = torch.bmm(f_A, f_B.permute(0, 2, 1))# beta的形状:(批量大小,序列A的词元数,embed_size),# 意味着序列B被软对齐到序列A的每个词元(beta的第1个维度)beta = torch.bmm(F.softmax(e, dim=-1), B)# beta的形状:(批量大小,序列B的词元数,embed_size),# 意味着序列A被软对齐到序列B的每个词元(alpha的第1个维度)alpha = torch.bmm(F.softmax(e.permute(0, 2, 1), dim=-1), A)return beta, alpha
比较
在下一步中,我们将一个序列中的词元与和该词元软对齐的另一个序列进行比较。注意,软对齐中,一个序列中的所有词元(尽管可能具有不同的注意力权重)将与另一个序列中的词元进行比较。
在比较步骤中,我们将来自一个序列的词元的连结(运算符[·,·])和来自另一个序列的对其的词元送入函数g(一个多层感知机):
v A , i = g ( [ a i , β i ] ) , i = 1 , . . . , m v B , j = g ( [ b j , α j ] ) , j = 1 , . . . , n 其中, v A , i 指:所有假设中的词元与前提中词元 i 软对齐,再与词元 i 的比较; v B , j 指:所有前提中的词元与假设中词元 j 软对齐,再与词元 j 的比较。 v_{A,i}=g([a_i,β_i]),i=1,...,m\\ v_{B,j}=g([b_j,α_j]),j=1,...,n\\ 其中,v_{A,i}指:所有假设中的词元与前提中词元i软对齐,再与词元i的比较;\\ v_{B,j}指:所有前提中的词元与假设中词元j软对齐,再与词元j的比较。 vA,i=g([ai,βi]),i=1,...,mvB,j=g([bj,αj]),j=1,...,n其中,vA,i指:所有假设中的词元与前提中词元i软对齐,再与词元i的比较;vB,j指:所有前提中的词元与假设中词元j软对齐,再与词元j的比较。
下面的Compare类定义了比较的步骤:
class Compare(nn.Module):def __init__(self, num_inputs, num_hiddens, **kwargs):super(Compare, self).__init__(**kwargs)self.g = mlp(num_inputs, num_hiddens, flatten=False)def forward(self, A, B, beta, alpha):V_A = self.g(torch.cat([A, beta], dim=2))V_B = self.g(torch.cat([B, alpha], dim=2))return V_A, V_B
聚合
现在我们有两组比较向量:
v A , i 和 v B , j v_{A,i}和v_{B,j} vA,i和vB,j
在最后一步中,我们将聚合这些信息以推断逻辑关系。我们首先求和这两组比较向量:
v A = ∑ i = 1 m v A , i , v B = ∑ j = 1 n v B , j v_A=\sum_{i=1}^mv_{A,i},v_B=\sum_{j=1}^nv_{B,j} vA=i=1∑mvA,i,vB=j=1∑nvB,j
接下来,我们将两个求和结果的连结提供给函数h(一个多层感知机),以获得逻辑关系的分类结果:
y ^ = h ( [ v A , v B ] ) \hat{y}=h([v_A,v_B]) y^=h([vA,vB])
聚合步骤在以下Aggregate类中定义。
class Aggregate(nn.Module):def __init__(self, num_inputs, num_hiddens, num_outputs, **kwargs):super(Aggregate, self).__init__(**kwargs)self.h = mlp(num_inputs, num_hiddens, flatten=True)self.linear = nn.Linear(num_hiddens, num_outputs)def forward(self, V_A, V_B):# 对两组比较向量分别求和V_A = V_A.sum(dim=1)V_B = V_B.sum(dim=1)# 将两个求和结果的连结送到多层感知机中Y_hat = self.linear(self.h(torch.cat([V_A, V_B], dim=1)))return Y_hat
整合代码
通过将注意步骤、比较步骤和聚合步骤组合在一起,我们定义了可分解注意力模型来联合训练这三个步骤:
class DecomposableAttention(nn.Module):def __init__(self, vocab, embed_size, num_hiddens, num_inputs_attend=100,num_inputs_compare=200, num_inputs_agg=400, **kwargs):super(DecomposableAttention, self).__init__(**kwargs)self.embedding = nn.Embedding(len(vocab), embed_size)self.attend = Attend(num_inputs_attend, num_hiddens)self.compare = Compare(num_inputs_compare, num_hiddens)# 有3种可能的输出:蕴涵、矛盾和中性self.aggregate = Aggregate(num_inputs_agg, num_hiddens, num_outputs=3)def forward(self, X):premises, hypotheses = XA = self.embedding(premises)B = self.embedding(hypotheses)beta, alpha = self.attend(A, B)V_A, V_B = self.compare(A, B, beta, alpha)Y_hat = self.aggregate(V_A, V_B)return Y_hat
训练和评估模型
现在,我们将在SNLI数据集上对定义好的可分解注意力模型进行训练和评估。我们从读取数据集开始。
读取数据集
我们使用上节定义的函数下载并读取SNLI数据集,批量大小和序列长度分别设为256和50:
batch_size, num_steps = 256, 50
train_iter, test_iter, vocab = d2l.load_data_snli(batch_size, num_steps)
创建模型
我们将预训练好的100维GloVe嵌入来表示输入词元。我们将向量ai和bj的维数定义为100。f和g的输出维度被设置为200。然后我们创建一个模型实例,初始化参数,并加载GloVe嵌入来初始化输入词元的向量。
embed_size, num_hiddens, devices = 100, 200, d2l.try_all_gpus()
net = DecomposableAttention(vocab, embed_size, num_hiddens)
glove_embedding = d2l.TokenEmbedding('glove.6b.100d')
embeds = glove_embedding[vocab.idx_to_token]
net.embedding.weight.data.copy_(embeds)
训练和评估模型
现在我们可以在SNLI数据集上训练和评估模型。
lr, num_epochs = 0.001, 4
trainer = torch.optim.Adam(net.parameters(), lr=lr)
loss = nn.CrossEntropyLoss(reduction="none")
d2l.train_ch13(net, train_iter, test_iter, loss, trainer, num_epochs,devices)
d2l.plt.show()
运行结果:
loss 0.495, train acc 0.805, test acc 0.826
443.5 examples/sec on [device(type=‘cpu’)]
运行图片:
使用模型
定义预测函数,输出一对前提和假设之间的逻辑关系。
#@save
def predict_snli(net, vocab, premise, hypothesis):"""预测前提和假设之间的逻辑关系"""net.eval()premise = torch.tensor(vocab[premise], device=d2l.try_gpu())hypothesis = torch.tensor(vocab[hypothesis], device=d2l.try_gpu())label = torch.argmax(net([premise.reshape((1, -1)),hypothesis.reshape((1, -1))]), dim=1)return 'entailment' if label == 0 else 'contradiction' if label == 1 \else 'neutral'
我们可以使用训练好的模型来获得对实例句子的自然语言推断结果:
print(predict_snli(net, vocab, ['he', 'is', 'good', '.'], ['he', 'is', 'bad', '.']))
预测结果:
‘contradiction’
小结
1、可分解注意模型包括三个步骤来预测前提和假设之间的逻辑关系:注意、比较和聚合。
2、通过注意力机制,我们可以将一个文本序列中的词元与另一个文本序列中的每个词元对齐,反之亦然。这种对齐是使用加权平均的软对齐,其中理想情况下,较大的权重与要对齐的词元相关联。
3、在计算注意力权重时,分解技巧会带来比二次复杂度更理想的线性复杂度。
4、我们可以使用预训练好的词向量作为下游自然语言处理任务的输入表示。
相关文章:
机器学习深度学习——NLP实战(自然语言推断——注意力机制实现)
👨🎓作者简介:一位即将上大四,正专攻机器学习的保研er 🌌上期文章:机器学习&&深度学习——NLP实战(自然语言推断——数据集) 📚订阅专栏:机器学习&…...
mac垃圾清理软件有哪些
随着使用时间的增加,mac系统会产生一些垃圾文件,影响系统的性能和稳定性。为了保持mac系统的高效,用户需要定期使用mac垃圾清理软件来清理系统缓存、日志、语言包等无用文件。CleanMyMac是一款功能强大的mac垃圾清理软件,它可以帮…...
8.18 校招 内推 面经
绿泡泡: neituijunsir 交流裙,内推/实习/校招汇总表格 1、校招 | 小米集团2024届全球校园招聘正式启动(内推) 校招 | 小米集团2024届全球校园招聘正式启动(内推) 2、2023校招总结--软件测试岗位 - 2 2…...
docker的web管理平台docker.ui
docker.ui安装 docker run --name docker.ui \ -p 8999:8999 \ --restartalways \ -v /var/run/docker.sock:/var/run/docker.sock \ -d joinsunsoft/docker.ui参数说明: docker run:启动container–name:容器命名–restartalwaysÿ…...
20230822 Windows上使用find_package引入OpenCV报错
报错信息 打开Cmake项目时,find_package 报错: Found OpenCV Windows Pack but it has no binaries compatible with yourconfiguration.You should manually point CMake variable OpenCV_DIR to your build of OpenCVlibrary.原因 大概率原项目是在 …...
MySQL下载安装配置
天行健,君子以自强不息;地势坤,君子以厚德载物。 每个人都有惰性,但不断学习是好好生活的根本,共勉! 文章均为学习整理笔记,分享记录为主,如有错误请指正,共同学习进步。…...
3D WEB轻量化引擎HOOPS产品助力NAPA打造船舶设计软件平台
NAPA(Naval Architectural PAckage,船舶建筑包),来自芬兰的船舶设计软件供应商,致力于提供世界领先的船舶设计、安全及运营的解决方案和数据分析服务。NAPA拥有超过30年的船舶设计经验,年营业额超过2560万欧…...
lesson9: C++多线程
1.线程库 1.1 thread类的简单介绍 C11 中引入了对 线程的支持 了,使得 C 在 并行编程时 不需要依赖第三方库 而且在原子操作中还引入了 原子类 的概念。要使用标准库中的线程,必须包含 < thread > 头文件 函数名 功能 thread() 构造一个线程对象…...
安卓修改SwitchCompat色值
SwitchCompat控件色值跟系统设置的主题有关,但是主题效果不是能轻易就能改的,因为涉及到整个APP的样式。网上方案基本都是通过修改style文件来改变色值,经过多次尝试修改最终觉得单独修改控件色值比较好。 一、控件属性 //修改开关色值就是最…...
pytorch内存泄漏
问题描述: 内存泄漏积累过多最终会导致内存溢出,当内存占用过大,进程会被killed掉。 解决过程: 在代码的运行阶段输出内存占用量,观察在哪一块存在内存剧烈增加或者显存异常变化的情况。但是在这个过程中要分级确认…...
20230821-字符串相乘-给树命名(unordered_map)
字符串相乘 有两个非负整数字符串num1,num2,计算num1和num2所表达整数的乘积,结果以字符串形式存储。注意:不能通过强制转换方法解题。 示例1: 输入: "4", "3" 输出: "12" …...
[Go版]算法通关村第十二关黄金——字符串冲刺题
目录 题目:最长公共前缀解法1:纵向对比-循环内套循环写法复杂度:时间复杂度 O ( n ∗ m ) O(n*m) O(n∗m)、空间复杂度 O ( 1 ) O(1) O(1)Go代码 解法2:横向对比-两两对比(类似合并K个数组、合并K个链表)复…...
neovim为工作区添加本地clangd配置
1 背景 尝试使用neovim开发stm32,使用clangd作为LSP提供代码补全等功能。 2 思路 使用stm32cubeMX生成一个基于makefile的stm32工程。 使用bear或compiledb基于makefile生成compile_commands.json文件。 为clangd配置--query-driver选项,使其使用arm…...
信号处理--基于EEG脑电信号的眼睛状态的分析
本实验为生物信息学专题设计小项目。项目目的是通过提供的14导联EEG 脑电信号,实现对于人体睁眼和闭眼两个状态的数据分类分析。每个脑电信号的时长大约为117秒。 目录 加载相关的库函数 读取脑电信号数据并查看数据的属性 绘制脑电多通道连接矩阵 绘制两类数据…...
Redis高可用:主从复制详解
目录 1.什么是主从复制? 2.优势 3.主从复制的原理 4.全量复制和增量复制 4.1 全量复制 4.2 增量复制 5.相关问题总结 5.1 当主服务器不进行持久化时复制的安全性 5.2 为什么主从全量复制使用RDB而不使用AOF? 5.3 为什么还有无磁盘复制模式ÿ…...
[Flutter]有的时候调用setState(() {})报错?
先看FlutterSDK的原生类State中有一个变量mounted。 abstract class State<T extends StatefulWidget> with Diagnosticable {/// mounted的作用是,此State对象当前是否在树中。/// 在创建State对象之后,在调用initState之前,框架通过…...
利用屏幕水印学习英语单词,无打扰英语单词学习
1、利用屏幕水印学习英语单词,不影响任何鼠标键盘操作,不影响工作 2、利用系统热键快速隐藏(ALT1键 隐藏与显示) 3、日积月累单词会有进步 4、软件下载地址: 免安装,代码未加密,安全的屏幕水印学习英语…...
开学必备物品清单!这几款优先考虑!
马上就要开学了,同学们也要准备一系列开学用品,方便我们的学习生活,那有哪些数码物品可以在开学前准备的呢,接下来给大家安利几款很不错很实用的数码好物! 推荐一:南卡00压开放式蓝牙耳机 南卡00压开放式…...
聊聊调制解调器
目录 1.什么是调制解调器 2.调制解调器的工作原理 3.调制解调器的作用 4.调制解调器未来发展 1.什么是调制解调器 调制解调器(Modem)是一种用于在数字设备和模拟设备之间进行数据传输的设备。调制解调器将数字数据转换为模拟信号进行传输,…...
Go语言入门指南:基础语法和常用特性(下)
上一节,我们了解Go语言特性以及第一个Go语言程序——Hello World,这一节就让我们更深入的了解一下Go语言的**基础语法**吧! 一、行分隔符 在 Go 程序中,一行代表一个语句结束。每个语句不需要像 C 家族中的其它语言一样以分号 ;…...
【MFC常用问题记录】
MFC 记录 MFC的edit control控件显示1.控件添加变量M_edit后:2.控件ID为IDC_EDIT1: 线程函数使用 MFC的edit control控件显示 1.控件添加变量M_edit后: CString str; int x 10; str.Format(_T("%d"),x); M_edit.SetWindowText(str)2.控件ID…...
ThreadLocal内存泄漏问题
引子: 内存泄漏:是指本应该被GC回收的无用对象没有被回收,导致内存空间的浪费,当内存泄露严重时会导致内存溢出。Java内存泄露的根本原因是:长生命周期的对象持有短生命周期对象的引用,尽管短生命周期对象已…...
微服务基础概念【内含图解】
目录 拓展补充: 单体架构 分布式架构 面向服务的体系结构 云原生 微服务架构 什么是微服务? 微服务定义 拓展补充: 单体架构 单体架构:将业务的所有功能集中在一个项目中开发,最终打成一个包部署 优点&#x…...
Dockerfile创建 LNMP 服务+Wordpress 网站平台
文章目录 一.环境及准备工作1.项目环境2.服务器环境3.任务需求 二.Linux 系统基础镜像三.docker构建Nginx1.建立工作目录上传安装包2.编写 Dockerfile 脚本3.准备 nginx.conf 配置文件4.生成镜像5.创建自定义网络6.启动镜像容器7.验证 nginx 四.docker构建Mysql1. 建立工作目录…...
消息中间件篇
消息中间件篇 RabbitMQ 如何保证消息不丢失 面试官: RabbitMQ如何保证消息不丢失 候选人: 嗯!我们当时MYSQL和Redis的数据双写一致性就是采用RabbitMQ实现同步的,这里面就要求了消息的高可用性,我们要保证消息的不…...
基本定时器
1.简介 1. 基本定时器 TIM6 和 TIM7 包含一个 16 位自动重载计数器 2. 可以专门用于驱动数模转换器 (DAC), 用于触发 DAC 的同步电路 3. 16 位自动重载递增计数器 4. 16 位可编程预分频器 5. 计数器溢出时, 会触发中断/DMA请求 从上往下看 1.开始RCC供给定时器的时钟 RCC_APB1…...
MySQL 中文全文检索
创建索引(MySQL 5.7.6后全文件索引可用WITH PARSER ngram,针对中文,日文,韩文) ALTER TABLE 表 ADD FULLTEXT 索引名 (字段) WITH PARSER ngram;或者CREATE FULLTEXT INDEX 索引名 ON 表 (字段) WITH PARSER ngram; …...
Redis——list类型详解
概要 Redis中的list类型相当于双端队列,支持头插,头删,尾插,尾删,并且列表中的内容是可以重复的。 如果搭配使用rpush和lpop,那么就相当于队列 如果搭配使用rpush和rpop,那么就相当于栈 lpu…...
npm 安装 git 仓库包
安装 #v1.0.0 代表版本, 例如打了仓库一个tag叫v1.0.0; 如果不指定版本则默认是最新的代码 npm install githttp://mygitlab.xxxx.net/chengchongzhen/hex-event-track.git#v1.0.0在项目根目录执行以下命令, 此时你的代码会被链接到npm的全局仓库, 类似执行了 npm install xxx …...
问题来了!你知道你穿的防砸劳保鞋的保护包头都是什么材料
防砸劳保鞋是较为常见的一种劳保鞋,用于作业过程中保护工人的脚,减少或避免被坠落物、重物砸伤或压伤脚部的工作鞋。防砸安全鞋鞋前头装有防护包头,具有耐压力和抗冲击性能。主要适用于矿山、机械、建筑、钢铁、冶金、运输等行业。 你穿的防砸…...
甘肃省第九建设集团网站首页/志鸿优化设计答案
朋友门!等待终于结束了! 我们的社区在短短一周内已经发展到了超过10,000名Twitter粉丝,12,000名Discord成员,以及7,000名Telegram成员!为了感谢您的大力支持,我们将继续努力。为了感谢您的大力支持,我们很高…...
汕头市网络推广报价/seo顾问阿亮
版权声明:您好,转载请留下本人博客的地址,谢谢 https://blog.csdn.net/hongbochen1223/article/details/47601525 在我们的上一篇博客中,我们介绍了首页中的app列表界面如何完成,这个ListView以及其Adapter会在我们后面…...
安徽网站建设哪家好/网站建设的公司
2019独角兽企业重金招聘Python工程师标准>>> Frink:为物理计算设计的编程语言 作者:Alen 翻译:赖信涛 责编:仲培艺 Frink是一个实用的计算工具,也算是一种专为物理计算而设计的编程语言。它能让物理计算变得…...
招生网站建设策划方案/营销运营主要做什么
如果你还没有安装好SciTE,可以参考《在ubuntu 12.04 中安装SciTE 文本编辑器》。刚安装好的SciTE文本编辑器非常简朴,需要经过适当配置才能成为真正称手的编程利器。下文中的所有配置项,可以直接拷贝,然后粘贴到SciTE的用户设置文…...
做网站建设的公司有哪些/百度直接打开
# ---------------------------------------- # 核心属性 # ----------------------------------------# 文件编码 banner.charset UTF-8 # 文件位置 banner.location classpath:banner.txt# 日志配置 # 日志配置文件的位置。 例如对于Logback的classpath:logback.x…...
网络推广排名/关键词优化技巧
组建团队 人员培养 知识分享...