【深度强化学习】(2) Double DQN 模型解析,附Pytorch完整代码
大家好,今天和大家分享一个深度强化学习算法 DQN 的改进版 Double DQN,并基于 OpenAI 的 gym 环境库完成一个小游戏,完整代码可以从我的 GitHub 中获得:
https://github.com/LiSir-HIT/Reinforcement-Learning/tree/main/Model
1. 算法原理
1.1 DQN 原理回顾
DQN 算法的原理是指导机器人不断与环境交互,理解最佳的行为方式,最终学习到最优的行为策略,机器人与环境的交互过程如下图所示。
机器人与环境的交互过程是机器人在 时刻,采取动作
并作用于环境,然后环境从
时刻状态
转变到
时刻状态
,同时奖励函数对
进行评价得到奖励值
。机器人根据
不断优化行为轨迹,最终学习到最优的行为策略。
整个强化学习过程可简化为马尔可夫决策过程(MDP)。其中所有状态均具备马尔可夫性。MDP 可用一个五元组 表示,各元素意义如下:
(1)S 是有限状态集合,即机器人在环境中探索到的所有可能状态。 表示机器人在 t 时刻的状态。
(2)A 是有限动作集合,即智能体根据 采取的所有可能动作集合。
表示机器人在t 时刻状态采取的行为
(3)P 是状态转移概率,定义如下:
(4)R 是奖励函数,即机器人基于 采取
后获得的期望奖励,定义如下:
(5)γ 代表折扣因子,值域 [0,1],即未来的期望奖励在当前时刻的价值比例。
MDP 中,价值函数包括状态价值函数和动作价值函数。动作价值函数(Q 函数)表示在策略 的指导下,根据状态
,采取行为
所获得的期望回报。策略
表示状态到行为的映射,相当于机器人的决策策略,并根据不同的状态选择不同的行为,即
。机器人在策略
的指导下,Q 函数的定义如下:
式中 表示折扣奖励,定义如下:
式中 随训练过程迭代减小,
越小表示未来的奖励对当前时刻的奖励影响越小。
Q 函数的贝尔曼方程表示如下:
式中, 表示机器人在
时采取
所获得的即时奖励,等式右侧第二项表示机
器人执行策略 产生的未来累计奖励的期望。
通过选取最大动作价值函数求解最优行为策略的公式表示如下:
DQN 算法通过行为的奖励值构造算法训练的标签,并且其中的经验回放(Experience Replay)和目标网络有效的解决了数据相关性和非静态分布的问题。DQN 算法的结构示意图如下。
DQN 的网络结构由目标网络和估计网络组成,这两个网络的结构相同但参数不同。估计网络具有最新的网络参数,计算当前状态-动作对的价值,并定期更新目标网络的参数,使其计算目标 Q 值。双网络结构打破了数据之间的相关性,使DQN 学习不同的数据分布。经验回放部分储存了智能体(机器人)的历史行为信息,其中包括多组行为序列对 ,即当前时刻状态 s ,行为 a ,奖励 r 以及下一时刻状态 s'。当 DQN 算法更新时,随机从经验池中抽取部分行为序列对进行经验回放,这种方式解决了经验池中数据相关性强和非静态分布导致的模型泛化能力差的问题。
DQN 算法通过贪婪法直接获得目标 Q 值,贪婪法通过最大化方式使 Q 值快速向可能的优化目标收敛,但易导致过估计Q 值的问题,使模型具有较大的偏差。DQN 算法过估计 Q 值的问题不适合机器人操作行为的研究,采用 Double DQN 算法解耦动作的选择和目标 Q 值的计算,以解决过估计 Q 值的问题。
1.2 Double DQN 原理
Double DQN 算法是 DQN 算法的改进版本,解决了 DQN 算法过估计行为价值的问题。DQN 算法中,某一时刻状态为非终止状态时,目标 Q 值的计算公式如下所示:
Double DQN 算法不直接通过最大化的方式选取目标网络计算的所有可能 Q 值,而是首先通过估计网络选取最大 Q 值对应的动作,公式表示如下:
然后目标网络根据 计算目标 Q 值,公式表示如下:
最后将上面两个公式结合,目标 Q 值的最终表示形式如下:
目标是最小化目标函数,即最小化估计 Q 值和目标 Q 值的差值,公式如下:
结合目标函数,损失函数定义如下:
Double DQN 的伪代码:
Double DQN 算法结构如下。在 Double DQN 框架中存在两个神经网络模型,分别是训练网络与目标网络。这两个神经网络模型的结构完全相同,但是权重参数不同;每训练一段之间后,训练网络的权重参数才会复制给目标网络。训练时,训练网络用于估计当前的 ,而目标网络用于估计
,这样就能保证真实值
的估计不会随着训练网络的不断自更新而变化过快。此外,DQN 还是一种支持离线学习的框架,即通过构建经验池的方式离线学习过去的经验。将均方误差
作为训练模型的损失函数,通过梯度下降法进行反向传播,对训练模型进行更新;若干轮经验池采样后,再将训练模型的权重赋给目标模型,以此进行 Double DQN 框架下的模型自学习。
2. 代码实现
模型构建部分的代码如下:
import torch
from torch import nn
from torch.nn import functional as F
import numpy as np
import collections # 队列
import random# ----------------------------------- #
#(1)经验回放池
# ----------------------------------- #class ReplayBuffer:def __init__(self, capacity):# 创建一个队列,先进先出,队列长度不变self.buffer = collections.deque(maxlen=capacity)# 填充经验池def add(self, state, action, reward, next_state, done):self.buffer.append((state, action, reward, next_state, done))# 随机采样batch组样本数据def sample(self, batch_size):transitions = random.sample(self.buffer, batch_size)# 分别取出这些数据,*获取list中的所有值state, action, reward, next_state, done = zip(*transitions)# 将state变成数组,后面方便计算return np.array(state), action, reward, np.array(next_state), done# 队列的长度def size(self):return len(self.buffer)# ----------------------------------- #
#(2)构造网络,训练网络和目标网络共用该结构
# ----------------------------------- #class Net(nn.Module):def __init__(self, n_states, n_hiddens, n_actions):super(Net, self).__init__()# 只有一个隐含层self.fc1 = nn.Linear(n_states, n_hiddens)self.fc2 = nn.Linear(n_hiddens, n_actions)# 前向传播def forward(self, x):x = self.fc1(x) # [b,n_states]-->[b,n_hiddens]x = self.fc2(x) # [b,n_hiddens]-->[b,n_actions]return x# ----------------------------------- #
#(3)模型构建
# ----------------------------------- #class Double_DQN:#(1)初始化def __init__(self, n_states, n_hiddens, n_actions,learning_rate, gamma, epsilon,target_update, device):# 属性分配self.n_states = n_statesself.n_hiddens = n_hiddensself.n_actions = n_actionsself.learning_rate = learning_rateself.gamma = gammaself.epsilon = epsilonself.target_update = target_updateself.device = device# 记录迭代次数self.count = 0# 实例化训练网络self.q_net = Net(self.n_states, self.n_hiddens, self.n_actions)# 实例化目标网络self.target_q_net = Net(self.n_states, self.n_hiddens, self.n_actions)# 优化器,更新训练网络的参数self.optimizer = torch.optim.Adam(self.q_net.parameters(), lr=self.learning_rate)#(2)动作选择def take_action(self, state):# numpy[n_states]-->[1, n_states]-->Tensorstate = torch.Tensor(state[np.newaxis, :])# print('--------------------------')# print(state.shape)# 如果小于贪婪系数就取最大值reward最大的动作if np.random.random() < self.epsilon:# 获取当前状态下采取各动作的rewardactions_value = self.q_net(state)# 获取reward最大值对应的动作索引action = actions_value.argmax().item()# 如果大于贪婪系数就随即探索else:action = np.random.randint(self.n_actions)return action#(3)获取每个状态对应的最大的state_valuedef max_q_value(self, state):# list-->tensor[3]-->[1,3]state = torch.tensor(state, dtype=torch.float).view(1,-1)# 当前状态对应的每个动作的reward的最大值 [1,3]-->[1,11]-->intmax_q = self.q_net(state).max().item()return max_q#(4)网络训练def update(self, transitions_dict):# 当前状态,array_shape=[b,4]states = torch.tensor(transitions_dict['states'], dtype=torch.float)# 当前状态的动作,tuple_shape=[b]==>[b,1]actions = torch.tensor(transitions_dict['actions'], dtype=torch.int64).view(-1,1)# 选择当前动作的奖励, tuple_shape=[b]==>[b,1]rewards = torch.tensor(transitions_dict['rewards'], dtype=torch.float).view(-1,1)# 下一个时刻的状态array_shape=[b,4]next_states = torch.tensor(transitions_dict['next_states'], dtype=torch.float)# 是否到达目标 tuple_shape=[b,1]dones = torch.tensor(transitions_dict['dones'], dtype=torch.float).view(-1,1)# 当前状态[b,4]-->当前状态采取的动作及其奖励[b,2]-->actions中是每个状态下的动作索引# -->当前状态s下采取动作a得到的state_valueq_values = self.q_net(states).gather(1, actions)# 获取动作索引# .max(1)输出tuple每个特征的最大state_value及其索引,[1]获取的每个特征的动作索引shape=[b]max_action = self.q_net(next_states).max(1)[1].view(-1,1)# 下个状态的state_value。下一时刻的状态输入到目标网络,得到每个动作对应的奖励,使用训练出来的action索引选取最优动作max_next_q_values = self.target_q_net(next_states).gather(1, max_action)# 目标网络计算出的,当前状态的state_valueq_targets = rewards + self.gamma * max_next_q_values * (1-dones)# 预测值和目标值的均方误差损失dqn_loss = torch.mean(F.mse_loss(q_values, q_targets))# 梯度清零self.optimizer.zero_grad()# 梯度反传dqn_loss.backward()# 更新训练网络的参数self.optimizer.step()# 更新目标网络参数if self.count % self.target_update == 0:self.target_q_net.load_state_dict(self.q_net.state_dict()) # 更新目标网络# 迭代计数+1self.count += 1
3. 案例实现
我们使用 OpenAI 中的重力摆来验证模型,动作是连续型,代表力矩;状态包含三个;目的是让杆子竖直。
环境交互和训练代码如下:
import torch
import numpy as np
import gym
from tqdm import tqdm
import matplotlib.pyplot as plt
from parsers import args
from RL_brain import ReplayBuffer, Double_DQN# GPU运算
device = torch.device("cuda") if torch.cuda.is_available() \else torch.device("cpu")# ------------------------------- #
#(1)加载环境
# ------------------------------- #env = gym.make("Pendulum-v1", render_mode="human")
n_states = env.observation_space.shape[0] # 状态数 3
act_low = env.action_space.low # 最小动作力矩 -2
act_high = env.action_space.high # 最大动作力矩 +2
n_actions = 11 # 动作是连续的[-2,2],将其离散成11个动作# 确定离散动作区间后,确定其连续动作
def dis_to_con(discrete_action, n_actions):# discrete_action代表动作索引return act_low + (act_high-act_low) * (discrete_action/(n_actions-1))# 实例化经验池
replay_buffer = ReplayBuffer(args.capacity)# 实例化 Double-DQN
agent = Double_DQN(n_states,args.n_hiddens,n_actions,args.lr,args.gamma,args.epsilon,args.target_update,device)# ------------------------------- #
#(2)模型训练
# ------------------------------- #return_list = [] # 记录每次迭代的return,即链上的reward之和
max_q_value = 0 # 最大state_value
max_q_value_list = [] # 保存所有最大的state_valuefor i in range(10): # 训练几个回合done = False # 初始,未到达终点state = env.reset()[0] # 重置环境episode_return = 0 # 记录每回合的returnwith tqdm(total=10, desc='Iteration %d' % i) as pbar:while True:# 状态state时做动作选择,返回索引action = agent.take_action(state)# 平滑处理最大state_valuemax_q_value = agent.max_q_value(state) * 0.005 + \max_q_value * 0.995# 保存每次迭代的最大state_valuemax_q_value_list.append(max_q_value)# 将action的离散索引连续化action_continuous = dis_to_con(action, n_actions)# 环境更新next_state, reward, done, _, _ = env.step(action_continuous)# 添加经验池replay_buffer.add(state, action, reward, next_state, done)# 更新状态state = next_state# 更新每回合的回报episode_return += reward# 如果经验池超数量过阈值时开始训练if replay_buffer.size() > args.min_size:# 在经验池中随机抽样batch组数据s, a, r, ns, d = replay_buffer.sample(args.batch_size)# 构造训练集transitions_dict = {'states': s,'actions': a,'next_states': ns,'rewards': r,'dones': d,}# 模型训练agent.update(transitions_dict)# 到达终点就停止if done is True: break# 保存每回合的returnreturn_list.append(episode_return)pbar.set_postfix({'step':agent.count,'return':'%.3f' % np.mean(return_list[-10:])})pbar.update(1)# ------------------------------- #
#(3)绘图
# ------------------------------- #plt.subplot(121)
plt.plot(return_list)
plt.title('return')
plt.subplot(122)
plt.plot(max_q_value_list)
plt.title('max_q_value')
plt.show()
相关文章:

【深度强化学习】(2) Double DQN 模型解析,附Pytorch完整代码
大家好,今天和大家分享一个深度强化学习算法 DQN 的改进版 Double DQN,并基于 OpenAI 的 gym 环境库完成一个小游戏,完整代码可以从我的 GitHub 中获得: https://github.com/LiSir-HIT/Reinforcement-Learning/tree/main/Model 1…...
【正则表达式】正则表达式语法规则
正则表达式语法规则1.普通字符 字符描述[ABC]匹配 […] 中的所有字符[^ABC]匹配除了 […] 中字符的所有字符[A-Z][A-Z] 表示一个区间,匹配所有大写字母,[a-z] 表示所有小写字母.匹配除换行符以外的任意字符[\s\S]匹配所有。\s 是匹配所有空白符…...

1636_isatty函数的功能
全部学习汇总: GreyZhang/g_unix: some basic learning about unix operating system. (github.com) 前面刚刚看完了一个函数和三个文件指针,一行代码懂了半行。但是继续分析我之前看到的代码还是遇到了困难,因为之前自己对于UNIX的一些基础知…...

基于Stackelberg博弈的光伏用户群优化定价模型(Matlab代码实现)
💥💥💞💞欢迎来到本博客❤️❤️💥💥 🏆博主优势:🌞🌞🌞博客内容尽量做到思维缜密,逻辑清晰,为了方便读者。 ⛳️座右铭&a…...

EXCEL职业版本(3)
Excel职业版本(3) 公式与函数 运算符 算数运算符 关系运算符 地址的引用 相对引用:你变它就变,如影随形 A2:A5 绝对引用:以不变应万变 $A$2 混合引用:识时务者为俊杰,根据时…...

查找Pycharm跑代码下载模型存放位置以及有关模型下载小技巧(model_name_or_path参数)
目录一、前言二、发现问题三、删除这些模型方法一:直接删除注意方法二:代码删除一、前言 当服务器连不上,只能在本地跑代码时需要使用***预训练语言模型进行处理 免不了需要把模型下载到本地 时间一长就会发现C盘容量不够 二、发现问题 正…...
JS学习笔记day04
今日内容 零、 复习昨日 一、事件 二、DOM操作 三、案例 零、 复习昨日 js 脚本语言,弱类型 引入方案: 3种 js的内容: 语法dombom 语法 变量 var 数据类型 引用类型 - 对象,JSON {key:value,key:value} 数组 var arr new Array();var arr [1,2];下标取值赋值pop() s…...
异步控制流程 遍历篇
文章目录基础方法onlyOnce 只执行一次,第二次报错once 只执行一次,第二次无效iteratorSymbol 判断是否具有迭代器并返回迭代器arrayEach 普通数组遍历baseEach 对象类型遍历symbolEach 具有迭代器类型遍历异步遍历each异步控制流程的目的: 对…...

ICASSP 2023论文模型开源|语音分离Mossformer
人类能在复杂的多人说话环境中轻易地分离干扰声音,选择性聆听感兴趣的主讲人说话。但这对机器却不容易,如何构建一个能够媲美人类听觉系统的自动化系统颇具挑战性。 本文将详细解读ICASSP2023本届会议收录的单通道语音分离模型Mossformer论文࿰…...
vs2019 更改工程项目名称
本地 解决方案所在的位置为:D:\Projcet 解决方案名称:hello.sln 位置:D:\Projcet\hello.sln 工程项目名称:test 位置:D:\Projcet\test (文件夹中包含头文件,源文件) 工程包含的文件: fun.h …...

FusionCompute安装和配置步骤
1. 先去华为官网下载FusionCompute的镜像 下载地址:https://support.huawei.com/enterprise/zh/distributed-storage/fusioncompute-pid-8576912/software/251713663?idAbsPathfixnode01%7C22658044%7C7919788%7C9856606%7C21462752%7C8576912 下载后放在D盘中&am…...
makefile 参数和基本使用
make 常用选项make[-f file] [options] [target]make 默认在当前目录中查找GUNmakefile、makefile 及 Makefile 文件作为make的输入文件-f 指定文件作为输入文件-v 显示版本号-n 只输出命令不执行, 一般作为测试-s 执行命令不显示命令,-w 显示执行前和执…...

golang 占位符还傻傻分不清?
xdm ,写 C/C 语言的时候有格式控制符,例如 %s , %d , %c , %p 等等 在写 golang 的时候,也是有对应的格式控制符,也叫做占位符,写这个占位符,需要有对应的数据与之对应,不能瞎搞 基本常见常用…...

manacher算法详解
例题 求一个字符串的最长回文子串的长度 O(N2)O(N^2)O(N2)的解法很容易想,就是从每个字符位置向左右同时拓展,然后检查当前是不是回文,更新长度,可以简单写一下代码 int solve(string &ss){int ans 0;int n ss.length();s…...

要做一个关于DDD的内部技术分享,记录下用到的资源,学习笔记(未完)
最后更新于2023年3月10日 14:28:08 问题建模》软件分层》具体结构,是层层递进的关系。有了问题建模,才能进行具体的软件分层的讨论,再有了分层,才能讨论在domain里面应该怎么实现具体结构。 1、问题建模:Domain、Mod…...

KDZD互感器二次负载测试仪
一、概述 电能计量综合误差过大是电能计量中普遍存在的一个关键问题。电压互感器二次回路压降引起的计量误差往往是影响电能计量综合误差的因素。所谓电压互感器二次压降引起的误差,就是指电压互感器二次端子和负载端子之间电压的幅值差相对于二次实际电压的百分数…...

在空投之后,Blur能否颠覆OpenSea的主导地位?
Mar. 2023, Daniel数据源: NFT Aggregators Overview & Aggregator Statistics Overview & Blur Airdrop一年前,通过聚合器进行的NFT交易量开始像滚雪球一样增长,有时甚至超过了直接通过市场平台的交易量。虽然聚合器的使用量从10月到…...

2023年新三板产品及服务研究报告
第一章 概述 全国中小企业股份转让系统(英语:National Equities Exchange and Quotations,缩写NEEQ),简称股转系统,是第三家全国性证券交易场所,因挂牌企业均为高科技企业而不同于原转让系统内…...
张力控制之开环模式
张力控制的相关知识也可以参看专栏的其它文章,链接如下: 张力闭环控制之传感器篇(精密调节气阀应用)_RXXW_Dor的博客-CSDN博客跳舞轮对应张力调节范围,我们可以通过改变气缸的气压方式间接改变,张力跳舞轮在收放卷闭环控制上的详细应用,可以参看下面的文章链接,这里我…...
python的django框架从入门到熟练【保姆式教学】第二篇
在上一篇博客中,我们介绍了Django的基础知识,并创建了一个简单的Web应用程序。在本篇教程中,我们将深入探讨Django的模型层(Model),它是Django应用程序的核心组件之一。 模型层 Django的模型层是一个对象…...
ubuntu搭建nfs服务centos挂载访问
在Ubuntu上设置NFS服务器 在Ubuntu上,你可以使用apt包管理器来安装NFS服务器。打开终端并运行: sudo apt update sudo apt install nfs-kernel-server创建共享目录 创建一个目录用于共享,例如/shared: sudo mkdir /shared sud…...
Java如何权衡是使用无序的数组还是有序的数组
在 Java 中,选择有序数组还是无序数组取决于具体场景的性能需求与操作特点。以下是关键权衡因素及决策指南: ⚖️ 核心权衡维度 维度有序数组无序数组查询性能二分查找 O(log n) ✅线性扫描 O(n) ❌插入/删除需移位维护顺序 O(n) ❌直接操作尾部 O(1) ✅内存开销与无序数组相…...

1.3 VSCode安装与环境配置
进入网址Visual Studio Code - Code Editing. Redefined下载.deb文件,然后打开终端,进入下载文件夹,键入命令 sudo dpkg -i code_1.100.3-1748872405_amd64.deb 在终端键入命令code即启动vscode 需要安装插件列表 1.Chinese简化 2.ros …...
实现弹窗随键盘上移居中
实现弹窗随键盘上移的核心思路 在Android中,可以通过监听键盘的显示和隐藏事件,动态调整弹窗的位置。关键点在于获取键盘高度,并计算剩余屏幕空间以重新定位弹窗。 // 在Activity或Fragment中设置键盘监听 val rootView findViewById<V…...

QT: `long long` 类型转换为 `QString` 2025.6.5
在 Qt 中,将 long long 类型转换为 QString 可以通过以下两种常用方法实现: 方法 1:使用 QString::number() 直接调用 QString 的静态方法 number(),将数值转换为字符串: long long value 1234567890123456789LL; …...

Mysql中select查询语句的执行过程
目录 1、介绍 1.1、组件介绍 1.2、Sql执行顺序 2、执行流程 2.1. 连接与认证 2.2. 查询缓存 2.3. 语法解析(Parser) 2.4、执行sql 1. 预处理(Preprocessor) 2. 查询优化器(Optimizer) 3. 执行器…...
省略号和可变参数模板
本文主要介绍如何展开可变参数的参数包 1.C语言的va_list展开可变参数 #include <iostream> #include <cstdarg>void printNumbers(int count, ...) {// 声明va_list类型的变量va_list args;// 使用va_start将可变参数写入变量argsva_start(args, count);for (in…...
jmeter聚合报告中参数详解
sample、average、min、max、90%line、95%line,99%line、Error错误率、吞吐量Thoughput、KB/sec每秒传输的数据量 sample(样本数) 表示测试中发送的请求数量,即测试执行了多少次请求。 单位,以个或者次数表示。 示例:…...
Python 训练营打卡 Day 47
注意力热力图可视化 在day 46代码的基础上,对比不同卷积层热力图可视化的结果 import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib.pypl…...
软件工程教学评价
王海林老师您好。 您的《软件工程》课程成功地将宏观的理论与具体的实践相结合。上半学期的理论教学中,您通过丰富的实例,将“高内聚低耦合”、SOLID原则等抽象概念解释得十分透彻,让这些理论不再是停留在纸面的名词,而是可以指导…...