[oneAPI] 使用Bert进行中文文本分类
[oneAPI] 使用Bert进行中文文本分类
- Intel® Optimization for PyTorch
- 基于BERT的文本分类模型
- 数据预处理
- 数据集
- 定义tokenize
- 建立词表
- 转换为Token序列
- padding处理与mask
- 模型
- 结果
- OneAPI
- 参考资料
比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517
Intel® DevCloud for oneAPI:https://devcloud.intel.com/oneapi/get_started/aiAnalyticsToolkitSamples/
Intel® Optimization for PyTorch
在本次实验中,我们利用PyTorch和Intel® Optimization for PyTorch的强大功能,对PyTorch进行了精心的优化和扩展。这些优化举措极大地增强了PyTorch在各种任务中的性能,尤其是在英特尔硬件上的表现更加突出。通过这些优化策略,我们的模型在训练和推断过程中变得更加敏捷和高效,显著地减少了计算时间,提高了整体效能。我们通过深度融合硬件和软件的精巧设计,成功地释放了硬件潜力,使得模型的训练和应用变得更加快速和高效。这一系列优化举措为人工智能应用开辟了新的前景,带来了全新的可能性。
基于BERT的文本分类模型
基于BERT的文本分类模型就是在原始的BERT模型后再加上一个分类层即可,同时,对于分类层的输入(也就是原始BERT的输出),默认情况下取BERT输出结果中[CLS]位置对于的向量即可,当然也可以修改为其它方式,例如所有位置向量的均值等。因此,对于基于BERT的文本分类模型来说其输入就是BERT的输入,输出则是每个类别对应的logits值。
数据预处理
在构建数据集之前,我们首先需要知道的是模型到底应该接收什么样的输入,然后才能构建出正确的数据形式。在上面我们说到,基于BERT的文本分类模型的输入就等价于BERT模型的输入,同时BERT模型的输入如图1所示:
数据集
在这里,我们使用到的数据集是今日头条开放的一个新闻分类数据集(https://github.com/aceimnorstuvwxz/toutiao-text-classfication-dataset),一共包含有382688条数据,15个类别,经过处理后数据集格式为:
千万不要乱申请网贷,否则后果很严重_!_4
10年前的今天,纪念5.12汶川大地震10周年_!_11
怎么看待杨毅在一NBA直播比赛中说詹姆斯的球场统治力已经超过乔丹、伯德和科比?_!_3
戴安娜王妃的车祸有什么谜团?_!_2
其中_!_左边为新闻标题,也就是后面需要用到的分类文本,右边为类别标签。
定义tokenize
将输入进来的文本序列tokenize到字符级别。对于中文语料来说就是将每个字和标点符号都给切分开。在这里,我们可以借用transformers包中的BertTokenizer方法来完成,如下所示:
1 if __name__ == '__main__':
2 model_config = ModelConfig()
3 tokenizer = BertTokenizer.from_pretrained(model_config.pretrained_model_dir).tokenize
4 print(tokenizer("青山不改,绿水长流,我们月来客栈见!"))
5 print(tokenizer("10年前的今天,纪念5.12汶川大地震10周年"))
6
7 # ['青', '山', '不', '改', ',', '绿', '水', '长', '流', ',', '我', '们', '月', '来', '客', '栈', '见', '!']
8 # ['10', '年', '前', '的', '今', '天', ',', '纪', '念', '5', '.', '12', '汶', '川', '大', '地', '震', '10', '周', '年']
建立词表
将vocab.txt中的内容读取进来形成一个词表即可
1 class Vocab:2 UNK = '[UNK]'3 def __init__(self, vocab_path):4 self.stoi = {}5 self.itos = []6 with open(vocab_path, 'r', encoding='utf-8') as f:7 for i, word in enumerate(f):8 w = word.strip('\n')9 self.stoi[w] = i
10 self.itos.append(w)
11
12 def __getitem__(self, token):
13 return self.stoi.get(token, self.stoi.get(Vocab.UNK))
14
15 def __len__(self):
16 return len(self.itos)
转换为Token序列
在得到构建的字典后,便可以通过如下函数来将训练集、验证集和测试集转换成Token序列:
1 def data_process(self, filepath):2 raw_iter = open(filepath, encoding="utf8").readlines()3 data = []4 max_len = 05 for raw in tqdm(raw_iter, ncols=80):6 line = raw.rstrip("\n").split(self.split_sep)7 s, l = line[0], line[1]8 tmp = [self.CLS_IDX] + [self.vocab[token] for token in self.tokenizer(s)]9 if len(tmp) > self.max_position_embeddings - 1:
10 tmp = tmp[:self.max_position_embeddings - 1] # BERT预训练模型只取前512个字符
11 tmp += [self.SEP_IDX]
12 tensor_ = torch.tensor(tmp, dtype=torch.long)
13 l = torch.tensor(int(l), dtype=torch.long)
14 max_len = max(max_len, tensor_.size(0))
15 data.append((tensor_, l))
16 return data, max_len
padding处理与mask
对原始文本序列tokenize转换为Token ID后还需要对其进行padding处理。对于这一处理过程可以通过如下代码来完成:
1 def pad_sequence(sequences, batch_first=False, max_len=None, padding_value=0):2 if max_len is None:3 max_len = max([s.size(0) for s in sequences])4 out_tensors = []5 for tensor in sequences:6 if tensor.size(0) < max_len:7 tensor = torch.cat([tensor, torch.tensor(8 [padding_value] * (max_len - tensor.size(0)))], dim=0)9 else:
10 tensor = tensor[:max_len]
11 out_tensors.append(tensor)
12 out_tensors = torch.stack(out_tensors, dim=1)
13 if batch_first:
14 return out_tensors.transpose(0, 1)
15 return out_tensors
模型
class BertModel(nn.Module):""""""def __init__(self, config):super().__init__()self.bert_embeddings = BertEmbeddings(config)self.bert_encoder = BertEncoder(config)self.bert_pooler = BertPooler(config)self.config = configself._reset_parameters()def forward(self,input_ids=None,attention_mask=None,token_type_ids=None,position_ids=None):"""***** 一定要注意,attention_mask中,被mask的Token用1(True)表示,没有mask的用0(false)表示这一点一定一定要注意:param input_ids: [src_len, batch_size]:param attention_mask: [batch_size, src_len] mask掉padding部分的内容:param token_type_ids: [src_len, batch_size] # 如果输入模型的只有一个序列,那么这个参数也不用传值:param position_ids: [1,src_len] # 在实际建模时这个参数其实可以不用传值:return:"""embedding_output = self.bert_embeddings(input_ids=input_ids,position_ids=position_ids,token_type_ids=token_type_ids)# embedding_output: [src_len, batch_size, hidden_size]all_encoder_outputs = self.bert_encoder(embedding_output,attention_mask=attention_mask)# all_encoder_outputs 为一个包含有num_hidden_layers个层的输出sequence_output = all_encoder_outputs[-1] # 取最后一层# sequence_output: [src_len, batch_size, hidden_size]pooled_output = self.bert_pooler(sequence_output)# 默认是最后一层的first token 即[cls]位置经dense + tanh 后的结果# pooled_output: [batch_size, hidden_size]return pooled_output, all_encoder_outputsdef _reset_parameters(self):r"""Initiate parameters in the transformer model.""""""初始化"""for p in self.parameters():if p.dim() > 1:normal_(p, mean=0.0, std=self.config.initializer_range)@classmethoddef from_pretrained(cls, config, pretrained_model_dir=None):model = cls(config) # 初始化模型,cls为未实例化的对象,即一个未实例化的BertModel对象pretrained_model_path = os.path.join(pretrained_model_dir, "pytorch_model.bin")if not os.path.exists(pretrained_model_path):raise ValueError(f"<路径:{pretrained_model_path} 中的模型不存在,请仔细检查!>\n"f"中文模型下载地址:https://huggingface.co/bert-base-chinese/tree/main\n"f"英文模型下载地址:https://huggingface.co/bert-base-uncased/tree/main\n")loaded_paras = torch.load(pretrained_model_path)state_dict = deepcopy(model.state_dict())loaded_paras_names = list(loaded_paras.keys())[:-8]model_paras_names = list(state_dict.keys())[1:]if 'use_torch_multi_head' in config.__dict__ and config.use_torch_multi_head:torch_paras = format_paras_for_torch(loaded_paras_names, loaded_paras)for i in range(len(model_paras_names)):logging.debug(f"## 成功赋值参数:{model_paras_names[i]},形状为: {torch_paras[i].size()}")if "position_embeddings" in model_paras_names[i]:# 这部分代码用来消除预训练模型只能输入小于512个字符的限制if config.max_position_embeddings > 512:new_embedding = replace_512_position(state_dict[model_paras_names[i]],loaded_paras[loaded_paras_names[i]])state_dict[model_paras_names[i]] = new_embeddingcontinuestate_dict[model_paras_names[i]] = torch_paras[i]logging.info(f"## 注意,正在使用torch框架中的MultiHeadAttention实现")else:for i in range(len(loaded_paras_names)):logging.debug(f"## 成功将参数:{loaded_paras_names[i]}赋值给{model_paras_names[i]},"f"参数形状为:{state_dict[model_paras_names[i]].size()}")if "position_embeddings" in model_paras_names[i]:# 这部分代码用来消除预训练模型只能输入小于512个字符的限制if config.max_position_embeddings > 512:new_embedding = replace_512_position(state_dict[model_paras_names[i]],loaded_paras[loaded_paras_names[i]])state_dict[model_paras_names[i]] = new_embeddingcontinuestate_dict[model_paras_names[i]] = loaded_paras[loaded_paras_names[i]]logging.info(f"## 注意,正在使用本地MyTransformer中的MyMultiHeadAttention实现,"f"如需使用torch框架中的MultiHeadAttention模块可通过config.__dict__['use_torch_multi_head'] = True实现")model.load_state_dict(state_dict)return model
结果
OneAPI
import intel_extension_for_pytorch as ipexmodel = model.to(config.device)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-5)'''
Apply Intel Extension for PyTorch optimization against the model object and optimizer object.
'''
model, optimizer = ipex.optimize(model, optimizer=optimizer)
参考资料
基于BERT预训练模型的中文文本分类任务: https://www.ylkz.life/deeplearning/p10979382/
相关文章:
[oneAPI] 使用Bert进行中文文本分类
[oneAPI] 使用Bert进行中文文本分类 Intel Optimization for PyTorch基于BERT的文本分类模型数据预处理数据集定义tokenize建立词表转换为Token序列padding处理与mask 模型 结果OneAPI参考资料 比赛:https://marketing.csdn.net/p/f3e44fbfe46c465f4d9d6c23e38e0517…...
【数据治理】什么是数据库归档
文章目录 前言什么是数据归档 前言 如果您的日常工作中需要对数据库进行管理,那您肯定已经或即将遭遇这样的困惑:随着业务的蓬勃发展,数据库文件的大小逐渐增大,您需要为在线业务提供越来越大的高性能磁盘容量,但数据…...
AI代码补全 案例 - 阿里云智能编码插件Cosy
文章目录 Cosy简介Cosy安装Marketplace安装【推荐】离线安装安装效果Cosy功能体验代码智能补全代码示例搜索API搜索自然语言搜索控制台异常搜索优质文档搜索Cosy体验有感参考Cosy简介 阿里云智能编码插件(Alibaba Cloud AI Coding Assistant)是一款AI编程助手,提供代码智能…...
【Linux】进程信号篇Ⅰ:信号的产生(signal、kill、raise、abort、alarm)、信号的保存(core dump)
文章目录 一、 signal 函数:用户自定义捕捉信号二、信号的产生1. 通过中断按键产生信号2. 调用系统函数向进程发信号2.1 kill 函数:给任意进程发送任意信号2.2 raise 函数:给调用进程发送任意信号2.3 abort 函数:给调用进程发送 6…...
漏洞指北-VulFocus靶场专栏-中级03
漏洞指北-VulFocus靶场专栏-初级03 中级009 🌸gxlcms-cve_2018_14685🌸step1:安装系统 密码rootstep2 进入后台页面 账号密码:admin amdin888step3 查看详细 有phpinfo() 中级010 🌸dedecms-cnvd_2018_01221dz…...
【leetcode 力扣刷题】数组交集(数组、set、map都可实现哈希表)
数组交集 349. 两个数组的交集排序+双指针数组实现哈希表unordered_setunordered_map 350. 两个数组的交集Ⅱ排序 双指针数组实现哈希表unordered_map 349. 两个数组的交集 题目链接:349. 两个数组的交集 题目内容如下,理解题意:…...
MySQL 8.0.31 登录提示caching_sha2_password问题解决方法
MySQL 8.0.31 登录提示caching_sha2_password问题解决方法 MySQL 8.0.31 使用了 caching_sha2_password 作为默认的身份验证插件,这可能导致一些旧的客户端和库无法连接到服务器。以下是一些解决此类问题的常见步骤和建议: 确保MySQL服务正在运行&#…...
[Google] DeepMind Gemini: 新一代LLM结合AlphaGo技术将力压 GPT-4|未来 AI 领域的新巨头
2016年,Google DeepMind 人工智能实验室孕育出的 AlphaGo 人工智能程序在围棋赛场上一举击败冠军选手,成为历史的见证者。如今,DeepMind 联合创始人兼首席执行官 Demis Hassabis 表示,他们的工程师正借鉴 AlphaGo 的技术研发一款名…...
Maven高级
目录 一、分模块开发与设计 1. 分模块开发的意义 2. 分模块开发(模块拆分) (1)创建Maven模块 (2)书写模块代码 (3)通过maven指令安装模块到本地仓库(install指令&…...
【视觉SLAM入门】5.2. 2D-3D PNP 3D-3D ICP BA非线性优化方法 数学方法SVD DLT
"养气之学,戒之躁急" 1. 3D-2D PNP1.1 代数法1.1.1 DLT(直接线性变换法)1.1.2. P3P 1.2 优化法BA (Bundle Adjustment)法 2. 3D-3D ICP2.1 代数法2.1.1 SVD方法 2.2 优化(BA)法2.2.2 非线性优化方法 前置事项: 1. 3D-2D PNP 该问题描述为&am…...
人脸老化预测(Python)
本次项目的文件 main.py主程序如下 导入必要的库和模块: 导入 TensorFlow 库以及自定义的 FaceAging 模块。导入操作系统库和参数解析库。 定义 str2bool 函数: 自定义函数用于将字符串转换为布尔值。 创建命令行参数解析器: 使用 argparse.A…...
AWS SDK 3.x for .NET Framework 4.0 可行性测试
前言 为了应对日益增长的网络安全挑战, 越来越多的互联网厂商已经陆续开始或者已经彻底停止了对 SSL 3 / TLS 1.0 / TLS1.1 等上古加密算法的支持. 而对于一些同样拥有悠久历史的和 AWS 服务相关联的应用程序, 是否可以通过仅更新 SDK 版本的方式来适应新的环境. 本文将以 Win…...
两个list。如何使用流的写法将一个list中的对象中的某些属性根据另外一个list中的属性值赋值进去?
两个list。如何使用流的写法将一个list中的对象中的某些属性根据另外一个list中的属性值赋值进去? 你可以使用Java 8以上版本中的流(Stream)和Lambda表达式来实现这个需求。假设有两个List,一个是sourceList,包含要赋值属性的对象;另一个是…...
美国陆军希望大数据技术能够帮助保护其云安全
随着陆军采用更大型的云服务,一位高级官员警告说,一些在私营部门有效的快速软件开发技巧和简单解决方案(例如开放代码库)如果没有额外的安全性,将无法为军队工作。 我们知道现代软件开发确实依赖于第三方库ÿ…...
vue 文字跑马灯
<template><div class"marquee-container"><div class"marquee-content"><div>{{ marqueeText }}</div><div>{{ marqueeText }}</div> <!-- 复制一份文本,用于无缝衔接 --></div></d…...
开源ChatGPT系统源码 采用NUXT3+Laravel9后端开发 前后端分离版本
开源ChatGPT系统源码 采用NUXT3Laravel9后端开发 前后端分离版本 ChatGPT是一种基于AI的聊天机器人技术,它可以帮助用户与聊天机器人进行自然语言交流,以解决用户的问题或满足用户的需求。ChatGPT的核心技术是使用自然语言处理(NLPÿ…...
【LeetCode|数据结构】剑指 Offer 33. 二叉搜索树的后序遍历序列
题目链接 剑指 Offer 33. 二叉搜索树的后序遍历序列 标签 二叉搜索树、后序遍历 步骤 二叉搜索树的左子树的节点值 ≤ \le ≤根节点值 ≤ \le ≤右子树的节点值;对于后序遍历序列最后一个元素的值为根节点的值; 由上面的两个性质可以得出ÿ…...
自定义协程
难点 自己写了一遍协程,困难的地方在于unity中的执行顺序突然发现unity里面可以 yield return 的其实有很多 WaitForSeconds WaitForSecondsRealtime WaitForEndOfFrame WaitForFixedUpdate WaitUntil WaitWhile IEnumerator(可以用于协程嵌套…...
【Atcoder】 [ABC240Ex] Sequence of Substrings
题目链接 Atcoder方向 Luogu方向 题目解法 先考虑一个性质,选出的子串长度不会超过 2 n \sqrt {2n} 2n 考虑最劣的选法是选出长度为 1 , 2 , 3 , . . . 1,2,3,... 1,2,3,... 的子串(如果后一个选出的串比前一个子串长度大超过1,那么后…...
真机二阶段之堆叠技术
堆叠技术 --- 可以将多台真实的物理设备逻辑上抽象成一台 思科 -- VPC 华为 -- iStack和CSS 华三 -- IRF 锐捷 -- VSU iStack和CSS的区别: CSS --- 集群 --- 它仅支持将两台支持集群的交换机逻辑上整合成一台设备。 iStack --- 堆叠 --- 可以将多台支持堆叠的交换…...
简单、快速、无需注册的在线 MockJs 工具
简单、快速、无需注册的 MockJs 工具。通过参数来返回数据,传入什么参数就返回什么数据。 使用 接口只支持返回文本类数据,不支持图片、流数据等。 json 调用接口 https://mock.starxg.com/?responseBody{“say”:“hello”}&contentTypeapplic…...
【Linux取经路】探索进程状态之僵尸进程 | 孤儿进程
文章目录 一、进程状态概述1.1 运行状态详解1.2 阻塞状态详解1.3 挂起状态详解 二、具体的Linux操作系统中的进程状态2.1 Linux内核源代码2.2 查看进程状态2.3 D磁盘休眠状态(Disk sleep)2.4 T停止状态(stopped) 三、僵尸进程3.1 僵尸进程危害总结 四、孤儿进程五、结语 一、进…...
第十二章MyBatis动态SQL
if标签与where标签 if标签 test如果为true就会拼接查询条件,否则不会 当没有使用Param,test出现arg0/param1当使用Param,test为Param指定的值当使用Pojo,test为对象的属性名 select * from car where <if test"name!n…...
redis--发布订阅
redis的发布和订阅 在Redis中,发布-订阅(Publish-Subscribe,简称Pub/Sub)是一种消息传递模式,用于在不同的客户端之间传递消息,允许一个消息发布者将消息发送给多个订阅者。这种模式适用于解耦消息发送者和…...
链表2-两两交换链表中的节点删除链表的倒数第N个节点链表相交环形链表II
今天记录的题目: ● 24. 两两交换链表中的节点 ● 19.删除链表的倒数第N个节点 ● 面试题 02.07. 链表相交 ● 142.环形链表II 两两交换链表中的节点 题目链接:24. 两两交换链表中的节点 这题比较简单,记录好两个节点,交换其nex…...
数据结构之并查集
并查集 1. 并查集原理2. 并查集实现3. 并查集应用3.1 省份数量3.2 等式方程的可满足性 4. 并查集的优缺点及时间复杂度 1. 并查集原理 并查表原理是一种树型的数据结构,用于处理一些不相交集合的合并及查询问题。并查集的思想是用一个数组表示了整片森林࿰…...
[element-ui] el-date-picker a-range-picker type=“daterange“ rules 校验
项目场景: 在项目中表单提交有时间区间校验 问题描述 想当然的就和其他单个输入框字符串校验,导致提交保存的时候 ,初次日期未选择,规则提示。后续在同一表单上继续提交时,校验失效。走进了死胡同,一直以…...
Dockers搭建个人网盘、私有仓库,Dockerfile制作Nginx、Lamp镜像
目录 1、使用mysql:5.6和 owncloud 镜像,构建一个个人网盘。 (1)下载mysql:5.6和owncloud镜像 (2)创建启动mysql:5.6和owncloud容器 (3)在浏览器中输入网盘服务器的IP地址,进行账…...
2023 CCPC 华为云计算挑战赛 hdu7401 流量监控(树形dp)
题目 流量监控 - HDU 7401 - Virtual Judge 简单来说,T(T<20)组样例,sumn不超过2e4 每次给定一棵n(n<2000)个点的树,两问: ①将n个点恰拆成n/2个pair(u,v),要求一个点是另一个点的祖先,求方案数 …...
01.Django入门
1.创建项目 1.1基于终端创建Django项目 打开终端进入文件路径(打算将项目放在哪个目录,就进入哪个目录) E:\learning\python\Django 执行命令创建项目 F:\Anaconda3\envs\pythonWeb\Scripts\django-admin.exe(Django-admin.exe所…...
青岛网站推广招商/企业培训心得体会
使用树莓派4B搭建NAS(二):基于Manjaro 20.042020-07-07 19:14:1614点赞81收藏5评论【写作说明】:本文图片较少,具体操作欢迎参考作者的其他文章以及文中链接。上一篇我们谈了如何使用树莓派4B Ubuntu Server 20.04搭建NAS,本篇主…...
wordpress分类排序号/全自动推广软件
被调合约(通过call回调)支持接收以太币的案例: 被调合约(通过call回调)支持接收以太币的案例:pragma solidity >0.4.0 <0.6.0;contract Test001 {// 这个合约会保留所有发送给它的以太币,没有办法返还。// 必须实现Fallback回退函数,才能支持cal…...
武汉网站建站/重庆seo搜索引擎优化优与略
1、配置连接池通过IP/console进入管理控制台(如果不知道用户名和密码可以通过以下方式进入:右击StartWebLogic.sh快捷方式,选择“编辑”,在文本中可以找到用户名和密码)在左侧菜单中依次进入mydomain(自定义的域名称)-服务&#x…...
网站建设与管理可以专升本吗/chrome网页版入口
谈谈SQL Server高可用的常见问题每次谈到SQL Server的高可用,很多的DBA,特别是SQL Server DBA心里一痛:因为大家都认为SQL Server无法或者很难实现SQL Server。也有很多的DBA朋友脑袋一拍,给出答案“高可用不就是微软的那几个技术…...
门户网站直接登录系统/如何网页优化
案例介绍本章节主要用java实现;方法调用指令、返回指令、解析方法符号引用、参数传递等。实现新的指令后我们的虚拟机就可以执行稍微复杂的运算并输出结果。从调用的角度来看,方法可以分为两类:静态方法(或者类方法)和实例方法。静态方法通过…...
可以做任务赚钱的网站有哪些/百度一下知道官网
QQ邮箱 设置-账户,生成授权码,确认POP3/SMTP服务是已开启的状态。 OutLook 文件 - 添加账户 - 手动设置或其他服务类型 - POP或IMAP 用户名填邮箱账户 其他设置勾选使用与接收邮件服务器相同的配置...