当前位置: 首页 > news >正文

对抗式生成模仿学习(GAIL)

目录

1 预先基础知识 

1.1 对抗生成网络(GAN)

1.1.1 基本概念

1.1.2 损失函数

1.1.2.1 固定G,求解令损失函数最大的D

1.1.2.2 固定D,求解令损失函数最小的G

1.2 对抗式生成模仿学习特点

2 对抗式生成模仿学习(GAIL)详细说明

3 参考文献

1 预先基础知识 

1.1 对抗生成网络(GAN)

1.1.1 基本概念

在GAN生成对抗网络中,包含两个模型,一个生成模型,一个判别模型。

  • 生成模型:负责生成看起来真实自然,和原始数据相似的实例。
  • 判别模型:负责判断给出的实例是真实的还是人为伪造的。

生成模型努力去欺骗判别模型,判别模型努力不被欺骗,这样两种模型交替优化训练,都得到了提升。

对于辨别器,如果得到的是生成图片辨别器应该输出0,如果是真实的图片应该输出 1,得到误差梯度反向传播来更新参数。对于生成器,首先由生成器生成一张图片,然后输入给判别器判别并的到相应的误差梯度,然后反向传播这些图片梯度成为组成生成器的权重。直观上来说就是:辨别器不得不告诉生成器如何调整从而使它生成的图片变得更加真实。

1.1.2 损失函数

GAN模型的目标函数:

其中,参考GAN的架构图,字母 V是原始GAN论文中指定用来表示该交叉熵的字母,x 表示任意真实数据,z 表示与真实数据相同结构的任意随机数据,G(z)表示在生成器中基于 z 生成的假数据,而D(x)表示判别器在真实数据 x上判断出的结果,D(G(z))表示判别器在假数据 G(z)上判断出的结果,其中 D(x) 与D(G(z))都是样本为“真”的概率,即标签为1的概率。

上式,主要意思是先固定生成器G,从判别器D的角度令损失最大化,紧接着固定D,从生成器G的角度令损失最小化,即可让判别器和生成器在共享损失的情况下实现对抗。其中第一个期望\mathbb{E}_{x \sim p_{\text{data}}(x)} \left[ \log D(x) \right]是所有x都是真实数据时(log(D(x)))的期望,第二个期望\mathbb{E}_{z \sim p(z)} \left[ \log (1 - D(G(z))) \right]是所有数据都是生成数据时log(1-D(G(z)))的期望。可以看出,在求解最优解的过程中存在两个过程:

  • 固定G,求解令损失函数最大的D
  • 固定D,求解令损失函数最小的G

判别网络是一个2分类,目标是分清真实数据和伪造数据,也就是希望D(x) 趋近于1,D(G(z))趋近于0,这也就体现了对抗的思想。G网络的loss是log(1-D(G(z))),D的loss是-(log(D(x)))+log(1-D(G(z)))。

1.1.2.1 固定G,求解令损失函数最大的D

判别器D的输入x有两部分:一部分是真实数据,设其分布为P_{\text{data}}(x);另一部分是生成器生成的数据,参考架构图,生成器接收的数据z服从分布P(z),A输入z经过生成器的计算生成的数据分布设为P_{G}(x)

这两部分这两部分都是判别器D的输入,不同的是,G的输出来自分布P_{G}(x),而真实数据来自分布P_{\text{data}}(x),经过一系列推导后的结果:

可以看出,固定G,将最优的D带入后,此时V(G,D*),实际上是在度量P_{\text{data}}(x)P_{G}(x)之间的JS散度,同KL散度一样,他们之间的分布差异越大,JS散度值也越大。换句话说:保持G不变,最大化V(G,D)就等价于计算JS散度。对于判别器来说,尽可能找出生成器生成的数据与真实数据分布之间的差异,这个差异就是JS散度。

1.1.2.2 固定D,求解令损失函数最小的G

对于生成器来说,让生成器生成的数据分布接近真实数据分布。现在第一步已经求出了最优解的D*,代入损失函数:

在最小化JS散度,JS散度越小,分部之间的差异越小,正好印证了第二个原则。

1.2 对抗式生成模仿学习特点

逆强化学习(Inverse Reinforcement Learning, IRL)作为一种典型的模仿学习方法,顾名思义,逆强化学习的学习过程与正常的强化学习利用奖励函数学习策略相反,不利用现有的奖励函数,而是试图学出一个奖励函数,并以之指导基于奖励函数的强化学习过程。IRL可以归结为解决从观察到的最优行为中提取奖励函数( Reward Function)的问题,这些最优行为也可以表示为专家策略 。基于IRL的方法交替地在两个过程中交替:一个阶段是使用示范数据来推断一个隐藏的奖励(Reward)或代价( Cost)函数,另一个阶段是使用强化学习基于推断的奖励函数来学习一个模仿策略。IRL的基本准则是:IRL选择奖励函数来优化策略,并且使得任何不同于\Pi _{E}的动作决策尽可能产生更大损失。

对抗式生成模仿学习(Generative Adversarial Imitation Learning,GAIL)是逆强化学习的一种重要实现方法之一。逆强化学习旨在从专家示范的行为中推断环境的奖励函数或者价值函数,而GAIL是逆强化学习的一种实现方式,它利用了生成对抗网络(GAN)的概念来进行模仿学习。

GAIL的关键点在于:

1生成对抗网络: GAIL使用生成对抗网络的框架,其中包括生成器和判别器。

2生成器与判别器: 生成器尝试生成与专家示范行为相似的状态-动作对,而判别器则尝试区分专家行为和生成器生成的行为。

3对抗优化: GAIL使用对抗训练的思想,通过生成器和判别器的对抗优化来使得生成器的输出逼近专家的行为。

GAIL的工作方式使得它在逆强化学习中发挥着重要作用,因为它提供了一种有效的方式来从专家示范中学习环境的奖励结构,以指导智能体的学习行为。通过对抗式生成模仿学习,智能体可以学习并模仿专家的行为,而无需显式地使用环境的奖励信号。

因此,GAIL作为逆强化学习的一种方法,为从专家示范中学习环境的奖励函数或者价值函数提供了一种有效的框架和方法。

2 对抗式生成模仿学习(GAIL)详细说明

 

生成式对抗模仿学习的整体优化流程如图所示。通过 GAIL 方法,策略生成器通过生成类似专家示教样本的探索样本,泛化示教样本的概率分布, 逼近专家示范行为数据,进而实现模仿专家技能的目的。该过程直接优化采样样本的概率分布,计算代价较小且算法通用性更强,实际模仿效果也更好。 

伪代码:

# 初始化策略 π、判别器 D、专家示范数据 D_expert、策略缓冲区 D_policy函数 GAIL_Training():初始化策略 π 的参数初始化判别器 D 的参数循环 直到收敛 或 达到最大迭代次数:# 使用当前策略 π 生成轨迹并存储在策略缓冲区 D_policy 中生成 trajectories 使用 π 并存储在 D_policy 中# 判别器训练循环 discriminator_updates 次数:# 从策略缓冲区 D_policy 中采样数据采样 (s_policy, a_policy) 从 D_policy 中# 从专家示范数据 D_expert 中采样数据采样 (s_expert, a_expert) 从 D_expert 中# 更新判别器 D计算 L_D = -[log(D(s_expert, a_expert)) + log(1 - D(s_policy, a_policy))]使用梯度下降法更新判别器参数以最小化 L_D# 策略更新采样 (s, a, ...) 从 D_policy 中计算伪奖励 r = -log(1 - D(s, a))# 使用伪奖励 r 更新策略 π计算 L_π 使用 PPO 或 其他强化学习方法使用梯度下降法更新策略 π 的参数以最大化 L_π

能够表征GAIL流程的主程序如下: 

import torch
from torch import nn
import torch.nn.functional as F
from torch.optim import Adamfrom .ppo import PPO
from gail_airl_ppo.network import GAILDiscrimclass GAIL(PPO):def __init__(self, buffer_exp, state_shape, action_shape, device, seed,gamma=0.995, rollout_length=50000, mix_buffer=1,batch_size=64, lr_actor=3e-4, lr_critic=3e-4, lr_disc=3e-4,units_actor=(64, 64), units_critic=(64, 64),units_disc=(100, 100), epoch_ppo=50, epoch_disc=10,clip_eps=0.2, lambd=0.97, coef_ent=0.0, max_grad_norm=10.0):super().__init__(state_shape, action_shape, device, seed, gamma, rollout_length,mix_buffer, lr_actor, lr_critic, units_actor, units_critic,epoch_ppo, clip_eps, lambd, coef_ent, max_grad_norm)# Expert's buffer.self.buffer_exp = buffer_exp# Discriminator.self.disc = GAILDiscrim(state_shape=state_shape,action_shape=action_shape,hidden_units=units_disc,hidden_activation=nn.Tanh()).to(device)self.learning_steps_disc = 0self.optim_disc = Adam(self.disc.parameters(), lr=lr_disc)self.batch_size = batch_sizeself.epoch_disc = epoch_discdef update(self, writer):self.learning_steps += 1for _ in range(self.epoch_disc):self.learning_steps_disc += 1# Samples from current policy's trajectories.states, actions = self.buffer.sample(self.batch_size)[:2]# Samples from expert's demonstrations.states_exp, actions_exp = \self.buffer_exp.sample(self.batch_size)[:2]# Update discriminator.self.update_disc(states, actions, states_exp, actions_exp, writer)# We don't use reward signals here,states, actions, _, dones, log_pis, next_states = self.buffer.get()# Calculate rewards.rewards = self.disc.calculate_reward(states, actions)# Update PPO using estimated rewards.self.update_ppo(states, actions, rewards, dones, log_pis, next_states, writer)def update_disc(self, states, actions, states_exp, actions_exp, writer):# Output of discriminator is (-inf, inf), not [0, 1].logits_pi = self.disc(states, actions)logits_exp = self.disc(states_exp, actions_exp)# Discriminator is to maximize E_{\pi} [log(1 - D)] + E_{exp} [log(D)].loss_pi = -F.logsigmoid(-logits_pi).mean()loss_exp = -F.logsigmoid(logits_exp).mean()loss_disc = loss_pi + loss_expself.optim_disc.zero_grad()loss_disc.backward()self.optim_disc.step()if self.learning_steps_disc % self.epoch_disc == 0:writer.add_scalar('loss/disc', loss_disc.item(), self.learning_steps)# Discriminator's accuracies.with torch.no_grad():acc_pi = (logits_pi < 0).float().mean().item()acc_exp = (logits_exp > 0).float().mean().item()writer.add_scalar('stats/acc_pi', acc_pi, self.learning_steps)writer.add_scalar('stats/acc_exp', acc_exp, self.learning_steps)

3 参考文献

https://zhuanlan.zhihu.com/p/628915533

【强化学习】GAIL_gail算法-CSDN博客

代码:https://github.com/toshikwa/gail-airl-ppo.pytorch.git

相关文章:

对抗式生成模仿学习(GAIL)

目录 1 预先基础知识 1.1 对抗生成网络&#xff08;GAN&#xff09; 1.1.1 基本概念 1.1.2 损失函数 1.1.2.1 固定G&#xff0c;求解令损失函数最大的D 1.1.2.2 固定D&#xff0c;求解令损失函数最小的G 1.2 对抗式生成模仿学习特点 2 对抗式生成模仿学习&#xff08;…...

信息系统项目管理师 | 新一代信息技术

关注WX&#xff1a;CodingTechWork 物联网 定义 The Internet of Things是指通过信息传感设备&#xff0c;按约定的协议&#xff0c;将任何物品与互联网连接&#xff0c;进行信息交互和通信&#xff0c;以实现智能化识别。定位、跟踪、监控和管理的一种网络。物联网主要解决…...

安全宣传咨询日活动向媒体投稿记住这个投稿好方法

在信息爆炸的时代,作为单位的信息宣传员,我肩负着将每一次重要活动,特别是像“安全宣传咨询日”这样的公益活动,有效传达给公众的重任。这份工作看似简单,实则充满了挑战,尤其是在我初涉此领域时,那段曲折而又难忘的投稿经历,至今记忆犹新。 初探投稿之海,遭遇重重困难 起初,我…...

第7章:系统架构设计基础知识-软件架构风格

由于历史原因&#xff0c;研究者和工程人员对Sofiware Architecture(简称SA)的翻译不尽相同&#xff0c;其软件的“体系结构”和“架构”具有相同的含义。 系统架构其实就是系统的结构&#xff0c;系统架构设计其实就是要给相关利益方说清楚通过什么样的结构来解决需求中功能和…...

自制调色小工具给图片加滤镜,修改图片红、绿、蓝通道及亮度,修改图片颜色

上篇&#xff1a; 上篇我们给地图添加了锐化、模糊等滤镜&#xff0c;这篇来写一个小工具给图片调色。 调色比锐化等滤镜要简单许多&#xff0c;直接拿到像素值修改即可。不需要用到卷积核。。。(*^▽^*) 核心原理就是图像结构&#xff0c;使用context.getImageData获取图像像…...

【Redis】java客户端(SpringData和jedis)

https://www.oz6.cn/articles/58 https://www.bilibili.com/video/BV1cr4y1671t/?p16 redis官网客户端介绍&#xff1a;https://redis.io/docs/latest/develop/connect/clients/ jedis maven引入依赖 <dependencies><!--引入Jedis依赖--><dependency><…...

大数据安全经典面试题及回答(上)

目录 一、大数据安全的主要挑战及应对策略 二、大数据安全中的“五个V”及其影响 三、在Hadoop集群中实施数据加密的步骤和注意事项 四、在大数据环境中实施访问控制和身份认证 五、大数据环境中数据备份和恢复的策略 六、大数据处理过程中保护用户隐私的策略 七、大数据…...

vi/vim使用命令

你是否在编辑文件时以为键盘坏了&#xff0c;为什么不能删除呢&#xff0c;为什么不能敲代码呢&#xff0c;当你初识vi&#xff0c;会觉得这个东西设计很难用&#xff0c;这篇教程带你熟练得用上这款经典的工具&#xff0c;当你熟练了这款工具就会真正体会到高效率打码 Vi 是在…...

webpack打包gz文件,nginx开启gzip压缩

wepback配置 webpack4配合"compression-webpack-plugin": "^6.1.2"打包压缩gz chain.plugin("compression").use(new CompressionPlugin({test: /\.js$|\.html$|\.css$/,threshold: 10240, // 超过10KB的压缩deleteOriginalAssets: false,// 保…...

微服务开发与实战Day11 - 微服务面试篇

一、分布式事务 1. CAP定理 1998年&#xff0c;加州大学的计算机科学及Eric Brewer提出&#xff0c;分布式系统有三个指标&#xff1a; Consistency&#xff08;一致性&#xff09;Availability&#xff08;可用性&#xff09;Partition tolerance&#xff08;分区容错性&am…...

基于Spring Boot+VUE职称评审管理系统

1管理员功能模块 管理员登录&#xff0c;通过填写注册时输入的用户名、密码、角色进行登录&#xff0c;如图1所示。 图1管理员登录界面图 管理员登录进入职称评审管理系统可以查看首页、个人中心、用户管理、评审员管理、省份管理、评审条件管理、职称申请管理、结果公布管理、…...

MySQL 基本语法讲解及示例(上)

第一节&#xff1a;MySQL的基本操作 1. 创建数据库 在 MySQL 中&#xff0c;创建数据库的步骤如下&#xff1a; 命令行操作 打开 MySQL 命令行客户端或连接到 MySQL 服务器。 输入以下命令创建一个数据库&#xff1a; CREATE DATABASE database_name;例如&#xff0c;创建一…...

6.18作业

完善对话框&#xff0c;点击登录对话框&#xff0c;如果账号和密码匹配&#xff0c;则弹出信息对话框&#xff0c;给出提示”登录成功“&#xff0c;提供一个Ok按钮&#xff0c;用户点击Ok后&#xff0c;关闭登录界面&#xff0c;跳转到其他界面 如果账号和密码不匹配&#xff…...

Excel文件转换为HTML文件

文章目录 前言安装python包python代码 前言 将一个Excel文件转换为HTML文件 安装python包 使用pandas和openpyxl库来实现这个功能 pip install pandas openpyxlpython代码 1、首先使用tkinter库中的filedialog模块弹出一个对话框来选择要转换的Excel文件 2、使用pandas库…...

MySQL数据库入门

1、MySQL概述 MySQL官方网站 https://www.mysql.com/downloads/ MySQL被Oracle公司收购了&#xff0c;作者又重新编写了一个开源的数据库管理系统&#xff0c;Mariadb 2、MySQL产品&版本 2、数据库在网站架构中的角色 LAMP LNMP网站架构 3、安装MySQL-基于yum 查…...

vue element-ui 下拉框 以及 input 限制输入,小数点后保留两位 界面设计案例 和 例子:支持mp4和m3u8视频播放

vue input 限制输入&#xff0c;小数点后保留两位 以及 图片垂直居中显示 和 分享 git 小技巧-CSDN博客文章浏览阅读430次&#xff0c;点赞5次&#xff0c;收藏4次。error:Your local changes to the following files would be overwritten by merge:_error: your local change…...

Python基础用法 之 运算符

1.算数运算符 符号作用说明举例加与“”相同 - 减与“-”相同*乘 与“ ”相同 9*218/除 与“ ”相同 9/24.5 、6/32.0//求商&#xff08;整数部分&#xff09; 两个数据做除法的 商 9//24%取余&#xff08;余数部分&#xff09; 是两个数据做除法的 余数 9%21**幂、次方2**…...

事务所管理系统的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;管理员管理&#xff0c;客户管理&#xff0c;评论管理&#xff0c;基础数据管理&#xff0c;公告信息管理 客户账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;律师管理&#xff0…...

airsim安装

继续进行&#xff0c;遇到下面的报错 Cannot find path HKEY_CLASSES_ROOT\Unreal.ProjectFile\shell\rungenproj 在Git地址的issue中&#xff0c;搜到下面的解决方法&#xff0c;根因是安装Unreal Engine之后未重启电脑&#xff0c;文件未关联导致&#xff0c;或者出现重定向…...

打造精致UI界面:字体设计的妙招

字体设计是UI设计的关键模块之一。字体设计是否有效可能直接实现或破坏整个UI界面。那么&#xff0c;界面设计的字体设计有哪些规范呢&#xff1f;如何设计细节字体&#xff1f;本文将解释字体设计规范的可读性、可读性和可用性&#xff0c;并介绍UI界面中的字体设计技巧。 如…...

[BJDCTF2020]ZJCTF,不过如此1

打开题目可以看到一段php文件包含&#xff0c;源码如下 <?phperror_reporting(0); $text $_GET["text"]; $file $_GET["file"]; if(isset($text)&&(file_get_contents($text,r)"I have a dream")){echo "<br><h1>…...

全网最全 Kimi 使用手册,看完 Kimi 效率提升 80%

在当前AI文字大模型领域&#xff0c;ChatGPT4.0无疑是最强大。然而&#xff0c;最近最火爆的大模型非国产Kimi莫属。 相较于其它大模型&#xff0c;Kimi 最大的优势在于&#xff0c;超长文本输入&#xff0c;支持200万汉字&#xff0c;是全球范围内罕见的超长文本处理工具&…...

“Redis中的持久化:深入理解RDB与AOF机制“

目录 # 概念 1. RDB持久化 1.1 备份是如何执行的&#xff08;RDB过程&#xff09; 1.2 配置文件信息 1.3 RDB持久化操作 1.4 RDB优势 1.5 RDB劣势 1.6 RDB做备份 2. AOF持久化 2.1 AOF开启及使用 2.2 异常恢复 2.3 配置文件操作 2.4 AOF持久化流程 2.5 优点 2.6…...

PHP框架详解:Symfony框架讲解

PHP作为一种流行的服务器端编程语言&#xff0c;拥有众多框架&#xff0c;其中Symfony是备受开发者推崇的一个强大框架。本文将详细讲解Symfony框架的特点、优势及其主要组件和用法。 一、Symfony简介 Symfony是由Fabien Potencier于2005年创建的一个开源PHP框架。它基于MVC&…...

PR软件视频抠图换背景

1 新建项目 2 新建序列 在项目的右下角有个图标&#xff0c;新建 序列 序列是视频的制作尺寸&#xff0c;根据自己的需要选择 3 新建颜色遮罩 在项目的右下角--新建颜色遮罩--选择黑色--确定 4 导入视频 把要导入视频的文件夹打开&#xff0c;把视频拖到 项目 里 把黑色遮罩拖…...

下载依赖有问题(只有自己有问题)

有缓存&#xff01; 删除node_modules 命令&#xff1a;npm run clean 前提是该项目支持这个命令&#xff1a;package.json > scripts 内有 clean 例如下面这个就没有clean&#xff0c;则直接手动删除 清除缓存 npm cache clean --force pnpm store prune删除lock文件 …...

vscode-关闭ts与js语义校验

1.ts与js语义校验 TypeScript&#xff08;TS&#xff09;和JavaScript&#xff08;JS&#xff09;在语义校验方面有很大的不同。TypeScript是一种静态类型检查的编程语言&#xff0c;它是JavaScript的一个超集&#xff0c;为JavaScript添加了类型系统和其他一些特性。而JavaScr…...

风控中的文本相似方法之余弦定理

一、余弦相似 一、 余弦相似概述 余弦相似性通过测量两个向量的夹角的余弦值来度量它们之间的相似性。0度角的余弦值是1&#xff0c;而其他任何角度的余弦值都不大于1&#xff1b;并且其最小值是-1。 从而两个向量之间的角度的余弦值确定两个向量是否大致指向相同的方向。结…...

Spring Boot定时任务编程指南:如何创建和配置周期性任务

&#x1f341; 作者&#xff1a;知识浅谈&#xff0c;CSDN签约讲师&#xff0c;CSDN博客专家&#xff0c;华为云云享专家&#xff0c;阿里云专家博主 &#x1f4cc; 擅长领域&#xff1a;全栈工程师、爬虫、ACM算法 &#x1f525; 微信&#xff1a;zsqtcyw 联系我领取学习资料 …...

Java 获取客户端 IP 地址【工具类】

Java 获取客户端 IP 地址 import javax.servlet.http.HttpServletRequest; import java.net.InetAddress;/*** 网络工具类*/ public class NetUtils {/*** 获取客户端 IP 地址** param request 请求* return {link String}*/public static String getIpAddress(HttpServletReq…...

自己注册个公司做网站怎么样/阿里巴巴指数查询

简介 django为用户实现防止跨站请求伪造的功能&#xff0c;通过中间件 django.middleware.csrf.CsrfViewMiddleware 来完成。而对于django中设置防跨站请求伪造功能有分为全局和局部。 全局&#xff1a; 中间件 django.middleware.csrf.CsrfViewMiddleware 局部&#xff1a; cs…...

网站开发用什么架构/天天外链官网

来自公众号&#xff1a;孤独烟引言大家应该知道烟哥最近要(tiao 咳咳咳)&#xff0c;嗯&#xff0c;不可描述&#xff01;随手讲其中一部分知识&#xff0c;都是一些烟哥自己平时工作的总结以及经验。大家看完&#xff0c;其实能避开很多坑。而且很多问题&#xff0c;都是面试中…...

深圳电商网站制作公司/sem账户托管外包

配置rsync下行同步时 文章目录配置rsync下行同步时一、执行同步命令时1、报错如下2、报错导致原因1&#xff09;排除第一个&#xff0c;虽无法连接到主机&#xff0c;但是我master还是开机状态&#xff0c;暂时排除第一项原因2&#xff09;防火墙阻挡&#xff08;firewalld&…...

做企业网站开发哪家好/百度广告推广怎么做

参考出处&#xff1a; http://www.cnblogs.com/mq0036/p/3382732.html http://www.cnblogs.com/hongcha717/archive/2010/10/24/1859780.html 出处中判断哪个是数组指针和指针数组&#xff1f; A int*p1[10] B int(*p2)[10] 首先看看他们的类型&#xff0c;在 VS C中使用sizeof…...

免费建站网站一级123456/苏州做网站的专业公司

【题目描述】 有两堆石子,两个人轮流去取。每次取的时候,只能从较多的那堆石子里取,并且取的数目必须是较少的那堆石子数目的整数倍&#xff0c;最后谁能够把一堆石子取空谁就算赢。 比如初始的时候两堆石子的数目是25和7。 25 7 --> 11 7 --> 4 7 --> 4 3 --> 1…...

昆明制作企业网站/网站优化排名推荐

题目链接&#xff1a;http://acm.hdu.edu.cn/showproblem.php?pid3404 题意&#xff1a;一个n*m的格子里全是灯。每次选出一个矩形&#xff0c;改变四个角灯的状态&#xff0c;而且右下角的灯初始必须是开的。 思路&#xff1a;NIM积模板。没明白怎么推导的式子。 const int I…...