【强化学习】强化学习数学基础:值函数近似
值函数近似
- Value Function Approximation
- Motivating examples: curve fitting
- Algorithm for state value estimation
- Objective function
- Optimization algorithms
- Selection of function approximators
- Illustrative examples
- Summary of the story
- Theoretical analysis
- Sarsa with function appriximation
- Q-learning with function approximation
- Deep Q-learning
- 内容来源
Value Function Approximation
Motivating examples: curve fitting
到目前为止,我们都是使用tables表示state和action values。例如,下表是action value的表示:
- 优势:直观且容易分析
- 劣势:难以处理较大或者连续的state或者action空间。两个方面:1)存储;2)泛化能力。
举个例子:假定有一个one-dimensional states s1,...,s∣S∣s_1,...,s_{|S|}s1,...,s∣S∣,当π\piπ是给定策略的时候,它们的state values是vπ(s1),...,vπ(s∣S∣)v_\pi(s_1),...,v_\pi(s_{|S|})vπ(s1),...,vπ(s∣S∣)。假设∣S∣|S|∣S∣非常大,因此我们希望用一个简单的曲线近似它们的点以降低内存:
答案是可以的。
首先我们使用简单的straight line去拟合这些点。假设straight line的方程为
其中:
- www是参数向量(parameter vector)
- ϕ(s)\phi(s)ϕ(s)是s的特征向量(feature vector)
- v^(s,w)\hat{v}(s,w)v^(s,w)与www成线性关系(当然,也可以是非线性的)
这样表示的好处是:
- 表格形式需要存储∣S∣|S|∣S∣个state values,现在,只需要存储两个参数aaa和bbb
- 每次我们想要使用s的值,我们可以计算ϕT(s)w\phi^T(s)wϕT(s)w。
- 但是这个好处也不是免费的,它需要付出一些代价:state values不能被精确地表示,这也是为什么这个方法被称为value approximation。
既然直线不够准确,那么是否可以使用高阶的曲线呢?当然可以。第二,我们使用一个second-order curve去拟合这些点:
在这种情况下:
- www和ϕ(s)\phi(s)ϕ(s)的维数增加了,但是values可以被拟合的更加精确。
- 尽管v^(s,w)\hat{v}(s,w)v^(s,w)与sss是非线性的,但是它与www是线性的。这种非线性的性质包含在ϕ(s)\phi(s)ϕ(s)中。
当然,还可以继续增加阶数。第三,使用一个更加high-order polynomial curves(多项式曲线)或者其他复杂的曲线来拟合这些点:
- 好处是:更好的approximate
- 坏处是:需要更多的parameters
小结一下:
- Idea:value function approximation的idea是用一个函数v^(s,w)\hat{v}(s, w)v^(s,w)来拟合vπ(s)v_\pi(s)vπ(s),这个函数里边有参数www,所以被称为parameterized function,www就是parameter vector。
- 这样做的好处:
- 1)节省存储:www的维数远小于∣S∣|S|∣S∣
- 2)泛化能力:当一个state sss是visited,参数www是updated,这样某些其他unvisited states的values也可以被updated。按这种方式,the learned values可以泛化到unvisited states。
Algorithm for state value estimation
Objective function
首先,用一种更正式的方式:
- 令vπ(s)v_\pi(s)vπ(s)和v^(s,w)\hat{v}(s,w)v^(s,w)分别表示true state value和approximate函数.
- 我们的目标是找到一个最优的www,使得v^(s,w)\hat{v}(s,w)v^(s,w)对于每个sss达到最优的近似vπ(s)v_\pi(s)vπ(s)
- 这个问题就是一个policy evaluation问题,稍后我们将会把它推广到policy improvement。
- 为了找到最优的www,我们需要两步:
- 第一步定义一个目标函数(object function)
- 第二步是优化这个目标函数。
The objective function is:J(w)=E[(vπ(S)−v^(S,w))2]J(w)=\mathbb{E}[(v_\pi(S)-\hat{v}(S,w))^2]J(w)=E[(vπ(S)−v^(S,w))2]
- 我们的目标是找到最优的www,这样可以最小化J(w)J(w)J(w)
- The expectation is with respect to the random variable S∈SS\in \mathcal{S}S∈S。SSS的概率分布是什么?
- This is often confusing because we have not discussed the probability distribution of states so far
- There are several ways to define the probability distribution of SSS.
第一种方式是使用一个uniform distribution.
- 它对待每个states都是同等的重要性,通过将每个state的概率设置为1/∣S∣1/|\mathcal{S}|1/∣S∣
- 这种情况下,目标函数变为:J(w)=E[(vπ(S)−v^(S,w))2]=1∣S∣∑s∈S(vπ(s)−v^(s,w))2J(w)=\mathbb{E}[(v_\pi (S)-\hat{v}(S,w))^2]=\frac{1}{|\mathcal{S}|}\sum_{s\in \mathcal{S}}(v_\pi(s)-\hat{v}(s,w))^2J(w)=E[(vπ(S)−v^(S,w))2]=∣S∣1s∈S∑(vπ(s)−v^(s,w))2
- 虽然平均分布是非常直观的,但是有一个问题:这里假设所有状态都是平等的,但是实际上可能不是那么回事。例如,某些状态在一个策略下可能几乎不会访问到。因此这种方式没有考虑一个给定策略下Markov process的实际动态变化。
第二种方式是使用stationary distribution
- Stationary distribution is an important concept. 它描述了一个Markov process的long-run behavior。
- 令{dπ(s)}s∈S\{d_\pi(s)\}_{s\in \mathcal{S} }{dπ(s)}s∈S表示基于策略π\piπ的Markov process的stationary distribution。根据定义有,dπ(s)≥0d_\pi(s)\ge 0dπ(s)≥0且∑s∈Sdπ(s)=1\sum_{s\in \mathcal{S}}d_\pi(s)=1∑s∈Sdπ(s)=1
- 在这种情况下,目标函数被重写为:J(w)=E[(vπ(S)−v^(S,w))2]=∑s∈Sdπ(s)(vπ(s)−v^(s,w))2J(w)=\mathbb{E}[(v_\pi (S)-\hat{v}(S,w))^2]=\sum_{s\in \mathcal{S}}d_\pi (s)(v_\pi(s)-\hat{v}(s,w))^2J(w)=E[(vπ(S)−v^(S,w))2]=s∈S∑dπ(s)(vπ(s)−v^(s,w))2这里的dπ(s)d_\pi(s)dπ(s)就扮演了权重的意思,这个函数是一个weighted squared error。
- 由于更频繁地visited states,具有更高的dπ(s)d_\pi(s)dπ(s)值,它们在目标函数中的权重也比那些很少访问的states的权重高。
对于stationary distribution更多的介绍:
- Distribution:state的Distribution
- Stationary : Long-run behavior
- Summary: 智能体agent根据一个策略运行一个较长时间之后,the probability that the agent is at any state can be described by this distribution.
需要强调的是:
- Stationary distribution 也被称为steady-state distribution,或者limiting distribution
- 它在理解value functional approximation method方面是非常重要的
- 对于policy gradient method也是非常重要的。
举个例子:如图所示,给定一个探索性的策略。让agent从一个状态出发然后跑很多次,根据这个策略,然后看一下会发生什么事情。
- 令nπ(s)n_\pi(s)nπ(s)表示次数,sss has been visited in a very long episode generated by π\piπ。
- 然后,dπ(s)d_\pi(s)dπ(s)可以由下式估计:dπ(s)≈nπ(s)∑s′∈Snπ(s′)d_\pi(s)\approx \frac{n_\pi(s)}{\sum_{s'\in \mathcal{S}}n_\pi(s') }dπ(s)≈∑s′∈Snπ(s′)nπ(s)
The converged values can be predicted because they are the entries of dπd_\pidπ:dπT=dπTPπd_\pi^T=d_\pi^TP_\pidπT=dπTPπ
对于上面的例子,有PπP_\piPπ:Pπ=[0.30.10.600.10.300.60.100.30.600.10.10.8]P_\pi=\begin{bmatrix}0.3 & 0.1 & 0.6 & 0\\0.1 & 0.3 & 0 & 0.6\\0.1 & 0 & 0.3 & 0.6\\0 & 0.1 & 0.1 & 0.8\end{bmatrix}Pπ=0.30.10.100.10.300.10.600.30.100.60.60.8可以计算出来它左边对应于eigenvalue等于1的那个eigenvector:dπ=[0.0345,0.1084,0.1330,0.7241]Td_\pi=[0.0345, 0.1084, 0.1330, 0.7241]^Tdπ=[0.0345,0.1084,0.1330,0.7241]T
Optimization algorithms
当我们有了目标函数,下一步就是优化它。为了最小化目标函数J(w)J(w)J(w),我们可以使用gradient-descent算法:wk+1=wk−αk∇wJ(wk)w_{k+1}=w_k-\alpha_k\nabla_w J(w_k)wk+1=wk−αk∇wJ(wk)它的true gradient是:
这个true gradient需要计算一个expectation。我们可以使用stochastic gradient替代the true gradient:wt+1=wt+αt(vπ(st)−v^(st,wt))∇wv^(st,wt)w_{t+1}=w_t+\alpha_t (v_\pi(s_t)-\hat{v}(s_t,w_t))\nabla_w \hat{v}(s_t, w_t)wt+1=wt+αt(vπ(st)−v^(st,wt))∇wv^(st,wt)其中sts_tst是S\mathcal{S}S的一个采样。这里2αk2\alpha_k2αk合并到了αk\alpha_kαk。
- 这个算法在实际当中是不能使用的,因为它需要true state value vπv_\pivπ,这是未知的。
- 可以使用vπ(st)v_\pi(s_t)vπ(st)的一个估计来替代它,这样该算法就可以实现了
那么如何进行代替呢?有两种方法:
- 第一种,Monte Carlo learning with function approximation
令gtg_tgt表示在episode中从sts_tst开始的discounted return,然后使用gtg_tgt近似vπ(st)v_\pi(s_t)vπ(st)。该算法变为wt+1=wt+αt(gt−v^(st,wt))∇wv^(st,wt)w_{t+1}=w_t+\alpha_t (g_t-\hat{v}(s_t,w_t))\nabla_w \hat{v}(s_t, w_t)wt+1=wt+αt(gt−v^(st,wt))∇wv^(st,wt) - 第二种,TD learning with function approximate
By the spirit of TD learning, rt+1+γv^(st+1,wt)r_{t+1}+\gamma \hat{v}(s_{t+1}, w_t)rt+1+γv^(st+1,wt)可以视为vπ(st)v_\pi(s_t)vπ(st)的一个近似。因此,算法变为:wt+1=wt+αt[rt+1+γv^(st+1,wt)]∇wv^(st,wt)w_{t+1}=w_t+\alpha_t[r_{t+1}+\gamma \hat{v}(s_{t+1}, w_t)]\nabla_w \hat{v}(s_t, w_t)wt+1=wt+αt[rt+1+γv^(st+1,wt)]∇wv^(st,wt)
TD learning with function approximation的伪代码:
该方法仅能估计在给定policy情况下的state values,但是对于后面的算法的理解是非常重要的。
Selection of function approximators
如何选取函数v^(s,w)\hat{v}(s,w)v^(s,w)?
- 第一种方法,也是之前被广泛使用的,就是linear functionv^(s,w)=ϕT(s)w\hat{v}(s,w)=\phi^T(s)wv^(s,w)=ϕT(s)w这里的ϕ(s)\phi(s)ϕ(s)是一个feature vector, 可以是polynomial basis,Fourier basis,…。
- 第二种方法是,现在广泛使用的,就是用一个神经网络作为一个非线性函数近似器。神经网络的输入是state,输出是v^(s,w)\hat{v}(s,w)v^(s,w),网络参数是www。
在线性的情况中v^(s,w)=ϕT(s)w\hat{v}(s,w)=\phi^T(s)wv^(s,w)=ϕT(s)w,我们有∇wv^(st,wt)=ϕ(s)\nabla_w \hat{v}(s_t, w_t)=\phi(s)∇wv^(st,wt)=ϕ(s)将这个带入到TD算法wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt)w_{t+1}=w_t+\alpha_t[r_{t+1}+\gamma \hat{v}(s_{t+1}, w_t)-\hat{v}(s_t,w_t)]\nabla_w \hat{v}(s_t, w_t)wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt)就变成了wt+1=wt+αt[rt+1+γϕT(st+1)wt−ϕT(st)wt]ϕ(st)w_{t+1}=w_t+\alpha_t[r_{t+1}+\gamma \phi^T(s_{t+1})w_t-\phi^T(s_t)w_t]\phi(s_t)wt+1=wt+αt[rt+1+γϕT(st+1)wt−ϕT(st)wt]ϕ(st)这个具有线性函数近似的TD learning算法称为TD-Linear
。
线性函数近似的劣势是:
- 难以去选择合适的feature vector.
线性函数近似的优势是: - TD算法在线性情况下的理论上的性质很容易理解和分析,与非线性情况相比
- 线性函数近似仍然在某些情况下使用:tabular representation是linear function approximation的一种少见的特殊情况。
那么为什么tabular representation是linear function approximation的一种少见的特殊情况?
- 首先,对于state sss,选择一个特殊的feature vectorϕ(s)=es∈R∣S∣\phi(s)=e_s\in \mathbb{R}^{|\mathcal{S}|}ϕ(s)=es∈R∣S∣其中ese_ses是一个vector,其中第sss个实体为1,其他为0.
- 在这种情况下v^(st,wt)=esTw=w(s)\hat{v}(s_t, w_t)=e_s^Tw=w(s)v^(st,wt)=esTw=w(s)其中w(s)w(s)w(s)是www的第s个实体。
回顾TD-Linear算法:wt+1=wt+αt[rt+1+γϕT(st+1)wt−ϕT(st)wt]ϕ(st)w_{t+1}=w_t+\alpha_t[r_{t+1}+\gamma \phi^T(s_{t+1})w_t-\phi^T(s_t)w_t]\phi(s_t)wt+1=wt+αt[rt+1+γϕT(st+1)wt−ϕT(st)wt]ϕ(st)
- 当ϕ(st)=es\phi(s_t)=e_sϕ(st)=es,上面的算法变成了wt+1=wt+αt[rt+1+γwt(st+1)−wt(st)]estw_{t+1}=w_t+\alpha_t[r_{t+1}+\gamma w_t(s_{t+1})-w_t(s_t)]e_{s_t}wt+1=wt+αt[rt+1+γwt(st+1)−wt(st)]est这是一个向量等式,仅仅更新wtw_twt的第sss个实体。
- 将上面式子两边乘以estTe_{s_t}^TestT,得到wt+1(st)=wt(st)+αt[rt+1+γwt(st+1)−wt(st)]w_{t+1}(s_t)=w_t(s_t)+\alpha_t[r_{t+1}+\gamma w_t(s_{t+1})-w_t(s_t)]wt+1(st)=wt(st)+αt[rt+1+γwt(st+1)−wt(st)]这就是基于表格形式的TD算法。
Illustrative examples
考虑一个5×5的网格世界示例:
- 给定一个策略:π(a∣s)=0.2\pi(a|s)=0.2π(a∣s)=0.2,对于任意的s,as,as,a
- 我们的目标是基于该策略,估计state values(策略评估问题)
- 总计有25种state values。
- 设置rforbidden=rboundary=−1,rtarget=1,γ=0.9r_{forbidden}=r_{boundary}=-1, r_{target}=1, \gamma=0.9rforbidden=rboundary=−1,rtarget=1,γ=0.9
Ground truth:
- true state values和3D可视化
Experience samples:
- 500 episodes were generated following the given policy
- Each episode has 500 steps and starts from a randomly selected state-action pair following a uniform distribution。
为了对比,首先给出表格形式的TD算法(TD-Table)的结果:
那么看一下TD-Linear是否也能很好估计出来state value呢?
第一步就是要建立feature vector。要建立一个函数,这个函数也对应一个曲面,这个曲面能很好地拟合真实的state value对应的曲面。那么函数对应的曲面最简单的情况是什么呢?就是平面,所以这时候选择feature vector等于ϕ(s)=[1xy]∈R3\phi(s)=\begin{bmatrix}1 \\x \\y\end{bmatrix}\in \mathbb{R}^3ϕ(s)=1xy∈R3在这种情况下,近似的state value是v^(s,w)=ϕT(s)w=[1,x,y][w1w2w3]=w1+w2x+w3y\hat{v}(s,w)=\phi^T(s)w=[1, x, y]\begin{bmatrix}w_1 \\w_2 \\w_3\end{bmatrix} =w_1+w_2x+w_3yv^(s,w)=ϕT(s)w=[1,x,y]w1w2w3=w1+w2x+w3y注意,ϕ(s)\phi(s)ϕ(s)也可以定义为ϕ(s)=[x,y,1]T\phi(s)=[x, y, 1]^Tϕ(s)=[x,y,1]T,其中这里边的顺序是不重要的。
将刚才的feature vector带入TD-Linear算法中,得到:
- 这里边的趋势是正确的,但是有一些错误,这是由于用平面拟合的本身方法的局限性。
- 我们尝试使用一个平面去近似一个非平面,这是非常困难的。
为了提高近似能力,可以使用high-order feature vectors,这样也就有更多的参数。
- 例如,我们考虑这样一个feature vector:ϕ(s)=[1,x,y,x2,y2,xy]T∈R6\phi(s)=[1, x, y, x^2, y^2, xy]^T\in \mathbb{R}^6ϕ(s)=[1,x,y,x2,y2,xy]T∈R6在这种情况下,有v^(s,w)=ϕT(s)w=w1+w2x+w3y+w4x2+w5y2+w6xy\hat{v}(s,w)=\phi^T(s)w=w_1+w_2x+w_3y+w_4x^2+w_5y^2+w_6xyv^(s,w)=ϕT(s)w=w1+w2x+w3y+w4x2+w5y2+w6xy这对应一个quadratic surface。
- 可以进一步增加feature vector的维度ϕ(s)=[1,x,y,x2,y2,xy,x3,y3,x2y,xy2]T∈R10\phi(s)=[1, x, y, x^2, y^2, xy, x^3, y^3, x^2y, xy^2]^T\in \mathbb{R}^10ϕ(s)=[1,x,y,x2,y2,xy,x3,y3,x2y,xy2]T∈R10
通过higher-order feature vectors的TD-Linear算法的结果:
Summary of the story
1)首先从一个objective function出发J(w)=E[(vπ(S)−v^(S,w))2]J(w)=\mathbb{E}[(v_\pi(S)-\hat{v}(S, w))^2]J(w)=E[(vπ(S)−v^(S,w))2]这个目标函数表明这是一个policy evaluation问题.
2)然后对这个objective function进行优化,优化方法使用gradient-descent algorithm:wt+1=wt+αt(vπ(st)−v^(st,wt))∇wv^(st,wt)w_{t+1}=w_t+\alpha_t (v_\pi(s_t)-\hat{v}(s_t,w_t))\nabla_w \hat{v}(s_t, w_t)wt+1=wt+αt(vπ(st)−v^(st,wt))∇wv^(st,wt)但是问题是里边有一个vπ(st)v_\pi(s_t)vπ(st)是不知道的。
3)第三,使用一个近似替代算法中的true value function vπ(st)v_\pi(s_t)vπ(st),得到下面算法:wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt)w_{t+1}=w_t+\alpha_t[r_{t+1}+\gamma \hat{v}(s_{t+1}, w_t)-\hat{v}(s_t,w_t)]\nabla_w \hat{v}(s_t, w_t)wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt)
尽管上面的思路对于理解基本思想是非常有帮助的,但是它在数学上是不严谨的,因为做了替换操作。
Theoretical analysis
一个基本的结论,这个算法wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt)w_{t+1}=w_t+\alpha_t[r_{t+1}+\gamma \hat{v}(s_{t+1}, w_t)-\hat{v}(s_t,w_t)]\nabla_w \hat{v}(s_t, w_t)wt+1=wt+αt[rt+1+γv^(st+1,wt)−v^(st,wt)]∇wv^(st,wt)不是去minimize下面的objective function:J(w)=E[(vπ(S)−v^(S,w))2]J(w)=\mathbb{E}[(v_\pi(S)-\hat{v}(S, w))^2]J(w)=E[(vπ(S)−v^(S,w))2]
实际上,有多种objective functions:
- Objective function 1:True value errorJ(w)=E[(vπ(S)−v^(S,w))2]=∣∣v^(w)−vπ∣∣D2J(w)=\mathbb{E}[(v_\pi(S)-\hat{v}(S, w))^2]=||\hat{v}(w)-v_\pi||_D^2J(w)=E[(vπ(S)−v^(S,w))2]=∣∣v^(w)−vπ∣∣D2
- Objective function 2:Bellman errorJBE(w)=∣∣v^(w)−(rπ+γPπv^(w))∣∣D2≐∣∣v^(w)−Tπ(v^(w))∣∣D2J_{BE}(w)=||\hat{v}(w)-(r_\pi+\gamma P_{\pi}\hat{v}(w))||_D^2\doteq ||\hat{v}(w)-T_\pi(\hat{v}(w))||_D^2JBE(w)=∣∣v^(w)−(rπ+γPπv^(w))∣∣D2≐∣∣v^(w)−Tπ(v^(w))∣∣D2其中Tπ(x)≐rπ+γPπxT_\pi(x)\doteq r_\pi+\gamma P_\pi xTπ(x)≐rπ+γPπx
- Objective function 2:Projected Bellman errorJPBE(w)=∣∣v^(w)−MTπ(v^(w))∣∣D2J_{PBE}(w)=||\hat{v}(w)-MT_\pi(\hat{v}(w))||_D^2JPBE(w)=∣∣v^(w)−MTπ(v^(w))∣∣D2其中MMM是一个projection matrix(投影矩阵)
简而言之,上面提到的TD-Linear算法在最小化projected Bellman error。
Sarsa with function appriximation
到目前为止,我们仅仅是考虑state value estimation的问题,也就是我们希望v^≈vπ\hat{v}\approx v_\piv^≈vπ。为了搜索最优策略,我们需要估计action values。
The Sarsa algorithm with value function approximation是:
这个上一节介绍的TD算法是一样的,只不过将v^\hat{v}v^换成了q^\hat{q}q^
为了寻找最优策略,我们将policy evaluation(上面算法做的事儿)和policy improvement结合。下面给出Sarsa with function approximation的伪代码:
举个例子:
- Sarsa with linear function approximation。
- rforbidden=rboundary=−10,rtarget=1,γ=0.9,α=0.001,ϵ=0.1r_{forbidden}=r_{boundary}=-10, r_{target}=1, \gamma=0.9, \alpha=0.001, \epsilon=0.1rforbidden=rboundary=−10,rtarget=1,γ=0.9,α=0.001,ϵ=0.1
Q-learning with function approximation
类似地,tabular Q-learning也可以扩展到value function approximation的情况。
The q-value更新规则是:
这与上面的Sarsa算法相同,除了q^(st+1,at+1,wt)\hat{q}(s_{t+1}, a_{t+1}, w_t)q^(st+1,at+1,wt)被替换为maxa∈A(st+1)q^(st+1,a,wt)\max_{a\in \mathcal{A}(s_{t+1})}\hat{q}(s_{t+1}, a, w_t)maxa∈A(st+1)q^(st+1,a,wt)。
Q-learning with function approximation伪代码(on-policy version):
举个例子:
- Q-learning with linear function approximation
- rforbidden=rboundary=−10,rtarget=1,γ=0.9,α=0.001,ϵ=0.1r_{forbidden}=r_{boundary}=-10, r_{target}=1, \gamma=0.9, \alpha=0.001, \epsilon=0.1rforbidden=rboundary=−10,rtarget=1,γ=0.9,α=0.001,ϵ=0.1
Deep Q-learning
Deep Q-learning算法又被称为deep Q-network (DQN):
- 最早的一个和最成功的一个将深度神经网络算法引入到强化学习中
- 神经网络的角色是一个非线性函数approximator
- 与下面的算法不同,是由于训练一个网络的方式:
Deep Q-learning旨在最小化目标函数/损失函数:
其中(S,A,R,S′)(S,A,R,S')(S,A,R,S′)是随机变量。
那么如何最小化目标函数呢?使用Gradient-descent!但是如何计算目标函数的梯度还是有一些tricky。这是因为在目标函数中有两个位置有www:
也就是说参数w不仅仅只出现在q^(S,A,w)\hat{q}(S,A,w)q^(S,A,w)中,还出现在它的前面。这里用yyy表示:y≐R+γmaxa∈A(S′)q^(S′,a,w)y\doteq R+\gamma \max_{a\in \mathcal{A}(S')} \hat{q}(S',a,w)y≐R+γa∈A(S′)maxq^(S′,a,w)
为了简单起见,我们可以假设www在yyy中是固定的(至少一定时间内),当我们计算梯度的时候。为了这样做,我们引入两个network。
- 一个是main network,用以表示q^(s,a,w)\hat{q}(s,a,w)q^(s,a,w)
- 另一个是target network q^(s,a,wT)\hat{q}(s,a,w_T)q^(s,a,wT)
用这两个network吧上面目标函数中的两个q^\hat{q}q^区分开来,就得到了如下式子:
其中wTw_TwT是target network parameter。
当wTw_TwT是固定的,可以计算出来JJJ的梯度如下:
- 这就是Deep Q-learning的基本思想,使用gradient-descent算法最小化目标函数。
- 然而,这样的优化过程涉及许多重要的技巧。
第一个技巧:使用了两个网络,一个是main network,另一个是target network。
为什么要使用两个网络呢?在数学上来说因为计算梯度的时候会非常的复杂,所以先去固定一个,然后再去计算另一个,这样就需要两个网络来实现。
具体实现的细节:
- 令www和wTw_TwT分别表示mean network和target network的参数,它们初始化的时候是一样的。
- 在每个iteration中,从
replay buffer
中draw一个mini-batch样本{(s,a,r,s′)}\{(s,a,r,s')\}{(s,a,r,s′)} - 网络的输入包括state sss和action aaa,目标输出是yT≐r+γmaxa∈A(s′)q^(s′,a,wT)y_T\doteq r+\gamma \max_{a\in \mathcal{A}(s')} \hat{q}(s',a,w_T)yT≐r+γmaxa∈A(s′)q^(s′,a,wT)。然后我们直接基于the mini-batch {(s,a,r,s′)}\{(s,a,r,s')\}{(s,a,r,s′)}最小化TD error或者称为loss function (yT−q^(s,a,w))2(y_T-\hat{q}(s,a,w))^2(yT−q^(s,a,w))2。这样一段时间后,参数w发生变化,再将其赋给wTw_TwT,再用来训练www。
另一个技巧:Experience replay
(经验回放)
问题:什么是Experience replay?
回答:
- 我们收集一些experience samples之后,we do NOT use these samples in the order they were collected。
- Instead,我们将它们存储在一个set中,称为replay buffer B≐{(s,a,r,s′)}\mathcal{B}\doteq \{(s, a, r, s')\}B≐{(s,a,r,s′)}
- 每次我们训练neural network,我们可以从replay buffer中draw a mini-batch的random samples
- 取出的samples,称为experience replay,应当按照一个均匀分布的方式,即每个experience被replay的机会是相等的。
问题:为什么在deep Q-learning中要用experience replay?为什么replay必须要按照一个uniform distribution的方式?
回答:这个回答依赖于下面的objective function
- (S,A)∼d(S,A)\sim d(S,A)∼d:(S,A)(S,A)(S,A)是一个索引,并将其视为一个single random variable。
- R∼p(R∣S,A),S′∼p(S′∣S,A)R\sim p(R|S,A), S'\sim p(S'|S,A)R∼p(R∣S,A),S′∼p(S′∣S,A):RRR和SSS由system model确定
- state-action pair (S,A)(S,A)(S,A)的分布假定是uniform.
- 然而,样本采集不是按照均匀分布来的,因为它们是由某个policies按顺序生成的。
- 为了打破顺序采样样本的关联,我们才从replay buffer中按照uniformly方式drawing samples,也就是experience replay technique
- 这是在数学上为什么experience replay是必须的,以及为什么experience replay必须是uniform的原因。
回顾tabular的情况:
- 问题1:为什么tabular Q-learning没有要求experience replay?
- 回答:没有uniform distribution的需要
- 问题2:为什么Deep Q-learning 涉及distribution?
- 回答:因为在deep Q-learning的情况下,目标函数是一个在所有(S,A)(S,A)(S,A)之上的scale average。tabular case没有涉及SSS或者AAA的任何distribution。在tabular情况下算法旨在求解对于所有的(s,a)(s,a)(s,a)的一组方程(Bellman optimality equation)。
- 问题3:可以在tabular Q-learning中使用experience replay吗?
- 回答:可以,而且还会让sample更加高效,因为同一个sample可以用多次。
再次给出Deep Q-learning的伪代码(off-policy version):
需要澄清的几个问题:
- 为什么没有策略更新?因为这里是off-policy
- 为什么没有使用之前导出的梯度去更新策略?因为之前导出梯度的算法比较底层,它可以指导我们去生成现在的算法,但是要遵循神经网络批量训练的黑盒特性,然后更好地高效地训练神经网络
- 这里网络的input和output与DQN原文中的不一样。原文中是on-policy的,这里是off-policy的。
举个例子:目标是learn optimal action values for every state-action pair。一旦得到最优策略,最优greedy策略可以立即得到。
问题设置:
仿真结果:
如果我们仅仅使用100步的一个single episode将会发生什么?也就是数据不充分的情况
可以看出,好的算法是需要充分的数据才能体现效果的。
内容来源
- 《强化学习的数学原理》 西湖大学工学院赵世钰教授 主讲
- 《动手学强化学习》 俞勇 著
相关文章:
【强化学习】强化学习数学基础:值函数近似
值函数近似Value Function ApproximationMotivating examples: curve fittingAlgorithm for state value estimationObjective functionOptimization algorithmsSelection of function approximatorsIllustrative examplesSummary of the storyTheoretical analysisSarsa with …...
JVM系列——Java与线程,介绍线程原理和操作系统的关系
并发不一定要依赖多线程(如PHP中很常见的多进程并发)。 但是在Java里面谈论并发,基本上都与线程脱不开关系。因此我们讲一下从Java线程在虚拟机中的实现。 线程的实现 线程是比进程更轻量级的调度执行单位。 线程的引入,可以把一个进程的资源分配和执行调…...
C++打开文件夹对话框之BROWSEINFO
头文件 #include <shlobj.h> #include <windows.h> #include <stdio.h> using namespace std; 案例 string chooseFile(void) {//用户选择的路径,可以是TCHAR szBuffer[MAX_PATH] {0};然后再使用TCHAR 转char字符串,此处可以直接使…...
Nuxt项目配置、目录结构说明-实战教程基础-Day02
Nuxt项目配置、目录结构说明-实战教程基础-Day02一、Nuxt项目结构1.1资源目录1.2 组件目录1.3 布局目录1.4 中间件目录1.5 页面目录1.6 插件目录1.7 静态文件目录1.8 Store 目录1.9 nuxt.config.js 文件1.10 package.json 文件其他:别名二、项目配置2.1 build2.2 cs…...
单链表的头插,尾插,头删,尾删等操作
前言顺序表要求是具有连续的物理空间,并且数据的话是在这些空间当中是连续的存储。但这样会带来很多问题,比如说在头部或者说中间插入的话,效率不是很高;并且申请空间可能需要扩容,并且越往后一般来说都是异地扩容&…...
Qt扫盲-QProcess理论总结
QProcess理论使用总结一、概述二、使用三、通过 Channel 通道通信四、同步进程API五、注意事项1. 平台特性2. 不能实时读取一、概述 QProcess 其实更多的是与外面进程进行交互的一个工具类,通过这个类来启动外部进程,获取这个进程的标准输出,…...
JAVA进阶 —— Steam流
目录 一、 引言 二、 Stream流概述 三、Stream流的使用步骤 1. 获取Stream流 1.1 单列集合 1.2 双列集合 1.3 数组 1.4 零散数据 2. Stream流的中间方法 3. Stream流的终结方法 四、 练习 1. 数据过滤 2. 数据操作 - 按年龄筛选 3. 数据操作 - 演员信息要求…...
Ubuntu Protobuf 安装(测试有效)
安装流程 下载软件 下载自己要安装的版本:https://github.com/protocolbuffers/protobuf 下载源码编译: 系统环境:Ubuntu16(其它版本亦可),Protobuf-3.6.1 编译源码 cd protobuf# 当使用 git clone 下来的…...
驱动程序开发:FTP服务器和OpenSSH的移植与搭建、以及一些笔记
目录一、FTP服务器移植与搭建1、在ubuntu下安装vsftpd2、在window下安装FileZilla3、移植vsftpd到开发板上4、Filezilla 连接测试5、注意点二、开发板 OpenSSH 移植与使用1、移植 zlib 库2、移植 openssl 库3、移植 openssh 库4、openssh 使用测试三、关于u-boot上的操作及根文…...
优化改进YOLOv5算法之添加GIoU、DIoU、CIoU、EIoU、Wise-IoU模块(超详细)
目录 1、IoU 1.1 什么是IOU 1.2 IOU代码 2、GIOU 2.1 为什么提出GIOU 2.2 GIoU代码 3 DIoU 3.1 为什么提出DIOU 3.2 DIOU代码 4 CIOU 4.1 为什么提出CIOU 4.2 CIOU代码 5 EIOU 5.1 为什么提出EIOU 5.2 EIOU代码 6 Wise-IoU 7 YOLOv5中添加GIoU、DIoU、CIoU、…...
windows电脑pc如何使用svn获取文档和代码
一、安装svn 下载链接 也可通过其他方式下载 二、使用 2.1 随便找一个文件夹 2.2 点击右键,选择SVN Checkout 2.3输入网址 如当你在网页上访问时地址为https://10.197.78.78/!/#aaa/view/head/bbb 在这里不能直接填入,而是 https://10.197.78.78/sv…...
ROS1学习笔记:tf坐标系广播与监听的编程实现(ubuntu20.04)
参考B站古月居ROS入门21讲:tf坐标系广播与监听的编程实现 基于VMware Ubuntu 20.04 Noetic版本的环境 文章目录一、创建功能包二、创建代码2.1 以C为例2.1.1 配置代码编译规则2.1.2 编译整个工作空间2.1.2 配置环境变量2.1.4 执行代码2.2 以Python为例2.2.1 配置代码…...
力扣解法汇总1590. 使数组和能被 P 整除
目录链接: 力扣编程题-解法汇总_分享记录-CSDN博客 GitHub同步刷题项目: https://github.com/September26/java-algorithms 原题链接:力扣 描述: 给你一个正整数数组 nums,请你移除 最短 子数组(可以为 …...
Spring源码阅读(基础)
第一章:bean的元数据 1.bean的注入方式: 1.1 xml文件 1.2 注解 Component(自己写的类才能在上面加这些注解) 1.3配置类: Configuration 注入第三方数据源之类 1.4 import注解 (引用了Myselector类下…...
服务搭建篇(九) 使用GitLab+Jenkins搭建CI\CD执行环境 (上) 基础环境搭建
1.前言 每当我们程序员开发在本地完成开发之后 , 都要部署到正式环境去使用 , 在一些传统的运维体系中 , 开发与运维都是割裂的 , 开发人员不允许操作正式服务器 , 服务器只能通过运维团队来操作 , 这样可以极大的提高服务器的安全性 , 不经过安全保护的开放服务器 , 对于黑客…...
CDC 长沙站丨云原生技术研讨会:数字兴链,云化未来!
一、活动信息:活动主题:CDC 长沙站丨云原生技术研讨会活动时间:2023 年 3 月 14 日下午 14:30-17:30活动地点:长沙市岳麓区-拓维信息总部 1 楼多功能厅活动参与方式:免门票参与,戳此…...
A.特定领域知识图谱知识推理方案:知识图谱推理算法综述[二](DTransE/PairRE:基于表示学习的知识图谱链接预测算法)
推荐参考文章: A.特定领域知识图谱知识推理方案:知识图谱推理算法综述[一](基于距离的翻译模型:TransE、TransH、TransR、TransH、TransA、RotatE) A.特定领域知识图谱知识推理方案:知识图谱推理算法综述[二](DTransE/PairRE:基于表示学习的知识图谱链接预测算法) A.…...
香港酒店模拟分析项目报告--使用tableau、python、matlab
转载请标记本文出处 软件:tableau、pycharm、关系型数据库:MySQL 数据大量分析考虑电脑性能的情况。 文章目录前言一、爬虫是什么?二、使用tableau数据可视化1.引入数据1.1 制作直方图-各地区酒店数量条形图1.2 各地区酒店均价1.3 价格等级堆…...
第18天-商城业务(商品检索服务,基于Elastic Search完成商品检索)
1.构建商品检索页面 1.1.引入依赖 <!-- thymeleaf模板引擎 --><dependency><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-thymeleaf</artifactId></dependency><!-- 热更新 --><…...
5.2 对射式红外传感器旋转编码器计次
对射式红外传感器1.1 接线图VCC GND分别接电源的正负极DO数字输出端,随意选择一个GPIO口1.2 硬件原理当挡光片或者编码盘在对射式红外传感器中间经过时,DO就会输出电平变化信号,电平跳变信号触发STM32 PB14号口中断,在中断函数中执…...
【数据库概论】第九章 关系查询处理和查询优化
第九章 关系查询处理和查询优化 本章主要介绍关系数据库查询管理和查询优化,主要分为代数优化(又称逻辑优化)和物理优化(也称非代数优化)。 9.1 关系型数据库系统的查询处理 查询处理是关系型数据库管理系统执行查询…...
(WIP) my cloud test bed (by quqi99)
作者:张华 发表于:2023-03-10 版权声明:可以任意转载,转载时请务必以超链接形式标明文章原始出处和作者信息及本版权声明 问题 想创建一个local local test bed, 用来方便做各种云实验,如openstack, k8s, ovn, lxd等…...
git | git 2023 详细版
文章目录一、Git命令1.2 设计用户签名1.3 初始化本地库1.4 查看本地库状态1.5 添加至暂存区1.6 从暂存区删除1.7 将暂存区的文件提交到本地库1.8 查看版本信息二、Git分支2.1 查看分支2.2 创建分支2.3 切换分支2.4 合并分支三、GitHub3.1 代码克隆clone3.2 给库取别名3.3 推送本…...
camunda流程引擎基本使用(笔记)
文章目录一、camunda基础1.1 安装与部署流程引擎1.2 流程引擎结构1.3 流程引擎的基本使用1.3.1 创建一个BPMN Diagram1.3.2 实现一个外部工作者1.3.3 部署流程1.3.4 创建一个流程实例并消费1.3.5 向流程中添加用户任务1.3.6 添加网关1.3.7 业务规则二、Java 集成流程引擎2.1 为…...
JS之数据结构与算法
前言数据结构是计算机存储、组织数据的方式,算法是系统描述解决问题的策略。了解基本的数据结构和算法可以提高代码的性能和质量。也是程序猿进阶的一个重要技能。手撸代码实现栈,队列,链表,字典,二叉树,动态规划和贪心算法1.数据结构篇1.1 栈栈的特点:先进后出clas…...
CnOpenData·A股上市企业数字化转型指数数据
一、数据简介 企业数字化转型是近年来中国社会各界重点关注的领域,但基础数据的不完善在很大程度上制约了相关科学研究的开展。构建合理、科学的数字化转型指标体系有利于学者定量地研究企业数字化的相关问题,也有利于衡量企业的数字化水平。广东金融学院…...
VMware16pro虚拟机安装全过程
很多时候需要用到Linux系统,简单的一种方式可以是:Windows系统运行Linux(Windows Subsystem for Linux)不过有些时候还是需要虚拟机来运行Linux,也更方便点,比如在做嵌入式系统的烧录等操作都需要Linux环境…...
阿里云第六代云服务器最新价格表(计算型c6、通用型g6和内存型r6)
目前阿里云第六代云服务器有计算型c6、通用型g6和内存型r6实例。计算型c6实例有2核4G、4核8G、8核16G配置可选,主要适用于网站应用、批量计算、视频编码等场景。通用型g6实例有2核8G、4核16G、8核32G配置可选,适用于各种类型的企业级应用,网站…...
微小目标识别研究(2)——基于K近邻的白酒杂质检测算法实现
文章目录实现思路配置opencv位置剪裁实现代码自适应中值滤波实现代码动态范围增强实现代码形态学处理实现代码图片预处理效果计算帧差连续帧帧差法原理和实现代码实现代码K近邻实现基本介绍实现代码这部分是手动实现的,并没有直接调用相关的库完整的代码——调用ope…...
2022-06-14至2022-08-11 关于复现MKP算法的总结与反思
Prerequisite 自2022年6月14日至2022年8月11日的时间内,我致力于完成A Hybrid Approach for the 0–1 Multidimensional Knapsack problem 论文的复现工作,此次是我第一次进行组合优化方向的学习工作,下面介绍该工作内容发展过程以及该工作结…...
建筑网站制作/搜索引擎优化案例分析
线程只要分为:主线程和子线程,主线程主要处理和界面相关的事情,而子线程则往往用于执行耗时的操作,由于Android的特性,如果在主线程中执行耗时操作那么就会导致程序无法及时响应,因此耗时操作必须方法子线程…...
用ps做网站方法/seo与sem的区别和联系
系列文章目录(springboot整合activiti5) 排他网关类似于程序开发中的if操作,只有判断条件为true的时候才会执行的,在Activiti中排他网关的xml代码的基本格式为 <exclusiveGateway id"exclusivegateway1" name"…...
上海网站开发工程师招聘网/长春网站建设公司哪家好
IDispatch error #3092 表示sql执行语句有语法错误,一般容易检查。下面是一个该错误的例子: sql1_T("SELECT *,DATEPART(yyyy,Student_InDate) as Student_InYearDate FROM Table_Student order by Student_ID"); //正确 sql2_T("SEL…...
微信小程序怎么创建店铺/seo查询seo
1.前期准备 1.1首先先从官网下载安装包 https://dev.mysql.com/downloads/mysql/ 1.2 创建软件目录,解压迁移软件 [rootdb03 opt]# tar -xvf mysql-8.0.20-linux-glibc2.12-x86_64.tar.xz mysql [rootdb03 opt]# ls mysql1.3 创建 mysql 用户 [rootdb03 opt]# …...
如何建设个人网站和博客/留手机号广告
随时随地阅读更多技术实战干货,获取项目源码、学习资料,请关注源代码社区公众号(ydmsq666) 在Android也可以实现点击网址、email等自动连接,在输入框中输入数据,在下面列表中列出所有匹配的数据,点击实现自动补全的功能…...
企业做网站的注意事项/东莞精准网络营销推广
websphere存档日期:2019年5月13日 | 首次发布:2009年3月11日 本文介绍了使用WebSphere Transformation Extender,其WebSphere Design Studio和WebSphere DataPower SOA Appliance的数据集成方案。 遗留系统的一个常见问题是与使用XML而不是其…...