跟着问题学15——GRU网络结构详解及代码实战
1 RNN的缺陷——长期依赖的问题 (The Problem of Long-Term Dependencies)
前面一节我们学习了RNN神经网络,它可以用来处理序列型的数据,比如一段文字,视频等等。RNN网络的基本单元如下图所示,可以将前面的状态作为当前状态的输入。

但也有一些情况,我们需要更“长期”的上下文信息。比如预测最后一个单词“我在中国长大……我说一口流利的**。”“短期”的信息显示,下一个单词很可能是一种语言的名字,但如果我们想缩小范围,我们需要更长期语境——“我在中国长大”,但这个相关信息与需要它的点之间的距离完全有可能变得非常大。
不幸的是,随着这种距离的扩大,RNN无法学会连接这些信息。
从理论上讲,RNN绝对有能力处理这种“长期依赖性”。人们可以为他们精心选择参数,以解决这种形式的问题。遗憾的是,在实践中,RNN似乎无法学习它们。
幸运的是,GRU也没有这个问题!
2、GRU
什么是GRU
GRU(Gate Recurrent Unit)是循环神经网络(Recurrent Neural Network, RNN)的一种。和LSTM(Long-Short Term Memory)一样,也是为了解决长期记忆和反向传播中的梯度等问题而提出来的。
GRU和LSTM在很多情况下实际表现上相差无几,那么为什么我们要使用新人GRU(2014年提出)而不是相对经受了更多考验的LSTM(1997提出)呢。
用论文中的话说,相比LSTM,使用GRU能够达到相当的效果,并且相比之下更容易进行训练,能够很大程度上提高训练效率,因此很多时候会更倾向于使用GRU。
2.1总体结构框架
前面我们讲到,神经网络的各种结构都是为了挖掘变换数据特征的,所以下面我们也将结合数据特征的维度来对比介绍一下RNN&&LSTM的网络结构。
多层感知机(线性连接层)结构

从特征角度考虑:
输入特征:是n*1的单维向量(这也是为什么卷积神经网络在linear层前要把所有特征层展平),
隐藏层:然后根据隐藏层神经元的数量m将前层输入的特征用m*1的单维向量进行表示(对特征进行了提取变换,隐藏层的数据特征),单个隐藏层的神经元数量就代表网络参数,可以设置多个隐藏层;
输出特征:最终根据输出层的神经元数量y输出y*1的单维向量。
卷积神经网络结构

从特征角度考虑:
输入特征:是(batch)*channel*width*height的张量,
卷积层(等):然后根据输入通道channel的数量c_in和输出通道channel的数量c_out会有c_out*c_in*k*k个卷积核将前层输入的特征进行卷积(对特征进行了提取变换,k为卷积核尺寸),卷积核的大小和数量c_out*c_in*k*k就代表网络参数,可以设置多个卷积层;每一个channel都代表提取某方面的一种特征,该特征用width*height的二维张量表示,不同特征层之间是相互独立的(可以进行融合)。
输出特征:根据场景的需要设置后面的输出,可以是多分类的单维向量等等。
循环神经网络RNN系列结构

从特征角度考虑:
输入特征:是(batch)*T_seq*feature_size的张量(T_seq代表序列长度,注意不是batch_size).
我们来详细对比一下卷积神经网络的输入特征,
(batch)*T_seq*feature_size
(batch)*channel*width*height,
逐个进行分析,RNN系列的基础输入特征表示是feature_size*1的单维向量,比如一个单词的词向量,比如一个股票价格的影响因素向量,而CNN系列的基础输入特征是width*height的二维张量;
再来看一下序列T_seq和通道channel,RNN系列的序列T_seq是指一个连续的输入,比如一句话,一周的股票信息,而且这个序列是有时间先后顺序且互相关联的,而CNN系列的通道channel则是指不同角度的特征,比如彩色图像的RGB三色通道,过程中每个通道代表提取了每个方面的特征,不同通道之间是没有强相关性的,不过也可以进行融合。
最后就是batch,两者都有,在RNN系列,batch就是有多个句子,在CNN系列,就是有多张图片(每个图片可以有多个通道)
隐藏层:明确了输入特征之后,我们再来看看隐藏层代表着什么。隐藏层有T_seq个隐状态H_t(和输入序列长度相同),每个隐状态H_t类似于一个channel,对应着T_seq中的t时刻的输入特征;而每个隐状态H_t是用hidden_size*1的单维向量表示的,所以一个隐含层是T_seq*hidden_size的张量;对应时刻t的输入特征由feature_size*1变为hidden_size*1的向量。如图中所示,同一个隐含层不同时刻的参数W_ih和W_hh是共享的;隐藏层可以有num_layers个(图中只有1个)
以t时刻具体阐述一下:
X_t是t时刻的输入,是一个feature_size*1的向量
W_ih是输入层到隐藏层的权重矩阵
H_t是t时刻的隐藏层的值,是一个hidden_size*1的向量
W_hh是上一时刻的隐藏层的值传入到下一时刻的隐藏层时的权重矩阵
Ot是t时刻RNN网络的输出
从上右图中可以看出这个RNN网络在t时刻接受了输入Xt之后,隐藏层的值是St,输出的值是Ot。但是从结构图中我们可以发现St并不单单只是由Xt决定,还与t-1时刻的隐藏层的值St-1有关。
2.2 GRU的输入输出结构
GRU的输入输出结构与普通的RNN是一样的。有一个当前的输入xt,和上一个节点传递下来的隐状态(hidden state)ht-1 ,这个隐状态包含了之前节点的相关信息。结合xt和 ht-1,GRU会得到当前隐藏节点的输出yt 和传递给下一个节点的隐状态 ht。

图 GRU的输入输出结构
那么,GRU到底有什么特别之处呢?下面来对它的内部结构进行分析!
2.3 GRU的内部结构
不同于LSTM有3个门控,GRU仅有2个门控,
第一个是“重置门”(reset gate),其根据当前时刻的输入xt和上一时刻的隐状态ht-1变换后经sigmoid函数输出介于0和1之间的数字,用于将上一时刻隐状态ht-1重置为ht-1’,即ht-1’=ht-1*r。

再将ht-1’与输入xt进行拼接,再通过一个tanh激活函数来将数据放缩到-1~1的范围内。即得到如下图2-3所示的h’。

第一个是“更新门”(update gate),其根据当前时刻的输入xt和上一时刻的隐状态ht-1变换后经sigmoid函数输出介于0和1之间的数字,

最终的隐状态ht的更新表达式即为:

再次强调一下,门控信号(这里的z)的范围为0~1。门控信号越接近1,代表”记忆“下来的数据越多;而越接近0则代表”遗忘“的越多。
2.4 小结
GRU很聪明的一点就在于,使用了同一个门控z就同时可以进行遗忘和选择记忆(LSTM则要使用多个门控)。与LSTM相比,GRU内部少了一个”门控“,参数比LSTM少,但是却也能够达到与LSTM相当的功能。考虑到硬件的计算能力和时间成本,因而很多时候我们也就会选择更加”实用“的GRU。

3代码
import torch
import torch.nn as nndef my_gru(input,initial_states,w_ih,w_hh,b_ih,b_hh):h_prev=initial_statesbatch_size,T_seq,feature_size=input.shapehidden_size=w_ih.shape[0]//3batch_w_ih=w_ih.unsqueeze(0).tile(batch_size,1,1)batch_w_hh=w_hh.unsqueeze(0).tile(batch_size,1,1)output=torch.zeros(batch_size,T_seq,hidden_size)for t in range(T_seq):x=input[:,t,:]w_times_x=torch.bmm(batch_w_ih,x.unsqueeze(-1))w_times_x=w_times_x.squeeze(-1)# print(batch_w_hh.shape,h_prev.shape)# 计算两个tensor的矩阵乘法,torch.bmm(a,b),tensor a 的size为(b,h,w),tensor b的size为(b,w,m)# 也就是说两个tensor的第一维是相等的,然后第一个数组的第三维和第二个数组的第二维度要求一样,# 对于剩下的则不做要求,输出维度 (b,h,m)# batch_w_hh=batch_size*(3*hidden_size)*hidden_size# h_prev=batch_size*hidden_size*1# w_times_x=batch_size*hidden_size*1##squeeze,在给定维度(维度值必须为1)上压缩维度,负数代表从后开始数w_times_h_prev=torch.bmm(batch_w_hh,h_prev.unsqueeze(-1))w_times_h_prev=w_times_h_prev.squeeze(-1)r_t=torch.sigmoid(w_times_x[:,:hidden_size]+w_times_h_prev[:,:hidden_size]+b_ih[:hidden_size]+b_hh[:hidden_size])z_t=torch.sigmoid(w_times_x[:,hidden_size:2*hidden_size]+w_times_h_prev[:,hidden_size:2*hidden_size]+b_ih[hidden_size:2*hidden_size]+b_hh[hidden_size:2*hidden_size])n_t=torch.tanh(w_times_x[:,2*hidden_size:3*hidden_size]+w_times_h_prev[:,2*hidden_size:3*hidden_size]+b_ih[2*hidden_size:3*hidden_size]+b_hh[2*hidden_size:3*hidden_size])h_prev=(1-z_t)*n_t+z_t*h_prevoutput[:,t,:]=h_prevreturn output,h_previf __name__=="__main__":fc=nn.Linear(12,6)batch_size=2T_seq=5feature_size=4hidden_size=3# output_feature_size=3input=torch.randn(batch_size,T_seq,feature_size)h_prev=torch.randn(batch_size,hidden_size)gru_layer=nn.GRU(feature_size,hidden_size,batch_first=True)output,h_final=gru_layer(input,h_prev.unsqueeze(0))# for k,v in gru_layer.named_parameters():# print(k,v.shape)# print(output,h_final)my_output, my_h_final=my_gru(input,h_prev,gru_layer.weight_ih_l0,gru_layer.weight_hh_l0,gru_layer.bias_ih_l0,gru_layer.bias_hh_l0)# print(my_output, my_h_final)# print(torch.allclose(output,my_output))
参考资料
https://zhuanlan.zhihu.com/p/32481747
https://speech.ee.ntu.edu.tw/~tlkagk/courses/MLDS_2018/Lecture/Seq%20(v2).pdf
https://www.bilibili.com/video/BV1jm4y1Q7uh/?spm_id_from=333.788&vd_source=cf7630d31a6ad93edecfb6c5d361c659
相关文章:
跟着问题学15——GRU网络结构详解及代码实战
1 RNN的缺陷——长期依赖的问题 (The Problem of Long-Term Dependencies) 前面一节我们学习了RNN神经网络,它可以用来处理序列型的数据,比如一段文字,视频等等。RNN网络的基本单元如下图所示,可以将前面的…...
【uniapp】swiper切换时,v-for重新渲染页面导致文字在视觉上的拉扯问题
问题描述 先用v-for渲染了几个列表,但这几个列表是占同一个位置的,只是通过切换swiper来显示哪个列表显示,也就是为了优化页面切换时候,没有根据swiper的current再更新v-for的数据,但现在就有个问题,怎么隐…...
【Android】Compose初识
文章目录 1.Compose是什么2.Compose优势3.可组合函数4.布局5.配置布局6.Material Design7.列表与动画8.声明式UI9.组合10.重组 1.Compose是什么 Jetpack Compose是谷歌开发的一个现代的、声明式的UI工具包,用于构建原生的Android应用程序界面。它简化了创建复杂用户…...
前端工程化面试题(二)
前端模块化标准 CJS、ESM 和 UMD 的区别 CJS(CommonJS)、ESM(ESModule)和UMD(Universal Module Definition)是前端模块化标准的三种主要形式,它们各自有不同的特点和使用场景: CJS&…...
以攻击者的视角进行软件安全防护
1. 前言 孙子曰:知彼知己者,百战不殆;不知彼而知己,一胜一负,不知彼,不知己,每战必殆。 摘自《 孙子兵法谋攻篇 》在2500 年前的那个波澜壮阔的春秋战国时代,孙子兵法的这段话&…...
008.精读《Apache Paimon Docs - Table w/o PK》
文章目录 1. 引言2. 基本概念2.1 定义2.2 使用场景 3. 流式处理3.1 自动小文件合并3.2 流式查询 4. 数据更新4.1 查询4.2 更新4.3 分桶附加表 5 总结 1. 引言 通过本文,上篇我们了解了Apache Paimon 主键表,本期我们将继续学习附加表(Append…...
C#实时监控指定文件夹中的动态,并将文件夹中生成的新图片显示在界面上(相机采图,并且从本地拿图)
结果展示 此类原理适用于文件夹中自动生成图片,并提取最新生成的图片将其显示, 如果你是相机采图将其保存到本地,可以用这中方法可视化,并将检测的结果和图片匹配 理论上任何文件都是可以监视并显示的,我这里只是做了…...
使用SQLark分析达梦慢SQL执行计划的一次实践
最近刚参加完达梦的 DCP 培训与考试,正好业务系统有个 sql 查询较慢,就想着练练手。 在深入了解达梦的过程中,发现达梦新出了一款叫 SQLark 百灵连接的工具。 我首先去官网大致浏览了下。虽然 SQLark 在功能深度上不如 DM Manager 和 PL/SQ…...
【人工智能】用Python构建高效的自动化数据标注工具:从理论到实现
《Python OpenCV从菜鸟到高手》带你进入图像处理与计算机视觉的大门! 数据标注是构建高质量机器学习模型的关键环节,但其耗时耗力常成为制约因素。本篇文章将介绍如何用Python构建一个自动化数据标注工具,结合机器学习和NLP技术,帮助加速数据标注过程。我们将从需求分析入…...
Java --- 注解(Annotation)
一.什么是注解? 在Java中,注解(Annotation)是一种元数据(metadata),它为程序中的类、方法、字段等提供额外的描述信息。注解本身不直接改变程序的行为,但可以被编译器、开发工具、框…...
nodejs作为provider接入nacos
需求:公司产品一直是nodejs的后台,采用的eggjs框架,也不是最新版本,现有有需求需求将这些应用集成到微服务的注册中心,领导要求用java。 思路:用spring cloud gateway将需要暴露的接口url转发,…...
SpringBoot3+Micormeter监控应用指标
监控内容简介 SpringBoot3项目监控服务 ,可以使用Micormeter度量指标库,帮助我们监控应用程序的度量指标,并将其发送到Prometheus中并用Grafana展示。监控指标有系统负载、内存使用情况、应用程序的响应时间、吞吐量、错误率等。 micromete…...
Mybatis-plus 简单使用,mybatis-plus 分页模糊查询报500 的错
一、mybtis-plus配置下载 MyBatis-Plus 是一个 Mybatis 增强版工具,在 MyBatis 上扩充了其他功能没有改变其基本功能,为了简化开发提交效率而存在。 具体的介绍请参见官方文档。 官网文档地址:mybatis-plus 添加mybatis-plus依赖 <depe…...
2022 年 12 月青少年软编等考 C 语言三级真题解析
目录 T1. 鸡兔同笼思路分析T2. 猴子吃桃思路分析T3. 括号匹配问题T4. 上台阶思路分析T5. 田忌赛马T1. 鸡兔同笼 一个笼子里面关了鸡和兔子(鸡有 2 2 2 只脚,兔子有 4 4 4 只脚,没有例外)。已经知道了笼子里面脚的总数 a a a,问笼子里面至少有多少只动物,至多有多少只…...
webpack 题目
文章目录 webpack 中 chunkHash 和 contentHash 的区别loader和plugin的区别?webpack 处理 image 是用哪个 loader,限制 image 大小的是...;webpack 如何优化打包速度 webpack 中 chunkHash 和 contentHash 的区别 主要从四方面来讲一下区别&…...
【MySQL】视图详解
视图详解 一、视图的概念二、视图的常用操作2.1创建视图2.2查询视图2.3修改视图2.4 删除视图2.5向视图中插入数据 三、视图的检查选项3.1 cascaded(级联 )3.2 local(本地) 四、视图的作用 一、视图的概念 视图(View)是一种虚拟存…...
第一节:ORIN NX介绍与基于sdkmanager的镜像烧录(包含ubuntu文件系统/CUDA/OpenCV/cudnn/TensorRT)
ORIN NX技术参数 Orin NX版本对比 如上图所示,ORIN NX官方发布的版本有两个版本一个版本是70TOPS算力,DDR为8GB的版本低配版本,一个是100TOPS算法,DDR为16GB的高配版本。 Orin NX的外设框图 两个版本除了GPU和DDR的差距外,外设基本上没有区别,丰富的外设接口,后续开发…...
2024-12-04OpenCV视频处理基础
OpenCV视频处理基础 OpenCV的视频教学:https://www.bilibili.com/video/BV14P411D7MH 1-OpenCV视频捕获 在 OpenCV 中,cv2.VideoCapture() 是一个用于捕获视频流的类。它可以用来从摄像头捕获实时视频,或者从视频文件中读取帧。以下是如何使用…...
D89【python 接口自动化学习】- pytest基础用法
day89 pytest的setup,setdown详解 学习日期:20241205 学习目标:pytest基础用法 -- pytest的setup,setdown详解 学习笔记: setup、teardown详解 模块级 setup_module/teardown_module 开始于模块始末,生…...
七、docker registry
七、docker registry 7.1 了解Docker Registry 7.1.1 介绍 registry 用于保存docker 镜像,包括镜像的层次结构和元数据。启动容器时,docker daemon会试图从本地获取相关的镜像;本地镜像不存在时,其将从registry中下载该镜像并保…...
Python|GIF 解析与构建(5):手搓截屏和帧率控制
目录 Python|GIF 解析与构建(5):手搓截屏和帧率控制 一、引言 二、技术实现:手搓截屏模块 2.1 核心原理 2.2 代码解析:ScreenshotData类 2.2.1 截图函数:capture_screen 三、技术实现&…...
springboot 百货中心供应链管理系统小程序
一、前言 随着我国经济迅速发展,人们对手机的需求越来越大,各种手机软件也都在被广泛应用,但是对于手机进行数据信息管理,对于手机的各种软件也是备受用户的喜爱,百货中心供应链管理系统被用户普遍使用,为方…...
sqlserver 根据指定字符 解析拼接字符串
DECLARE LotNo NVARCHAR(50)A,B,C DECLARE xml XML ( SELECT <x> REPLACE(LotNo, ,, </x><x>) </x> ) DECLARE ErrorCode NVARCHAR(50) -- 提取 XML 中的值 SELECT value x.value(., VARCHAR(MAX))…...
全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比
目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec? IPsec VPN 5.1 IPsec传输模式(Transport Mode) 5.2 IPsec隧道模式(Tunne…...
微软PowerBI考试 PL300-在 Power BI 中清理、转换和加载数据
微软PowerBI考试 PL300-在 Power BI 中清理、转换和加载数据 Power Query 具有大量专门帮助您清理和准备数据以供分析的功能。 您将了解如何简化复杂模型、更改数据类型、重命名对象和透视数据。 您还将了解如何分析列,以便知晓哪些列包含有价值的数据,…...
如何应对敏捷转型中的团队阻力
应对敏捷转型中的团队阻力需要明确沟通敏捷转型目的、提升团队参与感、提供充分的培训与支持、逐步推进敏捷实践、建立清晰的奖励和反馈机制。其中,明确沟通敏捷转型目的尤为关键,团队成员只有清晰理解转型背后的原因和利益,才能降低对变化的…...
WPF八大法则:告别模态窗口卡顿
⚙️ 核心问题:阻塞式模态窗口的缺陷 原始代码中ShowDialog()会阻塞UI线程,导致后续逻辑无法执行: var result modalWindow.ShowDialog(); // 线程阻塞 ProcessResult(result); // 必须等待窗口关闭根本问题:…...
0x-3-Oracle 23 ai-sqlcl 25.1 集成安装-配置和优化
是不是受够了安装了oracle database之后sqlplus的简陋,无法删除无法上下翻页的苦恼。 可以安装readline和rlwrap插件的话,配置.bahs_profile后也能解决上下翻页这些,但是很多生产环境无法安装rpm包。 oracle提供了sqlcl免费许可,…...
Windows 下端口占用排查与释放全攻略
Windows 下端口占用排查与释放全攻略 在开发和运维过程中,经常会遇到端口被占用的问题(如 8080、3306 等常用端口)。本文将详细介绍如何通过命令行和图形化界面快速定位并释放被占用的端口,帮助你高效解决此类问题。 一、准…...
C# WPF 左右布局实现学习笔记(1)
开发流程视频: https://www.youtube.com/watch?vCkHyDYeImjY&ab_channelC%23DesignPro Git源码: GitHub - CSharpDesignPro/Page-Navigation-using-MVVM: WPF - Page Navigation using MVVM 1. 新建工程 新建WPF应用(.NET Framework) 2.…...
