当前位置: 首页 > news >正文

【时间序列预测】基于PyTorch实现CNN_BiLSTM算法

文章目录

  • 1. CNN与BiLSTM
  • 2. 完整代码实现
  • 3. 代码结构解读
    • 3.1 CNN Layer
    • 3.2 BiLSTM Layer
    • 3.3 Output Layer
    • 3.4 forward Layer
  • 4. 应用场景
  • 5. 总结

  本文将详细介绍如何使用Pytorch实现一个结合卷积神经网络(CNN)双向长短期记忆网络(BiLSTM)的混合模型—CNN_BiLSTM。这种网络架构结合了CNN在提取局部特征方面的优势和BiLSTM在建模序列数据时的长期依赖关系的能力,特别适用于时序数据的预测任务,如时间序列分析、风速预测、股票预测等。

1. CNN与BiLSTM

  • CNN主要通过卷积操作对输入数据进行特征提取,适合于处理局部结构化的特征(如图像数据、时间序列数据中的局部模式)。
  • BiLSTM则是基于LSTM的变种,它通过双向遍历序列,可以同时捕捉过去和未来的信息,使其在处理时间序列数据时非常有效。
  • 在本例中,CNN负责提取时间序列数据的局部特征,而BiLSTM则进一步捕捉数据中的时序依赖关系,最终通过全连接层输出预测结果。

2. 完整代码实现

"""
CNN_BiLSTM Network
"""
from torch import nnclass CNN_BiLSTM(nn.Module):r"""CNN_BiLSTMArgs:cnn_in_channels : CNN输入通道数, if in.shape=[64,7,18] value=7bilstm_input_size : bilstm输入大小, if in.shape=[64,7,18] value=18output_size :  期望网络输出大小cnn_out_channels:  CNN层输出通道数cnn_kernal_size :  CNN层卷积核大小maxpool_kernal_size:  MaxPool Layer kernal_sizebilstm_hidden_size: BiLSTM Layer hidden_dimbilstm_num_layers: BiLSTM Layer num_layersdropout:  dropout防止过拟合, 取值(0,1)bilstm_proj_size: BiLSTM Layer proj_sizeExample:>>> import torch>>> input = torch.randn([64,7,18])>>> model = CNN_BiLSTM(7, 18,18)>>> out = model(input)"""def __init__(self,cnn_in_channels,bilstm_input_size,output_size,cnn_out_channels=32,cnn_kernal_size=3,maxpool_kernal_size=3,bilstm_hidden_size=128,bilstm_num_layers=4,dropout = 0.05,bilstm_proj_size = 0):super().__init__()# CNN Layerself.conv1d = nn.Conv1d(in_channels=cnn_in_channels, out_channels=cnn_out_channels, kernel_size=cnn_kernal_size, padding="same")self.relu = nn.ReLU()self.maxpool = nn.MaxPool1d(kernel_size= maxpool_kernal_size)# BiLSTM Layerself.bilstm = nn.LSTM(input_size = int(bilstm_input_size/maxpool_kernal_size),hidden_size = bilstm_hidden_size,num_layers = bilstm_num_layers,batch_first = True,dropout = dropout,bidirectional = True,proj_size = bilstm_proj_size)# output Layerself.fc = nn.Linear(bilstm_hidden_size*2,output_size)def forward(self, x):x = self.conv1d(x)x = self.relu(x)x = self.maxpool(x)bilstm_out,_ = self.bilstm(x)x = self.fc(bilstm_out[:, -1, :])return x      

3. 代码结构解读

3.1 CNN Layer

卷积层(Conv1d)用于提取局部特征,通常用于处理时间序列数据中的局部模式。它的输入是具有多个特征(例如风速、气压、湿度等)的时序数据。

相关代码:

# CNN Layer
self.conv1d = nn.Conv1d(in_channels=cnn_in_channels, out_channels=cnn_out_channels, kernel_size=cnn_kernal_size, padding="same")
self.relu = nn.ReLU()
self.maxpool = nn.MaxPool1d(kernel_size= maxpool_kernal_size)
  • cnn_in_channels : 表示输入通道数
  • cnn_out_channels : 表示卷积层的输出通道数
  • cnn_kernal_size : 为卷积核大小
  • padding : "same"表示特征输入大小和输出大小一致
  • maxpool_kernal_size : 为池化操作的核大小

3.2 BiLSTM Layer

双向长短期记忆网络(BiLSTM)用于捕捉时序数据中的长程依赖关系。

相关代码:

# BiLSTM Layer
self.bilstm = nn.LSTM(input_size = int(bilstm_input_size/maxpool_kernal_size),hidden_size = bilstm_hidden_size,num_layers = bilstm_num_layers,batch_first = True,dropout = dropout,bidirectional = True,proj_size = bilstm_proj_size
)
  • bilstm_input_size : 表示输入的特征维度
  • bilstm_hidden_size : 表示LSTM隐藏状态的维度
  • bilstm_num_layers : 是LSTM的层数
  • dropout : 用于防止过拟合
  • bilstm_proj_size : 是LSTM的投影层大小(如果需要)

3.3 Output Layer

全连接层(fc)将BiLSTM的输出映射到最终的预测结果。输出的维度为output_size,通常是我们需要预测的目标维度(例如未来的功率值)。

相关代码:

# output Layer
self.fc = nn.Linear(bilstm_hidden_size*2, output_size)

output_size : 输出维度大小

3.4 forward Layer

相关代码:

def forward(self, x):x = self.conv1d(x)x = self.relu(x)x = self.maxpool(x)bilstm_out,_ = self.bilstm(x)x = self.fc(bilstm_out[:, -1, :])return x      
  • 输入: 是一个三维张量,形状为 [batch_size, input_channels, seq_len],其中input_channels是输入数据的特征数(例如风速、湿度等),seq_len是时间步数(即输入序列的长度)。
  • CNN部分:首先通过卷积层提取局部特征,然后应用ReLU激活函数引入非线性,最后通过最大池化(MaxPool1d)对特征进行降维,减少计算量。
  • BiLSTM部分: 接着,将经过CNN处理后的特征传递给BiLSTM,捕捉时间序列中的长期依赖关系。BiLSTM的双向性使得模型能够同时考虑过去和未来的上下文信息。
  • 输出: 最终,模型通过全连接层(fc)将BiLSTM的最后一个时间步的输出映射为期望的输出大小。

4. 应用场景

  这个模型适合用于处理时间序列数据的预测任务,特别是在风力发电预测、气象预测、股市预测等领域。CNN用于从输入数据中提取局部特征,而BiLSTM则能够捕捉输入数据的长期时序依赖关系。因此,模型既能有效地处理局部特征,又能关注到长时间范围内的依赖关系,从而提高预测的准确性。

5. 总结

  本文详细介绍了如何使用Pytorch实现一个基于CNN和BiLSTM的混合模型(CNN_BiLSTM)。该模型结合了CNN在局部特征提取上的优势和BiLSTM在序列建模上的长程依赖能力,适用于时序数据的预测任务。在实际应用中,可以根据任务的不同调整CNN和LSTM的层数、通道数和隐藏状态维度等超参数,以提高模型的预测精度。

相关文章:

【时间序列预测】基于PyTorch实现CNN_BiLSTM算法

文章目录 1. CNN与BiLSTM2. 完整代码实现3. 代码结构解读3.1 CNN Layer3.2 BiLSTM Layer3.3 Output Layer3.4 forward Layer 4. 应用场景5. 总结 本文将详细介绍如何使用Pytorch实现一个结合卷积神经网络(CNN)和双向长短期记忆网络(BiLSTM&am…...

联想Y7000 2024版本笔记本 RTX4060安装ubuntu22.04双系统及深度学习环境配置

目录 1..制作启动盘 2.Windows 磁盘分区,删除原来ubuntu的启动项 3.四个设置 4.安装ubuntu 5.ubuntu系统配置 1..制作启动盘 先下载镜像文件,注意版本对应。Rufus - 轻松创建 USB 启动盘 用rufus制作时,需要注意选择正确的分区类型和系统类型。不然安装的系统会有问题…...

VuePress学习

1.介绍 VuePress 由两部分组成:第一部分是一个极简静态网站生成器 (opens new window),它包含由 Vue 驱动的主题系统和插件 API,另一个部分是为书写技术文档而优化的默认主题,它的诞生初衷是为了支持 Vue 及其子项目的文档需求。…...

一次“okhttp访问间隔60秒,提示unexpected end of stream“的问题排查过程

一、现象 okhttp调用某个服务,如果第二次访问间隔上一次访问时间超过60s,返回错误:"unexpected end of stream"。 二、最终定位原因: 空闲连接如果超过60秒,服务端会主动关闭连接。此时客户端恰巧访问了这…...

SQL最佳实践:避免使用COUNT=0

如果你遇到类似下面的 SQL 查询: SELECT * FROM customer c WHERE 0 (SELECT COUNT(*)FROM orders oWHERE o.customer_id c.customer_id);意味着有人没有遵循 SQL 最佳实践。该语句的作用是查找没有下过订单的客户,其中子查询使用了 COUNT 函数统计客…...

PG与ORACLE的差距

首先必须是XID 64,一个在极端环境下会FREEZE的数据库无论如何都无法承担关键业务系统的重任的,我们可以通过各种配置,提升硬件的性能,通过各种IT管控措施来尽可能避免在核心系统上面临FREEZE的风险,不过并不是每个企业…...

树莓派3B+驱动开发(2)- LED驱动(传统模式)

github主页:https://github.com/snqx-lqh 本项目github地址:https://github.com/snqx-lqh/RaspberryPiDriver 本项目硬件地址:https://oshwhub.com/from_zero/shu-mei-pai-kuo-zhan-ban 欢迎交流 笔记说明 如我在驱动开发总览中说的那样&…...

超详细搭建PhpStorm+PhpStudy开发环境

刚开始接触PHP开发,搭建开发环境是第一步,网上下载PhpStorm和PhpStudy软件,怎样安装和激活就不详细说了,我们重点来看一看怎样搭配这两个开发环境。 前提:现在假设你已经安装完PhpStorm和PhpStudy软件。 我的PhpStor…...

分析比对vuex和store模式

在 Vue 中,Vuex 和 store 模式 是两个不同的概念,它们紧密相关,主要用于管理应用的状态。下面我会详细介绍这两个概念,并通过例子帮助你更好地理解。 1. Vuex 是什么? Vuex 是 Vue.js 的一个状态管理库,用…...

C# 网络编程--基础核心内容

在现今软件开发中,网络编程是非常重要的一部分,本文简要介绍下网络编程的概念和实践。 C#网络编程的主要内容包括以下几个方面‌: : 上图引用大佬的图,大家也关注一下,有技术有品质,有国有家,情…...

【C++游戏程序】easyX图形库还原游戏《贪吃蛇大作战》(三)

承接上一篇文章:【C游戏程序】easyX图形库还原游戏《贪吃蛇大作战》(二),我们这次来补充一些游戏细节,以及增加吃食物加长角色长度等设定玩法,也是本游戏的最后一篇文章。 一.玩家边界检测 首先是用来检测…...

uni-app H5端使用注意事项 【跨端开发系列】

🔗 uniapp 跨端开发系列文章:🎀🎀🎀 uni-app 组成和跨端原理 【跨端开发系列】 uni-app 各端差异注意事项 【跨端开发系列】uni-app 离线本地存储方案 【跨端开发系列】uni-app UI库、框架、组件选型指南 【跨端开…...

SpringBoot中的@Configuration注解

在Spring Boot中,Configuration注解扮演着非常重要的角色,它是Spring框架中用于定义配置类的一个核心注解。以下是Configuration注解的主要作用: 定义配置类: 使用Configuration注解的类表示这是一个配置类,Spring容器…...

十二、路由、生命周期函数

router路由 页面路由指的是在应用程序中实现不同页面之间的跳转,以及数据传递。通过 Router 模块就可以实现这个功能 2.1创建页面 之前是创建的文件,使用路由的时候需要创建页面,步骤略有不同 方法 1:直接右键新建Page(常用)方法 2:单独添加页面并配置2.1.1直接右键新建…...

【蓝桥杯每日一题】X 进制减法

X 进制减法 2024-12-6 蓝桥杯每日一题 X 进制减法 贪心 进制转换 题目大意 进制规定了数字在数位上逢几进一。 XX 进制是一种很神奇的进制, 因为其每一数位的进制并不固定!例如说某 种 XX 进制数, 最低数位为二进制, 第二数位为十进制, 第三数位为八进制, 则 XX 进制…...

《蓝桥杯比赛规划》

大家好啊!我是NiJiMingCheng 我的博客:NiJiMingCheng 这节课我们来分享蓝桥杯比赛规划,好的规划会给我们的学习带来良好的收益,废话少说接下来就让我们进入学习规划吧,加油哦!!! 一、…...

C++算法练习day70——53.最大子序和

题目来源:. - 力扣(LeetCode) 题目思路分析 题目:寻找最大子数组和(也称为最大子序和)。 给定一个整数数组 nums,找到一个具有最大和的连续子数组(子数组最少包含一个元素&#x…...

import是如何“占领满屏“

import是如何“占领满屏“的? 《拒绝使用模块重导(Re-export)》 模块重导是一种通用的技术。在腾讯、字节、阿里等各大厂的组件库中都有大量使用。 如:字节的arco-design组件库中的组件:github.com/arco-design… …...

ceph /etc/ceph-csi-config/config.json: no such file or directory

环境 rook-ceph 部署的 ceph。 问题 kubectl describe pod dragonfly-redis-master-0Warning FailedMount 7m59s (x20 over 46m) kubelet MountVolume.MountDevice failed for volume "pvc-c63e159a-c940-4001-bf0d-e6141634cc55" : rpc error: cod…...

C语言——验证“哥德巴赫猜想”

问题描述: 验证"哥德巴赫猜想" 任何一个大于2的偶数都可以表示为两个质数之和。例如,4可以表示为22,6可以表示为33,8可以表示为35等 //验证"哥德巴赫猜想" //任何一个大于2的偶数都可以表示为两个质数之和…...

Flourish笔记:柱状图(Column chart (grouped))

文章目录 样式设定Chart Type:图表类型Controls & Filters:展示方式Colors:颜色bars:柱子的调整labels:柱子数字标注X axis:横坐标标签Y axis:纵坐标标签Plot BackgroundNumber FormatingLe…...

深度学习案例:DenseNet + SE-Net

本文为为🔗365天深度学习训练营内部文章 原作者:K同学啊 一 回顾DenseNet算法 DenseNet(Densely Connected Convolutional Networks)是一种深度卷积神经网络架构,提出的核心思想是通过在每一层与前面所有层进行直接连接…...

excel文件合并,每个excel名称插入excel列

import pandas as pd import os # 设置文件夹路径 folder_path rC:\test # 替换为您的下载文件夹路径 output_file os.path.join(folder_path, BOM材料.xlsx) # 创建一个空的 DataFrame 用于存储合并的数据 combined_data pd.DataFrame() # 遍历文件夹中的所有文件 for …...

Linux 如何设置特殊权限?

简介 通过使用 setuid、setgid 、sticky,它们是 Linux 中的特殊权限,可以对文件和目录的访问和执行方式提供额外的控制。 命令八进制数字功能setuid4当执行文件时,它以文件所有者的权限运行,而不是执行它的用户的权限运行。setg…...

零基础如何使用ChatGPT快速学习Python

引言 AI编程时代来临,没有编程基础可以快速上车享受时代的红利吗?答案是肯定的。本文旨在介绍零基础如何利用ChatGPT快速学习Python编程语言,开启AI编程之路。解决的问题包括:传统学习方式效率低、缺乏互动性以及学习资源质量参差…...

【开源】一款基于SpringBoot 的全开源充电桩平台

一、下载项目文件 下载源码项目文件口令:动作璆璜量子屏多好/~d1b8356ox2~:/复制口令后,进入夸克网盘app即可保存(如果复制到夸克app没有跳转资源,可以复制粘贴口令到夸克app的搜索框也可以打开(不用点搜索按钮&#…...

AI - RAG中的状态化管理聊天记录

AI - RAG中的状态化管理聊天记录 大家好,今天我们来聊聊LangChain和LLM中一个重要的话题——状态化管理聊天记录。在使用大语言模型(LLM)的时候,聊天记录(History)和状态(State)管理是非常关键的。那我们先…...

JAVA安全—SpringBoot框架MyBatis注入Thymeleaf模板注入

前言 之前我们讲了JAVA的一些组件安全,比如Log4j,fastjson。今天讲一下框架安全,就是这个也是比较常见的SpringBoot框架。 SpringBoot框架 Spring Boot是由Pivotal团队提供的一套开源框架,可以简化spring应用的创建及部署。它提…...

【STM32系列】提升ADC采样精度的方法

资料地址 兆易创新GigaDevice-资料下载兆易创新GD32 MCU ADC简介 ADC转换包括采样、保持、量化、编码四个步骤。的采样电容上,即在采样开关 SW 关闭的过程中,外部输入信号通过外部的输入电阻 RAIN 和以及 ADC 采样电阻 RADC 对采样电容 CADC 充电。采样…...

前端面试如何出彩

1、原型链和作用域链说不太清,主要表现在寄生组合继承和extends继承的区别和new做了什么。2、推荐我的两篇文章:若川:面试官问:能否模拟实现JS的new操作符、若川:面试官问:JS的继承 3、数组构造函数上有哪些…...

重庆农村网站建设/旺道seo营销软件

前言这个是我花49c币在csdn上下载的,稍微做了些修改,给大家分享一下。觉得还不错的点个赞吧。先上执行完毕后的控制台输出截图:package com.dq.utils;import java.math.BigDecimal;import java.math.BigInteger;import java.math.RoundingMod…...

排名优化软件/seo sem

前言 利用Python实现抖音字符视频。废话不多说。 让我们愉快地开始吧~ 开发工具 Python版本:3.6.4 相关模块: cv2模块; PIL模块; numpy模块; 以及一些Python自带的模块。 环境搭建 安装Python并添加到环境变…...

建设网站设计公司/他达拉非片的作用及功效副作用

美国网络安全公司FireEye研究人员发现,最近有不法分子向中东多家银行发送大量包含恶意代码的垃圾邮件,一旦用户打开附件,其中的恶意程序就会盗取银行的网络信息,为后续顺利实施盗窃做准备。 这一消息来自于香港《文汇报》&#xf…...

长春电商网站建设公司排名/百度网站推广一年多少钱

本人为C初学者,使用的是Linux C一站式学习,打算把课后习题中自己写的代码都记下来,本打算在论坛里盖楼的,但是发现论坛太麻烦了,不能连续回复三次,看到max_min_ 同学的回复决定来博客记录吧,这里…...

wordpress 买主题/安卓优化神器

皮一下Git操作流程第一部分 命令行1、分支操作1. git branch 创建分支2. git checkout -b 创建并切换到新建的分支上3. git checkout 切换分支4. git branch 查看分支列表5. git branch -v 查看所有分支的最后一次操作6. git branch -vv 查看当前分支7. git brabch -b 分支名 o…...

wordpress全功能主题/邮件营销

点击写配置按钮,直接把编辑框1的内容读取出来,然后加密数据,写到配置项里面,我选择的是DES加密。 读配置是读取配置文件,解密数据,写内容到编辑框1,但是这时解密失败了,在网上百度了…...