从0开始深度学习(32)——循环神经网络的从零开始实现
本章将从零开始,基于循环神经网络实现字符级语言模型(不是单词级)
首先我们把从0开始深度学习(30)——语言模型和数据集中的load_corpus_time_machine()函数进行引用,用于导入数据:
train_iter, vocab = load_corpus_time_machine()
train_iter
运行结果:

train_iter中的每个数字都表示在vocab中的索引,将这些索引直接输入神经网络可能会使学习变得困难,我们通常将每个词元表示为更具表现力的特征向量,即one-hot编码。
1 one-hot编码
假设词表中不同词元的数目为 N N N,也就是len(vocab),所以词元索引的范围为 0 N − 1 0~N-1 0 N−1。如果词元的索引是整数 i i i, 那么我们将创建一个长度为 N N N的全 0 0 0向量, 并将第 i i i处的元素设置为 1 1 1。例如索引为 0 0 0和 2 2 2的独热向量如下所示:
from torch.nn import functional as F
import torchF.one_hot(torch.tensor([0, 2]), len(vocab))
运行结果:
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0],# 第0处设置为1,表示索引1[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0]])# 第2处设置为1,表示索引2
由于我们每次采样的小批量数据形状是二维张量:(批量大小,时间步数),one_hot函数会将这样一个小批量数据转换成三维张量, 张量的最后一个维度等于词表大小len(vocab)。我们经常转换输入的维度,以便获得形状为 (时间步数,批量大小,词表大小)的输出。 这将使我们能够更方便地通过最外层的维度, 一步一步地更新小批量数据的隐状态。
转化这一步将在后面进行
2 初始化模型参数
初始化循环神经网络模型的模型参数, 隐藏单元数num_hiddens是一个可调的超参数。 当训练语言模型时,输入和输出来自相同的词表,因此,它们具有相同的维度,即词表的大小:
def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device) * 0.01# 隐藏层参数W_xh = normal((num_inputs, num_hiddens))W_hh = normal((num_hiddens, num_hiddens))b_h = torch.zeros(num_hiddens, device=device)# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params
3 循环神经网络模型
为了定义循环神经网络模型, 我们首先需要一个init_rnn_state函数在初始化时返回隐状态,这个函数的返回是一个张量,形状为(批量大小,隐藏单元数)
def init_rnn_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )
下面的rnn函数定义了如何在一个时间步内计算隐状态和输出。
def init_rnn_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )def rnn(inputs, state, params):# inputs的形状:(时间步数量,批量大小,词表大小)W_xh, W_hh, b_h, W_hq, b_q = paramsH, = state # 提取隐状态outputs = []# X的形状:(批量大小,词表大小)for X in inputs:H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h) # 使用激活函数Y = torch.mm(H, W_hq) + b_q # 做矩阵乘法,即 (隐变量*权重+偏置)outputs.append(Y)return torch.cat(outputs, dim=0), (H,)# cat()函数将所有时间步的输出拼接成一个张量,形状为 (时间步数量 * 批量大小, 输出大小)# (H,)为最后一个时间步的隐藏状态
定义了所有需要的函数之后,接下来我们创建一个类来包装这些函数, 并存储从零开始实现的循环神经网络模型的参数:
class RNNModelScratch: #@save"""从零开始实现的循环神经网络模型"""def __init__(self, vocab_size, num_hiddens, device,get_params, init_state, forward_fn):self.vocab_size, self.num_hiddens = vocab_size, num_hiddensself.params = get_params(vocab_size, num_hiddens, device)self.init_state, self.forward_fn = init_state, forward_fndef __call__(self, X, state):X = F.one_hot(X.T, self.vocab_size).type(torch.float32)# 进行one-hot编码return self.forward_fn(X, state, self.params)# 调用前向传播函数,传入编码后的输入数据、初始隐藏状态和模型参数def begin_state(self, batch_size, device):return self.init_state(batch_size, self.num_hiddens, device)# 初始化的隐藏状态
我们可以做一个测试:
num_hiddens = 512
net = RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,init_rnn_state, rnn)
state = net.begin_state(X.shape[0], d2l.try_gpu())
Y, new_state = net(X.to(d2l.try_gpu()), state)
Y.shape, len(new_state), new_state[0].shape
# 输出形状,隐状态形状
运行结果:
(torch.Size([10, 28]), 1, torch.Size([2, 512]))
输出形状是(时间步数*批量大小,词表大小), 隐状态形状是(批量大小,隐藏单元数),符合要求
4 梯度裁剪
在编写训练函数之前,要引入一个方法——梯度裁剪,用于防止梯度爆炸问题。
对于长度为 T T T的序列,在迭代时要计算 T T T个时间步上的梯度,于是会在反向传播中产生长度为 O ( T ) O(T) O(T)的矩阵乘法链,当 T T T过大时,有可能导致梯度爆炸问题,所以循环神经网络需要额外的方式来支持稳定的训练,下面不讲解原理,直接给出一种流行的方法。
不过注意:梯度裁剪提供了一个快速修复梯度爆炸的方法, 虽然它并不能完全解决问题,但它是众多有效的技术之一。
def grad_clipping(net, theta): #@save"""裁剪梯度"""if isinstance(net, nn.Module):params = [p for p in net.parameters() if p.requires_grad]else:params = net.paramsnorm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))if norm > theta:for param in params:param.grad[:] *= theta / norm
5 训练
与线性神经网络的训练有三个不同之处:
- 序列数据的不同采样方法(随机采样和顺序分区)将导致隐状态初始化的差异。
- 在更新模型参数之前裁剪梯度,这样即使训练过程中某个点上发生了梯度爆炸,也能保证模型不会发散。
- 使用困惑度来评价模型
对于第一点做一些解释
- 随机采样: 由于每次抽取的数据点是独立的,模型在处理每个样本时通常需要重新初始化隐状态。通常情况下,模型会在每个新的随机样本开始时,使用初始的隐状态。
- 顺序分区: 由于数据点是按时间顺序排列的,模型在处理每个子序列时可以利用前一个子序列的隐状态作为当前子序列的初始隐状态。
def predict_ch8(prefix, num_preds, net, vocab, device): #@save"""在prefix后面生成新字符"""state = net.begin_state(batch_size=1, device=device)outputs = [vocab[prefix[0]]]get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))for y in prefix[1:]: # 预热期_, state = net(get_input(), state)outputs.append(vocab[y])for _ in range(num_preds): # 预测num_preds步y, state = net(get_input(), state)outputs.append(int(y.argmax(dim=1).reshape(1)))return ''.join([vocab.idx_to_token[i] for i in outputs])
def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):"""训练一个迭代周期"""state, timer = None, 0metric = [0, 0] # 累积损失,总词元数量for X, Y in train_iter:if X.shape[0] == 0 or Y.shape[0] == 0: # 跳过空批次continueif state is None or use_random_iter:state = net.begin_state(batch_size=X.shape[0], device=device)else:if isinstance(state, tuple):state = tuple(s.detach() for s in state)else:state = state.detach()X, Y = X.to(device), Y.T.reshape(-1).to(device)y_hat, state = net(X, state)l = loss(y_hat, Y.long()).mean()if isinstance(updater, torch.optim.Optimizer):updater.zero_grad()l.backward()grad_clipping(net, 1)updater.step()else:l.backward()grad_clipping(net, 1)updater(batch_size=1)metric[0] += l.item() * Y.numel() # 使用 l.item() 累积标量损失metric[1] += Y.numel() # 累计总词元数量return metric[0] / max(1, metric[1]) # 避免除以零def train_ch8(net, train_iter, vocab, lr, num_epochs, device,use_random_iter=False):"""训练模型"""loss = nn.CrossEntropyLoss()updater = torch.optim.SGD(net.params, lr)for epoch in range(num_epochs):avg_loss = train_epoch_ch8(net, train_iter, loss, updater,device, use_random_iter)ppl = torch.exp(torch.tensor(avg_loss)) # 转换为 Tensor 以使用 torch.expprint(f'epoch {epoch + 1}, perplexity {ppl:.1f}')print(predict_ch8('time traveller', 50, net, vocab, device))# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 模型初始化
num_hiddens = 512
net = RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_rnn_state, rnn)# 训练模型
num_epochs, lr = 500, 1
train_ch8(net, train_iter, vocab, lr, num_epochs, device)相关文章:
从0开始深度学习(32)——循环神经网络的从零开始实现
本章将从零开始,基于循环神经网络实现字符级语言模型(不是单词级) 首先我们把从0开始深度学习(30)——语言模型和数据集中的load_corpus_time_machine()函数进行引用,用于导入数据: train_iter…...
GitLab使用操作v1.0
1.前置条件 Gitlab 项目地址:http://******/req Gitlab账户信息:例如 001/******自己的分支名称:例如 001-master(注:master只有项目创建者有权限更新,我们只能更新自己分支,然后创建合并请求&…...
cuda conda yolov11 环境搭建
优雅的 yolo v11 标注工具 AutoLabel Conda环境直接识别训练 nvidia-smi 检查CUDA版本 下载nvidia cudnn对应的版本 将cuDNN压缩包内对应的文件复制到本地bin、include、lib的文件夹中 C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6 miniConda快速开始-安装 执行…...
解决SpringBoot连接Websocket报:请求路径 404 No static resource websocket.
问题发现 最近在工作中用到了WebSocket进行前后端的消息通信,后端代码编写完后,测试一下是否连接成功,发现报No static resource websocket.,看这个错貌似将接口变成了静态资源来访问了,第一时间觉得是端点没有注册成…...
element-plus的组件数据配置化封装 - table
目录 一、封装的table、table-column组件以及相关ts类型的定义 1、ATable组件的封装 - index.ts 2、ATableColumn组件的封装 - ATableColumn.ts 3、ATable、ATableColumn类型 - interface.ts 二、ATable、ATableColumn组件的使用 三、相关属性、方法的使用以及相关说明 1. C…...
【二维动态规划:交错字符串】
介绍 编程语言:Java 本篇介绍一道比较经典的二维动态规划题。 交错字符串 主要说明几点: 为什么双指针解不了?为什么是二维动态规划?根据题意分析处转移方程。严格位置依赖和空间压缩优化。 题目介绍 题意有点抽象,…...
goframe开发一个企业网站 MongoDB 完整工具包18
1. MongoDB 工具包完整实现 (mongodb.go) package mongodbimport ("context""fmt""time""github.com/gogf/gf/v2/frame/g""go.mongodb.org/mongo-driver/mongo""go.mongodb.org/mongo-driver/mongo/options" )va…...
在vue中,根据后端接口返回的文件流实现word文件弹窗预览
需求 弹窗预览word文件,因浏览器无法直接根据blob路径直接预览word文件,所以需要利用插件实现。 解决方案 利用docx-preview实现word文件弹窗预览,以node版本16.21.3和docx-preview版本0.1.8为例 具体实现步骤 1、安装docx-preview插件 …...
动态规划之背包问题
0/1背包问题 1.二维数组解法 题目描述:有一个容量为m的背包,还有n个物品,他们的重量分别为w1、w2、w3.....wn,他们的价值分别为v1、v2、v3......vn。每个物品只能使用一次,求可以放进背包物品的最大价值。 输入样例…...
【Python】 深入理解Python的单元测试:用unittest和pytest进行测试驱动开发
《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门! 单元测试是现代软件开发中的重要组成部分,通过验证代码的功能性、准确性和稳定性,提升代码质量和开发效率。本文章深入介绍Python中两种主流单元测试框架:unittest和pytest,并结合测试驱动开发(TDD)…...
Java集合1.0
1.什么是集合? 集合就是一个存放数据的容器,准确的说是放数据对象引用的容器。 集合和数组的区别 数组是固定长度,集合是可变长度。数组可以存储基本数据类型,也可以存储引用数据类型,集合只能存储引用数据类型&…...
Leetcode 336 回文对
示例 1: 输入:words ["abcd","dcba","lls","s","sssll"] 输出:[[0,1],[1,0],[3,2],[2,4]] 解释:可拼接成的回文串为 ["dcbaabcd","abcddcba","sl…...
实现一个可配置的TCP设备模拟器,支持交互和解析配置
前言 诸位在做IOT开发的时候是否有遇到一个问题,那就是模拟一个设备来联调测试,虽然说现在的物联网通信主要是用mqtt通信,但还是有很多设备使用TCP这种协议交互,例如充电桩,还有一些工业设备,TCP这类报文交…...
算法的空间复杂度
空间复杂度 空间复杂度主要是衡量一个算法运行所需要的额外空间,在计算机发展早期,计算机的储存容量很小,所以空间复杂度是很重要的。但是经过计算机行业的迅速发展,计算机的容量已经不再是问题了,所以如今已经不再需…...
自定义协议
1. 问题引入 问题:TCP是面向字节流的(TCP不关心发送的数据是消息、文件还是其他任何类型的数据。它简单地将所有数据视为一个字节序列,即字节流。这意味着TCP不会对发送的数据进行任何特定的边界划分,它只是确保数据的顺序和完整…...
在 Taro 中实现系统主题适配:亮/暗模式
目录 背景实现方案方案一:CSS 变量 prefers-color-scheme 媒体查询什么是 prefers-color-scheme?代码示例 方案二:通过 JavaScript 监听系统主题切换 背景 用Taro开发的微信小程序,需求是页面的UI主题想要跟随手机系统的主题适配…...
autogen框架中使用chatglm4模型实现react
本文将介绍如何使用使用chatglm4实现react,利用环境变量、Tavily API和ReAct代理模式来回答用户提出的问题。 环境变量 首先,我们需要加载环境变量。这可以通过使用dotenv库来实现。 from dotenv import load_dotenv_ load_dotenv()注意.env文件处于…...
读《Effective Java》笔记 - 条目9
条目9:与try-finally 相比,首选 try -with -resource 什么是 try-finally? try-finally 是 Java 中传统的资源管理方式,通常用于确保资源(如文件流、数据库连接等)被正确关闭。 BufferedReader reader n…...
【软件入门】Git快速入门
Git快速入门 文章目录 Git快速入门0.前言1.安装和配置2.新建版本库2.1.本地创建2.2.云端下载 3.版本管理3.1.添加和提交文件3.2.回退版本3.2.1.soft模式3.2.2.mixed模式3.2.3.hard模式3.2.4.使用场景 3.3.查看版本差异3.4.忽略文件 4.云端配置4.1.Github4.1.1.SSH配置4.1.2.关联…...
nextjs window is not defined
问题产生的原因 在 Next.js 中,“window is not defined” 错误通常出现在服务器端渲染(Server - Side Rendering,SSR)的代码中。这是因为window对象是浏览器环境中的全局对象,在服务器端没有window这个概念。例如&am…...
挑战杯推荐项目
“人工智能”创意赛 - 智能艺术创作助手:借助大模型技术,开发能根据用户输入的主题、风格等要求,生成绘画、音乐、文学作品等多种形式艺术创作灵感或初稿的应用,帮助艺术家和创意爱好者激发创意、提高创作效率。 - 个性化梦境…...
java_网络服务相关_gateway_nacos_feign区别联系
1. spring-cloud-starter-gateway 作用:作为微服务架构的网关,统一入口,处理所有外部请求。 核心能力: 路由转发(基于路径、服务名等)过滤器(鉴权、限流、日志、Header 处理)支持负…...
基于距离变化能量开销动态调整的WSN低功耗拓扑控制开销算法matlab仿真
目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.算法仿真参数 5.算法理论概述 6.参考文献 7.完整程序 1.程序功能描述 通过动态调整节点通信的能量开销,平衡网络负载,延长WSN生命周期。具体通过建立基于距离的能量消耗模型&am…...
深入浅出:JavaScript 中的 `window.crypto.getRandomValues()` 方法
深入浅出:JavaScript 中的 window.crypto.getRandomValues() 方法 在现代 Web 开发中,随机数的生成看似简单,却隐藏着许多玄机。无论是生成密码、加密密钥,还是创建安全令牌,随机数的质量直接关系到系统的安全性。Jav…...
Java-41 深入浅出 Spring - 声明式事务的支持 事务配置 XML模式 XML+注解模式
点一下关注吧!!!非常感谢!!持续更新!!! 🚀 AI篇持续更新中!(长期更新) 目前2025年06月05日更新到: AI炼丹日志-28 - Aud…...
C++课设:简易日历程序(支持传统节假日 + 二十四节气 + 个人纪念日管理)
名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 专栏介绍:《编程项目实战》 目录 一、为什么要开发一个日历程序?1. 深入理解时间算法2. 练习面向对象设计3. 学习数据结构应用二、核心算法深度解析…...
comfyui 工作流中 图生视频 如何增加视频的长度到5秒
comfyUI 工作流怎么可以生成更长的视频。除了硬件显存要求之外还有别的方法吗? 在ComfyUI中实现图生视频并延长到5秒,需要结合多个扩展和技巧。以下是完整解决方案: 核心工作流配置(24fps下5秒120帧) #mermaid-svg-yP…...
如何配置一个sql server使得其它用户可以通过excel odbc获取数据
要让其他用户通过 Excel 使用 ODBC 连接到 SQL Server 获取数据,你需要完成以下配置步骤: ✅ 一、在 SQL Server 端配置(服务器设置) 1. 启用 TCP/IP 协议 打开 “SQL Server 配置管理器”。导航到:SQL Server 网络配…...
聚六亚甲基单胍盐酸盐市场深度解析:现状、挑战与机遇
根据 QYResearch 发布的市场报告显示,全球市场规模预计在 2031 年达到 9848 万美元,2025 - 2031 年期间年复合增长率(CAGR)为 3.7%。在竞争格局上,市场集中度较高,2024 年全球前十强厂商占据约 74.0% 的市场…...
边缘计算网关提升水产养殖尾水处理的远程运维效率
一、项目背景 随着水产养殖行业的快速发展,养殖尾水的处理成为了一个亟待解决的环保问题。传统的尾水处理方式不仅效率低下,而且难以实现精准监控和管理。为了提升尾水处理的效果和效率,同时降低人力成本,某大型水产养殖企业决定…...
