当前位置: 首页 > news >正文

NeuralForecast 模型的参数 windows_batch的含义

NeuralForecast 模型的参数 windows_batch的含义

flyfish

import pandas as pd
import numpy as npAirPassengers = np.array([112.0, 118.0, 132.0, 129.0, 121.0, 135.0, 148.0, 148.0, 136.0, 119.0],dtype=np.float32,
)AirPassengersDF = pd.DataFrame({"unique_id": np.ones(len(AirPassengers)),"ds": pd.date_range(start="1949-01-01", periods=len(AirPassengers), freq=pd.offsets.MonthEnd()),"y": AirPassengers,}
)Y_df = AirPassengersDF
Y_df = Y_df.reset_index(drop=True)
Y_df.head()
#Model Trainingfrom neuralforecast.core import NeuralForecast
from neuralforecast.models import VanillaTransformerhorizon = 3
models = [VanillaTransformer(input_size=2 * horizon, h=horizon, max_steps=2)]nf = NeuralForecast(models=models, freq='M')for model in nf.models:print(f'Model: {model.__class__.__name__}')for param, value in model.__dict__.items():print(f'  {param}: {value}')nf.fit(df=Y_df)

输出

Seed set to 1
Model: VanillaTransformertraining: True_parameters: OrderedDict()_buffers: OrderedDict()_non_persistent_buffers_set: set()_backward_pre_hooks: OrderedDict()_backward_hooks: OrderedDict()_is_full_backward_hook: None_forward_hooks: OrderedDict()_forward_hooks_with_kwargs: OrderedDict()_forward_hooks_always_called: OrderedDict()_forward_pre_hooks: OrderedDict()_forward_pre_hooks_with_kwargs: OrderedDict()_state_dict_hooks: OrderedDict()_state_dict_pre_hooks: OrderedDict()_load_state_dict_pre_hooks: OrderedDict()_load_state_dict_post_hooks: OrderedDict()_modules: OrderedDict([('loss', MAE()), ('valid_loss', MAE()), ('padder_train', ConstantPad1d(padding=(0, 3), value=0)), ('scaler', TemporalNorm()), ('enc_embedding', DataEmbedding((value_embedding): TokenEmbedding((tokenConv): Conv1d(1, 128, kernel_size=(3,), stride=(1,), padding=(1,), bias=False, padding_mode=circular))(position_embedding): PositionalEmbedding()(dropout): Dropout(p=0.05, inplace=False)
)), ('dec_embedding', DataEmbedding((value_embedding): TokenEmbedding((tokenConv): Conv1d(1, 128, kernel_size=(3,), stride=(1,), padding=(1,), bias=False, padding_mode=circular))(position_embedding): PositionalEmbedding()(dropout): Dropout(p=0.05, inplace=False)
)), ('encoder', TransEncoder((attn_layers): ModuleList((0-1): 2 x TransEncoderLayer((attention): AttentionLayer((inner_attention): FullAttention((dropout): Dropout(p=0.05, inplace=False))(query_projection): Linear(in_features=128, out_features=128, bias=True)(key_projection): Linear(in_features=128, out_features=128, bias=True)(value_projection): Linear(in_features=128, out_features=128, bias=True)(out_projection): Linear(in_features=128, out_features=128, bias=True))(conv1): Conv1d(128, 32, kernel_size=(1,), stride=(1,))(conv2): Conv1d(32, 128, kernel_size=(1,), stride=(1,))(norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)(norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)(dropout): Dropout(p=0.05, inplace=False)))(norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)
)), ('decoder', TransDecoder((layers): ModuleList((0): TransDecoderLayer((self_attention): AttentionLayer((inner_attention): FullAttention((dropout): Dropout(p=0.05, inplace=False))(query_projection): Linear(in_features=128, out_features=128, bias=True)(key_projection): Linear(in_features=128, out_features=128, bias=True)(value_projection): Linear(in_features=128, out_features=128, bias=True)(out_projection): Linear(in_features=128, out_features=128, bias=True))(cross_attention): AttentionLayer((inner_attention): FullAttention((dropout): Dropout(p=0.05, inplace=False))(query_projection): Linear(in_features=128, out_features=128, bias=True)(key_projection): Linear(in_features=128, out_features=128, bias=True)(value_projection): Linear(in_features=128, out_features=128, bias=True)(out_projection): Linear(in_features=128, out_features=128, bias=True))(conv1): Conv1d(128, 32, kernel_size=(1,), stride=(1,))(conv2): Conv1d(32, 128, kernel_size=(1,), stride=(1,))(norm1): LayerNorm((128,), eps=1e-05, elementwise_affine=True)(norm2): LayerNorm((128,), eps=1e-05, elementwise_affine=True)(norm3): LayerNorm((128,), eps=1e-05, elementwise_affine=True)(dropout): Dropout(p=0.05, inplace=False)))(norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True)(projection): Linear(in_features=128, out_features=1, bias=True)
))])prepare_data_per_node: Trueallow_zero_length_dataloader_with_multiple_devices: False_log_hyperparams: True_dtype: torch.float32_device: cpu_trainer: None_example_input_array: None_automatic_optimization: True_strict_loading: None_current_fx_name: None_param_requires_grad_state: {}_metric_attributes: None_compiler_ctx: None_fabric: None_fabric_optimizers: []_hparams_name: kwargs_hparams: "activation":                    gelu
"alias":                         None
"batch_size":                    32
"conv_hidden_size":              32
"decoder_input_size_multiplier": 0.5
"decoder_layers":                1
"drop_last_loader":              False
"dropout":                       0.05
"early_stop_patience_steps":     -1
"encoder_layers":                2
"exclude_insample_y":            False
"futr_exog_list":                None
"h":                             3
"hidden_size":                   128
"hist_exog_list":                None
"inference_windows_batch_size":  1024
"input_size":                    6
"learning_rate":                 0.0001
"loss":                          MAE()
"lr_scheduler":                  None
"lr_scheduler_kwargs":           None
"max_steps":                     2
"n_head":                        4
"num_lr_decays":                 -1
"num_workers_loader":            0
"optimizer":                     None
"optimizer_kwargs":              None
"random_seed":                   1
"scaler_type":                   identity
"start_padding_enabled":         False
"stat_exog_list":                None
"step_size":                     1
"val_check_steps":               100
"valid_batch_size":              None
"valid_loss":                    None
"windows_batch_size":            1024_hparams_initial: "activation":                    gelu
"alias":                         None
"batch_size":                    32
"conv_hidden_size":              32
"decoder_input_size_multiplier": 0.5
"decoder_layers":                1
"drop_last_loader":              False
"dropout":                       0.05
"early_stop_patience_steps":     -1
"encoder_layers":                2
"exclude_insample_y":            False
"futr_exog_list":                None
"h":                             3
"hidden_size":                   128
"hist_exog_list":                None
"inference_windows_batch_size":  1024
"input_size":                    6
"learning_rate":                 0.0001
"loss":                          MAE()
"lr_scheduler":                  None
"lr_scheduler_kwargs":           None
"max_steps":                     2
"n_head":                        4
"num_lr_decays":                 -1
"num_workers_loader":            0
"optimizer":                     None
"optimizer_kwargs":              None
"random_seed":                   1
"scaler_type":                   identity
"start_padding_enabled":         False
"stat_exog_list":                None
"step_size":                     1
"val_check_steps":               100
"valid_batch_size":              None
"valid_loss":                    None
"windows_batch_size":            1024random_seed: 1train_trajectories: []valid_trajectories: []optimizer: Noneoptimizer_kwargs: {}lr_scheduler: Nonelr_scheduler_kwargs: {}futr_exog_list: []hist_exog_list: []stat_exog_list: []futr_exog_size: 0hist_exog_size: 0stat_exog_size: 0trainer_kwargs: {'max_steps': 2, 'enable_checkpointing': False}h: 3input_size: 6windows_batch_size: 1024start_padding_enabled: Falsebatch_size: 32valid_batch_size: 32inference_windows_batch_size: 1024learning_rate: 0.0001max_steps: 2num_lr_decays: -1lr_decay_steps: 100000000.0early_stop_patience_steps: -1val_check_steps: 100step_size: 1exclude_insample_y: Falseval_size: 0test_size: 0decompose_forecast: Falsenum_workers_loader: 0drop_last_loader: Falsevalidation_step_outputs: []alias: Nonelabel_len: 3c_out: 1output_attention: Falseenc_in: 1

举例说明 如何构建windows

import pandas as pd
import numpy as npAirPassengers = np.array([112.0, 118.0, 132.0, 129.0, 121.0, 135.0, 148.0, 148.0, 136.0, 119.0],dtype=np.float32,
)AirPassengersDF = pd.DataFrame({"unique_id": np.ones(len(AirPassengers)),"ds": pd.date_range(start="1949-01-01", periods=len(AirPassengers), freq=pd.offsets.MonthEnd()),"y": AirPassengers,}
)Y_df = AirPassengersDF
Y_df = Y_df.reset_index(drop=True)
Y_df.head()
#Model Trainingfrom neuralforecast.core import NeuralForecast
from neuralforecast.models import NBEATShorizon = 3
models = [NBEATS(input_size=2 * horizon, h=horizon, max_steps=2)]nf = NeuralForecast(models=models, freq='M')
nf.fit(df=Y_df)

window_size 是窗口的总大小,它由 input_size 和 h 决定。
9= input_size(6) +h(3)
可以与原数据集对比下,是一个一个的往下移

当移动到 132.0的时候,为了凑齐9行,剩余的用0填充

窗口的形状就是 windows1 shape: torch.Size([4, 9, 2])

 window1: tensor([[[112.,   1.],[118.,   1.],[132.,   1.],[129.,   1.],[121.,   1.],[135.,   1.],[148.,   1.],[148.,   1.],[136.,   1.]],[[118.,   1.],[132.,   1.],[129.,   1.],[121.,   1.],[135.,   1.],[148.,   1.],[148.,   1.],[136.,   1.],[119.,   1.]],[[132.,   1.],[129.,   1.],[121.,   1.],[135.,   1.],[148.,   1.],[148.,   1.],[136.,   1.],[119.,   1.],[  0.,   0.]],[[129.,   1.],[121.,   1.],[135.,   1.],[148.,   1.],[148.,   1.],[136.,   1.],[119.,   1.],[  0.,   0.],[  0.,   0.]]])

windows_batch_size

最后由 windows1 shape: torch.Size([4, 9, 2])变成了 indows2 shape: torch.Size([1024, 9, 2])

也就是我们的传参windows_batch_size = 1024

下列举出4个例子,实际是1024个
表示采样了 1024 个窗口,每个窗口大小为9,包含 2 个特征。

....[[118.,  1.],[132.,   1.],[129.,   1.],[121.,   1.],[135.,   1.],[148.,   1.],[148.,   1.],[136.,   1.],[119.,   1.]],[[129.,   1.],[121.,   1.],[135.,   1.],[148.,   1.],[148.,   1.],[136.,   1.],[119.,   1.],[  0.,   0.],[  0.,   0.]],[[118.,   1.],[132.,   1.],[129.,   1.],[121.,   1.],[135.,   1.],[148.,   1.],[148.,   1.],[136.,   1.],[119.,   1.]],[[118.,   1.],[132.,   1.],[129.,   1.],[121.,   1.],[135.,   1.],[148.,   1.],[148.,   1.],[136.,   1.],[119.,   1.]],

最终训练时,返回的数据

windows_batch: {'temporal': 1024 个窗口数据, 'temporal_cols': Index(['y', 'available_mask'], dtype='object'), 'static': None, 'static_cols': None}

相关文章:

NeuralForecast 模型的参数 windows_batch的含义

NeuralForecast 模型的参数 windows_batch的含义 flyfish import pandas as pd import numpy as npAirPassengers np.array([112.0, 118.0, 132.0, 129.0, 121.0, 135.0, 148.0, 148.0, 136.0, 119.0],dtypenp.float32, )AirPassengersDF pd.DataFrame({"unique_id&qu…...

【记录】打印|用浏览器生成证件照打印PDF,打印在任意尺寸的纸上(简单无损!)

以前我打印证件照的时候,我总是在网上找在线证件照转换或者别的什么。但是我今天突然就琢磨了一下,用 PDF 打印应该也可以直接打印出来,然后就琢磨出来了,这么一条路大家可以参考一下。我觉得比在线转换成一张 a4 纸要方便的多&am…...

【python实现】实时监测GPU,空闲时自动执行脚本

文章目录 代码 代码 # author: muzhan # contact: levio.pkugmail.com import os import sys import time cmd nohup python -u train_post_2d_aut.py > output1.log & # gpu空闲时,需要执行的脚本命令 def gpu_info():gpu_status os.popen(nvidia-smi…...

chrome 浏览器历史版本下载

最近做一个项目,要使用到chrome浏览器比较久远的版本,在网上查找资源时,发现chrome比较老的版本的安装包特别难找,几经寻找,总算找到,具体方法如下 打开百度,搜索关键字【chrome版本号‘浏览迷’】,例如“chrome41浏览迷”,找到“全平台”开头的链接&am…...

【设计模式】工厂模式(创建型)⭐⭐⭐

文章目录 1.概念1.1 什么是工厂模式1.2 优点与缺点 2.实现方式2.1 简单工厂模式(Simple Factory)2.2 简单工厂模式缺点2.3 抽象工厂模式(Abstract Factory Pattern) 3 Java 哪些地方用到了工厂模式4 Spring 哪些地方用到了工厂模式…...

Postman 连接数据库 利用node+xmysql

1、准备nodejs环境 如果没有安装,在网上找教程,安装好后,在控制台输入命令查看版本,如下就成功了 2、安装xmysql 在控制台输入 npm install -g xmysql 3、连接目标数据库 帮助如下: 示例: 目标数据库…...

挑战你的数据结构技能:复习题来袭【6】

1. (单选题)设无向图的顶点个数为n,则该图最多有()条边 A. n-1 B. n(n-1)/2 C. n(n1)/2 D. 0 答案:B 分析: 2. (单选题)含有n个顶点的连通无向图,其边的个数至少为()。 A. n-1 B. n C. n1 D. nlog2n 答案:A…...

如何反编译jar并修改后还原为jar

如何反编译jar并修改后还原为jar 目标:修改jar包中某个类的某个方法后还原为新的jar 1.新建android工程,把旧的jar添加为lib 2.用jadx-gui打开旧的jar并保存所有资源 3.找到保存的资源中想修改的.java类 4.复制类中的内容, 在android工程中新建一个同样路径的包,并在包下创建…...

统计信号处理基础 习题解答10-5

题目 通过令 并进行计算来重新推导MMSE估计量。提示:利用结果 解答 首先需要明确的是: 上式是关于观测值x 的函数 其次需要说明一下这个结果 和教材一样,我们用求期望,需要注意的是,在贝叶斯情况下,是个…...

Vue3实战笔记(60)—从零开始:一步步搭建Vue 3自定义插件

文章目录 前言一、自定义插件二、使用步骤总结 前言 在开发和学习中,经常使用一些好用的插件,那么如何创建一个自己的插件呢?在 Vue 3 中,你可以通过创建一个包含 install 方法的对象来定义自定义插件。install 方法接收两个参数…...

Java面向对象笔记

多态 一种类型的变量可以引用多种实际类型的对象 如 package ooplearn;public class Test {public static void main(String[] args) {Animal[] animals new Animal[2];animals[0] new Dog();animals[1] new Cat();for (Animal animal : animals){animal.eat();}} }class …...

如何通过PHP语言实现远程控制多路照明

如何通过PHP语言实现远程控制多路照明呢? 本文描述了使用PHP语言调用HTTP接口,实现控制多路照明,通过多路控制器,可独立远程控制多路照明。 可选用产品:可根据实际场景需求,选择对应的规格 序号设备名称厂…...

Capture One Pro 23:专业 Raw 图像处理的卓越之选

在当今的数字摄影时代,拥有一款强大的图像处理软件至关重要。而 Capture One Pro 23 for Mac/Win 无疑是其中的佼佼者,为摄影师和图像爱好者带来了前所未有的体验。 Capture One Pro 23 以其出色的 Raw 图像处理能力而闻名。它能够精准地解析和处理各种…...

【主题广泛|投稿优惠】2024年交通运输与信息科学国际会议(ICTIS 2024)

2024年交通运输与信息科学国际会议(ICTIS 2024) 2024 International Conference on Transportation and Information Science 【重要信息】 大会地点:青岛 大会官网:http://www.icictis.com 投稿邮箱:icictissub-conf.…...

表格误删数据保存关闭后如何恢复?5个恢复方法大公开!

“我在编辑表格的时候一不小心就删除了部分数据,现在真的不知道该怎么操作了。希望大家能帮帮我吧!” 在日常工作中,我们经常会使用到各种表格软件来处理和分析数据。然而,有时由于操作失误或其他原因,我们可能会误删表…...

Go 语言中的切片:灵活的数据结构

切片(slice)是 Go 语言中一种非常重要且灵活的数据结构,它提供了对数组子序列的动态窗口。这使得切片在 Go 中的使用非常频繁,特别是在处理动态数据集时。本文将探讨切片的概念、操作和与函数的交互,以及如何有效地使用…...

在鲲鹏服务器搭建k8s高可用集群分享

高可用架构 本文采用kubeadm方式搭建k8s高可用集群,k8s高可用集群主要是对apiserver、etcd、controller-manager、scheduler做的高可用;高可用形式只要是为: 1. apiserver利用haproxykeepalived做的负载,多apiserver节点同时工作…...

MySQL之数据库事务机制学习笔记(五)

事务机制 事务(Transaction)是数据库管理系统中的一个重要概念,它是一组数据库操作的逻辑单元,要么全部执行成功,要么全部执行失败,具有以下四个特性,通常缩写为 ACID: 原子性&…...

linux 系统被异地登录,cpu占用拉满100%

一般是kswapd0导致的cpu占用异常 按顺序执行以下操作 在控制台执行top命令,查看占用最高的是否kswapd0。基本100%占用。记下该进程ID 5081 执行查找命令 find / -name kswapd0 显示查找结果: /proc/3316/.X2c4-unix/.rsync/a/kswapd0 /root/.configrc…...

智慧校园应用平台的全面建设

在当今社会,随着科技的不断进步,智慧校园应用平台逐渐成为学校管理的必备工具。在实现智慧校园全面建设的过程中,学校需要运用先进的技术和创新的理念,为教育提供更好的服务和支持。这篇文章将为您介绍智慧校园应用平台的全面建设…...

图论第6天

提高效率!!!两道题看并查集 841.钥匙和房间 忘了把visited 加引用了&#xff1a;& class Solution { public:bool canVisitAllRooms(vector<vector<int>>& rooms) {vector<int>visited(rooms.size(),false);dfs(rooms,visited,0);for(int i 0;i …...

Redis教程(二十一):Redis怎么保证缓存一致性

传送门:Redis教程汇总篇,让你从入门到精通 Redis 的缓存一致性 Redis 的缓存一致性是指在使用 Redis 作为缓存层时,保证缓存中的数据与数据库中的数据保持一致的状态。在分布式系统中,数据一致性是一个重要的问题,因为可能存在多个客户端同时读写同一数据,或者数据在不同…...

android apk签名

android apk签名 命令&#xff1a; java -jar signapk.jar platform.x509.pem platform.pk8 **.apk ***.apk note&#xff1a; apk密钥为&#xff1a; platform.pk8和platform.x509.pem 路径&#xff1a; build\target\product\security apk签名工具&#xff1a;sign…...

flutter 解析json另类封装方式 List<bean>,哈哈哈

flutter 解析json另类封装方式&#xff0c;哈哈哈 日常学习&#xff0c;仅供参考&#xff0c;不喜 勿喷 http请求数据泛型解析封装&#xff0c;需要判断泛型数据类型再根据类型解析&#xff0c;本文只抽取了list演示 核心代码 import dart:convert;import package:webwsyn/h…...

哈希表(Hash table)

哈希表(Hash table),也称为散列表,是一种根据关键码值(Key value)直接进行访问的数据结构。它通过散列函数(Hash function)将关键码值映射到表中的一个位置,以此来访问记录,从而加快查找的速度。以下是关于哈希表的详细解释: 基本概念 散列函数:将关键码值映射到表…...

【c语言】自定义类型-结构体

结构体 结构体的声明与使用结构体的声明与初始化结构体的自引用 结构体的内存对齐对齐规则为什么存在内存对齐修改默认对齐数 结构体的传参结构体实现位段什么是位段位段的内存分配位段的跨平台问题位段使用的注意事项 结构体&#xff1a;是一个自定义的类型&#xff0c;成员可…...

2-链表-71-环形链表 II-LeetCode142

2-链表-71-环形链表 II-LeetCode142 参考&#xff1a;代码随想录 LeetCode: 题目序号142 更多内容欢迎关注我&#xff08;持续更新中&#xff0c;欢迎Star✨&#xff09; Github&#xff1a;CodeZeng1998/Java-Developer-Work-Note 技术公众号&#xff1a;CodeZeng1998&#…...

【UnityShader入门精要学习笔记】第十七章 表面着色器

本系列为作者学习UnityShader入门精要而作的笔记&#xff0c;内容将包括&#xff1a; 书本中句子照抄 个人批注项目源码一堆新手会犯的错误潜在的太监断更&#xff0c;有始无终 我的GitHub仓库 总之适用于同样开始学习Shader的同学们进行有取舍的参考。 文章目录 表面着色器…...

Python社会经济 | 怀特的异方差一致估计量

&#x1f3af;要点 &#x1f3af;算法​和模型底层数学及代码&#xff1a;&#x1f58a;线性代数应用&#xff08;主成分分析&#xff09;&#xff1a;降维、投影&#xff08;用于求解线性系统&#xff09;和二次形式&#xff08;用于优化&#xff09;| &#x1f58a;奇值分解…...

《被讨厌的勇气》笔记

自由就是被别人讨厌。对人而言&#xff0c;最大的不幸就是不喜欢自己。活在“如果怎样怎样”之类的假设之中&#xff0c;就根本无法改变。活在害怕关系破裂的恐惧之中&#xff0c;那是为他人而活的一种不自由的生活方式。人生是连续刹那&#xff0c;我们只能活在“此时此刻”。…...

网站开发部门工资入什么科目/seo关键词排名优化哪好

目标&#xff1a;在同一个类中&#xff0c;多个测试函数时候&#xff0c;测试固件如何写。首先&#xff0c;我们先看一下如果存在两个测试函数的时候&#xff0c;程序是怎么执行的test1.pyimport timeimport unittestfrom framework.browser_engine import BrowserEnginefrom p…...

毕业设计可以做自己的网站吗/灰色词排名上首页

翻阅古今 读写文件是最常见的IO操作。Python内置了读写文件的函数&#xff0c;用法和C是兼容的。读写文件前&#xff0c;我们先必须了解一下&#xff0c;在磁盘上读写文件的功能都是由操作系统提供的&#xff0c;现代操作系统不允许普通的程序直接操作磁盘&#xff0c;所以&…...

flickr wordpress/社群营销的具体方法

IP Messenger是一款局域网内部聊天、文件传输工具&#xff0c;具有很多优点&#xff0c;如数据通讯不需要建立服务器、直接在两台电脑间通信和数据传输&#xff0c;支持文件及文件目录的传输&#xff0c;安全快捷以及小巧方便等优异特点&#xff0c;因此很多公司都采用它作为部…...

微网站建站平台/沈阳沈河seo网站排名优化

本系列参考文献为光学时间拉伸成像原理及应用。 这篇文章先开个头&#xff1a; 光学时间拉伸成像是一种新兴的超快光学成像方法。它克服了传统成像方法中存在的限制&#xff0c;能够实现超高帧的连续图像采集。光学时间拉伸成像可以与放大、非线性处理、压缩感知、图形相关等多…...

如何做音乐分享类网站/百度推广客户端app

测试前先启动hadoop [hadoopmini-yum ~]$ start-dfs.sh [hadoopmini-yum ~]$ start-yarn.sh 1在一堆给定的文本文件中统计输出每一个单词出现的总次数 代码 package cn.feizhou.wcdemo;import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.fs.Path; impo…...

如何根据网址攻击网站/普通话手抄报文字内容

去商场淘打折商品时&#xff0c;计算打折以后的价钱是件颇费脑子的事情。例如原价 &#xffe5;988&#xff0c;标明打 7 折&#xff0c;则折扣价应该是 &#xffe5;988 x 70% &#xffe5;691.60。本题就请你写个程序替客户计算折扣价。输入格式&#xff1a;输入在一行中给出…...