当前位置: 首页 > news >正文

【Python】学习率调整策略详解和示例

学习率调整得当将有助于算法快速收敛和获取全局最优,以获得更好的性能。本文对学习率调度器进行示例介绍。

  • 学习率调整的意义
  • 基础示例
    • 无学习率调整方法
    • 学习率调整方法一
    • 多因子调度器
    • 余弦调度器
  • 结论

学习率调整的意义

首先,学习率的大小很重要。如果它太大,优化就会发散;如果它太小,训练就会需要过长时间,或者我们最终只能得到次优的结果(陷入局部最优)。我们之前看到问题的条件数很重要。直观地说,这是最不敏感与最敏感方向的变化量的比率。

其次,衰减速率同样很重要。如果学习率持续过高,我们可能最终会在最小值附近弹跳,从而无法达到最优解。 简而言之,我们希望速率衰减,但要比慢,这样能成为解决凸问题的不错选择

另一个同样重要的方面是初始化。这既涉及参数最初的设置方式,又关系到它们最初的演变方式。这被戏称为预热(warmup),即我们最初开始向着解决方案迈进的速度有多快。一开始的大步可能没有好处,特别是因为最初的参数集是随机的。最初的更新方向可能也是毫无意义的

鉴于管理学习率需要很多细节,因此大多数深度学习框架都有自动应对这个问题的工具。本文将梳理不同的调度策略对准确性的影响,并展示如何通过学习率调度器(learning rate scheduler)来有效管理。

基础示例

我们从一个简单的问题开始,这个问题可以轻松计算,但足以说明要义。 为此,我们选择了一个稍微现代化的LeNet版本(激活函数使用relu而不是sigmoid,汇聚层使用最大汇聚层而不是平均汇聚层),并应用于Fashion-MNIST数据集。 此外,我们混合网络以提高性能。

无学习率调整方法

import math
import torch
from torch import nn
from torch.optim import lr_scheduler, SGD
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from torchvision.utils import make_grid
import matplotlib.pyplot as pltdef load_data_fashion_mnist(batch_size):# 定义数据预处理transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])# 加载训练集和测试集train_dataset = datasets.FashionMNIST(root='./data', train=True, download=True, transform=transform)test_dataset = datasets.FashionMNIST(root='./data', train=False, download=True, transform=transform)# 创建数据加载器train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)return train_loader, test_loader
def net_fn():model = nn.Sequential(nn.Conv2d(1, 6, kernel_size=5, padding=2), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Conv2d(6, 16, kernel_size=5), nn.ReLU(),nn.MaxPool2d(kernel_size=2, stride=2),nn.Flatten(),nn.Linear(16 * 5 * 5, 120), nn.ReLU(),nn.Linear(120, 84), nn.ReLU(),nn.Linear(84, 10))return modeldef train(net, train_loader, test_loader, num_epochs, loss, optimizer, device, scheduler=None):net.to(device)running_loss = 0.0train_losses = []test_losses = []test_accuracies = []for epoch in range(num_epochs):for i, (inputs, labels) in enumerate(train_loader):inputs, labels = inputs.to(device), labels.to(device)# Zero the parameter gradientsoptimizer.zero_grad()# Forward passoutputs = net(inputs)loss_value = loss(outputs, labels)# Backward and optimizeloss_value.backward()optimizer.step()# Print statisticsrunning_loss += loss_value.item()# if i % 200 == 199:  # print every 200 mini-batches#     print(f'[{epoch + 1}, {i + 1}] loss: {running_loss / 200}')#     running_loss = 0.0train_losses.append(running_loss / len(train_loader))# Evaluate the model on the test datasettest_loss, test_acc = evaluate(net, test_loader, device)test_losses.append(test_loss)test_accuracies.append(test_acc)print(f'Epoch {epoch+1}, Train Loss: {train_losses[-1]:.4f}, Test Loss: {test_losses[-1]:.4f}, Test Acc: {test_accuracies[-1]:.2f}')if scheduler:if scheduler.__module__ == lr_scheduler.__name__:scheduler.step()else:for param_group in  optimizer.param_groups:param_group['lr'] = scheduler(epoch)plt.figure(figsize=(10, 6))plt.plot(range(1, num_epochs + 1), train_losses, label='Training Loss')plt.plot(range(1, num_epochs + 1), test_losses, label='Test Loss')plt.plot(range(1, num_epochs + 1), test_accuracies, label='Test Accuracy')plt.title('Training, Test Losses and Test Accuracy')plt.xlabel('Epoch')plt.ylabel('Loss / Accuracy')plt.legend()plt.grid(True)plt.savefig("1.jpg")plt.show()def evaluate(model, data_loader, device):model.eval()test_loss = 0correct = 0with torch.no_grad():for inputs, labels in data_loader:inputs, labels = inputs.to(device), labels.to(device)outputs = model(inputs)test_loss += nn.CrossEntropyLoss(reduction='sum')(outputs, labels).item()_, predicted = torch.max(outputs.data, 1)correct += (predicted == labels).sum().item()test_loss /= len(data_loader.dataset)accuracy = correct / len(data_loader.dataset)#accuracy = 100. * correct / len(data_loader.dataset)return test_loss, accuracy# Device configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# Define the model
model = net_fn()# Define the loss function
loss = nn.CrossEntropyLoss()# Define the optimizer
lr=0.3
optimizer = SGD(model.parameters(), lr=lr)# Load the dataset
batch_size=128
train_loader, test_loader=load_data_fashion_mnist(batch_size)
num_epochs=30
train(model, train_loader, test_loader, num_epochs, loss, optimizer, device)

这里没有使用学习率调整策略。训练过程和结果如下图所示:

.
.
.
.
Epoch 23, Train Loss: 0.1247, Test Loss: 0.3939, Test Acc: 0.90
Epoch 24, Train Loss: 0.1236, Test Loss: 0.4370, Test Acc: 0.89
Epoch 25, Train Loss: 0.1167, Test Loss: 0.4117, Test Acc: 0.89
Epoch 26, Train Loss: 0.1169, Test Loss: 0.4440, Test Acc: 0.89
Epoch 27, Train Loss: 0.1163, Test Loss: 0.4336, Test Acc: 0.89
Epoch 28, Train Loss: 0.1055, Test Loss: 0.4312, Test Acc: 0.90
Epoch 29, Train Loss: 0.1065, Test Loss: 0.4942, Test Acc: 0.89
Epoch 30, Train Loss: 0.1051, Test Loss: 0.4763, Test Acc: 0.89

在这里插入图片描述

学习率调整方法一

设置在每个迭代轮数(甚至在每个小批量)之后向下调整学习率。 例如,以动态的方式来响应优化的进展情况。

在代码最后添加SquareRootScheduler类,并更新train()函数参数,其它内容不变。

class SquareRootScheduler:def __init__(self, lr=0.1):self.lr = lrdef __call__(self, num_update):return self.lr * pow(num_update + 1.0, -0.5)scheduler = SquareRootScheduler(lr=0.1)
train(model, train_loader, test_loader, num_epochs, loss, optimizer, device,scheduler)

运行代码,可得相应参数值和变化过程,如下所示。

Epoch 23, Train Loss: 0.1823, Test Loss: 0.2811, Test Acc: 0.90
Epoch 24, Train Loss: 0.1801, Test Loss: 0.2800, Test Acc: 0.90
Epoch 25, Train Loss: 0.1767, Test Loss: 0.2819, Test Acc: 0.90
Epoch 26, Train Loss: 0.1747, Test Loss: 0.2800, Test Acc: 0.91
Epoch 27, Train Loss: 0.1720, Test Loss: 0.2818, Test Acc: 0.90
Epoch 28, Train Loss: 0.1689, Test Loss: 0.2856, Test Acc: 0.90
Epoch 29, Train Loss: 0.1669, Test Loss: 0.2907, Test Acc: 0.90
Epoch 30, Train Loss: 0.1641, Test Loss: 0.2813, Test Acc: 0.90

在这里插入图片描述
我们可以看出曲线比没有策略时平滑了很多,效果有所提升。

多因子调度器

多因子调度器。
在这里插入图片描述
在这里插入图片描述
代码部分修改:

scheduler =lr_scheduler.MultiStepLR(optimizer, milestones=[15, 30], gamma=0.5)

运行结果为:
在这里插入图片描述
可见效果不理想,出现过拟合现象。

余弦调度器

余弦调度器是 (Loshchilov and Hutter, 2016)提出的一种启发式算法。 它所依据的观点是:我们可能不想在一开始就太大地降低学习率,而且可能希望最终能用非常小的学习率来“改进”解决方案。 这产生了一个类似于余弦的调度,函数形式如下所示,学习率的值在
之间。
在这里插入图片描述
代码中添加CosineScheduler类和修改scheduler。

class CosineScheduler:def __init__(self, max_update, base_lr=0.01, final_lr=0,warmup_steps=0, warmup_begin_lr=0):self.base_lr_orig = base_lrself.max_update = max_updateself.final_lr = final_lrself.warmup_steps = warmup_stepsself.warmup_begin_lr = warmup_begin_lrself.max_steps = self.max_update - self.warmup_stepsdef get_warmup_lr(self, epoch):increase = (self.base_lr_orig - self.warmup_begin_lr) \* float(epoch) / float(self.warmup_steps)return self.warmup_begin_lr + increasedef __call__(self, epoch):if epoch < self.warmup_steps:return self.get_warmup_lr(epoch)if epoch <= self.max_update:self.base_lr = self.final_lr + (self.base_lr_orig - self.final_lr) * (1 + math.cos(math.pi * (epoch - self.warmup_steps) / self.max_steps)) / 2return self.base_lr#scheduler = SquareRootScheduler(lr=0.1)
#scheduler =lr_scheduler.MultiStepLR(optimizer, milestones=[15, 30], gamma=0.5)
scheduler = CosineScheduler(max_update=20, base_lr=0.3, final_lr=0.01)
train(model, train_loader, test_loader, num_epochs, loss, optimizer, device,scheduler)

运行结果如下:

在这里插入图片描述
过拟合现象消失,效果提升。

结论

在开发时应根据自己需要,选择合适的学习率调整策略。优化在深度学习中有多种用途。对于同样的训练误差而言,选择不同的优化算法和学习率调度,除了最大限度地减少训练时间,可以导致测试集上不同的泛化和过拟合量。

注:部分内容摘选子书籍《动手学深度学习》

相关文章:

【Python】学习率调整策略详解和示例

学习率调整得当将有助于算法快速收敛和获取全局最优&#xff0c;以获得更好的性能。本文对学习率调度器进行示例介绍。 学习率调整的意义基础示例无学习率调整方法学习率调整方法一多因子调度器余弦调度器 结论 学习率调整的意义 首先&#xff0c;学习率的大小很重要。如果它…...

【Linux实践室】Linux用户管理实战指南:用户密码管理操作详解

&#x1f308;个人主页&#xff1a;聆风吟_ &#x1f525;系列专栏&#xff1a;Linux实践室、网络奇遇记 &#x1f516;少年有梦不应止于心动&#xff0c;更要付诸行动。 文章目录 一. ⛳️任务描述二. ⛳️相关知识2.1 &#x1f514;用户密码存放地及方式2.2 &#x1f514;使用…...

UE5学习日记——蓝图节点前缀关键字整理

一、起因 节点如海&#xff0c;中英文翻译的时候还是有差别的&#xff0c;比如&#xff1a; 同一个中文&#xff0c;可能在英文里完全不同&#xff0c;连出现位置可能都不一样 附加 Attach Actor To Component&#xff08;将Actor附加到组件&#xff09;Append Array&#xf…...

浅析机器学习的常用方法

引言&#xff1a; 机器学习&#xff08;Machine Learning&#xff0c;ML&#xff09;是一种以计算机程序为基础&#xff0c;在不需要明确编程的情况下&#xff0c;对数据进行分析和处理的人工智能技术。与传统的计算机编程相比&#xff0c;机器学习的区别在于它通过数据建立模…...

大数据开发(日志离线分析项目)

大数据开发&#xff08;日志离线分析项目&#xff09; 一、项目需求1、使用jqueryecharts的方式调用程序后台提供的rest api接口&#xff0c;获取json数据&#xff0c;然后通过jquerycss的方式进行数据展示。工作流程如下&#xff1a;2、七大角度1、用户基本信息分析模块2、浏览…...

PostgreSQL技术大讲堂 - 第48讲:PG高可用实现keepalived

PostgreSQL从小白到专家&#xff0c;是从入门逐渐能力提升的一个系列教程&#xff0c;内容包括对PG基础的认知、包括安装使用、包括角色权限、包括维护管理、、等内容&#xff0c;希望对热爱PG、学习PG的同学们有帮助&#xff0c;欢迎持续关注CUUG PG技术大讲堂。 第48讲&#…...

【若依 SpringBoot 前后端分离版】修改加密传输后密码错误的解决方法(附排错过程)

目录 排错过程 报错信息 SysLoginController SysLoginService&#xff08;问题核心&#xff09; 太长不看版&#xff1a;解决方法 文章传送门&#xff1a;若依(RuoYi)SpringBoot框架密码加密传输(前后分离板)_若依密码加密方式-CSDN博客文章浏览阅读1.5w次&#xff0c;点赞…...

发送请求- header配置

请求头里是客户端的要求&#xff0c;把你的诉求告诉服务端&#xff0c;服务端按照你的要求返回数据 &#xff0c; 请求header需要严格全配置&#xff0c;把请求header全部传入&#xff0c;不能频繁访问&#xff0c;让后端知道它是正常请求 一般只配置User-Agent和Content Typ…...

C语言重难知识点

C语言重难知识点 if(a=1) 为真函数指针的调用(int)2.9 = 2逗号运算符,最右边表达式值作为整个逗号表达式的值。文件操作if(a=1) 为真 int a=0,b=0,c=0; if(a...

jMeter学习

一. JMeter介绍 1. 什么是JMeter&#xff1f; Apache JMeter™ 应用程序是开源软件&#xff0c;一个 100% 纯 Java 应用程序&#xff0c;旨在加载测试功能行为和测量性能 。它最初是为测试 Web 应用程序而设计的&#xff0c;但后来扩展到其他测试功能。 2. JMeter能做啥&#x…...

Nodejs运行vue项目时,报错:Error: error:0308010C:digital envelope routines::unsupported

前端项目使用( npm run dev ) 运行vue项目时&#xff0c;出现错误&#xff1a;Error: error:0308010C:digital envelope routines::unsupported 经过探索&#xff0c;发现问题所在&#xff0c;主要是nodeJs V17版本发布了OpenSSL3.0对算法和秘钥大小增加了更为严格的限制&#…...

华为汽车图谱

极狐 极狐&#xff08;ARCFOX&#xff09;是由北汽、华为、戴姆勒、麦格纳等联合打造。总部位于北京蓝谷。 问界 华为与赛力斯&#xff08;东风小康&#xff09;合作的成果。 阿维塔 阿维塔&#xff08;AVATR&#xff09;是由长安汽车、华为、宁德时代三方联合打造。公司总部位…...

鸿蒙操作系统-初识

HarmonyOS-初识 简述安装配置hello world1.创建项目2.目录解释3.构建页面4.真机运行 应用程序包共享包HARHSP 快速修复包 官方文档请参考&#xff1a;HarmonyOS 简述 1.定义&#xff1a;HarmonyOS是分布式操作系统&#xff0c;它旨在为不同类型的智能设备提供统一的操作系统&a…...

【ZZULIOJ】1003: 两个整数的四则运算(Java)

题目描述 输入两个整数num1和num2&#xff0c;请你设计一个程序&#xff0c;计算并输出它们的和、差、积、整数商及余数。 输入 输入只有两个正整数num1、num2。 输出 输出占一行&#xff0c;包括两个数的和、差、积、商及余数&#xff0c;数据之间用一个空格隔开。 样例…...

聊聊芯片原厂

芯片原厂是芯片的生产商,他们制造和设计芯片,并拥有产品的所有权原厂这个词是为了区分芯片代理商(厂)而创造的。 每一家芯片制造商都会通过自己忠诚的芯片代理商(厂)来销售自己的芯片,代理商(厂)也会打着芯片制造商的旗号来销售芯片,因此有时候为了强调自己的正统地…...

百人一岗,Android开发者的困境。。。。。

前言 在当前的Android开发领域&#xff0c;竞争的激烈程度已经达到了前所未有的水平&#xff0c;几乎到了100个开发者竞争1个岗位的地步。 这种“内卷”现象的背后&#xff0c;是技术的快速发展和市场对Android开发者技能要求的不断提升。随着移动应用的普及和多样化&#xf…...

若依分离版 —引入echart连接Springboot后端

1. vue引入echart &#xff08;1&#xff09;首先安装ECharts库。可以通过npm npm install echarts --save &#xff08;2&#xff09;在vue页面中添加一个容器元素来显示图表 <el-card class"mt20"><div id"ha" ref"main"><…...

Halcon深度学习项目实战

Halcon在机器视觉中的价值主要体现在提供高效、可扩展、灵活的机器视觉解决方案&#xff0c;帮助用户解决各种复杂的机器视觉问题&#xff0c;提高生产效率和产品质量。 缩短产品上市时间 Halcon的灵活架构使其能够快速开发出任何类型的机器视觉应用。其全球通用的集成开发环…...

子类中的方法去调用父类中的方法有几种形式?原生django如何向响应头写入数据

1 子类中的方法去调用父类中的方法有几种形式 2 原生django如何向响应头写入数据 1 子类中的方法去调用父类中的方法有几种形式&#xff1f; class Animal:def eat(self):print(self.name, 在吃饭)class Dog(Animal):def __init__(self, name):self.name namedef test(self):#…...

数据安全治理框架构建

一、引言 在数字化时代&#xff0c;数据已成为企业和社会发展的重要驱动力。然而&#xff0c;随着数据量的激增和数据应用场景的扩展&#xff0c;数据安全风险也日益凸显。数据安全治理作为确保数据安全、合规使用的关键手段&#xff0c;受到了广泛的关注。本文旨在探讨数据安…...

观成科技:隐蔽隧道工具Ligolo-ng加密流量分析

1.工具介绍 Ligolo-ng是一款由go编写的高效隧道工具&#xff0c;该工具基于TUN接口实现其功能&#xff0c;利用反向TCP/TLS连接建立一条隐蔽的通信信道&#xff0c;支持使用Let’s Encrypt自动生成证书。Ligolo-ng的通信隐蔽性体现在其支持多种连接方式&#xff0c;适应复杂网…...

盘古信息PCB行业解决方案:以全域场景重构,激活智造新未来

一、破局&#xff1a;PCB行业的时代之问 在数字经济蓬勃发展的浪潮中&#xff0c;PCB&#xff08;印制电路板&#xff09;作为 “电子产品之母”&#xff0c;其重要性愈发凸显。随着 5G、人工智能等新兴技术的加速渗透&#xff0c;PCB行业面临着前所未有的挑战与机遇。产品迭代…...

ESP32 I2S音频总线学习笔记(四): INMP441采集音频并实时播放

简介 前面两期文章我们介绍了I2S的读取和写入&#xff0c;一个是通过INMP441麦克风模块采集音频&#xff0c;一个是通过PCM5102A模块播放音频&#xff0c;那如果我们将两者结合起来&#xff0c;将麦克风采集到的音频通过PCM5102A播放&#xff0c;是不是就可以做一个扩音器了呢…...

C++.OpenGL (10/64)基础光照(Basic Lighting)

基础光照(Basic Lighting) 冯氏光照模型(Phong Lighting Model) #mermaid-svg-GLdskXwWINxNGHso {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-GLdskXwWINxNGHso .error-icon{fill:#552222;}#mermaid-svg-GLd…...

06 Deep learning神经网络编程基础 激活函数 --吴恩达

深度学习激活函数详解 一、核心作用 引入非线性:使神经网络可学习复杂模式控制输出范围:如Sigmoid将输出限制在(0,1)梯度传递:影响反向传播的稳定性二、常见类型及数学表达 Sigmoid σ ( x ) = 1 1 +...

中医有效性探讨

文章目录 西医是如何发展到以生物化学为药理基础的现代医学&#xff1f;传统医学奠基期&#xff08;远古 - 17 世纪&#xff09;近代医学转型期&#xff08;17 世纪 - 19 世纪末&#xff09;​现代医学成熟期&#xff08;20世纪至今&#xff09; 中医的源远流长和一脉相承远古至…...

让回归模型不再被异常值“带跑偏“,MSE和Cauchy损失函数在噪声数据环境下的实战对比

在机器学习的回归分析中&#xff0c;损失函数的选择对模型性能具有决定性影响。均方误差&#xff08;MSE&#xff09;作为经典的损失函数&#xff0c;在处理干净数据时表现优异&#xff0c;但在面对包含异常值的噪声数据时&#xff0c;其对大误差的二次惩罚机制往往导致模型参数…...

Java求职者面试指南:Spring、Spring Boot、MyBatis框架与计算机基础问题解析

Java求职者面试指南&#xff1a;Spring、Spring Boot、MyBatis框架与计算机基础问题解析 一、第一轮提问&#xff08;基础概念问题&#xff09; 1. 请解释Spring框架的核心容器是什么&#xff1f;它在Spring中起到什么作用&#xff1f; Spring框架的核心容器是IoC容器&#…...

土建施工员考试:建筑施工技术重点知识有哪些?

《管理实务》是土建施工员考试中侧重实操应用与管理能力的科目&#xff0c;核心考查施工组织、质量安全、进度成本等现场管理要点。以下是结合考试大纲与高频考点整理的重点内容&#xff0c;附学习方向和应试技巧&#xff1a; 一、施工组织与进度管理 核心目标&#xff1a; 规…...

初探用uniapp写微信小程序遇到的问题及解决(vue3+ts)

零、关于开发思路 (一)拿到工作任务,先理清楚需求 1.逻辑部分 不放过原型里说的每一句话,有疑惑的部分该问产品/测试/之前的开发就问 2.页面部分(含国际化) 整体看过需要开发页面的原型后,分类一下哪些组件/样式可以复用,直接提取出来使用 (时间充分的前提下,不…...