动手学强化学习 第 15 章 模仿学习 训练代码
基于 https://github.com/boyu-ai/Hands-on-RL/blob/main/%E7%AC%AC15%E7%AB%A0-%E6%A8%A1%E4%BB%BF%E5%AD%A6%E4%B9%A0.ipynb
理论 模仿学习
修改了警告和报错
运行环境
Debian GNU/Linux 12
Python 3.9.19
torch 2.0.1
gym 0.26.2
运行代码
#!/usr/bin/env pythonimport gym
import torch
import torch.nn.functional as F
import torch.nn as nn
import numpy as np
import matplotlib.pyplot as plt
from tqdm import tqdm
import random
import rl_utilsclass PolicyNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(PolicyNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, action_dim)def forward(self, x):x = F.relu(self.fc1(x))return F.softmax(self.fc2(x), dim=1)class ValueNet(torch.nn.Module):def __init__(self, state_dim, hidden_dim):super(ValueNet, self).__init__()self.fc1 = torch.nn.Linear(state_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x):x = F.relu(self.fc1(x))return self.fc2(x)class PPO:''' PPO算法,采用截断方式 '''def __init__(self, state_dim, hidden_dim, action_dim, actor_lr, critic_lr,lmbda, epochs, eps, gamma, device):self.actor = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.critic = ValueNet(state_dim, hidden_dim).to(device)self.actor_optimizer = torch.optim.Adam(self.actor.parameters(),lr=actor_lr)self.critic_optimizer = torch.optim.Adam(self.critic.parameters(),lr=critic_lr)self.gamma = gammaself.lmbda = lmbdaself.epochs = epochs # 一条序列的数据用于训练轮数self.eps = eps # PPO中截断范围的参数self.device = devicedef take_action(self, state):state = torch.tensor(np.array([state]), dtype=torch.float).to(self.device)probs = self.actor(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def update(self, transition_dict):states = torch.tensor(np.array(transition_dict['states']),dtype=torch.float).to(self.device)actions = torch.tensor(transition_dict['actions']).view(-1, 1).to(self.device)rewards = torch.tensor(transition_dict['rewards'],dtype=torch.float).view(-1, 1).to(self.device)next_states = torch.tensor(np.array(transition_dict['next_states']),dtype=torch.float).to(self.device)dones = torch.tensor(transition_dict['dones'],dtype=torch.float).view(-1, 1).to(self.device)td_target = rewards + self.gamma * self.critic(next_states) * (1 -dones)td_delta = td_target - self.critic(states)advantage = rl_utils.compute_advantage(self.gamma, self.lmbda,td_delta.cpu()).to(self.device)old_log_probs = torch.log(self.actor(states).gather(1,actions)).detach()for _ in range(self.epochs):log_probs = torch.log(self.actor(states).gather(1, actions))ratio = torch.exp(log_probs - old_log_probs)surr1 = ratio * advantagesurr2 = torch.clamp(ratio, 1 - self.eps,1 + self.eps) * advantage # 截断actor_loss = torch.mean(-torch.min(surr1, surr2)) # PPO损失函数critic_loss = torch.mean(F.mse_loss(self.critic(states), td_target.detach()))self.actor_optimizer.zero_grad()self.critic_optimizer.zero_grad()actor_loss.backward()critic_loss.backward()self.actor_optimizer.step()self.critic_optimizer.step()actor_lr = 1e-3
critic_lr = 1e-2
num_episodes = 250
hidden_dim = 128
gamma = 0.98
lmbda = 0.95
epochs = 10
eps = 0.2
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")env_name = 'CartPole-v1'
env = gym.make(env_name)
env.reset(seed=0)
torch.manual_seed(0)
state_dim = env.observation_space.shape[0]
action_dim = env.action_space.n
ppo_agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda,epochs, eps, gamma, device)return_list = rl_utils.train_on_policy_agent(env, ppo_agent, num_episodes)def sample_expert_data(n_episode):states = []actions = []for episode in range(n_episode):state = env.reset()[0]done = Falsewhile not done and len(states) < 10000:action = ppo_agent.take_action(state)states.append(state)actions.append(action)next_state, reward, done, _, __ = env.step(action)state = next_statereturn np.array(states), np.array(actions)env.reset(seed=0)
torch.manual_seed(0)
random.seed(0)
n_episode = 1
expert_s, expert_a = sample_expert_data(n_episode)n_samples = 30 # 采样30个数据
random_index = random.sample(range(expert_s.shape[0]), n_samples)
expert_s = expert_s[random_index]
expert_a = expert_a[random_index]class BehaviorClone:def __init__(self, state_dim, hidden_dim, action_dim, lr):self.policy = PolicyNet(state_dim, hidden_dim, action_dim).to(device)self.optimizer = torch.optim.Adam(self.policy.parameters(), lr=lr)def learn(self, states, actions):states = torch.tensor(states, dtype=torch.float).to(device)actions = torch.tensor(actions).view(-1, 1).to(device)log_probs = torch.log(self.policy(states).gather(1, actions))bc_loss = torch.mean(-log_probs) # 最大似然估计self.optimizer.zero_grad()bc_loss.backward()self.optimizer.step()def take_action(self, state):state = torch.tensor(np.array([state]), dtype=torch.float).to(device)probs = self.policy(state)action_dist = torch.distributions.Categorical(probs)action = action_dist.sample()return action.item()def test_agent(agent, env, n_episode):return_list = []for episode in range(n_episode):episode_return = 0state = env.reset()[0]done = Falsewhile not done:action = agent.take_action(state)next_state, reward, done, _, __ = env.step(action)state = next_stateepisode_return += rewardreturn_list.append(episode_return)return np.mean(return_list)env.reset(seed=0)
torch.manual_seed(0)
np.random.seed(0)lr = 1e-3
bc_agent = BehaviorClone(state_dim, hidden_dim, action_dim, lr)
n_iterations = 1000
batch_size = 64
test_returns = []with tqdm(total=n_iterations, desc="进度条") as pbar:for i in range(n_iterations):sample_indices = np.random.randint(low=0,high=expert_s.shape[0],size=batch_size)bc_agent.learn(expert_s[sample_indices], expert_a[sample_indices])current_return = test_agent(bc_agent, env, 5)test_returns.append(current_return)if (i + 1) % 10 == 0:pbar.set_postfix({'return': '%.3f' % np.mean(test_returns[-10:])})pbar.update(1)iteration_list = list(range(len(test_returns)))
plt.plot(iteration_list, test_returns)
plt.xlabel('Iterations')
plt.ylabel('Returns')
plt.title('BC on {}'.format(env_name))
plt.show()class Discriminator(nn.Module):def __init__(self, state_dim, hidden_dim, action_dim):super(Discriminator, self).__init__()self.fc1 = torch.nn.Linear(state_dim + action_dim, hidden_dim)self.fc2 = torch.nn.Linear(hidden_dim, 1)def forward(self, x, a):cat = torch.cat([x, a], dim=1)x = F.relu(self.fc1(cat))return torch.sigmoid(self.fc2(x))class GAIL:def __init__(self, agent, state_dim, action_dim, hidden_dim, lr_d):self.discriminator = Discriminator(state_dim, hidden_dim,action_dim).to(device)self.discriminator_optimizer = torch.optim.Adam(self.discriminator.parameters(), lr=lr_d)self.agent = agentdef learn(self, expert_s, expert_a, agent_s, agent_a, next_s, dones):expert_states = torch.tensor(expert_s, dtype=torch.float).to(device)expert_actions = torch.tensor(expert_a).to(device)agent_states = torch.tensor(np.array(agent_s), dtype=torch.float).to(device)agent_actions = torch.tensor(agent_a).to(device)expert_actions = F.one_hot(expert_actions, num_classes=2).float()agent_actions = F.one_hot(agent_actions, num_classes=2).float()expert_prob = self.discriminator(expert_states, expert_actions)agent_prob = self.discriminator(agent_states, agent_actions)discriminator_loss = nn.BCELoss()(agent_prob, torch.ones_like(agent_prob)) + nn.BCELoss()(expert_prob, torch.zeros_like(expert_prob))self.discriminator_optimizer.zero_grad()discriminator_loss.backward()self.discriminator_optimizer.step()rewards = -torch.log(agent_prob).detach().cpu().numpy()transition_dict = {'states': agent_s,'actions': agent_a,'rewards': rewards,'next_states': next_s,'dones': dones}self.agent.update(transition_dict)env.reset(seed=0)
torch.manual_seed(0)
lr_d = 1e-3
agent = PPO(state_dim, hidden_dim, action_dim, actor_lr, critic_lr, lmbda,epochs, eps, gamma, device)
gail = GAIL(agent, state_dim, action_dim, hidden_dim, lr_d)
n_episode = 500
return_list = []with tqdm(total=n_episode, desc="进度条") as pbar:for i in range(n_episode):episode_return = 0state = env.reset()[0]done = Falsestate_list = []action_list = []next_state_list = []done_list = []while not done and len(state_list) < 10000:action = agent.take_action(state)next_state, reward, done, _, __ = env.step(action)state_list.append(state)action_list.append(action)next_state_list.append(next_state)done_list.append(done)state = next_stateepisode_return += rewardreturn_list.append(episode_return)gail.learn(expert_s, expert_a, state_list, action_list,next_state_list, done_list)if (i + 1) % 10 == 0:pbar.set_postfix({'return': '%.3f' % np.mean(return_list[-10:])})pbar.update(1)iteration_list = list(range(len(return_list)))
plt.plot(iteration_list, return_list)
plt.xlabel('Episodes')
plt.ylabel('Returns')
plt.title('GAIL on {}'.format(env_name))
plt.show()
rl_utils.py
参考 动手学强化学习 第 14 章 SAC 算法 训练代码-CSDN博客
相关文章:
动手学强化学习 第 15 章 模仿学习 训练代码
基于 https://github.com/boyu-ai/Hands-on-RL/blob/main/%E7%AC%AC15%E7%AB%A0-%E6%A8%A1%E4%BB%BF%E5%AD%A6%E4%B9%A0.ipynb 理论 模仿学习 修改了警告和报错 运行环境 Debian GNU/Linux 12 Python 3.9.19 torch 2.0.1 gym 0.26.2 运行代码 #!/usr/bin/env pythonimpor…...

第一阶段面试问题(前半部分)
1. 进程和线程的概念、区别以及什么时候用线程、什么时候用进程? (1)线程 线程是CPU任务调度的最小单元、是一个轻量级的进程 (2)进程 进程是操作系统资源分配的最小单元 进程是一个程序动态执行的过程,包…...

《数学教学通讯》是一本怎样的刊物?投稿难吗?
《数学教学通讯》是一本怎样的刊物?投稿难吗? 《数学教学通讯》是一本具有较高学术价值的教育类刊物。它创刊于 1979 年,由西南大学主管,西南大学数学与统计学院、重庆市数学学会主办,出版周期为旬刊。该刊物在国内外…...
<机器学习> K-means
K-means定义 K-means 是一种广泛使用的聚类算法,旨在将数据集中的点分组为 K 个簇(cluster),使得每个簇内的点尽可能相似,而不同簇的点尽可能不同。K-means 算法通过迭代的方式,逐步优化簇的分配和簇的中心…...

我们如何优化 Elasticsearch Serverless 中的刷新成本
作者:来自 Elastic Francisco Fernndez Castao, Henning Andersen 最近,我们推出了 Elastic Cloud Serverless 产品,旨在提供在云中运行搜索工作负载的无缝体验。为了推出该产品,我们重新设计了 Elasticsearch,将存储与…...
MySQL半同步复制
1.MySQL主从复制模式 1.1异步复制 异步复制为 MySQL 默认的复制模式,指主库写 binlog、从库 I/O 线程读 binlog 并写入 relaylog、从库 SQL 线程重放事务这三步之间是异步的。 异步复制的主库不需要关心备库的状态,主库不保证事务被传输到从库…...
[一本通提高数位动态规划]数字游戏:取模数题解
[一本通提高数位动态规划]数字游戏:取模数题解 1前言2问题3状态的设置4数位dp-part1预处理5数位dp-part2利用状态求解6代码7后记 1前言 本文为数字游戏:取模数的题解 需要读者对数位dp有基础的了解,建议先阅读 论数位dp–胎教级教学 B3883 […...
[Day 39] 區塊鏈與人工智能的聯動應用:理論、技術與實踐
區塊鏈的安全性分析 區塊鏈技術已經成為現代數字經濟的一個重要組成部分,提供了去中心化、透明和不可篡改的數據存儲與交易系統。然而,隨著區塊鏈技術的廣泛應用,其安全性問題也日益受到關注。本篇文章將詳細探討區塊鏈技術的安全性…...

OpenStack入门体验
一、云计算概述 1.1什么是云计算 云计算(cloud computing)是一种基于网络的超级计算模式,基于用户的不同需求,提供所需的资源,包括计算资源、存储资源、网络资源等。云计算服务运行在若干台高性能物理服务器之上,提供每秒 10万亿次的运算能力…...

预测未来 | MATLAB实现RF随机森林多变量时间序列预测未来-预测新数据
预测未来 | MATLAB实现RF随机森林多变量时间序列预测未来-预测新数据 预测效果 基本介绍 随机森林属于 集成学习 中的 Bagging(Bootstrap AGgregation 的简称) 方法。如果用图来表示他们之间的关系如下: 随机森林是由很多决策树构成的,不同决策树之间没有关联。当我们进行…...

iOS 系统提供的媒体资源选择器(UIImagePickerController)
简介 图片或者视频的选择功能几乎是每个APP必不可少的,UIImagePickerController 是 iOS 系统提供的一个方便的媒体选择器,允许用户从照片库中选择图片或视频,或者使用相机拍摄新照片和视频。 它的页面简单易用,代码稳定可靠&…...

电脑如何扩展硬盘分区?告别空间不足困扰
在数字化时代,电脑硬盘的存储空间显得愈发重要。随着个人文件、应用程序和系统更新的不断累积,原有的硬盘分区可能很快就会被填满。为了解决这个问题,扩展硬盘分区成为了一个非常实用的方法。那么,电脑如何扩展硬盘分区呢…...

论文阅读:Mammoth: Building math generalist models through hybrid instruction tuning
Mammoth: Building math generalist models through hybrid instruction tuning https://arxiv.org/pdf/2309.05653 MAmmoTH:通过混合指令调优构建数学通才模型 摘要 我们介绍了MAmmoTH,一系列特别为通用数学问题解决而设计的开源大型语言模型&#…...
什么样的双筒式防爆器把煤矿吸引?
什么样的双筒式防爆器把煤矿吸引?要有好的服务和态度,要用心去聆听客户的需求,去解决客户的疑虑,用诚信去赢得客户的信任。 150产品的技术特点 双筒式防爆器采用双罐结构,其水封水位观测直观、能够快速有效排污、操作…...

如何保证冰河AL0 400G 100W 的稳定运行?
要保证冰河 AL0 400G 100w 的稳定运行,可以考虑以下几点: 1. 适宜的工作环境:确保设备放置在通风良好、温度适宜的环境中。良好的散热条件有助于防止设备过热,因为过热可能会导致性能下降或故障。该设备采用纯铝合金外壳…...

剪画小程序:巴黎奥运会,从画面到声音!
在巴黎奥运会的赛场上,每一个瞬间都伴随着独特的声音。那是观众的欢呼,是运动员冲刺的呐喊,是国歌奏响的激昂旋律。 如今,通过剪画音频提取,我们能够将这些珍贵的声音从精彩的画面中分离出来,单独珍藏。 想…...

【leetcode详解】心算挑战: 一题搞懂涉及奇偶数问题的 “万金油” 思路(思路详解)
前记: 做了几日的leetcode每日一题,几乎全是十分钟结束战斗的【中等】题,今日杀出来个【简单】题,反倒开始难以想出很清楚的解题思路,反复调试修改才将题目逐渐考虑全面,看到了原本思路的漏洞,…...

【资料集】数据库设计说明书(Word原件提供)
2 数据库环境说明 3 数据库的命名规则 4 逻辑设计 5 物理设计 5.1 表汇总 5.2 表结构设计 6 数据规划 6.1 表空间设计 6.2 数据文件设计 6.3 表、索引分区设计 6.4 优化方法 7 安全性设计 7.1 防止用户直接操作数据库 7.2 用户帐号加密处理 7.3 角色与权限控制 8 数据库管理与维…...
MySQL 常用查询语句精粹
引言 MySQL 是一种广泛使用的开源关系型数据库管理系统,其强大的查询语言为用户提供了丰富的数据处理能力。掌握 MySQL 的常用查询语句对于数据库管理和数据分析至关重要。本文将介绍一些 MySQL 中的常用查询语句,并提供实际的示例。 基础查询 1. 选择…...
hive的内部表(MANAGED_TABLE)和外部表(EXTERNAL_TABLE)的区别
1.hive的表类型分为外部表和内部表 内部表和外部表的主要区别在于数据的存储方式。 外部表:外部表的存储在hdfs中,是我们指定的文件目录,当我们删除数据或者删除分区的时候不会将元数据删除,数据还会在hdfs目录中,我们…...
【根据当天日期输出明天的日期(需对闰年做判定)。】2022-5-15
缘由根据当天日期输出明天的日期(需对闰年做判定)。日期类型结构体如下: struct data{ int year; int month; int day;};-编程语言-CSDN问答 struct mdata{ int year; int month; int day; }mdata; int 天数(int year, int month) {switch (month){case 1: case 3:…...
pam_env.so模块配置解析
在PAM(Pluggable Authentication Modules)配置中, /etc/pam.d/su 文件相关配置含义如下: 配置解析 auth required pam_env.so1. 字段分解 字段值说明模块类型auth认证类模块,负责验证用户身份&am…...

令牌桶 滑动窗口->限流 分布式信号量->限并发的原理 lua脚本分析介绍
文章目录 前言限流限制并发的实际理解限流令牌桶代码实现结果分析令牌桶lua的模拟实现原理总结: 滑动窗口代码实现结果分析lua脚本原理解析 限并发分布式信号量代码实现结果分析lua脚本实现原理 双注解去实现限流 并发结果分析: 实际业务去理解体会统一注…...
C++中string流知识详解和示例
一、概览与类体系 C 提供三种基于内存字符串的流,定义在 <sstream> 中: std::istringstream:输入流,从已有字符串中读取并解析。std::ostringstream:输出流,向内部缓冲区写入内容,最终取…...

JUC笔记(上)-复习 涉及死锁 volatile synchronized CAS 原子操作
一、上下文切换 即使单核CPU也可以进行多线程执行代码,CPU会给每个线程分配CPU时间片来实现这个机制。时间片非常短,所以CPU会不断地切换线程执行,从而让我们感觉多个线程是同时执行的。时间片一般是十几毫秒(ms)。通过时间片分配算法执行。…...
全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比
目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec? IPsec VPN 5.1 IPsec传输模式(Transport Mode) 5.2 IPsec隧道模式(Tunne…...
Java 二维码
Java 二维码 **技术:**谷歌 ZXing 实现 首先添加依赖 <!-- 二维码依赖 --><dependency><groupId>com.google.zxing</groupId><artifactId>core</artifactId><version>3.5.1</version></dependency><de…...
管理学院权限管理系统开发总结
文章目录 🎓 管理学院权限管理系统开发总结 - 现代化Web应用实践之路📝 项目概述🏗️ 技术架构设计后端技术栈前端技术栈 💡 核心功能特性1. 用户管理模块2. 权限管理系统3. 统计报表功能4. 用户体验优化 🗄️ 数据库设…...

NXP S32K146 T-Box 携手 SD NAND(贴片式TF卡):驱动汽车智能革新的黄金组合
在汽车智能化的汹涌浪潮中,车辆不再仅仅是传统的交通工具,而是逐步演变为高度智能的移动终端。这一转变的核心支撑,来自于车内关键技术的深度融合与协同创新。车载远程信息处理盒(T-Box)方案:NXP S32K146 与…...

Linux 内存管理实战精讲:核心原理与面试常考点全解析
Linux 内存管理实战精讲:核心原理与面试常考点全解析 Linux 内核内存管理是系统设计中最复杂但也最核心的模块之一。它不仅支撑着虚拟内存机制、物理内存分配、进程隔离与资源复用,还直接决定系统运行的性能与稳定性。无论你是嵌入式开发者、内核调试工…...