当前位置: 首页 > news >正文

SOFTS: 时间序列预测的最新模型以及Python使用示例

近年来,深度学习一直在时间序列预测中追赶着提升树模型,其中新的架构已经逐渐为最先进的性能设定了新的标准。

这一切都始于2020年的N-BEATS,然后是2022年的NHITS。2023年,PatchTST和TSMixer被提出,最近的iTransformer进一步提高了深度学习预测模型的性能。

这是2024年4月《SOFTS: Efficient Multivariate Time Series Forecasting with Series-Core Fusion》中提出的新模型,采用集中策略来学习不同序列之间的交互,从而在多变量预测任务中获得最先进的性能。

在本文中,我们详细探讨了SOFTS的体系结构,并介绍新的STar聚合调度(STAD)模块,该模块负责学习时间序列之间的交互。然后,我们测试将该模型应用于单变量和多变量预测场景,并与其他模型作为对比。

SOFTS介绍

SOFTS是 Series-cOre Fused Time Series的缩写,背后的动机来自于长期多元预测对决策至关重要的认识:

首先我们一直研究Transformer的模型,它们试图通过使用补丁嵌入和通道独立等技术(如PatchTST)来降低Transformer的复杂性。但是由于通道独立性,消除了每个序列之间的相互作用,因此可能会忽略预测信息。

iTransformer 通过嵌入整个序列部分地解决了这个问题,并通过注意机制处理它们。但是基于transformer的模型在计算上是复杂的,并且需要更多的时间来训练非常大的数据集。

另一方面有一些基于mlp的模型。这些模型通常很快,并产生非常强的结果,但当存在许多序列时,它们的性能往往会下降。

所以出现了SOFTS:研究人员建议使用基于mlp的STAD模块。由于是基于MLP的,所以训练速度很快。并且STAD模块,它允许学习每个序列之间的关系,就像注意力机制一样,但计算效率更高。

SOFTS架构

在上图中可以看到每个序列都是单独嵌入的,就像在iTransformer 中一样。

然后将嵌入发送到STAD模块。每个序列之间的交互都是集中学习的,然后再分配到各个系列并融合在一起。

最后再通过线性层产生预测。

这个体系结构中有很多东西需要分析,我们下面更详细地研究每个组件。

1、归一化与嵌入

首先使用归一化来校准输入序列的分布。使用了可逆实例的归一化(RevIn)。它将数据以单位方差的平均值为中心。然后每个系列分别进行嵌入,就像在iTransformer 模型。

在上图中我们可以看到,嵌入整个序列就像应用补丁嵌入,其中补丁长度等于输入序列的长度。

这样,嵌入就包含了整个序列在所有时间步长的信息。

然后将嵌入式系列发送到STAD模块。

2、STar Aggregate-Dispatch (STAD)

STAD模块是soft模型与其他预测方法的真正区别。使用集中式策略来查找所有时间序列之间的相互作用。

嵌入的序列首先通过MLP和池化层,然后将这个学习到的表示连接起来形成核(上图中的黄色块表示)。

核构建好了以后就进入了“重复”和“连接”的步骤,在这个步骤中,核表示被分派给每个系列。

MLP和池化层未捕获的信息还可以通过残差连接添加到核表示中。然后在融合(fuse)操作的过程中,核表示及其对应系列的残差都通过MLP层发送。最后的线性层采用STAD模块的输出来生成每个序列的最终预测。

与其他捕获通道交互的方法(如注意力机制)相比,STAD模块的主要优点之一是它降低了复杂性。

因为STAD模块具有线性复杂度,而注意力机制具有二次复杂度,这意味着STAD在技术上可以更有效地处理具有多个序列的大型数据集。

下面我们来实际使用SOFTS进行单变量和多变量场景的测试。

使用SOFTS预测

这里,我们使用 Electricity Transformer dataset 数据集。

这个数据集跟踪了中国某省两个地区的变压器油温。每小时和每15分钟采样一个数据集,总共有四个数据集。

我门使用neuralforecast库中的SOFTS实现,这是官方认可的库,并且这样我们可以直接使用和测试不同预测模型的进行对比。

在撰写本文时,SOFTS还没有集成在的neuralforecast版本中,所以我们需要使用源代码进行安装。

 pip install git+https://github.com/Nixtla/neuralforecast.git

然后就是从导入包开始。使用datasetsforecast以所需格式加载数据集,以便使用neuralforecast训练模型,并使用utilsforecast评估模型的性能。这就是我们使用neuralforecast的原因,因为他都是一套的

 import pandas as pdimport numpy as npimport matplotlib.pyplot as pltfrom datasetsforecast.long_horizon import LongHorizonfrom neuralforecast.core import NeuralForecastfrom neuralforecast.losses.pytorch import MAE, MSEfrom neuralforecast.models import SOFTS, PatchTST, TSMixer, iTransformerfrom utilsforecast.losses import mae, msefrom utilsforecast.evaluation import evaluate

编写一个函数来帮助加载数据集,以及它们的标准测试大小、验证大小和频率。

 def load_data(name):if name == "ettm1":Y_df, *_ = LongHorizon.load(directory='./', group='ETTm1')Y_df = Y_df[Y_df['unique_id'] == 'OT'] # univariate datasetY_df['ds'] = pd.to_datetime(Y_df['ds'])val_size = 11520test_size = 11520freq = '15T'elif name == "ettm2":Y_df, *_ = LongHorizon.load(directory='./', group='ETTm2')Y_df['ds'] = pd.to_datetime(Y_df['ds']) val_size = 11520test_size = 11520freq = '15T'return Y_df, val_size, test_size, freq

然后就可以对ETTm1数据集进行单变量预测。

1、单变量预测

加载ETTm1数据集,将预测范围设置为96个时间步长。

可以测试更多的预测长度,但我们这里只使用96。

 Y_df, val_size, test_size, freq = load_data('ettm1')horizon = 96

然后初始化不同的模型,我们将soft与TSMixer, iTransformer和PatchTST进行比较。

所有模型都使用的默认配置将最大训练步数设置为1000,如果三次后验证损失没有改善,则停止训练。

 models = [SOFTS(h=horizon, input_size=3*horizon, n_series=1, max_steps=1000, early_stop_patience_steps=3),TSMixer(h=horizon, input_size=3*horizon, n_series=1, max_steps=1000, early_stop_patience_steps=3),iTransformer(h=horizon, input_size=3*horizon, n_series=1, max_steps=1000, early_stop_patience_steps=3),PatchTST(h=horizon, input_size=3*horizon, max_steps=1000, early_stop_patience_steps=3)]

然后初始化NeuralForecast对象训练模型。并使用交叉验证来获得多个预测窗口,更好地评估每个模型的性能。

 nf = NeuralForecast(models=models, freq=freq)nf_preds = nf.cross_validation(df=Y_df, val_size=val_size, test_size=test_size, n_windows=None)nf_preds = nf_preds.reset_index()

评估计算了每个模型的平均绝对误差(MAE)和均方误差(MSE)。因为之前的数据是缩放的,因此报告的指标也是缩放的。

 ettm1_evaluation = evaluate(df=nf_preds, metrics=[mae, mse], models=['SOFTS', 'TSMixer', 'iTransformer', 'PatchTST'])

从上图可以看出,PatchTST的MAE最低,而softts、TSMixer和PatchTST的MSE是一样的。在这种特殊情况下,PatchTST仍然是总体上最好的模型。

这并不奇怪,因为PatchTST在这个数据集中是出了名的好,特别是对于单变量任务。下面我们开始测试多变量场景。

2、多变量预测

使用相同的load_data函数,我们现在为这个多变量场景使用ETTm2数据集。

 Y_df, val_size, test_size, freq = load_data('ettm2')horizon = 96

然后简单地初始化每个模型。我们只使用多变量模型来学习序列之间的相互作用,所以不会使用PatchTST,因为它应用通道独立性(意味着每个序列被单独处理)。

然后保留了与单变量场景中相同的超参数。只将n_series更改为7,因为有7个时间序列相互作用。

 models = [SOFTS(h=horizon, input_size=3*horizon, n_series=7, max_steps=1000, early_stop_patience_steps=3, scaler_type='identity', valid_loss=MAE()),TSMixer(h=horizon, input_size=3*horizon, n_series=7, max_steps=1000, early_stop_patience_steps=3, scaler_type='identity', valid_loss=MAE()),iTransformer(h=horizon, input_size=3*horizon, n_series=7, max_steps=1000, early_stop_patience_steps=3, scaler_type='identity', valid_loss=MAE())]

训练所有的模型并进行预测。

 nf = NeuralForecast(models=models, freq='15min')nf_preds = nf.cross_validation(df=Y_df, val_size=val_size, test_size=test_size, n_windows=None)nf_preds = nf_preds.reset_index()

最后使用MAE和MSE来评估每个模型的性能。

 ettm2_evaluation = evaluate(df=nf_preds, metrics=[mae, mse], models=['SOFTS', 'TSMixer', 'iTransformer'])

上图中可以看到到当在96的水平上预测时,TSMixer large在ETTm2数据集上的表现优于iTransformer和soft。

虽然这与soft论文的结果相矛盾,这是因为我们没有进行超参数优化,并且使用了96个时间步长的固定范围。

这个实验的结果可能不太令人印象深刻,我们只在固定预测范围的单个数据集上进行了测试,所以这不是SOFTS性能的稳健基准,同时也说明了SOFTS在使用时可能需要更多的时间来进行超参数的优化。

总结

SOFTS是一个很有前途的基于mlp的多元预测模型,STAD模块是一种集中式方法,用于学习时间序列之间的相互作用,其计算强度低于注意力机制。这使得模型能够有效地处理具有许多并发时间序列的大型数据集。

虽然在我们的实验中,SOFTS的性能可能看起来有点平淡无奇,但请记住,这并不代表其性能的稳健基准,因为我们只在固定视界的单个数据集上进行了测试。

但是SOFTS的思路还是非常好的,比如使用集中式学习时间序列之间的相互作用,并且使用低强度的计算来保证数据计算的效率,这都是值得我们学习的地方。

并且每个问题都需要其独特的解决方案,所以将SOFTS作为特定场景的一个测试选项是一个明智的选择。

SOFTS: Efficient Multivariate Time Series Forecasting with Series-Core Fusion

https://avoid.overfit.cn/post/6254097fd18d479ba7fd85efcc49abac

相关文章:

SOFTS: 时间序列预测的最新模型以及Python使用示例

近年来,深度学习一直在时间序列预测中追赶着提升树模型,其中新的架构已经逐渐为最先进的性能设定了新的标准。 这一切都始于2020年的N-BEATS,然后是2022年的NHITS。2023年,PatchTST和TSMixer被提出,最近的iTransforme…...

C++ 取近似值

描述 写出一个程序,接受一个正浮点数值,输出该数值的近似整数值。如果小数点后数值大于等于 0.5 ,向上取整;小于 0.5 ,则向下取整。 数据范围:保证输入的数字在 32 位浮点数范围内 输入描述: 输入一个正…...

云原生系列之Docker常用命令

🌹作者主页:青花锁 🌹简介:Java领域优质创作者🏆、Java微服务架构公号作者😄 🌹简历模板、学习资料、面试题库、技术互助 🌹文末获取联系方式 📝 系列文章目录 云原生之…...

opencv_GUI

图像入门 import numpy as np import cv2 as cv # 用灰度模式加载图像 img cv.imread(C:/Users/HP/Downloads/basketball.png, 0)# 即使图像路径错误,它也不会抛出任何错误,但是打印 img会给你Nonecv.imshow(image, img) cv.waitKey(5000) # 一个键盘绑…...

FlowUs轻量化AI:趁这波升级专业版,全年无限AI助力笔记产出与二次编写

在数字时代,信息管理与知识产出的效率直接影响个人的生产力。FlowUs作为一款集笔记、文档、多维表、文件夹于一体的新一代知识管理平台,其轻量化AI的加入更是如虎添翼。特别是在活动期间,升级专业版将带来全年无限AI使用次数,让每…...

Day 22:2786. 访问数组中的位置使分数最大

Leetcode 2786. 访问数组中的位置使分数最大 给你一个下标从 0 开始的整数数组 nums 和一个正整数 x 。 你 一开始 在数组的位置 0 处&#xff0c;你可以按照下述规则访问数组中的其他位置&#xff1a; 如果你当前在位置 i &#xff0c;那么你可以移动到满足 i < j 的 任意 …...

理解Es的DSL语法(二):聚合

前一篇已经系统介绍过查询语法&#xff0c;详细可直接看上一篇文章&#xff08;理解DSL语法&#xff08;一&#xff09;&#xff09;&#xff0c;本篇主要介绍DSL中的另一部分&#xff1a;聚合 理解Es中的聚合 虽然Elasticsearch 是一个基于 Lucene 的搜索引擎&#xff0c;但…...

matlab-2-simulink-小白教程-如何绘制电路图进行电路仿真

以上述电路图为例&#xff1a;包含D触发器&#xff0c;时钟CLK,与非门 一、启动simulink的三种方式 方式1 在MATLAB的命令行窗口输入“Simulink”命令。 方式2 在MATLAB主窗口的“主页”选项卡中&#xff0c;单击“SIMULINK”命令组中的Simulink命令按钮。 方式3 从MATLAB…...

CSS从入门到精通——背景样式

目录 背景颜色 任务描述 相关知识 背景色 编程要求 背景图片 任务描述 相关知识 背景图片 设置背景图片 平铺背景图像 任务要求 背景定位与背景关联 任务描述 相关知识 背景定位 背景关联 简写背景 编程要求 背景颜色 任务描述 本关任务&#xff1a;在本关…...

网络编程---Java飞机大战联机

解析服务器端代码 代码是放在app/lib下的src下的main/java&#xff0c;而与之前放在app/src/main下路径不同 Main函数 Main函数里只放着创建MyServer类的一行 public static void main(String args[]){new MyServer();} MyServer构造函数 1.获取本机IP地址 //获取本机IP地…...

一个简单的Oracle函数

CREATE OR REPLACE FUNCTION getyj_zhibiao_value(p_name IN varchar2, p_index IN varchar2) RETURN NUMBER IS -- 定义返回的指标值变量 v_result NUMBER; -- 定义临时变量来存储查询到的指标值 v_index1 VARCHAR2(50); v_index2 VARCHAR2(50); …...

word中根据上级设置下级编号

如上级是3.13.4&#xff0c;如下图 现在想设置下级编码跟随上级逐级显示成3.13.4.1 则在标题功能说明这点击顶部菜单栏的编号按钮&#xff0c;如下图 然后&#xff0c;选择自定义编号-自定义列表-自定义按钮 然后重点是编号格式这一栏&#xff0c;需要手动填写下前三级的编号&…...

【康复学习--LeetCode每日一题】2786. 访问数组中的位置使分数最大

题目描述&#xff1a; 给你一个下标从 0 开始的整数数组 nums 和一个正整数 x 。 你一开始 在数组的位置 0 处&#xff0c;你可以按照下述规则访问数组中的其他位置&#xff1a; 如果你当前在位置 i &#xff0c;那么你可以移动到满足 i < j 的 任意 位置 j 。 对于你访问的…...

bash和sh区别

bash 和 sh 是两种常用的 Unix Shell&#xff0c;它们有一些区别&#xff0c;特别是在功能和兼容性方面。以下是一些主要的区别&#xff1a; 1. **历史与实现**&#xff1a; - sh&#xff08;Bourne Shell&#xff09;是第一个 Unix Shell&#xff0c;最初由 Stephen Bourn…...

Git 代码管理规范 !

分支命名 master 分支 master 为主分支&#xff0c;也是用于部署生产环境的分支&#xff0c;需要确保master分支稳定性。master 分支一般由 release 以及 hotfix 分支合并&#xff0c;任何时间都不能直接修改代码。 develop 分支 develop 为开发环境分支&#xff0c;始终保持最…...

MGRS坐标

一 概述 MGRS坐标系统&#xff0c;即军事格网参考系统&#xff0c;是北约(NATO)军事组织使用的标准坐标系统。它基于UTM&#xff08;通用横向墨卡托&#xff09;系统&#xff0c;并将每个UTM区域进一步划分为100km100km的小方块。这些方块通过两个相连的字母标识&#xff0c;其…...

FreeRTOS简单内核实现4 临界段

文章目录 0、思考与回答0.1、思考一0.2、思考二0.3、思考三 1、关中断1.1、带返回值1.2、不带返回值 2、开中断3、临界段4、应用 0、思考与回答 0.1、思考一 为什么需要临界段&#xff1f; 有时候我们需要部分代码一旦这开始执行&#xff0c;则不允许任何中断打断&#xff0…...

Scala的字符串插值

Scala的字符串插值 期待您的关注 ☀Scala学习笔记 目录 Scala的字符串插值 1. s插值器&#xff1a; 2. f插值器&#xff1a; 3. raw插值器&#xff1a; 在Scala中&#xff0c;字符串插值是一种方便的方式&#xff0c;可以在字符串中插入变量或表达式的值。Scala支持三种类型…...

EasyGBS服务器和终端配置

服务器配置 修改easygbs.ini sip/host为本机IP&#xff0c;否则终端能登录&#xff0c;无法视频。 [sip] host192.168.3.190 终端用于登录的用户名和密码 default_usertest default_passwordtest1234 default_guest_userguest default_guest_passwordtest1234终端配置 关…...

git配置2-不同的代码托管平台配置不同的ssh key

1. 配置单个ssh key 1.1. 原理1.2. 生成 ssh key1.3. 代码托管平台配置公钥 2. 配置多个ssh key 2.1. 应用场景2.2. 生成两个不同的key2.3. 修改config文件2.4. 配置代码托管平台2.5. 测试是否成功 1. 配置单个ssh key 1.1. 原理 使用ssh命令行工具&#xff08;git安装成功…...

【CT】LeetCode手撕—102. 二叉树的层序遍历

目录 题目1-思路2- 实现⭐102. 二叉树的层序遍历——题解思路 3- ACM实现3-1 二叉树构造3-2 整体实现 题目 原题连接&#xff1a;102. 二叉树的层序遍历 1-思路 1.借助队列 Queue &#xff0c;每次利用 ①while 循环遍历当前层结点&#xff0c;②将当前层结点的下层结点放入 …...

Flink 命令行提交、展示和取消作业

Apache Flink 是一个流处理和批处理的开源框架&#xff0c;用于在分布式环境中执行无边界和有边界的数据流。你可以使用 Flink 的命令行界面&#xff08;CLI&#xff09;来提交、展示和取消作业。 提交作业 使用 Flink CLI 提交作业的命令格式通常如下&#xff1a; ./bin/fl…...

STM32单片机选型方法

一.STM32单片机选型方法 1.首先要确定需求&#xff1a; 性能需求&#xff1a;根据应用的复杂度和性能要求&#xff0c;选择合适的CPU性能和主频。 内存需求&#xff1a;确定所需的内存大小&#xff0c;包括RAM和Flash存储空间。 外设需求&#xff1a;根据应用所需的功能&…...

gsap动画库的实践

先看效果&#xff1a; gsap动画库 安装插件&#xff1a;npm install gsap <template><div><h1 style"text-align: left">gsap的用法</h1><h1 style"text-align: left">https://gsap.com/resources/get-started</h1>&…...

LeetCode | 387.字符串中的第一个唯一字符

这道题可以用字典解决&#xff0c;只需要2次遍历字符串&#xff0c;第一次遍历字符串&#xff0c;记录每个字符出现的次数&#xff0c;第二次返回第一个出现次数为1的字符的下标&#xff0c;若找不到则返回-1 class Solution(object):def firstUniqChar(self, s):""…...

textarea 中的内容在word中显示换行不起作用

js文本换行在word显示 在JavaScript中&#xff0c;处理文本换行以确保它在Word中正确显示&#xff0c;通常需要将文本中的换行符转换为Word可识别的格式。在HTML中&#xff0c;换行通常是通过<br>标签来实现的&#xff0c;而在Word中&#xff0c;换行通常由段落标签<…...

Python 测试用例

在Python中编写测试用例通常使用unittest模块&#xff0c;这是Python标准库的一部分&#xff0c;专门用于编写和运行测试。下面是一个简单的测试用例的例子&#xff0c;展示了如何使用unittest模块来测试一个函数。 假设我们有一个简单的函数&#xff0c;用于计算两个数的和&a…...

树莓派等Linux开发板上使用 SSD1306 OLED 屏幕,bullseye系统 ubuntu,debian

Raspberry Pi OS Bullseye 最近发布了,随之而来的是许多改进,但其中大部分都在引擎盖下。没有那么多视觉差异,最明显的可能是新的默认桌面背景,现在是大坝或湖泊上的日落。https://www.the-diy-life.com/add-an-oled-stats-display-to-raspberry-pi-os-bullseye/ 通过这次操…...

SpringBoot3 整合 Mybatis 完整版

本文记录一下完整的 SpringBoot3 整合 Mybatis 的步骤。 只要按照本步骤来操作&#xff0c;整合完成后就可以正常使用。1. 添加数据库驱动依赖 以 MySQL 为例。 当不指定 依赖版本的时候&#xff0c;会 由 springboot 自动管理。 <dependency><groupId>com.mysql&l…...

图解Transformer学习笔记

教程是来自https://github.com/datawhalechina/learn-nlp-with-transformers/blob/main/docs/ 图解Transformer Attention为RNN带来了优点&#xff0c;那么有没有一种神经网络结构直接基于Attention构造&#xff0c;而不再依赖RNN、LSTM或者CNN的结构&#xff0c;这就是Trans…...