【深度学习入门篇 ⑨】循环神经网络实战
【🍊易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊】
大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙·终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。
今天我们看一下用循环神经网络RNN的原理并且动手应用到案例。
循环神经网络
在普通的神经网络中,信息的传递是单向的,这种限制虽然使得网络变得更容易学习,但在一定程度上也减弱了神经网络模型的能力。特别是在很多现实任务中,网络的输出不仅和当前时刻的输入相关,也和其过去一段时间的输出相关。此外,普通网络难以处理时序数据,比如视频、语音、文本等,时序数据的长度一般是不固定的,而前馈神经网络要求输入和输出的维数都是固定的,不能任意改变。因此,当处理这一类和时序相关的问题时,就需要一种能力更强的模型。
循环神经网络 (RNN)是一类具有短期记忆能力的神经网络。在循环神经网络中,神经元不但可以接受其它神经元的信息,也可以接受自身的信息,形成具有环路的网络结构。
RNN比传统的神经网络多了一个循环圈,这个循环表示的就是在下一个时间步上会返回作为输入的一部分,我们把RNN在时间点上展开 :
在不同的时间步,RNN的输入都将与之前的时间状态有关 ,具体来说,每个时间步的RNN单元都会接收两个输入:当前时间步的外部输入和前一时间步(隐藏层)的输出状态。通过这种方式,RNN能够学习并理解数据中的长期依赖关系,使得它在处理文本生成、语音识别、时间序列预测等序列数据时表现尤为出色。
此外,RNN的隐藏状态(或称为内部状态)在每次迭代时都会更新,这种更新过程包含了当前输入和前一时间步状态的非线性组合,使得网络能够动态地调整其对序列中接下来内容的预测或理解。
LSTM和GRU
传统的RNN在处理长序列数据时常常面临梯度消失或梯度爆炸的问题,这限制了其在处理长期依赖关系上的能力。为了克服这一局限性,LSTM(Long Short-Term Memory,长短期记忆网络)作为RNN的一种变体被引入。
LSTM是一种RNN特殊的类型,可以学习长期依赖信息。在很多问题上,LSTM都取得相当巨大的成功,并得到了广泛的应用。
LSTM是通过一个叫做门
的结构实现,门可以选择让信息通过或者不通过。 这个门主要是通过sigmoid和点乘实现的 ;sigmoid 的取值范围是在(0,1)之间,如果接近0表示不让任何信息通过,如果接近1表示所有的信息都会通过。
- 遗忘门通过sigmoid函数来决定哪些信息会被遗忘
-
输入门决定哪些新的信息会被保留。
例如:
我昨天吃了拉面,今天我想吃炒饭
,在这个句子中,通过遗忘门可以遗忘拉面
,同时更新新的主语为炒饭。
输出门
我们需要决定什么信息会被输出,也是一样这个输出经过变换之后会通过sigmoid函数的结果来决定那些细胞状态会被输出。
-
前一次的输出和当前时间步的输入的组合结果通过sigmoid函数进行处理得到O_t
-
更新后的细胞状态C_t会经过tanh层的处理,把数据转化到(-1,1)的区间
-
tanh处理后的结果和O_t进行相乘,把结果输出同时传到下一个LSTM的单元
GRU
GRU是一种LSTM的变形版本, 它将遗忘和输入门组合成一个“更新门”。它还合并了单元状态和隐藏状态,并进行了一些其他更改,由于他的模型比标准LSTM模型简单,所以越来越受欢迎。
双向LSTM
单向的 RNN,是根据前面的信息推出后面的,但有时候只看前面的词是不够的, 可能需要预测的词语和后面的内容也相关,那么此时需要一种机制,能够让模型不仅能够从前往后的具有记忆,还需要从后往前需要记忆。此时双向LSTM就可以帮助我们解决这个问题
由于是双向LSTM,所以每个方向的LSTM都会有一个输出,最终的输出会有2部分,所以往往需要concat的操作。
RNN实现文本情感分类
torch.nn.LSTM(input_size,hidden_size,num_layers,batch_first,dropout,bidirectional)
-
input_size
:输入数据的形状,即embedding_dim -
hidden_size
:隐藏层神经元的数量,即每一层有多少个LSTM单元 -
num_layer
:即RNN的中LSTM单元的层数 -
batch_first
:默认值为False,输入的数据需要[seq_len,batch,feature]
,如果为True,则为[batch,seq_len,feature]
-
dropout
:dropout的比例,默认值为0。dropout是一种训练过程中让部分参数随机失活的一种方式,能够提高训练速度,同时能够解决过拟合的问题。 -
bidirectional
:是否使用双向LSTM,默认是False
实例化LSTM对象之后,不仅需要传入数据,还需要前一次的h_0(前一次的隐藏状态)和c_0
LSTM的默认输出为output, (h_n, c_n)
-
output
:(seq_len, batch, num_directions * hidden_size)
--->batch_first=False -
h_n
:(num_layers * num_directions, batch, hidden_size)
-
c_n
:(num_layers * num_directions, batch, hidden_size)
LSTM和GRU的使用注意点
-
第一次调用之前,需要初始化隐藏状态,如果不初始化,默认创建全为0的隐藏状态
-
往往会使用LSTM or GRU 的输出的最后一维的结果,来代表LSTM、GRU对文本处理的结果,其形状为
[batch, num_directions*hidden_size]
。
使用LSTM完成文本情感分类
class IMDBLstmmodel(nn.Module):def __init__(self):super(IMDBLstmmodel,self).__init__()self.hidden_size = 64self.embedding_dim = 200self.num_layer = 2self.bidriectional = Trueself.bi_num = 2 if self.bidriectional else 1self.dropout = 0.5self.embedding = nn.Embedding(len(ws),self.embedding_dim,padding_idx=ws.PAD) #[N,300]self.lstm = nn.LSTM(self.embedding_dim,self.hidden_size,self.num_layer,bidirectional=True,dropout=self.dropout)self.fc = nn.Linear(self.hidden_size*self.bi_num,20)self.fc2 = nn.Linear(20,2)def forward(self, x):x = self.embedding(x)x = x.permute(1,0,2) h_0,c_0 = self.init_hidden_state(x.size(1))_,(h_n,c_n) = self.lstm(x,(h_0,c_0))out = torch.cat([h_n[-2, :, :], h_n[-1, :, :]], dim=-1)out = self.fc(out)out = F.relu(out)out = self.fc2(out)return F.log_softmax(out,dim=-1)def init_hidden_state(self,batch_size):h_0 = torch.rand(self.num_layer * self.bi_num, batch_size, self.hidden_size).to(device)c_0 = torch.rand(self.num_layer * self.bi_num, batch_size, self.hidden_size).to(device)return h_0,c_0
为了提高程序的运行速度,可以考虑把模型放在GPU上运行:
-
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
-
model.to(device)
train_batch_size = 64
test_batch_size = 5000
imdb_model = IMDBLstmmodel().to(device)
optimizer = optim.Adam(imdb_model.parameters())
criterion = nn.CrossEntropyLoss()def train(epoch):mode = Trueimdb_model.train(mode)train_dataloader =get_dataloader(mode,train_batch_size)for idx,(target,input,input_lenght) in enumerate(train_dataloader):target = target.to(device)input = input.to(device)optimizer.zero_grad()output = imdb_model(input)loss = F.nll_loss(output,target) loss.backward()optimizer.step()if idx %10 == 0:pred = torch.max(output, dim=-1, keepdim=False)[-1]acc = pred.eq(target.data).cpu().numpy().mean()*100.print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t ACC: {:.6f}'.format(epoch, idx * len(input), len(train_dataloader.dataset),100. * idx / len(train_dataloader), loss.item(),acc))torch.save(imdb_model.state_dict(), "model/mnist_net.pkl")torch.save(optimizer.state_dict(), 'model/mnist_optimizer.pkl')def test():mode = Falseimdb_model.eval()test_dataloader = get_dataloader(mode, test_batch_size)with torch.no_grad():for idx,(target, input, input_lenght) in enumerate(test_dataloader):target = target.to(device)input = input.to(device)output = imdb_model(input)test_loss = F.nll_loss(output, target,reduction="mean")pred = torch.max(output,dim=-1,keepdim=False)[-1]correct = pred.eq(target.data).sum()acc = 100. * pred.eq(target.data).cpu().numpy().mean()print('idx: {} Test set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(idx,test_loss, correct, target.size(0),acc))if __name__ == "__main__":test()for i in range(10):train(i)test()
然后由大家写代码得到模型训练的最终输出,大家可以改变模型来观察不同的结果。
相关文章:
【深度学习入门篇 ⑨】循环神经网络实战
【🍊易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊】 大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官…...
宝塔安装RabbitMq教程
需要放开15672端口,默认账号密码为guest/guest...
韦东山嵌入式linux系列-驱动进化之路:设备树的引入及简明教程
1 设备树的引入与作用 以 LED 驱动为例,如果你要更换LED所用的GPIO引脚,需要修改驱动程序源码、重新编译驱动、重新加载驱动。 在内核中,使用同一个芯片的板子,它们所用的外设资源不一样,比如A板用 GPIO A,…...
长轮询(Long Polling)实现原理和java代码示例
长轮询(Long Polling)背景 长轮询是一种在Web开发中常用的技术,用于实现服务器与客户端之间的即时通信或近乎实时的数据交换。在传统的轮询(Polling)中,客户端会定期向服务器发送请求以检查是否有新数据。…...
OWASP 移动应用 2024 十大安全风险
1. OWASP 移动应用 2024 十大安全风险 开放全球应用程序安全项目 (OWASP) 是一个非营利性基金会,致力于提高软件的安全性。自 2014、2016 年两次发布了移动应用的十大风险后,今年再次发布2024版。这对移动应用软件的检查工具有着…...
Qt界面假死原因
创建一个播放器类,继承QLabel,在播放器类中起一个线程用ffmpeg取流解码,将解码后的图像保存到队列,在gui线程中调用update()刷新显示。 当ffmpeg打开视频流失败后调用update()将qlabel刷新为黑色,有一定概率会使得qla…...
python调用MATLAB出错matlab.engine.MatlabExecutionError无法调用MATLAB函数报错
python调用MATLAB出错matlab.engine.MatlabExecutionError无法调用MATLAB函数报错 说明(废话)解决方案MATLAB异常乱码python矩阵转MATLAB矩阵matlab.engine.MatlabExecutionError 说明(废话) python调用MATLAB,调用m文件中的函数,刚开始都没有问题&…...
[GXYCTF2019]Ping Ping Ping1
打开靶机 结合题目名称,考虑是命令注入,试试ls 结果应该就在flag.php。尝试构造命令注入载荷。 cat flag.php 可以看到过滤了空格,用 $IFS$1替换空格 还过滤了flag,我们用字符拼接的方式看能否绕过,ag;cat$IFS$1fla$a.php。注意这里用分号间隔…...
成为git砖家(1): author 和 committer 的区别
大家好,我是白鱼。一直对 git author 和 committer 不太了解, 今天通过 cherry-pick 的例子搞清楚了区别。 原理 例如我克隆了著名开源项目 spdlog 的源码, 根据某个历史 commit A 创建了分支, 然后 cherry-pick 了这个 commit …...
Lianwei 安全周报|2024.07.15
新的一周又开始了,以下是本周「Lianwei周报」,我们总结推荐了本周的政策/标准/指南最新动态、热点资讯和安全事件,保证大家不错过本周的每一个重点! 政策/标准/指南最新动态 01 《人工智能全球治理上海宣言》发布 我们强调共同促…...
Linux - 基础开发工具(yum、vim、gcc、g++、make/Makefile、git、gdb)
目录 Linux软件包管理器 - yum Linux下安装软件的方式 认识yum 查找软件包 安装软件 如何实现本地机器和云服务器之间的文件互传 卸载软件 Linux编辑器 - vim vim的基本概念 vim下各模式的切换 vim命令模式各命令汇总 vim底行模式各命令汇总 vim的简单配置 Linux编译器 - gc…...
Git使用介绍教程
Git使用介绍教程 小白第一次写博客,内容写的可能不是很详细,仅供参考,大家一起努力 gitee网址:https://gitee.com 大部分的开发团队都以 Git 作为自己的版本控制工具,需要对 Git 的使用非常的熟悉。这篇文章中本人整理了自己在开发过程中经常使用到的 Git 命令,方便在偶…...
STM32的TIM1之PWM互补输出_死区时间和刹车配置
STM32的TIM1之PWM互补输出_死区时间和刹车配置 1、定时器1的PWM输出通道 STM32高级定时器TIM1在用作PWM互补输出时,共有4个输出通道,其中有3个是互补输出通道,如下: 通道1:TIM1_CH1对应PA8引脚,TIM1_CH1N对应PB13引…...
C++复习的长文指南
C复习的长文指南 一、入门语法知识1.预备1.1 main函数1.2 注释1.3 变量1.3 常量1.4 关键字1.5 标识符明明规则 2. 数据类型2.1 整型2.1.1 sizeof关键字 2.2 实型(浮点型)2.3 字符型2.4 转义字符2.5 字符串型2.6 布尔类型bool2.7 数据的输入 3. 运算符3.1…...
深入了解MySQL文件排序
数据准备 CREATE TABLE user_info (id bigint(20) NOT NULL AUTO_INCREMENT COMMENT ID,name varchar(20) NOT NULL COMMENT 用户名,age tinyint(4) NOT NULL DEFAULT 0 COMMENT 年龄,sex tinyint(2) NOT NULL DEFAULT 0 COMMENT 状态 0:男 1: 女,creat…...
【JAVA基础】反射
编译期和运行期 首先大家应该先了解两个概念,编译期和运行期,编译期就是编译器帮你把源代码翻译成机器能识别的代码,比如编译器把java代码编译成jvm识别的字节码文件,而运行期指的是将可执行文件交给操作系统去执行, …...
贪心算法(2024/7/16)
1合并区间 以数组 intervals 表示若干个区间的集合,其中单个区间为 intervals[i] [starti, endi] 。请你合并所有重叠的区间,并返回 一个不重叠的区间数组,该数组需恰好覆盖输入中的所有区间 。 示例 1: 输入:inter…...
Python 在Word表格中插入、删除行或列
Word文档中的表格可以用于组织和展示数据。在实际应用过程中,有时为了调整表格的结构或适应不同的数据展示需求,我们可能会需要插入、删除行或列。以下提供了几种使用Python在Word表格中插入或删除行、列的方法供参考: 文章目录 Python 在Wo…...
Java二十三种设计模式-单例模式(1/23)
引言 在软件开发中,设计模式是一套被反复使用的、大家公认的、经过分类编目的代码设计经验的总结。单例模式作为其中一种创建型模式,确保一个类只有一个实例,并提供一个全局访问点。本文将深入探讨单例模式的概念、实现方式、使用场景以及潜…...
Unity动画系统(3)---融合树
6.1 动画系统基础2-6_哔哩哔哩_bilibili Animator类 using System.Collections; using System.Collections.Generic; using UnityEngine; public class EthanController : MonoBehaviour { private Animator ani; private void Awake() { ani GetComponen…...
sqlalchemy.orm中validates对两个字段进行联合校验
版本 sqlalchemy1.4.37 需求说明 有个场景,需要在orm中对两个字段进行联合校验,当 col1 xxx’时,对 col2的长度进行检查,超过限制(500)时,进行截断。 网上找了很久,没找到类似的…...
【ROS2】高级:解锁 Fast DDS 中间件的潜力 [社区贡献]
目标:本教程将展示如何在 ROS 2 中使用 Fast DDS 的扩展配置功能。 教程级别:高级 时间:20 分钟 目录 背景 先决条件在同一个节点中混合同步和异步发布 创建具有发布者的节点创建包含配置文件的 XML 文件执行发布者节点创建一个包含订阅者的节…...
VirtualBox虚拟机与主机互传文件的方法
建立共享文件夹 1.点击设置,点击共享文件夹,添加共享文件夹路径,保存 2.启动虚拟机,点击设备,点击安装增强功能,界面会出现一个光碟图标,点击光碟图标 3.打开光碟图标,出现一个目…...
访问控制系列
目录 一、基本概念 1.客体与主体 2.引用监控器与引用验证机制 3.安全策略与安全模型 4.安全内核 5.可信计算基 二、访问矩阵 三、访问控制策略 1.主体属性 2.客体属性 3.授权者组成 4.访问控制粒度 5.主体、客体状态 6.历史记录和上下文环境 7.数据内容 8.决策…...
【BUG】已解决:ModuleNotFoundError: No module named ‘cv2’
已解决:ModuleNotFoundError: No module named ‘cv2’ 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页,我是博主英杰,211科班出身,就职于医疗科技公司,热衷分享知识,武汉城市开…...
成都亚恒丰创教育科技有限公司 【插画猴子:笔尖下的灵动世界】
在浩瀚的艺术海洋中,每一种创作形式都是人类情感与想象力的独特表达。而插画,作为这一广阔领域中的璀璨明珠,以其独特的视觉语言和丰富的叙事能力,构建了一个又一个令人遐想连篇的梦幻空间。成都亚恒丰创教育科技有限公司 在众多插…...
gite+picgo+typora打造个人免费笔记软件
文章目录 1️⃣个人笔记软件2️⃣ 配置教程2.1 使用软件2.2 node 环境配置2.3 软件安装2.4 gite仓库设置2.5 配置picgo2.6 测试检验2.7 github教程 🎡 完结撒花 1️⃣个人笔记软件 最近换了环境,没有之前的生产环境舒适,写笔记也没有劲头&…...
只用 CSS 能玩出什么花样?
在前端开发领域,CSS 不仅仅是一种样式语言,它更像是一位多才多艺的艺术家,能够创造出令人惊叹的视觉效果。本文将带你探索 CSS 的无限可能,从基本形状到动态动画,从几何艺术到仿生设计,只用 CSS 就能玩出令…...
Linux C++ 056-设计模式之迭代器模式
Linux C 056-设计模式之迭代器模式 本节关键字:Linux、C、设计模式、迭代器模式 相关库函数: 概念 迭代器模式(Iterator Pattern)是一种常用的设计模式。迭代器模式提供一种方法顺序访问一个聚合对象中的各个元素,而…...
【Elasticsearch7.11】reindex问题
参考博文链接 问题:reindex 时出现如下问题 原因:数据量大,kibana的问题 解决方法: 将DSL命令转化成CURL命令在服务上执行 CURL命令 自动转化 curl -XPOST "http://IP:PORT/_reindex" -H Content-Type: application…...
网站开发需要的人员/手机优化什么意思
计算机应用基础与操作(中国铁道出版社2014年出版图书)语音编辑锁定讨论上传视频本词条缺少概述图,补充相关内容使词条更完整,还能快速升级,赶紧来编辑吧!《计算机应用基础与操作》由宋岐主编。本书主要介绍了计算机基础知识和操作…...
wordpress移动端导航菜单/互联网推广的好处
今天想说的是不要让自己企业的信息化穿上“皇帝的新装”,从这几年的工作来看,以安全为例,安全对于大家来说越来越受到重视,从国家政策及舆论层面每日剧增,信息安全或者说网络安全没有绝对安全%100安全,企业…...
购物网站成功案例/打开百度搜索
随着项目工程的增大,花在编译的时间会越来越长。为了提高编译效率,我们可以开启多线程来提高编译速度,充分利用多核机器的性能来优化编译。 1.windows下。目前windows下我们使用vs2012编译工程。vs可以通过以下方法打开多核编译,如…...
武汉专业制作网站/兰州网站seo
出现如下问题, 解决方法: 更换gensim的版本 pip install gensim 3.0.0...
wordpress文章 代码块/百度教育官网
Oracle 数据库中有 ROWNUM 这个功能,查询 list 后生成序号,很是方便,但 MySQL 是模拟了 Oracle 和 SQL Server 中的大部分功能,可自动生成序号却没有现成的函数或伪序列,很多情况下最后只能在后端代码或者前端代码中实…...
adsense用什么网站做/如何设计企业网站
题目链接:http://www.spoj.com/problems/PROOT/ 题目大意:给出一个整数p,p为素数,给出n个数x,判断x是否为p的原根。 解题思路:参考:http://www.apfloat.org/prim.html 大致思想:如果…...