当前位置: 首页 > news >正文

从0开始深度学习(33)——循环神经网络的简洁实现

本章使用Pytorch的API实现RNN上的语言模型训练

0 导入库

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from collections import Counter
import re
import math
from tqdm import tqdm

1 准备数据

需要对文本进行预处理,比如转换为小写、去除标点符号等,以减少词汇量并简化问题,然后构建词汇表,即创建一个字符到索引的映射和一个索引到字符的映射,最后将将文本转换为整数序列,这些整数代表词汇表中的位置。

# 1. 加载数据
def load_data(file_path):with open(file_path, 'r') as f:lines = f.readlines()text = ''.join([line.strip().lower() for line in lines])# 使用正则表达式去除标点符号和数字text = re.sub(r'[^\w\s]', '', text)  # 去除标点符号text = re.sub(r'\d+', '', text)      # 去除数字return text# 2. 文本预处理
def preprocess_text(text):tokens = list(text)  # 将文本切分为字符vocab = sorted(set(tokens))  # 构建词表token_to_idx = {token: idx for idx, token in enumerate(vocab)}  # 词元到索引的映射idx_to_token = {idx: token for token, idx in token_to_idx.items()}  # 索引到词元的映射token_indices = [token_to_idx[token] for token in tokens]  # 把文本转化为索引列表return token_indices, token_to_idx, idx_to_token, vocab

2 创建数据集

从文本中提取固定长度的子序列作为输入,并将紧随其后的字符作为目标输出,最后将这些序列转换为适合输入到RNN模型的张量格式

# 数据集类
class TextDataset(Dataset):def __init__(self, token_indices, seq_len):self.data = token_indicesself.seq_len = seq_lendef __len__(self):return len(self.data) - self.seq_lendef __getitem__(self, idx):# 输入数据是从当前位置到指定序列长度的位置的数据,即一个序列x = self.data[idx:idx + self.seq_len]# 目标数据是输入数据的下一个位置的数据,即单个字符y = self.data[idx + 1:idx + self.seq_len + 1]return torch.tensor(x, dtype=torch.long), torch.tensor(y, dtype=torch.long)# 转化为Tensor

3 构建RNN模型

使用Pytorch构建RNN模型

class SimpleRNN(nn.Module):def __init__(self, vocab_size, hidden_size):super(SimpleRNN, self).__init__()self.hidden_size = hidden_size # 隐藏层形状self.rnn = nn.RNN(vocab_size, hidden_size, batch_first=True)'''vocab_size:特征的数量,即词汇表的大小hidden_size:隐藏层的状态向量的维度batch_first:决定了输入和输出张量的形状如果batch_first=True,输入和输出张量的形状将是(batch_size,sequence_length, input_size)。如果batch_first=False,输入和输出张量的形状将是 (sequence_length, batch_size, input_size)。'''self.fc = nn.Linear(hidden_size, vocab_size)def forward(self, x, hidden=None):out, hidden = self.rnn(x, hidden)  # RNN层out = self.fc(out)  # 全连接层return out, hidden

4 训练模型

在训练前,需要把数据转化为one-hot编码,以增强特征属性,添加困惑度作为评价指标,使用早停法提前结束训练,避免过拟合

# 4. 训练模型
def train_model(model, dataloader, val_dataloader, criterion, vocab_size, optimizer, device, num_epochs=100, patience=5, min_delta=0.001):assert vocab_size is not None, "vocab_size must be provided"model.to(device)  # 将模型移动到指定设备model.train()  # 设置模型为训练模式best_val_loss = float('inf')epochs_no_improve = 0for epoch in range(num_epochs):total_loss = 0# 训练阶段with tqdm(dataloader, desc=f'Epoch {epoch+1}/{num_epochs} (Training)', unit='batch') as tepoch:for inputs, targets in tepoch:# 将数据移动到指定设备inputs, targets = inputs.to(device), targets.to(device)  # 将输入数据转换为 one-hot 编码inputs_one_hot = F.one_hot(inputs, num_classes=vocab_size).float()# 清零梯度optimizer.zero_grad()  # 前向传播outputs, _ = model(inputs_one_hot)# 计算损失loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))# 反向传播和优化loss.backward()optimizer.step()total_loss += loss.item()tepoch.set_postfix(loss=loss.item())average_loss = total_loss / len(dataloader)perplexity = math.exp(average_loss)  # 计算困惑度print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {average_loss:.4f}, Perplexity: {perplexity:.4f}')# 验证阶段model.eval()val_loss = 0with torch.no_grad():with tqdm(val_dataloader, desc=f'Epoch {epoch+1}/{num_epochs} (Validation)', unit='batch') as tepoch:for inputs, targets in tepoch:inputs, targets = inputs.to(device), targets.to(device)inputs_one_hot = F.one_hot(inputs, num_classes=vocab_size).float()outputs, _ = model(inputs_one_hot)loss = criterion(outputs.view(-1, vocab_size), targets.view(-1))val_loss += loss.item()tepoch.set_postfix(loss=loss.item())average_val_loss = val_loss / len(val_dataloader)print(f'Validation Loss: {average_val_loss:.4f}')# 检查是否需要早停if average_val_loss < best_val_loss - min_delta:best_val_loss = average_val_lossepochs_no_improve = 0else:epochs_no_improve += 1if epochs_no_improve >= patience:print(f'Early stopping at epoch {epoch+1}')breakmodel.train()  # 回到训练模式

5 预测模型

我们的输入必须大于seq_len,不然就不符合输入格式(可以使用补全,这里不展开),对于单词或者句子,需要把他们分割为字符,然后转换为token序列,作为输入

def predict(model, token_to_idx, idx_to_token, start_text, length, device, unk_token='<UNK>'):model.to(device)model.eval()# 将起始文本转换为字符 token 序列input_tokens = []for char in start_text:if char in token_to_idx:input_tokens.append(token_to_idx[char])else:if unk_token in token_to_idx:input_tokens.append(token_to_idx[unk_token])  # 使用 <UNK> 表示未知字符else:raise ValueError(f"Character '{char}' not in vocabulary and no '<UNK>' token provided.")# 转换为 PyTorch Tensorinput_tensor = torch.tensor(input_tokens, dtype=torch.long).unsqueeze(0).to(device)generated_tokens = []with torch.no_grad():hidden = Nonefor i in range(length):# 将输入数据转换为 one-hot 编码inputs_one_hot = F.one_hot(input_tensor, num_classes=len(token_to_idx)).float()# 前向传播outputs, hidden = model(inputs_one_hot, hidden)# 获取最后一个时间步的输出output = outputs[0, -1, :]# 获取最大概率的 token_, top_index = output.topk(1)predicted_token = idx_to_token[top_index.item()]# 添加预测的 token 到生成的序列中generated_tokens.append(predicted_token)# 更新输入 tensorinput_tensor = torch.tensor([[top_index.item()]], dtype=torch.long).to(device)# 将生成的字符序列拼接成字符串generated_text = ''.join(generated_tokens)return start_text + generated_text

6 主函数

# 读取数据
file_path = '/home/caser/code/data/timemachine.txt'
text = load_data(file_path)
# 预处理数据
token_indices, token_to_idx, idx_to_token, vocab=preprocess_text(text)# 参数设置
seq_len = 5
batch_size = 64
hidden_size = 128
learning_rate = 0.01
num_epochs = 100
patience = 5  # 早停法的耐心值
min_delta = 0.001  # 早停法的最小改进阈值# 创建数据集和数据加载器
dataset = TextDataset(token_indices, seq_len)
train_size = int(0.9 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)# 初始化模型和优化器
vocab_size = len(vocab)
model = SimpleRNN(vocab_size, hidden_size)
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)# 选择设备
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 训练模型
train_model(model, train_dataloader, val_dataloader, criterion, vocab_size, optimizer, device, num_epochs, patience, min_delta)# 进行预测
start_text = 'the time traveller '
predicted_text = predict(model, token_to_idx, idx_to_token, start_text, length=50, device=device)
print(predicted_text)

运行结果:
在这里插入图片描述

相关文章:

从0开始深度学习(33)——循环神经网络的简洁实现

本章使用Pytorch的API实现RNN上的语言模型训练 0 导入库 import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import Dataset, DataLoader from collections import Counter import re import math from tqdm import tqdm1 准备数据 …...

【FAQ】HarmonyOS SDK 闭源开放能力 — 公共模块

1.问题描述&#xff1a; 文档哪里能找到所有的权限查看该权限是用户级的还是系统级的。 解决方案&#xff1a; 您好&#xff0c;可以看一下下方链接是否可以解决问题&#xff1a; https://developer.huawei.com/consumer/cn/doc/harmonyos-guides-V5/permissions-for-all-V…...

百度 文心一言 vs 阿里 通义千问 哪个好?

背景介绍&#xff1a; 在当前的人工智能领域&#xff0c;随着大模型技术的快速发展&#xff0c;市场上涌现出了众多的大规模语言模型。然而&#xff0c;由于缺乏统一且权威的评估标准&#xff0c;很多关于这些模型能力的文章往往基于主观测试或自行设定的排行榜来评价模型性能…...

内网不出网上线cs

一:本地正向代理目标 如下&#xff0c;本地(10.211.55.2)挂好了基于 reGeorg 的 http 正向代理。代理为: Socks5 10.211.55.2 1080python2 reGeorgSocksProxy.py -l 0.0.0.0 -p 1080 -u http://10.211.55.3:8080/shiro/tunnel.jsp 二&#xff1a;虚拟机配置proxifer 我们是…...

ubuntu22开机自动登陆和开机自动运行google浏览器自动打开网页

一、开机自动登陆 1、打开settings->点击Users 重启系统即可自动登陆桌面 二、开机自动运行google浏览器自动打开网页 1、安装google浏览器 sudo wget https://dl.google.com/linux/direct/google-chrome-stable_current_amd64.deb sudo dpkg -i ./google-chrome-stable…...

企业建站高性能的内容管理系统

AnQiCMS 是一款高性能的内容管理系统&#xff0c;基于Go语言开发。它支持多站点、多语言管理&#xff0c;提供灵活的内容发布和模板管理功能&#xff0c;同时&#xff0c;系统内置丰富的利于SEO操作的功能&#xff0c;支持包括自定义字段、文档分类、批量导入导出等功能 AnQiC…...

【爬虫框架:feapder,管理系统 feaplat】

github&#xff1a;https://github.com/Boris-code/feapder 爬虫管理系统 feaplat&#xff1a;http://feapder.com/#/feapder_platform/feaplat 爬虫在线工具库 &#xff1a;http://www.spidertools.cn &#xff1a;https://www.kgtools.cn/1、feapder 简介 对于学习 Python…...

faiss库中ivf-sq(ScalarQuantizer,标量量化)代码解读-5

训练过程 通过gdb调试得到这个ivfsq的训练过程&#xff0c;我尝试对这个内容具体训练过程进行解析&#xff0c;对每个调用栈里面的逻辑和代码进行解读。 步骤函数名称调用位置说明1faiss::IndexIVF::train/faiss/IndexIVF.cpp:1143开始训练&#xff0c;判断是否需要训练第一级…...

代码随想录算法训练营第六十天|Day60 图论

Bellman_ford 队列优化算法&#xff08;又名SPFA&#xff09; https://www.programmercarl.com/kamacoder/0094.%E5%9F%8E%E5%B8%82%E9%97%B4%E8%B4%A7%E7%89%A9%E8%BF%90%E8%BE%93I-SPFA.html 本题我们来系统讲解 Bellman_ford 队列优化算法 &#xff0c;也叫SPFA算法&#xf…...

在嵌入式Linux下如何用QT开发UI

在嵌入式 Linux 环境下使用 Qt 开发用户界面 (UI) 是一个常见的选择。Qt 提供了丰富的功能、跨平台支持以及优秀的图形界面开发能力&#xff0c;非常适合用于嵌入式系统。以下是开发流程的详细步骤&#xff1a; 1. 准备开发环境 硬件环境 一块运行嵌入式 Linux 的开发板&…...

【JavaScript】Promise详解

Promise 是 JavaScript 中处理异步操作的一种强大机制。它提供了一种更清晰、更可控的方式来处理异步代码&#xff0c;避免了回调地狱&#xff08;callback hell&#xff09;和复杂的错误处理。 基本概念 状态&#xff1a; Pending&#xff1a;初始状态&#xff0c;既不是成功…...

1062 Talent and Virtue

About 900 years ago, a Chinese philosopher Sima Guang wrote a history book in which he talked about peoples talent and virtue. According to his theory, a man being outstanding in both talent and virtue must be a "sage&#xff08;圣人&#xff09;"…...

C++《二叉搜索树》

在初阶数据结构中我学习了树基础的概念以及了解了顺序结构的二叉树——堆和链式结构二叉树该如何实现&#xff0c;那么接下来我们将进一步的学习二叉树&#xff0c;在此会先后学习到二叉搜索树、AVL树、红黑树&#xff1b;通过这些的学习将让我们更易于理解后面set、map、哈希等…...

机器学习-神经网络(BP神经网络前向和反向传播推导)

1.1 神经元模型 神经网络(neural networks)方面的研究很早就已出现,今天“神经网络”已是一个相当大的、多学科交叉的学科领域.各相关学科对神经网络的定义多种多样,本书采用目前使用得最广泛的一种,即“神经网络是由具有适应性的简单单元组成的广泛并行互连的网络,它的组织能够…...

基于智能物联网关的车辆超重AI检测应用

超重超载是严重的交通违法行为&#xff0c;超重超载车辆的交通安全风险极高&#xff0c;像是一颗行走的“不定时炸弹”&#xff0c;威胁着社会公众的安全。但总有一些人受到利益驱使&#xff0c;使超重超载的违法违规行为时有发生。 随着物联网和AI技术的发展&#xff0c;针对预…...

记录pbootcms提示:登录失败:表单提交校验失败,请刷新后重试的解决办法

问题描述 pbootcms后台登录的时候提示“登录失败&#xff1a;表单提交校验失败,请刷新后重试!” 解决办法 删除runtime目录&#xff0c;或尝试切换PHP版本&#xff0c;选择7.3或5.6一般就能解决了。...

【JavaScript】同步异步详解

同步和异步是编程中处理任务执行顺序的两种不同方式。理解这两种概念对于编写高效和响应式的应用程序至关重要。 同步&#xff08;Synchronous&#xff09; 定义&#xff1a;同步操作是指一个任务必须在下一个任务开始之前完成。换句话说&#xff0c;代码按顺序执行&#xff…...

vue 使用el-button 如何实现多个button 单选

在 Vue 中&#xff0c;如果你想要实现多个 el-button 按钮的 单选&#xff08;即只能选择一个按钮&#xff09;&#xff0c;可以通过绑定 v-model 或使用事件来处理按钮的选中状态。 下面是两种实现方式&#xff0c;分别使用 v-model 和事件监听来实现单选按钮效果&#xff1a…...

HarmonyOS-初级(二)

文章目录 应用程序框架UIAbilityArkUI框架 &#x1f3e1;作者主页&#xff1a;点击&#xff01; &#x1f916;HarmonyOS专栏&#xff1a;点击&#xff01; ⏰️创作时间&#xff1a;2024年11月28日13点10分 应用程序框架 应用程序框架可以被看做是应用模型的一种实现方式。 …...

Unity开启外部EXE程序

Unity开启外部EXE using System; using System.Collections; using System.Collections.Generic; using System.Diagnostics; using System.Runtime.InteropServices; using System.Threading.Tasks; using UnityEditor; using UnityEngine;public class Unity_OpenExe : Mono…...

CTF之密码学(埃特巴什码 )

一、基本原理 埃特巴什码的原理是&#xff1a;字母表中的最后一个字母代表第一个字母&#xff0c;倒数第二个字母代表第二个字母&#xff0c;以此类推。在罗马字母表中&#xff0c;对应关系如下&#xff1a; 常文&#xff08;明文&#xff09;&#xff1a;A B C D E F G H I …...

深入解析 PyTorch 的 torch.load() 函数:用法、参数与实际应用示例

深入解析 PyTorch 的 torch.load() 函数&#xff1a;用法、参数与实际应用示例 函数 torch.load() 是一个在PyTorch中用于加载通过 torch.save() 保存的序列化对象的核心功能。这个函数广泛应用于加载预训练模型、模型的状态字典&#xff08;state dictionaries&#xff09;、…...

ros2键盘实现车辆: 简单的油门_刹车_挡位_前后左右移动控制

参考: ROS python 实现键盘控制 底盘移动 https://blog.csdn.net/u011326325/article/details/131609340游戏手柄控制 1.背景与需求 1.之前实现过 键盘控制 底盘移动的程序, 底盘是线速度控制, 效果还不错. 2.新的底盘 只支持油门控制, 使用线速度控制问题比较多, 和底盘适配…...

ubuntu安装chrome无法打开问题

如果在ubuntu安装chrome后&#xff0c;点击chrome打开没反应&#xff0c;可以先试着在terminal上用命令打开 google-chrome 如果运行命令显示 Chrome has locked the profile so that it doesnt get corrupted. If you are sure no other processes are using this profile…...

CTF-RE 从0到N:Chacha20逆向实战 2024 强网杯青少年专项赛 EnterGame WP (END)

只想解题的看最后就好了,前面是算法分析 Chacha20 c语言是如何利用逻辑运算符拆分变量和合并的 通过百度网盘分享的文件&#xff1a;EnterGame_9acdc7c33f85832082adc6a4e... 链接&#xff1a;https://pan.baidu.com/s/182SRj2Xemo63PCoaLNUsRQ?pwd1111 提取码&#xff1a;1…...

vue3 ajax获取json数组排序举例

使用axios获取接口数据 可以在代码中安装axios包&#xff0c;并写入到package.json文件&#xff1a; npm install axios -S接口调用代码举例如下&#xff1a; const fetchScore async () > {try {const res await axios.get(http://127.0.0.1:8000/score/${userInput.v…...

web安全之信息收集

在信息收集中,最主要是就是收集服务器的配置信息和网站的敏感信息,其中包括域名及子域名信息,目标网站系统,CMS指纹,目标网站真实IP,开放端口等。换句话说,只要是与目标网站相关的信息,我们都应该去尽量搜集。 1.1收集域名信息 知道目标的域名之后,获取域名的注册信…...

报错:java: 无法访问org.springframework.boot.SpringApplication

idea报错内容&#xff1a; java: 无法访问org.springframework.boot.SpringApplication 报错原因&#xff1a; <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>3.4…...

线上+线下≠新零售,6大互通诠释新零售的核心要点-亿发

新零售&#xff0c;这个词汇在近年来频繁出现在我们的视野中&#xff0c;它不仅仅是线上与线下的简单相加&#xff0c;而是一场深刻的商业变革。本文将通过6大互通的核心要点&#xff0c;为您揭示新零售的真正内涵。 1. 商品的互联互通 新零售模式下&#xff0c;商品的互联互…...

GitHub Copilot革命性更新:整合顶尖AI模型,如何重塑开发体验?

在技术快速发展的今天&#xff0c;代码辅助工具已成为提升开发效率的利器。今天&#xff0c;我们带来了一个激动人心的消息——GitHub Copilot宣布引入多模型选择功能&#xff0c;这不仅是技术上的一次飞跃&#xff0c;更是对开发者工作流程的一次革新。 多模型选择&#xff1a…...

wordpress拿shell/谷歌下载安装

Dom4j and Sax difference。 Dom4j 解析的速度慢&#xff0c;而且消耗内存&#xff0c;因为在解析之前要先把文件放到内存中。并采用基于对象的模型解析有以下几点&#xff1a; 1. Dom4J parse loads the entire XML file into memory before parsing.2. It uses Object based …...

网站浏览器兼容性/中国四大软件外包公司

请大家帮个忙&#xff0c;集思广议能不能有什么好方法&#xff0c;解决了这个问题。。。...

日本男女做受网站/百度客服人工电话24

使用java操作HDFS需要使用到的jar包将hadoop的tar.gz包解压&#xff0c;里面的lib下的所有jar包&#xff0c;share/hadoop目录下的common和hdfs文件下的所有jar包以及Hadoop-common-2.7.7、Hadoop-hdfs-2.7.7、hadoop-client-2.7.7这三个jar包。常用的操作1.连接至hdfsTestpubl…...

招聘网站开发需要多长时间/网络营销方法有哪几种

Rust语言是一门系统编程语言&#xff0c;专注于安全和高性能。在保证性能的同时提供更好的内存安全。 Rust性能与标准C性能不相上下。 Rust不像Go、Java以及.NET那样使用自动垃圾回收系统&#xff0c;而是所有权系统来管理内存。 Rust对协程&#xff0c;异步&#xff0c;网络也…...

动态网站模板下载/网站百度关键词排名软件

今天分享一个杨氏太极拳的视频文件&#xff0c;让更多喜欢太极拳的朋友能够有好的帮助。。。 链接&#xff1a;https://pan.baidu.com/s/1skJaFXN 密码&#xff1a;sqwj Austin Liu 刘恒辉 Department Manager&#xff0c;Product Manager&#xff0c;Project Manager an…...

如何做网盟推广网站/yandex网站推广

环境&#xff1a;PLSQL Developer 7.1.5 Oracle 11.2.0Oracle中不像MYSQL和MSSQLServer中那样指定一个列为自动增长列的方式&#xff0c;不过在Oracle中可以通过SEQUENCE序列来实现自动增长字段。在Oracle中SEQUENCE被称为序列&#xff0c;每次取的时候它会自动增加&#xff0c…...