当前位置: 首页 > news >正文

【深度强化学习】(3) Policy Gradients 模型解析,附Pytorch完整代码

大家好,今天和各位分享一下基于策略的深度强化学习方法,策略梯度法是对策略进行建模,然后通过梯度上升更新策略网络的参数。我们使用了 OpenAI 的 gym 库,基于策略梯度法完成了一个小游戏。完整代码可以从我的 GitHub 中获得:

https://github.com/LiSir-HIT/Reinforcement-Learning/tree/main/Model


1. 基于策略的深度强化学习

针对智能体在大规模离散动作下无法建模的难题,在基于值函数的深度强化学习中,利用神经网络对 Q 值函数近似估计,使深度学习与强化学习得到完美融合。

但是基于值函数的深度强化学习有一定的不足之处:

(1) 无法处理连续动作的任务。DQN 系列的算法可以较好地解决强化学习中大规模离散动作空间的任务,但在连续动作的任务中,难以实现利用深度神经网络对所有状态-动作的 Q 值函数近似表达。

(2) 无法处理环境中状态受到限制的问题。在基于值函数深度强化学习更新网络参数时,损失函数会依赖当前状态和下一个状态的值函数,当智能体在环境中观察的状态有限或建模受到限制时,就会导致实际环境中两个不同的状态有相同的价值函数进而导致损失函数为零,出现梯度消失的问题。

(3) 智能体在环境中的探索性能较低。基于值函数的深度强化学习方法中,目标值都是从动作空间中选取一个最大价值的动作,导致智能体训练后的策略具有确定性,而面对一些需要随机策略进行探索的问题时,该方法就无法较好地解决。 

由于基于值函数的深度强化学习存在上述的一些局限性,需要新的方法来解决这些问题,于是基于策略的深度强化学习被提出。该方法中将智能体当前的策略参数化,并且使用梯度的方法进行更新。 


2. 策略梯度法

强化学习中策略梯度算法是对策略进行建模,然后通过梯度上升更新策略网络的参数。Policy Gradients 中无法使用策略的误差来构建损失函数,因为参数更新的目标是最大化累积奖励的期望值,所以策略更新的依据是某一动作对累积奖励的影响,即增加使累积回报变大的动作的概率,减弱使累积回报变小的动作的概率。 

下图代表智能体在当前策略下,完成一个回合后构成的状态、动作序列 r=\left \{ s_1,a_1,s_2,a_2,s_2,...s_T,a_T \right \},其中,Actor 是策略网络每个回合结束后的累计回报为每个状态下采取的动作的奖励之和

R = \sum_{t=1}^{T}r_t

智能体在环境中执行策略 \pi_ \theta 后状态转移概率

P_{\theta }(\tau )=P(s_1)P(r_1,s_2|s_1,a_1)\cdots P_\theta (a_T,s_T)P(r_{T+1},s_T|s_T,a_T)

P_{\theta }(\tau ) = \prod_{t=1}^{T} ( P_\theta (a_t|s_t) P(r_t,s_{t+1} | s_t,a_t) )

回合累计回报的期望

\bar{R}_\theta = \sum _{\tau } R(\tau ) P_\theta (\tau ) = E_{\tau \sim P_\theta (\tau )} [R(\tau )]

通过微分公式可以得到累计回报的梯度为:

\bigtriangledown \bar{R} _{\theta } = \sum_{\tau} R(\tau ) P_{\theta }(\tau) \bigtriangledown log P_\theta (\tau ) \approx \frac{1}{N} \sum_{n=1}^{N}\sum_{t=1}^{T} R(\tau ^n) \bigtriangledown log P_\theta (a_t^n | s_t^n)

利用累计回报的梯度更新策略网络的参数:

\theta ^ {new} \leftarrow \theta ^{old} + \beta \bigtriangledown \bar{R}_{\theta ^{old}}

其中, \beta 为梯度系数。通过上式的策略迭代可得,如果智能体在某个状态下采取的动作使累积回报增加,网络参数就会呈梯度上升趋势,该动作的概率就会增加,反之,梯度为下降趋势,减小该动作的概率。为了防止环境中所有的奖励都是正值,实现对于一些不好动作有一个负反馈,可以在总回报处减去一个基线。


3. 代码实现

策略函数 \pi (a|s) 是一个概率密度函数,用于控制智能体的运动。\pi (a|s)输入状态 s,输出每个动作 a 的概率分布。策略网络是指通过训练一个神经网络来近似策略函数。策略网络的参数 \theta 可通过策略梯度算法进行更新,从而实现对策略网络 \pi (a|s;\theta ) 的训练。 

伪代码如下

模型构建部分代码如下:

# 基于策略的学习方法,用于数值连续的问题
import numpy as np
import torch
from torch import nn
from torch.nn import functional as F# ----------------------------------------------------- #
#(1)构建训练网络
# ----------------------------------------------------- #
class Net(nn.Module):def __init__(self, n_states, n_hiddens, n_actions):super(Net, self).__init__()# 只有一层隐含层的网络self.fc1 = nn.Linear(n_states, n_hiddens)self.fc2 = nn.Linear(n_hiddens, n_actions)# 前向传播def forward(self, x):x = self.fc1(x)  # [b, states]==>[b, n_hiddens]x = F.relu(x)x = self.fc2(x)  # [b, n_hiddens]==>[b, n_actions]# 对batch中的每一行样本计算softmax,q值越大,概率越大x = F.softmax(x, dim=1)  # [b, n_actions]==>[b, n_actions]return x# ----------------------------------------------------- #
#(2)强化学习模型
# ----------------------------------------------------- #
class PolicyGradient:def __init__(self, n_states, n_hiddens, n_actions, learning_rate, gamma):# 属性分配self.n_states = n_states  # 状态数self.n_hiddens = n_hiddensself.n_actions = n_actions  # 动作数self.learning_rate = learning_rate  # 衰减self.gamma = gamma  # 折扣因子self._build_net()  # 构建网络模型# 网络构建def _build_net(self):# 网络实例化self.policy_net = Net(self.n_states, self.n_hiddens, self.n_actions)# 优化器self.optimizer = torch.optim.Adam(self.policy_net.parameters(), lr=self.learning_rate)# 动作选择,根据概率分布随机采样def take_action(self, state):  # 传入某个人的状态# numpy[n_states]-->[1,n_states]-->tensorstate = torch.Tensor(state[np.newaxis, :])# 获取每个人的各动作对应的概率[1,n_states]-->[1,n_actions]probs = self.policy_net(state)# 创建以probs为标准类型的数据分布action_dist = torch.distributions.Categorical(probs)# 以该概率分布随机抽样 [1,n_actions]-->[1] 每个状态取一组动作action = action_dist.sample()# 将tensor数据变成一个数 intaction = action.item()return action# 获取每个状态最大的state_valuedef max_q_value(self, state):# 维度变换[n_states]-->[1,n_states]state = torch.tensor(state, dtype=torch.float).view(1,-1)# 获取状态对应的每个动作的reward的最大值 [1,n_states]-->[1,n_actions]-->[1]-->floatmax_q = self.policy_net(state).max().item()return max_q# 训练模型def learn(self, transitions_dict):  # 输入batch组状态[b,n_states]# 取出该回合中所有的链信息state_list = transitions_dict['states']action_list = transitions_dict['actions']reward_list = transitions_dict['rewards']G = 0  # 记录该条链的returnself.optimizer.zero_grad()  # 优化器清0# 梯度上升最大化目标函数for i in reversed(range(len(reward_list))):# 获取每一步的reward, floatreward = reward_list[i]# 获取每一步的状态 [n_states]-->[1,n_states]state = torch.tensor(state_list[i], dtype=torch.float).view(1,-1)# 获取每一步的动作 [1]-->[1,1]action = torch.tensor(action_list[i]).view(1,-1)# 当前状态下的各个动作价值函数 [1,2]q_value = self.policy_net(state)# 获取已action对应的概率 [1,1]log_prob = torch.log(q_value.gather(1, action))# 计算当前状态的state_value = 及时奖励 + 下一时刻的state_valueG = reward + self.gamma * G# 计算每一步的损失函数loss = -log_prob * G# 反向传播loss.backward()# 梯度下降self.optimizer.step()

4. 实例演示

下面基于OpenAI 中的 gym 库,完成一个移动小车使得杆子竖直的游戏。状态states共包含 4 个,动作action有2个,向左和向右移动小车。迭代50回合,绘制每回合的 return 以及平均最大动作价值 q_max

代码如下:

import gym
import numpy as np
import matplotlib.pyplot as plt
from RL_brain import PolicyGradient# ------------------------------- #
# 模型参数设置
# ------------------------------- #n_hiddens = 16  # 隐含层个数
learning_rate = 2e-3  # 学习率
gamma = 0.9  # 折扣因子
return_list = []  # 保存每回合的reward
max_q_value = 0  # 初始的动作价值函数
max_q_value_list = []  # 保存每一step的动作价值函数# ------------------------------- #
#(1)加载环境
# ------------------------------- ## 连续性动作
env = gym.make("CartPole-v1", render_mode="human")
n_states = env.observation_space.shape[0]  # 状态数 4
n_actions = env.action_space.n  # 动作数 2# ------------------------------- #
#(2)模型实例化
# ------------------------------- #agent = PolicyGradient(n_states=n_states,  # 4n_hiddens=n_hiddens,  # 16n_actions=n_actions,  # 2learning_rate=learning_rate,  # 学习率gamma=gamma)  # 折扣因子# ------------------------------- #
#(3)训练
# ------------------------------- #for i in range(100):  # 训练10回合# 记录每个回合的returnepisode_return = 0# 存放状态transition_dict = {'states': [],'actions': [],'next_states': [],'rewards': [],'dones': [],}# 获取初始状态state = env.reset()[0]# 结束的标记done = False# 开始迭代while not done:# 动作选择action = agent.take_action(state)  # 对某一状态采取动作# 动作价值函数,曲线平滑max_q_value = agent.max_q_value(state) * 0.005 + max_q_value * 0.995# 保存每一step的动作价值函数max_q_value_list.append(max_q_value)# 环境更新next_state, reward, done, _, _ = env.step(action)# 保存每个回合的所有信息transition_dict['states'].append(state)transition_dict['actions'].append(action)transition_dict['next_states'].append(next_state)transition_dict['rewards'].append(reward)transition_dict['dones'].append(done)# 状态更新state = next_state# 记录每个回合的returnepisode_return += reward# 保存每个回合的returnreturn_list.append(episode_return)# 一整个回合走完了再训练模型agent.learn(transition_dict)# 打印回合信息print(f'iter:{i}, return:{np.mean(return_list[-10:])}')# 关闭动画
env.close()# -------------------------------------- #
# 绘图
# -------------------------------------- #plt.subplot(121)
plt.plot(return_list)
plt.title('return')
plt.subplot(122)
plt.plot(max_q_value_list)
plt.title('max_q_value')
plt.show()

左图代表每回合的return,右图为平均最大动作价值

相关文章:

【深度强化学习】(3) Policy Gradients 模型解析,附Pytorch完整代码

大家好,今天和各位分享一下基于策略的深度强化学习方法,策略梯度法是对策略进行建模,然后通过梯度上升更新策略网络的参数。我们使用了 OpenAI 的 gym 库,基于策略梯度法完成了一个小游戏。完整代码可以从我的 GitHub 中获得&…...

Windows基于Nginx搭建RTMP流媒体服务器(附带所有组件下载地址及验证方法)

RTMP服务时常用于直播时提供拉流推流传输数据的一种服务。前段时间由于朋友想搭建一套直播时提供稳定数据传输的服务器,所以就研究了一下如何搭建及使用。 1、下载nginx 首先我们要知道一般nginx不能直接配置rtmp服务,在Windows系统上需要特殊nginx版本…...

交流电机驱动器中的隔离电压感应

汽车和工业终端设备,如电机驱动器、串式逆变器和机载充电器,在高电压下运行,不能安全地与人直接互动。隔离电压测量通过保护人类免受高压电路执行一个功能的影响,有助于优化操作和确保使用的安全性。 设计用于高性能,隔…...

爬取知乎问题答案

参考博客:基于Python知乎回答爬虫 jieba关键字统计可视化_知乎爬虫搜索关键词_菠萝柚王子的博客-CSDN博客 1、安装依赖包 import numpy import requests import certifi from PIL import Image from lxml import etree import jieba from wordcloud import WordClo…...

通用智能理论

将智能定义为解决矛盾的能力,用解决矛盾的概率提升来评估智能程度,以此为基础推导智能原理,建立一种新的通用智能理论。 1 前言 通用人工智能(Artificial General Intelligence)是人类长久以来的梦想。经历了一次次挫败…...

保姆级使用PyTorch训练与评估自己的MixMIM网络教程

文章目录前言0. 环境搭建&快速开始1. 数据集制作1.1 标签文件制作1.2 数据集划分1.3 数据集信息文件制作2. 修改参数文件3. 训练4. 评估5. 其他教程前言 项目地址:https://github.com/Fafa-DL/Awesome-Backbones 操作教程:https://www.bilibili.co…...

《百万在线 大型游戏服务端开发》前两章概念笔记

第1章 从角色走路说起 游戏网络通信的流程则是服务端先开启监听,等待客户端的连接,然后交互操作,最后断开。 套接字 每个Socket都包含网络连接中一端的信息。每个客户端需要一个Socket结构,服务端则需要N1个Socket结构&#xff…...

3BHE029110R0111 ABB

3BHE029110R0111 ABB变频器控制方式低压通用变频输出电压为380~650V,输出功率为0.75~400kW,工作频率为0~400Hz,它的主电路都采用交—直—交电路。其控制方式经历了以下四代。1U/fC的正弦脉宽调制&#xff0…...

实现防重复操作(JS与CSS)

实现防重复操作(JS与CSS) 一、前言 日常开发中我们经常会对按钮进行一个防重复点击的校验,这个通常使用节流函数来实现。在规定时间内只允许提交一次,可以有效的避免事件过于频繁的执行和重复提交操作,以及为服务器考…...

怎么合并或注销重复LinkedIn领英帐号?

您可能会发现您拥有多个领英帐户。如果您收到消息,提示您尝试使用的邮箱与另一个帐户已绑定,就表明您可能存在重复的领英帐户。如果您使用许多不同的邮箱地址,也可能会收到这样的提示。 领英精灵温馨提示: 目前,仅支持在 PC 端合并…...

Redis高频面试题汇总(中)

目录 1.什么是redis事务? 2.如何使用 Redis 事务? 3.Redis 事务为什么不支持原子性 4.Redis 事务支持持久性吗 5.Redis事务基于lua脚本的实现 6.Redis集群的主从复制模型是怎样的? 7.Redis集群中,主从复制的数据同步的步骤 …...

【Flutter从入门到入坑之三】Flutter 是如何工作的

【Flutter从入门到入坑之一】Flutter 介绍及安装使用 【Flutter从入门到入坑之二】Dart语言基础概述 【Flutter从入门到入坑之三】Flutter 是如何工作的 本文章主要以界面渲染过程为例,介绍一下 Flutter 是如何工作的。 页面中的各界面元素(Widget&…...

Web Components学习(2)-语法

一、Web Components 对 Vue 的影响 尤雨溪在创建 Vue 的时候大量参考了 Web Components 的语法&#xff0c;下面写个简单示例。 首先写个 Vue 组件 my-span.vue&#xff1a; <!-- my-span.vue --> <template><span>my-span</span> </template>…...

Lesson 9.2 随机森林回归器的参数

文章目录一、弱分类器的结构1. 分枝标准与特征重要性2. 调节树结构来控制过拟合二、弱分类器的数量三、弱分类器训练的数据1. 样本的随机抽样2. 特征的随机抽样3. 随机抽样的模式四、弱分类器的其他参数在开始学习之前&#xff0c;先导入我们需要的库。 import numpy as np im…...

Kubernetes Secret简介

Secret概述 前面文章中学习ConfigMap的时候&#xff0c;我们说ConfigMap这个资源对象是Kubernetes当中非常重要的一个对象&#xff0c;一般情况下ConfigMap是用来存储一些非安全的配置信息&#xff0c;如果涉及到一些安全相关的数据的话用ConfigMap就非常不妥了&#xff0c;因…...

Redis 哨兵(Sentinel)

文章目录1.概述2. 没有哨兵下主从效果3.搭建多哨兵3.1 新建目录3.2 复制redis3.3 复制配置文件3.4 修改配置文件3.5 启动主从3.6 启动三个哨兵3.7 查看日志3.8 测试宕机1.概述 在redis主从默认是只有主具备写的能力&#xff0c;而从只能读。如果主宕机&#xff0c;整个节点不具…...

精读笔记 - How to backdoor Federated Learning

文章目录 精读笔记 - How to backdoor Federated Learning1. 基本信息2. 系统概要3. 攻击模型3.1 问题形式化定义3.1.1 前提假设3.1.2 攻击目标3.2 创新点3.2.1 Semantic Backdoor3.2.2 攻击方法4. 实验验证4.1 图像分类4.2 实验操作4.2.1 超参数设置4.2.2 衡量标准4.3 结果分析…...

即时通讯系列-N-客户端如何在推拉结合的模式下保证消息的可靠性展示

结论先行 原则&#xff1a; server拉取的消息一定是连续的原则&#xff1a; 端侧记录的消息的连续段有两个作用&#xff1a; 1. 记录消息的连续性&#xff0c; 即起始中间没有断层&#xff0c; 2. 消息连续&#xff0c; 同时意味着消息是最新的&#xff0c;消息不是过期的。同…...

关于js数据类型的理解

目录标题一、js数据类型分为 基本数据类型和引用数据类型二、区别&#xff1a;传值和传址三、深浅拷贝传值四、数据类型的判断一、js数据类型分为 基本数据类型和引用数据类型 1、基本数据类型 Number、String、Boolean、Null、undefined、BigInt、Symbol 2、引用数据类型 像对…...

大一上计算机期末考试考点

RGB颜色模型也称为相加混色模型 采样频率大于或等于原始声音信号最高频率的两倍即可还原出原始信号. 声音数字化过程中&#xff0c;采样是把时间上连续的模拟信号在时间轴上离散化的过程。 量化的主要工作就是将幅度上连续取值的每一个样本转换为离散值表示。 图像数字化过…...

微搭问搭001-如何清空表单的数据

韩老师&#xff0c;我点关闭按钮后&#xff0c;弹窗从新打开&#xff0c;里面的数据还在&#xff0c;这个可以从新打开清除不&#xff1f; 点关闭的时候清掉 就是清楚不掉也&#xff1f;咋清掉 清掉表单内容有属性可以做到&#xff1f; $page.widgets.id**.value “” 就可以实…...

Windows7,10使用:Vagrant+VirtualBox 安装 centos7

一、Vagrant&#xff0c;VirtualBox 是什么二、版本说明1、win7下建议安装版本2、win10下建议安装版本三、Windows7下安装1、安装Vagrant2、安装VirtualBox3、打开VirtualBox&#xff0c;配置虚拟机默认安装地址四、windows7下载.box文件&#xff0c;安装centos 71、下载一个.b…...

基于JavaEE开发博客系统项目开发与设计(附源码)

文章目录1.项目介绍2.项目模块3.项目效果1.项目介绍 这是一个基于JavaEE开发的一个博客系统。实现了博客的基本功能&#xff0c;前台页面可以进行文章浏览&#xff0c;关键词搜索&#xff0c;登录注册&#xff1b;登陆后支持对文章进行感谢、评论&#xff1b;然后还可以对评论…...

Android Framework——zygote 启动 SystemServer

概述 在Android系统中&#xff0c;所有的应用程序进程以及系统服务进程SystemServer都是由Zygote进程孕育&#xff08;fork&#xff09;出来的&#xff0c;这也许就是为什么要把它称为Zygote&#xff08;受精卵&#xff09;的原因吧。由于Zygote进程在Android系统中有着如此重…...

在ubuntu上搭建SSH和FTP和NFS和TFTP

一、SSH服务搭建使用如下命令安装 SSH 服务&#xff1b;ssh 的配置文件为/etc/ssh/sshd_config&#xff0c;使用默认配置即可。sudo apt-get install openssh-server开启 SSH 服务以后我们就可以在 Windwos 下使用终端软件登陆到 Ubuntu&#xff0c;比如使用 Mobaxterm。二、FT…...

ThinkPHP 6.1 模板篇之文件加载

本文主要讲述模板中如何使用包含文件、引入css/js文件及路径优化。 包含文件 使用{include}标签来加载公用重复的文件&#xff0c;比如头部、尾部和导航部分 包含用法 1.创建公用文件 在模版 view 目录创建一个 common公共目录&#xff0c;分别创建 header、footer 和 nav …...

操作系统内核与安全分析课程笔记【1】链表、汇编与makefile

文章目录链表循环双向链表哈希链表其他链表汇编内联汇编扩展内联汇编makefile链表 链表是linux内核中关键的数据结构。在第二次课中&#xff0c;重点介绍了循环双向链表和哈希链表。这两种链表都在传统的双向链表的基础之上进行了针对效率的优化。(ps&#xff1a;这部分可以通…...

华为OD机试题 - 九宫格按键输入(JavaScript)| 机考必刷

更多题库,搜索引擎搜 梦想橡皮擦华为OD 👑👑👑 更多华为OD题库,搜 梦想橡皮擦 华为OD 👑👑👑 更多华为机考题库,搜 梦想橡皮擦华为OD 👑👑👑 华为OD机试题 最近更新的博客使用说明本篇题解:九宫格按键输入题目输入输出示例一输入输出说明示例二输入输出说…...

PMSM控制_foc 控制环路

整个系统的控制过程有以下部分&#xff0c;以无感FOC&#xff0c;双电阻电流采样&#xff0c;控制周期为 10KHz 为例&#xff1a; 1、在每隔一个 PWM 周期采样一次两相电流 2、进行 FOC 的计算 &#xff08;1&#xff09;clarke 变换&#xff0c;将电流变换至静止坐标系下的 Ia…...

Linux 练习七 (IPC 共享内存)

文章目录System V 共享内存机制&#xff1a;shmget shmat shmdt shmctl案例一&#xff1a;有亲缘关系的进程通信案例二&#xff1a;非亲缘关系的进程通信内存写端write1.c内存读端read1.c案例三&#xff1a;不同程序之间的进程通信程序一&#xff0c;写者shmwr.c程序二&#xf…...

做外贸主要在那些网站找单/火蝠电商代运营公司

问题&#xff1a;完整错误信息&#xff1a; ./mgw: error while loading shared libraries: libxxx.so: cannot open shared object file: No such file or directory 方法&#xff1a;直接在运行./mgw的.sh脚本中添加下面语句即可&#xff1a; export LD_LIBRARY_PATH$LD_LI…...

crm系统价格/百度竞价关键词优化

1、设备对象 引入uiautomator&#xff0c;获取设备对象<所谓设备对象可理解为&#xff1a;Android模拟器或者真机> 语法&#xff1a;from uiautomator import device as d d 即为设备对象 1.1、获取设备信息 语法&#xff1a;d.info 返回值&#xff1a; Python{ udisplay…...

论坛网站建设推广优化/重庆网站seo技术

最近在项目中遇到一个需求&#xff0c;需要把项目部署在客户内网服务器上&#xff08;无外网&#xff09;&#xff0c;内网服务器&#xff08;无外网&#xff09;需要访问公网IP&#xff0c;由于我们的项目包括外接谷歌地图接口&#xff0c;视频直播接口&#xff1b;项目经理就…...

迅驰互联网站建设网络推广怎么样/营销软文推广平台

调用函数 / 类型转换 / 切片/ 迭代1. 调用函数:abs(),max(),min()2. 数据类型转换:int(),float(),str(),tool(),aabs,3. 定义函数,如果没有return语句&#xff0c;函数执行完毕后也会返回结果&#xff0c;只是结果为None在Python中&#xff0c;定义一个函数要使用def语句&…...

wordpress调用多媒体/自己如何制作一个小程序

模拟真实业务的这么一个小型的项目&#xff0c;来全程贯穿&#xff0c;用这个项目中的业务场景去一个一个的讲解hystrix高可用的每个技术。 大背景&#xff1a;电商网站&#xff0c;首页&#xff0c;商品详情页&#xff0c;搜索结果页&#xff0c;广告页&#xff0c;促销活动&…...

大连sem网站建设/搜索优化指的是什么

大家好&#xff0c;这里是小虾虾科技屋——————————————————————有时候&#xff0c;某个软件随着更新&#xff0c;会带走许多历史的痕迹&#xff0c;于是有的人就想尝试复原那些软件&#xff0c;所以产生了今天为大家分享的工具。iTunes旧版软件抓包神器&a…...