当前位置: 首页 > news >正文

PyTorch -- RNN 快速实践

  • RNN Layer torch.nn.RNN(input_size,hidden_size,num_layers,batch_first)

    • input_size: 输入的编码维度
    • hidden_size: 隐含层的维数
    • num_layers: 隐含层的层数
    • batch_first: ·True 指定输入的参数顺序为:
      • x:[batch, seq_len, input_size]
      • h0:[batch, num_layers, hidden_size]
  • RNN 的输入

    • x:[seq_len, batch, input_size]
      • seq_len: 输入的序列长度
      • batch: batch size 批大小
    • h0:[num_layers, batch, hidden_size]
  • RNN 的输出

    • y: [seq_len, batch, hidden_size]

在这里插入图片描述


  • 实战之预测 正弦曲线:以下会以此为例,演示 RNN 预测任务的部署
    在这里插入图片描述
    • 步骤一:确定 RNN Layer 相关参数值并基于此创建 Net

      import numpy as np
      from matplotlib import pyplot as pltimport torch
      import torch.nn as nn
      import torch.optim as optimseq_len     = 50
      batch       = 1
      num_time_steps = seq_leninput_size  = 1
      output_size = input_size
      hidden_size = 10  	
      num_layers = 1  	
      batch_first = True class Net(nn.Module):  ## model 定义def __init__(self):super(Net, self).__init__()self.rnn = nn.RNN(input_size=input_size,hidden_size=hidden_size,num_layers=num_layers,batch_first=batch_first)# for p in self.rnn.parameters():# 	nn.init.normal_(p, mean=0.0, std=0.001)self.linear = nn.Linear(hidden_size, output_size)def forward(self, x, hidden_prev):out, hidden_prev = self.rnn(x, hidden_prev)# out: [batch, seq_len, hidden_size]out = out.view(-1, hidden_size)  # [batch*seq_len, hidden_size]out = self.linear(out) 			 # [batch*seq_len, output_size]out = out.unsqueeze(dim=0)    # [1, batch*seq_len, output_size]return out, hidden_prev
      
    • 步骤二:确定 训练流程

      lr=0.01def tarin_RNN():model = Net()print('model:\n',model)criterion = nn.MSELoss()optimizer = optim.Adam(model.parameters(), lr)hidden_prev = torch.zeros(num_layers, batch, hidden_size)  #初始化hl = []for iter in range(100):  # 训练100次start = np.random.randint(10, size=1)[0]  ## 序列起点time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点output, hidden_prev = model(x, hidden_prev)hidden_prev = hidden_prev.detach()  ## 最后一层隐藏层的状态要 detachloss = criterion(output, y)model.zero_grad()loss.backward()optimizer.step()if iter % 100 == 0:print("Iteration: {} loss {}".format(iter, loss.item()))l.append(loss.item())#############################绘制损失函数#################################plt.plot(l,'r')plt.xlabel('训练次数')plt.ylabel('loss')plt.title('RNN LOSS')plt.savefig('RNN_LOSS.png')return hidden_prev,modelhidden_prev,model = tarin_RNN()
      
    • 步骤三:测试训练结果

      start = np.random.randint(3, size=1)[0]  ## 序列起点
      time_steps = np.linspace(start, start+10, num_time_steps)  ## 序列
      data = np.sin(time_steps).reshape(num_time_steps, 1)  ## 序列数据
      x = torch.tensor(data[:-1]).float().view(batch, seq_len-1, input_size)
      y = torch.tensor(data[1: ]).float().view(batch, seq_len-1, input_size)  # 目标为预测一个新的点    predictions = []  ## 预测结果
      input = x[:,0,:]
      for _ in range(x.shape[1]):input = input.view(1, 1, 1)pred, hidden_prev = model(input, hidden_prev)input = pred  ## 循环获得每个input点输入网络predictions.append(pred.detach().numpy()[0])
      x= x.data.numpy()
      y = y.data.numpy( )
      plt.scatter(time_steps[:-1], x.squeeze(), s=90)
      plt.plot(time_steps[:-1], x.squeeze())
      plt.scatter(time_steps[1:],predictions)  ## 黄色为预测
      plt.show()
      

      在这里插入图片描述


【高阶】上述例子比较简单,便于入门以推理到自己的目标任务,实际 RNN (长时间序列) 训练可能更有难度,可以添加

  • 对于梯度爆炸的解决:
    for p in model.parameters()"print(p.grad.norm())  ## 查阅梯度,看看是否爆炸torch.nn.utils.clip_grad_norm_(p, 10)  ## grad 限幅,其中的 norm 后面的_ 表示 in place 操作
    
  • 对于梯度消失的解决:-> LSTM

  • 另一个很好的实例关于飞行轨迹预测- - RNN-博客链接,可供学习参考
  • B站视频参考资料

相关文章:

PyTorch -- RNN 快速实践

RNN Layer torch.nn.RNN(input_size,hidden_size,num_layers,batch_first) input_size: 输入的编码维度hidden_size: 隐含层的维数num_layers: 隐含层的层数batch_first: True 指定输入的参数顺序为: x:[batch, seq_len, input_size]h0:[batc…...

SpringBoot 快速入门(保姆级详细教程)

目录 一、Springboot简介 二、SpringBoot 优点: 三、快速入门 1、新建工程 方式2:使用Spring Initializr创建项目 写在前面: SpringBoot 是 Spring家族中的一个全新框架,用来简化spring程序的创建和开发过程。SpringBoot化繁…...

【第18章】Vue实战篇之登录界面

文章目录 前言一、数据绑定1. 数据绑定2. 数据清空 二、表单校验1. 代码2. 展示 三、登录1.登录按钮2.user.js3. login 四、展示总结 前言 上一章完成用户注册&#xff0c;这一章主要做用户登录。 一、数据绑定 登录和注册使用相同的数据绑定 1. 数据绑定 <!-- 登录表单 -…...

[C++]使用C++部署yolov10目标检测的tensorrt模型支持图片视频推理windows测试通过

【测试通过环境】 vs2019 cmake3.24.3 cuda11.7.1cudnn8.8.0 tensorrt8.6.1.6 opencv4.8.0 【部署步骤】 获取pt模型&#xff1a;https://github.com/THU-MIG/yolov10训练自己的模型或者直接使用yolov10官方预训练模型 下载源码&#xff1a;https://github.com/laugh12321/yol…...

分享uniapp + Springboot3+vue3小程序项目实战

分享uniapp Springboot3vue3小程序项目实战 经过10天敲代码&#xff0c;终于从零到项目测试完成&#xff0c;一个前后端分离的小程序实战项目学习完毕 时间从6月12日 到6月22日&#xff0c;具有程序开发基础&#xff0c;第一次写uniapp,Springboot以前用过&#xff0c;VUE3也…...

Ubuntu 24.04安装zabbix7.0.0图形中文乱码

当zabbix安装完成后&#xff0c;设置中文界面时&#xff0c;打开图形&#xff0c;中文内容会显示方框乱码&#xff0c;是因为服务器字体中没有相关的中文字体&#xff0c;需要更换。 1、找到中文字体&#xff0c;可以在网络上下载《得意黑》开源字体&#xff0c;也可以在windo…...

MybatisPlus 调用 原生SQL

方式一 DemoMapper.java Mapper public interface DemoMapper extends BaseMapper<TableConfig> {Update("${sql}")int createTable(Param("sql") String sql); }测试代码 SpringBootTest class DemoMapperTest {Resourceprivate DemoMapper demo…...

1.SG90

目录 一.实物图 二.原理图 三.简介 四.工作原理 一.实物图 二.原理图 三.简介 舵机&#xff08;英文叫Servo&#xff09;&#xff0c;是伺服电机的一种&#xff0c;伺服电机就是带有反馈环节的电机&#xff0c;这种电机可以进行精确的位置控制或者输出较高的扭矩。舵机…...

【yolov8语义分割】跑通:下载yolov8+预测图片+预测视频

1、下载yolov8到autodl上 git clone https://github.com/ultralytics/ultralytics 下载到Yolov8文件夹下面 另外&#xff1a;现在yolov8支持像包一样导入&#xff0c;pip install就可以 2、yolov8 语义分割文档 看官方文档&#xff1a;主页 -Ultralytics YOLO 文档 还能切…...

基于STM8系列单片机驱动74HC595驱动两个3位一体的数码管

1&#xff09;单片机/ARM硬件设计小知识&#xff0c;分享给将要学习或者正在学习单片机/ARM开发的同学。 2&#xff09;内容属于原创&#xff0c;若转载&#xff0c;请说明出处。 3&#xff09;提供相关问题有偿答疑和支持。 为了节省单片机MCU的IO口资源驱动6个数码管&…...

Jlink下载固件到RAM区

Jlink下载固件到RAM区 准备批处理搜索exe批处理读取bin数据解析调用jlink批处理准备jlink脚本 调用执行 环境&#xff1a;J-Flash V7.96g 平台&#xff1a;arm cortex-m3 准备批处理 搜索exe批处理 find_file.bat echo off:: 自动识别脚本名和路径 set "SCRIPT_DIR%~dp…...

Kotlin基础——Typeclass

高阶类型 如在Iterable新增泛型方法时 interface Iterable<T> {fun filter(p: (T) -> Boolean): Iterable<T>fun remove(p: (T) -> Boolean): Iterable<T> filter { x -> !p(x) } }对应的List、Set实现上述方法时仍需要返回具体的类型 interfac…...

DC-DC 高压降压、非隔离AC-DC、提供强大的动力,选择优质电源芯片-(昱灿)

畅享长续航&#xff0c;尽在我们的充电芯片&#xff01; 无论是手机、平板还是智能设备&#xff0c;长时间使用后电量不足总是令人头疼。然而&#xff0c;我们的充电芯片将为您带来全新的充电体验&#xff01;采用先进的技术&#xff0c;我们的充电芯片能够提供快速而稳定的充电…...

GPT-4o的视觉识别能力,将绕过所有登陆的图形验证码

知识星球&#x1f517;除了包含技术干货&#xff1a;《Java代码审计》《Web安全》《应急响应》《护网资料库》《网安面试指南》还包含了安全中常见的售前护网案例、售前方案、ppt等&#xff0c;同时也有面向学生的网络安全面试、护网面试等。 我们来看一下市面上常见的图形验证…...

【LinuxC语言】进程间的通信——管道

文章目录 前言不同进程间通信的方式管道匿名管道和命名管道半双工与全双工管道相关函数创建管道总结前言 在Linux操作系统中,进程是执行中的程序的实例。每个进程都有自己的地址空间,数据栈以及其他用于跟踪进程执行的辅助数据。操作系统管理这些进程,并通过调度算法来分享…...

CompletableFuture 基本用法

一、 CompletableFuture简介 CompletableFuture 是 Java 8 引入的一个功能强大的类&#xff0c;用于异步编程和并发处理。它提供了丰富的 API 来处理异步任务的结果&#xff0c;支持函数式编程风格&#xff0c;并允许通过链式调用组合多个异步操作。 二、CompletableFuture中…...

网页如何发布到服务器上

将网页发布到服务器上的过程涉及多个步骤&#xff0c;包括准备阶段、选择托管提供商、发布网站等。12 准备阶段&#xff1a; 确保在本地开发环境中对网站进行了充分的测试&#xff0c;包括功能测试、性能测试和安全测试。 检查Web.config文件&#xff0c;确保所有的配置设置…...

Jenkins简要说明

Jenkins 是一个开源的持续集成和持续部署&#xff08;CI/CD&#xff09;工具&#xff0c;广泛用于自动化软件开发过程中的构建、测试和部署等任务。它是基于Java开发的&#xff0c;因此可以在任何支持Java的平台上运行&#xff0c;并且能够与各种操作系统、开发工具和插件无缝集…...

C# 比较基础知识:最佳实践和技巧

以下是一些在 C# 中进行比较的技巧和窍门的概述。 1. 比较原始类型 对于原始类型&#xff08;int、double、char 等&#xff09;&#xff0c;可以使用标准比较运算符。 int a 5; int b 10; bool isEqual (a b); // false bool isGreater (a > b); // false bool is…...

Ansible 自动化运维实践

随着 IT 基础设施的复杂性不断增加&#xff0c;手动运维已无法满足现代企业对高效、可靠的 IT 运维需求。Ansible 作为一款开源的自动化运维工具&#xff0c;通过简洁易用的 YAML 语法和无代理&#xff08;agentless&#xff09;架构&#xff0c;极大简化了系统配置管理、应用部…...

红队攻防渗透技术实战流程:中间件安全:IISNGINXAPACHETOMCAT

红队攻防渗透实战 1. 中间件安全1.1 中间件-IIS-短文件&解析&蓝屏等1.2 中间件-Nginx-文件解析&命令执行等1.2.1 后缀解析 文件名解析1.2.2 cve_2021_23017 无EXP有POC1.2.3 cve_2017_7529 意义不大1.3 中间件-Apache-RCE&目录遍历&文件解析等1.3.1 cve_20…...

如何卸载宝塔面板?

宝塔官方有提供宝塔面板的卸载命令&#xff0c;使用这个卸载命令&#xff0c;我们就能将宝塔面板卸载掉。 这里有一点需要注意的&#xff0c;如果卸载宝塔面板的同时&#xff0c;也希望将 Nginx、MySQL、PHP 等组件卸载掉&#xff0c;那么我们应该先在宝塔面板里面卸载掉以上软…...

python入门基础知识(错误和异常)

本文部分内容来自菜鸟教程Python 基础教程 | 菜鸟教程 (runoob.com) 本人负责概括总结代码实现。 以此达到快速复习目的 目录 语法错误 异常 异常处理 try/except try/except...else try-finally 语句 抛出异常 用户自定义异常 内置异常类型 常见的标准异常类型 语法…...

迈巴赫S480升级增强现实AR抬头显示hud比普通抬头显示HUD更好用吗

增强AR实景抬头显示HUD&#xff08;Augmented Reality Head-Up Display&#xff09;是一种更高级的驾驶辅助技术&#xff0c;相比于普通抬头显示HUD&#xff0c;它提供了更丰富、更具沉浸感的驾驶体验。以下是它比普通抬头显示HUD多的一些功能&#xff1a; • 信息呈现方式&am…...

vivado、vitis2022安装及其注意事项(省时、省空间)

1、下载 AMD官网-资源与支持-vivado ML开发者工具&#xff0c;或者vitis平台&#xff0c; 下载的时候有个官网推荐web安装&#xff0c;亲测这个耗时非常久&#xff0c;不建议使用&#xff0c;还是直接下载89G的安装包快。 注意&#xff1a;安装vitis平台会默认安装vivado&…...

【自动驾驶】ROS小车系统

文章目录 小车组成轮式运动底盘的组成轮式运动底盘的分类轮式机器人的控制方式感知传感器ROS决策主控ROS介绍ROS的坐标系ROS的单位机器人电气连接变压模块运动底盘的电气连接ROS主控与传感器的电气连接ROS主控和STM32控制器两种控制器的功能运动底盘基本组成电池电机控制器与驱…...

mysql学习——多表查询

多表查询 内连接外连接自连接自连接查询联合查询 子查询 学习黑马MySQL课程&#xff0c;记录笔记&#xff0c;用于复习。 添加外键 alter table emp add constraint fk_emp_dept_id foreign key (dept_id) references dept(id);多表查询 select * from emp , dept where emp…...

【Gradio】如何设置 Gradio 数据框的样式

简介 数据可视化是数据分析和机器学习的关键方面。Gradio DataFrame 组件是一种流行的方式&#xff0c;在网络应用程序中显示表格数据&#xff08;特别是以 pandas DataFrame 对象的形式&#xff09;。 本文将探讨 Gradio 的最新增强功能&#xff0c;这些功能允许用户整合 pand…...

【ThreeJS】Threejs +Vue3 开发基础

目前流行的前端3D框架以以Three.js、Babylon.js、A-Frame和ThingJS为例&#xff1a; 1.Three.js 功能&#xff1a; 提供了大量的3D功能&#xff0c;包括基本几何形状、材质、灯光、动画、特效等。 易用性&#xff1a; 功能强大且易于使用&#xff0c;抽象了复杂的底层细节&…...

cocos 如何使用九宫格图片,以及在微信小程序上失效。

1.在图片下方&#xff0c;点击edit。 2.拖动线条&#xff0c;使四角不被拉伸。 3.使用。 其他 在微信小程序上失效&#xff0c;需要将packable合图功能取消掉。...

外贸网站建设设计/站长工具备案查询

用iterator一直有问题&#xff0c;后来用for each循环就好了。 for循环遍历&#xff1a;for (String str : set) { System.out.println(str);} refurl: http://blog.sina.com.cn/s/blog_4f925fc3010182zi.html...

网站seo技术能不能赚钱/营销技巧和营销方法培训

GPIO是指通用输入输出接口&#xff08;general-purpose input/output&#xff09;&#xff0c;以前的板子是26针&#xff0c;4B型号是40针&#xff0c;每根针的含义从各种文档中找了几个图。 针的序号与GPIO的序号是不一样的。有些针是固定的含义&#xff0c;3.3V电压、5V电压…...

汕头网站公司/今日军事新闻头条

昨天做了什么 设计APP 今天打算做什么 修改错误 遇到的问题 APP中按钮关联错误转载于:https://www.cnblogs.com/liulala2017/p/8185604.html...

为什么做的网站在谷歌浏览器打不开/网络营销的主要特点有哪些

用打点计时器测速度1. 电火花计时器(1)构造&#xff1a;(2)原理&#xff1a; 脉冲电流经放电针、墨粉纸盘到纸盘轴&#xff0c;产生火花放电(3)工作电压&#xff1a;交流 220V2. 打点计时器的作用(1)测时间电源频率50Hz&#xff0c;每隔 0.02 s秒打一个点(2)测位移(3)研究…...

用心做的网站/班级优化大师官方免费下载

Collection├List│├LinkedList│├ArrayList│└Vector│ └Stack└SetMap├Hashtable├HashMap└WeakHashMapCollection接口Collection是最基本的集合接口&#xff0c;一个Collection代表一组Object&#xff0c;即Collection的元素(Elements)。Java SDK不提供直接继承自Col…...

网站建设公司好/长沙官网网站推广优化

#12 楼 akin520 为什么配不起呢? 我在做实验, 先配一个主从, ok, 没问题. 再反过来配置第二个主从的时候, 就有问题了.第一个主从管理配置:[mysqld]datadir/var/lib/mysqlsocket/var/lib/mysql/mysql.sockusermysqllog-binmysql-binserver-id25# Disabling symbolic-links is …...