当前位置: 首页 > news >正文

动手学强化学习 第 15 章 模仿学习 训练代码

基于 https://github.com/boyu-ai/Hands-on-RL/blob/main/%E7%AC%AC15%E7%AB%A0-%E6%A8%A1%E4%BB%BF%E5%AD%A6%E4%B9%A0.ipynb

理论 模仿学习

修改了警告和报错

运行环境

Debian GNU/Linux 12
Python 3.9.19
torch 2.0.1
gym 0.26.2

运行代码

#!/usr/bin/env pythonimport gym
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import rl_utilsclass PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x), dim=1)class ValueNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim):super(ValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)class PPO:''' PPO算法,采用截断方式 '''def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,lmbda, epochs, eps, gamma, device):self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.critic = ValueNet(state_dim, hidden_dim).to(device)self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),lr=critic_lr)self.gamma = gammaself.lmbda = lmbdaself.epochs = epochs  # 一条序列的数据用于训练轮数self.eps = eps  # PPO中截断范围的参数self.device = devicedef take_action(self, state):state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)probs = self.actor(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def update(self, transition_dict):states = torch.tensor(np.array(transition_dict['states']),dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(np.array(transition_dict['next_states']),dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)td_target = rewards + self.gamma * self.critic(next_states) * (1 -dones)td_delta = td_target - self.critic(states)advantage = rl_utils.compute_advantage(self.gamma, self.lmbda,td_delta.cpu()).to(self.device)old_log_probs = torch.log(self.actor(states).gather(1,actions)).detach()for _ in range(self.epochs):log_probs = torch.log(self.actor(states).gather(1, actions))ratio = torch.exp(log_probs - old_log_probs)surr1 = ratio * advantagesurr2 = torch.clamp(ratio, 1 - self.eps,1 + self.eps) * advantage  # 截断actor_loss = torch.mean(-torch.min(surr1, surr2))  # PPO损失函数critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))self.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()actor_loss.backward()critic_loss.backward()self.actor_optimizer.step()self.critic_optimizer.step()actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 250
hidden_dim = 128
gamma = 0.98
lmbda = 0.95
epochs = 10
eps = 0.2
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env_name = 'CartPole-v1'
env = gym.make(env_name)
env.reset(seed=0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
ppo_agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda,epochs, eps, gamma, device)return_list = rl_utils.train_on_policy_agent(env, ppo_agent, num_episodes)def sample_expert_data(n_episode):states = []actions = []for episode in range(n_episode):state = env.reset()[0]done = Falsewhile not done and len(states) < 10000:action = ppo_agent.take_action(state)states.append(state)actions.append(action)next_state, reward, done, _, __ = env.step(action)state = next_statereturn np.array(states), np.array(actions)env.reset(seed=0)
torch.manual_seed(0)
random.seed(0)
n_episode = 1
expert_s, expert_a = sample_expert_data(n_episode)n_samples = 30  # 采样30个数据
random_index = random.sample(range(expert_s.shape[0]), n_samples)
expert_s = expert_s[random_index]
expert_a = expert_a[random_index]class BehaviorClone:def __init__(self, state_dim, hidden_dim, action_dim, lr):self.policy = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)def learn(self, states, actions):states = torch.tensor(states, dtype=torch.float).to(device)actions = torch.tensor(actions).view(-1, 1).to(device)log_probs = torch.log(self.policy(states).gather(1, actions))bc_loss = torch.mean(-log_probs)  # 最大似然估计self.optimizer.zero_grad()bc_loss.backward()self.optimizer.step()def take_action(self, state):state = torch.tensor(np.array([state]), dtype=torch.float).to(device)probs = self.policy(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def test_agent(agent, env, n_episode):return_list = []for episode in range(n_episode):episode_return = 0state = env.reset()[0]done = Falsewhile not done:action = agent.take_action(state)next_state, reward, done, _, __ = env.step(action)state = next_stateepisode_return += rewardreturn_list.append(episode_return)return np.mean(return_list)env.reset(seed=0)
torch.manual_seed(0)
np.random.seed(0)lr = 1e-3
bc_agent = BehaviorClone(state_dim, hidden_dim, action_dim, lr)
n_iterations = 1000
batch_size = 64
test_returns = []with tqdm(total=n_iterations, desc="进度条") as pbar:for i in range(n_iterations):sample_indices = np.random.randint(low=0,high=expert_s.shape[0],size=batch_size)bc_agent.learn(expert_s[sample_indices], expert_a[sample_indices])current_return = test_agent(bc_agent, env, 5)test_returns.append(current_return)if (i + 1) % 10 == 0:pbar.set_postfix({'return': '%.3f' % np.mean(test_returns[-10:])})pbar.update(1)iteration_list = list(range(len(test_returns)))
plt.plot(iteration_list, test_returns)
plt.xlabel('Iterations')
plt.ylabel('Returns')
plt.title('BC on {}'.format(env_name))
plt.show()class Discriminator(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(Discriminator, self).__init__()self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x, a):cat = torch.cat([x, a], dim=1)x = F.relu(self.fc1(cat))return torch.sigmoid(self.fc2(x))class GAIL:def __init__(self, agent, state_dim, action_dim, hidden_dim, lr_d):self.discriminator = Discriminator(state_dim, hidden_dim,action_dim).to(device)self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d)self.agent = agentdef learn(self, expert_s, expert_a, agent_s, agent_a, next_s, dones):expert_states = torch.tensor(expert_s, dtype=torch.float).to(device)expert_actions = torch.tensor(expert_a).to(device)agent_states = torch.tensor(np.array(agent_s), dtype=torch.float).to(device)agent_actions = torch.tensor(agent_a).to(device)expert_actions = F.one_hot(expert_actions, num_classes=2).float()agent_actions = F.one_hot(agent_actions, num_classes=2).float()expert_prob = self.discriminator(expert_states, expert_actions)agent_prob = self.discriminator(agent_states, agent_actions)discriminator_loss = nn.BCELoss()(agent_prob, torch.ones_like(agent_prob)) + nn.BCELoss()(expert_prob, torch.zeros_like(expert_prob))self.discriminator_optimizer.zero_grad()discriminator_loss.backward()self.discriminator_optimizer.step()rewards = -torch.log(agent_prob).detach().cpu().numpy()transition_dict = {'states': agent_s,'actions': agent_a,'rewards': rewards,'next_states': next_s,'dones': dones}self.agent.update(transition_dict)env.reset(seed=0)
torch.manual_seed(0)
lr_d = 1e-3
agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda,epochs, eps, gamma, device)
gail = GAIL(agent, state_dim, action_dim, hidden_dim, lr_d)
n_episode = 500
return_list = []with tqdm(total=n_episode, desc="进度条") as pbar:for i in range(n_episode):episode_return = 0state = env.reset()[0]done = Falsestate_list = []action_list = []next_state_list = []done_list = []while not done and len(state_list) < 10000:action = agent.take_action(state)next_state, reward, done, _, __ = env.step(action)state_list.append(state)action_list.append(action)next_state_list.append(next_state)done_list.append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)gail.learn(expert_s, expert_a, state_list, action_list,next_state_list, done_list)if (i + 1) % 10 == 0:pbar.set_postfix({'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)iteration_list = list(range(len(return_list)))
plt.plot(iteration_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('GAIL on {}'.format(env_name))
plt.show()

rl_utils.py

参考 动手学强化学习 第 14 章 SAC 算法 训练代码-CSDN博客

相关文章:

动手学强化学习 第 15 章 模仿学习 训练代码

基于 https://github.com/boyu-ai/Hands-on-RL/blob/main/%E7%AC%AC15%E7%AB%A0-%E6%A8%A1%E4%BB%BF%E5%AD%A6%E4%B9%A0.ipynb 理论 模仿学习 修改了警告和报错 运行环境 Debian GNU/Linux 12 Python 3.9.19 torch 2.0.1 gym 0.26.2 运行代码 #!/usr/bin/env pythonimpor…...

第一阶段面试问题(前半部分)

1. 进程和线程的概念、区别以及什么时候用线程、什么时候用进程&#xff1f; &#xff08;1&#xff09;线程 线程是CPU任务调度的最小单元、是一个轻量级的进程 &#xff08;2&#xff09;进程 进程是操作系统资源分配的最小单元 进程是一个程序动态执行的过程&#xff0c;包…...

《数学教学通讯》是一本怎样的刊物?投稿难吗?

《数学教学通讯》是一本怎样的刊物&#xff1f;投稿难吗&#xff1f; 《数学教学通讯》是一本具有较高学术价值的教育类刊物。它创刊于 1979 年&#xff0c;由西南大学主管&#xff0c;西南大学数学与统计学院、重庆市数学学会主办&#xff0c;出版周期为旬刊。该刊物在国内外…...

<机器学习> K-means

K-means定义 K-means 是一种广泛使用的聚类算法&#xff0c;旨在将数据集中的点分组为 K 个簇&#xff08;cluster&#xff09;&#xff0c;使得每个簇内的点尽可能相似&#xff0c;而不同簇的点尽可能不同。K-means 算法通过迭代的方式&#xff0c;逐步优化簇的分配和簇的中心…...

我们如何优化 Elasticsearch Serverless 中的刷新成本

作者&#xff1a;来自 Elastic Francisco Fernndez Castao, Henning Andersen 最近&#xff0c;我们推出了 Elastic Cloud Serverless 产品&#xff0c;旨在提供在云中运行搜索工作负载的无缝体验。为了推出该产品&#xff0c;我们重新设计了 Elasticsearch&#xff0c;将存储与…...

MySQL半同步复制

1.MySQL主从复制模式 1.1异步复制 异步复制为 MySQL 默认的复制模式&#xff0c;指主库写 binlog、从库 I/O 线程读 binlog 并写入 relaylog、从库 SQL 线程重放事务这三步之间是异步的。 异步复制的主库不需要关心备库的状态&#xff0c;主库不保证事务被传输到从库&#xf…...

[一本通提高数位动态规划]数字游戏:取模数题解

[一本通提高数位动态规划]数字游戏&#xff1a;取模数题解 1前言2问题3状态的设置4数位dp-part1预处理5数位dp-part2利用状态求解6代码7后记 1前言 本文为数字游戏&#xff1a;取模数的题解 需要读者对数位dp有基础的了解&#xff0c;建议先阅读 论数位dp–胎教级教学 B3883 […...

[Day 39] 區塊鏈與人工智能的聯動應用:理論、技術與實踐

區塊鏈的安全性分析 區塊鏈技術已經成為現代數字經濟的一個重要組成部分&#xff0c;提供了去中心化、透明和不可篡改的數據存儲與交易系統。然而&#xff0c;隨著區塊鏈技術的廣泛應用&#xff0c;其安全性問題也日益受到關注。本篇文章將詳細探討區塊鏈技術的安全性&#xf…...

OpenStack入门体验

一、云计算概述 1.1什么是云计算 云计算(cloud computing)是一种基于网络的超级计算模式,基于用户的不同需求&#xff0c;提供所需的资源&#xff0c;包括计算资源、存储资源、网络资源等。云计算服务运行在若干台高性能物理服务器之上&#xff0c;提供每秒 10万亿次的运算能力…...

预测未来 | MATLAB实现RF随机森林多变量时间序列预测未来-预测新数据

预测未来 | MATLAB实现RF随机森林多变量时间序列预测未来-预测新数据 预测效果 基本介绍 随机森林属于 集成学习 中的 Bagging(Bootstrap AGgregation 的简称) 方法。如果用图来表示他们之间的关系如下: 随机森林是由很多决策树构成的,不同决策树之间没有关联。当我们进行…...

iOS 系统提供的媒体资源选择器(UIImagePickerController)

简介 图片或者视频的选择功能几乎是每个APP必不可少的&#xff0c;UIImagePickerController 是 iOS 系统提供的一个方便的媒体选择器&#xff0c;允许用户从照片库中选择图片或视频&#xff0c;或者使用相机拍摄新照片和视频。 它的页面简单易用&#xff0c;代码稳定可靠&…...

电脑如何扩展硬盘分区?告别空间不足困扰

在数字化时代&#xff0c;电脑硬盘的存储空间显得愈发重要。随着个人文件、应用程序和系统更新的不断累积&#xff0c;原有的硬盘分区可能很快就会被填满。为了解决这个问题&#xff0c;扩展硬盘分区成为了一个非常实用的方法。那么&#xff0c;电脑如何扩展硬盘分区呢&#xf…...

论文阅读:Mammoth: Building math generalist models through hybrid instruction tuning

Mammoth: Building math generalist models through hybrid instruction tuning https://arxiv.org/pdf/2309.05653 MAmmoTH&#xff1a;通过混合指令调优构建数学通才模型 摘要 我们介绍了MAmmoTH&#xff0c;一系列特别为通用数学问题解决而设计的开源大型语言模型&#…...

什么样的双筒式防爆器把煤矿吸引?

什么样的双筒式防爆器把煤矿吸引&#xff1f;要有好的服务和态度&#xff0c;要用心去聆听客户的需求&#xff0c;去解决客户的疑虑&#xff0c;用诚信去赢得客户的信任。 150产品的技术特点 双筒式防爆器采用双罐结构&#xff0c;其水封水位观测直观、能够快速有效排污、操作…...

如何保证冰河AL0 400G 100W 的稳定运行?

要保证冰河 AL0 400G 100w 的稳定运行&#xff0c;可以考虑以下几点&#xff1a; 1. 适宜的工作环境&#xff1a;确保设备放置在通风良好、温度适宜的环境中。良好的散热条件有助于防止设备过热&#xff0c;因为过热可能会导致性能下降或故障。该设备采用纯铝合金外壳&#xf…...

剪画小程序:巴黎奥运会,从画面到声音!

在巴黎奥运会的赛场上&#xff0c;每一个瞬间都伴随着独特的声音。那是观众的欢呼&#xff0c;是运动员冲刺的呐喊&#xff0c;是国歌奏响的激昂旋律。 如今&#xff0c;通过剪画音频提取&#xff0c;我们能够将这些珍贵的声音从精彩的画面中分离出来&#xff0c;单独珍藏。 想…...

【leetcode详解】心算挑战: 一题搞懂涉及奇偶数问题的 “万金油” 思路(思路详解)

前记&#xff1a; 做了几日的leetcode每日一题&#xff0c;几乎全是十分钟结束战斗的【中等】题&#xff0c;今日杀出来个【简单】题&#xff0c;反倒开始难以想出很清楚的解题思路&#xff0c;反复调试修改才将题目逐渐考虑全面&#xff0c;看到了原本思路的漏洞&#xff0c…...

【资料集】数据库设计说明书(Word原件提供)

2 数据库环境说明 3 数据库的命名规则 4 逻辑设计 5 物理设计 5.1 表汇总 5.2 表结构设计 6 数据规划 6.1 表空间设计 6.2 数据文件设计 6.3 表、索引分区设计 6.4 优化方法 7 安全性设计 7.1 防止用户直接操作数据库 7.2 用户帐号加密处理 7.3 角色与权限控制 8 数据库管理与维…...

MySQL 常用查询语句精粹

引言 MySQL 是一种广泛使用的开源关系型数据库管理系统&#xff0c;其强大的查询语言为用户提供了丰富的数据处理能力。掌握 MySQL 的常用查询语句对于数据库管理和数据分析至关重要。本文将介绍一些 MySQL 中的常用查询语句&#xff0c;并提供实际的示例。 基础查询 1. 选择…...

hive的内部表(MANAGED_TABLE)和外部表(EXTERNAL_TABLE)的区别

1.hive的表类型分为外部表和内部表 内部表和外部表的主要区别在于数据的存储方式。 外部表&#xff1a;外部表的存储在hdfs中&#xff0c;是我们指定的文件目录&#xff0c;当我们删除数据或者删除分区的时候不会将元数据删除&#xff0c;数据还会在hdfs目录中&#xff0c;我们…...

【AutoSar网络管理】验证ecu能够从RepeatMessage状态切换到ReadySleep

本专栏将为您提供: Autosar网络管理介绍,包括:状态迁移、状态行为、状态表现、切换条件、时间参数、消息类型等。DUT模拟节点介绍,包括:设计思路、代码展示、编写须知等。测试用例介绍,包括:测试内容、测试步骤、期望结果等。测试脚本介绍,包括:编写思路、代码展示、脚…...

js逻辑或(||)和且()

重点&#xff1a; JavaScript 中的逻辑运算符按照布尔逻辑进行计算&#xff0c;并且返回值是操作数本身 || ||:逻辑或&#xff0c;只要有一个表达式为真&#xff08;truthy&#xff09;&#xff0c;整个表达式就为真 逻辑或 (||) 的行为&#xff1a; ||运算符可以用来连接两个…...

ElasticSearch入门(六)SpringBoot2

private String author; Field(name “word_count”, type FieldType.Integer) private Integer wordCount; /** Jackson日期时间序列化问题&#xff1a; Cannot deserialize value of type java.time.LocalDateTime from String “2020-06-04 15:07:54”: Failed to des…...

vue项目Nginx部署启动

1.vue打包 &#xff08;1&#xff09;package.json增加打包命令 "scripts": {"dev": "webpack-dev-server --inline --progress --config build/webpack.dev.conf.js --host 10.16.14.110","start": "npm run dev","un…...

Duplicate class kotlin.collections.jdk8.CollectionsJDK8Kt found in modules。Android studio纯java代码报错

我使用java代码 构建项目&#xff0c;初始代码运行就会报错。我使用的是Android Studio Giraffe&#xff08;Adroid-studio-2022.3.1.18-windows&#xff09;。我在网上找的解决办法是删除重复的类&#xff0c;但这操作起来真的太麻烦了。 这是全部报错代码&#xff1a; Dupli…...

filebeat

1、作用 1、可以在本机收集日志2、也可以远程收集日志3、轻量级的日志收集系统&#xff0c;可以在非java环境运行。logstash是在jmv环境中运行&#xff0c;资源消耗很大&#xff0c;启动一个logstash要消耗500M左右的内存&#xff0c;filebeat只消耗10M左右的内存。收集nginx的…...

matlab y=sin(x) - 2/π*(x)函数绘制

[TOC](matlab ysin(x) - 2/π*(x)函数绘制) ysin(x) - 2/π*(x) clc; clear; close all; x_axis_length 10; y_axis_length 10; % 创建 x 值向量 x_positive linspace(0.1, 10, 1000); % 正半轴上的 x 值 x_negative linspace(-10, -0.1, 1000); % 负半轴上的 x 值% 计算…...

HyperDiffusion阅读

ICCV 2023 创新点 HyperDiffusion&#xff1a;一种用隐式神经场无条件生成建模的新方法。 HyperDiffusion直接对MLP权重进行操作&#xff0c;并生成新的神经隐式场。 HyperDiffusion是与维度无关的生成模型。可以对不同维度的数据用相同的训练方法来合成高保真示例。 局限性…...

分治思想 排序数组

题目 这是一道经典的关于分治思想的算法题&#xff0c;适合刚接触分治的小白。 . - 力扣&#xff08;LeetCode&#xff09; 思路 采用递归分治的思想&#xff0c;也就是快速排序的模拟&#xff0c;这里先确定每趟递归的作用&#xff1a; 在一个规定的区间内&#xff0c;随机…...

通用前端分页插件

/*** >>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>>* 分页组件* >>>>>>>>>>>>>>>>>>>…...

怎样做恶搞网站/上海推广外包

1. collections模块collections模块主要封装了⼀些关于集合类的相关操作. 比如, 我们学过的Iterable,Iterator等等. 除了这些以外, collections还提供了⼀些除了基本数据类型以外的数据集合类型. Counter, deque, OrderedDict, defaultdict以及namedtuple class Animal:passfr…...

高古楼网站找活做/网络科技有限公司

Android应用开发--MP3音乐播放器界面设计&#xff08;1&#xff09; 近期突然想自己开发一款MP3播放器&#xff0c;所以就有了上面的界面&#xff0c;小巫不才&#xff0c;为了记录自己的开发过程&#xff0c;写成博客&#xff0c;个人开发可能有很多地方不规范&#xff0c;但学…...

无法定位 wordpress 根目录./seo策略工具

一直以来&#xff0c;我们在网页上都是使用图片作为表情。但是图片表情使用起来麻烦&#xff0c;同时衍生了css sprite技术&#xff08;雪碧图&#xff0c;就是很多个小图片拼接成一张大图片&#xff09;。但其实unicode编码中就定义了大量的表情符号&#xff0c;而且你还可以使…...

织梦网站修改教程/搜索引擎收录提交入口

随着WiFi的普及&#xff0c;越来越多的人开始使用无线路由器进行无线上网&#xff0c;比如笔记本电脑&#xff0c;智能手机都可以共享一个路由器进行上网&#xff0c;有时候&#xff0c;我们在用笔记本进行无线WiFi上网的时候&#xff0c;电脑右下角的无线连接图标明明显示连接…...

做pc端网站适配/网站运营主要做什么工作

分享一下我老师大神的人工智能教程&#xff01;零基础&#xff0c;通俗易懂&#xff01;http://blog.csdn.net/jiangjunshow也欢迎大家转载本篇文章。分享知识&#xff0c;造福人民&#xff0c;实现我们中华民族伟大复兴&#xff01;1&#xff0e; Kafka整体结构图Kafka名词解释…...

深圳模板建站多少钱/竞价托管运营哪家好

github传送门 目录 前言效果图快速上手CollapsingToolbarLayout折叠模式AppBarLayout滚动方式CoordinatorLayout配合Snackbar自定义伸缩头部最后前言 之前也是写了RecyclerView的内容, 这次再补充伸缩头部的实现. 港真, 伸缩头部是那种看到第一眼就会爱上的视图效果, 好看又简洁…...