当前位置: 首页 > news >正文

从0开始基于transformer进行股价预测(pytorch版本)

目录

  • 数据阶段
    • 两个问题
    • 开始利用我们的代码进行切分
  • backbone网络
  • 训练
  • 效果 感觉还行,没有调参数。
  • 源码比较长,如果需要我后续会发(因为太长了!!)

数据阶段

!!!注意!!! , 本文不会讲原理,因为之前两篇文章已经讲过了,只会解释一些结构性问题,和思路问题。

所谓工欲善其事,必先利其器做量化分析的股价预测,完美必须要先把数据处理好。
那么本人的数据下载是在聚宽平台股票代码为601398的数据2014-3 到 2024-3年的默认数据。如何下载可以按照我的方式

在这里插入图片描述
进入研究环境后随便创建一个ipynb文件进行数据下载 ,运行以下代码就行

# 1.获取数据
data = get_price('601398.XSHG', start_date='2014-01-01', end_date='2024-01-01', frequency='daily', fields=None, skip_paused=False, fq='pre', panel=True)
# 2.保存数据
data.to_csv('data_沪深300/601398.XSHG(工商银行14-24).csv')

两个问题

1.为什么我们只需要用encoder部分去预测就行而不需要decoder部分?
答: 编码器用于将输入序列编码成一个上下文表示(contextual representation),然后解码器根据该上下文表示生成目标序列在时间序列预测任务中,我们不需要生成一个序列,而是预测单个或少量几个未来数据点。因此,编码器的上下文表示已经包含了足够的信息来进行预测,无需使用解码器。还有我觉得使用解码器的意思是,你用上一天的数据去预测下一天的数据,我感觉这样就没意思了,这和我们个人看有什么区别。而且对最后的结果也会造成不精准的效果。为什么这么说呢,你看解码器的mask编码部分应该可以理解了。
2.我们的维度为什么不是[batch, len, feature]? 因为这是pytorch要求,自己能实现的话,自己改吧。

开始利用我们的代码进行切分

我的思路用的是用五天的数据去预测下一天,数据集和测试及8/2分
但是我们要记住一点,就是我们必须要理解我们这么做的思路,就比如我们的特征有6列分别是,open,close,high,low,volume,money,我们可以通过训练得到我们想预测的某一特征。OK,我们这就开始。

说起数据分割里面的代码不难,最难的是
for i in range(len(X_CONVERT) - seq_length):
X_data.append(X_CONVERT[i:i+seq_length, :])
y_data.append(X_CONVERT[i+seq_length, 1])
你要知道我在干什么,就是用8成的数据集去预测得到我们所需要的train数据集和我们对应train数据集的label,举个例子就是,我们要炒菜,我们拿上原料后我们要知道炒的什么菜,那么菜单必须要知道。是吧,不然你炒完菜后说是红烧肉,但是没有菜单图片对比你怎么知道这是红烧肉?这也就是这一步的意义。

def split_data(batch_size,seq_length, pred_length, train_ratio):data_all = pd.read_csv(data_path)data_ha = []length = len(data_all)# 将数据转换为numpy数组,并添加到列表中for element in elements:data_element = data_all[element].values.astype(np.float32)data_element = data_element.reshape(length, 1)data_ha.append(data_element)X_hat = np.concatenate(data_ha, axis=1)X_CONVERT = torch.from_numpy(X_hat).float()X_CONVERT = X_CONVERT.flip(dims=[0])# 进行归一化min_val = np.min(X_hat, axis=0)max_val = np.max(X_hat, axis=0)X_normalized = (X_hat - min_val) / (max_val - min_val)X_CONVERT = torch.from_numpy(X_normalized).float()X_CONVERT = X_CONVERT.flip(dims=[0])#数据翻转# 划分训练集和验证集X_data = []y_data = []for i in range(len(X_CONVERT) - seq_length):#划分的时候是用8成的训练集去训练然后label是某##一列X_data.append(X_CONVERT[i:i+seq_length, :])y_data.append(X_CONVERT[i+seq_length, 1])X_data = torch.stack(X_data)y_data = torch.stack(y_data).squeeze(-1)print(X_data.shape, y_data.shape)dataset = TensorDataset(X_data, y_data)train_size = int(len(dataset) * train_ratio)val_size = len(dataset) - train_sizetrain_dataset, val_dataset = random_split(dataset, [train_size, val_size])train_loader = DataLoader(train_dataset, batch_size, shuffle=False)val_loader = DataLoader(val_dataset, batch_size, shuffle=False)return train_loader, val_loader,min_val, max_val

backbone网络

如其名,我们都知道这是这是transformer当然是用的transformer的结构。但是我们用,但是只用一部分,具体用什么部分开头说了,只用encoder

**但是具体操作起来的时候encoder里面的embadding部分我们需要修改,因为我们不是机器翻译,所以我们不需要把他变成词向量,我们时间序列数据,输入通常是连续的数值特征,使用线性层更直接地将这些数值特征映射到高维空间。并且我们的embadding嵌入层,适用于离散的输入,输出是固定维度的嵌入向量。而线性层,适用于连续的输入,可以灵活处理不同维度的输入特征,将其映射到高维表示。**具体看下面代码

class Encoder(nn.Module):def __init__(self):super(Encoder, self).__init__()self.src_emb = nn.Linear(feature, d_model)#这里替换了self.pos_emb = PositionalEncoding(d_model)self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])def forward(self, enc_inputs):enc_outputs = self.src_emb(enc_inputs)  # [batch_size, src_len, d_model]enc_outputs = self.pos_emb(enc_outputs)  # [batch_size, src_len, d_model]enc_self_attn_mask = get_attn_pad_mask(enc_inputs, enc_inputs)  # [batch_size, src_len, src_len]enc_self_attns = []for layer in self.layers:enc_outputs, enc_self_attn = layer(enc_outputs, enc_self_attn_mask)enc_self_attns.append(enc_self_attn)return enc_outputs, enc_self_attns

当然其他的部分和我上一篇的一样,但是就是decode不要了,当然也可以换成其他结果,或者加个注意力机制

讲下各个参数


d_model = 512   # linnerer的输入维度 也就是字embedding的维度
d_ff = 2048     # 前向传播隐藏层维度
d_k = d_v = 64  # K(=Q), V的维度
n_layers = 6    # 有多少个encoder和decoder
n_heads = 8     # Multi-Head Attention设置为8
feature=6       # 输入特征维度

当然主体还是要看一下的最重要的是通过encoder后的维度转换比较繁琐,要和我们之前split的数据集得到的y_train一致这样才能计算损失


class Transformer(nn.Module):def __init__(self):super(Transformer, self).__init__()self.Encoder = Encoder()self.projection = nn.Linear(d_model, 1, bias=False)def forward(self, enc_inputs):  # enc_inputs: [batch_size, src_len, feature]enc_outputs, enc_self_attns = self.Encoder(enc_inputs)  # enc_outputs: [batch_size, src_len, d_model]dec_logits = self.projection(enc_outputs)  # dec_logits: [batch_size, src_len, 1]dec_logits = dec_logits.mean(dim=1)  # 将每个时间步的预测结果取平均,得到 [batch_size, 1]return dec_logits.squeeze(-1), enc_self_attns  # 输出 [batch_size]

训练

先解释参数

batch_size=64#批处理大小
seq_length=7#时间序列长度 也就是通过seq_length天预测后面pred_length天
pred_length=1#预测长度
train_ratio=0.8#训练集比例
epochs = 50 # 训练轮数
lr= 0.001 # 学习率
png_save_path="diytransformers/12.24transformer/picture"#所有的图片保存的地方
loss_history = []# 存储每个 epoch 的损失

训练代码很长,挺简单的


# 训练模型
for epoch in range(epochs):epoch_loss = 0y_pre = []y_true = []# 训练阶段for X, y in train_loader:X = X.float()  # 确保输入数据类型为float32y = y.float()  # 确保目标数据类型为float32outputs, enc_self_attns = model(X)# 计算损失,确保形状一致loss = criterion(outputs, y)epoch_loss += loss.item()optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)optimizer.step()#转换我们的label和训练后得到的训练集的预测值 y_pre.append(outputs.detach())y_true.append(y.detach())avg_loss = epoch_loss / len(train_loader)loss_history.append(avg_loss)#获得最好的lossif avg_loss < best_loss:best_loss = avg_lossbest_epoch = epochbest_model_wts = copy.deepcopy(model.state_dict())torch.save(best_model_wts, path_train)y_pre_concat = torch.cat(y_pre, dim=0)y_true_concat = torch.cat(y_true, dim=0)# 计算并打印评估指标metrics = evaluate(y_pre_concat, y_true_concat, min_val, max_val)print(f'Epoch {epoch + 1}, Loss: {avg_loss:.6f}')# 可视化结果ht(y_true_concat.detach().cpu().numpy(), y_pre_concat.detach().cpu().numpy(), min_val, max_val,png_save_path)

最后是看我们的一些指标效果如何 比如这里我计算的mae,rmse,pcc等

# 加载最佳模型权重
model.load_state_dict(torch.load(train_over_path))# 测试模型并计算评估指标
test_metrics = test_model(model, val_loader, min_val, max_val)print(f'Test Metrics: {test_metrics}')

效果 感觉还行,没有调参数。

在这里插入图片描述

源码比较长,如果需要我后续会发(因为太长了!!)

相关文章:

从0开始基于transformer进行股价预测(pytorch版本)

目录 数据阶段两个问题开始利用我们的代码进行切分 backbone网络训练效果 感觉还行&#xff0c;没有调参数。源码比较长&#xff0c;如果需要我后续会发&#xff08;因为太长了&#xff01;&#xff01;&#xff09; 数据阶段 &#xff01;&#xff01;&#xff01;注意&#…...

【多GPU训练方法】

一、数据并行 这是最常用的方法。整个模型复制到每个GPU上。训练数据被均匀分割&#xff0c;每个GPU处理一部分数据。所有GPU上的梯度被收集并求平均。通常使用NCCL&#xff08;NVIDIA Collective Communications Library&#xff09;等通信库实现。参数更新 使用同步后的梯度…...

2024年PMP考试备考经验分享

PMP是项目管理领域最重要的认证之一,本身是IT行业比较流行的证书&#xff0c;近几年在临床试验领域也渐渐流行起来&#xff0c;是我周围临床项PM几乎人手一个的证书。 考试时间&#xff1a;PMP认证考试形式为180道选择题&#xff0c;考试时间为3小时50分。 考试计划&#xff…...

MT3046 愤怒的象棚

思路&#xff1a; a[]存愤怒值&#xff1b;b[i]存以i结尾的&#xff0c;窗口里的最大值&#xff1b;c[i]存以i结尾的&#xff0c;窗口里面包含✳的最大值。 &#xff08;✳为新大象的位置&#xff09; 例&#xff1a;1 2 3 4 ✳ 5 6 7 8 9 则ans的计算公式b3b4c4c5c6b7b8b9…...

深入了解代理IP常见协议:区别与选择

代理服务器在网络使用中扮演着重要的角色&#xff0c;是您设备和互联网之间的中间层。它不仅可以增强网络访问的安全性和隐私保护&#xff0c;还可以提供许多灵活的应用。使用代理时&#xff0c;不同的协议类型对数据交换具有不同的规则和特征。常见的代理协议包括HTTP代理、HT…...

【Linux 线程】线程的基本概念、LWP的理解

文章目录 一、ps -L 指令&#x1f34e;二、线程控制 一、ps -L 指令&#x1f34e; &#x1f427; 使用 ps -L 命令查看轻量级进程信息&#xff1b;&#x1f427; pthread_self() 用于获取用户态线程的 tid&#xff0c;而并非轻量级进程ID&#xff1b;&#x1f427; getpid() 用…...

Dify中的工具

Dify中的工具分为内置工具&#xff08;硬编码&#xff09;和第三方工具&#xff08;OpenAPI Swagger/ChatGPT Plugin&#xff09;。工具可被Workflow&#xff08;工作流&#xff09;和Agent使用&#xff0c;当然Workflow也可被发布为工具&#xff0c;这样Workflow&#xff08;工…...

在Visutal Studio 2022中完成D3D12初始化

在Visutal Studio 2022中完成DirectX设备初始化 1 DirectX121.1 DirectX 简介1.2 DirectX SDK安装2 D3D12初始化2.1 创建Windwos桌面项目2.2 修改符合模式2.3 下载d3dx12.h文件2.4 创建一个异常类D3DException,定义抛出异常实例的宏ThrowIfFailed3 D3D12的初始化步骤3.1 初始化…...

MobaXterm工具

MobaXterm 是一个增强型的 Windows 终端。其为 Windows 桌面提供所有重要的远程网络终端工具&#xff08;如 SSH、X11、RDP、VNC、FTP、SFTP、Telnet、Serial、Mosh、WSL 等&#xff09;&#xff0c;和 Unix 命令&#xff08;如 bash、ls、cat、sed、grep、awk、rsync 等&#…...

二分图练习

对于二分图我们可以用染色法 #include<bits/stdc.h> using namespace std;#define int long long const int N 2e65; int e[N],ne[N],h[N],idx 0; int colo[N]; int num 0;void add(int x,int y){e[idx] y;ne[idx] h[x];h[x] idx; } void dfs(int nod,int c){colo…...

创新设计策略:提升大屏幕可视化设计效果的关键方法

随着科技的不断发展和数据量的快速增长&#xff0c;数据可视化大屏在各个行业中的应用越来越广泛&#xff0c;可以帮助人们更好地理解和分析数据&#xff0c;可视化大屏设计也因此成了众多企业的需求。但很多设计师对可视化大屏设计并不了解&#xff0c;也不知道如何制作可视化…...

论文 | Chain-of-Thought Prompting Elicits Reasoningin Large Language Models 思维链

这篇论文研究了如何通过生成一系列中间推理步骤&#xff08;即思维链&#xff09;来显著提高大型语言模型进行复杂推理的能力。论文展示了一种简单的方法&#xff0c;称为思维链提示&#xff0c;通过在提示中提供几个思维链示例来自然地激发这种推理能力。 主要发现&#xff1…...

[机器学习]-人工智能对程序员的深远影响——案例分析

机器学习和人工智能对未来程序员的深远影响 目录 机器学习和人工智能对未来程序员的深远影响1. **自动化编码任务**1.1 代码生成1.2 自动调试1.3 测试自动化 2. **提升开发效率**2.1 智能建议2.2 项目管理 3. **改变编程范式**3.1 数据驱动开发 4. **职业发展的新机遇**4.1 AI工…...

AI学习环境 没有更好的替代 - (Google)Drive + Colab

在开始正题前&#xff0c;请容许我做一番回顾&#xff0c;并夹带一点点私货&#xff08;谷歌扛旗的开源精神还没有死&#xff0c;并且会是未来的举足轻重的力量&#xff09; 卧龙凤雏&#xff0c;一时瑜亮。一切的缘起应该是世纪初的门户网站乱战。 彼时&#xff0c;谷歌是从…...

【观成科技】Websocket协议代理隧道加密流量分析与检测

Websocket协议代理隧道加密流量简介 攻防场景下&#xff0c;Websocket协议常被用于代理隧道的搭建&#xff0c;攻击者企图通过Websocket协议来绕过网络限制&#xff0c;搭建一个低延迟、双向实时数据传输的隧道。当前&#xff0c;主流的支持Websocket通信代理的工具有&#xf…...

DangerWind-RPC-framework---三、服务端下机

当一台机器下线时&#xff0c;面临很多问题&#xff1a;如何将其从注册中心下线&#xff1f;如何清理释放资源&#xff1f;客户端拉取服务列表时也使用了本地缓存&#xff0c;如何及时更新本地缓存&#xff1f; 服务端机器的优雅下线需要使用ShutdownHook&#xff0c;这相当于添…...

基于Make的c工程No compilation commands found报错

由于安装gcc时只安装了build-essential&#xff0c;没有将其添加到环境变量中&#xff0c;因此打开Make工程时&#xff0c;CLion会产生如下错误&#xff1a; 要解决这个问题&#xff0c;一个方法是将GCC添加到环境变量中&#xff0c;但是这个方法需要修改至少两个配置文件&…...

c++:面向对象的继承特性

什么是继承 (1)继承是C源生支持的一种语法特性&#xff0c;是C面向对象的一种表现 (2)继承特性可以让派生类“瞬间”拥有基类的所有&#xff08;当然还得考虑权限&#xff09;属性和方法 (3)继承特性本质上是为了代码复用 (4)类在C编译器的内部可以理解为结构体&#xff0c;派…...

skywalking-2-客户端-php的安装与使用

skywalking的客户端支持php&#xff0c;真的很棒。 官方安装文档&#xff1a;https://skywalking.apache.org/docs/skywalking-php/next/en/setup/service-agent/php-agent/readme/ 前置准备 本次使用的php版本是8.2.13: php -v PHP 8.2.13 (cli) (built: Nov 21 2023 09:5…...

图文讲解IDEA如何导入JDBC驱动包

前言 学习JDBC编程,势必要学会如何导入驱动包,这里笔者用图文的方式来介绍 视频版本在这里 50秒教你怎么导入驱动包然后进行JDBC编程的学习_哔哩哔哩_bilibili 忘记录音频了,大伙凑合着看 下载驱动包 https://mvnrepository.com/artifact/mysql/mysql-connector-java 去中…...

java.lang.NullPointerException: null cannot be cast to non-null type kotlin.Int

java.lang.NullPointerException: null cannot be cast to non-null type kotlin.Int fun main(args: Array<String>) {var any1: Any?any1 nullval n1 any1 as? Int ?: -2024println(n1)kotlin.runCatching {var any2: Any?any2 nullval n2 any2 as Intprintln(…...

scrapy写爬虫

Scrapy是一个用于爬取网站数据并提取结构化信息的Python框架 一、Scrapy介绍 1.引擎&#xff08;Engine&#xff09; – Scrapy的引擎是控制数据流和触发事件的核心。它管理着Spider发送的请求和接收的响应&#xff0c;以及处理Spider生成的Item。引擎是Scrapy运行的驱动力。…...

Mybatis study

一、Mybatis Plus mybatis-plus指定实体类字段不查询 加标签 TableField(exist false) Spring Data Jpa学习 干我们这行&#xff0c;啥时候懈怠&#xff0c;就意味着长进的停止&#xff0c;长进的停止就意味着被淘汰&#xff0c;只能往前冲&#xff0c;直到凤凰涅槃的一天&am…...

【论文速读】《面向深度学习的联合消息传递与自编码器》

这篇文章来自华为的渥太华无线先进系统能力中心和无线技术实验室&#xff0c;作者中有大名鼎鼎的童文。 一、自编码架构的全局收发机面临的主要问题 文章对我比较有启发的地方&#xff0c;是提到自编码架构的全局收发机面临的主要问题&#xff1a; 问题一&#xff1a;基于随…...

防御---001

一、实验拓扑二、要求 1&#xff0c;DMZ区内的服务器&#xff0c;办公区仅能在办公时间内(9:00 - 18:00)可以访问&#xff0c;生产区的的设备全天可以访问. 2&#xff0c;生产区不允许访问互联网&#xff0c;办公区和游客区允许访问互联网 3,办公区设备10.0.2.10不允许访问DMZ…...

DNS 杂谈

一、定义 DNS&#xff08;Domain Name System&#xff09;&#xff0c;域名系统&#xff0c;该系统记录域名和Ip地址的相互映射关系。用户访问互联网时&#xff0c;通过域名地址得到对应的IP地址&#xff0c;这个过程称为域名解析。DNS运行于UDP协议之上&#xff0c;使用的端口…...

docker笔记2

docker笔记2 一、阿里云镜像配置二、docker基本原理1.docker是如何启动一个容器的2.docker的底层原理 三、镜像命令总结 一、阿里云镜像配置 配置镜像的目的 由于Docker Hub等公共镜像仓库的服务器可能位于国外&#xff0c;直接从中拉取镜像时可能会遇到网络延迟或不稳定的问…...

数字统计

import java.util.Scanner;// 注意类名必须为 Main, 不要有任何 package xxx 信息 public class Main {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别// 注意 while 处理多个 caseint a in.nextInt();i…...

Git 使用问题

Git 使用问题 1, 网络问题 1, 网络问题 # 报错如下&#xff1a; fatal: unable to access https://github.com/xianglingliwei/HRNet.git/: Failed to connect to github.com port 443 after 21044 ms: Couldnt connect to server在不能正常访问Github的区域&#xff0c;需要设…...

JMH325【剑侠情缘3】第2版80级橙武网游单机更稳定亲测视频安装教学更新整合收集各类修改教学补丁兴趣可以慢慢探索

资源介绍&#xff1a; 是否需要虚拟机&#xff1a;是 文件大小&#xff1a;压缩包约14G 支持系统&#xff1a;win10、win11 硬件需求&#xff1a;运行内存8G 4核及以上CPU独立显卡 下载方式&#xff1a;百度网盘 任务修复&#xff1a; 1&#xff0c;掌门任务&#xff08…...

普陀网站建设/线上营销活动主要有哪些

因为嘉伟思杯里的一个脚本题目&#xff0c;16进制计算&#xff0c;python3正则还没学&#xff0c;所以没写出来。大佬跟我说也可以用BS4&#xff0c;从DOM上下手,直接爬下来直接一个eval就搞定了&#xff0c;eval可以像这样计算16进制,eval(0x2b0x37)。BUGKU已经写了很多了&…...

做网站维护师傅带要学多久/链接提交入口

php连接访问Oracle是用过oci函数&#xff0c;以下是整理的文档1.安装Apache和php包 yum install -y httpd php* 2.下载Oracle组件oracle-instantclient-basic-10.2.0.4-1.i386.rpmoracle-instantclient-sqlplus-10.2.0.4-1.i386.rpmoracle-instantclient-devel-10.2.0.4-1.i38…...

wordpress文章meta/自己怎么做引流推广

一个朋友是前阿里人&#xff0c;37岁&#xff0c;离职后就职美团。以前投一个面一个&#xff0c;今年想跳槽&#xff0c;但没想到投十个能有两个面试机会就不错了&#xff0c;最后索性又回了阿里做架构。 他在面试的时候&#xff0c;碰见比自己大的面试官&#xff0c;态度和善&…...

呼伦贝尔建设工程检测网站/百度关键词是怎么排名靠前

时间限制&#xff1a;C/C 1秒&#xff0c;其他语言2秒 空间限制&#xff1a;C/C 256M&#xff0c;其他语言512M 热度指数&#xff1a;25218 本题知识点&#xff1a; 链表 题目描述 输入一个链表&#xff0c;输出该链表中倒数第k个结点。 示例1 输入 {1,2,3,4,5},1 返回值 {5} …...

免费flash网站模板/武汉关键词排名工具

Requests 是一个 Python 的 HTTP 客户端库。Request支持HTTP连接保持和连接池&#xff0c;支持使用cookie保持会话&#xff0c;支持文件上传&#xff0c;支持自动响应内容的编码&#xff0c;支持国际化的URL和POST数据自动编码。在python内置模块的基础上进行了高度的封装&…...

2015年做那些网站能致富/电商网站平台搭建

课程介绍 IDEA是一款功能强悍、非常好用的Java开发工具&#xff0c;近几年编程开发人员对IDEA情有独钟。虽然IDEA功能很强大&#xff0c;但目前市面讲解的不细致、不系统&#xff0c;导致很多IDEA初学者要么无从下手&#xff0c;要么耗费太多时间。本套课程分知识点进行录制讲解…...