当前位置: 首页 > news >正文

从0开始深度学习(32)——循环神经网络的从零开始实现

本章将从零开始,基于循环神经网络实现字符级语言模型(不是单词级)

首先我们把从0开始深度学习(30)——语言模型和数据集中的load_corpus_time_machine()函数进行引用,用于导入数据:

train_iter, vocab = load_corpus_time_machine()
train_iter

运行结果:
在这里插入图片描述

train_iter中的每个数字都表示在vocab中的索引,将这些索引直接输入神经网络可能会使学习变得困难,我们通常将每个词元表示为更具表现力的特征向量,即one-hot编码

1 one-hot编码

假设词表中不同词元的数目为 N N N,也就是len(vocab),所以词元索引的范围为 0 N − 1 0~N-1 0 N1。如果词元的索引是整数 i i i, 那么我们将创建一个长度为 N N N的全 0 0 0向量, 并将第 i i i处的元素设置为 1 1 1。例如索引为 0 0 0 2 2 2的独热向量如下所示:

from torch.nn import functional as F
import torchF.one_hot(torch.tensor([0, 2]), len(vocab))

运行结果:

tensor([[1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0],# 第0处设置为1,表示索引1[0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,0, 0, 0, 0]])# 第2处设置为1,表示索引2

由于我们每次采样的小批量数据形状是二维张量:(批量大小,时间步数),one_hot函数会将这样一个小批量数据转换成三维张量, 张量的最后一个维度等于词表大小len(vocab)。我们经常转换输入的维度,以便获得形状为 (时间步数,批量大小,词表大小)的输出。 这将使我们能够更方便地通过最外层的维度, 一步一步地更新小批量数据的隐状态。

转化这一步将在后面进行

2 初始化模型参数

初始化循环神经网络模型的模型参数, 隐藏单元数num_hiddens是一个可调的超参数。 当训练语言模型时,输入和输出来自相同的词表,因此,它们具有相同的维度,即词表的大小:

def get_params(vocab_size, num_hiddens, device):num_inputs = num_outputs = vocab_sizedef normal(shape):return torch.randn(size=shape, device=device) * 0.01# 隐藏层参数W_xh = normal((num_inputs, num_hiddens))W_hh = normal((num_hiddens, num_hiddens))b_h = torch.zeros(num_hiddens, device=device)# 输出层参数W_hq = normal((num_hiddens, num_outputs))b_q = torch.zeros(num_outputs, device=device)# 附加梯度params = [W_xh, W_hh, b_h, W_hq, b_q]for param in params:param.requires_grad_(True)return params

3 循环神经网络模型

为了定义循环神经网络模型, 我们首先需要一个init_rnn_state函数在初始化时返回隐状态,这个函数的返回是一个张量,形状为(批量大小,隐藏单元数)

def init_rnn_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )

下面的rnn函数定义了如何在一个时间步内计算隐状态输出

def init_rnn_state(batch_size, num_hiddens, device):return (torch.zeros((batch_size, num_hiddens), device=device), )def rnn(inputs, state, params):# inputs的形状:(时间步数量,批量大小,词表大小)W_xh, W_hh, b_h, W_hq, b_q = paramsH, = state # 提取隐状态outputs = []# X的形状:(批量大小,词表大小)for X in inputs:H = torch.tanh(torch.mm(X, W_xh) + torch.mm(H, W_hh) + b_h) # 使用激活函数Y = torch.mm(H, W_hq) + b_q # 做矩阵乘法,即 (隐变量*权重+偏置)outputs.append(Y)return torch.cat(outputs, dim=0), (H,)# cat()函数将所有时间步的输出拼接成一个张量,形状为 (时间步数量 * 批量大小, 输出大小)# (H,)为最后一个时间步的隐藏状态

定义了所有需要的函数之后,接下来我们创建一个类来包装这些函数, 并存储从零开始实现的循环神经网络模型的参数:

class RNNModelScratch: #@save"""从零开始实现的循环神经网络模型"""def __init__(self, vocab_size, num_hiddens, device,get_params, init_state, forward_fn):self.vocab_size, self.num_hiddens = vocab_size, num_hiddensself.params = get_params(vocab_size, num_hiddens, device)self.init_state, self.forward_fn = init_state, forward_fndef __call__(self, X, state):X = F.one_hot(X.T, self.vocab_size).type(torch.float32)# 进行one-hot编码return self.forward_fn(X, state, self.params)# 调用前向传播函数,传入编码后的输入数据、初始隐藏状态和模型参数def begin_state(self, batch_size, device):return self.init_state(batch_size, self.num_hiddens, device)# 初始化的隐藏状态

我们可以做一个测试:

num_hiddens = 512
net = RNNModelScratch(len(vocab), num_hiddens, d2l.try_gpu(), get_params,init_rnn_state, rnn)
state = net.begin_state(X.shape[0], d2l.try_gpu())
Y, new_state = net(X.to(d2l.try_gpu()), state)
Y.shape, len(new_state), new_state[0].shape
# 输出形状,隐状态形状

运行结果:

(torch.Size([10, 28]), 1, torch.Size([2, 512]))

输出形状是(时间步数*批量大小,词表大小), 隐状态形状是(批量大小,隐藏单元数),符合要求

4 梯度裁剪

在编写训练函数之前,要引入一个方法——梯度裁剪,用于防止梯度爆炸问题。

对于长度为 T T T的序列,在迭代时要计算 T T T个时间步上的梯度,于是会在反向传播中产生长度为 O ( T ) O(T) O(T)的矩阵乘法链,当 T T T过大时,有可能导致梯度爆炸问题,所以循环神经网络需要额外的方式来支持稳定的训练,下面不讲解原理,直接给出一种流行的方法。

不过注意:梯度裁剪提供了一个快速修复梯度爆炸的方法, 虽然它并不能完全解决问题,但它是众多有效的技术之一。

def grad_clipping(net, theta):  #@save"""裁剪梯度"""if isinstance(net, nn.Module):params = [p for p in net.parameters() if p.requires_grad]else:params = net.paramsnorm = torch.sqrt(sum(torch.sum((p.grad ** 2)) for p in params))if norm > theta:for param in params:param.grad[:] *= theta / norm

5 训练

与线性神经网络的训练有三个不同之处:

  1. 序列数据的不同采样方法(随机采样和顺序分区)将导致隐状态初始化的差异。
  2. 在更新模型参数之前裁剪梯度,这样即使训练过程中某个点上发生了梯度爆炸,也能保证模型不会发散。
  3. 使用困惑度来评价模型

对于第一点做一些解释

  • 随机采样: 由于每次抽取的数据点是独立的,模型在处理每个样本时通常需要重新初始化隐状态。通常情况下,模型会在每个新的随机样本开始时,使用初始的隐状态。
  • 顺序分区: 由于数据点是按时间顺序排列的,模型在处理每个子序列时可以利用前一个子序列的隐状态作为当前子序列的初始隐状态
def predict_ch8(prefix, num_preds, net, vocab, device):  #@save"""在prefix后面生成新字符"""state = net.begin_state(batch_size=1, device=device)outputs = [vocab[prefix[0]]]get_input = lambda: torch.tensor([outputs[-1]], device=device).reshape((1, 1))for y in prefix[1:]:  # 预热期_, state = net(get_input(), state)outputs.append(vocab[y])for _ in range(num_preds):  # 预测num_preds步y, state = net(get_input(), state)outputs.append(int(y.argmax(dim=1).reshape(1)))return ''.join([vocab.idx_to_token[i] for i in outputs])
def train_epoch_ch8(net, train_iter, loss, updater, device, use_random_iter):"""训练一个迭代周期"""state, timer = None, 0metric = [0, 0]  # 累积损失,总词元数量for X, Y in train_iter:if X.shape[0] == 0 or Y.shape[0] == 0:  # 跳过空批次continueif state is None or use_random_iter:state = net.begin_state(batch_size=X.shape[0], device=device)else:if isinstance(state, tuple):state = tuple(s.detach() for s in state)else:state = state.detach()X, Y = X.to(device), Y.T.reshape(-1).to(device)y_hat, state = net(X, state)l = loss(y_hat, Y.long()).mean()if isinstance(updater, torch.optim.Optimizer):updater.zero_grad()l.backward()grad_clipping(net, 1)updater.step()else:l.backward()grad_clipping(net, 1)updater(batch_size=1)metric[0] += l.item() * Y.numel()  # 使用 l.item() 累积标量损失metric[1] += Y.numel()  # 累计总词元数量return metric[0] / max(1, metric[1])  # 避免除以零def train_ch8(net, train_iter, vocab, lr, num_epochs, device,use_random_iter=False):"""训练模型"""loss = nn.CrossEntropyLoss()updater = torch.optim.SGD(net.params, lr)for epoch in range(num_epochs):avg_loss = train_epoch_ch8(net, train_iter, loss, updater,device, use_random_iter)ppl = torch.exp(torch.tensor(avg_loss))  # 转换为 Tensor 以使用 torch.expprint(f'epoch {epoch + 1}, perplexity {ppl:.1f}')print(predict_ch8('time traveller', 50, net, vocab, device))# 设置设备
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 模型初始化
num_hiddens = 512
net = RNNModelScratch(len(vocab), num_hiddens, device, get_params, init_rnn_state, rnn)# 训练模型
num_epochs, lr = 500, 1
train_ch8(net, train_iter, vocab, lr, num_epochs, device)

相关文章:

从0开始深度学习(32)——循环神经网络的从零开始实现

本章将从零开始,基于循环神经网络实现字符级语言模型(不是单词级) 首先我们把从0开始深度学习(30)——语言模型和数据集中的load_corpus_time_machine()函数进行引用,用于导入数据: train_iter…...

GitLab使用操作v1.0

1.前置条件 Gitlab 项目地址:http://******/req Gitlab账户信息:例如 001/******自己的分支名称:例如 001-master(注:master只有项目创建者有权限更新,我们只能更新自己分支,然后创建合并请求&…...

cuda conda yolov11 环境搭建

优雅的 yolo v11 标注工具 AutoLabel Conda环境直接识别训练 nvidia-smi 检查CUDA版本 下载nvidia cudnn对应的版本 将cuDNN压缩包内对应的文件复制到本地bin、include、lib的文件夹中 C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v12.6 miniConda快速开始-安装 执行…...

解决SpringBoot连接Websocket报:请求路径 404 No static resource websocket.

问题发现 最近在工作中用到了WebSocket进行前后端的消息通信,后端代码编写完后,测试一下是否连接成功,发现报No static resource websocket.,看这个错貌似将接口变成了静态资源来访问了,第一时间觉得是端点没有注册成…...

element-plus的组件数据配置化封装 - table

目录 一、封装的table、table-column组件以及相关ts类型的定义 1、ATable组件的封装 - index.ts 2、ATableColumn组件的封装 - ATableColumn.ts 3、ATable、ATableColumn类型 - interface.ts 二、ATable、ATableColumn组件的使用 三、相关属性、方法的使用以及相关说明 1. C…...

【二维动态规划:交错字符串】

介绍 编程语言:Java 本篇介绍一道比较经典的二维动态规划题。 交错字符串 主要说明几点: 为什么双指针解不了?为什么是二维动态规划?根据题意分析处转移方程。严格位置依赖和空间压缩优化。 题目介绍 题意有点抽象&#xff0c…...

goframe开发一个企业网站 MongoDB 完整工具包18

1. MongoDB 工具包完整实现 (mongodb.go) package mongodbimport ("context""fmt""time""github.com/gogf/gf/v2/frame/g""go.mongodb.org/mongo-driver/mongo""go.mongodb.org/mongo-driver/mongo/options" )va…...

在vue中,根据后端接口返回的文件流实现word文件弹窗预览

需求 弹窗预览word文件,因浏览器无法直接根据blob路径直接预览word文件,所以需要利用插件实现。 解决方案 利用docx-preview实现word文件弹窗预览,以node版本16.21.3和docx-preview版本0.1.8为例 具体实现步骤 1、安装docx-preview插件 …...

动态规划之背包问题

0/1背包问题 1.二维数组解法 题目描述:有一个容量为m的背包,还有n个物品,他们的重量分别为w1、w2、w3.....wn,他们的价值分别为v1、v2、v3......vn。每个物品只能使用一次,求可以放进背包物品的最大价值。 输入样例…...

【Python】 深入理解Python的单元测试:用unittest和pytest进行测试驱动开发

《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门! 单元测试是现代软件开发中的重要组成部分,通过验证代码的功能性、准确性和稳定性,提升代码质量和开发效率。本文章深入介绍Python中两种主流单元测试框架:unittest和pytest,并结合测试驱动开发(TDD)…...

Java集合1.0

1.什么是集合? 集合就是一个存放数据的容器,准确的说是放数据对象引用的容器。 集合和数组的区别 数组是固定长度,集合是可变长度。数组可以存储基本数据类型,也可以存储引用数据类型,集合只能存储引用数据类型&…...

Leetcode 336 回文对

示例 1: 输入:words ["abcd","dcba","lls","s","sssll"] 输出:[[0,1],[1,0],[3,2],[2,4]] 解释:可拼接成的回文串为 ["dcbaabcd","abcddcba","sl…...

实现一个可配置的TCP设备模拟器,支持交互和解析配置

前言 诸位在做IOT开发的时候是否有遇到一个问题,那就是模拟一个设备来联调测试,虽然说现在的物联网通信主要是用mqtt通信,但还是有很多设备使用TCP这种协议交互,例如充电桩,还有一些工业设备,TCP这类报文交…...

算法的空间复杂度

空间复杂度 空间复杂度主要是衡量一个算法运行所需要的额外空间,在计算机发展早期,计算机的储存容量很小,所以空间复杂度是很重要的。但是经过计算机行业的迅速发展,计算机的容量已经不再是问题了,所以如今已经不再需…...

自定义协议

1. 问题引入 问题:TCP是面向字节流的(TCP不关心发送的数据是消息、文件还是其他任何类型的数据。它简单地将所有数据视为一个字节序列,即字节流。这意味着TCP不会对发送的数据进行任何特定的边界划分,它只是确保数据的顺序和完整…...

在 Taro 中实现系统主题适配:亮/暗模式

目录 背景实现方案方案一:CSS 变量 prefers-color-scheme 媒体查询什么是 prefers-color-scheme?代码示例 方案二:通过 JavaScript 监听系统主题切换 背景 用Taro开发的微信小程序,需求是页面的UI主题想要跟随手机系统的主题适配…...

autogen框架中使用chatglm4模型实现react

本文将介绍如何使用使用chatglm4实现react,利用环境变量、Tavily API和ReAct代理模式来回答用户提出的问题。 环境变量 首先,我们需要加载环境变量。这可以通过使用dotenv库来实现。 from dotenv import load_dotenv_ load_dotenv()注意.env文件处于…...

读《Effective Java》笔记 - 条目9

条目9:与try-finally 相比,首选 try -with -resource 什么是 try-finally? try-finally 是 Java 中传统的资源管理方式,通常用于确保资源(如文件流、数据库连接等)被正确关闭。 BufferedReader reader n…...

【软件入门】Git快速入门

Git快速入门 文章目录 Git快速入门0.前言1.安装和配置2.新建版本库2.1.本地创建2.2.云端下载 3.版本管理3.1.添加和提交文件3.2.回退版本3.2.1.soft模式3.2.2.mixed模式3.2.3.hard模式3.2.4.使用场景 3.3.查看版本差异3.4.忽略文件 4.云端配置4.1.Github4.1.1.SSH配置4.1.2.关联…...

nextjs window is not defined

问题产生的原因 在 Next.js 中,“window is not defined” 错误通常出现在服务器端渲染(Server - Side Rendering,SSR)的代码中。这是因为window对象是浏览器环境中的全局对象,在服务器端没有window这个概念。例如&am…...

微信小程序云开发平台MySQL的连接方式

注:微信小程序云开发平台指的是腾讯云开发 先给结论:微信小程序云开发平台的MySQL,无法通过获取数据库连接信息的方式进行连接,连接只能通过云开发的SDK连接,具体要参考官方文档: 为什么? 因为…...

智能仓储的未来:自动化、AI与数据分析如何重塑物流中心

当仓库学会“思考”,物流的终极形态正在诞生 想象这样的场景: 凌晨3点,某物流中心灯火通明却空无一人。AGV机器人集群根据实时订单动态规划路径;AI视觉系统在0.1秒内扫描包裹信息;数字孪生平台正模拟次日峰值流量压力…...

Spring AI与Spring Modulith核心技术解析

Spring AI核心架构解析 Spring AI(https://spring.io/projects/spring-ai)作为Spring生态中的AI集成框架,其核心设计理念是通过模块化架构降低AI应用的开发复杂度。与Python生态中的LangChain/LlamaIndex等工具类似,但特别为多语…...

有限自动机到正规文法转换器v1.0

1 项目简介 这是一个功能强大的有限自动机(Finite Automaton, FA)到正规文法(Regular Grammar)转换器,它配备了一个直观且完整的图形用户界面,使用户能够轻松地进行操作和观察。该程序基于编译原理中的经典…...

IP如何挑?2025年海外专线IP如何购买?

你花了时间和预算买了IP,结果IP质量不佳,项目效率低下不说,还可能带来莫名的网络问题,是不是太闹心了?尤其是在面对海外专线IP时,到底怎么才能买到适合自己的呢?所以,挑IP绝对是个技…...

Selenium常用函数介绍

目录 一,元素定位 1.1 cssSeector 1.2 xpath 二,操作测试对象 三,窗口 3.1 案例 3.2 窗口切换 3.3 窗口大小 3.4 屏幕截图 3.5 关闭窗口 四,弹窗 五,等待 六,导航 七,文件上传 …...

Go语言多线程问题

打印零与奇偶数(leetcode 1116) 方法1:使用互斥锁和条件变量 package mainimport ("fmt""sync" )type ZeroEvenOdd struct {n intzeroMutex sync.MutexevenMutex sync.MutexoddMutex sync.Mutexcurrent int…...

Git常用命令完全指南:从入门到精通

Git常用命令完全指南:从入门到精通 一、基础配置命令 1. 用户信息配置 # 设置全局用户名 git config --global user.name "你的名字"# 设置全局邮箱 git config --global user.email "你的邮箱example.com"# 查看所有配置 git config --list…...

手机平板能效生态设计指令EU 2023/1670标准解读

手机平板能效生态设计指令EU 2023/1670标准解读 以下是针对欧盟《手机和平板电脑生态设计法规》(EU) 2023/1670 的核心解读,综合法规核心要求、最新修正及企业合规要点: 一、法规背景与目标 生效与强制时间 发布于2023年8月31日(OJ公报&…...

WPF八大法则:告别模态窗口卡顿

⚙️ 核心问题:阻塞式模态窗口的缺陷 原始代码中ShowDialog()会阻塞UI线程,导致后续逻辑无法执行: var result modalWindow.ShowDialog(); // 线程阻塞 ProcessResult(result); // 必须等待窗口关闭根本问题&#xff1a…...