使用 BERT 和逻辑回归进行文本分类及示例验证
使用 BERT 和逻辑回归进行文本分类及示例验证
一、引言
在自然语言处理领域中,文本分类是一项至关重要的任务。本文将详细介绍如何结合 BERT 模型与逻辑回归算法来实现文本分类,并通过实际示例进行验证。
二、环境准备
为了运行本文中的代码,你需要安装以下库:
pandas
:用于数据处理。sklearn
:包含机器学习算法。torch
:用于深度学习任务。transformers
:用于加载预训练语言模型。
三、代码实现
(一)读取数据集
首先,从 CSV 文件中读取数据集。假设该数据集包含两列,分别是content
(文本内容)和labels
(文本标签)。
import pandas as pd# 从 CSV 文件读取数据集
print("正在读取数据集...")
df = pd.read_csv('training_data.csv', encoding='utf-8-sig')
print("数据集读取完成,共包含 {} 条数据.".format(len(df)))
(二)分割数据集
接着,提取特征和目标,并将数据集分割为训练集和测试集。
# 提取特征和目标
X = df['content']
y = df['labels']# 分割数据集
print("正在分割数据集...")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("训练集大小: {}, 测试集大小: {}".format(len(X_train), len(X_test)))
(三)加载 BERT 模型和分词器
然后,加载 BERT 模型和分词器,以便将文本转化为特征向量。
import torch
from transformers import BertTokenizer, BertModel# 加载 BERT 模型和分词器
print("加载 BERT 模型和分词器...")
tokenizer = BertTokenizer.from_pretrained('D:\\bert-base-chinese')
model = BertModel.from_pretrained('D:\\bert-base-chinese')
(四)文本转化为特征向量
定义一个函数get_embeddings
,用于将文本转化为特征向量。该函数利用 BERT 模型对文本进行编码,然后获取[CLS]
标记的输出作为文本的特征向量。
# 文本转化为特征向量
def get_embeddings(texts):print("正在生成文本特征向量...")inputs = tokenizer(texts.tolist(), padding=True, truncation=True, return_tensors='pt')with torch.no_grad():outputs = model(**inputs)# 获取[CLS]标记的输出作为文本的特征向量return outputs.last_hidden_state[:, 0, :].numpy()
(五)训练分类模型
使用逻辑回归算法作为分类模型。先将训练集转化为 BERT 特征,然后训练分类模型。
from sklearn.linear_model import LogisticRegression# 转换训练集和测试集为 BERT 特征
X_train_bert = get_embeddings(X_train)
X_test_bert = get_embeddings(X_test)# 训练分类模型
print("正在训练分类模型...")
classifier = LogisticRegression(max_iter=1000) # 使用逻辑回归
classifier.fit(X_train_bert, y_train)
print("模型训练完成.")
(六)预测
使用训练好的分类模型对测试集进行预测,并打印预测结果。
# 预测
print("正在进行预测...")
predictions = classifier.predict(X_test_bert)# 打印预测结果
print("预测结果:", predictions)
(七)示例数据验证
最后,添加一些示例数据进行验证。将示例数据转化为 BERT 特征,然后使用分类模型进行预测,并打印预测结果。
# 添加示例数据进行验证
sample_texts = ["音乐有助力放松大脑,心情愉悦。","热爱生活,享受人生",
]# 将示例数据转换为 BERT 特征
print("正在对示例数据进行预测...")
sample_embeddings = get_embeddings(pd.Series(sample_texts))
sample_predictions = classifier.predict(sample_embeddings)# 打印示例数据预测结果
for text, prediction in zip(sample_texts, sample_predictions):print(f"文本: \"{text}\" 预测标签: {prediction}")
四、总结
本文介绍了如何运用 BERT 和逻辑回归进行文本分类,并通过示例数据进行了验证。借助 BERT 模型学习到的文本上下文信息,能够显著提高文本分类的准确性。同时,逻辑回归算法的快速性使得我们可以高效地对大量文本进行分类。
五、完整代码
text_categorize_and_tag.py
import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
import torch
from transformers import BertTokenizer, BertModel# 从CSV文件读取数据集
print("正在读取数据集...")
df = pd.read_csv('training_data.csv', encoding='utf-8-sig')
print("数据集读取完成,共包含 {} 条数据.".format(len(df)))# 提取特征和目标
X = df['content']
y = df['labels']# 分割数据集
print("正在分割数据集...")
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
print("训练集大小: {}, 测试集大小: {}".format(len(X_train), len(X_test)))# 加载BERT模型和分词器
print("加载BERT模型和分词器...")
tokenizer = BertTokenizer.from_pretrained('D:\\bert-base-chinese')
model = BertModel.from_pretrained('D:\\bert-base-chinese')# 文本转化为特征向量
def get_embeddings(texts):print("正在生成文本特征向量...")inputs = tokenizer(texts.tolist(), padding=True, truncation=True, return_tensors='pt')with torch.no_grad():outputs = model(**inputs)# 获取[CLS]标记的输出作为文本的特征向量return outputs.last_hidden_state[:, 0, :].numpy()# 转换训练集和测试集为BERT特征
X_train_bert = get_embeddings(X_train)
X_test_bert = get_embeddings(X_test)# 训练分类模型
print("正在训练分类模型...")
classifier = LogisticRegression(max_iter=1000) # 使用逻辑回归
classifier.fit(X_train_bert, y_train)
print("模型训练完成.")# 预测
print("正在进行预测...")
predictions = classifier.predict(X_test_bert)# 打印预测结果
print("预测结果:", predictions)# 添加示例数据进行验证
sample_texts = ["音乐有助力放松大脑,心情愉悦。","热爱生活,享受人生",
]# 将示例数据转换为BERT特征
print("正在对示例数据进行预测...")
sample_embeddings = get_embeddings(pd.Series(sample_texts))
sample_predictions = classifier.predict(sample_embeddings)# 打印示例数据预测结果
for text, prediction in zip(sample_texts, sample_predictions):print(f"文本: \"{text}\" 预测标签: {prediction}")
training_data.csv
content,labels
"Python 是一种广泛使用的高级编程语言。","编程"
"自然语言处理是人工智能领域的重要研究方向。","NLP"
"机器学习是分析数据的重要工具。","机器学习"
"数据科学结合了统计学和计算机科学。","数据科学"
"人工智能正在改变我们的生活方式。","人工智能"
"深度学习能够处理复杂的数据集。","机器学习"
"很多企业开始应用人工智能技术以提高效率。","人工智能"
"数据分析是理解客户行为的重要工具。","数据科学"
"编程不仅是技术,更是一种思维方式。","编程"
"算法在大数据时代发挥着重要作用。","数据科学"
"音乐可以影响人的情绪和认知。","音乐"
"学习音乐可以提高学生的创造力。","教育"
"现场音乐会可以提供独特的视听体验。","娱乐"
"教育科技正在变革传统的学习方式。","教育"
"学习一门乐器有助于提升专注力。","音乐"
"电影和电视节目是现代娱乐的重要部分。","娱乐"
"音乐治疗被广泛应用于心理健康。","音乐"
"在线教育平台为学习者提供灵活的选择。","教育"
"综艺节目为观众提供了丰富的娱乐内容。","娱乐"
"这是一篇关于机器学习的文章。","科技"
"我喜欢户外活动和旅游。","生活"
"COVID-19疫情对全球经济产生了深远的影响。","财经"
"人工智能正在改变我们的生活方式。","科技"
"旅游是一种能让人开阔视野的活动。","生活"
"金融科技让我们的投资变得更加智能。","财经"
"环境保护对我们的未来至关重要。","环保"
相关文章:
使用 BERT 和逻辑回归进行文本分类及示例验证
使用 BERT 和逻辑回归进行文本分类及示例验证 一、引言 在自然语言处理领域中,文本分类是一项至关重要的任务。本文将详细介绍如何结合 BERT 模型与逻辑回归算法来实现文本分类,并通过实际示例进行验证。 二、环境准备 为了运行本文中的代码…...
【skywalking 】监控 Spring Cloud Gateway 数据
使用Spring Cloud 开发,用Skywalking 监控服务,但是Skywalking 默认是不支持 Spring Cloud Gateway 网关服务的,需要手动将 Gateway 的插件添加到 Skywalking 启动依赖 jar 中。 skywalking相关版本信息 jdk:17skywalking&#x…...
SpringWeb
SpringWeb SpringWeb 概述 SpringWeb 是 spring 框架中的一个模块,基于 Servlet API 构建的 web 框架. springWeb 是 Spring 为 web 层开发提供的一整套完备的解决方案。 在 web 层框架历经 Strust1,WebWork,Strust2 等诸多产品的历代更…...
嵌入式刷题(day21)
MySQL和sqlite的区别 MySQL和SQLite是两种常见的关系型数据库管理系统(RDBMS),但它们在特性、使用场景和架构方面有显著的区别: 1. 架构 MySQL:是一个基于服务器的数据库系统,遵循客户端-服务器架构。MySQL服务器运行在主机上,客户端通过网络连接并发送查询。它可以并…...
OpenAI 下一代旗舰模型现身?奥尔特曼亲自辟谣“猎户座“传闻
在人工智能领域最受瞩目的ChatGPT即将迎来两周岁之际,一场关于OpenAI新旗舰模型的传闻再次引发业界热议。然而,这场喧嚣很快就被OpenAI掌门人奥尔特曼亲自澄清。 事件源于科技媒体The Verge的一则报道。据多位知情人士透露,OpenAI可能会在11…...
【C++】STL初识
【C】STL初识 文章目录 【C】STL初识前言一、STL基本概念二、STL六大组件简介三、STL三大组件四、初识STL总结 前言 本篇文章将讲到STL基本概念,STL六大组件简介,STL三大组件,初识STL。 一、STL基本概念 STL(Standard Template Library,标准…...
框架篇补充(东西多 需要重新看网课)
什么是AOP 面向切面编程 降低耦合 提高代码的复用 Spring的bean的生命周期 实例化bean 赋值 初始化bean 使用bean 销毁bean SpringMVC的执行流程 Springboot自动装配原理 实际上就是为了从spring.factories文件中 获取到对应的需要 进行自动装配的类 并生成相应的Bean…...
合约门合同全生命周期管理系统:企业合同管理的数字化转型之道
合约门合同全生命周期管理系统:企业合同管理的数字化转型之道 1. 引言 在现代企业中,合同管理已经不再是简单的文件存储和审批流程,而是企业合规性、风险管理和业务流程的关键环节之一。随着企业规模的扩大和合同数量的增加,传统…...
等保测评与风险管理:识别、评估和缓解潜在的安全威胁
在信息化时代,数据已成为企业最宝贵的资产之一,而信息安全则成为守护这份资产免受侵害的重中之重。等保测评(信息安全等级保护测评)作为保障信息系统安全的重要手段,其核心在于通过科学、规范、专业的评估手段…...
Golang Agent 可观测性的全面升级与新特性介绍
作者:张海彬(古琦) 背景 自 2024 年 6 月 26 日,ARMS 发布了针对 Golang 应用的可观测性监控功能以来,阿里云 ARMS 团队与程序语言与编译器团队一直致力于不断优化和提升该系统的各项功能,旨在为开发者提…...
SpringBoot的开篇 特点 初始化 ioc 配置文件
文章目录 前言SpringBoot发展历程SpringBoot前置准备SpringBoot特点 SpringBoot项目初始化项目启动Springboot的核心概念IOC概念介绍Bean对象通过注解扫描包 例子配置文件 前言 SpringBoot发展历程 最初,Spring框架的使用需要大量的XML配置,这使得开发…...
docker 可用镜像服务地址(2024.10.25亲测可用)
1.错误 Error response from daemon: Get “https://registry-1.docker.io/v2/” 原因:镜像服务器地址不可用。 2.可用地址 编辑daemon.json: vi /etc/docker/daemon.json内容修改如下: {"registry-mirrors": ["https://…...
【SQL实验】表的更新和简单查询
完整代码在文章末尾 在上次实验创建的educ数据库基础上,用SQL语句为student表、course表和sc表中添加以下记录 【SQL实验】数据库、表、模式的SQL语句操作_创建一个名为educ数据库,要求如下: (下面三个表中属性的数据类型需要自己设计合适-CSDN博客在这篇博文中已经…...
【C++】 string的了解及使用
标准库中的string类 在使用string类时,必须包含#include头文件以及using namespace std; string类的常用接口说明 C中string为我们提供了丰富的接口来供我们使用 – string接口文档 这里我们只介绍一些常见的接口 string类对象的常见构造 #include <iostrea…...
【K8S】kubernetes-dashboard.yaml
https://raw.githubusercontent.com/kubernetes/dashboard/v3.0.0-alpha0/charts/kubernetes-dashboard.yaml 以下链接的内容: 由于国内访问不了,找到一些方法下载了这个文件内容, 部署是mages 对象的镜像 WEB docker.io/kubernetesui/dash…...
远程root用户访问服务器中的MySQL8
一、Ubuntu下的MySQL8安装 在Ubuntu系统中安装MySQL 8.0可以通过以下步骤进行1. 更新包管理工具的仓库列表: sudo apt update 2. 安装MySQL 8.0,root用户默认没有密码: sudo apt install mysql-server sudo apt install mysql-client 【…...
解释一下 Java 中的静态变量(Static Variable)和静态方法(Static Method)?
今天来和大家深入探讨一下 Java 中的静态变量和静态方法,并通过一些具体的例子来理解它们在实际开发中的应用。 静态变量(Static Variable) 静态变量,也称为类变量,是在类的层次上共享的变量。这意味着无论创建了多少…...
【Linux】————磁盘与文件系统
作者主页: 作者主页 本篇博客专栏:Linux 创作时间 :2024年10月17日 一、磁盘的物理结构 磁盘的物理结构如图所示: 其中具体的物理存储结构如下: 磁盘中存储的基本单位为扇区,一个扇区的大小一般为512字…...
平衡控制——直立环——速度环
目录 平衡控制原理 平衡控制模型 平衡控制中基于模型设计与自动代码生成技术 速度环应用原理 速度控制模型 平衡控制原理 下图是一个单摆模型,对其进行受力分析如图。 在重力作用下,单摆受到和角度成正比,运动方向相反的回复力。而且在空气中运动的单摆,由于受…...
面试简要介绍hashMap
jdk8之前,hashmap采用的数据结构是数组链表,jdk8之后采用的数据结构是数组链表/红黑树。hashmap的数据以键值对的形式存在,如果两个元素的hash值相同,就会发生hash冲突,被放到同一个链表上--->如何解决hash冲突---&…...
HTTPS如何实现加密以及SSL/TSL加密的详细过程
通过将服务器从 HTTP 提升到 HTTPS 加密,数据在客户端和服务器之间的传输过程中的确得到了安全保护。以下是这种实现加密的机制以及客户端需要做的事情的详细说明。 为什么这样就实现了加密 SSL/TLS 协议: HTTPS 使用 SSL(安全套接层&#x…...
Golang | Leetcode Golang题解之第516题最长回文子序列
题目: 题解: func longestPalindromeSubseq(s string) int {n : len(s)dp : make([][]int, n)for i : range dp {dp[i] make([]int, n)}for i : n - 1; i > 0; i-- {dp[i][i] 1for j : i 1; j < n; j {if s[i] s[j] {dp[i][j] dp[i1][j-1] …...
(done) 什么 RPC 协议? remote procedure call 远程调用协议
来源:https://www.bilibili.com/video/BV1Qv4y127B4/?spm_id_from333.337.search-card.all.click&vd_source7a1a0bc74158c6993c7355c5490fc600 可以理解为,调用远程服务器上的一个方法/函数/服务的方式,同时隐藏网络细节 一个 python3 …...
PCL 基于Ransac提取误匹配点对
目录 一、概述 1.1原理 1.2实现步骤 1.3应用场景 二、代码实现 2.1关键函数 2.1.1 基于RANSAC的误匹配点对提出函数 2.1.2 点云可视化函数 2.2完整代码 三、实现效果 PCL点云算法汇总及实战案例汇总的目录地址链接: PCL点云算法与项目实战案例汇总(长期更新) 一、…...
光速写作 2.0.5 | 专注AI写作,海量素材库
光速写作是一款专为解决写作难题设计的应用。它具有以下功能:- 「AI写作」:帮助分析题目、整理写作思路,合成作文,写出好文章。- 「作文批改」:拍照上传作文后,进行全文点评和分句点评,并进行全…...
【已解决,含泪总结】非root权限在服务器上配置python和torch环境,代码最终成功训练(一)
配置Python环境 没有root权限服务器上有多个python环境但没有自己想要的怎么办 之前跑别的实验的时候改过指定的python3.7版本,但是居然我过了一段时间之后,再次打开,python版本居然又回到2.7(服务器/usr/下的默认python版本&am…...
公安基础知识-通哥
公安机关办理行政案件能力 考点一 治安案件追溯失效 6个月 派出所只有警告和500块以下罚款 公安是行政机关 1、治安小事、刑事案件大事 2、殴打他人-轻伤-(刑事案件)、轻微伤(治安案件) 3、《治安处罚法》《刑法》 4、只能构…...
Python画图|极坐标下的散点图动态输出
【1】引言 前序已经学习过散点图输出和极坐标图输出,文章链接包括但不限于下述部分: python画散点图|scatter()函数小试牛刀(入门级教程)_python ax.scatter-CSDN博客 python画图|极坐标中画散点图_极坐标上的散点图-CSDN博客 …...
揭开MySQL并发中的“死锁”之谜:从原理到解决方案的深度解析
目录 1. 环境准备:创建“账户”和“标记”表1.1 创建 dl_account_t 表1.2 创建 dl_mark_t 表 2. 死锁详解2.1 死锁情景一:相反加锁顺序导致的死锁2.2 死锁情景二:唯一索引冲突引发的死锁 3. 事务隔离级别与锁机制4. 预防与解决死锁的方法4.1 …...
【论文阅读】Reliable, Adaptable, and Attributable Language Models with Retrieval
文章目录 OverviewCurrent Retrieval-Augmented LMsArchitectureTraining Limitations & Future Work Overview Parametic language models的缺点: 事实性错误的普遍存在验证的难度(可溯源性差)难以在有顾虑的情况下排除某些序列适应调整…...
wordpress的开发者/自己建网站要多少钱
一.为什么学习节点操作 获取元素通常使用两种方式: 1.利用DOM提供的方法获取元素 document.getElementByld()document.getElementsByTagName()document.querySelector等逻辑性不强、繁琐 2.利用节点层级关系获取元素 利用父子兄节点关系获…...
ui做网站流程/网站提交
PublicClassTestClass Test Private_classid AsString <summary> 设置和获取分类ID </summary>PublicPropertyclassid() GetReturn_classid EndGetSet(ByValvalue) _classid value EndSetEnd PropertyEnd Class...
oa系统的概念/谷歌seo外包
添加多台压力机1、前置条件1)保证压力机上都安装了loadrunner Agent,并启动,状态栏中会有小卫星。2)添加的压力机与controller所在机器是否在同一个网段,建议关闭防火墙。在controller压力机上 ping 下连接压力机&…...
不用ftp可以做网站吗/深圳百度推广关键词推广
Next.js 教程阐述标签式导航 Router模块进行跳转阐述 学完如何编写组件和页面后,下一步应该了解的就是路由体系,每个框架都有着不同的路由体系,本文先学习最基础的页面如何跳转。 页面跳转一般有两种形式: 第一种是利用标签 <…...
同一个域名可以做几个网站吗/seo网站关键词排名提升
最近在做一个分类的任务,输入为3通道车型图片,输出要求将这些图片对车型进行分类,最后分类类别总共是30个。 开始是试用了实验室师姐的方法采用了VGGNet的模型对车型进行分类,据之前得实验结果是训练后最高能达到92%的正确率&…...
b2c电子商务网站的收益模式主要有/泉州排名推广
帮助您构建高质量的应用,是我们长期努力的一个方向。为此,我们经常寻找可以在工具和资源上投入精力的领域,这些工具和资源可以使您更加深刻地了解应用的性能。重大更新在 Android 11 上,我们引入了两个新工具——"数据访问审…...