当前位置: 首页 > news >正文

pytorch bert实现文本分类

以imdb公开数据集为例,bert模型可以在huggingface上自行挑选

1.导入必要的库

import os
import torch
from torch.utils.data import DataLoader, TensorDataset, random_split
from transformers import BertTokenizer, BertModel, BertConfig
from torch import nn
from torch.optim import AdamW
import numpy as np
from sklearn.metrics import accuracy_score
import pandas as pd
from tqdm import tqdmdevice = torch.device("cuda:0")
print(device)

2.加载和预处理数据:读取数据,将其转换为适合BERT的格式,并将评分映射到三个类别。

import random
def load_imdb_dataset_and_create_multiclass_labels(path_to_data, split="train"):print(f"load start: {split}")reviews = []labels = []  # 0 for low, 1 for medium, 2 for highfor label in ["pos", "neg"]:labeled_path = os.path.join(path_to_data, split, label)for file in os.listdir(labeled_path):if file.endswith('.txt'):with open(os.path.join(labeled_path, file), 'r', encoding='utf-8') as f:reviews.append(f.read())if label == "neg":# Randomly assign negative reviews to low or mediumlabels.append(random.choice([0, 1]))  else:labels.append(2)  # Assign positive reviews to highreturn reviews[:1000], labels[:1000]
#加载数据集
train_texts, train_labels = load_imdb_dataset_and_create_multiclass_labels("./data/aclImdb", split="train")
test_texts, test_labels = load_imdb_dataset_and_create_multiclass_labels("./data/aclImdb", split="test")
print("load okk")
#样本数量
print("train_texts: ",len(train_texts))
print("test_texts: ",len(test_texts))

3.文本转换为BERT的输入格式

tokenizer = BertTokenizer.from_pretrained('./bert_pretrain')def encode_texts(tokenizer, texts, max_len=512):input_ids = []attention_masks = []for text in texts:encoded = tokenizer.encode_plus(text,add_special_tokens=True,max_length=max_len,pad_to_max_length=True,return_attention_mask=True,return_tensors='pt',)input_ids.append(encoded['input_ids'])attention_masks.append(encoded['attention_mask'])return torch.cat(input_ids, dim=0), torch.cat(attention_masks, dim=0)train_inputs, train_masks = encode_texts(tokenizer, train_texts)
test_inputs, test_masks = encode_texts(tokenizer, test_texts)
print("input transfromer encode done")

4.创建TensorDataset和DataLoader

train_labels = torch.tensor(train_labels)
test_labels = torch.tensor(test_labels)train_dataset = TensorDataset(train_inputs, train_masks, train_labels)
test_dataset = TensorDataset(test_inputs, test_masks, test_labels)# Split the dataset into train and validation sets
train_size = int(0.9 * len(train_dataset))
val_size = len(train_dataset) - train_size
train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])train_dataloader = DataLoader(train_dataset, batch_size=128, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=128, shuffle=False)
test_dataloader = DataLoader(test_dataset, batch_size=128, shuffle=False)

5.构建模型:使用BERT进行多分类任务

class BertForMultiLabelClassification(nn.Module):def __init__(self):super(BertForMultiLabelClassification, self).__init__()self.bert = BertModel.from_pretrained('./bert_pretrain')self.dropout = nn.Dropout(0.1)self.classifier = nn.Linear(self.bert.config.hidden_size, 3)  # 3类def forward(self, input_ids, attention_mask):_, pooled_output = self.bert(input_ids=input_ids, attention_mask=attention_mask, return_dict=False)pooled_output = self.dropout(pooled_output)return self.classifier(pooled_output)

6.训练和评估模型

# 初始化模型、优化器和损失函数
model = BertForMultiLabelClassification()
# 使用多GPU
# if MULTI_GPU:
#     model = nn.DataParallel(model)
model.to(device)optimizer = AdamW(model.parameters(), lr=2e-5)
loss_fn = nn.CrossEntropyLoss()# 训练函数
def train(model, dataloader, optimizer, loss_fn, device):model.train()total_loss = 0for batch in dataloader:batch = tuple(b.to(device) for b in batch)inputs, masks, labels = batchoptimizer.zero_grad()outputs = model(input_ids=inputs, attention_mask=masks)loss = loss_fn(outputs, labels)total_loss += loss.item()loss.backward()optimizer.step()average_loss = total_loss / len(dataloader)return average_loss# 评估函数
def evaluate(model, dataloader, loss_fn, device):model.eval()total_loss = 0predictions, true_labels = [], []with torch.no_grad():for batch in dataloader:batch = tuple(b.to(device) for b in batch)inputs, masks, labels = batchoutputs = model(input_ids=inputs, attention_mask=masks)loss = loss_fn(outputs, labels)total_loss += loss.item()logits = outputs.detach().cpu().numpy()label_ids = labels.to('cpu').numpy()predictions.append(logits)true_labels.append(label_ids)average_loss = total_loss / len(dataloader)flat_predictions = np.concatenate(predictions, axis=0)flat_predictions = np.argmax(flat_predictions, axis=1).flatten()flat_true_labels = np.concatenate(true_labels, axis=0)accuracy = accuracy_score(flat_true_labels, flat_predictions)return average_loss, accuracy# 训练和评估循环
for epoch in range(3):  # 假设训练3个周期train_loss = train(model, train_dataloader, optimizer, loss_fn, device)val_loss, val_accuracy = evaluate(model, val_dataloader, loss_fn, device)print(f"Epoch {epoch+1}")print(f"Train Loss: {train_loss:.3f}")print(f"Validation Loss: {val_loss:.3f}, Accuracy: {val_accuracy:.3f}")# 在测试集上评估模型性能
test_loss, test_accuracy = evaluate(model, test_dataloader, loss_fn, device)
print(f"Test Loss: {test_loss:.3f}, Accuracy: {test_accuracy:.3f}")
#保存模型
torch.save(model.state_dict(), "./model/bert_multiclass_imdb_model.pt")

7.模型预测

from transformers import BertModel
import torchdef predict(texts, model, tokenizer, device, max_len=128):# 将文本编码为BERT的输入格式def encode_texts(tokenizer, texts, max_len):input_ids = []attention_masks = []for text in texts:encoded = tokenizer.encode_plus(text,add_special_tokens=True,max_length=max_len,pad_to_max_length=True,return_attention_mask=True,return_tensors='pt',)input_ids.append(encoded['input_ids'])attention_masks.append(encoded['attention_mask'])return torch.cat(input_ids, dim=0), torch.cat(attention_masks, dim=0)model.eval()  # 将模型设置为评估模式predictions = []input_ids, attention_masks = encode_texts(tokenizer, texts, max_len)input_ids = input_ids.to(device)attention_masks = attention_masks.to(device)with torch.no_grad():outputs = model(input_ids, attention_mask=attention_masks)logits = outputs.detach().cpu().numpy()predictions = np.argmax(logits, axis=1)return predictions# 示例文本
texts = ["I very like the movie", "the movie is so bad"]# 调用预测函数# 初始化模型
device = torch.device("cuda:0")
model = BertForMultiLabelClassification()
model.to(device)# 加载模型状态
model.load_state_dict(torch.load('./model/bert_multiclass_imdb_model.pt'))# 将模型设置为评估模式
model.eval()# 加载tokenizer
tokenizer = BertTokenizer.from_pretrained('./bert_pretrain')predictions = predict(texts, model, tokenizer, device)# 输出预测结果
for text, pred in zip(texts, predictions):print(f"Text: {text}, Predicted category: {pred}")

相关文章:

pytorch bert实现文本分类

以imdb公开数据集为例,bert模型可以在huggingface上自行挑选 1.导入必要的库 import os import torch from torch.utils.data import DataLoader, TensorDataset, random_split from transformers import BertTokenizer, BertModel, BertConfig from torch import…...

《开箱元宇宙》:Madballs 解锁炫酷新境界,人物化身系列大卖

你是否曾想过,元宇宙是如何融入世界上最具代表性的品牌和名人的战略中的?在本期的《开箱元宇宙》 系列中,我们与 Madballs 的战略顾问 Derek Roberto 一起聊聊 Madballs 如何在 90 分钟内售罄 2,000 个人物化身系列,以及是什么原…...

4K-Resolution Photo Exposure Correction at 125 FPS with ~8K Parameters

MSLTNet开源 | 4K分辨率125FPS8K的参数量,怎养才可以拒绝这样的模型呢? 错误的曝光照片的校正已经被广泛使用深度卷积神经网络或Transformer进行广泛修正。尽管这些方法具有令人鼓舞的表现,但它们通常在高分辨率照片上具有大量的参数数量和沉…...

网络初识:局域网广域网网络通信基础

提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 前言一、局域网LAN是什么?二、广域网是什么:三. IP地址四.端口号五.认识协议5.1五元组 总结 前言 一、局域网LAN是什么? 局域网…...

JVM之jps虚拟机进程状态工具

jps虚拟机进程状态工具 1、jps jps:(JVM Process Status Tool),虚拟机进程状态工具,可以列出正在运行的虚拟机进程,并显示虚拟机执 行主类(Main Class,main()函数所在的类)的名称&#xff0c…...

C++实现顺序栈的基本操作(扩展)

#include <stdio.h> typedef char ElemType; #define StackSize 100 /*顺序栈的初始分配空间*/ typedef struct { ElemType data[StackSize]; /*保存栈中元素*/int top; /*栈顶指针*/ } SqStack; void InitStack(SqStack &st) {st.top-1; } …...

用python写一个简单的爬虫

爬虫是一种自动化程序&#xff0c;用于从互联网上获取数据。它能够模拟人类浏览网页的行为&#xff0c;访问网页并提取所需的信息。爬虫在很多领域都有广泛的应用&#xff0c;例如数据采集、信息监控、搜索引擎索引等。 下面是一个使用Python编写的简单爬虫示例&#xff1a; …...

分布式追踪

目录 文章目录 目录自定义指标1.删除标签2.添加指标3.禁用指标 分布式追踪上下文传递Jaeger 关于我最后最后 自定义指标 除了 Istio 自带的指标外&#xff0c;我们还可以自定义指标&#xff0c;要自定指标需要用到 Istio 提供的 Telemetry API&#xff0c;该 API 能够灵活地配…...

make -c VS make -f

make 是一个用于构建&#xff08;编译&#xff09;项目的工具&#xff0c;它通过读取一个名为 Makefile 的文件来执行构建任务。make 命令有很多选项和参数&#xff0c;其中包括 -c 和 -f。 make -c&#xff1a; 作用&#xff1a;指定进入指定的目录并执行相应的 Makefile。 示…...

Unity 代码控制Color无变化

Unity中&#xff0c;我们给Color的赋值比较常用的方法是&#xff1a; 1、使用预定义颜色常量&#xff1a; Color color Color.white; //白色 Color color Color.black; //黑色 Color color Color.red; //红色 Color color Color.green; //绿色 Color color Color.blue; …...

【Erlang进阶学习】2、匿名函数

受到其它一些函数式编程开发语言的影响&#xff0c;在Erlang语言中&#xff0c;将函数作为一个对象&#xff0c;赋予其“变量”的属性&#xff0c;即为我们的匿名函数 或 简称 fun&#xff0c;它具有以下特性&#xff1a; &#xff08;匿名函数&#xff1a;不是定义在Erlang模…...

肖sir__mysql之视图__009

mysql之视图 一、什么是视图 视图是一个虚拟表&#xff08;逻辑表&#xff09;&#xff0c;它不在数据库中以存储形式保存&#xff08;本身包含数据&#xff09;&#xff0c;是在使用视图的时候动态生成。 二、视图作用 1、查询数据库中的非常复的数据 例如&#xff1a;多表&a…...

FPGA falsh相关知识总结

1.存储容量是128M/8 Mb16MB 2.有256个sector扇区*每个扇区64KB16MB 3.一页256Byte 4.页编程地址0256 5&#xff1a;在调试SPI时序的时候一定注意&#xff0c;miso和mosi两个管脚只要没发送数据就一定要悬空&#xff08;处于高组态&#xff09;&#xff0c;不然指令会通过两…...

升辉清洁IPO:广东清洁服务“一哥”还需要讲好全国化的故事

近日&#xff0c;广东物业清洁服务“一哥”升辉清洁第四次冲击IPO成功&#xff0c;拟于12月5日在香港主板挂牌上市。自2021年4月第一次递交招股书&#xff0c;时隔两年半&#xff0c;升辉清洁终于拿到了上市的门票。 天眼查显示&#xff0c;升辉清洁成立于2000年&#xff0c;主…...

Python自动化办公:PDF文件的分割与合并

我们平时办公中&#xff0c;可能需要对pdf进行合并或者分割&#xff0c;但奈何没有可以白嫖的工具&#xff0c;此时python就是一个万能工具库。 其中PyPDF2是一个用于处理PDF文件的Python库&#xff0c;它提供了分割和合并PDF文件的功能。 在本篇博客中&#xff0c;我们将详细…...

破解app思路

1.会看smali代码逻辑 一.快速定位关键代码 1.分析流程 搜索特征字符串 搜索关键 api 通过方法名来判断方法的功能 2.快速定位关键代码 反编译 APK 程序 AndroidManifest.xml>包名/系统版本/组件 程序的主 activity(程序入口界面) 每个 Android 程序…...

背景特效插件:Background Effects

...

36.位运算符

一.什么是位运算符 按照二进制位来进行运算的运算符叫做位运算符&#xff0c;所以要先将操作数转换成二进制&#xff08;补码&#xff09;的形式在运算。C语言的中的位运算符有&#xff1a; 运算符作用举例结果& 按位与&#xff08;and&#xff09; 0&00; 0&10; …...

C#异常处理-throw语句

throw语句是我们手动引发异常的一个语句。 在程序执行过程中&#xff0c;当某些条件不符合我们的要求时&#xff0c;那么我们就可以使用throw语句手动抛出异常&#xff0c;那么就可以在异常发生的地方终止当前代码块的执行&#xff0c;此时我们就可以把控制权传递给调用堆栈中…...

PlantUML语法(全)及使用教程-时序图

目录 1. 参与者1.1、参与者说明1.2、背景色1.3、参与者顺序 2. 消息和箭头2.1、 文本对其方式2.2、响应信息显示在箭头下面2.3、箭头设置2.4、修改箭头颜色2.5、对消息排序 3. 页面标题、眉角、页脚4. 分割页面5. 生命线6. 填充区设置7. 注释8. 移除脚注9. 组合信息9.1、alt/el…...

231204 刷题日报

21. 合并两个有序链表 单调栈没看懂&#xff0c;晚上回家再说吧 380. O(1) 时间插入、删除和获取随机元素 今天被接雨水钉在耻辱柱&#xff0c;找时间再看吧...

PTA 7-229 sdut-C语言实验- 排序

给你N(N<100)个数,请你按照从小到大的顺序输出。 输入格式: 输入数据第一行是一个正整数N,第二行有N个整数。 输出格式: 输出一行&#xff0c;从小到大输出这N个数&#xff0c;中间用空格隔开。 输入样例: 5 1 4 3 2 5输出样例: 1 2 3 4 5 #include <stdio.h>…...

原生横向滚动条 吸附 页面底部

效果图 /** 横向滚动条 吸附 页面底部 */ export class StickyHorizontalScrollBar {constructor(options {}) {const { el, style } optionsthis.createScrollbar(style)this.insertScrollbar(el)this.setScrollbarSize()this.onEvent()}/** 创建滚轴组件元素 */createS…...

1+x网络系统建设与运维(中级)-练习3

一.设备命名 AR1 [Huawei]sysn AR1 [AR1] 同理可得&#xff0c;所有设备的命名如上图所示 二.VLAN LSW1 [LSW1]vlan 10 [LSW1-vlan10]q [LSW1]int g0/0/1 [LSW1-GigabitEthernet0/0/1]port link-type access [LSW1-GigabitEthernet0/0/1]port default vlan 10 [LSW1-GigabitEt…...

知识图谱07——图片中表格开源ocr识别

对比了多种ocr识别算法&#xff0c;最终选择了百度paddle官方的ocr算法 在所在的虚拟环境下运行 pip install paddleocr --userfrom paddleocr import PaddleOCR import os import csv# 创建 PaddleOCR 对象 ocr PaddleOCR(use_gpuTrue) # 无gpu时选择False# 指定图片文件夹…...

每日一练2023.12.4——正整数【PTA】

一时间网上一片求救声&#xff0c;急问这个怎么破。其实这段代码很简单&#xff0c;index数组就是arr数组的下标&#xff0c;index[0]2 对应 arr[2]1&#xff0c;index[1]0 对应 arr[0]8&#xff0c;index[2]3 对应 arr[3]0&#xff0c;以此类推…… 很容易得到电话号码是18013…...

golang之net/http模块学习

文章目录 开启服务开启访问静态文件获取现在时间按时间创建一个空的json文件按时间创建一个固定值的json文件 跨域请求处理输出是json 开启服务 package mainimport ("fmt""net/http" )//路由 func handler(w http.ResponseWriter, r *http.Request){fmt.…...

Python中format函数用法

嗨喽~大家好呀&#xff0c;这里是魔王呐 ❤ ~! python更多源码/资料/解答/教程等 点击此处跳转文末名片免费获取 format优点 format是python2.6新增的一个格式化字符串的方法&#xff0c;相对于老版的%格式方法&#xff0c;它有很多优点。 1.不需要理会数据类型的问题&#…...

Android 断点调试

Android 调试 https://developer.android.google.cn/studio/debug?hlzh-cn 调试自己写的代码&#xff08;不在Android源码&#xff09; 点击 Attach debugger to Android process 图标 需要在添加断点界面手动输入函数名 但也可以不手动&#xff0c;有个技巧可以new 空proje…...

对抗神经网络 CGAN实战详解 完整数据代码可直接运行

代码视频讲解: 中文核心项目:对抗神经网络 CGAN实战详解 完整代码数据可直接运行_哔哩哔哩_bilibili 运行图: 完整代码: from keras.layers import Input, Dense, Reshape, Flatten, Dropout, multiply from keras.layers import BatchNormalization, Activation, Embedd…...

网站策划书格式/郑州客串seo

前言 cloudflare 是一家国外的 CDN 加速服务商&#xff0c;还是很有名气的。提供免费和付费的加速和网站保护服务。以前推荐过的百度云加速的国外节点就是和 cloudflare 合作使用的 cloudflare 的节点。 cloudflare 提供了不同类型的套餐&#xff0c;即使是免费用户&#xff0c…...

django网站开发源码/持啊传媒企业推广

推荐博客&#xff1a;付铭 day-01 HTML 1、HTML 基本语法 html标签 单标签 <img /> 、<img> 双标签 <html> </html> 属性 属于标签 <img src"图片的地址"><table width"100" height"200"></table> 1…...

特产网站设计/seo网站推广软件排名

http请求介绍 HTTP(HyperText Transfer Protocol)是一套计算机通过网络进行通信的规则。计算机专家设计出HTTP&#xff0c;使HTTP客户&#xff08;如Web浏览器&#xff09;能够从HTTP服务器(Web服务器)请求信息和服务&#xff0c;HTTP目前协议的版本是1.1.HTTP是一种无状态的协…...

网页设计实训报告记录和结果分析/桔子seo

QQ概念版&#xff0c;触摸是王道 转载于:https://www.cnblogs.com/nuddle/archive/2010/05/06/1728535.html...

做一个网站的基本步骤/一个新品牌如何推广

Q1. 环境预准备 绝大多数MON创建的失败都是由于防火墙没有关导致的&#xff0c;亦或是SeLinux没关闭导致的。一定一定一定要关闭每个每个每个节点的防火墙(执行一次就好&#xff0c;没安装报错就忽视)&#xff1a; CentOS sed -i s/SELINUX.*/SELINUXdisabled/ /etc/selinux…...

成都网站设计开发公司/百度搜索引擎广告位的投放

Ubuntu 3D目录了解 Ubuntu历史与发展过程特色Ubuntu的3D桌面开发意念安装设定从硬盘安装其他特色软件组件缺省套件私有版权软件的采用发布周期正式衍生版本非正式衍生版本中文译名各界评价免费获取Ubuntu CD软件组件缺省套件私有版权软件的采用发布周期正式衍生版本非正式衍生版…...