当前位置: 首页 > news >正文

【深度学习笔记】09 权重衰减

09 权重衰减

    • 范数和权重衰减
    • 利用高维线性回归实现权重衰减
    • 权重衰减的简洁实现

范数和权重衰减

在训练参数化机器学习模型时,权重衰减(decay weight)是最广泛应用的正则化技术之一,它通常也被称为 L 2 L_2 L2正则化。这项技术通过函数与零的距离来衡量函数的复杂度,
因为在所有函数 f f f中,函数 f = 0 f = 0 f=0(所有输入都得到值 0 0 0
在某种意义上是最简单的。

一种简单的方法是通过线性函数
f ( x ) = w ⊤ x f(\mathbf{x}) = \mathbf{w}^\top \mathbf{x} f(x)=wx
中的权重向量的某个范数来度量其复杂性,
例如 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2
要保证权重向量比较小,
最常用方法是将其范数作为惩罚项加到最小化损失的问题中。
将原来的训练目标最小化训练标签上的预测损失,
调整为最小化预测损失和惩罚项之和。

损失由下式给出:

L ( w , b ) = 1 n ∑ i = 1 n 1 2 ( w ⊤ x ( i ) + b − y ( i ) ) 2 . L(\mathbf{w}, b) = \frac{1}{n}\sum_{i=1}^n \frac{1}{2}\left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right)^2. L(w,b)=n1i=1n21(wx(i)+by(i))2.

x ( i ) \mathbf{x}^{(i)} x(i)是样本 i i i的特征,
y ( i ) y^{(i)} y(i)是样本 i i i的标签,
( w , b ) (\mathbf{w}, b) (w,b)是权重和偏置参数。

为了惩罚权重向量的大小,
必须以某种方式在损失函数中添加 ∥ w ∥ 2 \| \mathbf{w} \|^2 w2
我们通过正则化常数 λ \lambda λ来描述这种权衡,
这是一个非负超参数,我们使用验证数据拟合:

L ( w , b ) + λ 2 ∥ w ∥ 2 , L(\mathbf{w}, b) + \frac{\lambda}{2} \|\mathbf{w}\|^2, L(w,b)+2λw2,

对于 λ = 0 \lambda = 0 λ=0,我们恢复了原来的损失函数。
对于 λ > 0 \lambda > 0 λ>0,我们限制 ∥ w ∥ \| \mathbf{w} \| w的大小。
这里我们仍然除以 2 2 2:当我们取一个二次函数的导数时,
2 2 2 1 / 2 1/2 1/2会抵消。

通过平方 L 2 L_2 L2范数,我们去掉平方根,留下权重向量每个分量的平方和。
这使得惩罚的导数很容易计算:导数的和等于和的导数。

L 2 L_2 L2正则化回归的小批量随机梯度下降更新如下式:

w ← ( 1 − η λ ) w − η ∣ B ∣ ∑ i ∈ B x ( i ) ( w ⊤ x ( i ) + b − y ( i ) ) . \begin{aligned} \mathbf{w} & \leftarrow \left(1- \eta\lambda \right) \mathbf{w} - \frac{\eta}{|\mathcal{B}|} \sum_{i \in \mathcal{B}} \mathbf{x}^{(i)} \left(\mathbf{w}^\top \mathbf{x}^{(i)} + b - y^{(i)}\right). \end{aligned} w(1ηλ)wBηiBx(i)(wx(i)+by(i)).

我们根据估计值与观测值之间的差异来更新 w \mathbf{w} w
然而,我们同时也在试图将 w \mathbf{w} w的大小缩小到零。
这就是为什么这种方法有时被称为权重衰减
我们仅考虑惩罚项,优化算法在训练的每一步衰减权重。
与特征选择相比,权重衰减为我们提供了一种连续的机制来调整函数的复杂度。
较小的 λ \lambda λ值对应较少约束的 w \mathbf{w} w
而较大的 λ \lambda λ值对 w \mathbf{w} w的约束更大。

是否对相应的偏置 b 2 b^2 b2进行惩罚在不同的实践中会有所不同,
在神经网络的不同层中也会有所不同。
通常,网络输出层的偏置项不会被正则化。

利用高维线性回归实现权重衰减

%matplotlib inline
import torch
from torch import nn
from d2l import torch as d2l

首先生成数据,生成公式如下:

y = 0.05 + ∑ i = 1 d 0.01 x i + ϵ where  ϵ ∼ N ( 0 , 0.0 1 2 ) . y = 0.05 + \sum_{i = 1}^d 0.01 x_i + \epsilon \text{ where } \epsilon \sim \mathcal{N}(0, 0.01^2). y=0.05+i=1d0.01xi+ϵ where ϵN(0,0.012).

选择标签是关于输入的线性函数。
标签同时被均值为0,标准差为0.01高斯噪声破坏。
为了使过拟合的效果更加明显,我们可以将问题的维数增加到 d = 200 d = 200 d=200
并使用一个只包含20个样本的小训练集。

n_train, n_test, num_inputs, batch_size = 20, 100, 200, 5
true_w, true_b = torch.ones((num_inputs, 1)) * 0.01, 0.05
train_data = d2l.synthetic_data(true_w, true_b, n_train)
train_iter = d2l.load_array(train_data, batch_size)
test_data = d2l.synthetic_data(true_w, true_b, n_test)
test_iter = d2l.load_array(test_data, batch_size, is_train=False)

初始化模型参数

定义一个函数来随机初始化模型参数

def init_params():w = torch.normal(0, 1, size = (num_inputs, 1), requires_grad = True)b = torch.zeros(1, requires_grad = True)return [w, b]

定义 L 2 L_2 L2范数惩罚

def l2_penalty(w):return torch.sum(w.pow(2)) / 2

定义训练代码实现

下面的代码将模型拟合训练数据集,并在测试数据集上进行评估。

def train(lambd):w, b = init_params()net, loss = lambda X: d2l.linreg(X, w, b), d2l.squared_lossnum_epochs, lr = 100, 0.003animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:# 增加了L2范数惩罚项,# 广播机制使l2_penalty(w)成为一个长度为batch_size的向量l = loss(net(X), y) + lambd * l2_penalty(w)l.sum().backward()d2l.sgd([w, b], lr, batch_size)if (epoch + 1) % 5 == 0:animator.add(epoch + 1, (d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数是:', torch.norm(w).item())

忽略正则化直接训练

用lamdb=0禁用权重衰减后运行代码。此时训练误差有所减少,但测试误差没有减少,这意味着出现了严重的过拟合。

train(lambd = 0)
w的L2范数是: 14.971677780151367

在这里插入图片描述

使用权重衰减

使用权重衰减来运行代码。此时训练误差增大,但测试误差减小。这正是我们期望从正则化中得到的效果。

train(lambd = 3)
w的L2范数是: 0.34405317902565

在这里插入图片描述

权重衰减的简洁实现

在实例化优化器时直接通过weight_decay指定weight decay超参数。默认情况下,PyTorch同时衰减权重和便宜。这里只为权重设置了weight_decay,所以偏置参数 b b b不会衰减。

def train_concise(wd):net = nn.Sequential(nn.Linear(num_inputs, 1))for param in net.parameters():param.data.normal_()loss = nn.MSELoss(reduction='none')num_epochs, lr = 100, 0.003# 偏置参数没有衰减trainer = torch.optim.SGD([{"params":net[0].weight,'weight_decay': wd},{"params":net[0].bias}], lr=lr)animator = d2l.Animator(xlabel='epochs', ylabel='loss', yscale='log',xlim=[5, num_epochs], legend=['train', 'test'])for epoch in range(num_epochs):for X, y in train_iter:trainer.zero_grad()l = loss(net(X), y)l.mean().backward()trainer.step()if (epoch + 1) % 5 == 0:animator.add(epoch + 1,(d2l.evaluate_loss(net, train_iter, loss),d2l.evaluate_loss(net, test_iter, loss)))print('w的L2范数:', net[0].weight.norm().item())
train_concise(0)
w的L2范数: 13.416662216186523

在这里插入图片描述

train_concise(3)
w的L2范数: 0.39273694157600403

在这里插入图片描述

相关文章:

【深度学习笔记】09 权重衰减

09 权重衰减 范数和权重衰减利用高维线性回归实现权重衰减初始化模型参数定义 L 2 L_2 L2​范数惩罚定义训练代码实现忽略正则化直接训练使用权重衰减 权重衰减的简洁实现 范数和权重衰减 在训练参数化机器学习模型时,权重衰减(decay weight&#xff09…...

三大兼容 | 人大金仓兼容+优化MySQL用户变量特性

目前,KingbaseES对MySQL的兼容性,已从功能兼容阶段过渡到强性能兼容、生态全面兼容阶段,针对客户常常遇到的用户变量问题,KingbaseES在兼容MySQL用户变量功能的基础上,优化了MySQL用户变量的一些原生问题,使…...

Git介绍与安装使用

目录 1.Git初识 1.1提出问题 1.2如何解决--版本控制器 1.3注意事项 2.Git安装 2.1Linux-centos安装 2.2Linux-ubuntu安装 2.3Windows安装 3.Git基本操作 3.1创建Git本地仓库 3.2配置Git 4.认识⼯作区、暂存区、版本库 1.Git初识 1.1提出问题 不知道你工作或学习时…...

理解DuLinkList L中的“”引用符号

在C中,DuLinkList &L 这种形式的参数表示 L 是一个 DuLinkList 类型的引用。这里的 & 符号表示引用。 引用是C的一个特性,它提供了一种方式来访问已存在的变量的别名。当你对引用进行操作时,实际上是在操作它所引用的变量。如果你在…...

前端并发多个请求并失败重发

const MAX_RETRIES 3;// 模拟请求 function makeRequest(url) {return new Promise((resolve, reject) > {setTimeout(() > {Math.random() < 0.75 ? resolve(${url} 成功) : reject(${url} 失败); // 随机决定请求是否成功}, Math.random() * 2000); // 随机延时执…...

【Qt开发流程】之对象模型2:属性系统

描述 Qt提供了一个复杂的属性系统&#xff0c;类似于一些编译器供应商提供的属性系统。然而&#xff0c;作为一个独立于编译器和平台的库&#xff0c;Qt不依赖于非标准的编译器特性&#xff0c;如__property或[property]。 Qt解决方案适用于Qt支持的所有平台上的任何标准c编译…...

PHP之curl详细讲解

cURL&#xff08;全称为Client for URLs&#xff09;是一个功能强大的开源库&#xff0c;用于在多种协议上进行数据传输、发送HTTP请求和获取响应。它支持多种协议&#xff0c;包括HTTP、HTTPS、FTP、SMTP等&#xff0c;并且能够与各种服务器进行通信。 cURL库可以通过命令行工…...

R语言30分钟上手

文章目录 1. 环境&安装1.1. rstudio保存工作空间 2. 创建数据集2.1. 数据集概念2.2. 向量、矩阵2.3. 数据框2.3.1. 创建数据框2.3.2. 创建新变量2.3.3. 变量的重编码2.3.4. 列重命名2.3.5. 缺失值2.3.6. 日期值2.3.7. 数据框排序2.3.8. 数据框合并(合并沪深300和中证500收盘…...

上下拉电阻会增强驱动能力吗?

最近看到一个关于上下拉电阻的问题&#xff0c;发现不少人认为上下拉电阻能够增强驱动能力。随后跟几个朋友讨论了一下&#xff0c;大家一致认为不存在上下拉电阻增强驱动能力这回事&#xff0c;因为除了OC输出这类特殊结构外&#xff0c;上下拉电阻就是负载&#xff0c;只会减…...

题目:小明的彩灯(蓝桥OJ 1276)

题目描述&#xff1a; 解题思路&#xff1a; 一段连续区间加减&#xff0c;采用差分。最终每个元素结果与0比较大小&#xff0c;比0小即负数输出0。 题解&#xff1a; #include<bits/stdc.h> using namespace std;using ll long long; const int N 1e5 10; ll a[N],…...

换元法求不定积分

1.一般步骤&#xff1a;选取换元对象&#xff08;不一定是式子中的值&#xff0c;也可以是式子中的最小公倍数或者最大公因数&#xff09;&#xff0c;然后将dx换为dt*t的导数&#xff0c;再用t将原式表示&#xff0c;化简计算即可 2. 3. 4. 5. 6....

在Docker容器中启用SSH服务,实现外部访问的详细教程

目录 步骤 1: 安装 SSH 服务器 步骤 2: 配置 SSH 服务器 步骤 3: 设置 SSH 用户 步骤 4: 重启 SSH 服务器 步骤 5: 映射容器端口 步骤 6: 使用 SSH 连接到容器 要在Docker容器中启用SSH服务&#xff0c;以便从外部访问&#xff0c;您需要执行以下步骤&#xff1a; 步骤 …...

Go 模块系统最小版本选择法 MVS 详解

目录 Golang 模块系统简介 包版本管理 最小版本选择&#xff08;MVS&#xff09;原理 MVS 的优点 MVS的缺点 实际使用MVS 小结 参考资料 Golang 模块系统简介 Golang 模块系统是 Go 1.11 版本引入的一个新特性&#xff0c;主要目的是解决 Go 项目中的依赖管理问题。在模…...

ifstream读取txt中的中文数据转成QString出现乱码

使用ifstream从txt文本中读取中文数据到string&#xff0c;再将string转成QString输出时出现了乱码。 分析&#xff1a;如果ifstream能成功从txt文本中读出中文数据&#xff0c;那大概率txt用的编码是ANSI编码&#xff08;GBK就是ANSI的一种&#xff09;&#xff0c;那么在转成…...

UE4 双屏分辨率设置

背景&#xff1a; 做了一个UI 应用&#xff0c;需要在双屏上进行显示。 分辨率如下&#xff1a;3840*1080&#xff1b; 各种折腾&#xff0c;其实很简单&#xff1a; 主要是在全屏模式的时候 一开始没有选对&#xff0c;双屏总是不稳定。 全屏模式改成&#xff1a;Windows 之…...

$sformat在仿真中打印文本名的使用

在仿真中&#xff0c;定义队列&#xff0c;使用任务进行函数传递&#xff0c;并传递文件名&#xff0c;传递队列&#xff0c;进行打印 $sformat(filename, “./data_log/%0d_%0d_%0d_0.txt”, f_num, lane_num,dt); 使用此函数可以自定义字符串&#xff0c;在仿真的时候进行文件…...

【Rust】结构体与枚举

文章目录 结构体struct基础用法使用字段初始化简写语法使用没有命名字段的元组结构体来创建不同的类型没有任何字段的类单元结构体方法语法关联函数多个 impl 块 枚举枚举值Option 结构体struct 基础用法 一个存储用户账号信息的结构体&#xff1a; struct User {active: bo…...

CentOS7 防火墙常用命令

以下是在 CentOS 7 上使用 firewall-cmd 命令管理防火墙时的一些常用命令&#xff1a; 检查防火墙状态&#xff1a; sudo firewall-cmd --state 启动防火墙&#xff1a; sudo systemctl start firewalld 停止防火墙&#xff1a; sudo systemctl stop firewalld 重启防火墙&…...

【无标题】什么是UL9540测试,UL9540:2023版本增加哪些测试项目

什么是UL9540测试&#xff0c;UL9540:2023版本增加哪些测试项目 UL 9540是美国安全实验室&#xff08;Underwriters Laboratories&#xff09;发布的标准&#xff0c;名称为"UL 9540: Energy Storage Systems and Equipment"&#xff0c;翻译为中文为"能量存储…...

springcloud整合Oauth2自定义登录/登出接口

我使用的是password模式&#xff0c;并配置了token模式 一、登录 (这里我使用的示例是用户名密码认证方式) 1. Oath2提供默认登录授权接口 org.springframework.security.oauth2.provider.endpoint.postAccess; Tokenpublic ResponseEntity<OAuth2AccessToken> pos…...

Oracle常见内置程序包的使用Package

Oracle常见内置程序包的使用 点击此处可跳转至&#xff1a;Oracle的程序包(Package)&#xff0c;对包的基础进行学习常见内置程序包的使用Package1、DBMS_OUTPUT包2、DBMS_XMLQUERY包3、DBMS_RANDOM包4、UTL_FILE包5、DBMS_JOB包6、DBMS_LOB包7、DBMS_SQL包8、DBMS_LOCK包9、DB…...

Flutter:视频下载案例

前言 最近在研究视频下载&#xff0c;因此打算一边研究一边记录一下。方便以后使用时查看。 使用到的库有&#xff1a; permission_handler 11.1.0 &#xff1a;权限请求 flutter_downloader 1.11.5&#xff1a;文件下载器 path_provider 2.1.1&#xff1a;路径处理 视频…...

要求CHATGPT高质量回答的艺术:提示工程技术的完整指南

要求CHATGPT高质量回答的艺术&#xff1a;提示工程技术的完整指南 第2章&#xff1a;指令提示技术 现在&#xff0c;让我们开始探索“指令提示技术”&#xff0c;以及如何使用它从ChatGPT生成高质量的文本。 指令提示技术是一种通过为模型提供特定指令来指导ChatGPT输出的方…...

JDK 历史版本下载以及指定版本应用

参考&#xff1a; 官网下载JAVA的JDK11版本&#xff08;下载、安装、配置环境变量&#xff09;_java11下载-CSDN博客 Gradle&#xff1a;执行命令时指定 JDK 版本 - 微酷网 下载 打开官网地址 Java Downloads | Oracle 当前版本在这里&#xff0c;但是我们要下载历史版本 选…...

Linux基础项目开发1:量产工具——UI系统(五)

前言&#xff1a; 前面我们已经把显示系统、输入系统、文字系统搭建好了&#xff0c;现在我们就要给它实现按钮操作了&#xff0c;也就是搭建UI系统&#xff0c;下面让我们一起实现UI系统的搭建吧 目录 一、按钮数据结构抽象 ui.h 二、按键编程 1.button.c 2.disp_manager…...

面试就是这么简单,offer拿到手软(四)—— 常见java152道基础面试题

面试就是这么简单&#xff0c;offer拿到手软&#xff08;一&#xff09;—— 常见非技术问题回答思路 面试就是这么简单&#xff0c;offer拿到手软&#xff08;二&#xff09;—— 常见65道非技术面试问题 面试就是这么简单&#xff0c;offer拿到手软&#xff08;三&#xff…...

深入理解Redis分片策略:提升系统性能的关键一步

目录 引言 1. 一致性哈希算法 2. 范围分片 3. 哈希槽分片 实战经验分享 结论 引言 Redis作为一款高性能的键值存储系统&#xff0c;为了应对大规模数据和高并发的访问&#xff0c;引入了分片策略&#xff0c;使得数据能够分布存储在多个节点上&#xff0c;实现系统的横向…...

【数据结构(七)】查找算法

文章目录 查找算法介绍1. 线性查找算法2. 二分查找算法2.1. 思路分析2.2. 代码实现2.3. 功能拓展 3. 插值查找算法3.1. 前言3.2. 相关概念3.3. 实例应用 4. 斐波那契(黄金分割法)查找算法4.1. 斐波那契(黄金分割法)原理4.2. 实例应用 查找算法介绍 在 java 中&#xff0c;我们…...

Android画布Canvas绘制drawBitmap基于源Rect和目的Rect,Kotlin

Android画布Canvas绘制drawBitmap基于源Rect和目的Rect&#xff0c;Kotlin <?xml version"1.0" encoding"utf-8"?> <androidx.appcompat.widget.LinearLayoutCompat xmlns:android"http://schemas.android.com/apk/res/android"xmlns…...

深度优先搜索LeetCode979. 在二叉树中分配硬币

给你一个有 n 个结点的二叉树的根结点 root &#xff0c;其中树中每个结点 node 都对应有 node.val 枚硬币。整棵树上一共有 n 枚硬币。 在一次移动中&#xff0c;我们可以选择两个相邻的结点&#xff0c;然后将一枚硬币从其中一个结点移动到另一个结点。移动可以是从父结点到…...

免费入驻的网站设计平台/百度竞价推广点击器

看了标题&#xff0c;可能很多人会心生疑问&#xff0c;比如……DAX语言是什么&#xff1f;答&#xff1a;……说来话长&#xff0c;简而言之&#xff0c;DAX&#xff0c;即数据分析表达式语言&#xff0c;是PowerPivot和SQL Server分析服务表格式的语言&#xff0c;具有强悍而…...

常州专业网站建设费用/seo优化方式包括

一、前言 该技术博客是关于我在B站自学计算机组成原理知识的笔记&#xff0c;该技术博客是根据 2019 王道考研 计算机组成原理 课程总结而成。该技术博客关于记录视频课程内容&#xff0c;如果对你有帮助&#xff0c;方便大家之后的学习&#xff0c;该系列博客会持续更新至全套…...

华为网站建站/网上代写文章一般多少钱

之前编译的参数没有添加mysqli支持&#xff0c;因代码需要&#xff0c;必须添加上去。这次尝试一下扩展编译。 由于是源代码安装的&#xff0c;所以在php的目录bin下面有相关的命令。 到解压的源代码的ext目录下面&#xff0c;进入mysqli目录&#xff0c;执行/home/php/bin/php…...

网页版梦幻西游谛听怎么获得/杭州网络排名优化

使用的环境&#xff1a;Xcode V8.3.3 学习OpenGL的过程中&#xff0c;会使用到gltools,glew,glfw3,glut等库文件&#xff0c;glut包Mac自带&#xff0c;故不需要考虑。主要考虑的是另三个包文件怎样安装&#xff0c;配置。本文主要讲两大部分&#xff1a; 1. glew&#xff0c;…...

室内设计师资格证书/seo技术教程

1.在Java中&#xff0c;如果父类中的某些方法不包含任何逻辑&#xff0c;并且需要有子类重写&#xff0c;应该使用&#xff08;c&#xff09;关键字来申明父类的这些方法。 a. Finalc b. Static c. Abstract d. Void 2.给定两个java程序&#xff0c;如下&#xff1a; public…...

什么网站容易做流量/营销模式方案

Linux SCP和SSH命令平时做Oracle实验、经常会在多个主机间传数据或者登入、这两个命令经常用到这里以最简单的实例介绍一下、以免自己忘了㈠ SCP www.2cto.comscp是在两台机器间复制传输数据的命令、其实质相当于利用SSH协议来传输数据的cp命令复制远程服务器的文件到本地&am…...