当前位置: 首页 > news >正文

第N2周:中文文本分类-Pytorch实现

目录

  • 一、前言
  • 二、准备工作
  • 三、数据预处理
    • 1.加载数据
    • 2.构建词典
    • 3.生成数据批次和迭代器
  • 三、模型构建
    • 1. 搭建模型
    • 2. 初始化模型
    • 3. 定义训练与评估函数
  • 四、训练模型
    • 1. 拆分数据集并运行模型

一、前言

🍨 本文为🔗365天深度学习训练营 中的学习记录博客
🍖 原作者:K同学啊|接辅导、项目定制

● 难度:夯实基础⭐⭐
● 语言:Python3、Pytorch3
● 时间:4月23日-4月28日
🍺要求:
1、熟悉NLP的基础知识

二、准备工作

环境搭建
Python 3.8
pytorch == 1.8.1
torchtext == 0.9.1

三、数据预处理

1.加载数据

在这里插入图片描述

import torch
import torch.nn as nn
import os,PIL,pathlib,warningswarnings.filterwarnings("ignore")             #忽略警告信息# win10系统
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
import pandas as pd# 加载自定义中文数据
train_data = pd.read_csv('./data/train.csv', sep='\t', header=None)
train_data.head()
# 构造数据集迭代器
def coustom_data_iter(texts, labels):for x, y in zip(texts, labels):yield x, ytrain_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])

2.构建词典

from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
# conda install jieba -y
import jieba# 中文分词方法
tokenizer = jieba.lcutdef yield_tokens(data_iter):for text,_ in data_iter:yield tokenizer(text)vocab = build_vocab_from_iterator(yield_tokens(train_iter), specials=["<unk>"])
vocab.set_default_index(vocab["<unk>"]) # 设置默认索引,如果找不到单词,则会选择默认索引
vocab(['我','想','看','和平','精英','上','战神','必备','技巧','的','游戏','视频'])
label_name = list(set(train_data[1].values[:]))
print(label_name)
text_pipeline  = lambda x: vocab(tokenizer(x))
label_pipeline = lambda x: label_name.index(x)print(text_pipeline('我想看和平精英上战神必备技巧的游戏视频'))
print(label_pipeline('Video-Play'))

3.生成数据批次和迭代器

from torch.utils.data import DataLoaderdef collate_batch(batch):label_list, text_list, offsets = [], [], [0]for (_text,_label) in batch:# 标签列表label_list.append(label_pipeline(_label))# 文本列表processed_text = torch.tensor(text_pipeline(_text), dtype=torch.int64)text_list.append(processed_text)# 偏移量,即语句的总词汇量offsets.append(processed_text.size(0))label_list = torch.tensor(label_list, dtype=torch.int64)text_list  = torch.cat(text_list)offsets    = torch.tensor(offsets[:-1]).cumsum(dim=0) #返回维度dim中输入元素的累计和return text_list.to(device),label_list.to(device), offsets.to(device)# 数据加载器,调用示例
dataloader = DataLoader(train_iter,batch_size=8,shuffle   =False,collate_fn=collate_batch)

三、模型构建

1. 搭建模型

from torch import nnclass TextClassificationModel(nn.Module):def __init__(self, vocab_size, embed_dim, num_class):super(TextClassificationModel, self).__init__()self.embedding = nn.EmbeddingBag(vocab_size,   # 词典大小embed_dim,    # 嵌入的维度sparse=False) # self.fc = nn.Linear(embed_dim, num_class)self.init_weights()def init_weights(self):initrange = 0.5self.embedding.weight.data.uniform_(-initrange, initrange) # 初始化权重self.fc.weight.data.uniform_(-initrange, initrange)        self.fc.bias.data.zero_()                                  # 偏置值归零def forward(self, text, offsets):embedded = self.embedding(text, offsets)return self.fc(embedded)

2. 初始化模型

num_class  = len(label_name)
vocab_size = len(vocab)
em_size    = 64
model      = TextClassificationModel(vocab_size, em_size, num_class).to(device)

3. 定义训练与评估函数

import timedef train(dataloader):model.train()  # 切换为训练模式total_acc, train_loss, total_count = 0, 0, 0log_interval = 50start_time   = time.time()for idx, (text,label,offsets) in enumerate(dataloader):predicted_label = model(text, offsets)optimizer.zero_grad()                    # grad属性归零loss = criterion(predicted_label, label) # 计算网络输出和真实值之间的差距,label为真实值loss.backward()                          # 反向传播torch.nn.utils.clip_grad_norm_(model.parameters(), 0.1) # 梯度裁剪optimizer.step()  # 每一步自动更新# 记录acc与losstotal_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)if idx % log_interval == 0 and idx > 0:elapsed = time.time() - start_timeprint('| epoch {:1d} | {:4d}/{:4d} batches ''| train_acc {:4.3f} train_loss {:4.5f}'.format(epoch, idx, len(dataloader),total_acc/total_count, train_loss/total_count))total_acc, train_loss, total_count = 0, 0, 0start_time = time.time()def evaluate(dataloader):model.eval()  # 切换为测试模式total_acc, train_loss, total_count = 0, 0, 0with torch.no_grad():for idx, (text,label,offsets) in enumerate(dataloader):predicted_label = model(text, offsets)loss = criterion(predicted_label, label)  # 计算loss值# 记录测试数据total_acc   += (predicted_label.argmax(1) == label).sum().item()train_loss  += loss.item()total_count += label.size(0)return total_acc/total_count, train_loss/total_count

四、训练模型

1. 拆分数据集并运行模型

from torch.utils.data.dataset import random_split
from torchtext.data.functional import to_map_style_dataset
# 超参数
EPOCHS     = 10 # epoch
LR         = 5  # 学习率
BATCH_SIZE = 64 # batch size for trainingcriterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gamma=0.1)
total_accu = None# 构建数据集
train_iter = coustom_data_iter(train_data[0].values[:], train_data[1].values[:])
train_dataset = to_map_style_dataset(train_iter)split_train_, split_valid_ = random_split(train_dataset,[int(len(train_dataset)*0.8),int(len(train_dataset)*0.2)])train_dataloader = DataLoader(split_train_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)valid_dataloader = DataLoader(split_valid_, batch_size=BATCH_SIZE,shuffle=True, collate_fn=collate_batch)for epoch in range(1, EPOCHS + 1):epoch_start_time = time.time()train(train_dataloader)val_acc, val_loss = evaluate(valid_dataloader)# 获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']if total_accu is not None and total_accu > val_acc:scheduler.step()else:total_accu = val_accprint('-' * 69)print('| epoch {:1d} | time: {:4.2f}s | ''valid_acc {:4.3f} valid_loss {:4.3f} | lr {:4.6f}'.format(epoch,time.time() - epoch_start_time,val_acc,val_loss,lr))print('-' * 69)

相关文章:

第N2周:中文文本分类-Pytorch实现

目录 一、前言二、准备工作三、数据预处理1.加载数据2.构建词典3.生成数据批次和迭代器 三、模型构建1. 搭建模型2. 初始化模型3. 定义训练与评估函数 四、训练模型1. 拆分数据集并运行模型 一、前言 &#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客 …...

Salesforce许可证和版本有什么区别,购买帐号时应该如何选择?

Salesforce许可证分配给特定用户&#xff0c;授予他们访问Salesforce产品和功能的权限。Salesforce版本和许可证是不同的概念&#xff0c;但极易混淆。 Salesforce版本&#xff1a;这是对组织购买的Salesforce产品和功能的访问权限。大致可分为Essentials、Professional、Ente…...

接口测试怎么做?全网最详细从接口测试到接口自动化详解,看这篇就够了...

目录&#xff1a;导读 前言一、Python编程入门到精通二、接口自动化项目实战三、Web自动化项目实战四、App自动化项目实战五、一线大厂简历六、测试开发DevOps体系七、常用自动化测试工具八、JMeter性能测试九、总结&#xff08;尾部小惊喜&#xff09; 前言 抛出一个问题&…...

DataStore入门及在项目中的使用

首先给个官网的的地址&#xff1a;应用架构&#xff1a;数据层 - DataStore - Android 开发者 | Android Developers 小伙伴们可以直接看官网的资料&#xff0c;本篇文章是对官网的部分细节进行补充 一、为什么要使用DataStore 代替SharedPreferences SharedPreferences&a…...

用Python爬取中国各省GDP数据

介绍 在数据分析和经济研究中&#xff0c;了解中国各省份的GDP数据是非常重要的。然而&#xff0c;手动收集这些数据可能是一项繁琐且费时的任务。幸运的是&#xff0c;Python提供了一些强大的工具和库&#xff0c;使我们能够自动化地从互联网上爬取数据。本文将介绍如何使用P…...

深度学习-第T5周——运动鞋品牌识别

深度学习-第T5周——运动鞋品牌识别 深度学习-第T5周——运动鞋品牌识别一、前言二、我的环境三、前期工作1、导入数据集2、查看图片数目3、查看数据 四、数据预处理1、 加载数据1、设置图片格式2、划分训练集3、划分验证集4、查看标签 2、数据可视化3、检查数据4、配置数据集 …...

自媒体的孔雀效应:插根鸡毛还是专业才华?

自媒体时代&#xff0c;让许多原本默默无闻的人找到了表达自己的平台。有人声称&#xff0c;现在这个时代&#xff0c;“随便什么人身上插根鸡毛就可以当孔雀了”。可是&#xff0c;事实真的如此吗&#xff1f; 首先&#xff0c;我们不能否认的是&#xff0c;自媒体确实为大众提…...

Linux系统优化

一、系统启动流程 1.centos6 centos6开机启动流程&#xff0c;传送门 2.centos7启动流程 二、系统启动运行级别 2.1 什么是运行级别 运行级别&#xff1a;指操作系统当前正在运行的功能级别&#xff1b; [rootweb01 ~]# ll /usr/lib/systemd/system lrwxrwxrwx. 1 root root…...

Java笔记_22(反射和动态代理)

Java笔记_22 一、反射1.1、反射的概述1.2、获取class对象的三种方式1.3、反射获取构造方法1.4、反射获取成员变量1.5、反射获取成员方法1.6、综合练习1.6.1、保存信息1.6.2、跟配置文件结合动态创建 一、反射 1.1、反射的概述 什么是反射? 反射允许对成员变量&#xff0c;成…...

前端web入门-HTML-day01

(创作不易&#xff0c;感谢有你&#xff0c;你的支持&#xff0c;就是我前行的最大动力&#xff0c;如果看完对你有帮助&#xff0c;请留下您的足迹&#xff09; 目录 HTML初体验 HTML 定义 标签语法 总结&#xff1a; HTML 基本骨架 基础知识&#xff1a; 总结&#…...

创建一个Go项目

创建一个Go项目 1.创建项目 package mainfunc main() {println("你好啊&#xff0c;简单点了&#xff01;") }如果是本地的话可以采用go run 项目名的方式。 可以采用go run --work 项目名的方式&#xff0c;此时可以展示日志信息。 如果是只编译的话 go build 项…...

从 Spring 的创建到 Bean 对象的存储、读取

目录 创建 Spring 项目&#xff1a; 1.创建一个 Maven 项目&#xff1a; 2.添加 Spring 框架支持&#xff1a; 3.配置资源文件&#xff1a; 4.添加启动类&#xff1a; Bean 对象的使用&#xff1a; 1.存储 Bean 对象&#xff1a; 1.1 创建 Bean&#xff1a; 1.2 存储 B…...

【一文吃透归并排序】基本归并·原地归并·自然归并 C++

目录 1 引入情境基本归并排序实现 C 2 原地归并排序2-1 死板的解法2-2 原地工作区2-3 链表归并排序 3 自底向上归并排序4 两路自然归并排序4-1 形式化描述4-2 代码实现 1 引入情境 归并思想&#xff1a;假设有两队小孩&#xff0c;都是从矮到高排序&#xff0c;现在通过一扇门后…...

读《Spring Boot 3核心技术与最佳实践》有感

我是谁&#xff1f; &#x1f468;‍&#x1f393;作者&#xff1a;bug菌 ✏️博客&#xff1a;CSDN、掘金、infoQ、51CTO等 &#x1f389;简介&#xff1a;CSDN/阿里云/华为云/51CTO博客专家&#xff0c;C站历届博客之星Top50&#xff0c;掘金/InfoQ/51CTO等社区优质创作者&am…...

板子短路了?

有段时间没更新了&#xff0c;主要是最近有点忙&#xff0c;当然也因为有点“懒”。 做这行业的都知道&#xff0c;下半年都是比较忙的&#xff0c;相信大家也是&#xff01; 相信做硬件的小伙伴们&#xff0c;遇到过短路的板子已经不计其数了。 短路带来的危害&#xff1a;…...

一行代码绘制高分SCI限制立方图

一、概述 Restricted cubic splines (RCS)是一种基于样条函数的非参数化模型&#xff0c;它可以可靠地拟合非线性关系&#xff0c;可以自适应地调整分割结点。在统计学和机器学习领域&#xff0c;RCS通常用来对连续型自变量进行建模&#xff0c;并在解释自变量与响应变量的关系…...

spring 容器结构/机制debug分析--Spring 学习的核心内容和几个重要概念--IOC 的开发模式--综合解图

目录 Spring Spring 学习的核心内容 解读上图: Spring 几个重要概念 ● 传统的开发模式 解读上图 ● IOC 的开发模式 解读上图 代码示例—入门 xml代码 注意事项和细节 1、说明 2、解释一下类加载路径 3、debug 看看 spring 容器结构/机制 综合解图 Spring Spr…...

excel实战小测第四

【项目背景】 本项目为某招聘网站部分招聘信息&#xff0c;要求对“数据分析师”岗位进行招聘需求分析&#xff0c;通过对城市、行业、学历要求、薪资待遇等不同方向进行相关性分析&#xff0c;加深对数据分析行业的了解。 结合企业真实招聘信息&#xff0c;可以帮助有意转向数…...

什么是SpringBoot自动配置

概述&#xff1a; 现在的Java面试基本都会问到你知道什么是Springboot的自动配置。为什么面试官要问这样的问题&#xff0c;主要是在于看你有没有对Springboot的原理有没有深入的了解&#xff0c;有没有看过Springboot的源码&#xff0c;这是区别普通程序员与高级程序员最好的…...

基于IC5000烧录器使用winIDEA烧写+调试程序(S32K324的软件烧写与调试)

目录 一、iSYSTEM简介二、如何使用iSYSTEM winIDEA烧写调试程序2.1 打开winIDEA&#xff1a;2.2 新建一个Workspace;2.3 硬件配置:2.4 选择CPU芯片型号&#xff1a;2.5 加载烧写文件&#xff1a;2.6 开始烧录程序&#xff1a;2.7 程序调试Debug&#xff1a;2.7.1 运行程序&…...

新手开始学【网络安全】要怎么入门?

前言&#xff1a;网络安全如何从零开始学习&#xff0c;少走弯路&#xff1f; 目录&#xff1a; 一&#xff0c;怎么入门&#xff1f; 1、Web 安全相关概念&#xff08;2 周&#xff09;2、熟悉渗透相关工具&#xff08;3 周&#xff09;3、渗透实战操作&#xff08;5 周&…...

Linux指令 快捷键

热键 上一次我们说到了linux的基本指令&#xff0c;这次我们先说一下热键 TAB TAB键在linux中有什么作用呢&#xff1f;&#xff1f; 在Linux中&#xff0c;假设我们想要输入的指令忘记了&#xff0c;我们可以TAB两下&#xff0c;帮我们补全命令或者假如命令太多&#xff0…...

Testing and fault tolerence考试要点

文章目录 ATPGFault modelScanFunctional testMemory BISTLogic BISTboundary scanATEIddq testingFault tolerant designRisk analysis ATPG ATPG工作流程fault collapsing的原则 Fault model 有哪些fault model以及他们的工作原理 Scan Scan寄存器结构Scan Chain的连接方…...

记一次springboot项目漏洞挖掘

前言 前段时间的比赛将该cms作为了题目考察&#xff0c;这个cms的洞也被大佬们吃的差不多了&#xff0c;自己也就借此机会来浅浅测试下这个cms残余漏洞&#xff0c;并记录下这一整个流程&#xff0c;谨以此记给小白师傅们分享下思路&#xff0c;有错误的地方还望大佬们请以指正…...

R语言 | 数据框

目录 一、认识数据框 7.1 建立第一个数据框 7.2 验证与设定数据框的列名和行名 二、认识数据框的结构 三、获取数据框内容 3.1 一般获取 3.2 特殊字符$ 3.3 再看取得的数据 四、使用rbind()函数增加数据框的行数据 五、使用cbind()函数增加数据框的列数据 5.1 使用$符号…...

基于SpringBoot的招生管理系统的设计与实现

背景 本次设计任务是要设计一个招生管理系统&#xff0c;通过这个系统能够满足管理员和学生的招生公告管理功能。系统的主要功能包括首页、个人中心、学生管理、专业信息管理、专业报名管理、录取通知管理、系统管理等功能。 管理员可以根据系统给定的账号进行登录&#xff0…...

Oracle Profile详解

Profile的作用主要表现在三个方面 1、密码策略 2、对用户所能使用的资源进行管理 3、profile存放在数据字典里面&#xff0c;默认有一个名字为default的profile set linesize 160 set pagesize 30 select resource_name,resource_type,limit from dba_profiles where profile‘…...

r语言tidyverse教程:5 字符串处理stringr

文章目录 R语言系列&#xff1a; 编程基础&#x1f48e;循环语句&#x1f48e;向量、矩阵和数组&#x1f48e;列表、数据帧排序函数&#x1f48e;apply系列函数tidyverse&#xff1a;readr&#x1f48e;tibble&#x1f48e;tidyr&#x1f48e;dplyr&#x1f48e;stringr stri…...

知识变现海哥:知识变现的本质就是卖

知识变现的本质就是卖&#xff0c;而有人买的本质&#xff0c;就是你解决了某方面的需求。 好的成交&#xff0c;从来都是相互的&#xff0c; 只靠一边主动推销来维系是远远不够的。 绝对不是靠忽悠&#xff0c;而是靠实力。 先讲一个故事。 19世纪时&#xff0c;一个年轻的…...

jdbc和druid和mybatis之间的关系

第一种方式 jdbc整合了:加载数据库驱动,创建连接,写原生语句,执行,关闭这些东西. 第二种方式 mybatis对jdbc进行封装,他允许你通过配置的形式,配置数据库参数,并且允许你通过xml来写动态sql语句.if:test让你可以把sql变得灵活起来.并且还能将你的查询结果直接映射到你想要的…...