当前位置: 首页 > news >正文

【Python时序预测系列】基于LSTM实现多输入多输出单步预测(案例+源码)

这是我的第312篇原创文章。

一、引言

单站点多变量输入多变量输出单步预测问题----基于LSTM实现。

多输入就是输入多个特征变量

多输出就是同时预测出多个标签的结果

单步就是利用过去N天预测未来1天的结果

二、实现过程

2.1 读取数据集

df=pd.read_csv("data.csv", parse_dates=["Date"], index_col=[0])
print(df.shape)
print(df.head())
fea_num = len(df.columns)

df:

图片

2.2 划分数据集

# 拆分数据集为训练集和测试集
test_split=round(len(df)*0.20)
df_for_training=df[:-test_split]
df_for_testing=df[-test_split:]# 绘制训练集和测试集的折线图
plt.figure(figsize=(10, 6))
plt.plot(train_data, label='Training Data')
plt.plot(test_data, label='Testing Data')
plt.xlabel('Year')
plt.ylabel('Passenger Count')
plt.title('International Airline Passengers - Training and Testing Data')
plt.legend()
plt.show()

共5203条数据,8:2划分:训练集4162,测试集1041。

训练集和测试集:

图片

2.3 归一化

# 将数据归一化到 0~1 范围(整体一起做归一化)
scaler = MinMaxScaler(feature_range=(0,1))
df_for_training_scaled = scaler.fit_transform(df_for_training)
df_for_testing_scaled=scaler.transform(df_for_testing)

2.4 构造LSTM数据集(时序-->监督学习)

def createXY(data, win_size, target_feature_idxs):passwin_size = 12 # 时间窗口
target_feature_idxs = [0, 1, 2, 3, 4] # 指定待预测特征列索引
trainX, trainY = createXY(df_for_training_scaled, win_size, target_feature_idxs)
testX, testY = createXY(df_for_testing_scaled, win_size, target_feature_idxs)
print("训练集形状:", trainX.shape, trainY.shape)
print("测试集形状:", testX.shape, testY.shape)# 将数据集转换为 LSTM 模型所需的形状(样本数,时间步长,特征数)
trainX = np.reshape(trainX, (trainX.shape[0], win_size, fea_num))
testX = np.reshape(testX, (testX.shape[0], win_size, fea_num))print("trainX Shape-- ",trainX.shape)
print("trainY Shape-- ",trainY.shape)
print("testX Shape-- ",testX.shape)
print("testY Shape-- ",testY.shape)

滑动窗口设置为12:

取出df_for_training_scaled第【1-12】行第【1-5】列的12条数据作为trainX[0],取出df_for_training_scaled第【13】行第【1-5】列的1条数据作为trainY[0];依此类推。最终构造出的训练集数量(4150)比划分时候的训练集数量(4162)少一个滑动窗口(12)。

图片

trainX是一个(4150,12,5)的三维数组,三个维度分布表示(样本数量,步长,特征数),每一个样本比如trainX[0]是一个(12,5)二维数组表示(步长,特征数),这也是LSTM模型每一步的输入。

图片

trainY是一个(4150,5)的二维数组,二个维度分布表示(样本数量,标签数),每一个样本比如trainY[0]是一个(5,)一维数组表示(标签数,),这也是LSTM模型每一步的输出。

图片

2.5 建立模拟合模型

# 输入维度
input_shape = Input(shape=(trainX.shape[1], trainX.shape[2]))
# LSTM层
lstm_layer = LSTM(128, activation='relu')(input_shape)
# 全连接层
dense_1 = Dense(64, activation='relu')(lstm_layer)
dense_2 = Dense(32, activation='relu')(dense_1)
# 输出层
output_1 = Dense(1, name='Open')(dense_2)
output_2 = Dense(1, name='High')(dense_2)
output_3 = Dense(1, name='Low')(dense_2)
output_4 = Dense(1, name='Close')(dense_2)
output_5 = Dense(1, name='AdjClose')(dense_2)
model = Model(inputs = input_shape, outputs = [output_1, output_2, output_3, output_4, output_5])
model.compile(loss='mse', optimizer='adam')
model.summary()

这是一个多输入多输出的 LSTM 模型,接受包含12个时间步长和5个特征的输入序列,在经过一层128个神经元的 LSTM 层和5个全连接层后,输出5个单独的预测结果,分别是 Open、High、 Low、Close和 AdjClose。

图片

进行训练,这里[trainY[:,i] for i in range(trainY.shape[1])]把原来的trainY做了转置,是一个(5,4150)的二维数组,分别表示(标签数,样本数)。相当于建立了5个通道,每个通道是(4150,)的一维数组。

history = model.fit(trainX, [trainY[:,i] for i in range(trainY.shape[1])], epochs=20, batch_size=32)

2.6 进行预测

进行预测,上面我们分析过模型每一步的输入是一个(12,5)二维数组表示(步长,特征数),模型每一步的输出是是一个(5,)一维数组表示(标签数,)

prediction_test = model.predict(testX)

如果直接model.predict(testX),testX的形状是(1029,12,5),是一个批量预测,输出prediction_test是一个(5,1029,1)的三维数组,prediction_test[0]就是第一个标签的预测结果,prediction_test[1]就是第二个标签的预测结果...多输出就是同时预测出多个标签的结果

图片

2.7 预测效果展示

分析一下第一个变量open的效果,i=0:

prediction_train = model.predict(trainX)
prediction_train0=model.predict(trainX)[i]
prediction_train_copies_array = ...
pred_train=...
original_train_copies_array = trainY
original_train=...
print("train Pred Values-- ", pred_train)
print("\ntrain Original Values-- ", original_train)
plt.plot(df_for_training.index[win_size:,], original_train, color = 'red', label = '真实值')
plt.plot(df_for_training.index[win_size:,], pred_train, color = 'blue', label = '预测值')
plt.title('Stock Price Prediction')
plt.xlabel('Time')
plt.xticks(rotation=45)
plt.ylabel('Stock Price')
plt.legend()
plt.show()

训练集真实值与预测值:

图片

prediction_test = model.predict(testX)
prediction_test0=model.predict(testX)[i]
prediction_test_copies_array = ...
pred_test=...
original_test_copies_array = testY
original_test=...
print("\ntest Original Values-- ", original_test)
plt.plot(df_for_testing.index[win_size:,], original_test, color = 'red', label = '真实值')
plt.plot(df_for_testing.index[win_size:,], pred_test, color = 'blue', label = '预测值')
plt.title('Stock Price Prediction')
plt.xlabel('Time')
plt.xticks(rotation=45)
plt.ylabel('Stock Price')
plt.legend()
plt.show()

测试集真实值与预测值:

图片

2.8 评估指标

图片

作者简介:

读研期间发表6篇SCI数据挖掘相关论文,现在某研究院从事数据算法相关科研工作,结合自身科研实践经历不定期分享关于Python、机器学习、深度学习、人工智能系列基础知识与应用案例。致力于只做原创,以最简单的方式理解和学习,关注我一起交流成长。需要数据集和源码的小伙伴可以关注底部公众号添加作者微信。

相关文章:

【Python时序预测系列】基于LSTM实现多输入多输出单步预测(案例+源码)

这是我的第312篇原创文章。 一、引言 单站点多变量输入多变量输出单步预测问题----基于LSTM实现。 多输入就是输入多个特征变量 多输出就是同时预测出多个标签的结果 单步就是利用过去N天预测未来1天的结果 二、实现过程 2.1 读取数据集 dfpd.read_csv("data.csv&qu…...

git客户端工具之Github,适用于windows和mac

对于我本人,我已经习惯了使用Github Desktop,不同的公司使用的代码管理平台不一样,就好奇Github Desktop是不是也适用于其他平台,结果是可以的。 一、克隆代码 File --> Clone repository… 选择第三种URL方式,输入url &…...

ai除安卓手机版APP软件一键操作自动渲染去擦消稀缺资源下载

安卓手机版:点击下载 苹果手机版:点击下载 电脑版(支持Mac和Windows):点击下载 一款全新的AI除安卓手机版APP,一键操作,轻松实现自动渲染和去擦消效果,稀缺资源下载 1、一键操作&…...

Unity获取剪切板内容粘贴板图片文件文字

最近做了一个发送消息的unity项目,需要访问剪切板里面的图片文字文件等,翻遍了网上的东西,看了不是需要导入System.Windows.Forms(关键导入了unity还不好用,只能用在纯c#项目中),所以我看了下py…...

利用谷歌云serverless代码托管服务Cloud Functions构建Gemini Pro API

谷歌在2024年4月发布了全新一代的多模态模型Gemini 1.5 Pro,Gemini 1.5 Pro不仅能够生成创意文本和代码,还能理解、总结上传的图片、视频和音频内容,并且支持高达100万tokens的上下文。在多个基准测试中表现优异,性能超越了ChatGP…...

极狐GitLab 17.0 重磅发布,100+ DevSecOps功能更新来啦~【一】

GitLab 是一个全球知名的一体化 DevOps 平台,很多人都通过私有化部署 GitLab 来进行源代码托管。极狐GitLab :https://gitlab.cn/install?channelcontent&utm_sourcecsdn 是 GitLab 在中国的发行版,专门为中国程序员服务。可以一键式部署…...

python实现符文加、解密

在历史悠久的加密技术中,恺撒密码以其简单却有效的原理闻名。通过固定的字母位移,明文可以被转换成密文,而解密则是逆向操作。这种技术不仅适用于英文字母,还可以扩展到其他语言的字符体系,如日语的平假名或汉语的拼音…...

【解释】i.MX6ULL_IO_电气属性说明

【解释】i.MX6ULL_IO_电气属性说明 文章目录 1 Hyst1.1 迟滞(Hysteresis)是什么?1.2 GPIO的Hyst. Enable Field 参数1.3 应用场景 2 Pull / Keep Select Field2.1 PUE_0_Keeper — Keeper2.2 PUE_1_Pull — Pull2.3 选择Keeper还是Pull 3 Dr…...

02-《石莲》

石 莲 石莲(学名:Sinocrassula indica A.Berger),别名因地卡,为二年生草本植物,全株无毛,具须根。花茎高15-60厘米,直立,常被微乳头状突起。茎生叶互生,宽倒披…...

MySQL之聚簇索引和非聚簇索引

1、什么是聚簇索引和非聚簇索引? 聚簇索引,通常也叫聚集索引。 非聚簇索引,指的是二级索引。 下面看一下它们的含义: 1.1、聚集索引选取规则 如果存在主键,主键索引就是聚集索引。如果不存在主键,将使…...

Web后端开发之前后端交互

http协议 http ● 超文本传输协议 (HyperText Transfer Protocol)服务器传输超文本到本地浏览器的传送协议 是互联网上应用最为流行的一种网络协议,用于定义客户端浏览器和服务器之间交换数据的过程。 HTTP是一个基于TCP/IP通信协议来传递数据. HTT…...

520. 检测大写字母 Easy

我们定义,在以下情况时,单词的大写用法是正确的: 全部字母都是大写,比如 "USA" 。 单词中所有字母都不是大写,比如 "leetcode" 。 如果单词不只含有一个字母,只有首字母大写&#xff0…...

vue的跳转传参

1、接收参数使用route,route包含路由信息,接收参数有两种方式,params和query path跳转只能使用query传参,name跳转都可以 params:获取来自动态路由的参数 query:获取来自search部分的参数 写法 path跳,query传 传参数 import { useRout…...

docker配置镜像源

1)打开 docker配置文件 sudo nano /etc/docker/daemon.json 2)添加 国内镜像源 {"registry-mirrors": ["https://akchsmlh.mirror.aliyuncs.com","https://registry.docker-cn.com","https://docker.mirrors.ustc…...

MySQL高级-SQL优化-insert优化-批量插入-手动提交事务-主键顺序插入

文章目录 1、批量插入1.1、大批量插入数据1.2、启动Linux中的mysql服务1.3、客户端连接到mysql数据库,加上参数 --local-infile1.4、查询当前会话中 local_infile 系统变量的值。1.5、开启从本地文件加载数据到服务器的功能1.6、创建表 tb_user 结构1.7、上传文件到…...

认识100种电路之振荡电路

在电子电路领域,振荡是一项至关重要的功能。那么,为什么电路中需要振荡?其背后的原理是什么?让我们一同深入探究。 【为什么需要振荡电路】 简单来说,振荡电路的存在是为了产生周期性的信号。在众多电子设备中&#…...

SSH 无密登录配置流程

一、免密登录原理 非对称加密: 由于对称加密的存在弊端,就产生了非对称加密,非对称加密中有两个密钥:公钥和私钥。公钥由私钥产生,但却无法推算出私钥;公钥加密后的密文,只能通过对应的私钥来解…...

Python自动化运维 系统基础信息模块

1.系统信息的收集 系统信息的收集,对于服务质量的把控,服务的监控等来说是非常重要的组成部分,甚至是核心的基础支撑部分。我们可以通过大量的核心指标数据,结合对应的检测体系,快速的发现异常现象的苗头,进…...

如何安装和配置Monit

前些天发现了一个巨牛的人工智能学习网站,通俗易懂,风趣幽默,忍不住分享一下给大家。点击跳转到网站。 关于 Monit Monit 是一个有用的程序,可以自动监控和管理服务器程序,以确保它们不仅保持在线,而且文…...

【redis】redis分片集群基础知识

1、基本概念 1.1定义 分片:数据按照某种规则(比如哈希)被分割成多个片段(或分片),每个片段被称为一个槽(slot)。槽是Redis分片集群中数据的基本单元。节点:Redis分片集…...

深入浅出Asp.Net Core MVC应用开发系列-AspNetCore中的日志记录

ASP.NET Core 是一个跨平台的开源框架,用于在 Windows、macOS 或 Linux 上生成基于云的新式 Web 应用。 ASP.NET Core 中的日志记录 .NET 通过 ILogger API 支持高性能结构化日志记录,以帮助监视应用程序行为和诊断问题。 可以通过配置不同的记录提供程…...

如何在看板中体现优先级变化

在看板中有效体现优先级变化的关键措施包括:采用颜色或标签标识优先级、设置任务排序规则、使用独立的优先级列或泳道、结合自动化规则同步优先级变化、建立定期的优先级审查流程。其中,设置任务排序规则尤其重要,因为它让看板视觉上直观地体…...

理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端

🌟 什么是 MCP? 模型控制协议 (MCP) 是一种创新的协议,旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议,它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决 问题背景 在一个基于 Spring Cloud Gateway WebFlux 构建的微服务项目中,新增了一个本地验证码接口 /code,使用函数式路由(RouterFunction)和 Hutool 的 Circle…...

2023赣州旅游投资集团

单选题 1.“不登高山,不知天之高也;不临深溪,不知地之厚也。”这句话说明_____。 A、人的意识具有创造性 B、人的认识是独立于实践之外的 C、实践在认识过程中具有决定作用 D、人的一切知识都是从直接经验中获得的 参考答案: C 本题解…...

Go 语言并发编程基础:无缓冲与有缓冲通道

在上一章节中,我们了解了 Channel 的基本用法。本章将重点分析 Go 中通道的两种类型 —— 无缓冲通道与有缓冲通道,它们在并发编程中各具特点和应用场景。 一、通道的基本分类 类型定义形式特点无缓冲通道make(chan T)发送和接收都必须准备好&#xff0…...

人工智能(大型语言模型 LLMs)对不同学科的影响以及由此产生的新学习方式

今天是关于AI如何在教学中增强学生的学习体验,我把重要信息标红了。人文学科的价值被低估了 ⬇️ 转型与必要性 人工智能正在深刻地改变教育,这并非炒作,而是已经发生的巨大变革。教育机构和教育者不能忽视它,试图简单地禁止学生使…...

RabbitMQ入门4.1.0版本(基于java、SpringBoot操作)

RabbitMQ 一、RabbitMQ概述 RabbitMQ RabbitMQ最初由LShift和CohesiveFT于2007年开发,后来由Pivotal Software Inc.(现为VMware子公司)接管。RabbitMQ 是一个开源的消息代理和队列服务器,用 Erlang 语言编写。广泛应用于各种分布…...

QT开发技术【ffmpeg + QAudioOutput】音乐播放器

一、 介绍 使用ffmpeg 4.2.2 在数字化浪潮席卷全球的当下,音视频内容犹如璀璨繁星,点亮了人们的生活与工作。从短视频平台上令人捧腹的搞笑视频,到在线课堂中知识渊博的专家授课,再到影视平台上扣人心弦的高清大片,音…...

Matlab实现任意伪彩色图像可视化显示

Matlab实现任意伪彩色图像可视化显示 1、灰度原始图像2、RGB彩色原始图像 在科研研究中,如何展示好看的实验结果图像非常重要!!! 1、灰度原始图像 灰度图像每个像素点只有一个数值,代表该点的​​亮度(或…...