从0开始深度学习(32)——循环神经网络的从零开始实现
本章将从零开始,基于循环神经网络实现字符级语言模型(不是单词级)
首先我们把从0开始深度学习(30)——语言模型和数据集中的load_corpus_time_machine()
函数进行引用,用于导入数据:
train_iter, vocab = load_corpus_time_machine()
train_iter
运行结果:
train_iter中的每个数字都表示在vocab中的索引,将这些索引直接输入神经网络可能会使学习变得困难,我们通常将每个词元表示为更具表现力的特征向量,即one-hot编码。
1 one-hot编码
假设词表中不同词元的数目为 N N N,也就是len(vocab)
,所以词元索引的范围为 0 N − 1 0~N-1 0 N−1。如果词元的索引是整数 i i i, 那么我们将创建一个长度为 N N N的全 0 0 0向量, 并将第 i i i处的元素设置为 1 1 1。例如索引为 0 0 0和 2 2 2的独热向量如下所示:
from torch.nn import functional as F
import torchF.one_hot(torch.tensor([0, 2]), len(vocab))
运行结果:
tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0],# 第0处设置为1,表示索引1[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0]])# 第2处设置为1,表示索引2
由于我们每次采样的小批量数据形状是二维张量:(批量大小,时间步数),one_hot函数会将这样一个小批量数据转换成三维张量, 张量的最后一个维度等于词表大小len(vocab)
。我们经常转换输入的维度,以便获得形状为 (时间步数,批量大小,词表大小)
的输出。 这将使我们能够更方便地通过最外层的维度, 一步一步地更新小批量数据的隐状态。
转化这一步将在后面进行
2 初始化模型参数
初始化循环神经网络模型的模型参数, 隐藏单元数num_hiddens
是一个可调的超参数。 当训练语言模型时,输入和输出来自相同的词表,因此,它们具有相同的维度,即词表的大小:
def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device) * 0.01# 隐藏层参数W_xh = normal((num_inputs, num_hiddens))W_hh = normal((num_hiddens, num_hiddens))b_h = torch.zeros(num_hiddens, device=device)# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params
3 循环神经网络模型
为了定义循环神经网络模型, 我们首先需要一个init_rnn_state函数在初始化时返回隐状态,这个函数的返回是一个张量,形状为(批量大小,隐藏单元数)
def init_rnn_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )
下面的rnn函数定义了如何在一个时间步内计算隐状态和输出。
def init_rnn_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )def rnn(inputs, state, params):# inputs的形状:(时间步数量,批量大小,词表大小)W_xh, W_hh, b_h, W_hq, b_q = paramsH, = state # 提取隐状态outputs = []# X的形状:(批量大小,词表大小)for X in inputs:H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h) # 使用激活函数Y = torch.mm(H, W_hq) + b_q # 做矩阵乘法,即 (隐变量*权重+偏置)outputs.append(Y)return torch.cat(outputs, dim=0), (H,)# cat()函数将所有时间步的输出拼接成一个张量,形状为 (时间步数量 * 批量大小, 输出大小)# (H,)为最后一个时间步的隐藏状态
定义了所有需要的函数之后,接下来我们创建一个类来包装这些函数, 并存储从零开始实现的循环神经网络模型的参数:
class RNNModelScratch: #@save"""从零开始实现的循环神经网络模型"""def __init__(self, vocab_size, num_hiddens, device,get_params, init_state, forward_fn):self.vocab_size, self.num_hiddens = vocab_size, num_hiddensself.params = get_params(vocab_size, num_hiddens, device)self.init_state, self.forward_fn = init_state, forward_fndef __call__(self, X, state):X = F.one_hot(X.T, self.vocab_size).type(torch.float32)# 进行one-hot编码return self.forward_fn(X, state, self.params)# 调用前向传播函数,传入编码后的输入数据、初始隐藏状态和模型参数def begin_state(self, batch_size, device):return self.init_state(batch_size, self.num_hiddens, device)# 初始化的隐藏状态
我们可以做一个测试:
num_hiddens = 512
net = RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,init_rnn_state, rnn)
state = net.begin_state(X.shape[0], d2l.try_gpu())
Y, new_state = net(X.to(d2l.try_gpu()), state)
Y.shape, len(new_state), new_state[0].shape
# 输出形状,隐状态形状
运行结果:
(torch.Size([10, 28]), 1, torch.Size([2, 512]))
输出形状是(时间步数*批量大小,词表大小), 隐状态形状是(批量大小,隐藏单元数),符合要求
4 梯度裁剪
在编写训练函数之前,要引入一个方法——梯度裁剪,用于防止梯度爆炸问题。
对于长度为 T T T的序列,在迭代时要计算 T T T个时间步上的梯度,于是会在反向传播中产生长度为 O ( T ) O(T) O(T)的矩阵乘法链,当 T T T过大时,有可能导致梯度爆炸问题,所以循环神经网络需要额外的方式来支持稳定的训练,下面不讲解原理,直接给出一种流行的方法。
不过注意:梯度裁剪提供了一个快速修复梯度爆炸的方法, 虽然它并不能完全解决问题,但它是众多有效的技术之一。
def grad_clipping(net, theta): #@save"""裁剪梯度"""if isinstance(net, nn.Module):params = [p for p in net.parameters() if p.requires_grad]else:params = net.paramsnorm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))if norm > theta:for param in params:param.grad[:] *= theta / norm
5 训练
与线性神经网络的训练有三个不同之处:
- 序列数据的不同采样方法(随机采样和顺序分区)将导致隐状态初始化的差异。
- 在更新模型参数之前裁剪梯度,这样即使训练过程中某个点上发生了梯度爆炸,也能保证模型不会发散。
- 使用困惑度来评价模型
对于第一点做一些解释
- 随机采样: 由于每次抽取的数据点是独立的,模型在处理每个样本时通常需要重新初始化隐状态。通常情况下,模型会在每个新的随机样本开始时,使用初始的隐状态。
- 顺序分区: 由于数据点是按时间顺序排列的,模型在处理每个子序列时可以利用前一个子序列的隐状态作为当前子序列的初始隐状态。
def predict_ch8(prefix, num_preds, net, vocab, device): #@save"""在prefix后面生成新字符"""state = net.begin_state(batch_size=1, device=device)outputs = [vocab[prefix[0]]]get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))for y in prefix[1:]: # 预热期_, state = net(get_input(), state)outputs.append(vocab[y])for _ in range(num_preds): # 预测num_preds步y, state = net(get_input(), state)outputs.append(int(y.argmax(dim=1).reshape(1)))return ''.join([vocab.idx_to_token[i] for i in outputs])
def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):"""训练一个迭代周期"""state, timer = None, 0metric = [0, 0] # 累积损失,总词元数量for X, Y in train_iter:if X.shape[0] == 0 or Y.shape[0] == 0: # 跳过空批次continueif state is None or use_random_iter:state = net.begin_state(batch_size=X.shape[0], device=device)else:if isinstance(state, tuple):state = tuple(s.detach() for s in state)else:state = state.detach()X, Y = X.to(device), Y.T.reshape(-1).to(device)y_hat, state = net(X, state)l = loss(y_hat, Y.long()).mean()if isinstance(updater, torch.optim.Optimizer):updater.zero_grad()l.backward()grad_clipping(net, 1)updater.step()else:l.backward()grad_clipping(net, 1)updater(batch_size=1)metric[0] += l.item() * Y.numel() # 使用 l.item() 累积标量损失metric[1] += Y.numel() # 累计总词元数量return metric[0] / max(1, metric[1]) # 避免除以零def train_ch8(net, train_iter, vocab, lr, num_epochs, device,use_random_iter=False):"""训练模型"""loss = nn.CrossEntropyLoss()updater = torch.optim.SGD(net.params, lr)for epoch in range(num_epochs):avg_loss = train_epoch_ch8(net, train_iter, loss, updater,device, use_random_iter)ppl = torch.exp(torch.tensor(avg_loss)) # 转换为 Tensor 以使用 torch.expprint(f'epoch {epoch + 1}, perplexity {ppl:.1f}')print(predict_ch8('time traveller', 50, net, vocab, device))# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 模型初始化
num_hiddens = 512
net = RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_rnn_state, rnn)# 训练模型
num_epochs, lr = 500, 1
train_ch8(net, train_iter, vocab, lr, num_epochs, device)
相关文章:

从0开始深度学习(32)——循环神经网络的从零开始实现
本章将从零开始,基于循环神经网络实现字符级语言模型(不是单词级) 首先我们把从0开始深度学习(30)——语言模型和数据集中的load_corpus_time_machine()函数进行引用,用于导入数据: train_iter…...

GitLab使用操作v1.0
1.前置条件 Gitlab 项目地址:http://******/req Gitlab账户信息:例如 001/******自己的分支名称:例如 001-master(注:master只有项目创建者有权限更新,我们只能更新自己分支,然后创建合并请求&…...

cuda conda yolov11 环境搭建
优雅的 yolo v11 标注工具 AutoLabel Conda环境直接识别训练 nvidia-smi 检查CUDA版本 下载nvidia cudnn对应的版本 将cuDNN压缩包内对应的文件复制到本地bin、include、lib的文件夹中 C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6 miniConda快速开始-安装 执行…...

解决SpringBoot连接Websocket报:请求路径 404 No static resource websocket.
问题发现 最近在工作中用到了WebSocket进行前后端的消息通信,后端代码编写完后,测试一下是否连接成功,发现报No static resource websocket.,看这个错貌似将接口变成了静态资源来访问了,第一时间觉得是端点没有注册成…...

element-plus的组件数据配置化封装 - table
目录 一、封装的table、table-column组件以及相关ts类型的定义 1、ATable组件的封装 - index.ts 2、ATableColumn组件的封装 - ATableColumn.ts 3、ATable、ATableColumn类型 - interface.ts 二、ATable、ATableColumn组件的使用 三、相关属性、方法的使用以及相关说明 1. C…...

【二维动态规划:交错字符串】
介绍 编程语言:Java 本篇介绍一道比较经典的二维动态规划题。 交错字符串 主要说明几点: 为什么双指针解不了?为什么是二维动态规划?根据题意分析处转移方程。严格位置依赖和空间压缩优化。 题目介绍 题意有点抽象,…...

goframe开发一个企业网站 MongoDB 完整工具包18
1. MongoDB 工具包完整实现 (mongodb.go) package mongodbimport ("context""fmt""time""github.com/gogf/gf/v2/frame/g""go.mongodb.org/mongo-driver/mongo""go.mongodb.org/mongo-driver/mongo/options" )va…...

在vue中,根据后端接口返回的文件流实现word文件弹窗预览
需求 弹窗预览word文件,因浏览器无法直接根据blob路径直接预览word文件,所以需要利用插件实现。 解决方案 利用docx-preview实现word文件弹窗预览,以node版本16.21.3和docx-preview版本0.1.8为例 具体实现步骤 1、安装docx-preview插件 …...

动态规划之背包问题
0/1背包问题 1.二维数组解法 题目描述:有一个容量为m的背包,还有n个物品,他们的重量分别为w1、w2、w3.....wn,他们的价值分别为v1、v2、v3......vn。每个物品只能使用一次,求可以放进背包物品的最大价值。 输入样例…...

【Python】 深入理解Python的单元测试:用unittest和pytest进行测试驱动开发
《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门! 单元测试是现代软件开发中的重要组成部分,通过验证代码的功能性、准确性和稳定性,提升代码质量和开发效率。本文章深入介绍Python中两种主流单元测试框架:unittest和pytest,并结合测试驱动开发(TDD)…...

Java集合1.0
1.什么是集合? 集合就是一个存放数据的容器,准确的说是放数据对象引用的容器。 集合和数组的区别 数组是固定长度,集合是可变长度。数组可以存储基本数据类型,也可以存储引用数据类型,集合只能存储引用数据类型&…...

Leetcode 336 回文对
示例 1: 输入:words ["abcd","dcba","lls","s","sssll"] 输出:[[0,1],[1,0],[3,2],[2,4]] 解释:可拼接成的回文串为 ["dcbaabcd","abcddcba","sl…...

实现一个可配置的TCP设备模拟器,支持交互和解析配置
前言 诸位在做IOT开发的时候是否有遇到一个问题,那就是模拟一个设备来联调测试,虽然说现在的物联网通信主要是用mqtt通信,但还是有很多设备使用TCP这种协议交互,例如充电桩,还有一些工业设备,TCP这类报文交…...

算法的空间复杂度
空间复杂度 空间复杂度主要是衡量一个算法运行所需要的额外空间,在计算机发展早期,计算机的储存容量很小,所以空间复杂度是很重要的。但是经过计算机行业的迅速发展,计算机的容量已经不再是问题了,所以如今已经不再需…...

自定义协议
1. 问题引入 问题:TCP是面向字节流的(TCP不关心发送的数据是消息、文件还是其他任何类型的数据。它简单地将所有数据视为一个字节序列,即字节流。这意味着TCP不会对发送的数据进行任何特定的边界划分,它只是确保数据的顺序和完整…...

在 Taro 中实现系统主题适配:亮/暗模式
目录 背景实现方案方案一:CSS 变量 prefers-color-scheme 媒体查询什么是 prefers-color-scheme?代码示例 方案二:通过 JavaScript 监听系统主题切换 背景 用Taro开发的微信小程序,需求是页面的UI主题想要跟随手机系统的主题适配…...

autogen框架中使用chatglm4模型实现react
本文将介绍如何使用使用chatglm4实现react,利用环境变量、Tavily API和ReAct代理模式来回答用户提出的问题。 环境变量 首先,我们需要加载环境变量。这可以通过使用dotenv库来实现。 from dotenv import load_dotenv_ load_dotenv()注意.env文件处于…...

读《Effective Java》笔记 - 条目9
条目9:与try-finally 相比,首选 try -with -resource 什么是 try-finally? try-finally 是 Java 中传统的资源管理方式,通常用于确保资源(如文件流、数据库连接等)被正确关闭。 BufferedReader reader n…...

【软件入门】Git快速入门
Git快速入门 文章目录 Git快速入门0.前言1.安装和配置2.新建版本库2.1.本地创建2.2.云端下载 3.版本管理3.1.添加和提交文件3.2.回退版本3.2.1.soft模式3.2.2.mixed模式3.2.3.hard模式3.2.4.使用场景 3.3.查看版本差异3.4.忽略文件 4.云端配置4.1.Github4.1.1.SSH配置4.1.2.关联…...

nextjs window is not defined
问题产生的原因 在 Next.js 中,“window is not defined” 错误通常出现在服务器端渲染(Server - Side Rendering,SSR)的代码中。这是因为window对象是浏览器环境中的全局对象,在服务器端没有window这个概念。例如&am…...

C语言实现冒泡排序:从基础到优化全解析
一、什么是冒泡排序? 冒泡排序(Bubble Sort)是一种经典的排序算法,其工作原理非常直观:通过多次比较和交换相邻元素,将较大的元素“冒泡”到数组的末尾。经过多轮迭代,整个数组会变得有序。 二…...

windows11下git与 openssl要注意的问题
看了一下自己贴文的历史,有一条重要的忘了写了。 当时帮有位同事配置gitlab,众说周知gitlab是不太好操作。 但我还是自认自己git还是相当熟的。 解决了一系列问题,如配置代理,sshkey,私有库,等等࿰…...

lua除法bug
故事背景,新来了一个数值,要改公式。神奇的一幕出现了,公式算出一个非常大的数。排查是lua有一个除法bug,1除以大数得到一个非常大的数。 function div(a, b)return tonumber(string.format("%.2f", a/b)) end print(1/73003) pri…...

Ubuntu下Docker容器java服务往mysql插入中文数据乱码
一、问题描述 1、java服务部署在ubuntu下的docker容器内,但是会出现部分插入中文数据显示乱码,如图所示: 二、解决方案 1、查看mysql是否支持utf8,登录进入Mysql 输入命令: mysql -u root -pshow variables like c…...

C语言根据字符串变量获取/设置结构体成员值
一、背景 在项目中需要根据从数据库中获取的字段与对应的键值付给对应结构体成员上,而c语言中没有类似的反射机制,所以需要实现类似功能。例,从表中查到a 10,在结构体t中,需要将 t.a 10。 二、实现 感谢ChatGPT&…...

Selenium 自动化测试demo
场景描述: 模拟用户登录页面操作,包括输入用户名、密码、验证码。验证码为算数运算,如下: 使用到的工具和依赖: 1. Selenium:pip install selenium 2. 需要安装浏览器驱动:这里使用的是Edge 3…...

LeetCode 111.二叉树的最小深度
题目: 给定一个二叉树,找出其最小深度。 最小深度是从根节点到最近叶子节点的最短路径上的节点数量。 说明:叶子节点是指没有子节点的节点。 思路:自底向上(归)/自顶向下(递) DF…...

大工C语言作业答案
前言 这里是大连理工大学新版C语言课程MOOC作业的答案。 后期我会把全部的作业答案开源出来,希望对大家有帮助。 第九周第一题 #define _CRT_SECURE_NO_WARNINGS #include <stdio.h> int B(int i) {int sum 1;while (i > 0){sum i * sum;i--;}return su…...

【Unity踩坑】Unity中父对象是非均匀缩放时出现倾斜或剪切现象
The game object is deformed when the parent object is in non-uniform scaling. 先来看一下现象 有两个Cube, Cube1(Scale2,1,1),Cube2(Scale1,1,1) 将Cube2拖拽为Cube2的子对象。并且将position设置为(-0.6,1,0&a…...

QT 跨平台实现 SSDP通信 支持多网卡
一.多网卡场景 在做SSDP通信的时候,客户端发出M-search命令后, 主机没有捕捉到SSDP的消息,你可以查看下,是不是局域网下,既打开了wifi,又连接了本地网络,mac os下很容易出现这种场景。此时,我们发送消息时,需要遍历所有网卡并发送M-search命令。 二.QT相关接口介绍 1…...