当前位置: 首页 > news >正文

使用pytorch实现LSTM预测交通流

原始数据:

免费可下载原始参考数据

预测结果图:

根据测试数据test_data的真实值real_flow,与模型根据测试数据得到的输出结果pre_flow

完整源码:

#!/usr/bin/env python
# _*_ coding: utf-8 _*_import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data as Data
from torchsummary import summary
import math
from torch.utils.data import DataLoader
from torch.utils.data import Dataset
import time
import matplotlib.pyplot as plt
from sklearn.metrics import mean_squared_error, mean_absolute_error
import math
from matplotlib.font_manager import FontProperties  # 画图时可以使用中文# 加载数据
f = pd.read_csv(r'C:\Users\14600\Desktop\AE86.csv')# 从新设置列标
def set_columns():columns = []for i in f.loc[2]:columns.append(i.strip())return columnsf.columns = set_columns()
f.drop([0, 1, 2], inplace=True)
# 读取数据
data = f['Total Carriageway Flow'].astype(np.float64).values[:, np.newaxis]class LoadData(Dataset):def __init__(self, data, time_step, divide_days, train_mode):self.train_mode = train_modeself.time_step = time_stepself.train_days = divide_days[0]self.test_days = divide_days[1]self.one_day_length = int(24 * 4)# flow_norm (max_data. min_data)self.flow_norm, self.flow_data = LoadData.pre_process_data(data)# 不进行标准化# self.flow_data = datadef __len__(self, ):if self.train_mode == "train":return self.train_days * self.one_day_length - self.time_stepelif self.train_mode == "test":return self.test_days * self.one_day_lengthelse:raise ValueError(" train mode error")def __getitem__(self, index):if self.train_mode == "train":index = indexelif self.train_mode == "test":index += self.train_days * self.one_day_lengthelse:raise ValueError(' train mode error')data_x, data_y = LoadData.slice_data(self.flow_data, self.time_step, index,self.train_mode)data_x = LoadData.to_tensor(data_x)data_y = LoadData.to_tensor(data_y)return {"flow_x": data_x, "flow_y": data_y}# 这一步就是划分数据@staticmethoddef slice_data(data, time_step, index, train_mode):if train_mode == "train":start_index = indexend_index = index + time_stepelif train_mode == "test":start_index = index - time_stepend_index = indexelse:raise ValueError("train mode error")data_x = data[start_index: end_index, :]data_y = data[end_index]return data_x, data_y# 数据与处理@staticmethoddef pre_process_data(data, ):norm_base = LoadData.normalized_base(data)normalized_data = LoadData.normalized_data(data, norm_base[0], norm_base[1])return norm_base, normalized_data# 生成原始数据中最大值与最小值@staticmethoddef normalized_base(data):max_data = np.max(data, keepdims=True)  # keepdims保持维度不变min_data = np.min(data, keepdims=True)# max_data.shape  --->(1, 1)return max_data, min_data# 对数据进行标准化@staticmethoddef normalized_data(data, max_data, min_data):data_base = max_data - min_datanormalized_data = (data - min_data) / data_basereturn normalized_data@staticmethod# 反标准化  在评价指标误差以及画图的使用使用def recoverd_data(data, max_data, min_data):data_base = max_data - min_datarecoverd_data = data * data_base - min_datareturn recoverd_data@staticmethoddef to_tensor(data):return torch.tensor(data, dtype=torch.float)# 划分数据
divide_days = [25, 5]
time_step = 5
batch_size = 48
train_data = LoadData(data, time_step, divide_days, "train")
test_data = LoadData(data, time_step, divide_days, "test")
train_loader = DataLoader(train_data, batch_size=batch_size, shuffle=True, )
test_loader = DataLoader(test_data, batch_size=batch_size, shuffle=False, )# LSTM构建网络
class LSTM(nn.Module):def __init__(self, input_num, hid_num, layers_num, out_num, batch_first=True):super().__init__()self.l1 = nn.LSTM(input_size=input_num,hidden_size=hid_num,num_layers=layers_num,batch_first=batch_first)self.out = nn.Linear(hid_num, out_num)def forward(self, data):flow_x = data['flow_x']  # B * T * Dl_out, (h_n, c_n) = self.l1(flow_x, None)  # None表示第一次 hidden_state是0#         print(l_out[:, -1, :].shape)out = self.out(l_out[:, -1, :])return out# 定义模型参数
input_num = 1  # 输入的特征维度
hid_num = 50  # 隐藏层个数
layers_num = 3  # LSTM层个数
out_num = 1
lstm = LSTM(input_num, hid_num, layers_num, out_num)
loss_func = nn.MSELoss()
optimizer = torch.optim.Adam(lstm.parameters())# 训练模型
lstm.train()
epoch_loss_change = []
for epoch in range(30):epoch_loss = 0.0start_time = time.time()for data_ in train_loader:lstm.zero_grad()predict = lstm(data_)loss = loss_func(predict, data_['flow_y'])epoch_loss += loss.item()loss.backward()optimizer.step()epoch_loss_change.append(1000 * epoch_loss / len(train_data))end_time = time.time()print("Epoch: {:04d}, Loss: {:02.4f}, Time: {:02.2f} mins".format(epoch, 1000 * epoch_loss / len(train_data),(end_time - start_time) / 60))
plt.plot(epoch_loss_change)# 评价模型
lstm.eval()
with torch.no_grad():  # 关闭梯度total_loss = 0.0pre_flow = np.zeros([batch_size, 1])  # [B, D],T=1 # 目标数据的维度,用0填充real_flow = np.zeros_like(pre_flow)for data_ in test_loader:pre_value = lstm(data_)loss = loss_func(pre_value, data_['flow_y'])total_loss += loss.item()# 反归一化pre_value = LoadData.recoverd_data(pre_value.detach().numpy(),test_data.flow_norm[0].squeeze(1),  # max_datatest_data.flow_norm[1].squeeze(1),  # min_data)target_value = LoadData.recoverd_data(data_['flow_y'].detach().numpy(),test_data.flow_norm[0].squeeze(1),test_data.flow_norm[1].squeeze(1),)pre_flow = np.concatenate([pre_flow, pre_value])real_flow = np.concatenate([real_flow, target_value])pre_flow = pre_flow[batch_size:]real_flow = real_flow[batch_size:]
#     # 计算误差
mse = mean_squared_error(pre_flow, real_flow)
rmse = math.sqrt(mean_squared_error(pre_flow, real_flow))
mae = mean_absolute_error(pre_flow, real_flow)
print('均方误差---', mse)
print('均方根误差---', rmse)
print('平均绝对误差--', mae)# 画出预测结果图
font_set = FontProperties(fname=r"C:\Windows\Fonts\simsun.ttc", size=15)  # 中文字体使用宋体,15号
plt.figure(figsize=(15, 10))
plt.plot(real_flow, label='Real_Flow', color='r', )
plt.plot(pre_flow, label='Pre_Flow')
plt.xlabel('测试序列', fontproperties=font_set)
plt.ylabel('交通流量/辆', fontproperties=font_set)
plt.legend()
# 预测储存图片
# plt.savefig('...\Desktop\123.jpg')plt.show()

相关文章:

使用pytorch实现LSTM预测交通流

原始数据: 免费可下载原始参考数据 预测结果图: 根据测试数据test_data的真实值real_flow,与模型根据测试数据得到的输出结果pre_flow 完整源码: #!/usr/bin/env python # _*_ coding: utf-8 _*_import pandas as pd import nu…...

C/C++(八)C++11

目录 一、C11的简介 二、万能引用与完美转发 1、万能引用:模板中的 && 引用 2、完美转发:保持万能引用左右值属性的解决方案 三、可变参数模板 1、可变参数模板的基本使用 2、push 系列和 emplace 系列的区别 四、lambda表达式&#xf…...

使用three.js 实现 自定义绘制平面的效果

使用three.js 实现 自定义绘制平面的效果 预览 import * as THREE from three import { OrbitControls } from three/examples/jsm/controls/OrbitControls.jsconst box document.getElementById(box)const scene new THREE.Scene()const camera new THREE.PerspectiveCam…...

玩转Docker | 使用Docker部署捕鱼网页小游戏

玩转Docker | 使用Docker部署捕鱼网页小游戏 一、项目介绍项目简介项目预览二、系统要求环境要求环境检查Docker版本检查检查操作系统版本三、部署捕鱼网页小游戏下载镜像创建容器检查容器状态下载项目内容查看服务监听端口安全设置四、访问捕鱼网页小游戏五、总结一、项目介绍…...

第2章 Android App开发基础

第 2 章 Android App开发基础 bilibili学习地址 github代码地址 本章介绍基于Android系统的App开发常识,包括以下几个方面:App开发与其他软件开发有什么不一 样,App工程是怎样的组织结构又是怎样配置的,App开发的前后端分离设计…...

通过 SYSENTER/SYSEXIT指令来学习系统调用

SYSENTER指令—快速系统调用 指令格式没有什么重要的内容,只有opcode ,没有后面的其他字段 指令的作用: 执行快速调用到特权级别0的系统过程或例程。SYSENTER是SYSEXIT的配套指令。该指令经过优化,能够为从运行在特权级别3的用户代码到特权级别0的操作系统或执行过程…...

Nginx开发实战——网络通信(一)

文章目录 Nginx开发框架信号处理函数的进一步完善(避免僵尸子进程)(续)ngx_signal.cxxngx_process_cycle.cxx 网络通信实战客户端和服务端1. 解析一个浏览器访问网页的过程2.客户端服务器角色规律总结 网络模型OSI 7层网络模型TCP/IP 4层模型3.TCP/IP的解释和比喻 最…...

w外链如何跳转微信小程序

要创建外链跳转微信小程序,主要有以下几种方法: 使用第三方工具生成跳转链接: 注册并登录第三方外链平台:例如 “W外链” 等工具。前往该平台的官方网站,使用手机号、邮箱等方式进行注册并登录账号。选择创建小程序外…...

获取平台Redis各项性能指标

业务场景 在XXXX项目中把A网的过车数据传到B网中,其中做了一个业务处理,就是如果因为网络或者其他原因导致把数据传到B网失败,就会把数据暂时先存到redis里,并且执行定时任务重新发送失败的。 问题 不过现场的情况比较不稳定。出…...

STM32 HAL 点灯

首先从点灯开始 完整函数如下: #include "led.h" #include "sys.h"//包含了stm32f1xx.h(包含各种寄存器定义、中断向量定义、常量定义等)//初始化GPIO口 void led_init(void) {GPIO_InitTypeDef gpio_initstruct;//打开…...

【http作业】

1.关闭防火墙 [rootlocalhost ~]# systemctl stop firewalld #关闭防火墙 [rootlocalhost ~]# setenforce 0 2.下载nginx包 [rootlocalhost ~]# mount /dev/sr0 /mnt #挂载目录 [rootlocalhost ~]# yum install nginx -y #下载nginx包 3.增加多条端口 [rootlocalhost ~]# n…...

WPF+MVVM案例实战(十一)- 环形进度条实现

文章目录 1、运行效果2、功能实现1、文件创建与代码实现2、角度转换器实现3、命名空间引用3、源代码下载1、运行效果 2、功能实现 1、文件创建与代码实现 打开 Wpf_Examples 项目,在Views 文件夹下创建 CircularProgressBar.xaml 窗体文件。 CircularProgressBar.xaml 代码实…...

简述MCU微控制器

目录 一、MCU 的主要特点: 二、常见 MCU 系列: 三、应用场景: MCU 是微控制器(Microcontroller Unit)的缩写,指的是一种小型计算机,专门用于嵌入式系统。它通常集成了中央处理器(…...

微服务的雪崩问题

微服务的雪崩问题: 微服务调用链路中的某个服务故障,引起整个链路种的所有微服务都不可用。这就是微服务的雪崩问题。(级联失败),具体表现出来就是微服务之间相互调用,服务的提供者出现阻塞或者故障&#x…...

Java基础(4)——构建字符串(干货)

今天聊Java构建字符串以及其内存原理 我们先来看一个小例子。一个是String,一个是StringBuilder. 通过结果对比,StringBuilder要远远快于String. String/StringBuilder/StringBuffer这三个构建字符串有什么区别? 拼接速度上,StringBuilder…...

logback日志脱敏后异步写入文件

大家项目中肯定都会用到日志打印,目的是为了以后线上排查问题方便,但是有些企业对输出的日志包含的敏感(比如:用户身份证号,银行卡号,手机号等)信息要进行脱敏处理。 哎!我们最近就遇到了日志脱敏的改造。可…...

电容的基本知识

1.电容的相关公式 2.电容并联和串联的好处 电容并联的好处: 增加总电容值: 并联连接的电容器可以增加总的电容值,这对于需要较大电容值来滤除高频噪声或储存更多电荷的应用非常有用。 改善频率响应: 并联不同的电容值可以设计一个滤波器,以在特定的频率范围内提供更好的滤…...

【Axure高保真原型】分级树筛选中继器表格

今天和大家分享分级树筛选中继器表格的原型模板,点击树的箭头可以展开或者收起子级内容,点击内容,可以筛选出该内容及子级内容下所有的表格数据。左侧的树和右侧的表格都是用中继器制作的,所以使用也很方便,只需要在中…...

STM32 I2C通信:硬件I2C与软件模拟I2C的区别

文章目录 STM32 I2C通信:硬件I2C与软件模拟I2C的区别。一、硬件I2C速度快:实现简单:稳定性好: 二、软件模拟I2C灵活性高:支持多路通信: 三、选择哪种方式? STM32 I2C通信:硬件I2C与软…...

服务器新建用户

文章目录 前言一、步骤二、问题三、赋予管理员权限总结 前言 环境: 一、步骤 创建用户需要管理员权限sudo sudo useradd tang为用户设置密码 sudo passwd tang设置密码后,可以尝试使用 su 切换到 tang 用户,确保该用户可以正常使用&#…...

鸿蒙开发融云demo发送图片消息

鸿蒙开发融云demo发送图片消息 融云鸿蒙版是不带UI的,得自己一步步搭建。 这次讲如何发送图片消息,选择图片,显示图片消息。 还是有点难度的,好好看,好好学。 一、思路: 选择图片用:photoVie…...

音视频入门基础:AAC专题(11)——AudioSpecificConfig简介

音视频入门基础:AAC专题系列文章: 音视频入门基础:AAC专题(1)——AAC官方文档下载 音视频入门基础:AAC专题(2)——使用FFmpeg命令生成AAC裸流文件 音视频入门基础:AAC…...

OpenCV基本操作(python开发)——(8)实现芯片瑕疵检测

OpenCV基本操作(python开发)——(1) 读取图像、保存图像 OpenCV基本操作(python开发)——(2)图像色彩操作 OpenCV基本操作(python开发)——(3&…...

聚水潭商品信息集成到MySQL的高效解决方案

聚水潭商品信息集成到MySQL的技术案例分享 在数据驱动的业务环境中,如何高效地实现不同系统之间的数据对接和集成,是每个企业面临的重要挑战。本文将聚焦于一个具体的系统对接集成案例:将聚水潭平台上的商品信息单集成到BI斯莱蒙的MySQL数据…...

# centos6.5 使用 yum list 报错Error Cannot find a valid baseurl for repo bas 解决方法

centos6.5 使用 yum list 报错Error Cannot find a valid baseurl for repo bas 解决方法 一、问题描述: centos6.5 使用 yum list 报错Error Cannot find a valid baseurl for repo bas 如下图: 二、问题分析: 官方已停止对CentOS 6的更…...

【专题】2023-2024中国保险数字化营销调研报告汇总PDF洞察(附原数据表)

原文链接: https://tecdat.cn/?p38063 在时代浪潮的推动下,中国保险行业正经历着一场波澜壮阔的变革之旅。 2023 年,中国经济迈向高质量发展阶段,保险公司纷纷聚焦队伍转型,专业化、职业化代理人成为行业新方向。回…...

““ 引用类型应用举例

#include <iostream> //使能cin(),cout(); #include <stdlib.h> //使能exit(); #include <iomanip> //使能setbase(),setfill(),setw(),setprecision(),setiosflags()和resetiosflags(); //setbase( char x )是设置输出数字的基数,如输出进制数则用se…...

数字图像处理 - 基于ubuntu20.04运行.NET6+OpenCVSharp项目

一、简述 上一篇Ubuntu20.04 更新Nvidia驱动 + 安装CUDA12.1 + cudnn8.9.7-CSDN博客,记录了Ubuntu20.04 更新Nvidia驱动 + 安装CUDA12.1 + cudnn8.9.7的过程,最终的目的是要这些服务器上运行.net6的程序,进行图像处理、onnxruntime推理等。 这里记录进行OpenCVSharp的安装和…...

git cherry-pick用法详解

git cherry-pick 是 Git 中一个非常有用的命令&#xff0c;它允许你选择一个特定的提交&#xff08;commit&#xff09;并将其变更应用到当前分支上。这个功能在你需要将某个分支上的某个或某些特定提交合并到另一个分支时特别有用&#xff0c;而不需要将整个分支合并过去。 基…...

HCIP-HarmonyOS Application Developer V1.0 笔记(一)

HarmonyOS的系统特性 硬件互助&#xff0c;资源共享;一次开发&#xff0c;多端部署;统一OS&#xff0c;弹性部署。 分布式软总线&#xff1a;分布式任务调度、分布式数据管理、分布式硬件虚拟化的基座 18N的独立设备 1个手机&#xff0c;8种设备&#xff08;车机&#xff0c…...

wordpress 文章幻灯片/网站建设平台哪家好

看到百威啤酒的客户端主界面的按钮&#xff0c;感觉比较新奇&#xff0c;先看下图片&#xff1a; 注意图中我画的箭头&#xff0c;当时鼠标点击的黑色圈圈的位置&#xff0c;然后按钮出现了按下的效果&#xff08;黄色的描边&#xff09; 刚开始看到这种效果很是好奇&#xff0…...

应届生招聘去哪个网站/北京官网seo

文章目录1.处理客户端删除状态请求1.1 InstanceResource.deleteStatusUpdate()1.2 PeerAwareInstanceRegistryImpl.deleteStatusOverride()1.3 AbstractInstanceRegistry.deleteStatusOverride()1.4 PeerAwareInstanceRegistryImpl.replicateToPeers()1.处理客户端删除状态请求…...

问答类网站怎么做啊/自制网页

除了最初发表的以"熔岩音乐"和"流利阅读"为范本&#xff0c;以及初学者入门点这里&#xff0c;内容太过于巨细靡遗&#xff0c;主要是为了给入门朋友详解相关的基本流程&#xff0c;但之后所有文章唯有对于实例的重点讲解&#xff0c;主要在于培养大家对于…...

河南商丘疫情最新政策/seo是什么意思 为什么要做seo

Symantec Endpoint Protection v11.0.5002.333 简体中文企业版 【现象描述】&#xff1a;为了内网安全&#xff0c;内网可以上外网的一台服务器上面部署了网络版symantec杀毒软件&#xff0c;版本是 v11.0.5002.333 简体中文企业版&#xff1b;由它生成的客户端&#xff0c;分发…...

南京seo公司哪家好/武汉seo工厂

多世界诠释有点类似科幻小说中最喜欢使用的“平行宇宙”概念。该诠释认为&#xff0c;波函数对电子位置的其他预测不但没有消失&#xff0c;而且还全部发生了&#xff0c;只不过它们都发生在彼此不相干的世界里而已。这听上去就像&#xff0c;如果你在这个现实里做了什么糟糕的…...

网站浏览图片怎么做/关键词查询的分析网站

Lazarus crack,功能丰富的 Delphi 开发环境 Lazarus为开发人员提供了一个功能丰富的 Delphi 开发环境&#xff0c;使他们能够创建功能齐全的跨平台应用程序&#xff0c;专为个人和商业用途而设计。 Lazarus 项目拥有超过 15 年的历史&#xff0c;包括一个用于 Free Pascal 的综…...