当前位置: 首页 > news >正文

对抗式生成模仿学习(GAIL)

目录

1 预先基础知识 

1.1 对抗生成网络(GAN)

1.1.1 基本概念

1.1.2 损失函数

1.1.2.1 固定G,求解令损失函数最大的D

1.1.2.2 固定D,求解令损失函数最小的G

1.2 对抗式生成模仿学习特点

2 对抗式生成模仿学习(GAIL)详细说明

3 参考文献

1 预先基础知识 

1.1 对抗生成网络(GAN)

1.1.1 基本概念

在GAN生成对抗网络中,包含两个模型,一个生成模型,一个判别模型。

  • 生成模型:负责生成看起来真实自然,和原始数据相似的实例。
  • 判别模型:负责判断给出的实例是真实的还是人为伪造的。

生成模型努力去欺骗判别模型,判别模型努力不被欺骗,这样两种模型交替优化训练,都得到了提升。

对于辨别器,如果得到的是生成图片辨别器应该输出0,如果是真实的图片应该输出 1,得到误差梯度反向传播来更新参数。对于生成器,首先由生成器生成一张图片,然后输入给判别器判别并的到相应的误差梯度,然后反向传播这些图片梯度成为组成生成器的权重。直观上来说就是:辨别器不得不告诉生成器如何调整从而使它生成的图片变得更加真实。

1.1.2 损失函数

GAN模型的目标函数:

其中,参考GAN的架构图,字母 V是原始GAN论文中指定用来表示该交叉熵的字母,x 表示任意真实数据,z 表示与真实数据相同结构的任意随机数据,G(z)表示在生成器中基于 z 生成的假数据,而D(x)表示判别器在真实数据 x上判断出的结果,D(G(z))表示判别器在假数据 G(z)上判断出的结果,其中 D(x) 与D(G(z))都是样本为“真”的概率,即标签为1的概率。

上式,主要意思是先固定生成器G,从判别器D的角度令损失最大化,紧接着固定D,从生成器G的角度令损失最小化,即可让判别器和生成器在共享损失的情况下实现对抗。其中第一个期望\mathbb{E}_{x \sim p_{\text{data}}(x)} \left[ \log D(x) \right]是所有x都是真实数据时(log(D(x)))的期望,第二个期望\mathbb{E}_{z \sim p(z)} \left[ \log (1 - D(G(z))) \right]是所有数据都是生成数据时log(1-D(G(z)))的期望。可以看出,在求解最优解的过程中存在两个过程:

  • 固定G,求解令损失函数最大的D
  • 固定D,求解令损失函数最小的G

判别网络是一个2分类,目标是分清真实数据和伪造数据,也就是希望D(x) 趋近于1,D(G(z))趋近于0,这也就体现了对抗的思想。G网络的loss是log(1-D(G(z))),D的loss是-(log(D(x)))+log(1-D(G(z)))。

1.1.2.1 固定G,求解令损失函数最大的D

判别器D的输入x有两部分:一部分是真实数据,设其分布为P_{\text{data}}(x);另一部分是生成器生成的数据,参考架构图,生成器接收的数据z服从分布P(z),A输入z经过生成器的计算生成的数据分布设为P_{G}(x)

这两部分这两部分都是判别器D的输入,不同的是,G的输出来自分布P_{G}(x),而真实数据来自分布P_{\text{data}}(x),经过一系列推导后的结果:

可以看出,固定G,将最优的D带入后,此时V(G,D*),实际上是在度量P_{\text{data}}(x)P_{G}(x)之间的JS散度,同KL散度一样,他们之间的分布差异越大,JS散度值也越大。换句话说:保持G不变,最大化V(G,D)就等价于计算JS散度。对于判别器来说,尽可能找出生成器生成的数据与真实数据分布之间的差异,这个差异就是JS散度。

1.1.2.2 固定D,求解令损失函数最小的G

对于生成器来说,让生成器生成的数据分布接近真实数据分布。现在第一步已经求出了最优解的D*,代入损失函数:

在最小化JS散度,JS散度越小,分部之间的差异越小,正好印证了第二个原则。

1.2 对抗式生成模仿学习特点

逆强化学习(Inverse Reinforcement Learning, IRL)作为一种典型的模仿学习方法,顾名思义,逆强化学习的学习过程与正常的强化学习利用奖励函数学习策略相反,不利用现有的奖励函数,而是试图学出一个奖励函数,并以之指导基于奖励函数的强化学习过程。IRL可以归结为解决从观察到的最优行为中提取奖励函数( Reward Function)的问题,这些最优行为也可以表示为专家策略 。基于IRL的方法交替地在两个过程中交替:一个阶段是使用示范数据来推断一个隐藏的奖励(Reward)或代价( Cost)函数,另一个阶段是使用强化学习基于推断的奖励函数来学习一个模仿策略。IRL的基本准则是:IRL选择奖励函数来优化策略,并且使得任何不同于\Pi _{E}的动作决策尽可能产生更大损失。

对抗式生成模仿学习(Generative Adversarial Imitation Learning,GAIL)是逆强化学习的一种重要实现方法之一。逆强化学习旨在从专家示范的行为中推断环境的奖励函数或者价值函数,而GAIL是逆强化学习的一种实现方式,它利用了生成对抗网络(GAN)的概念来进行模仿学习。

GAIL的关键点在于:

1生成对抗网络: GAIL使用生成对抗网络的框架,其中包括生成器和判别器。

2生成器与判别器: 生成器尝试生成与专家示范行为相似的状态-动作对,而判别器则尝试区分专家行为和生成器生成的行为。

3对抗优化: GAIL使用对抗训练的思想,通过生成器和判别器的对抗优化来使得生成器的输出逼近专家的行为。

GAIL的工作方式使得它在逆强化学习中发挥着重要作用,因为它提供了一种有效的方式来从专家示范中学习环境的奖励结构,以指导智能体的学习行为。通过对抗式生成模仿学习,智能体可以学习并模仿专家的行为,而无需显式地使用环境的奖励信号。

因此,GAIL作为逆强化学习的一种方法,为从专家示范中学习环境的奖励函数或者价值函数提供了一种有效的框架和方法。

2 对抗式生成模仿学习(GAIL)详细说明

 

生成式对抗模仿学习的整体优化流程如图所示。通过 GAIL 方法,策略生成器通过生成类似专家示教样本的探索样本,泛化示教样本的概率分布, 逼近专家示范行为数据,进而实现模仿专家技能的目的。该过程直接优化采样样本的概率分布,计算代价较小且算法通用性更强,实际模仿效果也更好。 

伪代码:

# 初始化策略 π、判别器 D、专家示范数据 D_expert、策略缓冲区 D_policy函数 GAIL_Training():初始化策略 π 的参数初始化判别器 D 的参数循环 直到收敛 或 达到最大迭代次数:# 使用当前策略 π 生成轨迹并存储在策略缓冲区 D_policy 中生成 trajectories 使用 π 并存储在 D_policy 中# 判别器训练循环 discriminator_updates 次数:# 从策略缓冲区 D_policy 中采样数据采样 (s_policy, a_policy) 从 D_policy 中# 从专家示范数据 D_expert 中采样数据采样 (s_expert, a_expert) 从 D_expert 中# 更新判别器 D计算 L_D = -[log(D(s_expert, a_expert)) + log(1 - D(s_policy, a_policy))]使用梯度下降法更新判别器参数以最小化 L_D# 策略更新采样 (s, a, ...) 从 D_policy 中计算伪奖励 r = -log(1 - D(s, a))# 使用伪奖励 r 更新策略 π计算 L_π 使用 PPO 或 其他强化学习方法使用梯度下降法更新策略 π 的参数以最大化 L_π

能够表征GAIL流程的主程序如下: 

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adamfrom .ppo import PPO
from gail_airl_ppo.network import GAILDiscrimclass GAIL(PPO):def __init__(self, buffer_exp, state_shape, action_shape, device, seed,gamma=0.995, rollout_length=50000, mix_buffer=1,batch_size=64, lr_actor=3e-4, lr_critic=3e-4, lr_disc=3e-4,units_actor=(64, 64), units_critic=(64, 64),units_disc=(100, 100), epoch_ppo=50, epoch_disc=10,clip_eps=0.2, lambd=0.97, coef_ent=0.0, max_grad_norm=10.0):super().__init__(state_shape, action_shape, device, seed, gamma, rollout_length,mix_buffer, lr_actor, lr_critic, units_actor, units_critic,epoch_ppo, clip_eps, lambd, coef_ent, max_grad_norm)# Expert's buffer.self.buffer_exp = buffer_exp# Discriminator.self.disc = GAILDiscrim(state_shape=state_shape,action_shape=action_shape,hidden_units=units_disc,hidden_activation=nn.Tanh()).to(device)self.learning_steps_disc = 0self.optim_disc = Adam(self.disc.parameters(), lr=lr_disc)self.batch_size = batch_sizeself.epoch_disc = epoch_discdef update(self, writer):self.learning_steps += 1for _ in range(self.epoch_disc):self.learning_steps_disc += 1# Samples from current policy's trajectories.states, actions = self.buffer.sample(self.batch_size)[:2]# Samples from expert's demonstrations.states_exp, actions_exp = \self.buffer_exp.sample(self.batch_size)[:2]# Update discriminator.self.update_disc(states, actions, states_exp, actions_exp, writer)# We don't use reward signals here,states, actions, _, dones, log_pis, next_states = self.buffer.get()# Calculate rewards.rewards = self.disc.calculate_reward(states, actions)# Update PPO using estimated rewards.self.update_ppo(states, actions, rewards, dones, log_pis, next_states, writer)def update_disc(self, states, actions, states_exp, actions_exp, writer):# Output of discriminator is (-inf, inf), not [0, 1].logits_pi = self.disc(states, actions)logits_exp = self.disc(states_exp, actions_exp)# Discriminator is to maximize E_{\pi} [log(1 - D)] + E_{exp} [log(D)].loss_pi = -F.logsigmoid(-logits_pi).mean()loss_exp = -F.logsigmoid(logits_exp).mean()loss_disc = loss_pi + loss_expself.optim_disc.zero_grad()loss_disc.backward()self.optim_disc.step()if self.learning_steps_disc % self.epoch_disc == 0:writer.add_scalar('loss/disc', loss_disc.item(), self.learning_steps)# Discriminator's accuracies.with torch.no_grad():acc_pi = (logits_pi < 0).float().mean().item()acc_exp = (logits_exp > 0).float().mean().item()writer.add_scalar('stats/acc_pi', acc_pi, self.learning_steps)writer.add_scalar('stats/acc_exp', acc_exp, self.learning_steps)

3 参考文献

https://zhuanlan.zhihu.com/p/628915533

【强化学习】GAIL_gail算法-CSDN博客

代码:https://github.com/toshikwa/gail-airl-ppo.pytorch.git

相关文章:

对抗式生成模仿学习(GAIL)

目录 1 预先基础知识 1.1 对抗生成网络&#xff08;GAN&#xff09; 1.1.1 基本概念 1.1.2 损失函数 1.1.2.1 固定G&#xff0c;求解令损失函数最大的D 1.1.2.2 固定D&#xff0c;求解令损失函数最小的G 1.2 对抗式生成模仿学习特点 2 对抗式生成模仿学习&#xff08;…...

信息系统项目管理师 | 新一代信息技术

关注WX&#xff1a;CodingTechWork 物联网 定义 The Internet of Things是指通过信息传感设备&#xff0c;按约定的协议&#xff0c;将任何物品与互联网连接&#xff0c;进行信息交互和通信&#xff0c;以实现智能化识别。定位、跟踪、监控和管理的一种网络。物联网主要解决…...

安全宣传咨询日活动向媒体投稿记住这个投稿好方法

在信息爆炸的时代,作为单位的信息宣传员,我肩负着将每一次重要活动,特别是像“安全宣传咨询日”这样的公益活动,有效传达给公众的重任。这份工作看似简单,实则充满了挑战,尤其是在我初涉此领域时,那段曲折而又难忘的投稿经历,至今记忆犹新。 初探投稿之海,遭遇重重困难 起初,我…...

第7章:系统架构设计基础知识-软件架构风格

由于历史原因&#xff0c;研究者和工程人员对Sofiware Architecture(简称SA)的翻译不尽相同&#xff0c;其软件的“体系结构”和“架构”具有相同的含义。 系统架构其实就是系统的结构&#xff0c;系统架构设计其实就是要给相关利益方说清楚通过什么样的结构来解决需求中功能和…...

自制调色小工具给图片加滤镜,修改图片红、绿、蓝通道及亮度,修改图片颜色

上篇&#xff1a; 上篇我们给地图添加了锐化、模糊等滤镜&#xff0c;这篇来写一个小工具给图片调色。 调色比锐化等滤镜要简单许多&#xff0c;直接拿到像素值修改即可。不需要用到卷积核。。。(*^▽^*) 核心原理就是图像结构&#xff0c;使用context.getImageData获取图像像…...

【Redis】java客户端(SpringData和jedis)

https://www.oz6.cn/articles/58 https://www.bilibili.com/video/BV1cr4y1671t/?p16 redis官网客户端介绍&#xff1a;https://redis.io/docs/latest/develop/connect/clients/ jedis maven引入依赖 <dependencies><!--引入Jedis依赖--><dependency><…...

大数据安全经典面试题及回答(上)

目录 一、大数据安全的主要挑战及应对策略 二、大数据安全中的“五个V”及其影响 三、在Hadoop集群中实施数据加密的步骤和注意事项 四、在大数据环境中实施访问控制和身份认证 五、大数据环境中数据备份和恢复的策略 六、大数据处理过程中保护用户隐私的策略 七、大数据…...

vi/vim使用命令

你是否在编辑文件时以为键盘坏了&#xff0c;为什么不能删除呢&#xff0c;为什么不能敲代码呢&#xff0c;当你初识vi&#xff0c;会觉得这个东西设计很难用&#xff0c;这篇教程带你熟练得用上这款经典的工具&#xff0c;当你熟练了这款工具就会真正体会到高效率打码 Vi 是在…...

webpack打包gz文件,nginx开启gzip压缩

wepback配置 webpack4配合"compression-webpack-plugin": "^6.1.2"打包压缩gz chain.plugin("compression").use(new CompressionPlugin({test: /\.js$|\.html$|\.css$/,threshold: 10240, // 超过10KB的压缩deleteOriginalAssets: false,// 保…...

微服务开发与实战Day11 - 微服务面试篇

一、分布式事务 1. CAP定理 1998年&#xff0c;加州大学的计算机科学及Eric Brewer提出&#xff0c;分布式系统有三个指标&#xff1a; Consistency&#xff08;一致性&#xff09;Availability&#xff08;可用性&#xff09;Partition tolerance&#xff08;分区容错性&am…...

基于Spring Boot+VUE职称评审管理系统

1管理员功能模块 管理员登录&#xff0c;通过填写注册时输入的用户名、密码、角色进行登录&#xff0c;如图1所示。 图1管理员登录界面图 管理员登录进入职称评审管理系统可以查看首页、个人中心、用户管理、评审员管理、省份管理、评审条件管理、职称申请管理、结果公布管理、…...

MySQL 基本语法讲解及示例(上)

第一节&#xff1a;MySQL的基本操作 1. 创建数据库 在 MySQL 中&#xff0c;创建数据库的步骤如下&#xff1a; 命令行操作 打开 MySQL 命令行客户端或连接到 MySQL 服务器。 输入以下命令创建一个数据库&#xff1a; CREATE DATABASE database_name;例如&#xff0c;创建一…...

6.18作业

完善对话框&#xff0c;点击登录对话框&#xff0c;如果账号和密码匹配&#xff0c;则弹出信息对话框&#xff0c;给出提示”登录成功“&#xff0c;提供一个Ok按钮&#xff0c;用户点击Ok后&#xff0c;关闭登录界面&#xff0c;跳转到其他界面 如果账号和密码不匹配&#xff…...

Excel文件转换为HTML文件

文章目录 前言安装python包python代码 前言 将一个Excel文件转换为HTML文件 安装python包 使用pandas和openpyxl库来实现这个功能 pip install pandas openpyxlpython代码 1、首先使用tkinter库中的filedialog模块弹出一个对话框来选择要转换的Excel文件 2、使用pandas库…...

MySQL数据库入门

1、MySQL概述 MySQL官方网站 https://www.mysql.com/downloads/ MySQL被Oracle公司收购了&#xff0c;作者又重新编写了一个开源的数据库管理系统&#xff0c;Mariadb 2、MySQL产品&版本 2、数据库在网站架构中的角色 LAMP LNMP网站架构 3、安装MySQL-基于yum 查…...

vue element-ui 下拉框 以及 input 限制输入,小数点后保留两位 界面设计案例 和 例子:支持mp4和m3u8视频播放

vue input 限制输入&#xff0c;小数点后保留两位 以及 图片垂直居中显示 和 分享 git 小技巧-CSDN博客文章浏览阅读430次&#xff0c;点赞5次&#xff0c;收藏4次。error:Your local changes to the following files would be overwritten by merge:_error: your local change…...

Python基础用法 之 运算符

1.算数运算符 符号作用说明举例加与“”相同 - 减与“-”相同*乘 与“ ”相同 9*218/除 与“ ”相同 9/24.5 、6/32.0//求商&#xff08;整数部分&#xff09; 两个数据做除法的 商 9//24%取余&#xff08;余数部分&#xff09; 是两个数据做除法的 余数 9%21**幂、次方2**…...

事务所管理系统的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;管理员管理&#xff0c;客户管理&#xff0c;评论管理&#xff0c;基础数据管理&#xff0c;公告信息管理 客户账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;律师管理&#xff0…...

airsim安装

继续进行&#xff0c;遇到下面的报错 Cannot find path HKEY_CLASSES_ROOT\Unreal.ProjectFile\shell\rungenproj 在Git地址的issue中&#xff0c;搜到下面的解决方法&#xff0c;根因是安装Unreal Engine之后未重启电脑&#xff0c;文件未关联导致&#xff0c;或者出现重定向…...

打造精致UI界面:字体设计的妙招

字体设计是UI设计的关键模块之一。字体设计是否有效可能直接实现或破坏整个UI界面。那么&#xff0c;界面设计的字体设计有哪些规范呢&#xff1f;如何设计细节字体&#xff1f;本文将解释字体设计规范的可读性、可读性和可用性&#xff0c;并介绍UI界面中的字体设计技巧。 如…...

第19节 Node.js Express 框架

Express 是一个为Node.js设计的web开发框架&#xff0c;它基于nodejs平台。 Express 简介 Express是一个简洁而灵活的node.js Web应用框架, 提供了一系列强大特性帮助你创建各种Web应用&#xff0c;和丰富的HTTP工具。 使用Express可以快速地搭建一个完整功能的网站。 Expre…...

eNSP-Cloud(实现本地电脑与eNSP内设备之间通信)

说明&#xff1a; 想象一下&#xff0c;你正在用eNSP搭建一个虚拟的网络世界&#xff0c;里面有虚拟的路由器、交换机、电脑&#xff08;PC&#xff09;等等。这些设备都在你的电脑里面“运行”&#xff0c;它们之间可以互相通信&#xff0c;就像一个封闭的小王国。 但是&#…...

Java 语言特性(面试系列2)

一、SQL 基础 1. 复杂查询 &#xff08;1&#xff09;连接查询&#xff08;JOIN&#xff09; 内连接&#xff08;INNER JOIN&#xff09;&#xff1a;返回两表匹配的记录。 SELECT e.name, d.dept_name FROM employees e INNER JOIN departments d ON e.dept_id d.dept_id; 左…...

工业安全零事故的智能守护者:一体化AI智能安防平台

前言&#xff1a; 通过AI视觉技术&#xff0c;为船厂提供全面的安全监控解决方案&#xff0c;涵盖交通违规检测、起重机轨道安全、非法入侵检测、盗窃防范、安全规范执行监控等多个方面&#xff0c;能够实现对应负责人反馈机制&#xff0c;并最终实现数据的统计报表。提升船厂…...

中南大学无人机智能体的全面评估!BEDI:用于评估无人机上具身智能体的综合性基准测试

作者&#xff1a;Mingning Guo, Mengwei Wu, Jiarun He, Shaoxian Li, Haifeng Li, Chao Tao单位&#xff1a;中南大学地球科学与信息物理学院论文标题&#xff1a;BEDI: A Comprehensive Benchmark for Evaluating Embodied Agents on UAVs论文链接&#xff1a;https://arxiv.…...

ElasticSearch搜索引擎之倒排索引及其底层算法

文章目录 一、搜索引擎1、什么是搜索引擎?2、搜索引擎的分类3、常用的搜索引擎4、搜索引擎的特点二、倒排索引1、简介2、为什么倒排索引不用B+树1.创建时间长,文件大。2.其次,树深,IO次数可怕。3.索引可能会失效。4.精准度差。三. 倒排索引四、算法1、Term Index的算法2、 …...

QT: `long long` 类型转换为 `QString` 2025.6.5

在 Qt 中&#xff0c;将 long long 类型转换为 QString 可以通过以下两种常用方法实现&#xff1a; 方法 1&#xff1a;使用 QString::number() 直接调用 QString 的静态方法 number()&#xff0c;将数值转换为字符串&#xff1a; long long value 1234567890123456789LL; …...

基于Java Swing的电子通讯录设计与实现:附系统托盘功能代码详解

JAVASQL电子通讯录带系统托盘 一、系统概述 本电子通讯录系统采用Java Swing开发桌面应用&#xff0c;结合SQLite数据库实现联系人管理功能&#xff0c;并集成系统托盘功能提升用户体验。系统支持联系人的增删改查、分组管理、搜索过滤等功能&#xff0c;同时可以最小化到系统…...

基于TurtleBot3在Gazebo地图实现机器人远程控制

1. TurtleBot3环境配置 # 下载TurtleBot3核心包 mkdir -p ~/catkin_ws/src cd ~/catkin_ws/src git clone -b noetic-devel https://github.com/ROBOTIS-GIT/turtlebot3.git git clone -b noetic https://github.com/ROBOTIS-GIT/turtlebot3_msgs.git git clone -b noetic-dev…...

Java毕业设计:WML信息查询与后端信息发布系统开发

JAVAWML信息查询与后端信息发布系统实现 一、系统概述 本系统基于Java和WML(无线标记语言)技术开发&#xff0c;实现了移动设备上的信息查询与后端信息发布功能。系统采用B/S架构&#xff0c;服务器端使用Java Servlet处理请求&#xff0c;数据库采用MySQL存储信息&#xff0…...