文本匹配SimCSE模型代码详解以及训练自己的中文数据集
前言
在上一篇博客文本匹配中的示例代码中使用到了一个SimCSE模型,用来提取短文本的特征,然后计算特征相似度,最终达到文本匹配的目的。但是该示例代码中的短文本是用的英文短句,其实SimCSE模型也可以用于中文短文本的特征提取,本篇博客就基于苏沐剑发表于科学空间的中文任务还是SOTA吗?我们给SimCSE补充了一些实验博客中使用到的代码,来记录一下代码梳理的笔记,并且使用自己的数据集在这篇代码上进行训练。另外,关于这个模型的原理细节等,可以参考别的博主写的内容,还有就是作者的论文,这些会附在最后的参考链接。
代码详解
数据导入部分
数据导入部分的代码主要有三个步骤,(1)从txt中读取文本数据,常规操作,这里没什么可说的;
datasets = {'%s-%s' % (task_name, f):load_data('%s%s/%s.%s.data' % (data_path, task_name, task_name, f))for f in ['train', 'valid', 'test']
}
(2)将读取到的文本句子转换成id向量,同样也是常规操作;
def convert_to_ids(data, tokenizer, maxlen=64):"""转换文本数据为id形式"""a_token_ids, b_token_ids, labels = [], [], []for d in tqdm(data):token_ids = tokenizer.encode(d[0], maxlen=maxlen)[0]a_token_ids.append(token_ids)token_ids = tokenizer.encode(d[1], maxlen=maxlen)[0]b_token_ids.append(token_ids)labels.append(d[2])a_token_ids = sequence_padding(a_token_ids)b_token_ids = sequence_padding(b_token_ids)return a_token_ids, b_token_ids, labels
(3)第三步则是写了一个class,使用了一个生成器,完成数据batch读取。这里需要注意的是,每个batch中,同一个文本数据,输入了两次,一个batch中的两个一样的文本输入,由于模型最后一层的加入了dropout,模型输出结果是有些许差别的,这样有差别的输出,则可以互为label,这也是SimCSE模型巧妙的地方。
class data_generator(DataGenerator):"""训练语料生成器"""def __iter__(self, random=False):batch_token_ids = []for is_end, token_ids in self.sample(random):batch_token_ids.append(token_ids) ##同一条文本输入两次batch_token_ids.append(token_ids) ##同一条文本输入两次if len(batch_token_ids) == self.batch_size * 2 or is_end:batch_token_ids = sequence_padding(batch_token_ids)batch_segment_ids = np.zeros_like(batch_token_ids)batch_labels = np.zeros_like(batch_token_ids[:, :1])yield [batch_token_ids, batch_segment_ids], batch_labelsbatch_token_ids = []
模型定义部分
这个模型的定义其实很简单,就是用bert作为特征提取的基础模型,然后再bert模型输出的基础上加上一个dropout操作,就是代码中的pooling层,核心代码就是下面几行
bert = build_transformer_model(config_path,checkpoint_path,model=model,with_pool='linear',dropout_rate=dropout_rate)
outputs, count = [], 0
while True:try:output = bert.get_layer('Transformer-%d-FeedForward-Norm' % count).outputoutputs.append(output)count += 1except:break
output = bert.output
# 最后的编码器
encoder = Model(bert.inputs, output)
模型的损失函数
模型的损失函数是所有代码中最难理解的部分,虽然代码只有十几行,但是最需要花费时间去理解的。
在阐述这个SimCSE模型的损失函数代码之前,首先要搞清楚,这个模型是要解决什么问题,其目的主要是为了提取短文本的特征,使得相似的句子,提取出来的特征距离更近,不同语义的句子,特征距离越远,这样使得提取出来的文本特征更具有辨识度,和人脸识别原理很类似,这就是对比学习模型系列想要达到的目的。
在了解了对比学习的大致原理之后,再来看代码,下面是解释
idxs = K.arange(0, K.shape(y_pred)[0])
这行代码就是模型输出的一个维度(模型输入的batchsize),构建一个索引,比如,模型输入batchsize为6,那idxs则就是[0,1,2,3,4,5]
idxs_1 = idxs[None, :]
这就是给idxs增加一个维度,使其变成[[0,1,2,3,4,5]]
idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None]
这行代码比较关键,目的是让idxs向量中数值是奇数的赋值为它的前一个数,数值为偶数的则赋值为它后一个索引值,这个一前一后的赋值,就是它相似度最大的索引值(排除自己)。这里需要解释一下的是,这里每个索引值背后代表的是SimCSE模型输出的一个个的提取到的文本特征向量,维度是1*738,和bert模型输出应该是一样的维度。而这里为什么要取一前一后的赋值索引,这因为数据导入时候,在每个batch里面同一条文本被相邻的导入了两次,那么这两个相邻的文本,经过SimCSE模型提取到的特征也是最为相似的,其相似度要接近1,而每个batch里面不相邻的模型输出,则应该是0,这样模型才能达到收敛的效果
y_true = K.equal(idxs_1, idxs_2)
y_true = K.cast(y_true, K.floatx())
这两行代码就是可以将y_true变成一个batchsize * batchsize大小的相似度矩阵,相似度的规则和上面描述的一样
生成y_true的中间值,其实可以打印出来看看,设定 y_pred为[‘a’, ‘a’, ‘b’, ‘b’, ‘c’, ‘c’]时候,整个调试代码如下:
from bert4keras.backend import keras, Kimport tensorflow as tfy_pred = ['a', 'a', 'b', 'b', 'c', 'c']session = tf.Session()
# 张量转化为ndarrayidxs = K.arange(0, K.shape(y_pred)[0])
array = session.run(idxs)
print('1', array)idxs_1 = idxs[None, :]
array = session.run(idxs_1)
print('2', array)idxs_2 = (idxs + 1 - idxs % 2 * 2)[:, None]
array = session.run(idxs_2)
print('3', array)y_true = K.equal(idxs_1, idxs_2)
array = session.run(y_true)
print('4', array)y_true = K.cast(y_true, K.floatx())array = session.run(y_true)
print('5',array)
y_pred = K.l2_normalize(y_pred, axis=1)
similarities = K.dot(y_pred, K.transpose(y_pred))
similarities = similarities - tf.eye(K.shape(y_pred)[0]) * 1e12
similarities = similarities * 20
这几行代码就是计算SimCSE模型预测出来每个batch里的每个文本特征之间的相似度,特征越相似,K.dot(y_pred, K.transpose(y_pred)),特征向量点乘越接近1,similarities = similarities - tf.eye(K.shape(y_pred)[0]) * 1e12,则是为了消除相似度矩阵对角线上的元素,即同一条特征自身与自身点乘的结果。
loss = K.categorical_crossentropy(y_true, similarities, from_logits=True)
最后用交叉熵损失来定义模型最后的输出损失
训练自己的数据
在这个模型需要训练自己的数据,首先是环境搭建:
jieba-0.42.1
bert4keras-0.10.5
keras-2.3.1
cudatoolkit 10.0.130
cudnn 7.6.0
tensorflow-gpu 1.13.1
然后准备数据集,格式如下:
txt这个标签,0,1可以有,也可以没有
接着就是下载预训练模型,bert的模型,下载之后,修改eval.py中的数据集和预训练模型的路径,将其修改成自己的路径
最后运行代码训练模型即可得到预测结果
参考链接
SimCSE论文及源码解读
SimCSE的loss实现源码解读
SimCSE: Simple Contrastive Learning of Sentence Embeddings
princeton-nlp/SimCSE
相关文章:
文本匹配SimCSE模型代码详解以及训练自己的中文数据集
前言 在上一篇博客文本匹配中的示例代码中使用到了一个SimCSE模型,用来提取短文本的特征,然后计算特征相似度,最终达到文本匹配的目的。但是该示例代码中的短文本是用的英文短句,其实SimCSE模型也可以用于中文短文本的特征提取&a…...
Biotin-PEG-FITC 生物素聚乙二醇荧光素;FITC-PEG-Biotin 科研用生物试剂
结构式: Biotin-PEG-FITC 生物素聚乙二醇荧光素 英文名称:Biotin-PEG-Fluorescein 中文名称:生物素聚乙二醇荧光素 外观:黄色液体、半固体或固体,取决于分子量。 溶剂:溶于大部分有机溶剂,…...
FISCO BCOS 搭建区块链,在SpringBoot中调用合约
一、搭建区块链 使用的是FISCO BCOS 和 WeBASE-Front来搭建区块链,详细教程: https://blog.csdn.net/yueyue763184/article/details/128924144?spm1001.2014.3001.5501 搭建好能达到下图效果即可: 二、部署智能合约与导出java文件、SDK证…...
面试官:int和Integer有什么区别?
回答思路: 原始数据类型和包装类介绍 主要区别(数据使用内存) 自动装箱、自动拆箱机制和实践原则 回答总结: int 是8种基本数据类型(byte、boolean、char、short、int、long、float、double)之一ÿ…...
MFC常用技巧
MFC常用技巧1、句柄MFC中如何获取窗口的句柄2、字符串CString转char*Unicode下char *转换为CString3、Visual C 64 位迁移的常见问题(数据类型、指针类型的长度问题)4、c - 将_beginthread返回的uintptr_t转换为HANDLE是否安全1、句柄 MFC中如何获取窗口…...
C++ —— 多态
目录 1.多态的概念 2.多态的定义及实现 2.1构成多态的两个硬性条件 2.2虚函数的重写 2.3override和final 3.抽象类 3.1接口继承和实现继承 4.多态原理 4.1虚函数表 4.2原理 4.3静态绑定和动态绑定 5.单继承和多继承体系的虚函数表 5.1单继承体系的虚函数表 5.2多继…...
java agent设计开发概要
agent开发设计 agent 开发的一些心得,适合熟悉agent或者有agent开发需求的同学 1 有个基础的agent,是java 标准的agent。这是agent代码入口 2 设计包结构, 基础agent agent下有plugin,加载plugin可以自己定义一个类加载器 plugin࿱…...
node.js笔记-模块化(commonJS规范),包与npm(Node Package Manager)
目录 模块化 node.js中模块的分类 模块的加载方式 模块作用域 向外共享模块作用域中的成员 向外共享成员 包与npm(Node package Manager) 什么是包? 包的来源 为什么需要包? 查找和下载包 npm下载和卸载包命令 配置np…...
Linux 磁盘坏块修复处理(错误:read error: Input/output error)
当磁盘出现坏块时,你对所关联的文件进行读取时,一般会出现 read error: Input/output error 这样的错误。 反过来讲,当你看到 read error: Input/output error 这种错误时,很大可能就是磁盘出现了坏块问题。 解决步骤:…...
API 面试四连杀:接口如何设计?安全如何保证?签名如何实现?防重如何实现?
下面我们就来讨论下常用的一些API设计的安全方法,可能不一定是最好的,有更牛逼的实现方式,但是这篇是我自己的经验分享. 一、token 简介 Token:访问令牌access token, 用于接口中, 用于标识接口调用者的身份、凭证,减…...
操作系统题目收录(六)
1、某系统采用基于优先权的非抢占式进程调度策略,完成一次进程调度和进程切换的系统时间开销为1us。在T时刻就绪队列中有3个进程P1P_1P1、P2P_2P2和P3P_3P3,其在就绪队列中的等待时间、需要的CPU时间和优先权如下表所示。若优先权值大的进程优先获…...
2023年十款开源测试开发工具推荐!
今天为大家奉献一篇测试开发工具集锦干货。在本篇文章中,将给大家推荐10款日常工作中经常用到的测试开发工具神器,涵盖了自动化测试、性能压测、流量复制、混沌测试、造数据等。 1、AutoMeter-API 自动化测试平台 AutoMeter 是一款针对分布式服务&…...
MySQL慢查询分析和性能优化
1 背景我们的业务服务随着功能规模扩大,用户量扩增,流量的不断的增长,经常会遇到一个问题,就是数据存储服务响应变慢。导致数据库服务变慢的诱因很多,而RD最重要的工作之一就是找到问题并解决问题。下面以MySQL为例子&…...
C++学习笔记(四)
组合、继承。委托(类与类之间的关系) 复合 queue类里有一个deque,那么他们的关系叫做复合。右上角的图表明复合的概念。上图的特例表明,queue中的功能都是通过调用c进行实现(adapter)。 复合关系下的构造和…...
【4】深度学习之Pytorch——如何使用张量处理时间序列数据集(共享自行车数据集)
表格数据 表格中的每一行都独立于其他行,他们的顺序页没有任何关系。并且,没有提供有关行之前和行之后的列编码信息。 表格类型的数据是指通过表格的形式表示的数据,它以行和列的方式组织数据。表格中的每一行代表一个数据项,每…...
mulesoft MCIA 破釜沉舟备考 2023.02.10.01
mulesoft MCIA 破釜沉舟备考 2023.02.10.01 1. What is a defining charcateristic of an integration-Platform-as-a-Service(iPaaS)?2. An application deployed to a runtime fabric environment with two cluster replicas is designed to periodically trigger of flow f…...
干货 | PCB拼板,那几条很讲究的规则!
拼板指的是将一张张小的PCB板让厂家直接给拼做成一整块。一、为什么要拼板呢,也就是说拼板的好处是什么?1.为了满足生产的需求。有些PCB板太小,不满足做夹具的要求,所以需要拼在一起进行生产。2.提高SMT贴片的焊接效率。只需要过一…...
笔试题-2023-思远半导体-数字IC设计【纯净题目版】
回到首页:2023 数字IC设计秋招复盘——数十家公司笔试题、面试实录 推荐内容:数字IC设计学习比较实用的资料推荐 题目背景 笔试时间:2022.08.20应聘岗位:数字IC设计工程师笔试时长:90min笔试平台:牛客网题目类型:填空题(2道),不定项选择题(3道),单选题(2道),问…...
canvas根据坐标点位画图形-canvas拖拽编辑单个图形形状
首先在选中图形的时候需要用鼠标右击来弹出选择框,实现第一个编辑节点功能 在components文件夹下新建右键菜单 RightMenu文件: <template><div v-show"show" class"right-menu" :style"top:this.ypx;left:this.xpx…...
JavaEE 初阶 — 确认应答机制
文章目录确认应答机制(安全机制)1 什么是后发先至问题1 如何解决后发先至问题确认应答机制(安全机制) 确认应答 是实现可靠传输的最核心机制。 这里指的 可靠传输 不是说 100% 可以把消息发给接收方,而是尽力而为&…...
0207 事件
事件监听事件监听版本事件类型事件概念事件在编程时系统内发生的动作或者发生的事情例子点击按钮鼠标经过拖拽鼠标事件监听(注册事件,绑定事件)让程序员检测是否有事件产生,一旦有事件触发,就立即调用一个函数做出响应…...
SpringBoot整合Swagger
目录 一、swagger介绍 二、springboot集成swagger 1、创建一个springboot-web项目 2、导入相关依赖 3、编写一个Hellow工程 4、配置swagger --->config 5、启动springboot工程 6、配置swagger信息 7、配置swagger扫描接口 8、如何设置Swagger在生产环境中使用&…...
20230210英语学习
Why Do So Many Cats Have White ‘Socks’ on Their Paws? 为什么好多猫咪脚上都“穿着白袜子”? If you see a house cat, the odds are high that it will have white paws, a look that many owners affectionately call "socks."But socks are rar…...
【图像处理OpenCV(C++版)】——4.5 全局直方图均衡化
前言: 😊😊😊欢迎来到本博客😊😊😊 🌟🌟🌟 本专栏主要结合OpenCV和C来实现一些基本的图像处理算法并详细解释各参数含义,适用于平时学习、工作快…...
2022年API安全研究报告
目录 导读 2022年API安全风险概况 2022年平均每月遭受攻击的API数量超21万...
【内网安全-横向移动】基于SMB协议-PsExec
目录 一、SMB协议 1、简述: 2、工具: 二、PsExec 1、简述: 2、使用: 1、常用参数: 2、情况: 3、插件 三、PsExec(impacket) 1、简述: 1、impacket࿱…...
whistle 一个神奇的前端调试工具(抓包\代理工具)
在进行前端开发过程中,我们常常需要对一些接口进行处理,以及当后端接口没有弄好需要我们mock一些假数据,针对这些场景,我们就可以使用whistle 来解决。首先,我们要知道能满足我们需求的工具有很多,例如&…...
node.js下载和vite项目创建以及可能遇到的错误
目录 一、node.js的下载 1、去官网下载 节点.js (nodejs.org) 2、下载过程 第一步: 第二步: 第三步: 第四步: 第五步: 二、vite项目的创建(使用的工具是Hbuilder x) 第一步: 出现报错…...
如何使用python画一个爱心
1 问题 如何使用python画一个爱心。 2 方法 桌面新建一个文本文档,文件后缀改为.py,输入相关代码ctrls保存,关闭,最后双击运行。 代码清单 1 from turtle import * def curvemove(): for i in range(200): right(1) …...
1 Flutter UI Container和 Text 和图片组件
一 Text 组件Text 文本组件的一些属性如下body: const Text("this is leonardo fibonacci",// 文本对齐的方式textAlign: TextAlign.center,// 文本方向textDirection: TextDirection.rtl,// 字体显示最大的行数maxLines: 2,// 文字超出屏幕之后的显示方式 ellipsi…...
wordpress 4.7优化精减/兰蔻搜索引擎营销案例
你觉得最好用的地图导航软件是哪一个?这个问题,回答的人千千万万,现在国内两大巨头是,高德地图,百度地图。个人感觉是,二者皆有优势和缺点。 百度地图在手机导航中也算是做得比较好的一款地图类应用。最突出的方面是…...
高端制作网站公司/重庆seo研究中心
当查询变得很慢很慢,建立索引已经无法提高查询速度时。那么,最常见的MySQL优化方案,你造吗?SQL查询语句优化的六大方案:1、使用索引2、借助explain(查询优化神器)选择更好的索引和优化查询语句3…...
有什么办法可以在备案期间网站不影响seo/免费seo公司
参考资料:http://www.cnblogs.com/dreamvibe/p/4349886.html 为什么转换成对偶问题: 首先是我们有不等式约束方程,这就需要我们写成min max的形式来得到最优解。而这种写成这种形式对x不能求导,所以我们需要转换成max min的形式&a…...
网站怎么做才美观/百度网盘app下载安装 官方下载
Androidstudio 编译jpeglib 和 pnglig的 so的简单记录编译多个SO: 用ADD_SUBDIRECTORY 添加多个目录, 生产多个SOmylibrary/CMakeLists.txt png 和 jpeg 目录下有源码和CMakeLists.txt 生产SO,cmake_minimum_required(VERSION 3.4.1)#设置…...
网站抄袭/昆明网络营销公司哪家比较好
第10章 对象和类 面向对象编程OOP特性: 抽象封装和数据隐藏多态继承代码的可重用性 10.1 过程性编程和面向对象编程 采用过程性编程: 首先考虑遵循的步骤,然后考虑如何表示这些数据。 采用OOP: 首先考虑数据,考虑如何表示数…...
全球网站域名后缀/seo权威入门教程
这个进度条可以反映真实进度,并且完成百分比的文字时随着进度增加而移动的,所在位置也恰好是真实完成的百分比位置,效果如下:思路如下:第一部分是左侧的蓝色直线,代表已经完成的进度;第二部分是…...