【时间序列预测】基于PyTorch实现CNN_BiLSTM算法
文章目录
- 1. CNN与BiLSTM
- 2. 完整代码实现
- 3. 代码结构解读
- 3.1 CNN Layer
- 3.2 BiLSTM Layer
- 3.3 Output Layer
- 3.4 forward Layer
- 4. 应用场景
- 5. 总结
本文将详细介绍如何使用Pytorch实现一个结合
卷积神经网络(CNN)
和双向长短期记忆网络(BiLSTM)
的混合模型—CNN_BiLSTM
。这种网络架构结合了CNN在提取局部特征方面的优势和BiLSTM在建模序列数据时的长期依赖关系的能力,特别适用于时序数据的预测任务,如时间序列分析、风速预测、股票预测等。
1. CNN与BiLSTM
CNN
主要通过卷积操作对输入数据进行特征提取,适合于处理局部结构化的特征(如图像数据、时间序列数据中的局部模式)。BiLSTM
则是基于LSTM的变种,它通过双向遍历序列,可以同时捕捉过去和未来的信息,使其在处理时间序列数据时非常有效。- 在本例中,CNN负责提取时间序列数据的局部特征,而BiLSTM则进一步捕捉数据中的时序依赖关系,最终通过全连接层输出预测结果。
2. 完整代码实现
"""
CNN_BiLSTM Network
"""
from torch import nnclass CNN_BiLSTM(nn.Module):r"""CNN_BiLSTMArgs:cnn_in_channels : CNN输入通道数, if in.shape=[64,7,18] value=7bilstm_input_size : bilstm输入大小, if in.shape=[64,7,18] value=18output_size : 期望网络输出大小cnn_out_channels: CNN层输出通道数cnn_kernal_size : CNN层卷积核大小maxpool_kernal_size: MaxPool Layer kernal_sizebilstm_hidden_size: BiLSTM Layer hidden_dimbilstm_num_layers: BiLSTM Layer num_layersdropout: dropout防止过拟合, 取值(0,1)bilstm_proj_size: BiLSTM Layer proj_sizeExample:>>> import torch>>> input = torch.randn([64,7,18])>>> model = CNN_BiLSTM(7, 18,18)>>> out = model(input)"""def __init__(self,cnn_in_channels,bilstm_input_size,output_size,cnn_out_channels=32,cnn_kernal_size=3,maxpool_kernal_size=3,bilstm_hidden_size=128,bilstm_num_layers=4,dropout = 0.05,bilstm_proj_size = 0):super().__init__()# CNN Layerself.conv1d = nn.Conv1d(in_channels=cnn_in_channels, out_channels=cnn_out_channels, kernel_size=cnn_kernal_size, padding="same")self.relu = nn.ReLU()self.maxpool = nn.MaxPool1d(kernel_size= maxpool_kernal_size)# BiLSTM Layerself.bilstm = nn.LSTM(input_size = int(bilstm_input_size/maxpool_kernal_size),hidden_size = bilstm_hidden_size,num_layers = bilstm_num_layers,batch_first = True,dropout = dropout,bidirectional = True,proj_size = bilstm_proj_size)# output Layerself.fc = nn.Linear(bilstm_hidden_size*2,output_size)def forward(self, x):x = self.conv1d(x)x = self.relu(x)x = self.maxpool(x)bilstm_out,_ = self.bilstm(x)x = self.fc(bilstm_out[:, -1, :])return x
3. 代码结构解读
3.1 CNN Layer
卷积层(Conv1d)用于提取局部特征,通常用于处理时间序列数据中的局部模式。它的输入是具有多个特征(例如风速、气压、湿度等)的时序数据。
相关代码:
# CNN Layer
self.conv1d = nn.Conv1d(in_channels=cnn_in_channels, out_channels=cnn_out_channels, kernel_size=cnn_kernal_size, padding="same")
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool1d(kernel_size= maxpool_kernal_size)
cnn_in_channels
: 表示输入通道数cnn_out_channels
: 表示卷积层的输出通道数cnn_kernal_size
: 为卷积核大小padding
: "same"表示特征输入大小和输出大小一致maxpool_kernal_size
: 为池化操作的核大小
3.2 BiLSTM Layer
双向长短期记忆网络(BiLSTM)用于捕捉时序数据中的长程依赖关系。
相关代码:
# BiLSTM Layer
self.bilstm = nn.LSTM(input_size = int(bilstm_input_size/maxpool_kernal_size),hidden_size = bilstm_hidden_size,num_layers = bilstm_num_layers,batch_first = True,dropout = dropout,bidirectional = True,proj_size = bilstm_proj_size
)
bilstm_input_size
: 表示输入的特征维度bilstm_hidden_size
: 表示LSTM隐藏状态的维度bilstm_num_layers
: 是LSTM的层数dropout
: 用于防止过拟合bilstm_proj_size
: 是LSTM的投影层大小(如果需要)
3.3 Output Layer
全连接层(fc)将BiLSTM的输出映射到最终的预测结果。输出的维度为output_size,通常是我们需要预测的目标维度(例如未来的功率值)。
相关代码:
# output Layer
self.fc = nn.Linear(bilstm_hidden_size*2, output_size)
output_size
: 输出维度大小
3.4 forward Layer
相关代码:
def forward(self, x):x = self.conv1d(x)x = self.relu(x)x = self.maxpool(x)bilstm_out,_ = self.bilstm(x)x = self.fc(bilstm_out[:, -1, :])return x
输入:
是一个三维张量,形状为[batch_size, input_channels, seq_len]
,其中input_channels
是输入数据的特征数(例如风速、湿度等),seq_len
是时间步数(即输入序列的长度)。CNN部分:
首先通过卷积层提取局部特征,然后应用ReLU激活函数引入非线性,最后通过最大池化(MaxPool1d)对特征进行降维,减少计算量。BiLSTM部分:
接着,将经过CNN处理后的特征传递给BiLSTM,捕捉时间序列中的长期依赖关系。BiLSTM的双向性使得模型能够同时考虑过去和未来的上下文信息。输出:
最终,模型通过全连接层(fc)将BiLSTM的最后一个时间步的输出映射为期望的输出大小。
4. 应用场景
这个模型适合用于处理时间序列数据的预测任务,特别是在风力发电预测、气象预测、股市预测等领域。CNN用于从输入数据中提取局部特征,而BiLSTM则能够捕捉输入数据的长期时序依赖关系。因此,模型既能有效地处理局部特征,又能关注到长时间范围内的依赖关系,从而提高预测的准确性。
5. 总结
本文详细介绍了如何使用Pytorch实现一个基于CNN和BiLSTM的混合模型
(CNN_BiLSTM)
。该模型结合了CNN在局部特征提取上的优势和BiLSTM在序列建模上的长程依赖能力,适用于时序数据的预测任务。在实际应用中,可以根据任务的不同调整CNN和LSTM的层数、通道数和隐藏状态维度等超参数,以提高模型的预测精度。
相关文章:
【时间序列预测】基于PyTorch实现CNN_BiLSTM算法
文章目录 1. CNN与BiLSTM2. 完整代码实现3. 代码结构解读3.1 CNN Layer3.2 BiLSTM Layer3.3 Output Layer3.4 forward Layer 4. 应用场景5. 总结 本文将详细介绍如何使用Pytorch实现一个结合卷积神经网络(CNN)和双向长短期记忆网络(BiLSTM&am…...

联想Y7000 2024版本笔记本 RTX4060安装ubuntu22.04双系统及深度学习环境配置
目录 1..制作启动盘 2.Windows 磁盘分区,删除原来ubuntu的启动项 3.四个设置 4.安装ubuntu 5.ubuntu系统配置 1..制作启动盘 先下载镜像文件,注意版本对应。Rufus - 轻松创建 USB 启动盘 用rufus制作时,需要注意选择正确的分区类型和系统类型。不然安装的系统会有问题…...

VuePress学习
1.介绍 VuePress 由两部分组成:第一部分是一个极简静态网站生成器 (opens new window),它包含由 Vue 驱动的主题系统和插件 API,另一个部分是为书写技术文档而优化的默认主题,它的诞生初衷是为了支持 Vue 及其子项目的文档需求。…...

一次“okhttp访问间隔60秒,提示unexpected end of stream“的问题排查过程
一、现象 okhttp调用某个服务,如果第二次访问间隔上一次访问时间超过60s,返回错误:"unexpected end of stream"。 二、最终定位原因: 空闲连接如果超过60秒,服务端会主动关闭连接。此时客户端恰巧访问了这…...
SQL最佳实践:避免使用COUNT=0
如果你遇到类似下面的 SQL 查询: SELECT * FROM customer c WHERE 0 (SELECT COUNT(*)FROM orders oWHERE o.customer_id c.customer_id);意味着有人没有遵循 SQL 最佳实践。该语句的作用是查找没有下过订单的客户,其中子查询使用了 COUNT 函数统计客…...
PG与ORACLE的差距
首先必须是XID 64,一个在极端环境下会FREEZE的数据库无论如何都无法承担关键业务系统的重任的,我们可以通过各种配置,提升硬件的性能,通过各种IT管控措施来尽可能避免在核心系统上面临FREEZE的风险,不过并不是每个企业…...
树莓派3B+驱动开发(2)- LED驱动(传统模式)
github主页:https://github.com/snqx-lqh 本项目github地址:https://github.com/snqx-lqh/RaspberryPiDriver 本项目硬件地址:https://oshwhub.com/from_zero/shu-mei-pai-kuo-zhan-ban 欢迎交流 笔记说明 如我在驱动开发总览中说的那样&…...
超详细搭建PhpStorm+PhpStudy开发环境
刚开始接触PHP开发,搭建开发环境是第一步,网上下载PhpStorm和PhpStudy软件,怎样安装和激活就不详细说了,我们重点来看一看怎样搭配这两个开发环境。 前提:现在假设你已经安装完PhpStorm和PhpStudy软件。 我的PhpStor…...
分析比对vuex和store模式
在 Vue 中,Vuex 和 store 模式 是两个不同的概念,它们紧密相关,主要用于管理应用的状态。下面我会详细介绍这两个概念,并通过例子帮助你更好地理解。 1. Vuex 是什么? Vuex 是 Vue.js 的一个状态管理库,用…...
C# 网络编程--基础核心内容
在现今软件开发中,网络编程是非常重要的一部分,本文简要介绍下网络编程的概念和实践。 C#网络编程的主要内容包括以下几个方面: : 上图引用大佬的图,大家也关注一下,有技术有品质,有国有家,情…...
【C++游戏程序】easyX图形库还原游戏《贪吃蛇大作战》(三)
承接上一篇文章:【C游戏程序】easyX图形库还原游戏《贪吃蛇大作战》(二),我们这次来补充一些游戏细节,以及增加吃食物加长角色长度等设定玩法,也是本游戏的最后一篇文章。 一.玩家边界检测 首先是用来检测…...

uni-app H5端使用注意事项 【跨端开发系列】
🔗 uniapp 跨端开发系列文章:🎀🎀🎀 uni-app 组成和跨端原理 【跨端开发系列】 uni-app 各端差异注意事项 【跨端开发系列】uni-app 离线本地存储方案 【跨端开发系列】uni-app UI库、框架、组件选型指南 【跨端开…...
SpringBoot中的@Configuration注解
在Spring Boot中,Configuration注解扮演着非常重要的角色,它是Spring框架中用于定义配置类的一个核心注解。以下是Configuration注解的主要作用: 定义配置类: 使用Configuration注解的类表示这是一个配置类,Spring容器…...

十二、路由、生命周期函数
router路由 页面路由指的是在应用程序中实现不同页面之间的跳转,以及数据传递。通过 Router 模块就可以实现这个功能 2.1创建页面 之前是创建的文件,使用路由的时候需要创建页面,步骤略有不同 方法 1:直接右键新建Page(常用)方法 2:单独添加页面并配置2.1.1直接右键新建…...

【蓝桥杯每日一题】X 进制减法
X 进制减法 2024-12-6 蓝桥杯每日一题 X 进制减法 贪心 进制转换 题目大意 进制规定了数字在数位上逢几进一。 XX 进制是一种很神奇的进制, 因为其每一数位的进制并不固定!例如说某 种 XX 进制数, 最低数位为二进制, 第二数位为十进制, 第三数位为八进制, 则 XX 进制…...

《蓝桥杯比赛规划》
大家好啊!我是NiJiMingCheng 我的博客:NiJiMingCheng 这节课我们来分享蓝桥杯比赛规划,好的规划会给我们的学习带来良好的收益,废话少说接下来就让我们进入学习规划吧,加油哦!!! 一、…...
C++算法练习day70——53.最大子序和
题目来源:. - 力扣(LeetCode) 题目思路分析 题目:寻找最大子数组和(也称为最大子序和)。 给定一个整数数组 nums,找到一个具有最大和的连续子数组(子数组最少包含一个元素&#x…...

import是如何“占领满屏“
import是如何“占领满屏“的? 《拒绝使用模块重导(Re-export)》 模块重导是一种通用的技术。在腾讯、字节、阿里等各大厂的组件库中都有大量使用。 如:字节的arco-design组件库中的组件:github.com/arco-design… …...
ceph /etc/ceph-csi-config/config.json: no such file or directory
环境 rook-ceph 部署的 ceph。 问题 kubectl describe pod dragonfly-redis-master-0Warning FailedMount 7m59s (x20 over 46m) kubelet MountVolume.MountDevice failed for volume "pvc-c63e159a-c940-4001-bf0d-e6141634cc55" : rpc error: cod…...

C语言——验证“哥德巴赫猜想”
问题描述: 验证"哥德巴赫猜想" 任何一个大于2的偶数都可以表示为两个质数之和。例如,4可以表示为22,6可以表示为33,8可以表示为35等 //验证"哥德巴赫猜想" //任何一个大于2的偶数都可以表示为两个质数之和…...

超短脉冲激光自聚焦效应
前言与目录 强激光引起自聚焦效应机理 超短脉冲激光在脆性材料内部加工时引起的自聚焦效应,这是一种非线性光学现象,主要涉及光学克尔效应和材料的非线性光学特性。 自聚焦效应可以产生局部的强光场,对材料产生非线性响应,可能…...
python如何将word的doc另存为docx
将 DOCX 文件另存为 DOCX 格式(Python 实现) 在 Python 中,你可以使用 python-docx 库来操作 Word 文档。不过需要注意的是,.doc 是旧的 Word 格式,而 .docx 是新的基于 XML 的格式。python-docx 只能处理 .docx 格式…...
Swagger和OpenApi的前世今生
Swagger与OpenAPI的关系演进是API标准化进程中的重要篇章,二者共同塑造了现代RESTful API的开发范式。 本期就扒一扒其技术演进的关键节点与核心逻辑: 🔄 一、起源与初创期:Swagger的诞生(2010-2014) 核心…...

【C++特殊工具与技术】优化内存分配(一):C++中的内存分配
目录 一、C 内存的基本概念 1.1 内存的物理与逻辑结构 1.2 C 程序的内存区域划分 二、栈内存分配 2.1 栈内存的特点 2.2 栈内存分配示例 三、堆内存分配 3.1 new和delete操作符 4.2 内存泄漏与悬空指针问题 4.3 new和delete的重载 四、智能指针…...
【Android】Android 开发 ADB 常用指令
查看当前连接的设备 adb devices 连接设备 adb connect 设备IP 断开已连接的设备 adb disconnect 设备IP 安装应用 adb install 安装包的路径 卸载应用 adb uninstall 应用包名 查看已安装的应用包名 adb shell pm list packages 查看已安装的第三方应用包名 adb shell pm list…...
tomcat指定使用的jdk版本
说明 有时候需要对tomcat配置指定的jdk版本号,此时,我们可以通过以下方式进行配置 设置方式 找到tomcat的bin目录中的setclasspath.bat。如果是linux系统则是setclasspath.sh set JAVA_HOMEC:\Program Files\Java\jdk8 set JRE_HOMEC:\Program Files…...
区块链技术概述
区块链技术是一种去中心化、分布式账本技术,通过密码学、共识机制和智能合约等核心组件,实现数据不可篡改、透明可追溯的系统。 一、核心技术 1. 去中心化 特点:数据存储在网络中的多个节点(计算机),而非…...
2.2.2 ASPICE的需求分析
ASPICE的需求分析是汽车软件开发过程中至关重要的一环,它涉及到对需求进行详细分析、验证和确认,以确保软件产品能够满足客户和用户的需求。在ASPICE中,需求分析的关键步骤包括: 需求细化:将从需求收集阶段获得的高层需…...
LTR-381RGB-01RGB+环境光检测应用场景及客户类型主要有哪些?
RGB环境光检测 功能,在应用场景及客户类型: 1. 可应用的儿童玩具类型 (1) 智能互动玩具 功能:通过检测环境光或物体颜色触发互动(如颜色识别积木、光感音乐盒)。 客户参考: LEGO(乐高&#x…...

本地部署drawDB结合内网穿透技术实现数据库远程管控方案
文章目录 前言1. Windows本地部署DrawDB2. 安装Cpolar内网穿透3. 实现公网访问DrawDB4. 固定DrawDB公网地址 前言 在数字化浪潮席卷全球的背景下,数据治理能力正日益成为构建现代企业核心竞争力的关键因素。无论是全球500强企业的数据中枢系统,还是初创…...