transformers进行学习率调整lr_scheduler(warmup)
一、get_scheduler实现warmup
1、warmup基本思想
Warmup(预热)是深度学习训练中的一种技巧,旨在逐步增加学习率以稳定训练过程,特别是在训练的早期阶段。它主要用于防止在训练初期因学习率过大导致的模型参数剧烈波动或不稳定。预热阶段通常是指在训练开始时,通过多个步长逐步将学习率从一个较低的值增加到目标值(通常是预定义的最大学习率)。
2、warmup基本实现
from transformers import get_schedulerscheduler = get_scheduler(name="cosine", # 可以选择 'linear', 'cosine', 'polynomial', 'constant', 'constant_with_warmup'optimizer=optimizer,num_warmup_steps=100, # 预热步数num_training_steps=num_training_steps # 总的训练步数
)#linear:线性学习率下降
#cosine:余弦退火
#polynomial:多项式衰减
#constant:常数学习率
#constant_with_warmup:预热后保持常数# 上述代码等价于
from transformers import get_cosine_scheduler_with_warmupscheduler = get_cosine_scheduler_with_warmup(optimizer=optimizer,num_warmup_steps=100, # 预热步数num_training_steps=num_training_steps # 总的训练步数
)# 同理等价于linear, polynomial, constant分别等价于
from transformers import (get_constant_schedule, get_polynomial_decay_schedule_with_warmup, get_linear_schedule_with_warmup)
二、各种warmup策略学习率变化规律
1、get_constant_schedule学习率变化规律

2、get_cosine_schedule_with_warmup学习率变化规律

3、get_cosine_with_hard_restarts_schedule_with_warmup学习率变化规律

4、get_linear_schedule_with_warmup学习率变化规律

5、get_polynomial_decay_schedule_with_warmup学习率变化规律(power=2, power=1类似于linear)

6、注意事项
- 如果网络中不同框架采用不同的学习率,上述的warmup策略仍然有效(如图二、5中所示)
- 给schduler设置的number_training_steps一定要和训练过程相匹配,如下所示。
7、可视化学习率过程
import matplotlib.pyplot as plt
from transformers import get_scheduler
from torch.optim import AdamW
import torch
import math# 定义一些超参数learning_rate = 1e-3 # 初始学习率# 假设有一个模型
model = torch.nn.Linear(10, 2)# 获得训练总的步数
epochs = 50
batch_size = 32
#train_loader = ***
#num_train_loader = len(train_loader)
num_train_loader = 1235num_training_steps = epochs * math.ceil(num_train_loader/batch_size) # 总的训练步数# 定义优化器
optimizer = AdamW(model.parameters(), lr=learning_rate)# 创建学习率调度器
scheduler = get_scheduler(name="cosine", # 可以选择 'linear', 'cosine', 'polynomial', 'constant', 'constant_with_warmup'optimizer=optimizer,num_warmup_steps=100, # 预热步数num_training_steps=num_training_steps # 总的训练步数
)# 存储每一步的学习率
learning_rates = []# for step in range(num_training_steps):
# optimizer.step()
# scheduler.step()
# learning_rates.append(optimizer.param_groups[0]['lr'])for epoch in range(epochs):# for batch in train_loader:for step in range(0, num_train_loader, batch_size):optimizer.zero_grad()# loss.backward()optimizer.step()scheduler.step()learning_rates.append(optimizer.param_groups[0]['lr'])# 绘制学习率曲线
plt.plot(learning_rates)
plt.xlabel("Training Steps")
plt.ylabel("Learning Rate")
plt.title("Learning Rate Schedule")
plt.show()
实验结果:

相关文章:
transformers进行学习率调整lr_scheduler(warmup)
一、get_scheduler实现warmup 1、warmup基本思想 Warmup(预热)是深度学习训练中的一种技巧,旨在逐步增加学习率以稳定训练过程,特别是在训练的早期阶段。它主要用于防止在训练初期因学习率过大导致的模型参数剧烈波动或不稳定。…...
智能优化算法之灰狼优化算法(GWO)
智能优化算法是一类基于自然界中生物、物理或社会现象的优化技术。这些算法通过模拟自然界中的一些智能行为,如遗传学、蚁群觅食、粒子群体运动等,来解决复杂的优化问题。智能优化算法广泛应用于各种工程和科学领域,因其具有全局搜索能力、鲁…...
昇思25天学习打卡营第17天|计算机视觉
昇思25天学习打卡营第17天 文章目录 昇思25天学习打卡营第17天ShuffleNet图像分类ShuffleNet网络介绍模型架构Pointwise Group ConvolutionChannel ShuffleShuffleNet模块构建ShuffleNet网络 模型训练和评估训练集准备与加载模型训练模型评估模型预测 打卡记录 ShuffleNet图像分…...
Windows图形界面(GUI)-MFC-C/C++ - 键鼠操作
公开视频 -> 链接点击跳转公开课程博客首页 -> 链接点击跳转博客主页 目录 MFC鼠标 派发流程 鼠标消息(客户区) 鼠标消息(非客户) 坐标处理 客户区 非客户 坐标转换 示例代码 MFC键盘 击键消息 虚拟键代码 键状态 MFC鼠标 派发流程 消息捕获&#…...
Angular 18.2.0 的新功能增强和创新
一.Angular 增强功能 Angular 是一个以支持开发强大的 Web 应用程序而闻名的平台,最近发布了 18.2.0 版本。此更新带来了许多新功能和改进,进一步增强了其功能和开发人员体验。在本文中,我们将深入探讨 Angular 18.2.0 为开发人员社区提供的…...
matlab 小数取余 rem 和 mod有 bug
目录 前言Matlab取余函数1 mod 函数1.1 命令行输入1.2 命令行输出 2 rem 函数2.1 命令行输入2.2 命令行输出 分析原因注意 前言 在 Matlab 代码中mod(0.11, 0.1) < 0.01 判断为真,mod(1.11, 0.1) < 0.01判断为假,导致出现意料外的结果。 结果发现…...
Avalonia中的数据模板
文章目录 1. 介绍和概述什么是数据模板:数据模板的用途:2. 定义数据模板在XAML中定义数据模板:在代码中定义数据模板:3. 使用数据模板在控件中使用数据模板:数据模板选择器:定义数据模板选择器:在XAML中使用数据模板选择器:4. 复杂数据模板使用嵌套数据模板:使用模板绑…...
Sqlmap中文使用手册 - Techniques模块参数使用
目录 1. Techniques模块的帮助文档2. 各个参数的介绍2.1 --techniqueTECH2.2 --time-secTIMESEC2.3 --union-colsUCOLS2.4 --union-charUCHAR2.5 --union-fromUFROM2.6 --dns-domainDNS2.7 --second-urlSEC2.8 --second-reqSEC 1. Techniques模块的帮助文档 Techniques:These o…...
科普文:kubernets原理
kubernetes 已经成为容器编排领域的王者,它是基于容器的集群编排引擎,具备扩展集群、滚动升级回滚、弹性伸缩、自动治愈、服务发现等多种特性能力。 本文将带着大家快速了解 kubernetes ,了解我们谈论 kubernetes 都是在谈论什么。 一、背…...
GO-学习-02-常量
常量是不变的 const package main import "fmt"func main() {//常量定义时必须赋值const pi 3.1415926const e 2.718//一次声明多个常量const(a 1b 2c "ihan")const(n1 100n2n3)//n2,n3也是100 同时声明多个常量时,如果省略了值则表示和…...
Vue系列面试题
大家好,我是有用就扩散,有用就点赞。 1.Vue中组件间有哪些通信方式? 父子组件通信: (1)props | $emit (接收父组件数据 | 传数据给父组件) (2)ref | $refs&a…...
等级保护 总结2
网络安全等级保护解决方案的主打产品: HiSec Insight安全态势感知系统、 FireHunter6000沙箱、 SecoManager安全控制器、 HiSecEngine USG系列防火墙和HiSecEngine AntiDDoS防御系统。 华为HiSec Insight安全态势感知系统是基于商用大数据平台FusionInsight的A…...
关于Redis(热点数据缓存,分布式锁,缓存安全(穿透,击穿,雪崩));
热点数据缓存: 为了把一些经常访问的数据,放入缓存中以减少对数据库的访问频率。从而减少数据库的压力,提高程序的性能。【内存中存储】成为缓存; 缓存适合存放的数据: 查询频率高且修改频率低 数据安全性低 作为缓存的组件: redis组件 memory组件 e…...
【北京迅为】《i.MX8MM嵌入式Linux开发指南》-第三篇 嵌入式Linux驱动开发篇-第四十七章 字符设备和杂项设备总结回顾
i.MX8MM处理器采用了先进的14LPCFinFET工艺,提供更快的速度和更高的电源效率;四核Cortex-A53,单核Cortex-M4,多达五个内核 ,主频高达1.8GHz,2G DDR4内存、8G EMMC存储。千兆工业级以太网、MIPI-DSI、USB HOST、WIFI/BT…...
C#初级——枚举
枚举 枚举是一组命名整型常量。 enum 枚举名字 { 常量1, 常量2, …… 常量n }; 枚举的常量是由 , 分隔的列表。并且,在这个整型常量列表中,通常默认第一位枚举符号的值为0,此后的枚举符号的值都比前一位大1。 在将枚举赋值给 int 类型的…...
Linux 动静态库
一、动静态库 1、库的理解 库其实是给我们提供方法的实现,如上面的对于printf函数的实现就是在库中实现的,而这个库也就是c标准库,本质也是文件,也有对应的路径 2、区别 静态库是指编译链接时,把库文件的代码全部加入…...
微信小游戏之 三消(一)
首先设定一下 单个 方块 cell 类: 类定义和属性 init 方法 用于初始化方块,接收游戏实例、数据、宽度、道具类型和位置。 onWarning 方法 设置警告精灵的帧,并播放闪烁动作,用于显示方块的警告状态。 grow 方法 根据传入的方向…...
软件测试---Linux
Linux命令使用:为了将来工作中与服务器设备进行交互而准备的技能(远程连接/命令的使用)数据库的使用:MySQL,除了查询动作需要重点掌握以外,其他操作了解即可什么是虚拟机 通过虚拟化技术,在电脑…...
数据库之数据表基本操作
目录 一、创建数据表 1.创建表的语法形式 2.使用SQL语句设置约束条件 1.设置主键约束 2.设置自增约束 3.设置非空约束 4.设置唯一性约束 5.设置无符号约束 6.设置默认约束 7.设置外键约束 8.设置表的存储引擎 二、查看表结构 1.查看表基本结构 2.查看建表语句 三…...
利用OSMnx求路网最短路径并可视化(二)
书接上回,为了增加多路径的可视化效果和坐标匹配最近点来实现最短路可视化,我们使用图形化工具matplotlib结合OSMnx的绘图功能来展示整个路网图,并特别高亮显示计算出的最短路径。 多起终点最短路路径并计算距离和时间 完整代码#运行环境 P…...
智慧医疗能源事业线深度画像分析(上)
引言 医疗行业作为现代社会的关键基础设施,其能源消耗与环境影响正日益受到关注。随着全球"双碳"目标的推进和可持续发展理念的深入,智慧医疗能源事业线应运而生,致力于通过创新技术与管理方案,重构医疗领域的能源使用模式。这一事业线融合了能源管理、可持续发…...
Lombok 的 @Data 注解失效,未生成 getter/setter 方法引发的HTTP 406 错误
HTTP 状态码 406 (Not Acceptable) 和 500 (Internal Server Error) 是两类完全不同的错误,它们的含义、原因和解决方法都有显著区别。以下是详细对比: 1. HTTP 406 (Not Acceptable) 含义: 客户端请求的内容类型与服务器支持的内容类型不匹…...
树莓派超全系列教程文档--(62)使用rpicam-app通过网络流式传输视频
使用rpicam-app通过网络流式传输视频 使用 rpicam-app 通过网络流式传输视频UDPTCPRTSPlibavGStreamerRTPlibcamerasrc GStreamer 元素 文章来源: http://raspberry.dns8844.cn/documentation 原文网址 使用 rpicam-app 通过网络流式传输视频 本节介绍来自 rpica…...
【学习笔记】深入理解Java虚拟机学习笔记——第4章 虚拟机性能监控,故障处理工具
第2章 虚拟机性能监控,故障处理工具 4.1 概述 略 4.2 基础故障处理工具 4.2.1 jps:虚拟机进程状况工具 命令:jps [options] [hostid] 功能:本地虚拟机进程显示进程ID(与ps相同),可同时显示主类&#x…...
Device Mapper 机制
Device Mapper 机制详解 Device Mapper(简称 DM)是 Linux 内核中的一套通用块设备映射框架,为 LVM、加密磁盘、RAID 等提供底层支持。本文将详细介绍 Device Mapper 的原理、实现、内核配置、常用工具、操作测试流程,并配以详细的…...
深度学习习题2
1.如果增加神经网络的宽度,精确度会增加到一个特定阈值后,便开始降低。造成这一现象的可能原因是什么? A、即使增加卷积核的数量,只有少部分的核会被用作预测 B、当卷积核数量增加时,神经网络的预测能力会降低 C、当卷…...
sipsak:SIP瑞士军刀!全参数详细教程!Kali Linux教程!
简介 sipsak 是一个面向会话初始协议 (SIP) 应用程序开发人员和管理员的小型命令行工具。它可以用于对 SIP 应用程序和设备进行一些简单的测试。 sipsak 是一款 SIP 压力和诊断实用程序。它通过 sip-uri 向服务器发送 SIP 请求,并检查收到的响应。它以以下模式之一…...
C++.OpenGL (20/64)混合(Blending)
混合(Blending) 透明效果核心原理 #mermaid-svg-SWG0UzVfJms7Sm3e {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-SWG0UzVfJms7Sm3e .error-icon{fill:#552222;}#mermaid-svg-SWG0UzVfJms7Sm3e .error-text{fill…...
《信号与系统》第 6 章 信号与系统的时域和频域特性
目录 6.0 引言 6.1 傅里叶变换的模和相位表示 6.2 线性时不变系统频率响应的模和相位表示 6.2.1 线性与非线性相位 6.2.2 群时延 6.2.3 对数模和相位图 6.3 理想频率选择性滤波器的时域特性 6.4 非理想滤波器的时域和频域特性讨论 6.5 一阶与二阶连续时间系统 6.5.1 …...
JDK 17 序列化是怎么回事
如何序列化?其实很简单,就是根据每个类型,用工厂类调用。逐个完成。 没什么漂亮的代码,只有有效、稳定的代码。 代码中调用toJson toJson 代码 mapper.writeValueAsString ObjectMapper DefaultSerializerProvider 一堆实…...
