Pytorch intermediate(三) RNN分类
使用RNN对MNIST手写数字进行分类。RNN和LSTM模型结构
pytorch中的LSTM的使用让人有点头晕,这里讲述的是LSTM的模型参数的意义。
1、加载数据集
import torch
import torchvision
import torch.nn as nn
import torchvision.transforms as transforms
import torch.utils.data as Data device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')sequence_length = 28
input_size = 28
hidden_size = 128
num_layers = 2
num_classes = 10
batch_size = 128
num_epochs = 2
learning_rate = 0.01 train_dataset = torchvision.datasets.MNIST(root='./data/',train=True,transform=transforms.ToTensor(),download=True)
test_dataset = torchvision.datasets.MNIST(root='./data/',train=False,transform=transforms.ToTensor())train_loader = Data.DataLoader(dataset=train_dataset,batch_size=batch_size,shuffle=True)
test_loader = Data.DataLoader(dataset=test_dataset,batch_size=batch_size)
2、构建RNN模型
-
input_size – 输入的特征维度
-
hidden_size – 隐状态的特征维度
-
num_layers – 层数(和时序展开要区分开)
-
bias – 如果为
False
,那么LSTM
将不会使用,默认为True
。 -
batch_first – 如果为
True
,那么输入和输出Tensor
的形状为(batch, seq, feature)
-
dropout – 如果非零的话,将会在
RNN
的输出上加个dropout
,最后一层除外。 -
bidirectional – 如果为
True
,将会变成一个双向RNN
,默认为False
1、上面的参数来自于文档,最基本的参数是input_size, hidden_size, num_layer三个。input_size:输入数据向量维度,在这里为28;hidden_size:隐藏层特征维度,也是输出的特征维度,这里是128;num_layers:lstm模块个数,这里是2。
2、h0和c0的初始化维度为(num_layer,batch_size, hidden_size)
3、lstm的输出有out和(hn,cn),其中out.shape = torch.Size([128, 28, 128]),对应(batch_size,时序数,隐藏特征维度),也就是保存了28个时序的输出特征,因为做的分类,所以只需要最后的输出特征。所以取出最后的输出特征,进行全连接计算,全连接计算的输出维度为10(10分类)。
4、batch_first这个参数比较特殊:如果为true,那么输入数据的维度为(batch, seq, feature),否则为(seq, batch, feature)
5、num_layers:lstm模块个数,如果有两个,那么第一个模块的输出会变成第二个模块的输入。
总结:构建一个LSTM模型要用到的参数,(输入数据的特征维度,隐藏层的特征维度,lstm模块个数);时序的个数体现在X中, X.shape = (batch_size, 时序长度, 数据向量维度)。
可以理解为LSTM可以根据我们的输入来实现自动的时序匹配,从而达到输入长短不同的功能。
class RNN(nn.Module):def __init__(self, input_size,hidden_size,num_layers, num_classes):super(RNN, self).__init__()self.hidden_size = hidden_sizeself.num_layers = num_layers#input_size - 输入特征维度#hidden_size - 隐藏状态特征维度#num_layers - 层数(和时序展开要区分开),lstm模块的个数#batch_first为true,输入和输出的形状为(batch, seq, feature),true意为将batch_size放在第一维度,否则放在第二维度self.lstm = nn.LSTM(input_size,hidden_size,num_layers,batch_first = True) self.fc = nn.Linear(hidden_size, num_classes)def forward(self,x):#参数:LSTM单元个数, batch_size, 隐藏层单元个数 h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device) #h0.shape = (2, 128, 128)c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(device)#输出output : (seq_len, batch, hidden_size * num_directions)#(h_n, c_n):最后一个时间步的隐藏状态和细胞状态#对out的理解:维度batch, eq_len, hidden_size,其中保存着每个时序对应的输出,所以全连接部分只取最后一个时序的#out第一维batch_size,第二维时序的个数,第三维隐藏层个数,所以和lstm单元的个数是无关的out,_ = self.lstm(x, (h0, c0)) #shape = torch.Size([128, 28, 128])out = self.fc(out[:,-1,:]) #因为batch_first = true,所以维度顺序batch, eq_len, hidden_sizereturn out
训练部分
model = RNN(input_size,hidden_size, num_layers, num_classes).to(device)
print(model)#RNN(
# (lstm): LSTM(28, 128, num_layers=2, batch_first=True)
# (fc): Linear(in_features=128, out_features=10, bias=True)
#)criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)total_step = len(train_loader)
for epoch in range(num_epochs):for i,(images, labels) in enumerate(train_loader):#batch_size = -1, 序列长度 = 28, 数据向量维度 = 28images = images.reshape(-1, sequence_length, input_size).to(device)labels = labels.to(device)# Forward passoutputs = model(images)loss = criterion(outputs, labels)# Backward and optimizeoptimizer.zero_grad()loss.backward() optimizer.step()if (i+1) % 100 == 0:print(outputs.shape)print ('Epoch [{}/{}], Step [{}/{}], Loss: {:.4f}' .format(epoch+1, num_epochs, i+1, total_step, loss.item()))
# Test the model
with torch.no_grad():correct = 0total = 0for images, labels in test_loader:images = images.reshape(-1, sequence_length, input_size).to(device)labels = labels.to(device)outputs = model(images)_, predicted = torch.max(outputs.data, 1)total += labels.size(0)correct += (predicted == labels).sum().item()print('Test Accuracy of the model on the 10000 test images: {} %'.format(100 * correct / total))
相关文章:
Pytorch intermediate(三) RNN分类
使用RNN对MNIST手写数字进行分类。RNN和LSTM模型结构 pytorch中的LSTM的使用让人有点头晕,这里讲述的是LSTM的模型参数的意义。 1、加载数据集 import torch import torchvision import torch.nn as nn import torchvision.transforms as transforms import torc…...
vue2+webpack升级vue3+vite,修改插件兼容性bug
同学们可以私信我加入学习群! 前言 在前面使用electronvue3的过程中,已经验证了历史vue2代码vue3混合开发的模式。 本次旧项目vue框架整体升级中,同事已经完成了vue3、pinia、router等基础框架工具的升级。所以我此次记录的主要是vite打包工…...
案例实战-Spring boot Web
准备工作 需求&环境搭建 需求: 部门管理: 查询部门列表 删除部门 新增部门 修改部门 员工管理 查询员工列表(分页、条件) 删除员工 新增员工 修改员工 环境搭建 准备数据库表(dept、emp) -- 部门管理…...
Spring6.1之RestClient分析
文章目录 1 RestClient1.1 介绍1.2 准备项目1.2.1 pom.xml1.2.2 创建全局 RestClient1.2.3 Get接收数据 retrieve1.2.4 结果转换 Bean1.2.5 Post发布数据1.2.6 Delete删除数据1.2.7 处理错误1.2.8 Exchange 方法 1 RestClient 1.1 介绍 Spring 框架一直提供了两种不同的客户端…...
冒泡排序、选择排序、插入排序、希尔排序
冒泡排序 基本思想 代码实现 # 冒泡排序 def bubble_sort(arr):length len(arr) - 1for i in range(length):flag Truefor j in range(length - i):if arr[j] > arr[j 1]:temp arr[j]arr[j] arr[j 1]arr[j 1] tempflag Falseprint(f第{i 1}趟的排序结果为&#…...
OpenCV(二十三):中值滤波
1.中值滤波的原理 中值滤波(Median Filter)是一种常用的非线性图像滤波方法,用于去除图像中的椒盐噪声等离群点。它的原理是基于邻域像素值的排序,并将中间值作为当前像素的新值。 2.中值滤波函数 medianBlur() void cv::medianBl…...
Prompt Tuning训练过程
目录 0. 入门 0.1. NLP发展的四个阶段: Prompt工程如此强大,我们还需要模型训练吗? - 知乎 Prompt learning系列之prompt engineering(二) 离散型prompt自动构建 Prompt learning系列之训练策略篇 - 知乎 ptuning v2 的 chatglm垂直领域训练记…...
装备制造企业是否要转型智能装备后服务型公司?
一、从制造到服务:装备制造企业的转型之路 装备制造企业作为国家经济发展的重要支柱,面临着日益激烈的市场竞争。在这样的背景下,越来越多的装备制造企业开始意识到,通过转型为智能装备后服务型公司,可以更好地满足客…...
day-49 代码随想录算法训练营(19) 动态规划 part 10
121.买卖股票的最佳时机 思路一:贪心 不断更新最小买入值不断更新当前值和最小买入值的差值最大值 思路二:动态规划(今天自己写出来了哈哈哈哈哈哈哈) 1.dp存储:dp[i][0] 表示当前持有 dp[i][1]表示当前不持有2.状…...
检查文件名是否含不可打印字符的C++代码源码
本篇文章属于《518抽奖软件开发日志》系列文章的一部分。 我在开发《518抽奖软件》(www.518cj.net)的时候,有时候需要检查输入的是否是合法的文件名,文件名是否含不可打印字符等。代码如下: //----------------------…...
学习笔记-正则表达式
https://www.runoob.com/regexp/regexp-tutorial.html 正则表达式re(Regular Expression)是一种文本模式,包括普通字符(例如,a 到 z 之间的字母)和特殊字符(称为"元字符"),可以用来描…...
Wireshark TS | 网络路径不一致传输丢包问题
问题背景 网络路径不一致,或者说是网络路径来回不一致,再专业点可以说是网络路径不对称,以上种种说法,做网络方向的工程师肯定会更清楚些,用简单的描述就是: A 与 B 通讯场景,C 和 D 代表中间…...
CMake高级用法实例分析(学习paddle官方的CMakeLists)
cmake基础学习教程 https://juejin.cn/post/6844903557183832078 官方完整CMakeLists cmake_minimum_required(VERSION 3.0) project(PaddleObjectDetector CXX C)option(WITH_MKL "Compile demo with MKL/OpenBlas support,defaultuseMKL." ON) o…...
数据采集: selenium 自动翻页接口调用时的验证码处理
写在前面 工作中遇到,简单整理理解不足小伙伴帮忙指正 对每个人而言,真正的职责只有一个:找到自我。然后在心中坚守其一生,全心全意,永不停息。所有其它的路都是不完整的,是人的逃避方式,是对大…...
IDEA安装翻译插件
IDEA安装翻译插件 File->Settings->Plugins 在Marketplace中,找到Translation,点击Install 更换翻译引擎 勾选自动翻译文档 翻译 鼠标右击->点击Translate...
DBeaver使用
一、导出表结构 二、导出数据CSV 导出数据时DBeaver并没有导出表结构,所以表结构需要额外保存; 导入数据CSV 导入数据时会因外键、字段长度导致失败;...
Nougat:一种用于科学文档OCR的Transformer 模型
随着人工智能领域的不断进步,其子领域,包括自然语言处理,自然语言生成,计算机视觉等,由于其广泛的用例而迅速获得了大量的普及。光学字符识别(OCR)是计算机视觉中一个成熟且被广泛研究的领域。它有许多用途,…...
redis八股1
参考Redis连环60问(八股文背诵版) - 知乎 (zhihu.com) 1.是什么 本质上是一个key-val数据库,把整个数据库加载到内存中操作,定期通过异步操作把数据flush到硬盘持久化。因为纯内存操作,所以性能很出色,每秒可以超过10…...
人工智能基础-趋势-架构
在过去的几周里,我花了一些时间来了解生成式人工智能基础设施的前景。在这篇文章中,我的目标是清晰概述关键组成部分、新兴趋势,并重点介绍推动创新的早期行业参与者。我将解释基础模型、计算、框架、计算、编排和矢量数据库、微调、标签、合…...
Date日期工具类(数据库日期区间问题)
文章目录 前言DateUtils日期工具类总结 前言 在我们日常开发过程中,当涉及到处理日期和时间的操作时,字符串与Date日期类往往要经过相互转换,且在SQL语句的动态查询中,往往月份的格式不正确,SQL语句执行的效果是不同的…...
为什么需要 TIME_WAIT 状态
还是用一下上一篇文章画的图 TCP 的 11 个状态,每一个状态都缺一不可,自然 TIME_WAIT 状态被赋予的意义也是相当重要,咱们直接结论先行 上文我们提到 tcp 中,主动关闭的一边会进入 TIME_WAIT 状态, 另外 Tcp 中的有 …...
Linux——(第七章)文件权限管理
目录 一、基本介绍 二、文件/目录的所有者 1.查看文件的所有者 2.修改文件所有者 三、文件/目录的所在组 1.修改文件/目录所在组 2.修改用户所在组 四、权限的基本介绍 五、rwx权限详解 1.rwx作用到文件 2.rwx作用到目录 六、修改权限 一、基本介绍 在Linux中&…...
Scala在大数据领域的崛起:当前趋势和未来前景
文章首发地址 Scala在大数据领域有着广阔的前景和现状。以下是一些关键点: Scala是一种具有强大静态类型系统的多范式编程语言,它结合了面向对象编程和函数式编程的特性。这使得Scala非常适合处理大数据,因为它能够处理并发、高吞吐量和复杂…...
前端面试经典题--页面布局
题目 假设高度已知,请写出三栏布局,其中左、右栏宽度各为300px,中间自适应。 五种解决方式代码 浮动解决方式 绝对定位解决方式 flexbox解决方式 表格布局 网格布局 源代码 <!DOCTYPE html> <html lang"en"> <…...
【webrtc】接收/发送的rtp包、编解码的VCM包、CopyOnWriteBuffer
收到的rtp包RtpPacketReceived 经过RtpDepacketizer 解析后变为ParsedPayloadRtpPacketReceived 分配内存,执行memcpy拷贝:然后把 RtpPacketReceived 给到OnRtpPacket 传递:uint8_t* media_payload = media_packet.AllocatePayload(rtx_payload.size());RTC...
Bash常见快捷键
生活在 Bash Shell 中,熟记以下快捷键,将极大的提高你的命令行操作效率。 编辑命令 Ctrl a :移到命令行首Ctrl e :移到命令行尾Ctrl f :按字符前移(右向)Ctrl b :按字符后移&a…...
软件验收测试
1. 服务流程 验收测试 2. 服务内容 测试过程中,根据合同要求制定测试方案,验证工程项目是否满足用户需求,软件质量特性是否达到系统的要求。 3. 周期 10-15个工作日 4. 报告用途 可作为进行地方、省级、国家、部委项目的验收࿰…...
Java 与零拷贝
零拷贝是由操作系统实现的,使用 Java 中的零拷贝抽象类库在支持零拷贝的操作系统上运行才会实现零拷贝,如果在不支持零拷贝的操作系统上运行,并不会提供零拷贝的功能。 简述内核态和用户态 Linux 的体系结构分为内核态(内核空间…...
AI性能指标解析:误触率与错误率
简介:随着人工智能(AI)技术的不断发展,它越来越多地渗透到我们日常生活的各个方面。从个人助手到自动驾驶,从语音识别到图像识别,AI正不断地改变我们与世界的互动方式。但你有没有想过,如何准确…...
count(*) 和 count(1) 有什么区别?哪个性能最好?
哪种 count 性能最好? count() 是什么? count() 是一个聚合函数,函数的参数不仅可以是字段名,也可以是其他任意表达式,该函数的作用是统计符合查询条件的记录中,函数指定的参数不为 NULL 的记录由多少条。…...
网站更新步骤/搜索引擎优化教材答案
条款1根据ParamType的特征,即一般形式的函数模板中param的型别饰词,将模板型别推导分为三种情况、而在采用auto进行变量声明中,型别饰词取代了ParamType,所以也存在三种情况: 情况1,型别饰词是指针或引用,…...
北京建设教育网站/百度贴吧首页
1062. 最简分数(20) 时间限制400 ms内存限制65536 kB乙级练习题解目录一个分数一般写成两个整数相除的形式:N/M,其中M不为0。最简分数是指分子和分母没有公约数的分数表示形式。 现给定两个不相等的正分数 N1/M1 和 N2/M2,要求你按从小到大的…...
中国建设在线平台官网/谷歌seo最好的公司
在下午学习JavaScript数组的过程中,多次用到了比值函数 比值函数function( a, b )是和JavaScript里的sort( )函数一起使用的,比值函数嵌套在sort( )函数的圆括号里 为什么要用比值函数? sort() 以字母顺序对数组进行升序排序而数字顺序sor…...
go语言视频网站开发/重庆关键词优化
Smartbi Mining平台是一个注重于实际生产应用的数据分析预测平台,它旨在为个人、团队和企业所做的决策提供预测。该平台不仅可为用户提供直观的流式建模、拖拽式操作和流程化、可视化的建模界面,还提供了大量的数据预处理操作。此外,它内置了…...
html5网站开发语言的有点/百度推广关键词质量度
接下来的连续几篇,我们要演练作一个描述通用的二维表,并演示该二维表的继承通途。先定义数据行的状态1/**//// <summary> 2 /// 数据行的状态 3 /// </summary>4publicenumRowState5{ 6 /**//// <summary> 7 /// …...
广告设计论文/什么是seo网站优化
智能电视除了安装常规软件之外,还可以安装一些我们比较常用的社交软件你知道吗?今天小智就跟大家分享一下智能电视如何安装微信。安装微信的第一步就是要在电视或盒子上安装第三方应用当贝市场,大家需要准备一个U盘,将U盘插入电脑…...