递归神经网络(RNN)及其预测和分类的Python和MATLAB实现
递归神经网络(Recurrent Neural Networks,RNN)是一种广泛应用于序列数据建模的深度学习模型。相比于传统的前馈神经网络,RNN具有记忆和上下文依赖性的能力,适用于处理具有时序关联性的数据,如文本、语音、时间序列等。RNN的应用领域包括语言建模、机器翻译、语音识别、生成文本等。
### RNN的原理
RNN的核心在于其递归结构,允许信息在网络内部进行循环传递。在传统前馈神经网络中,每一层的输出仅与当前输入有关,而RNN的隐藏层不仅接收输入数据,还接收上一个时间步的隐藏状态作为输入。这种设计使RNN可以保持对先前信息的记忆,并在处理序列数据时具有上下文依赖性。
具体来说,假设某时刻t的输入为$X_t$,隐藏状态为$H_t$,输出为$Y_t$,则RNN的计算公式可以表示为:
$$H_t = f(W_{hx}X_t + W_{hh}H_{t-1} + b_h)$$
$$Y_t = g(W_{hy}H_t + b_y)$$
其中,$f$和$g$为激活函数,$W_{hx}$、$W_{hh}$、$W_{hy}$分别为输入到隐藏层、隐藏层到隐藏层、隐藏层到输出层的权重矩阵,$b_h$、$b_y$为偏置。通过这种循环计算,RNN可以对不同时间步的输入进行处理,并保持记忆状态。
### RNN的训练
RNN的训练通常采用反向传播算法,通过最小化损失函数来更新网络参数。在序列分类任务中,可以使用交叉熵损失函数;在序列生成任务中,可以使用最大似然估计或强化学习方法。由于RNN存在梯度消失和梯度爆炸问题,常见的解决方法包括梯度裁剪、使用门控循环单元(GRU)和长短时记忆网络(LSTM)等结构。
### RNN的实现过程
1. 数据准备:准备序列数据,将其转换成适合RNN模型输入的格式。
2. 模型构建:定义RNN网络结构,包括输入层、隐藏层和输出层,并选择合适的激活函数。
3. 损失函数和优化器选择:选择适合任务的损失函数和优化器,如交叉熵损失函数和Adam优化器等。
4. 模型训练:使用训练数据对模型进行训练,通过反向传播算法更新参数,并监测模型在验证集上的性能。
5. 模型评估:使用测试数据评估模型性能,计算损失值和准确率等指标。
6. 模型应用:将训练好的RNN模型应用于实际任务中,如文本生成、情感分析等。
总之,RNN作为一种能够处理序列数据的深度学习模型,在自然语言处理、时间序列预测等领域发挥着重要作用。通过理解其原理和实现过程,可以更好地应用RNN解决实际问题。
以下是使用Python编写的递归神经网络(RNN)进行时间序列预测的示例代码:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
# 创建时间序列数据
def generate_time_series_data(num_data_points):
time = np.linspace(0, 30, num_data_points)
data = np.sin(time) + 0.1 * np.random.randn(num_data_points)
return data
data = generate_time_series_data(1000)
# 将时间序列数据转换为训练数据集
def create_dataset(data, time_steps):
X, y = [], []
for i in range(len(data) - time_steps):
X.append(data[i:i+time_steps])
y.append(data[i+time_steps])
return np.array(X), np.array(y)
X_train, y_train = create_dataset(data, time_steps=10)
# 构建RNN模型
model = tf.keras.Sequential([
tf.keras.layers.SimpleRNN(64, input_shape=(10, 1)),
tf.keras.layers.Dense(1)
])
# 编译模型
model.compile(optimizer='adam', loss='mean_squared_error')
# 拟合模型
model.fit(X_train, y_train, epochs=10, batch_size=32)
# 预测未来时间序列数据
future_data = data[-10:] # 最后10个数据点
for _ in range(30):
X_test = np.array([future_data[-10:]]) # 使用最后10个数据点进行预测
prediction = model.predict(X_test.reshape(1, 10, 1))
future_data = np.append(future_data, prediction)
# 可视化预测结果
plt.plot(np.arange(1000), data, label='Original Data')
plt.plot(np.arange(1000, 1030), future_data[10:], label='Predicted Data')
plt.legend()
plt.show()
以下是一个大致的MATLAB示例代码逻辑:
% 创建时间序列数据
time = linspace(0, 30, 1000);
data = sin(time) + 0.1 * randn(1, 1000);
% 创建训练数据集
XTrain = data(1:990);
YTrain = data(11:1000);
% 定义并训练RNN模型
layers = [sequenceInputLayer(10), lstmLayer(64), fullyConnectedLayer(1)];
options = trainingOptions('adam', 'MaxEpochs', 10, 'MiniBatchSize', 32);
net = trainNetwork(XTrain, YTrain, layers, options);
% 预测未来数据
future_data = data(end-9:end); % 最后10个数据点
for i = 1:30
XTest = future_data(end-9:end);
prediction = predict(net, XTest);
future_data = [future_data, prediction];
end
% 可视化结果
figure;
plot(1:1000, data, 'b', 'LineWidth', 1.5);
hold on;
plot(1001:1030, future_data(11:end), 'r', 'LineWidth', 1.5);
legend('Original Data', 'Predicted Data');
递归神经网络(RNN)进行分类任务的示例代码如下:
Python代码示例:
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import SimpleRNN, Dense
# 加载MNIST数据集
(X_train, y_train), (X_test, y_test) = mnist.load_data()
# 数据预处理
X_train = X_train.reshape(-1, 28, 28) / 255.0
X_test = X_test.reshape(-1, 28, 28) / 255.0
# 构建RNN模型
model = Sequential([
SimpleRNN(64, input_shape=(28, 28)),
Dense(10, activation='softmax')
])
# 编译模型
model.compile(optimizer='adam', loss='sparse_categorical_crossentropy', metrics=['accuracy'])
# 拟合模型
model.fit(X_train, y_train, epochs=5, batch_size=32)
# 评估模型
_, test_accuracy = model.evaluate(X_test, y_test)
print(f'Test accuracy: {test_accuracy}')
MATLAB代码示例:
% 加载MNIST数据集
[XTrain, YTrain] = digitTrainCellArrayData;
[XTest, YTest] = digitTestCellArrayData;
% 数据预处理
XTrain = reshape(XTrain, size(XTrain, 1), 1, size(XTrain, 2)) / 255.0;
XTest = reshape(XTest, size(XTest, 1), 1, size(XTest, 2)) / 255.0;
% 构建和训练RNN模型
layers = [sequenceInputLayer(1), lstmLayer(64), fullyConnectedLayer(10), classificationLayer];
options = trainingOptions('adam', 'MaxEpochs', 5, 'MiniBatchSize', 32);
net = trainNetwork(XTrain, categorical(YTrain), layers, options);
% 评估模型
YTest = classify(net, XTest);
accuracy = sum(YTest == YTest) / numel(YTest);
disp(['Test accuracy: ', num2str(accuracy)]);
相关文章:
递归神经网络(RNN)及其预测和分类的Python和MATLAB实现
递归神经网络(Recurrent Neural Networks,RNN)是一种广泛应用于序列数据建模的深度学习模型。相比于传统的前馈神经网络,RNN具有记忆和上下文依赖性的能力,适用于处理具有时序关联性的数据,如文本、语音、时…...
以flask为后端的博客项目——星云小窝
以flask为后端的博客项目——星云小窝 文章目录 以flask为后端的博客项目——星云小窝前言一、星云小窝项目——项目介绍(一)二、星云小窝项目——项目启动(二)三、星云小窝项目——项目结构(三)四、谈论一…...
CUDA编程02 - 数据并行介绍
一:概述 数据并行是指在数据集的不同部分上执行计算工作,这些计算工作彼此相互独立且可以并行执行。许多应用程序都具有丰富的数据并行性,使其能够改造成可并行执行的程序。因此,对于程序员来说,熟悉数据并行的概念以及使用并行编程语言来编写数据并行的代码是非常重要的。…...
Android 视频音量图标
attrs.xml <?xml version"1.0" encoding"utf-8"?> <resources><!--图标颜色--><attr name"ijkSolid" format"color|reference" /><!--喇叭底座宽度--><attr name"ijkCornerWidth" form…...
VScode 修改 Markdown Preview Enhanced 字体以及大纲编号
修改字体和背景颜色 按快捷键 Ctrl , 打开设置,搜索 markdown-preview-enhanced.previewTheme,选择一个黑色主题的css,如 github-dark.css. 修改自动编号和背景颜色 背景颜色 按 F1 或者 Ctrl Shift P,输入 Customize CSS…...
TCP的FIN报文可否携带数据
问题发现: 发现FTP-DATA数据传输完,TCP的挥手似乎只有两次 实际发现FTP-DATA报文中,TCP层flags中携带了FIN标志 piggyback FIN 问题转化为 TCP packet中如果有FIN flag,该报文还能携带data数据么? 答案是肯定的 RFC7…...
【GoF23种设计模式+简单工厂模式】
一、设计模式概述与类型 1.1、设计模式的一般定义: 设计模式(Design Pattern)是一套被反复使用、多数人知晓的、经过分类编目的、代码设计经验的总结,使用设计模式是为了可重用代码,让代码更容易被他人理解并且保证代…...
北醒单点激光雷达更改id和波特率以及Ubuntu20.04下CAN驱动
序言: 需要的硬件以及软件 1、USB-CAN分析仪使用顶配pro版本,带有支持ubuntu下的驱动包的,可以读取数据。 2、电源自备24V电源 3、单点激光雷达接线使用can线可以组网。 一、更改北醒单点激光雷达的id号和波特率 安装并运行USB-CAN分析仪自带…...
【线性代数】矩阵变换
一些特殊的矩阵 一,对角矩阵 1,什么是对角矩阵 表示将矩阵进行伸缩(反射)变换,仅沿坐标轴方向伸缩(反射)变换。 2,对角矩阵可分解为多个F1矩阵,如下: 二&a…...
聚焦智慧出行,TDengine 与路特斯科技再度携手
在全球汽车行业向电动化和智能化转型的过程中,智能驾驶技术正迅速成为行业的焦点。随着消费者对出行效率、安全性和便利性的需求不断提升,汽车制造商们需要在全球范围内实现低延迟、高质量的数据传输和处理,以提升用户体验。在此背景下&#…...
虚拟机迁移报错:虚拟机版本与主机“x.x.x.x”的版本不兼容
1.虚拟机在VCenter上从一个ESXi迁移到另一个ESXi上时报错:虚拟机版本与主机“x.x.x.x”的版本不兼容。 2.例如从10.0.128.13的ESXi上迁移到10.0.128.11的ESXi上。点击10.0.128.10上的任意一台虚拟机,查看虚拟机版本。 3.确认要迁移的虚拟机磁盘所在位…...
【教程】vscode添加powershell7终端
win10自带的 powershell 是1.0版本的,太老了,更换为powershell7后,在 vscode 的集成终端中没有显示本篇教程记录在vscode添加powershell7终端的过程 打开vscode终端配置 然后来到这个页面进行设置 查看 powershell7 的安装位置ÿ…...
如何乘上第四次工业革命的大船
如何乘上第四次工业革命的大船 第四次工业革命通常被认为是信息技术和数字化时代的到来,但具体影响哪些产业,以及它将如何演变和展开,仍然是一个广泛讨论的话题。 然而,已经可以看到一些领域可能受到第四次工业革命的深远影响,例如人工智能、物联网、大数据、生物技术、可…...
RKNN执行bash ./build-linux_RK3566_RK3568.sh 报错
目录 报错信息: 原因分析: 解决办法: 报错信息: CMake Error at /usr/share/cmake-3.22/Modules/CMakeDetermineCCompiler.cmake:49 (message): Could not find compiler set in environment variable CC: aarch64-linux-gnu-gcc. Call Stack (most recent call fir…...
Linux常用命令整理
本文将分享一些常用的Linux命令。根据功能的不同,大概分为以下几个方面,一是文件相关命令,二是进程相关命令,三是网络相关命令,四是磁盘相关命令,五是用户管理相关命令,六是系统命令。 1. 文件…...
python 闭包、装饰器
一、闭包: 1. 外部函数嵌套内部函数 2. 外部函数返回内部函数 3.内部函数可以访问外部函数局部变量 闭包(Closure)是指在一个函数内部定义的函数,并且内部函数可以访问外部函数的局部变量,即使外部函数已经执行…...
[pycharm]解决pycharm运行程序出现卡住scanning files to index索引的问题
有时候会出现索引问题,显示scanning files to index 解决方法: in pycharm, go to the "File" on the left top, then select "invalidate caches/restart...", and press "invalidate and restart". 然后等它自己重启…...
python每日学习11:numpy库的用法(下)
python每日学习11:numpy库的用法(下) 数组的拼接 名方法称说明concatenate连接沿现有轴的数组序列hstack水平堆叠序列中的数组(列方向)vstack竖直堆叠序列中的数组(行方向)concatenate函数用于沿指定轴连接相同形状的两…...
【Emacs有什么优点,用Emacs写程序真的比IDE更方便吗?】
🎥博主:程序员不想YY啊 💫CSDN优质创作者,CSDN实力新星,CSDN博客专家 🤗点赞🎈收藏⭐再看💫养成习惯 ✨希望本文对您有所裨益,如有不足之处,欢迎在评论区提出…...
6、基于Fabirc 2.X 通用电子存证系统部署
evidence 将GOPATH设置为/root/go,拉取项目: cd $GOPATH/src && git clone https://gitee.com/henan-minghua_0/evidence.git 在/etc/hosts中添加: 127.0.0.1 orderer.example.com 127.0.0.1 peer0.org1.example.com 127.0.0.1 peer1.org…...
uniapp 对接腾讯云IM群组成员管理(增删改查)
UniApp 实战:腾讯云IM群组成员管理(增删改查) 一、前言 在社交类App开发中,群组成员管理是核心功能之一。本文将基于UniApp框架,结合腾讯云IM SDK,详细讲解如何实现群组成员的增删改查全流程。 权限校验…...
树莓派超全系列教程文档--(61)树莓派摄像头高级使用方法
树莓派摄像头高级使用方法 配置通过调谐文件来调整相机行为 使用多个摄像头安装 libcam 和 rpicam-apps依赖关系开发包 文章来源: http://raspberry.dns8844.cn/documentation 原文网址 配置 大多数用例自动工作,无需更改相机配置。但是,一…...
k8s从入门到放弃之Ingress七层负载
k8s从入门到放弃之Ingress七层负载 在Kubernetes(简称K8s)中,Ingress是一个API对象,它允许你定义如何从集群外部访问集群内部的服务。Ingress可以提供负载均衡、SSL终结和基于名称的虚拟主机等功能。通过Ingress,你可…...
剑指offer20_链表中环的入口节点
链表中环的入口节点 给定一个链表,若其中包含环,则输出环的入口节点。 若其中不包含环,则输出null。 数据范围 节点 val 值取值范围 [ 1 , 1000 ] [1,1000] [1,1000]。 节点 val 值各不相同。 链表长度 [ 0 , 500 ] [0,500] [0,500]。 …...
1.3 VSCode安装与环境配置
进入网址Visual Studio Code - Code Editing. Redefined下载.deb文件,然后打开终端,进入下载文件夹,键入命令 sudo dpkg -i code_1.100.3-1748872405_amd64.deb 在终端键入命令code即启动vscode 需要安装插件列表 1.Chinese简化 2.ros …...
[10-3]软件I2C读写MPU6050 江协科技学习笔记(16个知识点)
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16...
ios苹果系统,js 滑动屏幕、锚定无效
现象:window.addEventListener监听touch无效,划不动屏幕,但是代码逻辑都有执行到。 scrollIntoView也无效。 原因:这是因为 iOS 的触摸事件处理机制和 touch-action: none 的设置有关。ios有太多得交互动作,从而会影响…...
大语言模型(LLM)中的KV缓存压缩与动态稀疏注意力机制设计
随着大语言模型(LLM)参数规模的增长,推理阶段的内存占用和计算复杂度成为核心挑战。传统注意力机制的计算复杂度随序列长度呈二次方增长,而KV缓存的内存消耗可能高达数十GB(例如Llama2-7B处理100K token时需50GB内存&a…...
laravel8+vue3.0+element-plus搭建方法
创建 laravel8 项目 composer create-project --prefer-dist laravel/laravel laravel8 8.* 安装 laravel/ui composer require laravel/ui 修改 package.json 文件 "devDependencies": {"vue/compiler-sfc": "^3.0.7","axios": …...
【笔记】WSL 中 Rust 安装与测试完整记录
#工作记录 WSL 中 Rust 安装与测试完整记录 1. 运行环境 系统:Ubuntu 24.04 LTS (WSL2)架构:x86_64 (GNU/Linux)Rust 版本:rustc 1.87.0 (2025-05-09)Cargo 版本:cargo 1.87.0 (2025-05-06) 2. 安装 Rust 2.1 使用 Rust 官方安…...
