当前位置: 首页 > news >正文

动手学强化学习 第 11 章 TRPO 算法(TRPOContinuous) 训练代码

基于 Hands-on-RL/第11章-TRPO算法.ipynb at main · boyu-ai/Hands-on-RL · GitHub

理论 TRPO 算法

修改了警告和报错

运行环境

Debian GNU/Linux 12
Python 3.9.19
torch 2.0.1
gym 0.26.2

运行代码

TRPOContinuous.py

#!/usr/bin/env pythonimport torch
import numpy as np
import gym
import matplotlib.pyplot as plt
import torch.nn.functional as F
import rl_utils
import copyclass ValueNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim):super(ValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)class PolicyNetContinuous(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNetContinuous, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc_mu = torch.nn.Linear(hidden_dim, action_dim)self.fc_std = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))mu = 2.0 * torch.tanh(self.fc_mu(x))std = F.softplus(self.fc_std(x))return mu, std  # 高斯分布的均值和标准差class TRPOContinuous:""" 处理连续动作的TRPO算法 """def __init__(self, hidden_dim, state_space, action_space, lmbda,kl_constraint, alpha, critic_lr, gamma, device):state_dim = state_space.shape[0]action_dim = action_space.shape[0]self.actor = PolicyNetContinuous(state_dim, hidden_dim,action_dim).to(device)self.critic = ValueNet(state_dim, hidden_dim).to(device)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),lr=critic_lr)self.gamma = gammaself.lmbda = lmbdaself.kl_constraint = kl_constraintself.alpha = alphaself.device = devicedef take_action(self, state):state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)mu, std = self.actor(state)action_dist = torch.distributions.Normal(mu, std)action = action_dist.sample()return [action.item()]def hessian_matrix_vector_product(self,states,old_action_dists,vector,damping=0.1):mu, std = self.actor(states)new_action_dists = torch.distributions.Normal(mu, std)kl = torch.mean(torch.distributions.kl.kl_divergence(old_action_dists,new_action_dists))kl_grad = torch.autograd.grad(kl,self.actor.parameters(),create_graph=True)kl_grad_vector = torch.cat([grad.view(-1) for grad in kl_grad])kl_grad_vector_product = torch.dot(kl_grad_vector, vector)grad2 = torch.autograd.grad(kl_grad_vector_product,self.actor.parameters())grad2_vector = torch.cat([grad.contiguous().view(-1) for grad in grad2])return grad2_vector + damping * vectordef conjugate_gradient(self, grad, states, old_action_dists):x = torch.zeros_like(grad)r = grad.clone()p = grad.clone()rdotr = torch.dot(r, r)for i in range(10):Hp = self.hessian_matrix_vector_product(states, old_action_dists,p)alpha = rdotr / torch.dot(p, Hp)x += alpha * pr -= alpha * Hpnew_rdotr = torch.dot(r, r)if new_rdotr < 1e-10:breakbeta = new_rdotr / rdotrp = r + beta * prdotr = new_rdotrreturn xdef compute_surrogate_obj(self, states, actions, advantage, old_log_probs,actor):mu, std = actor(states)action_dists = torch.distributions.Normal(mu, std)log_probs = action_dists.log_prob(actions)ratio = torch.exp(log_probs - old_log_probs)return torch.mean(ratio * advantage)def line_search(self, states, actions, advantage, old_log_probs,old_action_dists, max_vec):old_para = torch.nn.utils.convert_parameters.parameters_to_vector(self.actor.parameters())old_obj = self.compute_surrogate_obj(states, actions, advantage,old_log_probs, self.actor)for i in range(15):coef = self.alpha ** inew_para = old_para + coef * max_vecnew_actor = copy.deepcopy(self.actor)torch.nn.utils.convert_parameters.vector_to_parameters(new_para, new_actor.parameters())mu, std = new_actor(states)new_action_dists = torch.distributions.Normal(mu, std)kl_div = torch.mean(torch.distributions.kl.kl_divergence(old_action_dists,new_action_dists))new_obj = self.compute_surrogate_obj(states, actions, advantage,old_log_probs, new_actor)if new_obj > old_obj and kl_div < self.kl_constraint:return new_parareturn old_paradef policy_learn(self, states, actions, old_action_dists, old_log_probs,advantage):surrogate_obj = self.compute_surrogate_obj(states, actions, advantage,old_log_probs, self.actor)grads = torch.autograd.grad(surrogate_obj, self.actor.parameters())obj_grad = torch.cat([grad.view(-1) for grad in grads]).detach()descent_direction = self.conjugate_gradient(obj_grad, states,old_action_dists)Hd = self.hessian_matrix_vector_product(states, old_action_dists,descent_direction)max_coef = torch.sqrt(2 * self.kl_constraint /(torch.dot(descent_direction, Hd) + 1e-8))new_para = self.line_search(states, actions, advantage, old_log_probs,old_action_dists,descent_direction * max_coef)torch.nn.utils.convert_parameters.vector_to_parameters(new_para, self.actor.parameters())def update(self, transition_dict):states = torch.tensor(np.array(transition_dict['states']),dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions'],dtype=torch.float).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(np.array(transition_dict['next_states']),dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)rewards = (rewards + 8.0) / 8.0  # 对奖励进行修改,方便训练td_target = rewards + self.gamma * self.critic(next_states) * (1 -dones)td_delta = td_target - self.critic(states)advantage = rl_utils.compute_advantage(self.gamma, self.lmbda,td_delta.cpu()).to(self.device)mu, std = self.actor(states)old_action_dists = torch.distributions.Normal(mu.detach(),std.detach())old_log_probs = old_action_dists.log_prob(actions)critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))self.critic_optimizer.zero_grad()critic_loss.backward()self.critic_optimizer.step()self.policy_learn(states, actions, old_action_dists, old_log_probs,advantage)num_episodes = 2000
hidden_dim = 128
gamma = 0.9
lmbda = 0.9
critic_lr = 1e-2
kl_constraint = 0.00005
alpha = 0.5
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env_name = 'Pendulum-v1'
env = gym.make(env_name)
env.reset(seed=0)
torch.manual_seed(0)
agent = TRPOContinuous(hidden_dim, env.observation_space, env.action_space,lmbda, kl_constraint, alpha, critic_lr, gamma, device)
return_list = rl_utils.train_on_policy_agent(env, agent, num_episodes)episodes_list = list(range(len(return_list)))
plt.plot(episodes_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('TRPO on {}'.format(env_name))
plt.show()mv_return = rl_utils.moving_average(return_list, 9)
plt.plot(episodes_list, mv_return)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('TRPO on {}'.format(env_name))
plt.show()

rl_utils.py 参考

动手学强化学习 第 11 章 TRPO 算法 训练代码-CSDN博客

相关文章:

动手学强化学习 第 11 章 TRPO 算法(TRPOContinuous) 训练代码

基于 Hands-on-RL/第11章-TRPO算法.ipynb at main boyu-ai/Hands-on-RL GitHub 理论 TRPO 算法 修改了警告和报错 运行环境 Debian GNU/Linux 12 Python 3.9.19 torch 2.0.1 gym 0.26.2 运行代码 TRPOContinuous.py #!/usr/bin/env pythonimport torch import numpy a…...

数量关系模块

三年后指的不是现在 选A注意单位 注意单位换算 A 正方形减去扇形 256-X5y 那么小李拿的一定是末尾是1或者是6&#xff0c;所以小李拿的是26&#xff0c;那么y46&#xff0c;那么小王或者小周拿的是92&#xff0c;所以选择三个数之和等于92的&#xff0c;所以选择D 分数 百分数 …...

滑模面、趋近律设计过程详解(滑模控制)

目录 1. 确定系统的状态变量和目标2. 定义滑模面3. 选择滑模面的参数4. 设计控制律5. 验证滑模面设计6. 总结 设计滑模面&#xff08;Sliding Surface&#xff09;是滑模控制&#xff08;Sliding Mode Control&#xff0c;SMC&#xff09;中的关键步骤。滑模控制是一种鲁棒控制…...

SQL Server 端口配置

目录 默认端口 更改端口 示例&#xff1a;更改 TCP 端口 示例&#xff1a;验证端口设置 远程连接测试 示例&#xff1a;使用 telnet 测试连接 配置防火墙 示例&#xff1a;Windows 防火墙设置 远程连接测试 示例&#xff1a;使用 telnet 测试连接 默认端口 TCP/IP: …...

同一窗口还是新窗口打开链接更利于SEO优化

&#x1f49d;&#x1f49d;&#x1f49d;欢迎莅临我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里可以感受到一份轻松愉快的氛围&#xff0c;不仅可以获得有趣的内容和知识&#xff0c;也可以畅所欲言、分享您的想法和见解。 推荐&#xff1a;「storm…...

kafka 安装

docker安装kafka(KRaft 模式) KRaft模式不再对Zookeeper依赖。 docker run -d --name kafka-kraft \-p 9092:9092 -p 9093:9093 \-e KAFKA_PROCESS_ROLESbroker,controller \-e KAFKA_NODE_ID1 \-e KAFKA_CONTROLLER_QUORUM_VOTERS1127.0.0.1:9093 \-e KAFKA_LISTENERSPLAINTEX…...

消息队列中间件 - Kafka:高效数据流处理的引擎

作者&#xff1a;逍遥Sean 简介&#xff1a;一个主修Java的Web网站\游戏服务器后端开发者 主页&#xff1a;https://blog.csdn.net/Ureliable 觉得博主文章不错的话&#xff0c;可以三连支持一下~ 如有疑问和建议&#xff0c;请私信或评论留言&#xff01; 前言 在现代大数据和…...

el-table表格动态合并相同数据单元格(可指定列+自定义合并)

el-table表格动态合并相同数据单元格(可指定列自定义合并)_el-table 合并单元格动态-CSDN博客 vue2elementUI表格实现实现多列动态合并_element table动态合并列-CSDN博客...

复习Nginx

1.关于Nginx Nginx的关键特性 1.支持高并发 2.内存资源消耗低 3.高扩展性&#xff08;模块化设计&#xff09; 4.高可用性&#xff08;master-worker&#xff09; Nginx运行架构 注意 默认情况下&#xff0c;Nginx会创建和服务器cpu核心数量相等的worker进程 worker进程之间…...

nvm:Node.js 版本管理工具

nvm&#xff08;Node Version Manager&#xff09;是一个用于管理多个 Node.js 版本的工具&#xff0c;它允许你在同一个系统上安装和使用不同版本的 Node.js。这对于开发者来说非常有用&#xff0c;特别是当不同的项目需要不同版本的 Node.js 时。 以下是 nvm 的一些主要特性…...

springboot校园商店配送系统-计算机毕业设计源码68448

摘要 本文详细阐述了基于Spring Boot框架的校园商店配送系统的设计与实现过程。该系统针对校园内的用户需求&#xff0c;整合了用户注册与登录、商品浏览与购买、订单管理、配送追踪、用户反馈收集以及后台管理等功能&#xff0c;为校园内的普通用户、商家、配送员和管理员提供…...

【Redis 初阶】客户端(C++ 使用样例列表)

一、编写 helloworld 需要先使用 redis-plus-plus 连接一下 Redis 服务器&#xff0c;再使用 ping 命令检测连通性。 1、Makefile Redis 库最多可以支持到 C17 版本。&#xff08;如果是用 Centos&#xff0c;需要注意 gcc/g 的版本&#xff0c;看是否支持 C17。不支持的话&a…...

【STM32】STM32单片机入门

个人主页~ 这是一个新的系列&#xff0c;stm32单片机系列&#xff0c;资料都是从网上找的&#xff0c;主要参考江协科技还有正点原子以及csdn博客等资料&#xff0c;以一个一点没有接触过单片机但有一点编程基础的小白视角开始stm32单片机的学习&#xff0c;希望能对也没有学过…...

学生信息管理系统(Python+PySimpleGUI+MySQL)

吐槽一下 经过一段时间学习pymysql的经历&#xff0c;我深刻的体会到了pymysql的不靠谱之处&#xff1b; 就是在使用int型传参&#xff0c;我写的sql语句中格式化%d了之后&#xff0c;我在要传入的数据传递的每一步的去强制转换了&#xff0c;但是他还是会报错&#xff0c;说我…...

Java8.0标准之重要特性及用法实例(十九)

简介&#xff1a; CSDN博客专家&#xff0c;专注Android/Linux系统&#xff0c;分享多mic语音方案、音视频、编解码等技术&#xff0c;与大家一起成长&#xff01; 新书发布&#xff1a;《Android系统多媒体进阶实战》&#x1f680; 优质专栏&#xff1a; Audio工程师进阶系列…...

Linux系统中,`buffer`和`cache` 区别

在Linux系统中&#xff0c;buffer和cache都是操作系统用来提高磁盘I/O性能的机制&#xff0c;它们通过将数据暂存于内存中来减少对磁盘的直接访问。尽管它们的目的相似&#xff0c;但它们在实现和用途上有所不同。 Buffer 定义&#xff1a;buffer主要用于存储即将被写入磁盘的…...

python创建进度条的两个手搓方法

# 使用\b 回删进行手搓 import sys,time for i in range(1, 101):# 这里的10代表你的进度: 一个汉字2字节print(你的进度:,str(i)\b*(i10),flushTrue,end)time.sleep(0.5) # 利用\r手搓 import sys,time for i in range(1, 101):# \r光标回到开头print("\r", end&qu…...

JAVA—面向对象编程基础

面向对象是java编程的套路。更符合人类思维习惯&#xff0c;编程更直观。面向对象有三大特征&#xff1a;封装&#xff0c;继承&#xff0c;多态。 目录 1.理解面向对象 2.对象在计算机中的执行原理 3.类和对象的一些注意事项 4.类与对象的一些语法知识 &#xff08;1&am…...

【计算机视觉学习之CV2图像操作实战:车道识别1】

车道识别 步骤 区域感兴趣高斯模糊图片灰度化边缘提取膨胀腐蚀中值滤波霍夫圆环检测直线绘制车道 import cv2 import numpy as npdef create_roi_mask(frame):height, width frame.shape[:2]# 三角形的顶点top_vertex [int(width / 2 30), int(height * 0.5 30)]bottom_l…...

动态之美:Laravel动态路由参数的实现艺术

动态之美&#xff1a;Laravel动态路由参数的实现艺术 在Web开发中&#xff0c;路由是应用程序的神经系统&#xff0c;它负责将请求映射到相应的处理逻辑。Laravel框架提供了一种强大而灵活的路由系统&#xff0c;允许开发者定义动态路由参数&#xff0c;从而创建更具动态性和可…...

Python练手小项目

计算器 创建一个简单的计算器&#xff0c;能够进行加、减、乘、除四种基本运算。 # 定义加法函数 def add(x, y):return x y# 定义减法函数 def subtract(x, y):return x - y# 定义乘法函数 def multiply(x, y):return x * y# 定义除法函数 def divide(x, y):if y 0:return…...

苹果手机通讯录恢复教程?3招速成指南

随着科技的不断进步&#xff0c;手机丢失、系统崩溃等意外情况也时有发生&#xff0c;一旦这些情况发生&#xff0c;我们宝贵的通讯录资料很可能会付诸东流。对此&#xff0c;本文为广大苹果手机用户提供一份简洁明了的通讯录恢复教程&#xff0c;让你轻松掌握苹果手机通讯录恢…...

python爬虫入门(五)之Re解析

一、什么是Re解析 “Re解析”是指使用正则表达式&#xff08;regular expression&#xff0c;简称regex&#xff09;进行文本解析或匹配的过程。 解析网页内容的三种方式&#xff1a; 1、bs4解析&#xff08;最简单&#xff09; 2、re解析&#xff08;解析速度最快&#xf…...

可靠的图纸加密软件,七款图纸加密软件推荐

大家好啊,我是小固,今天跟大家聊聊图纸加密软件。 作为一名设计师,我深知保护自己的知识产权有多重要。曾经就因为图纸泄露,差点血本无归,那个教训可真是惨痛啊!所以我今天就给大家推荐几款靠谱的图纸加密软件,希望能帮到你们。 固信软件https://www.gooxion.com/ 首先要隆重…...

【每日一题】【最短路】【BFS】小红走矩阵 “葡萄城杯”牛客周赛 Round 53 F题 C++

“葡萄城杯”牛客周赛 Round 53 F题 小红走矩阵 题目背景 “葡萄城杯”牛客周赛 Round 53 题目描述 n m n\times m nm的矩阵由障碍和空地组成&#xff0c;初始时小红位于起点 ( 1 , 1 ) (1,1) (1,1)&#xff0c;她想要前往终点 ( n , m ) (n,m) (n,m)。小红每一步可以往上…...

无线磁吸充电宝哪个牌子值得入手?什么牌子磁吸充电宝性价比高?

在当下科技日新月异的时期&#xff0c;无线磁吸充电宝成为了众多电子设备用户的得力助手。然而&#xff0c;面对市场上众多品牌和型号的无线磁吸充电宝&#xff0c;消费者常常陷入选择的困境&#xff1a;到底哪个牌子值得入手&#xff1f;什么牌子的磁吸充电宝性价比高&#xf…...

互联网摸鱼日报(2024-08-01)

互联网摸鱼日报(2024-08-01) 36氪新闻 氪星晚报 | Uber与比亚迪合作&#xff0c;将在平台上增加10万辆电动汽车&#xff1b;维维股份将收购大窑汽水&#xff1f;公司回应&#xff1a;消息不实&#xff1b;我国科学家取得全固态锂电池研究新突破 《死侍与金刚狼》&#xff0c;…...

Alpla003经典的价量背离的因子在可转债列表里的因子分析(附python代码)

原创文章第605篇&#xff0c;专注“AI量化投资、世界运行的规律、个人成长与财富自由"。 遗传算法给出的因子五花八门&#xff0c;可解释性不高。 强化学习原理不同&#xff0c;但结果类似。 大模型之前咱们尝试过&#xff0c;Quantlab3.9代码&#xff1a;内置大模型LL…...

进阶理解——typeof 、instanceof

typeof 、instance of 先聊聊JavaScript基本类型数据类型5种含值数据类型2种不含值类型 6种类型的*对象* typeofinstanceof总结进一步扩展一下具体讨论一下typeof局限性扩展判断方法 很多时候&#xff0c;回头望&#xff0c;理解会更深刻&#xff0c;也希望能帮助一些初学的同学…...

不同类型的生物反应器在支架成熟过程中具有哪些特点和应用?

3D Bioprinting of Human Tissues: Biofabrication, Bioinks, and Bioreactors是发表于《International Journal of Molecular Sciences》的一篇综述&#xff0c;详细介绍了3D生物打印人体组织的相关技术进展&#xff0c;包括数据处理、生物打印技术、生物墨水配方、生物反应器…...

网站建设销售培训/百度网站链接

调整前 调整后例如上图&#xff0c;我们需要在顶部显示分类汇总的结果&#xff0c;那么如何操作呢&#xff1f;步骤单击数据透视表中任意单元格→数据透视表工具→设计→报表布局→以大纲…...

湖南做网站公司/安卓优化大师破解版

一、String对象的存储请看这样两个语句&#xff1a;String x "abc";String y new String("abcd");现在来分析一下内存的分配情况。如图&#xff1a;可以看出&#xff0c;x与y存在栈中&#xff0c;它们保存了相应对象的引用。第一条语句没有在堆中分配内存…...

网站开发工作总结论文/seo发帖论坛

1.class not found ,不能加载某个配置文件. 具体错误原因找不到了,在我们的工程中,主要是因为lib中的包有冲突,这个只能作为个人日志了,好像和大家分享不了.不好意思哈~ 2.action中处理两个或者两个以上bo时,应注意,尽量由一个bo方法来实现,即在一个bo中注入两个dao,而不要在…...

有投标功能的网站怎么做/网页设计代做

关键部分代码&#xff1a;public partial class frmStandCalculator : Form{public float number1 0;//存储前一个操作数;public bool flag false;//标记单击了操作符没有&#xff0c;false为还未击操作符&#xff0c;ture为单击了操作符public char doflag ;//初始化操作符…...

网站开发php是什么意思/域名注册需要多少钱

一。CxImage类库简介 这只是翻译了CxImage开源项目主页上的部分简介及简单使用。 CxImage类库是一个优秀的图像操作类库。它可以快捷地存取、显示、转换各种图像。有的读者可能说&#xff0c;有那么多优秀的图形库&#xff0c;如OpenIL,FreeImage,PaintLib等等&#xff0c;它们…...

南京维露斯网站建设/seo关键词排名技术

下面介绍标准批输入对象的步骤&#xff1a; 1、创建批输入对象 2、维护对象属性 开始录屏并修改需要的栏位。 3、定义资源结构 4、定义字段 5、定义关系结构 其实就是将上面创建的结构和SAP标准结构匹配&#xff0c;RELATIONSHIP来创建表间关系。 6、字段mapping &#xff0…...