当前位置: 首页 > news >正文

pytorch实现门控循环单元 (GRU)

 人工智能例子汇总:AI常见的算法和例子-CSDN博客  

特性GRULSTM
计算效率更快,参数更少相对较慢,参数更多
结构复杂度只有两个门(更新门和重置门)三个门(输入门、遗忘门、输出门)
处理长时依赖一般适用于中等长度依赖更适合处理超长时序依赖
训练速度训练更快,梯度更稳定训练较慢,占用更多内存

例子:

import torch
import torch.nn as nn
import torch.optim as optim
import random
import matplotlib.pyplot as plt# 🏁 迷宫环境(5×5)
class MazeEnv:def __init__(self, size=5):self.size = sizeself.state = (0, 0)  # 起点self.goal = (size-1, size-1)  # 终点self.actions = [(0,1), (0,-1), (1,0), (-1,0)]  # 右、左、下、上def reset(self):self.state = (0, 0)  # 重置起点return self.statedef step(self, action):dx, dy = self.actions[action]x, y = self.statenx, ny = max(0, min(self.size-1, x+dx)), max(0, min(self.size-1, y+dy))reward = 1 if (nx, ny) == self.goal else -0.1done = (nx, ny) == self.goalself.state = (nx, ny)return (nx, ny), reward, done# 🤖 GRU 策略网络
class GRUPolicy(nn.Module):def __init__(self, input_size, hidden_size, output_size):super(GRUPolicy, self).__init__()self.gru = nn.GRU(input_size, hidden_size, batch_first=True)self.fc = nn.Linear(hidden_size, output_size)def forward(self, x, hidden):out, hidden = self.gru(x, hidden)out = self.fc(out[:, -1, :])  # 只取最后时间步return out, hidden# 🎯 训练参数
env = MazeEnv(size=5)
policy = GRUPolicy(input_size=2, hidden_size=16, output_size=4)
optimizer = optim.Adam(policy.parameters(), lr=0.01)
loss_fn = nn.CrossEntropyLoss()# 🎓 训练
num_episodes = 500
epsilon = 1.0  # 初始的ε值,控制探索的概率
epsilon_min = 0.01  # 最小ε值
epsilon_decay = 0.995  # ε衰减率
best_path = []  # 用于存储最佳路径for episode in range(num_episodes):state = env.reset()hidden = torch.zeros(1, 1, 16)  # GRU 初始状态states, actions, rewards = [], [], []logits_list = []  for _ in range(20):  # 最多 20 步state_tensor = torch.tensor([[state[0], state[1]]], dtype=torch.float32).unsqueeze(0)logits, hidden = policy(state_tensor, hidden)logits_list.append(logits)# ε-greedy 策略if random.random() < epsilon:action = random.choice(range(4))  # 随机选择动作else:action = torch.argmax(logits, dim=1).item()  # 选择最大值对应的动作next_state, reward, done = env.step(action)states.append(state)actions.append(action)rewards.append(reward)if done:print(f"Episode {episode} - Reached Goal!")# 找到最优路径best_path = states + [next_state]  # 当前 episode 的路径breakstate = next_state# 计算损失logits = torch.cat(logits_list, dim=0)  # (T, 4)action_tensor = torch.tensor(actions, dtype=torch.long)  # (T,)loss = loss_fn(logits, action_tensor)  optimizer.zero_grad()loss.backward()optimizer.step()# 衰减 εepsilon = max(epsilon_min, epsilon * epsilon_decay)if episode % 100 == 0:print(f"Episode {episode}, Loss: {loss.item():.4f}, Epsilon: {epsilon:.4f}")# 🧐 确保 best_path 已经记录
if len(best_path) == 0:print("No path found during training.")
else:print(f"Best path: {best_path}")# 🚀 测试路径(只绘制最佳路径)
fig, ax = plt.subplots(figsize=(6,6))# 初始化迷宫图
maze = [[0 for _ in range(5)] for _ in range(5)]  # 5×5 迷宫
ax.imshow(maze, cmap="coolwarm", origin="upper")# 画网格
ax.set_xticks(range(5))
ax.set_yticks(range(5))
ax.grid(True, color="black", linewidth=0.5)# 画出最佳路径(红色)
for (x, y) in best_path:ax.add_patch(plt.Rectangle((y, x), 1, 1, color="red", alpha=0.8))# 画起点和终点
ax.text(0, 0, "S", ha="center", va="center", fontsize=14, color="white", fontweight="bold")
ax.text(4, 4, "G", ha="center", va="center", fontsize=14, color="white", fontweight="bold")plt.title("GRU RL Agent - Best Path")
plt.show()

相关文章:

pytorch实现门控循环单元 (GRU)

人工智能例子汇总&#xff1a;AI常见的算法和例子-CSDN博客 特性GRULSTM计算效率更快&#xff0c;参数更少相对较慢&#xff0c;参数更多结构复杂度只有两个门&#xff08;更新门和重置门&#xff09;三个门&#xff08;输入门、遗忘门、输出门&#xff09;处理长时依赖一般适…...

Word List 2

词汇颜色标识解释 词汇表中的生词 词汇表中的词组成的搭配、派生词 例句中的生词 我自己写的生词&#xff08;用于区分易混淆的词&#xff0c;无颜色标识&#xff09; 不认识的单词或句式 单词的主要汉语意思 不太理解的句子语法和结构 Word List 2 英文音标中文regi…...

机器学习常用包numpy篇(四)函数运算

目录 前言 一、三角函数 二、双曲函数 三、数值修约 四、 求和、求积与差分 五、 指数与对数 六、算术运算 七、 矩阵与向量运算 八、代数运算 九、 其他数学工具 总结 前言 Python 的原生运算符可实现基础数学运算&#xff08;加减乘除、取余、取整、幂运算&#…...

CSS in JS

css in js css in js 的核心思想是&#xff1a;用一个 JS 对象来描述样式&#xff0c;而不是 css 样式表。 例如下面的对象就是一个用于描述样式的对象&#xff1a; const styles {backgroundColor: "#f40",color: "#fff",width: "400px",he…...

TCP 丢包恢复策略:代价权衡与优化迷局

网络物理层丢包是一种需要偿还的债务&#xff0c;可以容忍低劣的传输质量&#xff0c;这为 UDP 类服务提供了空间&#xff0c;而对于 TCP 类服务&#xff0c;可以用另外两类代价来支付&#xff1a; 主机端采用轻率的 GBN 策略恢复丢包&#xff0c;节省 CPU 资源&#xff0c;但…...

面经--C语言——内存泄漏、malloc和new的区别 .c文件怎么转换为可执行程序 uart和usart的区别 继承的访问权限总结

文章目录 内存泄漏预防内存泄漏的方法&#xff1a; malloc和new的区别.c文件怎么转换为可执行程序uart和usart的区别继承的访问权限总结访问控制符总结1. **public**:2. **protected**:3. **private**:继承类型&#xff1a; 内存泄漏 内存泄漏是指程序在运行时动态分配内存后&…...

Denavit-Hartenberg DH MDH坐标系

Denavit-Hartenberg坐标系及其规则详解 6轴协作机器人的MDH模型详细图_6轴mdh-CSDN博客 N轴机械臂的MDH正向建模&#xff0c;及python算法_mdh建模-CSDN博客 运动学3-----正向运动学 | 鱼香ROS 机器人学&#xff1a;MDH建模 - 哆啦美 - 博客园 机械臂学习——标准DH法和改进MDH…...

力扣动态规划-20【算法学习day.114】

前言 ###我做这类文章一个重要的目的还是记录自己的学习过程&#xff0c;我的解析也不会做的非常详细&#xff0c;只会提供思路和一些关键点&#xff0c;力扣上的大佬们的题解质量是非常非常高滴&#xff01;&#xff01;&#xff01; 习题 1.网格中的最小路径代价 题目链接…...

计算机视觉-边缘检测

一、边缘 1.1 边缘的类型 ①实体上的边缘 ②深度上的边缘 ③符号的边缘 ④阴影产生的边缘 不同任务关注的边缘不一样 1.2 提取边缘 突变-求导&#xff08;求导也是一种卷积&#xff09; 近似&#xff0c;1&#xff08;右边的一个值-自己可以用卷积做&#xff09; 该点f(x,y)…...

文字加持:让 OpenCV 轻松在图像中插上文字

前言 在很多图像处理任务中,我们不仅需要提取图像信息,还希望在图像上加上一些文字,或是标注,或是动态展示。正如在一幅画上添加一个标语,或者在一个视频上加上动态字幕,cv2.putText 就是这个“文字魔术师”,它能让我们的图像从“沉默寡言”变得生动有趣。 今天,我们…...

掌握 HTML5 多媒体标签:如何在所有浏览器中顺利嵌入视频与音频

系列文章目录 01-从零开始学 HTML&#xff1a;构建网页的基本框架与技巧 02-HTML常见文本标签解析&#xff1a;从基础到进阶的全面指南 03-HTML从入门到精通&#xff1a;链接与图像标签全解析 04-HTML 列表标签全解析&#xff1a;无序与有序列表的深度应用 05-HTML表格标签全面…...

在Mac mini M4上部署DeepSeek R1本地大模型

在Mac mini M4上部署DeepSeek R1本地大模型 安装ollama 本地部署&#xff0c;我们可以通过Ollama来进行安装 Ollama 官方版&#xff1a;【点击前往】 Web UI 控制端【点击安装】 如何在MacOS上更换Ollama的模型位置 默认安装时&#xff0c;OLLAMA_MODELS 位置在"~/.o…...

【电脑系统】电脑突然(蓝屏)卡死发出刺耳声音

文章目录 前言问题描述软件解决方案尝试硬件解决方案尝试参考文献 前言 在 更换硬盘 时遇到的问题&#xff0c;有时候只有卡死没有蓝屏 问题描述 更换硬盘后&#xff0c;电脑用一会就卡死&#xff0c;蓝屏&#xff0c;显示蓝屏代码 UNEXPECTED_STORE_EXCEPTION 软件解决方案…...

Docker使用指南(二)——容器相关操作详解(实战案例教学,创建/使用/停止/删除)

目录 1.容器操作相关命令​编辑 案例一&#xff1a; 案例二&#xff1a; 容器常用命令总结&#xff1a; 1.查看容器状态&#xff1a; 2.删除容器&#xff1a; 3.进入容器&#xff1a; 二、Docker基本操作——容器篇 1.容器操作相关命令 下面我们用两个案例来具体实操一…...

Java中的常见对象类型解析

在Java开发中&#xff0c;数据的组织和传递是一个重要的概念。为了确保代码的清晰性、可维护性和可扩展性&#xff0c;我们通常会根据不同的用途&#xff0c;设计和使用不同类型的对象。这些对象的作用各不相同&#xff0c;但它们共同为构建高效、模块化的软件架构提供支持。 …...

Dijkstra算法解析

Dijkstra算法&#xff0c;用于求解图中从一个起点到其他所有节点的最短路径。解决单源最短路径问题的有效方法。 条件 有向 带权路径 时间复杂度 O&#xff08;n平方&#xff09; 方法步骤 1 把图上的点分为两个集合 要求的起点 和除了起点之外的点 。能直达的写上权值 不…...

C++ Primer 多维数组

欢迎阅读我的 【CPrimer】专栏 专栏简介&#xff1a;本专栏主要面向C初学者&#xff0c;解释C的一些基本概念和基础语言特性&#xff0c;涉及C标准库的用法&#xff0c;面向对象特性&#xff0c;泛型特性高级用法。通过使用标准库中定义的抽象设施&#xff0c;使你更加适应高级…...

maven mysql jdk nvm node npm 环境安装

安装JDK 1.8 11 环境 maven环境安装 打开网站 下载 下载zip格式 解压 自己创建一个maven库 以后在idea 使用maven时候重新设置一下 这三个地方分别设置 这时候maven才算设置好 nvm 管理 npm nodejs nvm下载 安装 Releases coreybutler/nvm-windows GitHub 一键安装且若有…...

SQL Server中RANK()函数:处理并列排名与自然跳号

RANK()是SQL Server的窗口函数&#xff0c;为结果集中的行生成排名。当出现相同值时&#xff0c;后续排名会跳过被占用的名次&#xff0c;形成自然间隔。与DENSE_RANK()的关键区别在于是否允许排名值连续。 语法&#xff1a; RANK() OVER ([PARTITION BY 分组列]ORDER BY 排序…...

如何运行Composer安装PHP包 安装JWT库

1. 使用Composer Composer是PHP的依赖管理工具&#xff0c;它允许你轻松地安装和管理PHP包。对于JWT&#xff0c;你可以使用firebase/php-jwt这个库&#xff0c;这是由Firebase提供的官方库。 安装Composer&#xff08;如果你还没有安装的话&#xff09;&#xff1a; 访问Co…...

反向工程与模型迁移:打造未来商品详情API的可持续创新体系

在电商行业蓬勃发展的当下&#xff0c;商品详情API作为连接电商平台与开发者、商家及用户的关键纽带&#xff0c;其重要性日益凸显。传统商品详情API主要聚焦于商品基本信息&#xff08;如名称、价格、库存等&#xff09;的获取与展示&#xff0c;已难以满足市场对个性化、智能…...

系统设计 --- MongoDB亿级数据查询优化策略

系统设计 --- MongoDB亿级数据查询分表策略 背景Solution --- 分表 背景 使用audit log实现Audi Trail功能 Audit Trail范围: 六个月数据量: 每秒5-7条audi log&#xff0c;共计7千万 – 1亿条数据需要实现全文检索按照时间倒序因为license问题&#xff0c;不能使用ELK只能使用…...

Matlab | matlab常用命令总结

常用命令 一、 基础操作与环境二、 矩阵与数组操作(核心)三、 绘图与可视化四、 编程与控制流五、 符号计算 (Symbolic Math Toolbox)六、 文件与数据 I/O七、 常用函数类别重要提示这是一份 MATLAB 常用命令和功能的总结,涵盖了基础操作、矩阵运算、绘图、编程和文件处理等…...

Rust 异步编程

Rust 异步编程 引言 Rust 是一种系统编程语言,以其高性能、安全性以及零成本抽象而著称。在多核处理器成为主流的今天,异步编程成为了一种提高应用性能、优化资源利用的有效手段。本文将深入探讨 Rust 异步编程的核心概念、常用库以及最佳实践。 异步编程基础 什么是异步…...

重启Eureka集群中的节点,对已经注册的服务有什么影响

先看答案&#xff0c;如果正确地操作&#xff0c;重启Eureka集群中的节点&#xff0c;对已经注册的服务影响非常小&#xff0c;甚至可以做到无感知。 但如果操作不当&#xff0c;可能会引发短暂的服务发现问题。 下面我们从Eureka的核心工作原理来详细分析这个问题。 Eureka的…...

【VLNs篇】07:NavRL—在动态环境中学习安全飞行

项目内容论文标题NavRL: 在动态环境中学习安全飞行 (NavRL: Learning Safe Flight in Dynamic Environments)核心问题解决无人机在包含静态和动态障碍物的复杂环境中进行安全、高效自主导航的挑战&#xff0c;克服传统方法和现有强化学习方法的局限性。核心算法基于近端策略优化…...

Caliper 负载(Workload)详细解析

Caliper 负载(Workload)详细解析 负载(Workload)是 Caliper 性能测试的核心部分,它定义了测试期间要执行的具体合约调用行为和交易模式。下面我将全面深入地讲解负载的各个方面。 一、负载模块基本结构 一个典型的负载模块(如 workload.js)包含以下基本结构: use strict;/…...

「全栈技术解析」推客小程序系统开发:从架构设计到裂变增长的完整解决方案

在移动互联网营销竞争白热化的当下&#xff0c;推客小程序系统凭借其裂变传播、精准营销等特性&#xff0c;成为企业抢占市场的利器。本文将深度解析推客小程序系统开发的核心技术与实现路径&#xff0c;助力开发者打造具有市场竞争力的营销工具。​ 一、系统核心功能架构&…...

MySQL 主从同步异常处理

阅读原文&#xff1a;https://www.xiaozaoshu.top/articles/mysql-m-s-update-pk MySQL 做双主&#xff0c;遇到的这个错误&#xff1a; Could not execute Update_rows event on table ... Error_code: 1032是 MySQL 主从复制时的经典错误之一&#xff0c;通常表示&#xff…...

【C++】纯虚函数类外可以写实现吗?

1. 答案 先说答案&#xff0c;可以。 2.代码测试 .h头文件 #include <iostream> #include <string>// 抽象基类 class AbstractBase { public:AbstractBase() default;virtual ~AbstractBase() default; // 默认析构函数public:virtual int PureVirtualFunct…...