当前位置: 首页 > news >正文

Lecture3 梯度下降(Gradient Descent)

目录

1 问题背景

2 批量梯度下降 (Batch Gradient Descent)

3 鞍点(Saddle Point)

3 随机梯度下降 (Stochastic Gradient Descent)

4 小批量梯度下降 (Mini-batch Gradient Descent)


1 问题背景

图1 上节课讲述的穷举法求最优权重值

  在Lecture2中,介绍了使用穷举法来确定最优\omega值,然而当遇到\omega范围较大,或者数量过多等情况时,穷举法的时间复杂度过大。因此,我们需要优化该算法。

2 批量梯度下降 (Batch Gradient Descent)

  在这次课中,介绍了一种寻找\omega最优值的算法——批量梯度下降 (Batch Gradient Descent, BGD)

简单介绍下该算法。首先对于下图:

图2 训练过程中权重初始值与最优值的位置

  假设我们目前的起始\omega位于上图红色点,为了找到最优\omega点(位于绿点),那么我们需要向左边移动,这样才能到达最优\omega点。

图3 我们需要计算梯度以向左移动权值点

 

   如何让权值点向左还是向右移动呢?此时我们需要计算当前点的梯度(Gradient),也就是用成本函数对权重进行求导,如果梯度<0,则向函数值递减方向移动;梯度>0,则向函数值递增方向移动。

  因为要移动起来,所以我们每移动一步,就要更新一下\omega值。更新函数如下图Update处:

图4 更新权重值的函数公式

  在这个更新函数中, α代表学习率(Learning Rate),学习率是机器学习中常用的一个超参数,它定义了每次更新参数时步长的大小,即每次更新参数时参数值变化的幅度。如果学习率设置得过大,所求结果可能会在最优解的附近来回震荡,而无法找到全局最优解。如果学习率设置得过小,那么模型的训练将会非常缓慢,甚至找不到最优解。

  这个式子中,梯度前面用了减号,是为了朝函数值递减方向,也就是往最优\omega所在的点移动,所以在梯度前面加负号。 就这样持续一步步地更新\omega,直到找到最优\omega

下面我们来具体讲讲如何去计算更新函数中的\frac{\partial cost}{\partial \omega} :

图5 更新函数

计算过程中,需要用到上节课总结的两个公式:

图6 均方误差MSE

图7 预测值y_hat

  接着把上述两个公式代入原式:

图8 求导过程

   蓝色处,因为cost=MSE,所以直接代入上一节课的MSE公式,然后对\omega求导。

  绿色处,由有理运算法则,和的导数等于导数的和,所以这里可以把\frac{\partial}{\partial \omega}移入求和式子中,对里面先进行求导后,再求和相加。

  黄色处,根据复合导数的链式求导法进行求导。

代码实现

from matplotlib import pyplot as pltx_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0  # 初始权重,由这个权重开始迭代'''线性模型,算出预测值y_hat'''
def forward(x):return x * w'''均方误差MSE'''
def cost(xs, ys):cost = 0for x, y in zip(xs, ys):y_pred = forward(x)  # 算出y_hatcost += (y_pred - y) ** 2  # (y_hat - y)²return cost / len(xs)  # 除以样本总数求均值'''梯度下降公式'''
def gradient(xs, ys):grad = 0for x, y in zip(xs, ys):grad += 2 * x * (x * w - y)return grad / len(xs)print('Predict (before training)', 4, forward(4))  # 训练前,模型对输入的4的最终预测结果cost_list = [] # 保存每轮迭代后的cost值
epoch_list = [] # 保存每轮的迭代后的epoch值
for epoch in range(100):  # 进行100轮训练cost_val = cost(x_data, y_data)grad_val = gradient(x_data, y_data)w -= 0.01 * grad_val # 使用梯度下降法更新权重,0.01表示学习率print('Epoch:', epoch, 'w=%.2f' % w, 'loss=%.2f' % cost_val)cost_list.append(cost_val)epoch_list.append(epoch)
print('Predict (after training)', 4, forward(4))  # 训练后,模型对输入的4的最终预测结果'''绘图'''
plt.plot(epoch_list, cost_list)
plt.ylabel('Cost')
plt.xlabel('Epoch')
plt.grid()
plt.show()
图9 输出结果图像

  将MSE公式和Linear Model公式代入整合,的最终更新函数:

图10 最终的更新函数

补充

训练后的结果一般来说,cost会趋于收敛情况

图11 通常训练后cost图像会趋于收敛

 如果发生如下情况,说明训练失败,原因有很多,其中之一可能是学习率取得太大:

图12 训练失败

 

  这就是批量梯度下降算法,本质上是一个贪心算法(Greedy Algorithm)。不过该算法有局限性,比如当前的预测值\omega正好位于下图绿线处,因为再往右移动会梯度会发生变化,使得程序直接终止,于是误将红的点作为最优\omega值,而忽略了处于蓝色点的最优\omega值:

图13 局部最优和全局最优示意图

   我们把上图中的红点称为局部最优点(Local Optimum),蓝色点称为全局最优点(Global Optimum)。因此对于该梯度下降算法,很可能会找到局部最优点,而忽略了全局最优点。不过这种现象不必担心,因为在实际训练中,往往很难陷入局部最优点。

3 鞍点(Saddle Point)

  在实际训练中,往往很难陷入局部最优点,而最需要解决的问题是鞍点(Saddle Point),鞍点是机器学习和数学中的一个概念,它指的是一个特殊的局部极小值,在某些方向上是极小值,但在其他方向上是极大值。在一元函数中,梯度=0的点就是鞍点。比如下图中,红色小球所处的位置就在鞍点,此时梯度为零,会导致更新函数无法更新(因为梯度=0,\omega=\omega-α*0相当于没有发生更新):

图14 鞍点示意图

 

 从多维角度来分析,比如下图红球处于马鞍面(Saddle Surface),从一个切面看可以处于最小值,从另一个切面看又处于最大值:

图15 位于马鞍面的鞍点

 

  在优化问题中,鞍点是一种特殊的局部最优解,是一个难以优化的点,因为优化算法可能很难从鞍点附近找到全局最优解。这是因为,如果优化算法在鞍点附近搜索,它可能会被误导到其他附近的局部最优解,而不是真正的全局最优解。所以在深度学习中,需要克服的最大问题就是鞍点而非局部最优问题。

3 随机梯度下降 (Stochastic Gradient Descent)

  随机梯度下降 (Stochastic Gradient Descent, SGD)在深度学习中很常用,和BGD算法的区别是,BGD使用所有的样本的均值的平均损失来作为\omega的更新依据,而SGD是从所有样本中随机选择单个样本的损失值来对\omega进行更新。

  随机梯度下降的优点是,每次仅使用一个数据点的梯度,因此在每次迭代时都有可能沿着非0梯度的方向更新参数,这样就避免陷入到鞍点导致无法更新参数。

图16 BGD到SGD公式上的改变

代码实现

import randomx_data = [1.0, 2.0, 3.0]
y_data = [2.0, 4.0, 6.0]
w = 1.0def forward(x):return x * wdef loss(x, y):y_pred = forward(x)return (y_pred - y) ** 2def gradient(x, y):return 2 * x * (x * w - y)print('Predict (before training)', 4, forward(4))
for epoch in range(100):t = random.randrange(0, 3) # 随机得到一个样本x = x_data[t]y = y_data[t]grad = gradient(x, y)w = w - 0.01 * gradprint("\tgrad: ", x, y, '%.2f' % grad)l = loss(x, y)print("progress:", epoch, "w=%.2f" % w, "loss=%.2f" % l)
print('Predict (after training)', 4, forward(4))

部分输出结果

Predict (before training) 4 4.0
    grad:  3.0 6.0 -18.00
progress: 0 w=1.18 loss=6.05
    grad:  2.0 4.0 -6.56
progress: 1 w=1.25 loss=2.28
    grad:  3.0 6.0 -13.58
progress: 2 w=1.38 loss=3.44
    grad:  1.0 2.0 -1.24
progress: 3 w=1.39 loss=0.37
    grad:  2.0 4.0 -4.85
progress: 4 w=1.44 loss=1.24

···

    grad:  1.0 2.0 -0.00
progress: 97 w=2.00 loss=0.00
    grad:  1.0 2.0 -0.00
progress: 98 w=2.00 loss=0.00
    grad:  2.0 4.0 -0.00
progress: 99 w=2.00 loss=0.00
Predict (after training) 4 7.999910864525451

4 小批量梯度下降 (Mini-batch Gradient Descent)

  SGD算法虽然可以在一定程度上避免陷入局部最优以及鞍点问题,但是运算所需时间复杂度过高,每次仅使用一个数据点的梯度,因此它的收敛速度通常比较慢。

  因此有一个折中的办法,就是使用小批量梯度下降 (Mini-batch Gradient Descent) 算法。简单来说,小批量梯度下降是一种介于批量梯度下降和随机梯度下降之间的优化算法。结合了这两种方法,通过使用小的随机选择的训练数据子集(称为mini-batch)计算损失函数关于参数的梯度的平均值来更新模型参数。

  总之,小批量梯度下降算法实现了BGD的高计算效率和SGD的良好收敛性之间的平衡。

相关文章:

Lecture3 梯度下降(Gradient Descent)

目录 1 问题背景 2 批量梯度下降 (Batch Gradient Descent) 3 鞍点(Saddle Point) 3 随机梯度下降 (Stochastic Gradient Descent) 4 小批量梯度下降 (Mini-batch Gradient Descent) 1 问题背景 图1 上节课讲述的穷举法求最优权重值在Lecture2中&#xff0c;介绍了使用穷举…...

深入了解DSP

一、时钟和电源 问&#xff1a;DSP的电源设计和时钟设计应该特别注意哪些方面&#xff1f;外接晶振选用有源的好还是无源的好&#xff1f; 答&#xff1a;时钟一般使用晶体&#xff0c;电源可用TI的配套电源。外接晶振用无源的好。 问&#xff1a;TMS320LF2407的A/D转换精度保证…...

Flink反压如何排查

Flink反压利用了网络传输和动态限流。Flink的任务的组成由流和算子组成&#xff0c;那么流中的数据在算子之间转换的时候&#xff0c;会放入分布式的阻塞队列中。当消费者的阻塞队列满的时候&#xff0c;则会降低生产者的处理速度。 如上图所示&#xff0c;当Task C 的数据处…...

windows无法访问指定设备路径或文件怎么办?2个解决方案

有时候Win10电脑打不开程序或文件&#xff0c;windows无法访问指定设备路径或文件该怎么办&#xff1f;原因是什么呢&#xff1f;一般导致这种情况的出现&#xff0c;大多是因为我们的电脑缺乏相应的查看权限&#xff0c;我们只需要通过赋予权限就可以解决这个难题了。 操作环境…...

冷知识|鹤顶红还能用来修长城?

大家好&#xff0c;我是建模助手。 在上篇浅浅地蹭了波热点之后&#xff0c;我灵机一动&#xff0c;倒不如也搞一搞建筑方面的冷知识&#xff1f;冷热搭配&#xff0c;事半功倍... 问问大家&#xff0c;如果谈起古建筑&#xff0c;关键词都有什么&#xff1f;是庄严、震撼、壮…...

【GD32F427开发板试用】在IAR环境中移植RTX5

本篇文章来自极术社区与兆易创新组织的GD32F427开发板评测活动&#xff0c;更多开发板试用活动请关注极术社区网站。作者&#xff1a;吴金刚 0.前言 首先感谢极术社区和兆易创新给了这次试用GD32F427开发板的机会。 板子做的虽然简约&#xff0c;但是自带了GD-link所以一根USB…...

MySQl学习(从入门到精通15)

MySQl学习&#xff08;从入门到精通15&#xff09;第 18 章_MySQL 8 其它新特性1. MySQL 8 新特性概述1. 1 MySQL 8. 0 新增特性1. 2 MySQL 8. 0 移除的旧特性2. 新特性 1 &#xff1a;窗口函数2. 1 使用窗口函数前后对比2. 2 窗口函数分类2. 3 语法结构2. 4 分类讲解1. 序号函…...

前端构建工具 Vite

文章目录参考环境构建工具构建工具的主要功能目前主流的前端构建工具Vite为什么使用 Vite冷启动WebpackVite热更新优化热更新优化预构建依赖Webpack VS ViteVite 的缺点首屏性能懒加载与 Vite 相关的基本操作获取create-vite创建项目Project nameSelect a frameworkSelect a va…...

若依框架---PageHelper分页(十)

在前几天的文章中&#xff0c;我们介绍了PageHelper的分页方法&#xff0c;研读代码定位到了ExecutorUtil.pageQuery(...)方法&#xff0c;并阅读到了其中的部分代码。 今天我们将看到重要的SQL修改代码。 getPageSql 我们接着看代码&#xff1a; if (!dialect.beforePage(…...

苹果手机专用蓝牙耳机有哪些?与iphone兼容性好的蓝牙耳机

蓝牙耳机摆脱了线缆的束缚&#xff0c;在地以各种方式轻松通话。自从蓝牙耳机问世以来&#xff0c;一直是行动商务族提升效率的好工具&#xff0c;苹果产品一直都是受欢迎的数码产品&#xff0c;下面推荐几款与iphone兼容性好的蓝牙耳机。 第一款&#xff1a;南卡小音舱蓝牙耳…...

CS-TPGS;壳聚糖修饰维生素E;Chitosan-g-TPGS

Chitosan-g-TPGS,CS-TPGS壳聚糖修饰维生素E聚乙二醇1000琥珀酸酯外观呈现白色固体或者粘稠液体。长期保存需要在-20℃,避光,干燥条件下存放&#xff0c;注意取用一定要干燥,避免频繁溶冻。 维生素E聚乙二醇琥珀酸酯(简称TPGS)是维生素E的水溶性衍生物,由维生素E琥珀酸酯的羧基与…...

easyx的基本使用(万字解析)

easyx的基本使用一.基本框架1.创建文件2.创建窗体-initgraph,closegraph,getchar二.简单的绘制1.圆形-circle2.坐标系统-setorigin,setaspectratio三.简单图形1.绘制点-putpixel2.简单的直线-line3.矩形-rectangle4.椭圆-ellipse5.圆角矩形-roundrect6.扇形-pie7.圆弧-arc四.多…...

基于OpenCV 的车牌识别

基于OpenCV 的车牌识别 车牌识别是一种图像处理技术&#xff0c;用于识别不同车辆。这项技术被广泛用于各种安全检测中。现在让我一起基于 OpenCV 编写 Python 代码来完成这一任务。 车牌识别的相关步骤 1. 车牌检测&#xff1a;第一步是从汽车上检测车牌所在位置。我们将使用…...

C#【必备技能篇】Winform跨线程更新进度条的实例

文章目录实例一&#xff1a;【方便理解&#xff0c;常用&#xff01;】源码&#xff1a;运行效果&#xff1a;实例二&#xff1a;【重在理解代码本身】源码&#xff1a;运行效果&#xff1a;参考&#xff1a;实例一&#xff1a;【方便理解&#xff0c;常用&#xff01;】 跨线…...

(1分钟速通面试) 矩阵分解相关内容

矩阵分解算法--总结QR分解 LU分解本篇博客总结一下QR分解和LU分解&#xff0c;这些都是矩阵加速的操作&#xff0c;在slam里面还算是比较常用的内容&#xff0c;这个地方在isam的部分出现过。(当然isam也是一个坑&#xff0c;想要出点创新成果的话 可能是不太现实的 短期来讲 哈…...

this指向

&#xff08;1&#xff09;在全局环境中的this——window 无论是否在严格模式下&#xff0c;在全局执行环境中&#xff08;在任何函数体外部&#xff09;this 都指向全局对象。 "use strict"console.log(this); //windowconsole.log(thiswindow);//true &#xff08…...

安卓小游戏:小板弹球

安卓小游戏&#xff1a;小板弹球 前言 这个是通过自定义View实现小游戏的第三篇&#xff0c;是小时候玩的那种五块钱的游戏机上的&#xff0c;和俄罗斯方块很像&#xff0c;小时候觉得很有意思&#xff0c;就模仿了一下。 需求 这里的逻辑就是板能把球弹起来&#xff0c;球…...

7、单行函数

文章目录1 函数的理解1.1 什么是函数1.2 不同DBMS函数的差异1.3 MySQL的内置函数及分类2 数值函数2.1 基本函数2.2 角度与弧度互换函数2.3 三角函数2.4 指数与对数2.5 进制间的转换3 字符串函数4 日期和时间函数4.1 获取日期、时间4.2 日期与时间戳的转换4.3 获取月份、星期、星…...

华为机试题:HJ56 完全数计算(python)

文章目录博主精品专栏导航知识点详解1、input()&#xff1a;获取控制台&#xff08;任意形式&#xff09;的输入。输出均为字符串类型。1.1、input() 与 list(input()) 的区别、及其相互转换方法2、print() &#xff1a;打印输出。3、整型int() &#xff1a;将指定进制&#xf…...

opencv——傅里叶变换、低通与高通滤波及直方图等操作

1、傅里叶变换a、傅里叶变换原理时域分析&#xff1a;以时间为参照进行分析。频域分析&#xff1a;相当于上帝视角一样&#xff0c;看事物层次更高&#xff0c;时域的运动在频域来看就是静止的。eg&#xff1a;投球——时域分析&#xff1a;第1分钟投了3分&#xff0c;第2分钟投…...

AI-调查研究-01-正念冥想有用吗?对健康的影响及科学指南

点一下关注吧&#xff01;&#xff01;&#xff01;非常感谢&#xff01;&#xff01;持续更新&#xff01;&#xff01;&#xff01; &#x1f680; AI篇持续更新中&#xff01;&#xff08;长期更新&#xff09; 目前2025年06月05日更新到&#xff1a; AI炼丹日志-28 - Aud…...

7.4.分块查找

一.分块查找的算法思想&#xff1a; 1.实例&#xff1a; 以上述图片的顺序表为例&#xff0c; 该顺序表的数据元素从整体来看是乱序的&#xff0c;但如果把这些数据元素分成一块一块的小区间&#xff0c; 第一个区间[0,1]索引上的数据元素都是小于等于10的&#xff0c; 第二…...

【根据当天日期输出明天的日期(需对闰年做判定)。】2022-5-15

缘由根据当天日期输出明天的日期(需对闰年做判定)。日期类型结构体如下&#xff1a; struct data{ int year; int month; int day;};-编程语言-CSDN问答 struct mdata{ int year; int month; int day; }mdata; int 天数(int year, int month) {switch (month){case 1: case 3:…...

页面渲染流程与性能优化

页面渲染流程与性能优化详解&#xff08;完整版&#xff09; 一、现代浏览器渲染流程&#xff08;详细说明&#xff09; 1. 构建DOM树 浏览器接收到HTML文档后&#xff0c;会逐步解析并构建DOM&#xff08;Document Object Model&#xff09;树。具体过程如下&#xff1a; (…...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中&#xff0c;UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...

(转)什么是DockerCompose?它有什么作用?

一、什么是DockerCompose? DockerCompose可以基于Compose文件帮我们快速的部署分布式应用&#xff0c;而无需手动一个个创建和运行容器。 Compose文件是一个文本文件&#xff0c;通过指令定义集群中的每个容器如何运行。 DockerCompose就是把DockerFile转换成指令去运行。 …...

【JVM】Java虚拟机(二)——垃圾回收

目录 一、如何判断对象可以回收 &#xff08;一&#xff09;引用计数法 &#xff08;二&#xff09;可达性分析算法 二、垃圾回收算法 &#xff08;一&#xff09;标记清除 &#xff08;二&#xff09;标记整理 &#xff08;三&#xff09;复制 &#xff08;四&#xff…...

【 java 虚拟机知识 第一篇 】

目录 1.内存模型 1.1.JVM内存模型的介绍 1.2.堆和栈的区别 1.3.栈的存储细节 1.4.堆的部分 1.5.程序计数器的作用 1.6.方法区的内容 1.7.字符串池 1.8.引用类型 1.9.内存泄漏与内存溢出 1.10.会出现内存溢出的结构 1.内存模型 1.1.JVM内存模型的介绍 内存模型主要分…...

学习一下用鸿蒙​​DevEco Studio HarmonyOS5实现百度地图

在鸿蒙&#xff08;HarmonyOS5&#xff09;中集成百度地图&#xff0c;可以通过以下步骤和技术方案实现。结合鸿蒙的分布式能力和百度地图的API&#xff0c;可以构建跨设备的定位、导航和地图展示功能。 ​​1. 鸿蒙环境准备​​ ​​开发工具​​&#xff1a;下载安装 ​​De…...

pycharm 设置环境出错

pycharm 设置环境出错 pycharm 新建项目&#xff0c;设置虚拟环境&#xff0c;出错 pycharm 出错 Cannot open Local Failed to start [powershell.exe, -NoExit, -ExecutionPolicy, Bypass, -File, C:\Program Files\JetBrains\PyCharm 2024.1.3\plugins\terminal\shell-int…...