当前位置: 首页 > news >正文

【深度学习入门篇 ⑨】循环神经网络实战

【🍊易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊】

大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙·终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官、CSDN人工智能领域优质创作者 。


今天我们看一下用循环神经网络RNN的原理并且动手应用到案例。

3e012755cfd647aebdf70ff24536d38b.png 

循环神经网络

在普通的神经网络中,信息的传递是单向的,这种限制虽然使得网络变得更容易学习,但在一定程度上也减弱了神经网络模型的能力。特别是在很多现实任务中,网络的输出不仅和当前时刻的输入相关,也和其过去一段时间的输出相关。此外,普通网络难以处理时序数据,比如视频、语音、文本等,时序数据的长度一般是不固定的,而前馈神经网络要求输入和输出的维数都是固定的,不能任意改变。因此,当处理这一类和时序相关的问题时,就需要一种能力更强的模型。

循环神经网络 (RNN)是一类具有短期记忆能力的神经网络。在循环神经网络中,神经元不但可以接受其它神经元的信息,也可以接受自身的信息,形成具有环路的网络结构。  

ab119b30479c4d74bb10bf02ef0d9f34.png 

RNN比传统的神经网络多了一个循环圈,这个循环表示的就是在下一个时间步上会返回作为输入的一部分,我们把RNN在时间点上展开 :

6e2096802ad346c1836d1ede9370a9fe.png

在不同的时间步,RNN的输入都将与之前的时间状态有关 ,具体来说,每个时间步的RNN单元都会接收两个输入:当前时间步的外部输入和前一时间步(隐藏层)的输出状态。通过这种方式,RNN能够学习并理解数据中的长期依赖关系,使得它在处理文本生成、语音识别、时间序列预测等序列数据时表现尤为出色。

此外,RNN的隐藏状态(或称为内部状态)在每次迭代时都会更新,这种更新过程包含了当前输入和前一时间步状态的非线性组合,使得网络能够动态地调整其对序列中接下来内容的预测或理解。

d1ad2acff14b48458791021e8ce8eaa5.png

LSTM和GRU

传统的RNN在处理长序列数据时常常面临梯度消失或梯度爆炸的问题,这限制了其在处理长期依赖关系上的能力。为了克服这一局限性,LSTM(Long Short-Term Memory,长短期记忆网络)作为RNN的一种变体被引入。

LSTM是一种RNN特殊的类型,可以学习长期依赖信息。在很多问题上,LSTM都取得相当巨大的成功,并得到了广泛的应用。

48465d18371741739f23324e0f1f3e05.png

LSTM是通过一个叫做的结构实现,门可以选择让信息通过或者不通过。 这个门主要是通过sigmoid和点乘实现的 ;sigmoid 的取值范围是在(0,1)之间,如果接近0表示不让任何信息通过,如果接近1表示所有的信息都会通过。

  • 遗忘门通过sigmoid函数来决定哪些信息会被遗忘
  • 输入门决定哪些新的信息会被保留。

例如:

我昨天吃了拉面,今天我想吃炒饭,在这个句子中,通过遗忘门可以遗忘拉面,同时更新新的主语为炒饭。

输出门

我们需要决定什么信息会被输出,也是一样这个输出经过变换之后会通过sigmoid函数的结果来决定那些细胞状态会被输出。

  1. 前一次的输出和当前时间步的输入的组合结果通过sigmoid函数进行处理得到O_t

  2. 更新后的细胞状态C_t会经过tanh层的处理,把数据转化到(-1,1)的区间

  3. tanh处理后的结果和O_t进行相乘,把结果输出同时传到下一个LSTM的单元

8ca0b205bcfa44e18c3af5b4f7271880.png 

GRU

GRU是一种LSTM的变形版本, 它将遗忘和输入门组合成一个“更新门”。它还合并了单元状态和隐藏状态,并进行了一些其他更改,由于他的模型比标准LSTM模型简单,所以越来越受欢迎。

664e50357e604f918c707643ca15bc9c.png

b429639b6a994ec099f87d8adf609263.png 

双向LSTM

单向的 RNN,是根据前面的信息推出后面的,但有时候只看前面的词是不够的, 可能需要预测的词语和后面的内容也相关,那么此时需要一种机制,能够让模型不仅能够从前往后的具有记忆,还需要从后往前需要记忆。此时双向LSTM就可以帮助我们解决这个问题

f990226c2e3a4c9da262cc74ff2201e4.png 

由于是双向LSTM,所以每个方向的LSTM都会有一个输出,最终的输出会有2部分,所以往往需要concat的操作。

96f81f98d8e74dadaa1f4925a3406007.pngRNN实现文本情感分类 

torch.nn.LSTM(input_size,hidden_size,num_layers,batch_first,dropout,bidirectional)
  1. input_size:输入数据的形状,即embedding_dim

  2. hidden_size:隐藏层神经元的数量,即每一层有多少个LSTM单元

  3. num_layer :即RNN的中LSTM单元的层数

  4. batch_first:默认值为False,输入的数据需要[seq_len,batch,feature],如果为True,则为[batch,seq_len,feature]

  5. dropout:dropout的比例,默认值为0。dropout是一种训练过程中让部分参数随机失活的一种方式,能够提高训练速度,同时能够解决过拟合的问题。

  6. bidirectional:是否使用双向LSTM,默认是False

实例化LSTM对象之后,不仅需要传入数据,还需要前一次的h_0(前一次的隐藏状态)和c_0

LSTM的默认输出为output, (h_n, c_n)  

  1. output(seq_len, batch, num_directions * hidden_size)--->batch_first=False

  2. h_n:(num_layers * num_directions, batch, hidden_size)

  3. c_n: (num_layers * num_directions, batch, hidden_size)

 4b9843ea2e35484f86a90641afd0fff6.png

LSTM和GRU的使用注意点

  1. 第一次调用之前,需要初始化隐藏状态,如果不初始化,默认创建全为0的隐藏状态

  2. 往往会使用LSTM or GRU 的输出的最后一维的结果,来代表LSTM、GRU对文本处理的结果,其形状为[batch, num_directions*hidden_size]

使用LSTM完成文本情感分类

class IMDBLstmmodel(nn.Module):def __init__(self):super(IMDBLstmmodel,self).__init__()self.hidden_size = 64self.embedding_dim = 200self.num_layer = 2self.bidriectional = Trueself.bi_num = 2 if self.bidriectional else 1self.dropout = 0.5self.embedding = nn.Embedding(len(ws),self.embedding_dim,padding_idx=ws.PAD) #[N,300]self.lstm = nn.LSTM(self.embedding_dim,self.hidden_size,self.num_layer,bidirectional=True,dropout=self.dropout)self.fc = nn.Linear(self.hidden_size*self.bi_num,20)self.fc2 = nn.Linear(20,2)def forward(self, x):x = self.embedding(x)x = x.permute(1,0,2) h_0,c_0 = self.init_hidden_state(x.size(1))_,(h_n,c_n) = self.lstm(x,(h_0,c_0))out = torch.cat([h_n[-2, :, :], h_n[-1, :, :]], dim=-1)out = self.fc(out)out = F.relu(out)out = self.fc2(out)return F.log_softmax(out,dim=-1)def init_hidden_state(self,batch_size):h_0 = torch.rand(self.num_layer * self.bi_num, batch_size, self.hidden_size).to(device)c_0 = torch.rand(self.num_layer * self.bi_num, batch_size, self.hidden_size).to(device)return h_0,c_0

为了提高程序的运行速度,可以考虑把模型放在GPU上运行:

  1. device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

  2. model.to(device)

train_batch_size = 64
test_batch_size = 5000
imdb_model = IMDBLstmmodel().to(device) 
optimizer = optim.Adam(imdb_model.parameters())
criterion = nn.CrossEntropyLoss()def train(epoch):mode = Trueimdb_model.train(mode)train_dataloader =get_dataloader(mode,train_batch_size)for idx,(target,input,input_lenght) in enumerate(train_dataloader):target = target.to(device)input = input.to(device)optimizer.zero_grad()output = imdb_model(input)loss = F.nll_loss(output,target) loss.backward()optimizer.step()if idx %10 == 0:pred = torch.max(output, dim=-1, keepdim=False)[-1]acc = pred.eq(target.data).cpu().numpy().mean()*100.print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\t ACC: {:.6f}'.format(epoch, idx * len(input), len(train_dataloader.dataset),100. * idx / len(train_dataloader), loss.item(),acc))torch.save(imdb_model.state_dict(), "model/mnist_net.pkl")torch.save(optimizer.state_dict(), 'model/mnist_optimizer.pkl')def test():mode = Falseimdb_model.eval()test_dataloader = get_dataloader(mode, test_batch_size)with torch.no_grad():for idx,(target, input, input_lenght) in enumerate(test_dataloader):target = target.to(device)input = input.to(device)output = imdb_model(input)test_loss  = F.nll_loss(output, target,reduction="mean")pred = torch.max(output,dim=-1,keepdim=False)[-1]correct = pred.eq(target.data).sum()acc = 100. * pred.eq(target.data).cpu().numpy().mean()print('idx: {} Test set: Avg. loss: {:.4f}, Accuracy: {}/{} ({:.2f}%)\n'.format(idx,test_loss, correct, target.size(0),acc))if __name__ == "__main__":test()for i in range(10):train(i)test()

然后由大家写代码得到模型训练的最终输出,大家可以改变模型来观察不同的结果。

 

相关文章:

【深度学习入门篇 ⑨】循环神经网络实战

【🍊易编橙:一个帮助编程小伙伴少走弯路的终身成长社群🍊】 大家好,我是小森( ﹡ˆoˆ﹡ ) ! 易编橙终身成长社群创始团队嘉宾,橙似锦计划领衔成员、阿里云专家博主、腾讯云内容共创官…...

宝塔安装RabbitMq教程

需要放开15672端口,默认账号密码为guest/guest...

韦东山嵌入式linux系列-驱动进化之路:设备树的引入及简明教程

1 设备树的引入与作用 以 LED 驱动为例,如果你要更换LED所用的GPIO引脚,需要修改驱动程序源码、重新编译驱动、重新加载驱动。 在内核中,使用同一个芯片的板子,它们所用的外设资源不一样,比如A板用 GPIO A&#xff0c…...

长轮询(Long Polling)实现原理和java代码示例

长轮询(Long Polling)背景 长轮询是一种在Web开发中常用的技术,用于实现服务器与客户端之间的即时通信或近乎实时的数据交换。在传统的轮询(Polling)中,客户端会定期向服务器发送请求以检查是否有新数据。…...

OWASP 移动应用 2024 十大安全风险

1. OWASP 移动应用 2024 十大安全风险 开放全球应用程序安全项目 (OWASP) 是一个非营利性基金会,致力于提高软件的安全性。自 2014、2016 年两次发布了移动应用的十大风险后,今年再次发布2024版。这对移动应用软件的检查工具有着…...

Qt界面假死原因

创建一个播放器类,继承QLabel,在播放器类中起一个线程用ffmpeg取流解码,将解码后的图像保存到队列,在gui线程中调用update()刷新显示。 当ffmpeg打开视频流失败后调用update()将qlabel刷新为黑色,有一定概率会使得qla…...

python调用MATLAB出错matlab.engine.MatlabExecutionError无法调用MATLAB函数报错

python调用MATLAB出错matlab.engine.MatlabExecutionError无法调用MATLAB函数报错 说明(废话)解决方案MATLAB异常乱码python矩阵转MATLAB矩阵matlab.engine.MatlabExecutionError 说明(废话) python调用MATLAB,调用m文件中的函数,刚开始都没有问题&…...

[GXYCTF2019]Ping Ping Ping1

打开靶机 结合题目名称,考虑是命令注入,试试ls 结果应该就在flag.php。尝试构造命令注入载荷。 cat flag.php 可以看到过滤了空格,用 $IFS$1替换空格 还过滤了flag,我们用字符拼接的方式看能否绕过,ag;cat$IFS$1fla$a.php。注意这里用分号间隔…...

成为git砖家(1): author 和 committer 的区别

大家好,我是白鱼。一直对 git author 和 committer 不太了解, 今天通过 cherry-pick 的例子搞清楚了区别。 原理 例如我克隆了著名开源项目 spdlog 的源码, 根据某个历史 commit A 创建了分支, 然后 cherry-pick 了这个 commit …...

Lianwei 安全周报|2024.07.15

新的一周又开始了,以下是本周「Lianwei周报」,我们总结推荐了本周的政策/标准/指南最新动态、热点资讯和安全事件,保证大家不错过本周的每一个重点! 政策/标准/指南最新动态 01 《人工智能全球治理上海宣言》发布 我们强调共同促…...

Linux - 基础开发工具(yum、vim、gcc、g++、make/Makefile、git、gdb)

目录 Linux软件包管理器 - yum Linux下安装软件的方式 认识yum 查找软件包 安装软件 如何实现本地机器和云服务器之间的文件互传 卸载软件 Linux编辑器 - vim vim的基本概念 vim下各模式的切换 vim命令模式各命令汇总 vim底行模式各命令汇总 vim的简单配置 Linux编译器 - gc…...

Git使用介绍教程

Git使用介绍教程 小白第一次写博客,内容写的可能不是很详细,仅供参考,大家一起努力 gitee网址:https://gitee.com 大部分的开发团队都以 Git 作为自己的版本控制工具,需要对 Git 的使用非常的熟悉。这篇文章中本人整理了自己在开发过程中经常使用到的 Git 命令,方便在偶…...

STM32的TIM1之PWM互补输出_死区时间和刹车配置

STM32的TIM1之PWM互补输出_死区时间和刹车配置 1、定时器1的PWM输出通道 STM32高级定时器TIM1在用作PWM互补输出时,共有4个输出通道,其中有3个是互补输出通道,如下: 通道1:TIM1_CH1对应PA8引脚,TIM1_CH1N对应PB13引…...

C++复习的长文指南

C复习的长文指南 一、入门语法知识1.预备1.1 main函数1.2 注释1.3 变量1.3 常量1.4 关键字1.5 标识符明明规则 2. 数据类型2.1 整型2.1.1 sizeof关键字 2.2 实型(浮点型)2.3 字符型2.4 转义字符2.5 字符串型2.6 布尔类型bool2.7 数据的输入 3. 运算符3.1…...

深入了解MySQL文件排序

数据准备 CREATE TABLE user_info (id bigint(20) NOT NULL AUTO_INCREMENT COMMENT ID,name varchar(20) NOT NULL COMMENT 用户名,age tinyint(4) NOT NULL DEFAULT 0 COMMENT 年龄,sex tinyint(2) NOT NULL DEFAULT 0 COMMENT 状态 0:男 1: 女,creat…...

【JAVA基础】反射

编译期和运行期 首先大家应该先了解两个概念,编译期和运行期,编译期就是编译器帮你把源代码翻译成机器能识别的代码,比如编译器把java代码编译成jvm识别的字节码文件,而运行期指的是将可执行文件交给操作系统去执行, …...

贪心算法(2024/7/16)

1合并区间 以数组 intervals 表示若干个区间的集合,其中单个区间为 intervals[i] [starti, endi] 。请你合并所有重叠的区间,并返回 一个不重叠的区间数组,该数组需恰好覆盖输入中的所有区间 。 示例 1: 输入:inter…...

Python 在Word表格中插入、删除行或列

Word文档中的表格可以用于组织和展示数据。在实际应用过程中,有时为了调整表格的结构或适应不同的数据展示需求,我们可能会需要插入、删除行或列。以下提供了几种使用Python在Word表格中插入或删除行、列的方法供参考: 文章目录 Python 在Wo…...

Java二十三种设计模式-单例模式(1/23)

引言 在软件开发中,设计模式是一套被反复使用的、大家公认的、经过分类编目的代码设计经验的总结。单例模式作为其中一种创建型模式,确保一个类只有一个实例,并提供一个全局访问点。本文将深入探讨单例模式的概念、实现方式、使用场景以及潜…...

Unity动画系统(3)---融合树

6.1 动画系统基础2-6_哔哩哔哩_bilibili Animator类 using System.Collections; using System.Collections.Generic; using UnityEngine; public class EthanController : MonoBehaviour { private Animator ani; private void Awake() { ani GetComponen…...

sqlalchemy.orm中validates对两个字段进行联合校验

版本 sqlalchemy1.4.37 需求说明 有个场景,需要在orm中对两个字段进行联合校验,当 col1 xxx’时,对 col2的长度进行检查,超过限制(500)时,进行截断。 网上找了很久,没找到类似的…...

【ROS2】高级:解锁 Fast DDS 中间件的潜力 [社区贡献]

目标:本教程将展示如何在 ROS 2 中使用 Fast DDS 的扩展配置功能。 教程级别:高级 时间:20 分钟 目录 背景 先决条件在同一个节点中混合同步和异步发布 创建具有发布者的节点创建包含配置文件的 XML 文件执行发布者节点创建一个包含订阅者的节…...

VirtualBox虚拟机与主机互传文件的方法

建立共享文件夹 1.点击设置,点击共享文件夹,添加共享文件夹路径,保存 2.启动虚拟机,点击设备,点击安装增强功能,界面会出现一个光碟图标,点击光碟图标 3.打开光碟图标,出现一个目…...

访问控制系列

目录 一、基本概念 1.客体与主体 2.引用监控器与引用验证机制 3.安全策略与安全模型 4.安全内核 5.可信计算基 二、访问矩阵 三、访问控制策略 1.主体属性 2.客体属性 3.授权者组成 4.访问控制粒度 5.主体、客体状态 6.历史记录和上下文环境 7.数据内容 8.决策…...

【BUG】已解决:ModuleNotFoundError: No module named ‘cv2’

已解决:ModuleNotFoundError: No module named ‘cv2’ 欢迎来到英杰社区https://bbs.csdn.net/topics/617804998 欢迎来到我的主页,我是博主英杰,211科班出身,就职于医疗科技公司,热衷分享知识,武汉城市开…...

成都亚恒丰创教育科技有限公司 【插画猴子:笔尖下的灵动世界】

在浩瀚的艺术海洋中,每一种创作形式都是人类情感与想象力的独特表达。而插画,作为这一广阔领域中的璀璨明珠,以其独特的视觉语言和丰富的叙事能力,构建了一个又一个令人遐想连篇的梦幻空间。成都亚恒丰创教育科技有限公司 在众多插…...

gite+picgo+typora打造个人免费笔记软件

文章目录 1️⃣个人笔记软件2️⃣ 配置教程2.1 使用软件2.2 node 环境配置2.3 软件安装2.4 gite仓库设置2.5 配置picgo2.6 测试检验2.7 github教程 🎡 完结撒花 1️⃣个人笔记软件 最近换了环境,没有之前的生产环境舒适,写笔记也没有劲头&…...

只用 CSS 能玩出什么花样?

在前端开发领域,CSS 不仅仅是一种样式语言,它更像是一位多才多艺的艺术家,能够创造出令人惊叹的视觉效果。本文将带你探索 CSS 的无限可能,从基本形状到动态动画,从几何艺术到仿生设计,只用 CSS 就能玩出令…...

Linux C++ 056-设计模式之迭代器模式

Linux C 056-设计模式之迭代器模式 本节关键字:Linux、C、设计模式、迭代器模式 相关库函数: 概念 迭代器模式(Iterator Pattern)是一种常用的设计模式。迭代器模式提供一种方法顺序访问一个聚合对象中的各个元素,而…...

【Elasticsearch7.11】reindex问题

参考博文链接 问题:reindex 时出现如下问题 原因:数据量大,kibana的问题 解决方法: 将DSL命令转化成CURL命令在服务上执行 CURL命令 自动转化 curl -XPOST "http://IP:PORT/_reindex" -H Content-Type: application…...

网站开发需要的人员/手机优化什么意思

计算机应用基础与操作(中国铁道出版社2014年出版图书)语音编辑锁定讨论上传视频本词条缺少概述图,补充相关内容使词条更完整,还能快速升级,赶紧来编辑吧!《计算机应用基础与操作》由宋岐主编。本书主要介绍了计算机基础知识和操作…...

wordpress移动端导航菜单/互联网推广的好处

今天想说的是不要让自己企业的信息化穿上“皇帝的新装”,从这几年的工作来看,以安全为例,安全对于大家来说越来越受到重视,从国家政策及舆论层面每日剧增,信息安全或者说网络安全没有绝对安全%100安全,企业…...

购物网站成功案例/打开百度搜索

随着项目工程的增大,花在编译的时间会越来越长。为了提高编译效率,我们可以开启多线程来提高编译速度,充分利用多核机器的性能来优化编译。 1.windows下。目前windows下我们使用vs2012编译工程。vs可以通过以下方法打开多核编译,如…...

武汉专业制作网站/兰州网站seo

出现如下问题, 解决方法: 更换gensim的版本 pip install gensim 3.0.0...

wordpress文章 代码块/百度教育官网

Oracle 数据库中有 ROWNUM 这个功能,查询 list 后生成序号,很是方便,但 MySQL 是模拟了 Oracle 和 SQL Server 中的大部分功能,可自动生成序号却没有现成的函数或伪序列,很多情况下最后只能在后端代码或者前端代码中实…...

adsense用什么网站做/如何设计企业网站

题目链接:http://www.spoj.com/problems/PROOT/ 题目大意:给出一个整数p,p为素数,给出n个数x,判断x是否为p的原根。 解题思路:参考:http://www.apfloat.org/prim.html 大致思想:如果…...