当前位置: 首页 > news >正文

深度学习PyTorch 之 transformer-中文多分类

transformer的原理部分在前面基本已经介绍完了,接下来就是代码部分,因为transformer可以做的任务有很多,文本的分类、时序预测、NER、文本生成、翻译等,其相关代码也会有些不同,所以会分别进行介绍

但是对于不同的任务其流程是一样的,所以一些重复的步骤就不过多解释了。

1、 前期准备

数据和之前LSTM是一样的,同时我们还使用上次训练好的词嵌入模型

以下是代码

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import numpy as np
from gensim.models import KeyedVectors
from sklearn.model_selection import train_test_split
import pandas as pd
import jieba
import re
from sklearn.preprocessing import LabelEncoder# 加载数据
file_path = './data/news.csv'
data = pd.read_csv(file_path)# 显示数据的前几行
data.head()# 文本清洗和分词函数
def clean_and_cut(text):# 删除特殊字符和数字text = re.sub(r'[^a-zA-Z\u4e00-\u9fff]', '', text)# 使用jieba进行分词words = jieba.cut(text)return ' '.join(words)X_train_cut = data["text"].apply(clean_and_cut)
# 显示处理后的文本
data.head()# 将标签转换为数值形式
label_encoder = LabelEncoder()
data["label"] = label_encoder.fit_transform(data["label"])
# 加载保存的word vectors
loaded_wv = KeyedVectors.load('word_vector', mmap='r') class Word2VecDataset(Dataset):def __init__(self, texts, labels, word2vec, max_len=100):self.texts = textsself.labels = labelsself.word2vec = word2vecself.max_len = max_lendef __len__(self):return len(self.texts)def __getitem__(self, idx):text = self.texts[idx]label = self.labels[idx]embeds = [self.word2vec[word] if word in self.word2vec else np.zeros(self.word2vec.vector_size) for word in text]if len(embeds) > self.max_len:embeds = embeds[:self.max_len]else:embeds += [np.zeros(self.word2vec.vector_size) for _ in range(self.max_len - len(embeds))]return torch.tensor(embeds, dtype=torch.float), torch.tensor(label, dtype=torch.long)# texts和labels是数据集中的文本和标签列表
texts = X_train_cut.tolist()
labels = data['label'].tolist()# 划分数据集
train_texts, test_texts, train_labels, test_labels = train_test_split(texts, labels, test_size=0.2)

2、位置编码和主模型

import mathclass PositionalEncoding(nn.Module):def __init__(self, d_model, max_len=100):super(PositionalEncoding, self).__init__()# 创建一个位置编码矩阵pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)  # (1, max_len, d_model)self.register_buffer('pe', pe)def forward(self, x):# x: (batch_size, max_len, d_model)x = x + self.pe.expand(x.size(0), -1, -1)return x

2.1 PositionalEncoding 类

这个类用于创建和提供位置编码。位置编码是 Transformer 模型中用于注入序列中单词的位置信息的机制。这种位置信息对于模型理解单词的顺序很重要。

初始化方法 __init__
  • d_model:模型的维度,也是词嵌入的维度。
  • max_len:序列的最大长度。
  • pe:位置编码矩阵,大小为 (1, max_len, d_model)。这个矩阵被注册为一个缓冲区,这意味着它会被保存和加载与模型的其他参数一起。
前向传播方法 forward
  • 输入 x 的形状是 (batch_size, max_len, d_model)
  • self.pe.expand(x.size(0), -1, -1):这个操作将位置编码矩阵扩展为 (batch_size, max_len, d_model),以便它可以与输入数据相加。
  • 最后,将扩展后的位置编码矩阵加到输入数据上,并返回结果。
#修改Transformer模型以添加位置编码
class TransformerClassifierWithPE(nn.Module):def __init__(self, num_classes, d_model=100, nhead=2, num_layers=2, dim_feedforward=2048, dropout=0.1):super(TransformerClassifierWithPE, self).__init__()# 位置编码self.pos_encoder = PositionalEncoding(d_model)# Transformer编码器层encoder_layers = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, dim_feedforward=dim_feedforward, dropout=dropout)self.transformer_encoder = nn.TransformerEncoder(encoder_layers, num_layers=num_layers)# 分类器self.classifier = nn.Linear(d_model, num_classes)def forward(self, x):# x: (batch_size, max_len, d_model)x = self.pos_encoder(x)x = x.permute(1, 0, 2)  # (max_len, batch_size, d_model)x = self.transformer_encoder(x)  # (max_len, batch_size, d_model)x = x.mean(dim=0)  # (batch_size, d_model)x = self.classifier(x)  # (batch_size, num_classes)return x

2.2 TransformerClassifierWithPE 类

这个类定义了一个带有位置编码的 Transformer 分类器模型。

初始化方法 __init__
  • num_classes:分类任务的类别数量。
  • d_model:模型的维度,也是词嵌入的维度。
  • nhead:多头注意力的头数。
  • num_layers:Transformer 编码器层的数量。
  • dim_feedforward:前馈网络中的隐藏层维度。
  • dropout:Dropout 的概率。
  • pos_encoder:PositionalEncoding 实例,用于位置编码。
  • transformer_encoder:Transformer 编码器,由多个 TransformerEncoderLayer 组成。
  • classifier:线性分类器,用于生成最终的分类结果。
前向传播方法 forward
  • 输入 x 的形状是 (batch_size, max_len, d_model)
  • 首先,使用 self.pos_encoder(x) 获取位置编码后的输入。
  • 然后,将输入的维度从 (batch_size, max_len, d_model) 转换为 (max_len, batch_size, d_model),这是因为 PyTorch 的 Transformer 编码器期望的输入维度是这样的。
  • 接下来,通过 self.transformer_encoder(x) 应用 Transformer 编码器。
  • 然后,使用 x.mean(dim=0) 获取每个序列的平均表示。
  • 最后,通过 self.classifier(x) 应用线性分类器,得到最终的分类结果。
    这个模型可以用于文本分类任务,其中输入是文本序列的词嵌入表示。

3、训练模型


# 模型参数
d_model = 512
nhead = 8
num_encoder_layers = 3
dim_feedforward = 2048
num_classes = len(data.label.unique())  # 假设label_dict是我们的标签字典
max_len = 256model = TransformerClassifierWithPE( d_model=d_model, nhead=nhead, num_layers=num_encoder_layers, dim_feedforward=dim_feedforward, num_classes=num_classes, max_len=max_len,dropout=0.1)-----------------------------
TransformerModel((pos_encoder): PositionalEncoding()(transformer_encoder): TransformerEncoder((layers): ModuleList((0-2): 3 x TransformerEncoderLayer((self_attn): MultiheadAttention((out_proj): NonDynamicallyQuantizableLinear(in_features=512, out_features=512, bias=True))(linear1): Linear(in_features=512, out_features=2048, bias=True)(dropout): Dropout(p=0.1, inplace=False)(linear2): Linear(in_features=2048, out_features=512, bias=True)(norm1): LayerNorm((512,), eps=1e-05, elementwise_affine=True)(norm2): LayerNorm((512,), eps=1e-05, elementwise_affine=True)(dropout1): Dropout(p=0.1, inplace=False)(dropout2): Dropout(p=0.1, inplace=False))))(decoder): Linear(in_features=512, out_features=10, bias=True)
)
# 训练模型
num_epochs = 20
for epoch in range(num_epochs):for inputs, labels in train_loader:# 清除梯度optimizer.zero_grad()# 前向传播outputs = model(inputs)# 计算损失loss = criterion(outputs, labels)# 反向传播loss.backward()# 更新参数optimizer.step()print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}')
# 在测试集上评估模型
model.eval()
with torch.no_grad():correct = 0total = 0for inputs, labels in test_loader:outputs = model(inputs)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print(f'Accuracy of the model on the test set: {100 * correct / total}%')

相关文章:

深度学习PyTorch 之 transformer-中文多分类

transformer的原理部分在前面基本已经介绍完了,接下来就是代码部分,因为transformer可以做的任务有很多,文本的分类、时序预测、NER、文本生成、翻译等,其相关代码也会有些不同,所以会分别进行介绍 但是对于不同的任务…...

STC 51单片机烧录程序遇到一直检测单片机的问题

准备工作 一,需要一个USB-TTL的下载器 ,并安装好对应的驱动程序 二、对应的下载软件,stc软件需要官方的软件(最好是最新的,个人遇到旧的下载软件出现问题) 几种出现一直检测的原因 下载软件图标&#xf…...

后端系统开发之——接口参数校验

今天难得双更,大家点个关注捧个场 原文地址:后端系统开发之——接口参数校验 - Pleasure的博客 下面是正文内容: 前言 在上一篇文章中提到了接口的开发,虽然是完成了,但还是缺少一些细节——传入参数的校验。 即用户…...

IDEA 配置阿里规范检测

IDEA中安装插件 配置代码风格检查规范 使用代码风格检测 在代码类中,右键 然后会给出一些不符合规范的修改建议: 保存代码时自动格式化代码 安装插件: 配置插件:...

数据仓库系列总结

一、数据仓库架构 1、数据仓库的概念 数据仓库(Data Warehouse)是一个面向主题的、集成的、相对稳定的、反映历史变化的数据集合,用于支持管理决策。 数据仓库通常包含多个来源的数据,这些数据按照主题进行组织和存储&#x…...

gitlab runner没有内网的访问权限应该怎么解决

如果你的GitLab Runner没有内网访问权限,但你需要访问内部资源(如私有仓库或其他服务),你可以考虑以下几种方法: VPN 或 SSH 隧道: 在允许的情况下,通过VPN或SSH隧道连接到内部网络。这将允许Gi…...

el-tree 设置默认展开指定层级

el-tree默认关闭所有选项&#xff0c;但是有添加或者编辑删除的情况下&#xff0c;需要刷新接口&#xff0c;此时会又要关闭所有选项&#xff1b; 需求&#xff1a;在编辑时、添加、删除 需要将该内容默认展开 <el-tree :default-expanded-keys"expandedkeys":da…...

python便民超市管理系统flask-django-nodejs-php

随着人们生活节奏的加快&#xff0c;以前传统的购物方式发生了巨大的改变&#xff0c;以前一个超市要想经营好自己的门店&#xff0c;每天都要忙着记账出账&#xff0c;尤其是出库入库统计&#xff0c;如果忙中出乱&#xff0c;可能导致今天所有的营业流水&#xff0c;要重新换…...

HarmonyOS — BusinessError 不能被 JSON.stringify转换

在鸿蒙中BusinessError 继承于Error&#xff0c;而在JavaScript&#xff08;以及TypeScript&#xff0c;因为它是JavaScript的超集&#xff09;中&#xff0c;Error 对象包含一些不能被 JSON.stringify 直接序列化的属性。JSON.stringify 方法会将一个JavaScript对象或者值转换…...

JupyterNotebook 如何切换使用的虚拟环境kernel

在Jupyter Notebook中&#xff0c;如果需要修改使用的虚拟环境Kernel&#xff1a; 首先&#xff0c;需要确保虚拟环境已经安装conda上【conda基本操作】 打开Jupyter Notebook。 在Jupyter Notebook的顶部菜单中&#xff0c;选择 “New” 在弹出的窗口中&#xff0c;列出了…...

预防GPT-3和其他复杂语言模型中的“幻觉”

标题&#xff1a;预防GPT-3和其他复杂语言模型中的“幻觉” 正文&#xff1a; “假新闻”的一个显著特征是它经常在事实正确信息的环境中呈现虚假信息&#xff0c;通过一种文学渗透的方式&#xff0c;使不真实的数据获得感知权威&#xff0c;这是半真半假力量令人担忧的展示。…...

从源码解析AQS

前置概念 要彻底了解AQS的底层实现就必须要了解一下线程相关的知识。 包括voliatevoliate 我们使用翻译软件翻译一下volatile&#xff0c;会发现它有以下几个意思&#xff1a;易变的;无定性的;无常性的;可能急剧波动的;不稳定的;易恶化的;易挥发的;易发散的。这也正式使用vola…...

基于Spring Boot的云上水果超市的设计与实现

摘 要 伴随着我国社会的发展&#xff0c;人民生活质量日益提高。于是对云上水果超市进行规范而严格是十分有必要的&#xff0c;所以许许多多的信息管理系统应运而生。此时单靠人力应对这些事务就显得有些力不从心了。所以本论文将设计一套云上水果超市&#xff0c;帮助商家进行…...

游戏引擎中的动画基础

一、动画技术简介 视觉残留理论 - 影像在我们的视网膜上残留1/24s。 游戏中动画面临的挑战&#xff1a; 交互&#xff1a;游戏中的玩家动画需要和场景中的物体进行交互。实时&#xff1a;最慢需要在1/30秒内算完所有的场景渲染和动画数据。&#xff08;可以用动画压缩解决&am…...

springboot3快速入门案例2024最新版

前边 springboot3 系统要求 技术&工具版本&#xff08;or later&#xff09;maven3.6.3 or later 3.6.3 或更高版本Tomcat10.0Servlet9.0JDK17 SpringBoot的主要目标是&#xff1a; 为所有 Spring 开发提供更快速、可广泛访问的入门体验。开箱即用&#xff0c;设置合理的…...

软考 系统架构设计师系列知识点之系统性能(1)

所属章节&#xff1a; 第2章. 计算机系统基础知识 第9节. 系统性能 系统性能是一个系统提供给用户的所有性能指标的集合。它既包括硬件性能&#xff08;如处理器主频、存储器容量、通信带宽等&#xff09;和软件性能&#xff08;如上下文切换、延迟、执行时间等&#xff09;&a…...

Trent-FPGA硬件设计课程

本课程涵盖FPGA硬件设计的基础概念和实践应用。学生将学习Verilog语言编程、数字电路设计原理、FPGA架构和开发工具的使用。通过项目实践&#xff0c;掌握FPGA设计流程和调试技巧&#xff0c;为硬件加速和嵌入式系统开发打下坚实基础。 课程大小&#xff1a;4.3G 课程下载&am…...

【大模型学习记录】db-gpt源码安装问题汇总

1、首次源码安装时安装的其实dbgpt到conda环境中&#xff0c;会将路径一起安装。 如果有其他的路径使用同样的conda环境会报错&#xff0c;一直读取的就是原先的路径的内容。需要自己新创建一个conda env 2、界面中配置知识库问答时&#xff0c;报错 # 1、报的错如下&#x…...

QB PHP 多语言配置

1&#xff1a; 下载QBfast .exe 的文件 2&#xff1a; 安装的时候 &#xff0c;一定点击 仅为我 安装 而不是 所有人 3&#xff1a; 如果提示 更新就 更新 &#xff0c; 安装如2 4&#xff1a; 如果遇到 新增 或者编辑已经 配置的项目时 不起作用 &#xff1a; 右…...

Kubernetes实战(三十一)-使用开源CEPH作为后端StorageClass

1 引言 K8S在1.13版本开始支持使用Ceph作为StorageClass。其中云原生存储Rook和开源Ceph应用都非常广泛。本文主要介绍K8S如何对接开源Ceph使用RBD卷。 K8S对接Ceph的技术栈如下图所示。K8S主要通过容器存储接口CSI和Ceph进行交互。 Ceph官方文档&#xff1a;Block Devices a…...

【Python爬虫】详解BeautifulSoup()及其方法

文章目录 &#x1f354;准备工作&#x1f339;BeautifulSoup()⭐代码实现✨打印标签里面的内容✨快速拿到一个标签里的属性✨打印整个文档&#x1f386;获取特定标签的特定内容 &#x1f339;查找标签&#x1f388;在文档查找标签 find_all&#x1f388;正则表达式搜索 &#x…...

C语言经典算法-8

文章目录 其他经典例题跳转链接41.基数排序法42.循序搜寻法&#xff08;使用卫兵&#xff09;43.二分搜寻法&#xff08;搜寻原则的代表&#xff09;44.插补搜寻法45.费氏搜寻法 其他经典例题跳转链接 C语言经典算法-1 1.汉若塔 2. 费式数列 3. 巴斯卡三角形 4. 三色棋 5. 老鼠…...

Panasonic松下PLC如何数据采集?如何实现快速接入IIOT云平台?

在工业自动化领域&#xff0c;数据采集与远程控制是提升生产效率、优化资源配置的关键环节。对于使用Panasonic松下PLC的用户来说&#xff0c;如何实现高效、稳定的数据采集&#xff0c;并快速接入IIOT云平台&#xff0c;是摆在他们面前的重要课题。HiWoo Box工业物联网关以其强…...

高性能 MySQL 第四版(GPT 重译)(四)

第十一章&#xff1a;扩展 MySQL 在个人项目中运行 MySQL&#xff0c;甚至在年轻公司中运行 MySQL&#xff0c;与在市&#xfffd;&#xfffd;已经建立并且“呈现指数增长”业务中运行 MySQL 大不相同。在高速业务环境中&#xff0c;流量可能每年增长数倍&#xff0c;环境变得…...

整型数组按个位值排序 - 华为OD统一考试(C卷)

OD统一考试&#xff08;C卷&#xff09; 分值&#xff1a; 100分 题解&#xff1a; Java / Python / C 题目描述 给定一个非空数组(列表)&#xff0c;其元素数据类型为整型&#xff0c;请按照数组元素十进制最低位从小到大进行排序&#xff0c;十进制最低位相同的元素&#xf…...

【React】Diff算法

1. React15 Diff算法&#xff08;递归进行&#xff09; 一句话概括&#xff1a;新虚拟DOM和旧虚拟DOM对比&#xff0c;找出差异&#xff0c;根据差异更新真实DOM Diff过程描述&#xff1a; 1. 树比较(DOM) 同层节点之间相互比较&#xff0c;不会跨层级比较。&#xff08;当发现…...

【物联网】Modbus 协议及应用

Modbus 协议简介 QingHub设计器在设计物联网数据采集时不可避免的需要针对Modbus协议的设备做相关数据采集&#xff0c;这里就我们的实际项目经验分享Modbus协议 简介 Modbus由MODICON公司于1979年开发&#xff0c;是一种工业现场总线协议标准。1996年施耐德公司推出基于以太…...

Docker容器引擎

1、Docker是什么。 Docker是在Linux容器里运行应用的开源工具&#xff0c;是一种轻量级的"虚拟机"。Docker的logo设计为蓝色鲸鱼&#xff0c;拖着许多集装箱。鲸鱼可以看作宿主机&#xff0c;而集装箱可以理解为相互隔离的容器&#xff0c;每个集装箱中都包含自己的应…...

2.28线程

注意被抢占时是返回原队列&#xff0c;优先级不变。越往下优先级越小。往下没有优先级时&#xff0c;在最低的优先级队列里循环 到达了不一定会被服务&#xff0c;会进入就绪态进行等待 。核心等式就是周转时间运行时间等待时间&#xff0c;带权就是周转/运行&#xff0c; 随着…...

TCP/IP ⽹络模型

TCP/IP ⽹络模型 对于同⼀台设备上的进程间通信&#xff0c;有很多种⽅式&#xff0c;⽐如有管道、消息队列、共享内存、信号等⽅式&#xff0c;⽽对于不同设备上的进程间通信&#xff0c;就需要⽹络通信&#xff0c;⽽设备是多样性的&#xff0c;所以要兼容多种多样的设备&am…...

唐山网站建设公司/营销方式有哪几种

我见过java运行在手机上,包括很廉价的山寨手机,但是却暂时没发现.net在手机上有什么作为。wp7可能是个转机&#xff0c;但是按照《Java的跨平台就是一句谎言。 》&#xff08;http://www.cnblogs.com/hack/archive/2010/05/30/1747513.html&#xff09;的标准&#xff0c;那.ne…...

免费企业建站选哪家/互动营销公司

如今在线客服系统智能化水平正在不断提高&#xff0c;企业通过使用人工智能以及各类大数据统计等技术&#xff0c;能够帮助客服中心去快速解决大量复杂的客服工作&#xff0c;并且有效提高客服的工作效率。那么企业也会存在一些疑问&#xff0c;智能客服是如何去改善客户服务模…...

营销网站建设软件下载/百度网站提交

墨墨导读&#xff1a;本文来自墨天轮用户 肖杰 的投稿&#xff0c;介绍用BBED恢复删除数据的全过程。墨天轮主页&#xff1a;https://www.modb.pro/u/6722Oracle中delete行时&#xff0c;数据实际上并没有被删除。而是将行标记为已删除&#xff0c;并相应地调整空闲空间计数器和…...

unity做网站/百度网

首先说明一下过程&#xff1a;准备好35G以上的空间&#xff0c;否则你别打算预览这个2010。 第一步下载&#xff1a;这个VPC 2007镜像有11个文件&#xff0c;10个700MB的&#xff0c;一个286MB的。这里大概需要7.2G。 附上下载地址(想批量的朋友注意第一个文件是exe&#xff…...

吉野家网站谁做的/百度贴吧网页版入口

多态性&#xff08;Polymorphism&#xff09;一词最早用于生物学,指同一种族的生物体具有相同的特性。在C#中多态性的定义是&#xff1a;同一操作作用于不同的类的实例、不同的类将进行不同的解释、最后产生不同的执行结果。C#支持两种类型的多态性&#xff1a;编译时的多态性&…...

微信公司网站/微信小程序开发详细步骤

E 题意&#xff1a; 就是给你一个筛子&#xff0c;然后你最多可以晒n次&#xff0c;比如你第i次晒到x&#xff0c;你停止得分为x&#xff0c;要么继续抛。小A会尽量让自己的得分值最大。问你这个最大值的期望是多少。 思考&#xff1a; 这题有意思的地方就是&#xff0c;小A…...