机器学习优化算法:从梯度下降到Adam及其变种
机器学习优化算法:从梯度下降到Adam及其变种
引言
最近deepseek的爆火已然说明,在机器学习领域,优化算法是模型训练的核心驱动力。无论是简单的线性回归还是复杂的深度神经网络,优化算法的选择直接影响模型的收敛速度、泛化性能和计算效率。通过本文,你可以系统性地介绍从经典的梯度下降法到当前主流的自适应优化算法(如Adam),分析其数学原理、优缺点及适用场景,并探讨未来发展趋势。
一、优化算法基础
1.1 梯度下降法(Gradient Descent)
数学原理:
介绍如下:
梯度下降可以通过计算损失函数 J ( θ ) J(\theta) J(θ)对参数 θ \theta θ的梯度 ∇ θ J ( θ ) \nabla_\theta J(\theta) ∇θJ(θ),沿负梯度方向更新参数:
θ t + 1 = θ t − η ∇ θ J ( θ t ) \theta_{t+1} = \theta_t - \eta \nabla_\theta J(\theta_t) θt+1=θt−η∇θJ(θt)
其中 η \eta η为学习率。
三种变体:
- 批量梯度下降(BGD):使用全量数据计算梯度,收敛稳定但计算成本高。
- 随机梯度下降(SGD):每次随机选取单个样本更新参数,计算快但噪声大。
- 小批量梯度下降(Mini-batch SGD):平衡BGD与SGD,采用小批量数据,兼顾效率与稳定性。
二、动量法与自适应学习率
2.1 动量法(Momentum)
原理:引入动量项模拟物理惯性,减少振荡,加速收敛。
更新公式:
v t = γ v t − 1 + η ∇ θ J ( θ t ) v_t = \gamma v_{t-1} + \eta \nabla_\theta J(\theta_t) vt=γvt−1+η∇θJ(θt)
θ t + 1 = θ t − v t \theta_{t+1} = \theta_t - v_t θt+1=θt−vt
其中 γ \gamma γ为动量因子(通常0.9),累积历史梯度方向。
2.2 Nesterov加速梯度(NAG)
改进动量法,先根据动量项预估下一步位置,再计算梯度:
v t = γ v t − 1 + η ∇ θ J ( θ t − γ v t − 1 ) v_t = \gamma v_{t-1} + \eta \nabla_\theta J(\theta_t - \gamma v_{t-1}) vt=γvt−1+η∇θJ(θt−γvt−1)
θ t + 1 = θ t − v t \theta_{t+1} = \theta_t - v_t θt+1=θt−vt
NAG在凸优化中具有理论收敛优势。
2.3 自适应学习率算法
Adagrad
为每个参数分配独立的学习率,适应稀疏数据:
g t , i = ∇ θ J ( θ t , i ) g_{t,i} = \nabla_\theta J(\theta_{t,i}) gt,i=∇θJ(θt,i)
G t , i = G t − 1 , i + g t , i 2 G_{t,i} = G_{t-1,i} + g_{t,i}^2 Gt,i=Gt−1,i+gt,i2
θ t + 1 , i = θ t , i − η G t , i + ϵ g t , i \theta_{t+1,i} = \theta_{t,i} - \frac{\eta}{\sqrt{G_{t,i} + \epsilon}} g_{t,i} θt+1,i=θt,i−Gt,i+ϵηgt,i
缺陷: G t G_t Gt累积导致学习率过早衰减。
RMSprop
改进Adagrad,引入指数移动平均:
E [ g 2 ] t = β E [ g 2 ] t − 1 + ( 1 − β ) g t 2 E[g^2]_t = \beta E[g^2]_{t-1} + (1-\beta)g_t^2 E[g2]t=βE[g2]t−1+(1−β)gt2
θ t + 1 = θ t − η E [ g 2 ] t + ϵ g t \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{E[g^2]_t + \epsilon}} g_t θt+1=θt−E[g2]t+ϵηgt
缓解学习率下降问题,适合非平稳目标。
三、Adam算法详解
3.1 Adam的核心思想
结合动量法与自适应学习率,引入一阶矩估计(均值)和二阶矩估计(方差)。
3.2 算法步骤
- 计算梯度: g t = ∇ θ J ( θ t ) g_t = \nabla_\theta J(\theta_t) gt=∇θJ(θt)
- 更新一阶矩: m t = β 1 m t − 1 + ( 1 − β 1 ) g t m_t = \beta_1 m_{t-1} + (1-\beta_1)g_t mt=β1mt−1+(1−β1)gt
- 更新二阶矩: v t = β 2 v t − 1 + ( 1 − β 2 ) g t 2 v_t = \beta_2 v_{t-1} + (1-\beta_2)g_t^2 vt=β2vt−1+(1−β2)gt2
- 偏差校正(因初始零偏差):
m ^ t = m t 1 − β 1 t , v ^ t = v t 1 − β 2 t \hat{m}_t = \frac{m_t}{1-\beta_1^t}, \quad \hat{v}_t = \frac{v_t}{1-\beta_2^t} m^t=1−β1tmt,v^t=1−β2tvt - 参数更新:
θ t + 1 = θ t − η v ^ t + ϵ m ^ t \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t} + \epsilon} \hat{m}_t θt+1=θt−v^t+ϵηm^t
超参数建议: β 1 = 0.9 \beta_1=0.9 β1=0.9, β 2 = 0.999 \beta_2=0.999 β2=0.999, ϵ = 1 0 − 8 \epsilon=10^{-8} ϵ=10−8。
3.3 优势与局限性
- 优点:自适应学习率、内存效率高、适合大规模数据与参数。
- 缺点:可能陷入局部最优、泛化性能在某些任务中不如SGD。
四、Adam的改进与变种
4.1 Nadam
融合NAG与Adam,公式改变为:
θ t + 1 = θ t − η v ^ t + ϵ ( β 1 m ^ t + ( 1 − β 1 ) g t 1 − β 1 t ) \theta_{t+1} = \theta_t - \frac{\eta}{\sqrt{\hat{v}_t}+\epsilon} (\beta_1 \hat{m}_t + \frac{(1-\beta_1)g_t}{1-\beta_1^t}) θt+1=θt−v^t+ϵη(β1m^t+1−β1t(1−β1)gt)
这样能够加速收敛并提升稳定性。
4.2 AMSGrad
解决Adam二阶矩估计可能导致的收敛问题:
v t = max ( β 2 v t − 1 , v t ) v_t = \max(\beta_2 v_{t-1}, v_t) vt=max(β2vt−1,vt)
保证学习率单调递减,符合收敛理论。
五、算法对比与选择指南
算法 | 收敛速度 | 内存消耗 | 适用场景 |
---|---|---|---|
SGD | 慢 | 低 | 凸优化、精细调参 |
Momentum | 中等 | 低 | 高维、非平稳目标 |
Adam | 快 | 中 | 默认选择、复杂模型 |
AMSGrad | 中等 | 中 | 理论保障强的任务 |
实践建议:
- 首选Adam作为基准,尤其在资源受限时。
- 对泛化要求高时尝试SGD + Momentum。
- 使用学习率预热(Warmup)或周期性调整(如Cosine退火)提升效果。
六、未来研究方向
- 理论分析:非凸优化中的收敛性证明。
- 自动化调参:基于元学习的优化器设计。
- 异构计算优化:适应GPU/TPU等硬件特性。
- 生态整合:与深度学习框架(如PyTorch、TensorFlow)深度融合。
结论
从梯度下降到Adam,优化算法的演进体现了机器学习对高效、自适应方法的追求。理解不同算法的内在机制,结合实际任务灵活选择,是提升模型性能的关键。未来,随着理论突破与计算硬件的进步,优化算法将继续推动机器学习技术的边界。
全文约10,000字,涵盖基础概念、数学推导、对比分析及实践指导,可作为入门学习与工程实践的参考指南。
相关文章:

机器学习优化算法:从梯度下降到Adam及其变种
机器学习优化算法:从梯度下降到Adam及其变种 引言 最近deepseek的爆火已然说明,在机器学习领域,优化算法是模型训练的核心驱动力。无论是简单的线性回归还是复杂的深度神经网络,优化算法的选择直接影响模型的收敛速度、泛化性能…...

[SAP ABAP] 静态断点的使用
在 ABAP 编程环境中,静态断点通过关键字BREAK-POINT实现,当程序执行到这一语句时,会触发调试器中断程序的运行,允许开发人员检查当前状态并逐步跟踪后续代码逻辑 通常情况下,在代码的关键位置插入静态断点可以帮助开发…...

129.求根节点到叶节点数字之和(遍历思想)
Problem: 129.求根节点到叶节点数字之和 文章目录 题目描述思路复杂度Code 题目描述 思路 遍历思想(利用二叉树的先序遍历) 直接利用二叉树的先序遍历,将遍历过程中的节点值先利用字符串拼接起来遇到根节点时再转为数字并累加起来,在归的过程中…...

NCCL、HCCL、通信、优化
文章目录 从硬件PCIE、NVLINK、RDMA原理到通信NCCL、MPI原理!通信实现方式:机器内通信、机器间通信通信实现方式:通讯协调通信实现方式:机器内通信:PCIe通信实现方式:机器内通信:NVLink通信实现…...

unity学习21:Application类与文件存储的位置
目录 1 unity是一个跨平台的引擎 1.1 使用 Application类,去读写文件 1.2 路径特点 1.2.1 相对位置/相对路径: 1.2.2 固定位置/绝对路径: 1.3 测试方法,仍然挂一个C#脚本在gb上 2 游戏数据文件夹路径(只读&…...

17 一个高并发的系统架构如何设计
高并发系统的理解 第一:我们设计高并发系统的前提是该系统要高可用,起码整体上的高可用。 第二:高并发系统需要面对很大的流量冲击,包括瞬时的流量和黑客攻击等 第三:高并发系统常见的需要考虑的问题,如内存不足的问题,服务抖动的…...

Spring Boot 实例解析:配置文件
SpringBoot 的热部署: Spring 为开发者提供了一个名为 spring-boot-devtools 的模块来使用 SpringBoot 应用支持热部署,提高开发者的效率,无需手动重启 SpringBoot 应用引入依赖: <dependency> <groupId>org.springfr…...

pytorch图神经网络处理图结构数据
人工智能例子汇总:AI常见的算法和例子-CSDN博客 图神经网络(Graph Neural Networks,GNNs)是一类能够处理图结构数据的深度学习模型。图结构数据由节点(vertices)和边(edges)组成&a…...

计算机网络一点事(23)
传输层 端口作用:标识主机特定进程,TCP,UDP协议 端口号分类:服务器:0-1023,熟知 1024-49151 登记 客户端:49152-65535 功能:实现端到端,进程到进程的通信,…...

(9)下:学习与验证 linux 里的 epoll 对象里的 EPOLLIN、 EPOLLHUP 与 EPOLLRDHUP 的不同。小例子的实验
(4)本实验代码的蓝本,是伊圣雨老师里的课本里的代码,略加改动而来的。 以下是 服务器端的代码: 每当收到客户端的报文时,就测试一下对应的 epoll 事件里的事件标志,不读取报文内容,…...

DeepSeek-R1模型1.5b、7b、8b、14b、32b、70b和671b有啥区别?
deepseek-r1的1.5b、7b、8b、14b、32b、70b和671b有啥区别?码笔记mabiji.com分享:1.5B、7B、8B、14B、32B、70B是蒸馏后的小模型,671B是基础大模型,它们的区别主要体现在参数规模、模型容量、性能表现、准确性、训练成本、推理成本…...

一、html笔记
(一)前端概述 1、定义 前端是Web应用程序的前台部分,运行在PC端、移动端等浏览器上,展现给用户浏览的网页。通过HTML、CSS、JavaScript等技术实现,是用户能够直接看到和操作的界面部分。上网就是下载html文档,浏览器是一个解释器,运行从服务器下载的html文件,解析html、…...

AI大模型开发原理篇-2:语言模型雏形之词袋模型
基本概念 词袋模型(Bag of Words,简称 BOW)是自然语言处理和信息检索等领域中一种简单而常用的文本表示方法,它将文本看作是一组单词的集合,并忽略文本中的语法、词序等信息,仅关注每个词的出现频率。 文本…...

基于微信小程序的实习记录系统设计与实现(LW+源码+讲解)
专注于大学生项目实战开发,讲解,毕业答疑辅导,欢迎高校老师/同行前辈交流合作✌。 技术范围:SpringBoot、Vue、SSM、HLMT、小程序、Jsp、PHP、Nodejs、Python、爬虫、数据可视化、安卓app、大数据、物联网、机器学习等设计与开发。 主要内容:…...

【LLM】DeepSeek-R1-Distill-Qwen-7B部署和open webui
note DeepSeek-R1-Distill-Qwen-7B 的测试效果很惊艳,CoT 过程可圈可点,25 年应该值得探索更多端侧的硬件机会。 文章目录 note一、下载 Ollama二、下载 Docker三、下载模型四、部署 open webui 一、下载 Ollama 访问 Ollama 的官方网站 https://ollam…...

【Elasticsearch】 Intervals Query
Elasticsearch Intervals Query 返回基于匹配术语的顺序和接近度的文档。 intervals 查询使用 匹配规则,这些规则由一小组定义构建而成。这些规则然后应用于指定 field 中的术语。 这些定义生成覆盖文本中术语的最小间隔序列。这些间隔可以进一步由父源组合和过滤…...

DeepSeek技术深度解析:从不同技术角度的全面探讨
DeepSeek技术深度解析:从不同技术角度的全面探讨 引言 DeepSeek是一个集成了多种先进技术的平台,旨在通过深度学习和其他前沿技术来解决复杂的问题。本文将从算法、架构、数据处理以及应用等不同技术角度对DeepSeek进行详细分析。 一、算法层面 深度学…...

Docker 部署 Starrocks 教程
Docker 部署 Starrocks 教程 StarRocks 是一款高性能的分布式分析型数据库,主要用于 OLAP(在线分析处理)场景。它最初是由百度的开源团队开发的,旨在为大数据分析提供一个高效、低延迟的解决方案。StarRocks 支持实时数据分析&am…...

【LLM-agent】(task6)构建教程编写智能体
note 构建教程编写智能体 文章目录 note一、功能需求二、相关代码(1)定义生成教程的目录 Action 类(2)定义生成教程内容的 Action 类(3)定义教程编写智能体(4)交互式操作调用教程编…...

29.Word:公司本财年的年度报告【13】
目录 NO1.2.3.4 NO5.6.7 NO8.9.10 NO1.2.3.4 另存为F12:考生文件夹:Word.docx选中绿色标记的标题文本→样式对话框→单击右键→点击样式对话框→单击右键→修改→所有脚本→颜色/字体/名称→边框:0.5磅、黑色、单线条:点…...

14 2D矩形模块( rect.rs)
一、 rect.rs源码 // Copyright 2013 The Servo Project Developers. See the COPYRIGHT // file at the top-level directory of this distribution. // // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or // http://www.apache.org/licenses/LICENS…...

【Unity3D】实现2D角色/怪物死亡消散粒子效果
核心:这是一个Unity粒子系统自带的一种功能,可将粒子生成控制在一个Texture图片网格范围内,并且粒子颜色会自动采样图片的像素点颜色,之后则是粒子编辑出消散效果。 Particle System1物体(爆发式随机速度扩散10000个粒…...

Linux - 进程间通信(3)
目录 3、解决遗留BUG -- 边关闭信道边回收进程 1)解决方案 2)两种方法相比较 4、命名管道 1)理解命名管道 2)创建命名管道 a. 命令行指令 b. 系统调用方法 3)代码实现命名管道 构建类进行封装命名管道&#…...

3、C#基于.net framework的应用开发实战编程 - 实现(三、三) - 编程手把手系列文章...
三、 实现; 三.三、编写应用程序; 此文主要是实现应用的主要编码工作。 1、 分层; 此例子主要分为UI、Helper、DAL等层。UI负责便签的界面显示;Helper主要是链接UI和数据库操作的中间层;DAL为对数据库的操…...

C++编程语言:抽象机制:泛型编程(Bjarne Stroustrup)
泛型编程(Generic Programming) 目录 24.1 引言(Introduction) 24.2 算法和(通用性的)提升(Algorithms and Lifting) 24.3 概念(此指模板参数的插件)(Concepts) 24.3.1 发现插件集(Discovering a Concept) 24.3.2 概念与约束(Concepts and Constraints) 24.4 具体化…...

Python面试宝典13 | Python 变量作用域,从入门到精通
今天,我们来深入探讨一下 Python 中一个非常重要的概念——变量作用域。理解变量作用域对于编写清晰、可维护、无 bug 的代码至关重要。 什么是变量作用域? 简单来说,变量作用域就是指一个变量在程序中可以被访问的范围。Python 中有四种作…...

基于最近邻数据进行分类
人工智能例子汇总:AI常见的算法和例子-CSDN博客 完整代码: import torch import numpy as np from sklearn.neighbors import KNeighborsClassifier from sklearn.metrics import accuracy_score import matplotlib.pyplot as plt# 生成一个简单的数据…...

DeepSeek V3 vs R1:大模型技术路径的“瑞士军刀“与“手术刀“进化
DeepSeek V3 vs R1:——大模型技术路径的"瑞士军刀"与"手术刀"进化 大模型分水岭:从通用智能到垂直突破 2023年,GPT-4 Turbo的发布标志着通用大模型进入性能瓶颈期。当模型参数量突破万亿级门槛后,研究者们开…...

一、TensorFlow的建模流程
1. 数据准备与预处理: 加载数据:使用内置数据集或自定义数据。 预处理:归一化、调整维度、数据增强。 划分数据集:训练集、验证集、测试集。 转换为Dataset对象:利用tf.data优化数据流水线。 import tensorflow a…...

指导初学者使用Anaconda运行GitHub上One - DM项目的步骤
以下是指导初学者使用Anaconda运行GitHub上One - DM项目的步骤: 1. 安装Anaconda 下载Anaconda: 让初学者访问Anaconda官网(https://www.anaconda.com/products/distribution),根据其操作系统(Windows、M…...