当前位置: 首页 > news >正文

Stable-Baselines 3 部分源代码解读 2 on_policy_algorithm.py

Stable-Baselines 3 部分源代码解读 ./common/on_policy_algorithm.py

前言

阅读PPO相关的源码,了解一下标准库是如何建立PPO算法以及各种tricks的,以便于自己的复现。

在Pycharm里面一直跳转,可以看到PPO类是最终继承于基类,也就是这个py文件的内容。

所以阅读源码就先从这里开始。: )

import 包

import sys
import time
from typing import Any, Dict, List, Optional, Tuple, Type, TypeVar, Unionimport numpy as np
import torch as th
from gym import spacesfrom stable_baselines3.common.base_class import BaseAlgorithm
from stable_baselines3.common.buffers import DictRolloutBuffer, RolloutBuffer
from stable_baselines3.common.callbacks import BaseCallback
from stable_baselines3.common.policies import ActorCriticPolicy
from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule
from stable_baselines3.common.utils import obs_as_tensor, safe_mean
from stable_baselines3.common.vec_env import VecEnv

OnPolicyAlgorithm 类

这个类是PPO算法类的中间曾,夹在底层基类和上层PPO类的之间。

主要是同策略算法,例如:A2C和PPO算法。

policyenvlearning_rate三者与基类base-class.py的一致

n_steps表示每次更新前需要经过的时间步,作者在这里给出了n_steps * n_envs的例子,可能的意思是,如果环境是重复的多个,打算做并行训练的话,那么就是每个子环境的时间步乘以环境的数量

batch_size经验回放的最小批次信息

gammagae_lambdaclip_rangeclip_range_vf均是具有默认值的参数,分别代表“折扣因子”、“GAE奖励中平衡偏置和方差的参数”、“为网络参数而限制幅度的范围”、“为值函数网络参数而限制幅度的范围”

normalize_advantage标志是否需要归一化优势(advantage)

ent_coefvf_coef损失计算的熵系数

max_grad_norm最大的梯度长度,梯度下降的限幅

use_sdesde_sample_freq是状态独立性探索,只适用于连续环境,与基类base-class.py的一致

target_kl限制每次更新时KL散度不能太大,因为clipping限幅不能防止大量更新

monitor_wrapper标志是否需要Gym库提供的监视器包装器

_init_setup_model是否建立模型,也就是是否在创建这个实例过程中创建初始化模型

class OnPolicyAlgorithm(BaseAlgorithm):"""The base for On-Policy algorithms (ex: A2C/PPO).:param policy: The policy model to use (MlpPolicy, CnnPolicy, ...):param env: The environment to learn from (if registered in Gym, can be str):param learning_rate: The learning rate, it can be a functionof the current progress remaining (from 1 to 0):param n_steps: The number of steps to run for each environment per update(i.e. batch size is n_steps * n_env where n_env is number of environment copies running in parallel):param gamma: Discount factor:param gae_lambda: Factor for trade-off of bias vs variance for Generalized Advantage Estimator.Equivalent to classic advantage when set to 1.:param ent_coef: Entropy coefficient for the loss calculation:param vf_coef: Value function coefficient for the loss calculation:param max_grad_norm: The maximum value for the gradient clipping:param use_sde: Whether to use generalized State Dependent Exploration (gSDE)instead of action noise exploration (default: False):param sde_sample_freq: Sample a new noise matrix every n steps when using gSDEDefault: -1 (only sample at the beginning of the rollout):param tensorboard_log: the log location for tensorboard (if None, no logging):param monitor_wrapper: When creating an environment, whether to wrap itor not in a Monitor wrapper.:param policy_kwargs: additional arguments to be passed to the policy on creation:param verbose: Verbosity level: 0 for no output, 1 for info messages (such as device or wrappers used), 2 fordebug messages:param seed: Seed for the pseudo random generators:param device: Device (cpu, cuda, ...) on which the code should be run.Setting it to auto, the code will be run on the GPU if possible.:param _init_setup_model: Whether or not to build the network at the creation of the instance:param supported_action_spaces: The action spaces supported by the algorithm."""def __init__(self,policy: Union[str, Type[ActorCriticPolicy]],env: Union[GymEnv, str],learning_rate: Union[float, Schedule],n_steps: int,gamma: float,gae_lambda: float,ent_coef: float,vf_coef: float,max_grad_norm: float,use_sde: bool,sde_sample_freq: int,tensorboard_log: Optional[str] = None,monitor_wrapper: bool = True,policy_kwargs: Optional[Dict[str, Any]] = None,verbose: int = 0,seed: Optional[int] = None,device: Union[th.device, str] = "auto",_init_setup_model: bool = True,supported_action_spaces: Optional[Tuple[spaces.Space, ...]] = None,):super().__init__(policy=policy,env=env,learning_rate=learning_rate,policy_kwargs=policy_kwargs,verbose=verbose,device=device,use_sde=use_sde,sde_sample_freq=sde_sample_freq,support_multi_env=True,seed=seed,tensorboard_log=tensorboard_log,supported_action_spaces=supported_action_spaces,)self.n_steps = n_stepsself.gamma = gammaself.gae_lambda = gae_lambdaself.ent_coef = ent_coefself.vf_coef = vf_coefself.max_grad_norm = max_grad_normself.rollout_buffer = None# 调用基类的_setup_model()模型if _init_setup_model:self._setup_model()def _setup_model(self) -> None:# 初始化学习率,让他可以调用self._setup_lr_schedule()# 设置随机数种子self.set_random_seed(self.seed)# 设置经验池子的类,如果观测空间是spaces.Dict类那么就赋值DictRolloutBuffer# 如果观测空间不是spaces.Dict类那么就赋值RolloutBufferbuffer_cls = DictRolloutBuffer if isinstance(self.observation_space, spaces.Dict) else RolloutBuffer# 根据类初始化实例经验池子# 初始化经验池子的是时候将设备信息、折扣率、GAE超参数和环境的数量也传进去了self.rollout_buffer = buffer_cls(self.n_steps,self.observation_space,self.action_space,device=self.device,gamma=self.gamma,gae_lambda=self.gae_lambda,n_envs=self.n_envs,)# 初始化策略,直接输入状态空间、动作空间、可调用的学习率、是否使用状态独立性探索,以及自己制定策略# 的时候自己家的模型的参数和激活函数self.policy = self.policy_class(  # pytype:disable=not-instantiableself.observation_space,self.action_space,self.lr_schedule,use_sde=self.use_sde,**self.policy_kwargs  # pytype:disable=not-instantiable)# 将策略放到GPU/CPU中self.policy = self.policy.to(self.device)def collect_rollouts(self,env: VecEnv,callback: BaseCallback,rollout_buffer: RolloutBuffer,n_rollout_steps: int,) -> bool:# 收集环境交互数据# 这个方法使用当前的策略并将交互历史填充到RolloutBuffer经验池子中# rollout的意思是无模型的概念,而不是有模型的RL或规划里面的rollout的概念# env 用于训练的环境# callback 在每个时间步都会调用的回调函数# rollout_buffer 将收集的经验放置到rollout_buffer中# 在每个环境中需要收集的条数# 返回值是True:如果rollout_buffer收集了这么多的经验;返回值是False:如果回调函数提前终止了# 这个rollouts。"""Collect experiences using the current policy and fill a ``RolloutBuffer``.The term rollout here refers to the model-free notion and should notbe used with the concept of rollout used in model-based RL or planning.:param env: The training environment:param callback: Callback that will be called at each step(and at the beginning and end of the rollout):param rollout_buffer: Buffer to fill with rollouts:param n_rollout_steps: Number of experiences to collect per environment:return: True if function returned with at least `n_rollout_steps`collected, False if callback terminated rollout prematurely."""assert self._last_obs is not None, "No previous observation was provided"# 将策略转变到评估模式# Switch to eval mode (this affects batch norm / dropout)self.policy.set_training_mode(False)# 重置经验池子,如果使用状态独立性探索,那么就重置策略的噪声n_steps = 0rollout_buffer.reset()# Sample new weights for the state dependent explorationif self.use_sde:self.policy.reset_noise(env.num_envs)# 回调函数执行on_rollout_start()命令,跳转定义时候没有看到具体定义callback.on_rollout_start()while n_steps < n_rollout_steps:# 如果使用了状态独立性探索,并且达到了探索频率的节点,那么就重置策略的噪声if self.use_sde and self.sde_sample_freq > 0 and n_steps % self.sde_sample_freq == 0:# Sample a new noise matrixself.policy.reset_noise(env.num_envs)# 在断开梯度的情况下,转换观测数据到tensor张量内,然后输入到策略中输出动作、价值和对数概率# 最后再将动作数据转移到numpy中with th.no_grad():# Convert to pytorch tensor or to TensorDictobs_tensor = obs_as_tensor(self._last_obs, self.device)actions, values, log_probs = self.policy(obs_tensor)actions = actions.cpu().numpy()# Rescale and perform action# 归一化动作信息,限制在动作空间的上下界clipped_actions = actions# Clip the actions to avoid out of bound errorif isinstance(self.action_space, spaces.Box):clipped_actions = np.clip(actions, self.action_space.low, self.action_space.high)# 将动作信息输入到环境中,输出新的观测、奖励数值、是否完成以及其他信息。new_obs, rewards, dones, infos = env.step(clipped_actions)# 处理回调函数和更新经验池子self.num_timesteps += env.num_envs# Give access to local variablescallback.update_locals(locals())if callback.on_step() is False:return Falseself._update_info_buffer(infos)n_steps += 1# 如果动作空间是离散空间的话,那么就转变成一个列向量if isinstance(self.action_space, spaces.Discrete):# Reshape in case of discrete actionactions = actions.reshape(-1, 1)# 判断数据是否是终止的# 终止之后计算累计奖励# Handle timeout by bootstraping with value function# see GitHub issue #633for idx, done in enumerate(dones):if (doneand infos[idx].get("terminal_observation") is not Noneand infos[idx].get("TimeLimit.truncated", False)):terminal_obs = self.policy.obs_to_tensor(infos[idx]["terminal_observation"])[0]with th.no_grad():terminal_value = self.policy.predict_values(terminal_obs)[0]rewards[idx] += self.gamma * terminal_value# 经验池子输入的是上一个状态、动作、奖励、上一个回合的开始状态、价值列表以及对数概率rollout_buffer.add(self._last_obs, actions, rewards, self._last_episode_starts, values, log_probs)self._last_obs = new_obsself._last_episode_starts = dones# 计算下一个状态的价值with th.no_grad():# Compute value for the last timestepvalues = self.policy.predict_values(obs_as_tensor(new_obs, self.device))# 计算回报和优势rollout_buffer.compute_returns_and_advantage(last_values=values, dones=dones)callback.on_rollout_end()return Truedef train(self) -> None:# 这个是父类的方法# 在子类的实际PPO类中做了重写"""Consume current rollout data and update policy parameters.Implemented by individual algorithms."""raise NotImplementedErrordef learn(self: SelfOnPolicyAlgorithm,total_timesteps: int,callback: MaybeCallback = None,log_interval: int = 1,tb_log_name: str = "OnPolicyAlgorithm",reset_num_timesteps: bool = True,progress_bar: bool = False,) -> SelfOnPolicyAlgorithm:iteration = 0# 初始化模型total_timesteps, callback = self._setup_learn(total_timesteps,callback,reset_num_timesteps,tb_log_name,progress_bar,)callback.on_training_start(locals(), globals())while self.num_timesteps < total_timesteps:# 这里开始执行上面的函数,在环境中收集数据,收集完了就继续训练# 如果出了故障了,就在接下来跳出循环continue_training = self.collect_rollouts(self.env, callback, self.rollout_buffer, n_rollout_steps=self.n_steps)if continue_training is False:break# 跌带次数+1,并根据当前的训练次数更新学习率iteration += 1self._update_current_progress_remaining(self.num_timesteps, total_timesteps)# 在控制台按照预先定义的频率输出相关信息# Display training infosif log_interval is not None and iteration % log_interval == 0:time_elapsed = max((time.time_ns() - self.start_time) / 1e9, sys.float_info.epsilon)fps = int((self.num_timesteps - self._num_timesteps_at_start) / time_elapsed)self.logger.record("time/iterations", iteration, exclude="tensorboard")if len(self.ep_info_buffer) > 0 and len(self.ep_info_buffer[0]) > 0:self.logger.record("rollout/ep_rew_mean", safe_mean([ep_info["r"] for ep_info in self.ep_info_buffer]))self.logger.record("rollout/ep_len_mean", safe_mean([ep_info["l"] for ep_info in self.ep_info_buffer]))self.logger.record("time/fps", fps)self.logger.record("time/time_elapsed", int(time_elapsed), exclude="tensorboard")self.logger.record("time/total_timesteps", self.num_timesteps, exclude="tensorboard")self.logger.dump(step=self.num_timesteps)self.train()callback.on_training_end()return selfdef _get_torch_save_params(self) -> Tuple[List[str], List[str]]:state_dicts = ["policy", "policy.optimizer"]return state_dicts, []

相关文章:

Stable-Baselines 3 部分源代码解读 2 on_policy_algorithm.py

Stable-Baselines 3 部分源代码解读 ./common/on_policy_algorithm.py 前言 阅读PPO相关的源码&#xff0c;了解一下标准库是如何建立PPO算法以及各种tricks的&#xff0c;以便于自己的复现。 在Pycharm里面一直跳转&#xff0c;可以看到PPO类是最终继承于基类&#xff0c;也…...

15. Qt中OPenGL的参数传递问题

1. 说明 在OPenGL中&#xff0c;需要使用GLSL语言来编写着色器的函数&#xff0c;在顶点着色器和片段着色器之间需要参数值的传递&#xff0c;且在CPU中的数据也需要传递到顶点着色器中进行使用。本文简单介绍几种参数传递的方式&#xff1a; &#xff08;本文内容仅个人理解&…...

注意,这本2区SCI期刊最快18天录用,还差一步录用只因犯了这个错

发表案例分享&#xff1a; 2区医学综合类SCI&#xff0c;仅18天录用&#xff0c;录用后28天见刊 2023.02.10 | 见刊 2023.01.13 | Accepted 2023.01.11 | 提交返修稿 2022.12.26 | 提交论文至期刊部系统 录用截图来源&#xff1a;期刊部投稿系统 见刊截图来源&#xff1a…...

Could not find resource jdbc.properties问题的解决

以如下开头的内容&#xff1a; Exception in thread "main" org.apache.ibatis.exceptions.PersistenceException: ### Error building SqlSession. ### The error may exist in SQL Mapper Configuration 出现以上问题是没有在src/main/resources下创建jdbc.prop…...

【面试题】==与equals区别、Hashcode作用、hashcode相同equals()也一定为true吗?泛型特点与好处

文章目录1. 和 equals 的区别是什么&#xff1f;2.Hashcode的作用3. 两个对象的hashCode() 相同&#xff0c; 那么equals()也一定为 true吗&#xff1f;4.泛型常用特点5.使用泛型的好处&#xff1f;1. 和 equals 的区别是什么&#xff1f; “” 对于基本类型和引用类型 的作…...

Flex布局中的flex属性

1.flex-grow&#xff0c;flex-shrink&#xff0c;flex-basis取值含义 flex-grow&#xff1a; 延申性描述。在满足“延申条件”时&#xff0c;flex容器中的项目会按照设置的flex-grow值的比例来延申&#xff0c;占满容器剩余空间。 取值情况&#xff1a; 取负值无效。取0值表示不…...

SpringBoot + Ant Design Pro Vue实现动态路由和菜单的前后端分离框架

Ant Design Pro Vue默认路由和菜单配置是采用中心化的方式&#xff0c;在 router.config.js统一配置和管理&#xff0c;同时也提供了动态获取路由和菜单的解决方案&#xff0c;并将在2.0.3版本中提供&#xff0c;因到目前为止&#xff0c;官方发布的版本为2.0.2&#xff0c;所以…...

robotframework自动化测试环境搭建

环境说明 win10 python版本&#xff1a;3.8.3rc1 安装清单 安装配置 selenium安装 首先检查pip命令是否安装&#xff1a; C:\Users\name>pipUsage:pip <command> [options]Commands:install Install packages.download Do…...

尚硅谷《Redis7》(小白篇)

尚硅谷《Redis7 》&#xff08;小白篇&#xff09; 02 redis 是什么 官方网站&#xff1a; https://redis.io/ 作者 Git Hub https://github.com/antirez 03 04 05 能做什么 06 去哪下 Download https://redis.io/download/ redis中文文档 https://www.redis.com.cn/docu…...

并非从0开始的c++ day6

并非从0开始的c day6二级指针练习-文件读写位运算位逻辑运算符按位取反 ~位于&#xff08;AND&#xff09;&#xff1a;&位或&#xff08;OR&#xff09;&#xff1a; |位异或: ^移位运算符左移<<右移>>多维数组一维数组数组名一维数组名传入到函数参数中数组指…...

PMP考前冲刺2.22 | 2023新征程,一举拿证

承载2023新一年的好运让我们迈向PMP终点一起冲刺&#xff01;一起拿证&#xff01;每日5道PMP习题助大家上岸PMP&#xff01;&#xff01;&#xff01;题目1-2&#xff1a;1.在新产品开发过程中&#xff0c;项目经理关注到行业排名第一的公司刚刚发布同类型的产品。相比竞品&am…...

RxJava的订阅过程

要使用Rxjava首先要导入两个包&#xff0c;其中rxandroid是rxjava在android中的扩展 implementation io.reactivex:rxandroid:1.2.1implementation io.reactivex:rxjava:1.2.0首先从最基本的Observable的创建到订阅开始分析 Observable.create(new Observable.OnSubscribe<S…...

【2.22】MySQL、Redis、动态规划

认识Redis Redis是一种基于内存的数据库&#xff0c;对数据的读写操作都是在内存中完成的&#xff0c;因此读写速度非常快&#xff0c;常用于缓存&#xff0c;消息队列&#xff0c;分布式锁等场景。 Redis提供了多种数据类型来支持不同的业务场景&#xff0c;比如String(字符串…...

2年手动测试,裸辞后找不到工作怎么办?

我们可以从以下几个方面来具体分析下&#xff0c;想通了&#xff0c;理解透了&#xff0c;才能更好的利用资源提升自己。一、我会什么&#xff1f;先说第一个我会什么&#xff1f;第一反应&#xff1a;我只会功能测试&#xff0c;在之前的4年的中我只做了功能测试。内心存在一种…...

Leetcode6. N字形变换

一、题目描述&#xff1a; 将一个给定字符串 s 根据给定的行数 numRows &#xff0c;以从上往下、从左到右进行 Z 字形排列。 比如输入字符串为 “PAYPALISHIRING” 行数为 3 时&#xff0c;排列如下&#xff1a; 之后&#xff0c;你的输出需要从左往右逐行读取&#xff0c;产…...

将Nginx 核心知识点扒了个底朝天(十)

ngx_http_upstream_module的作用是什么? ngx_http_upstream_module用于定义可通过fastcgi传递、proxy传递、uwsgi传递、memcached传递和scgi传递指令来引用的服务器组。 什么是C10K问题? C10K问题是指无法同时处理大量客户端(10,000)的网络套接字。 Nginx是否支持将请求压…...

GPU显卡环境配置安装

前言 最近公司购买了一张RTX3090的显卡和一台新的服务器&#xff0c;然后对机器的GPU环境进行了安装和配置&#xff0c;然后简单记录一下 环境版本 操作系统&#xff1a;Centos7.8 显卡型号&#xff1a;RTX3090 Python版本&#xff1a;3.7.6 Tensorflow版本&#xff1a;2…...

CIMCAI super unmanned intelligent gate container damage detect

世界港航人工智能领军者企业CIMCAI中集飞瞳打造全球最先进超级智能闸口无人闸口ceaspectusG™视频流动态感知集装箱箱况残损检测箱况残损识别率99%以上&#xff0c;箱信息识别率99.95%以上World port shipping AI leader CIMCAIThe worlds most advanced super intelligent gat…...

web概念概述

软件架构&#xff1a;1. C/S: Client/Server 客户端/服务器端* 在用户本地有一个客户端程序&#xff0c;在远程有一个服务器端程序* 如&#xff1a;QQ&#xff0c;迅雷...* 优点&#xff1a;1. 用户体验好* 缺点&#xff1a;1. 开发、安装&#xff0c;部署&#xff0c;维护 麻烦…...

编译原理笔记(1)绪论

文章目录1.什么是编译2.编译系统的结构3.词法分析概述4.语法分析概述5.语义分析概述6.中间代码生成和后端概述1.什么是编译 编译的定义&#xff1a;将高级语言翻译成汇编语言或机器语言的过程。前者称为源语言&#xff0c;后者称为目标语言。 高级语言源程序的处理过程&#…...

MySQL(八)

服务器参数设置 general datadir/var/lib/mysql 数据文件存放的目录socket/var/lib/mysql/mysql.sock mysql.socket表示server和client在同一台服务器&#xff0c;并且使用localhost进行连接&#xff0c;就会使用socket进行连接pid_file/var/lib/mysql/mysql.pid 存储mysql的p…...

steam搬砖项目,小投入高回报,可放大操作,(内附教学资料)

我必须要说&#xff0c;steam搬砖项目就是全网门槛最低的副业&#xff0c;有手就行&#xff01; 本人90后底层员工一枚&#xff0c;新入csgo搬砖项目&#xff0c;轻松翻身 什么做抖音、海外问卷、直播卖货&#xff0c;电商等等对比我这个都是小钱。我这个方法是利用了大部分人…...

华为OD机试真题Python实现【最多提取子串数目】真题+解题思路+代码(20222023)

最多提取子串数目 题目 给定由 [a-z] 26 个英文小写字母组成的字符串 A 和 B,其中 A 中可能存在重复字母,B 中不会存在重复字母 现从字符串 A 中按规则挑选一些字母,可以组成字符串 B。 挑选规则如下: 1) 同一个位置的字母只能被挑选一次 2) 被挑选字母的相对先后顺序不…...

day32 多线程(上)

文章目录相关概念codeThreadTest01ThreadTest02 编写一个类&#xff0c;直接继承java.lang.Thread&#xff0c;重写run方法ThreadTest03 实现线程的第二种方法ThreadTest04 采用匿名内部类的方式ThreadTest05 获取线程名字ThreadTest06 sleep方法sleep面试题ThreadTest08 终止线…...

【flink】 各种join类型对比

表定义 动态表(dynamic table)&#xff1a;动态表是流的另一种表达方式&#xff0c;动态表作为一个逻辑的抽象概念&#xff0c;使我们更容易理解flink中将streaming发展到table这个层次的设计&#xff0c;本质都是对无边界、持续变更数据的表示形式&#xff0c;所以动态表与流之…...

常用正则表达式

一、校验数字的表达式 数字&#xff1a;^[0-9]*$ n位的数字&#xff1a;^\d{n}$ 至少n位的数字&#xff1a;^\d{n,}$ m-n位的数字&#xff1a;^\d{m,n}$ 零和非零开头的数字&#xff1a;^(0|[1-9][0-9]*)$ 非零开头的最多带两位小数的数字&#xff1a;^([1-9][0-9]*)(.[0…...

PMP考试有没有什么技巧可以介绍一下么?

一、试题形式 ——中英文对照 即每道题都是一遍英文&#xff0c;一遍翻译的中文&#xff0c;在审题的时候有一些小的技巧需要注意。首先如果你的英文水平足够好&#xff0c;建议直接阅读原文。PMP试题毕竟是美国人出的&#xff0c;语言的组织、思想的表达&#xff0c;肯定更符…...

2022-2023年营销报告(B站平台) | 5大行业势态、流量大盘全景洞察

一直以来&#xff0c;手持高活跃、高粘性用户群体的B站是行业用来观察年轻人消费习惯的重要平台。以至于用户群体的不断壮大带动了B站的商业价值。如今B站的商业舞台越来越大&#xff0c;不断地向外界招手&#xff0c;欢迎更多品牌积极加入到这个千万年轻人聚集的内容社区。为了…...

Python的异常与工具包

异常 当检测到一个错误时&#xff0c;python解释器就无法继续执行了&#xff0c;反而出现了一些错误的提示&#xff0c;这就是所谓的异常。 捕获异常 世界上没有完美的程序&#xff0c;任何程序在运行的过程中&#xff0c;都有可能出现异常&#xff0c;导致程序无法完美运行…...

基于SSM的婴幼儿商城

基于SSM的婴幼儿商城 ✌全网粉丝20W,csdn特邀作者、博客专家、CSDN新星计划导师、java领域优质创作者,博客之星、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ &#x1f345;文末获取项目下载方式&#x1f345; 一、项目背景介绍&#xff1a; …...