当前位置: 首页 > news >正文

基于BERT的情感分析

基于BERT的情感分析

1. 项目背景

情感分析(Sentiment Analysis)是自然语言处理的重要应用之一,用于判断文本的情感倾向,如正面、负面或中性。随着深度学习的发展,预训练语言模型如BERT在各种自然语言处理任务中取得了显著的效果。本项目利用预训练语言模型BERT,构建一个能够对文本进行情感分类的模型。


2. 项目结构

sentiment-analysis/
├── data/
│   ├── train.csv        # 训练数据集
│   ├── test.csv         # 测试数据集
├── src/
│   ├── preprocess.py    # 数据预处理模块
│   ├── train.py         # 模型训练脚本
│   ├── evaluate.py      # 模型评估脚本
│   ├── inference.py     # 模型推理脚本
│   ├── utils.py         # 工具函数(可选)
├── models/
│   ├── bert_model.pt    # 保存的模型权重
├── logs/
│   ├── training.log     # 训练日志(可选)
├── README.md            # 项目说明文档
├── requirements.txt     # 依赖包列表
└── run.sh               # 一键运行脚本

3. 环境准备

3.1 系统要求

  • Python 3.6 或以上版本
  • GPU(可选,但建议使用以加速训练)

3.2 安装依赖

建议在虚拟环境中运行。安装所需的依赖包:

pip install -r requirements.txt

requirements.txt内容:

torch>=1.7.0
transformers>=4.0.0
pandas
scikit-learn
tqdm

4. 数据准备

4.1 数据格式

数据文件train.csvtest.csv的格式如下:

textlabel
I love this product.1
This is a bad movie.0
  • text:输入文本
  • label:目标标签,1为正面情感,0为负面情感

将数据文件保存至data/目录下。

4.2 数据集划分

可以使用train_test_split将数据划分为训练集和测试集。


5. 代码实现

5.1 数据预处理 (src/preprocess.py)

import pandas as pd
from transformers import BertTokenizer
from torch.utils.data import Dataset
import torchclass SentimentDataset(Dataset):"""自定义的用于情感分析的Dataset。"""def __init__(self, data_path, tokenizer, max_len=128):"""初始化Dataset。Args:data_path (str): 数据文件的路径。tokenizer (BertTokenizer): BERT的分词器。max_len (int): 最大序列长度。"""self.data = pd.read_csv(data_path)self.tokenizer = tokenizerself.max_len = max_lendef __len__(self):"""返回数据集的大小。"""return len(self.data)def __getitem__(self, idx):"""根据索引返回一条数据。Args:idx (int): 数据索引。Returns:dict: 包含input_ids、attention_mask和label的字典。"""text = str(self.data.iloc[idx]['text'])label = int(self.data.iloc[idx]['label'])encoding = self.tokenizer(text, padding='max_length', truncation=True, max_length=self.max_len, return_tensors="pt")return {'input_ids': encoding['input_ids'].squeeze(0),  # shape: [seq_len]'attention_mask': encoding['attention_mask'].squeeze(0),  # shape: [seq_len]'label': torch.tensor(label, dtype=torch.long)  # shape: []}

5.2 模型训练 (src/train.py)

import torch
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, AdamW, BertTokenizer, get_linear_schedule_with_warmup
from preprocess import SentimentDataset
import argparse
import os
from tqdm import tqdmdef train_model(data_path, model_save_path, batch_size=16, epochs=3, lr=2e-5, max_len=128):"""训练BERT情感分析模型。Args:data_path (str): 训练数据的路径。model_save_path (str): 模型保存的路径。batch_size (int): 批次大小。epochs (int): 训练轮数。lr (float): 学习率。max_len (int): 最大序列长度。"""# 初始化分词器和数据集tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')dataset = SentimentDataset(data_path, tokenizer, max_len=max_len)# 划分训练集和验证集train_size = int(0.8 * len(dataset))val_size = len(dataset) - train_sizetrain_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])# 数据加载器train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)val_loader = DataLoader(val_dataset, batch_size=batch_size)# 初始化模型model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)# 优化器和学习率调度器optimizer = AdamW(model.parameters(), lr=lr)total_steps = len(train_loader) * epochsscheduler = get_linear_schedule_with_warmup(optimizer, num_warmup_steps=0, num_training_steps=total_steps)# 设备设置device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)# 训练循环for epoch in range(epochs):model.train()total_loss = 0progress_bar = tqdm(train_loader, desc=f"Epoch {epoch + 1}/{epochs}")for batch in progress_bar:optimizer.zero_grad()input_ids = batch['input_ids'].to(device)  # shape: [batch_size, seq_len]attention_mask = batch['attention_mask'].to(device)  # shape: [batch_size, seq_len]labels = batch['label'].to(device)  # shape: [batch_size]outputs = model(input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.lossloss.backward()optimizer.step()scheduler.step()total_loss += loss.item()progress_bar.set_postfix(loss=loss.item())avg_train_loss = total_loss / len(train_loader)print(f"Epoch {epoch + 1}/{epochs}, Average Loss: {avg_train_loss:.4f}")# 验证模型model.eval()val_loss = 0correct = 0total = 0with torch.no_grad():for batch in val_loader:input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['label'].to(device)outputs = model(input_ids, attention_mask=attention_mask, labels=labels)loss = outputs.losslogits = outputs.logitsval_loss += loss.item()preds = torch.argmax(logits, dim=1)correct += (preds == labels).sum().item()total += labels.size(0)avg_val_loss = val_loss / len(val_loader)val_accuracy = correct / totalprint(f"Validation Loss: {avg_val_loss:.4f}, Accuracy: {val_accuracy:.4f}")# 保存模型os.makedirs(os.path.dirname(model_save_path), exist_ok=True)torch.save(model.state_dict(), model_save_path)print(f"Model saved to {model_save_path}")if __name__ == "__main__":parser = argparse.ArgumentParser(description="Train BERT model for sentiment analysis")parser.add_argument('--data_path', type=str, default='data/train.csv', help='Path to training data')parser.add_argument('--model_save_path', type=str, default='models/bert_model.pt', help='Path to save the trained model')parser.add_argument('--batch_size', type=int, default=16, help='Batch size')parser.add_argument('--epochs', type=int, default=3, help='Number of training epochs')parser.add_argument('--lr', type=float, default=2e-5, help='Learning rate')parser.add_argument('--max_len', type=int, default=128, help='Maximum sequence length')args = parser.parse_args()train_model(data_path=args.data_path,model_save_path=args.model_save_path,batch_size=args.batch_size,epochs=args.epochs,lr=args.lr,max_len=args.max_len)

5.3 模型评估 (src/evaluate.py)

import torch
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
from preprocess import SentimentDataset
from torch.utils.data import DataLoader
from transformers import BertForSequenceClassification, BertTokenizer
import argparse
from tqdm import tqdmdef evaluate_model(data_path, model_path, batch_size=16, max_len=128):"""评估BERT情感分析模型。Args:data_path (str): 测试数据的路径。model_path (str): 训练好的模型的路径。batch_size (int): 批次大小。max_len (int): 最大序列长度。"""# 初始化分词器和数据集tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')dataset = SentimentDataset(data_path, tokenizer, max_len=max_len)loader = DataLoader(dataset, batch_size=batch_size)# 加载模型model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))model.eval()# 设备设置device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)all_preds = []all_labels = []with torch.no_grad():for batch in tqdm(loader, desc="Evaluating"):input_ids = batch['input_ids'].to(device)attention_mask = batch['attention_mask'].to(device)labels = batch['label'].to(device)outputs = model(input_ids, attention_mask=attention_mask)logits = outputs.logitspreds = torch.argmax(logits, dim=1)all_preds.extend(preds.cpu().numpy())all_labels.extend(labels.cpu().numpy())accuracy = accuracy_score(all_labels, all_preds)precision, recall, f1, _ = precision_recall_fscore_support(all_labels, all_preds, average='binary')print(f"Accuracy: {accuracy:.4f}")print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1-score: {f1:.4f}")if __name__ == "__main__":parser = argparse.ArgumentParser(description="Evaluate BERT model for sentiment analysis")parser.add_argument('--data_path', type=str, default='data/test.csv', help='Path to test data')parser.add_argument('--model_path', type=str, default='models/bert_model.pt', help='Path to the trained model')parser.add_argument('--batch_size', type=int, default=16, help='Batch size')parser.add_argument('--max_len', type=int, default=128, help='Maximum sequence length')args = parser.parse_args()evaluate_model(data_path=args.data_path,model_path=args.model_path,batch_size=args.batch_size,max_len=args.max_len)

5.4 推理 (src/inference.py)

import torch
from transformers import BertTokenizer, BertForSequenceClassification
import argparsedef predict_sentiment(text, model_path, max_len=128):"""对输入的文本进行情感预测。Args:text (str): 输入的文本。model_path (str): 训练好的模型的路径。max_len (int): 最大序列长度。Returns:str: 预测的情感类别。"""# 初始化分词器和模型tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')model = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=2)model.load_state_dict(torch.load(model_path, map_location=torch.device('cpu')))model.eval()# 设备设置device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model.to(device)# 数据预处理inputs = tokenizer(text, return_tensors="pt", truncation=True, padding='max_length', max_length=max_len)inputs = {key: value.to(device) for key, value in inputs.items()}# 模型推理with torch.no_grad():outputs = model(**inputs)logits = outputs.logitsprediction = torch.argmax(logits, dim=1).item()sentiment = "Positive" if prediction == 1 else "Negative"return sentimentif __name__ == "__main__":parser = argparse.ArgumentParser(description="Inference script for sentiment analysis")parser.add_argument('--text', type=str, required=True, help='Input text for sentiment prediction')parser.add_argument('--model_path', type=str, default='models/bert_model.pt', help='Path to the trained model')parser.add_argument('--max_len', type=int, default=128, help='Maximum sequence length')args = parser.parse_args()sentiment = predict_sentiment(text=args.text,model_path=args.model_path,max_len=args.max_len)print(f"Input Text: {args.text}")print(f"Predicted Sentiment: {sentiment}")

6. 项目运行

6.1 一键运行脚本 (run.sh)

#!/bin/bash# 训练模型
python src/train.py --data_path=data/train.csv --model_save_path=models/bert_model.pt# 评估模型
python src/evaluate.py --data_path=data/test.csv --model_path=models/bert_model.pt# 推理示例
python src/inference.py --text="I love this movie!" --model_path=models/bert_model.pt

6.2 单独运行

6.2.1 训练模型
python src/train.py --data_path=data/train.csv --model_save_path=models/bert_model.pt --epochs=3 --batch_size=16
6.2.2 评估模型
python src/evaluate.py --data_path=data/test.csv --model_path=models/bert_model.pt
6.2.3 模型推理
python src/inference.py --text="This product is great!" --model_path=models/bert_model.pt

7. 结果展示

7.1 训练结果

  • 损失下降曲线:可以使用matplotlibtensorboard绘制训练过程中的损失变化。
  • 训练日志:在logs/training.log中记录训练过程。

7.2 模型评估

  • 准确率(Accuracy):模型在测试集上的准确率。
  • 精确率、召回率、F1-score:更全面地评估模型性能。

7.3 推理示例

示例:

python src/inference.py --text="I absolutely love this!" --model_path=models/bert_model.pt

输出:

Input Text: I absolutely love this!
Predicted Sentiment: Positive

8. 注意事项

  • 模型保存与加载:确保模型保存和加载时的路径正确,特别是在使用相对路径时。
  • 设备兼容性:代码中已考虑CPU和GPU的兼容性,确保设备上安装了相应的PyTorch版本。
  • 依赖版本:依赖的库版本可能会影响代码运行,建议使用requirements.txt中指定的版本。

9. 参考资料

  • BERT论文
  • Hugging Face Transformers文档
  • PyTorch官方文档

相关文章:

基于BERT的情感分析

基于BERT的情感分析 1. 项目背景 情感分析(Sentiment Analysis)是自然语言处理的重要应用之一,用于判断文本的情感倾向,如正面、负面或中性。随着深度学习的发展,预训练语言模型如BERT在各种自然语言处理任务中取得了…...

推荐15个2024最新精选wordpress模板

以下是推荐的15个2024年最新精选WordPress模板,轻量级且SEO优化良好,适合需要高性能网站的用户。中文wordpress模板适合搭建企业官网使用。英文wordpress模板,适合B2C网站搭建,功能强大且兼容性好,是许多专业外贸网站的…...

AWTK-WIDGET-WEB-VIEW 实现笔记 (2) - Windows

在 Windows 平台上的实现,相对比较顺利,将一个窗口嵌入到另外一个窗口是比较容易的事情。 1. 创建窗口 这里有点需要注意: 父窗口的大小变化时,子窗口也要跟着变化,否则 webview 显示不出来。创建时窗口的大小先设置…...

Linux四剑客及正则表达式

正则表达式 基础正则(使用四剑客命令时无需加任何参数即可使用) ^ # 匹配以某一内容开头 如:^grep匹配所有以grep开头的行。 $ # 匹配以某一内容结尾 如:grep$ 匹配所有以grep结尾的行。 ^$ # 匹配空行。 . # 匹配…...

ALS 推荐算法案例演示(python)

数学知识补充:矩阵 总结来说: Am*k X Bk*n Cm*n ----至于乘法的规则,是数学问题, 知道可以乘即可,不需要我们自己计算 反过来 Cm*n Am*k X Bk*n ----至于矩阵如何拆分/如何分解,是数学问题,知道可以拆/可以分解即可 ALS 推荐算法案例:电影推…...

labview中连接sql server数据库查询语句

当使用数据库查询功能时,我们需要用到数据库的查询语句,这里已调用sql server为例,我们需要按照时间来查询,这里在正常调用数据库查询语句时,我们需要在前面给他加一个限制条件这里用到了,数据库的查询语句…...

leetcode_二叉树最大深度

对二叉树的理解 对递归调用的理解 对内存分配的理解 基础数据结构(C版本) - 飞书云文档 每次函数的调用 都会进行一次新的栈内存分配 所以lmax和rmax的值不会混在一起 /*** Definition for a binary tree node.* struct TreeNode {* int val;* …...

Elasticsearch 重建索引 数据迁移

Elasticsearch 重建索引 数据迁移 处理流程创建临时索引数据迁移重建索引写在最后 大家都知道,es的索引创建完成之后就不可以再修改了,包括你想更改字段属性或者是分词方式等。那么随着业务数据量的发展,可能会出现需要修改索引,或…...

2411rust,异步函数

原文 Rust异步工作组很高兴地宣布,在实现在特征中使用异步 fn的目标方面取得了重大进度.将在下周发布稳定的Rust1.75版,会包括特征中支持impl Trait注解和async fn. 稳定化 自从RFC#1522在Rust1.26中稳定下来以来,Rust就允许用户按函数的返回类型(一般叫"RPIT")编…...

前端网络性能优化问题

DNS预解析 DNS 解析也是需要时间的&#xff0c;可以通过预解析的⽅式来预先获得域名所对应的 IP。 <link rel"dns-prefetch" href"//abcd.cn"> 缓存 强缓存 在缓存期间不需要请求&#xff0c; state code 为 200 可以通过两种响应头实现&#…...

优选算法——双指针

前言 本篇博客为大家介绍双指针问题&#xff0c;它属于优选算法中的一种&#xff0c;也是一种很经典的算法&#xff1b;算法部分的学习对我们来说至关重要&#xff0c;它可以让我们积累解题思路&#xff0c;同时也可以大大提升我们的编程能力&#xff0c;本文主要是通过一些题…...

【Rabbitmq篇】RabbitMQ⾼级特性----消息确认

目录 前言&#xff1a; 一.消息确认机制 • ⾃动确认 • ⼿动确认 手动确认方法又分为三种&#xff1a; 二. 代码实现&#xff08;spring环境&#xff09; 配置相关信息&#xff1a; 1&#xff09;. AcknowledgeMode.NONE 2 &#xff09;AcknowledgeMode.AUTO 3&…...

开源TTS语音克隆神器GPT-SoVITS_V2版本地整合包部署与远程使用生成音频

文章目录 前言1.GPT-SoVITS V2下载2.本地运行GPT-SoVITS V23.简单使用演示4.安装内网穿透工具4.1 创建远程连接公网地址 5. 固定远程访问公网地址 前言 本文主要介绍如何在Windows系统电脑使用整合包一键部署开源TTS语音克隆神器GPT-SoVITS&#xff0c;并结合cpolar内网穿透工…...

【idea】更换快捷键

因为个人习惯问题需要把快捷键替换一下。我喜欢用CTRLD删除一下&#xff0c;用CTRLY复制一样。恰好这两个快捷键需要互换一下。 打开file——>setting——>Keymap——>Edit Actions 找到CTRLY并且把它删除 找到CTRLD 并且把它删除 鼠标右键添加CTRLY 同样操作在Delet…...

最小的子数组(leetcode 209)

给定一个正整数数组&#xff0c;找到大于等于s的连续的最小长度的区间。 解法一&#xff1a;暴力解法 两层for循环&#xff0c;一个区间终止位置&#xff0c;一个区间起始位置&#xff0c;找到大于等于s的最小区间长度&#xff08;超时了&#xff09; 解法二&#xff1a;双指…...

IDEA-Plugins无法下载插件(网络连接问题-HTTP Proxy Settings)

IDEA-Plugins无法下载插件&#xff08;网络连接问题&#xff09; 改成如下配置&#xff1a; 勾选 添这个url即可&#xff1a;https://plugins.jetbrains.com/ 重启插件中心&#xff0c;问题解决。...

AWTK-WIDGET-WEB-VIEW 发布

awtk-widget-web-view 是通过 webview 提供的接口&#xff0c;实现的 AWTK 自定义控件&#xff0c;使得 AWTK 可以方便的显示 web 页面。 项目网址&#xff1a; https://gitee.com/zlgopen/awtk-widget-web-view webview 提供了一个跨平台的 webview 接口&#xff0c;是一个非…...

Mysql每日一题(if函数)

两种写法if()和case if()函数 select *,if(T.xT.y>T.z and T.xT.z>T.y and T.yT.z>T.x,Yes,No) as triangle from Triangle as T; case方法 select *, case when T.xT.y>T.z and T.xT.z>T.y and T.yT.z>T.x then Yes else No end as triangle from Trian…...

Spring Cloud Alibaba [Gateway]网关。

1 简介 网关作为流量的入口&#xff0c;常用功能包括路由转发、权限校验、限流控制等。而springcloudgateway 作为SpringCloud 官方推出的第二代网关框架&#xff0c;取代了Zuul网关。 1.1 SpringCloudGateway特点: &#xff08;1&#xff09;基于Spring5&#xff0c;支持响应…...

【动手学深度学习Pytorch】2. Softmax回归代码

零实现 导入所需要的包&#xff1a; import torch from IPython import display from d2l import torch as d2l定义数据集参数、模型参数&#xff1a; batch_size 256 # 每次随机读取256张图片 train_iter, test_iter d2l.load_data_fashion_mnist(batch_size) # 将展平每个…...

技术周总结 11.11~11.17 周日(Js JVM XML)

文章目录 一、11.11 周一1.1&#xff09;问题01&#xff1a;js中的prompt弹窗区分出来用户点击的是 确认还是取消进一步示例 1.2&#xff09;问题02&#xff1a;在 prompt弹窗弹出时默认给弹窗中写入一些内容 二、11.12 周二2.1) 问题02: 详解JVM中的本地方法栈本地方法栈的主要…...

MATLAB 使用教程 —— 矩阵和数组

矩阵和数组MATLAB 中矩阵和数组长什么样&#xff1f;MATLAB 怎么用矩阵计算&#xff1f;创建和操作矩阵矩阵运算示例串联 访问矩阵的元素 矩阵和数组 MATLAB 是“matrix laboratory”的缩写形式。MATLAB 主要用于处理 整个的矩阵和数组&#xff0c;而其他编程语言大多逐个处理…...

React教程第二节之虚拟DOM与Diffing算法理解

1、什么是虚拟DOM 虚拟DOM 是javascript的一个对象&#xff0c;是内存中的一种数据结构&#xff0c;以树的形式存储UI的状态&#xff0c;树中的每个节点都代表着真实的DOM&#xff0c;用来描述我们希望在页面看到的 HTML结构&#xff1b; 现在的MVVM 框架&#xff0c;大多使用…...

C++——类和对象(part2)

前言 本篇博客继续为大家介绍类与对象的知识&#xff0c;承接part1的内容&#xff0c;本篇内容是类与对象的核心内容&#xff0c;稍微有些复杂&#xff0c;如果你对其感兴趣&#xff0c;请继续阅读&#xff0c;下面进入正文部分。 1. 类的默认成员函数 默认成员函数就是用户…...

【FFmpeg系列】:音频处理

前言 在多媒体处理领域&#xff0c;FFmpeg无疑是一个不可或缺的利器。它功能强大且高度灵活&#xff0c;能够轻松应对各种音频和视频处理任务&#xff0c;无论是简单的格式转换&#xff0c;还是复杂的音频编辑&#xff0c;都不在话下。然而&#xff0c;要想真正发挥FFmpeg的潜…...

Python绘制雪花

文章目录 系列目录写在前面技术需求完整代码代码分析1. 代码初始化部分分析2. 雪花绘制核心逻辑分析3. 窗口保持部分分析4. 美学与几何特点总结 写在后面 系列目录 序号直达链接爱心系列1Python制作一个无法拒绝的表白界面2Python满屏飘字表白代码3Python无限弹窗满屏表白代码4…...

vue3 如何调用第三方npm包内部的 pinia 状态管理库方法

抛砖引玉: 如果在开发vue3项目是, 引用了npm第三方包 ,而且这个包内使用了Pinia 状态管理库,那我们如何去调用 npm内部的 Pinia 状态管理库呢? 实际遇到的问题: 今天在制作npm包时遇到的问题,之前Vue2版本的时候状态管理库用的Vuex ,当时调用npm包内的状态管理库很简单,直接引…...

uni-app快速入门(七)--组件路由跳转和API路由跳转及参数传递

uni-app有两种页面路由跳转模式&#xff0c;即使用navigator组件跳转和调用API跳转&#xff0c;API调转不要理解为调用后台接口的API&#xff0c;而是指脚本函数中使用跳转函数。 一、组件路由跳转 1.1 打开新页面 打开新页面使用组件的open-type"navigate",见下面…...

Flink升级程序和版本

Flink DataStream程序通常设计为长时间运行,如几周、几个月甚至几年。与所有长时间运行的服务一样,Flink streaming应用程序也需要维护,包括修复错误、实现改进或将应用程序迁移到更高版本的Flink集群。 这里就来描述下如何更新Flink streaming应用程序,以及如何将正在运行…...

从0安装mysql server

安装 MySQL Server 首先,你需要在 Ubuntu 上安装 MySQL 服务器。运行以下命令来安装:sudo apt update sudo apt install mysql-server安装完成后,MySQL 服务会自动启动。你可以通过以下命令检查 MySQL 服务是否正在运行: sudo systemctl status mysql如果 MySQL 正在运行,…...