深度学习——权重衰减(weight_decay)
深度学习——权重衰减(weight_decay)
文章目录
- 前言
- 一、权重衰减
- 1.1. 范数与权重衰减
- 1.2. 高维线性回归
- 1.3. 从零开始实现
- 1.3.1.初始化模型参数
- 1.3.2. 定义L₂范数惩罚
- 1.3.3. 定义训练代码实现
- 1.3.4. 不管正则化直接训练
- 1.3.5. 使用权重衰减
- 1.4. 简洁实现
- 总结
前言
上一章描述了过拟合的问题,本章我们将介绍一些正则化模型的技术。如权重衰减
参考书:
《动手学深度学习》
一、权重衰减
1.1. 范数与权重衰减
在训练参数化机器学习模型时,权重衰减(weight decay)是最广泛使用的正则化的技术之一, 它通常也被称为 L 2 L_2 L2正则化。这项技术通过函数与零的距离来衡量函数的复杂度,
因为在所有函数 f f f中,函数 f = 0 f = 0 f=0(所有输入都得到值 0 0 0),在某种意义上是最简单的。
但是我们应该如何精确地测量一个函数和零之间的距离呢?
一种简单的方法是通过线性函数
f ( x ) = w ⊤ x f(\mathbf{x}) = \mathbf{w}^\top \mathbf{x} f(x)=w⊤x 中的权重向量的某个范数来度量其复杂性,
例如 ∥ w ∥ 2 \| \mathbf{w} \|^2 ∥w∥2。
要保证权重向量比较小,最常用方法是将其范数作为惩罚项加到最小化损失的问题中:
即将原来的训练目标最小化训练标签上的预测损失,调整为最小化预测损失和惩罚项之和。
现在,如果我们的权重向量增长的太大,我们的学习算法可能会更集中于最小化权重范数 ∥ w ∥ 2 \| \mathbf{w} \|^2 ∥w∥2。这正是我们想要的。
我们的损失由下式给出:
L ( w , b ) = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 . L(\mathbf{w}, b) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right)^2. L(w,b)=n1i=1∑n21(w⊤x(i)+b−y(i))2.
为了惩罚权重向量的大小,我们必须以某种方式在损失函数中添加 ∥ w ∥ 2 \| \mathbf{w} \|^2 ∥w∥2
但是模型应该如何平衡这个新的额外惩罚的损失?
实际上,我们通过正则化常数 λ \lambda λ来描述这种权衡,这是一个非负超参数,我们使用验证数据拟合:
L ( w , b ) + λ 2 ∥ w ∥ 2 , L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2, L(w,b)+2λ∥w∥2,
对于 λ > 0 \lambda > 0 λ>0,我们限制 ∥ w ∥ \| \mathbf{w} \| ∥w∥的大小。
为什么在这里我们使用平方范数而不是标准范数(即欧几里得距离)?
我们这样做是为了便于计算。通过平方 L 2 L_2 L2范数,我们去掉平方根,留下权重向量每个分量的平方和。
这使得惩罚的导数很容易计算:导数的和等于和的导数。
此外,为什么我们首先使用 L 2 L_2 L2范数,而不是 L 1 L_1 L1范数。
L 2 L_2 L2正则化线性模型构成经典的岭回归(ridge regression)算法,
L 1 L_1 L1正则化线性回归是统计学中类似的基本模型,通常被称为套索回归(lasso regression)。
使用 L 2 L_2 L2范数的一个原因是它对权重向量的大分量施加了巨大的惩罚。这使得我们的学习算法偏向于在大量特征上均匀分布权重的模型。在实践中,这可能使它们对单个变量中的观测误差更为稳定。
相比之下, L 1 L_1 L1惩罚会导致模型将权重集中在一小部分特征上,
而将其他权重清除为零。这称为特征选择(feature selection),可能是其他场景下需要的。
L 2 L_2 L2正则化回归的小批量随机梯度下降更新如下式:
w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) . \begin{aligned} \mathbf{w} & \leftarrow \left(1- \eta\lambda \right) \mathbf{w} - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)} \left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right). \end{aligned} w←(1−ηλ)w−∣B∣ηi∈B∑x(i)(w⊤x(i)+b−y(i)).
我们根据估计值与观测值之间的差异来更新 w \mathbf{w} w。然而,我们同时也在试图将 w \mathbf{w} w的大小缩小到零。
这就是为什么这种方法有时被称为权重衰减。我们仅考虑惩罚项,优化算法在训练的每一步衰减权重。
与特征选择相比,权重衰减为我们提供了一种连续的机制来调整函数的复杂度。 较小的 λ \lambda λ值对应较少约束的 w \mathbf{w} w,而较大的 λ \lambda λ值对 w \mathbf{w} w的约束更大。
是否对相应的偏置 b 2 b^2 b2进行惩罚在不同的实践中会有所不同,
在神经网络的不同层中也会有所不同。通常,网络输出层的偏置项不会被正则化。
1.2. 高维线性回归
我们通过一个简单的例子来演示权重衰减。
首先,我们像以前一样生成一些数据,生成公式如下:
y = 0.05 + ∑ i = 1 d 0.01 x i + ϵ where ϵ ∼ N ( 0 , 0.0 1 2 ) . y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2). y=0.05+i=1∑d0.01xi+ϵ where ϵ∼N(0,0.012).
我们选择标签是关于输入的线性函数。标签同时被均值为0,标准差为0.01高斯噪声破坏。
为了使过拟合的效果更加明显,我们可以将问题的维数增加到 d = 200 d = 200 d=200,
并使用一个只包含20个样本的小训练集。
import torch
from d2l import torch as d2l
from torch import nnn_train,n_test,num_inputs,batch_size = 20,100,200,5
true_w,true_b = torch.ones((num_inputs,1))*0.01,0.05"""
使用d2l.synthetic_data函数生成了训练数据和测试数据,并使用d2l.load_array函数将数据加载为迭代器。
"""
train_data = d2l.synthetic_data(true_w,true_b,n_train)
train_iter = d2l.load_array(train_data,batch_size)test_data = d2l.synthetic_data(true_w,true_b,n_test)
test_iter = d2l.load_array(test_data,batch_size,is_train= False)
#这里设置is_train=False表示测试数据不用于模型训练,只用于评估模型的性能。
1.3. 从零开始实现
下面我们将从头开始实现权重衰减,只需将 L 2 L_2 L2的平方惩罚添加到原始目标函数中。
1.3.1.初始化模型参数
#初始化模型参数
#我们将定义一个函数来随机初始化模型参数
def init_params():w = torch.normal(0,1,size=(num_inputs,1),requires_grad= True)b = torch.zeros(1,requires_grad=True)return [w,b]
1.3.2. 定义L₂范数惩罚
#定义L2范数惩罚(实现这一惩罚最方便的方法是对所有项求平方后并将它们求和)
def l2_penalty(w):return torch.sum(w.pow(2))/2 #将权重w的平方和除以2,除以2是为了方便计算梯度
1.3.3. 定义训练代码实现
#定义训练代码实现
def train(lambd):w,b = init_params()net,loss = lambda x: d2l.linreg(x,w,b),d2l.squared_lossnum_epochs,lr = 100,0.003animator = d2l.Animator(xlabel="epochs",ylabel="loss",yscale="log",xlim= [5,num_epochs],legend=["train","test"])for epoch in range(num_epochs):for x,y in train_iter:#增加了L2范数惩罚项#广播机制使l2_penalty(w)成为一个长度为batch_size的向量l = loss(net(x),y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w,b],lr,batch_size)if (epoch+1)%5 ==0:animator.add(epoch+1,(d2l.evaluate_loss(net,train_iter,loss),d2l.evaluate_loss(net,test_iter,loss)))print("w的L2范数是:",torch.norm(w).item())
1.3.4. 不管正则化直接训练
#现在用`lambd = 0`禁用权重衰减后运行这个代码。
#注意,这里训练误差有了减少,但测试误差没有减少,这意味着出现了严重的过拟合。train(lambd= 0)#结果:
w的L2范数是: 13.981727600097656
1.3.5. 使用权重衰减
#使用权重衰减来运行代码。
#注意,在这里训练误差增大,但测试误差减小。这正是我们期望从正则化中得到的效果。train(lambd= 3)#结果:
w的L2范数是: 0.3319331705570221d2l.plt.show()
1.4. 简洁实现
深度学习框架为了便于我们使用权重衰减,将权重衰减集成到优化算法中,以便与任何损失函数结合使用。
#在下面的代码中,我们在实例化优化器时直接通过`weight_decay`指定weight decay超参数。
#默认情况下,PyTorch同时衰减权重和偏移。
#这里我们只为权重设置了`weight_decay`,所以偏置参数$b$不会衰减。def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs,1))for param in net.parameters():param.data.normal_() #使用正态分布随机初始化参数loss = nn.MSELoss(reduction="none") #定义损失函数为均方误差损失num_epochs,lr = 100,0.003#偏置参数没有衰减trainer = torch.optim.SGD([{"params":net[0].weight,"weight_decay":wd},{"params":net[0].bias}],lr = lr) #net[0].weight表示模型的权重参数,net[0].bias表示模型的偏置参数。weight_decay参数用于设置权重衰减的强度。animator = d2l.Animator(xlabel="epochs",ylabel="loss",yscale="log",xlim=[5,num_epochs],legend=["train","test"])for epoch in range(num_epochs):for x,y in train_iter:trainer.zero_grad() #清零梯度,以防止梯度累积l = loss(net(x),y)l.mean().backward() #计算损失的平均值,并进行反向传播,计算梯度trainer.step() #更新模型的参数,执行一步优化器的更新if (epoch+1)%5 == 0:animator.add(epoch+1,(d2l.evaluate_loss(net,train_iter,loss),d2l.evaluate_loss(net,test_iter,loss)))print("w的L2范数:", net[0].weight.norm().item()) #打印模型权重的L2范数,用于评估模型的复杂度。train_concise(0)
train_concise(3)
d2l.plt.show()#结果:
w的L2范数: 13.411089897155762
w的L2范数: 0.3319282829761505
总结
为了有效防止模型的过拟合,降低模型的复杂度,提高泛化能力,本章简单记录了一种常见的正则化技术:权重衰减。简单来说权重衰减是通过在损失函数中添加一个正则化项来实现的。这个正则化项通常是模型参数的L2范数(平方和)或L1范数(绝对值和),通过限制模型参数的大小来防止过拟合。
我独泊兮其未兆,如婴儿之未孩,傫傫(lèi lèi)兮,若无所归。
–2023-10-2 进阶篇
相关文章:
深度学习——权重衰减(weight_decay)
深度学习——权重衰减(weight_decay) 文章目录 前言一、权重衰减1.1. 范数与权重衰减1.2. 高维线性回归1.3. 从零开始实现1.3.1.初始化模型参数1.3.2. 定义L₂范数惩罚1.3.3. 定义训练代码实现1.3.4. 不管正则化直接训练1.3.5. 使用权重衰减 1.4. 简洁实现 总结 前言…...
nignx如何部署让前端不用清缓存就可以部署
在Nginx中,可以使用以下方法来部署前端应用程序,使前端用户无需清空缓存即可进行部署: 1、使用版本号:在前端应用程序的构建过程中,可以添加一个独特的版本号到应用程序的名称中。每次部署时,将版本号更新…...
CSS3实现动画加载效果
<!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta name"viewport" content"widthdevice-width, initial-scale1.0" /><title>加载效果</title><link rel"style…...
springboot定时任务Scheduled使用和弊端分析
1.springboot定时任务Scheduled使用说明: (1)创建定时任务类 import com.one.utils.DateUtil; import org.springframework.beans.factory.annotation.Autowired; import...
openGauss学习笔记-93 openGauss 数据库管理-访问外部数据库-oracle_fdw
文章目录 openGauss学习笔记-93 openGauss 数据库管理-访问外部数据库-oracle_fdw93.1 编译oracle_fdw93.2 使用oracle_fdw93.3 常见问题93.4 注意事项 openGauss学习笔记-93 openGauss 数据库管理-访问外部数据库-oracle_fdw openGauss的fdw实现的功能是各个openGauss数据库及…...
【Git】Git下载安装环境配置 下载速度慢的解决方案
这里写自定义目录标题 介绍一、下载官网下载镜像站 二、安装安装成功 三、Git三种界面介绍Git cmd界面展示git bash界面展示git GUI界面展示 四、环境配置配置流程1、打开环境变量界面2、添加环境变量 /删除环境变量3、在变量中找到Git\cmd的值就表示配置成功4、没有找到点击新…...
常见源协议介绍
开源协议(Open Source License)是一种法律文档,用于规定如何使用、修改和分发开源软件和其他开源项目的规则和条件。这些协议允许创作者或组织将其创造的代码或作品以开放源代码的形式共享给他人,以促进协作、创新和知识共享。常见…...
大数据概述(林子雨慕课课程)
文章目录 1. 大数据概述1.1 大数据概念和影响1.2 大数据的应用1.3 大数据的关键技术1.4 大数据与云计算和物联网的关系云计算物联网 1. 大数据概述 大数据的四大特点:大量化、快速化、多样化、价值密度低 1.1 大数据概念和影响 大数据摩尔定律 大数据由结构化和非…...
ES6 class类关键字super
super关键字 在 JavaSCript 中,能通过 extends 关键字去继承父类 super 关键字在子类中有以下用法: 当成函数调用 super() 作为 "属性查询" super.prop 和 super[expr] super() super 作为函数调用时,代表父类的构造函数。 ES6 要求…...
C++并发与多线程(4) | 传递临时对象作为线程参数的一些问题Ⅰ
一、陷阱1 写一个传递临时对象作为线程参数的示例: #include <iostream> #include <vector> #include <thread> using namespace std;void myprint(const int& i, char* pmybuf) {cout << i << endl;cout << pmybuf << endl;r…...
CentOS Integration SIG 正式成立
导读CentOS 董事会已批准成立 CentOS Integration Special Interest Group (SIG)。该小组旨在帮助那些在 Red Hat Enterprise Linux (RHEL) 或特别是其上游 CentOS Stream 上构建产品和服务的人员,验证其能否在未来版本中继续运行。 红帽 RHEL CI 工程师 Aleksandr…...
智能AI系统源码ChatGPT系统源码+详细搭建部署教程+AI绘画系统+已支持OpenAI GPT全模型+国内AI全模型
一、AI创作系统 SparkAi创作系统是基于OpenAI很火的ChatGPT进行开发的Ai智能问答系统,支持OpenAI GPT全模型国内AI全模型。本期针对源码系统整体测试下来非常完美,可以说SparkAi是目前国内一款的ChatGPT对接OpenAI软件系统。那么如何搭建部署AI创作Chat…...
软考程序员考试大纲(2023)
文章目录 前言一、考试说明1.考试目标2.考试要求3.考试科目设置 二、考试范围考试科目1:计算机与软件工程基本知识1.计算机科学基础2.计算机系统基础知识3.系统开发和运行知识4.网络与信息安全基础知识5&am…...
【重拾C语言】七、指针(一)指针与变量、指针操作、指向指针的指针
目录 前言 七、指针 7.1 指针与变量 7.1.1 指针类型和指针变量 7.1.2 指针所指变量 7.1.3 空指针、无效指针 7.2 指针操作 7.2.1 指针的算术运算 7.2.2 指针的比较 7.2.3 指针的递增和递减 7.3 指向指针的指针 前言 指针是C语言中一个重要的概念正确灵活运用指针 可…...
Kafka源码简要分析
目录 一、生产者的初始化流程 二、生产者到缓冲队列的流程 三、Sender拉取数据到Kafka流程 四、消费者初始化 五、主题订阅原理 六、消费者抓取数据原理 七、消费者组初始化 八、消费者组消费流程 九、提交offset原理 一、生产者的初始化流程 首先获取事务id和客户端…...
react 按住ctrl键,点击时会出现菜单的问题修复
问题描述:我需要按住crtl键,然后鼠标点击后做一些逻辑操作,但是出现如下问题 问题一:按住ctrl键后,点击时不触发click事件,只触发 mousedown和mouseup事件。 问题二:按住ctrl键点击时出现菜单…...
【虚拟机栈】
文章目录 1. 虚拟机栈概述2. 局部变量表(Local Variables)3. 操作数栈4. 动态链接4.1 方法的调用:解析与分配 5. 方法返回地址6. 栈的相关面试题 1. 虚拟机栈概述 每个线程在创建时都会创建一个虚拟机栈,其内部保存一个个的栈帧(Stack Frame…...
Linux系列讲解 —— 【fsck】检查并修复Linux文件系统
当文件系统出现损坏时,例如文件无法查看,删除等,可以使用 fsck(File System Consistency Check)进行修复。但是需要注意fsck在修复时,如果检查出某个文件有问题,可能会向用户请求删除。所以&…...
gitlab突然提示我要输入密码了。
用了很长时间的一个gitlab库,今天提交代码的时候突然提示我输入密码了,并且用户还是gitxx.xx.xx.xx的,瞬间懵逼。 想想原因,可能是因为我不久前设置了本地对另外一个git库的远程访问,用的是ssh,操作过程中可…...
业务测试常见问题(一)
如何多维度的分析一个需求? 功能维度:需求中所描述的功能是否实现,与用户的需求是否一致,是否完整符合用户的需求等。 安全性维度:是否有安全漏洞,是否存在未授权访问漏洞等,以保证系统的安全性…...
IntelliJ IDEA失焦自动重启服务的解决方法
IDEA 热部署特性 热部署,即应用正属于运行状态时,我们对应用源码进行了修改更新,在不重新启动应用的情况下,可以能够自动的把更新的内容重新进行编译并部署到服务器上,使修改立即生效。 现象 在使用 IntelliJ IDEA运…...
终端准入控制系统,保障企业内网安全的关键防线
随着网络技术的不断发展,企业面临的安全威胁也越来越多。终端作为承载企业业务的媒介,对内网资产安全有着重要影响。确保内网终端(如PC、BYOD、IoT等)能够得到统一管理,对保护内网安全很有必要。终端准入控制作为一种有…...
mysql-执行计划
1. 执行计划表概述 id相同表示加载表的顺序是从上到下。 id不同id值越大,优先级越高,越先被执行。id有相同,也有不同,同时存在。 id相同的可以认为是一组,从上往下顺序执行;在所有的组中,id的值…...
金蝶云星空和旺店通·企业奇门接口打通对接实战
金蝶云星空和旺店通企业奇门接口打通对接实战 接入系统:金蝶云星空 金蝶K/3Cloud(金蝶云星空)是移动互联网时代的新型ERP,是基于WEB2.0与云技术的新时代企业管理服务平台。金蝶K/3Cloud围绕着“生态、人人、体验”,旨在…...
在服务器上使用nginx改变前端项目请求的url
location /app-dev {rewrite ^/app-dev/(.*) /$1 break;proxy_pass http://152.136.36.251:9999;proxy_set_header Host $host;proxy_set_header X-Real-IP $remote_addr; } location /请求后缀 { rewrite ^/app-dev/(.*) /$1 break; proxy_pass 想要的请求后端的url; …...
【学习笔记】莫比乌斯反演
退役OIer回来受虐啦 一些定义 μ ( x ) { 1 x > 1 ( − 1 ) n x ∏ i 1 n P i 0 o t h e r w i s e \mu(x) \begin{cases} 1 & x > 1 \\ (-1)^n & x \prod _ {i1} ^ {n} P_{i}\\ 0 & otherwise \end{cases} μ(x)⎩ ⎨ ⎧1(−1)n0x>1x∏i1nPi…...
一款构建Python命令行应用的开源库
1 简介 当我们编写 Python 程序时,我们经常需要与用户进行交互,接收输入并输出结果。Python 提供了许多方法来实现这一点,其中一个非常方便的方法是使用 typer 库。typer 是一个用于构建命令行应用程序的 Python 库,它使得创建命令…...
10-Node.js模块化
01.模块化简介 目标 了解模块化概念和好处,以及 CommonJS 标准语法导出和导入 讲解 在 Node.js 中每个文件都被当做是一个独立的模块,模块内定义的变量和函数都是独立作用域的,因为 Node.js 在执行模块代码时,将使用如下所示的…...
数字IC前端学习笔记:数字乘法器的优化设计(Dadda Tree乘法器)
相关阅读 数字IC前端https://blog.csdn.net/weixin_45791458/category_12173698.html?spm1001.2014.3001.5482 华莱士树仍然是一种比较规则的结构(这使得可以方便地生成树的结构),这导致了它所使用的全加器和半加器个数不是最少的ÿ…...
计算机专业毕业设计项目推荐14-文档编辑平台(SpringBoot+Vue+Mysql)
文档编辑平台(SpringBootVueMysql) **介绍****各部分模块实现** 介绍 本系列(后期可能博主会统一为专栏)博文献给即将毕业的计算机专业同学们,因为博主自身本科和硕士也是科班出生,所以也比较了解计算机专业的毕业设计流程以及模式,在编写的…...
外贸网站开发莆田/陕西网站建设制作
每天一习题,提升Python不是问题!!有更简洁的写法请评论告知我! https://www.cnblogs.com/poloyy/category/1676599.html 题目 打印99乘法表 解题思路 外层循环,获取被乘数内层循环,获取乘数 答案 for i in …...
广州建设工程网站/东莞关键词优化软件
神农氏 神农氏据说长得像牛魔王——“牛首人身”,不过他看上去并不像牛魔王那样粗暴。事实上他极具仁慈爱心。这位优秀青年,最大的爱好就是拎了一根棍子,在西部的黄土高坡上考察野生植物,是个十足内向的家伙。他到处收集植物样…...
爱网度假/网站排名优化软件
旧的时间字符串-->simpledataformat1.parse(该字符串) 获得date类型 -->simpledataformat2.format(date) simpledateformat1的pattern的格式和旧的字符串相同,simpledateformat2的pattern格式和希望的相同。 比如 旧的字符串格式为 yyyy-MM-dd,希望…...
前端做兼职网站/搜索引擎seo推广
http://acm.timus.ru/problem.aspx?space1&num1806 只要算法对 ural 一般不会卡时间的 这个题是一个简单的最短路 spfa 关键在于找边 找边的方法是 对于每一个点 枚举它的所有可能的变化 搜索是否有和变化后的字符串一样的 搜索的时候既可以用 map 也可以 自己写字典树 m…...
wordpress建立移动m站/免费推广网站入口
目录 一、StreamTokenizer中的基本方法 二、StreamTokenizer的构造方法 2.1 指定单词要素 2.2 指定分隔符 三、算法题用法 3.1 普通用法 3.2 多组输入 一、StreamTokenizer中的基本方法 commenChar(int ch) - 指定某个字符为注释字符,此字符之后直到行结尾都被stre…...
武汉平价做网站/合肥seo推广公司哪家好
对于你在这里所做的事情,使用反射似乎不是一个好的设计.最好使用Map< String,Integer>例如:static final Map VALUES_BY_NAME;static {final Map valuesByName new HashMap<>();valuesByName.put("width", 5);valuesByName.put("potato…...