梯度下降法以及 Python 实现
文章目录
- 1. 引言
- 2. 梯度法
- 3. 例子
- 4. 代码实现
- 5. 讨论 — 学习率 η \eta η
- 参考
1. 引言
梯度下降法,可以根据微分求出的斜率计算函数的最小值。
在人工智能中,经常被应用于学习算法。
2. 梯度法
梯度法 是根据函数的微分值(斜率)搜索最小值的算法。
梯度下降法也是一种梯度法,它通过向最陡方向下降来查找最小值。
给定一个多变量函数:
f ( x ) = f ( x 1 , x 2 , … , x i , … , x n ) . f(x) = f(x_1, x_2, \dots, x_i, \dots, x_n). f(x)=f(x1,x2,…,xi,…,xn).
首先为 x x x 赋予一个合适的初始值,然后通过下面的表达式进行更新:
x i t + 1 = x i t − η ∂ f ( x ) ∂ x i . x^{t+1}_i = x^{t}_i - \eta \frac{\partial f(x)}{\partial x_i}. xit+1=xit−η∂xi∂f(x).
其中, ∂ f ( x ) ∂ x i \displaystyle \frac{\partial f(x)}{\partial x_i} ∂xi∂f(x) 表示函数 f ( x ) f(x) f(x) 对变量 x i x_i xi 的偏导数。 x i t x^{t}_i xit 表示第 t t t 次迭代时变量 x i x_i xi 的取值, x i t + 1 x^{t+1}_i xit+1 表示第 t + 1 t+1 t+1 次迭代时变量 x i x_i xi 的取值。需要说明的是, t t t 是一个非负整数,也即是 t ∈ N t \in \mathbb{N} t∈N。
η \eta η 是一个重要的参数,被称为学习系数或学习率的常数。 η \eta η 决定了 x i x_i xi 的更新速度。可以理解为,一个人 P 要从 A 点走到 B 点,, η \eta η 就是 P 走路时每一步的跨步大小,也称为步长。
根据该表达式, ∂ f ( x ) ∂ x i \displaystyle \frac{\partial f(x)}{\partial x_i} ∂xi∂f(x) 越大,也即是坡度越陡, x i x_i xi 值的变化就越大。
重复此操作,直到 f ( x ) f(x) f(x) 停止变化,那么此时 f ( x ) f(x) f(x) 的值就是 min f ( x ) \min f(x) minf(x)。
3. 例子
给定一个单变量函数 f ( x ) f(x) f(x):
f ( x ) = x 2 − 2 x . f(x)= x^2 - 2x. f(x)=x2−2x.
求 f ( x ) f(x) f(x) 的最小值。
解:函数 f ( x ) f(x) f(x) 的导数记为 f ′ ( x ) f'(x) f′(x):
f ′ ( x ) = d f ( x ) d x = 2 x − 2. f'(x)=\frac{\mathrm{d} f(x)}{\mathrm{d} x}=2x-2. f′(x)=dxdf(x)=2x−2.
令 f ′ ( x ) = 0 f'(x)=0 f′(x)=0,则
f ′ ( x ) = 0 ⇒ 2 x − 2 = 0 x = 1. \begin{aligned} f'(x) =0 \Rightarrow 2x-2 & = 0 \\ x & = 1. \\ \end{aligned} f′(x)=0⇒2x−2x=0=1.
即当 x = 1 x=1 x=1 处, f ( x ) f(x) f(x) 的导数 f ′ ( x ) f'(x) f′(x) 为 0。
将 x = 1 x=1 x=1 带入到 f ( x ) f(x) f(x) 中,得到:
f min ( x ) = f ( x = 1 ) = 1 2 − 2 ∗ 1 = − 1. f_{\min}(x)=f(x=1)=1^2-2*1=-1. fmin(x)=f(x=1)=12−2∗1=−1.
即 f ( x ) f(x) f(x) 的最小值在 x = 1 x=1 x=1 处取得,最小值为 -1。
下面通过模拟梯度下降法来求解。
假设 x x x 的初始值为 2,即 x 0 = 2 x^0=2 x0=2,令学习率 η = 0.1 \eta=0.1 η=0.1。
次数 t t t | 变量 x t x^t xt | 导数 f ′ ( x t ) = 2 x t − 2 f'(x^t)=2x^t-2 f′(xt)=2xt−2 | 函数 f ( x t ) = ( x t ) 2 − 2 x t f(x^t)=(x^t)^2-2x^t f(xt)=(xt)2−2xt | 更新 x t + 1 x^{t+1} xt+1 |
---|---|---|---|---|
0 | x 0 = 2 x^0=2 x0=2 | f ′ ( x 0 ) = 2 ∗ 2 − 2 = 2 f'(x^0)=2*2-2=2 f′(x0)=2∗2−2=2 | f ( x 0 ) = 2 2 − 2 ∗ 2 = 0 f(x^0)=2^2-2*2=0 f(x0)=22−2∗2=0 | x 1 = 2 − 0.1 ∗ 2 = 1.8 x^1=2-0.1*2=1.8 x1=2−0.1∗2=1.8 |
1 | x 1 = 1.8 x^1=1.8 x1=1.8 | f ′ ( x 1 ) = 2 ∗ 1.8 − 2 = 1.6 f'(x^1)=2*1.8-2=1.6 f′(x1)=2∗1.8−2=1.6 | f ( x 1 ) = 1. 6 2 − 2 ∗ 1.6 = − 0.64 f(x^1)=1.6^2-2*1.6=-0.64 f(x1)=1.62−2∗1.6=−0.64 | x 2 = 1.8 − 0.1 ∗ 1.6 = 1.64 x^2=1.8-0.1*1.6=1.64 x2=1.8−0.1∗1.6=1.64 |
2 | x 2 = 1.64 x^2=1.64 x2=1.64 | f ′ ( x 2 ) = 2 ∗ 1.64 − 2 = 1.28 f'(x^2)=2*1.64-2=1.28 f′(x2)=2∗1.64−2=1.28 | f ( x 2 ) = 1.6 4 2 − 2 ∗ 1.64 = − 0.5904 f(x^2)=1.64^2-2*1.64=-0.5904 f(x2)=1.642−2∗1.64=−0.5904 | x 3 = 1.64 − 0.1 ∗ 1.28 = 1.512 x^3=1.64-0.1*1.28=1.512 x3=1.64−0.1∗1.28=1.512 |
3 | x 3 = 1.512 x^3=1.512 x3=1.512 | f ′ ( x 3 ) = 2 ∗ 1.512 − 2 = 1.024 f'(x^3)=2*1.512-2=1.024 f′(x3)=2∗1.512−2=1.024 | f ( x 3 ) = 1.51 2 2 − 2 ∗ 1.512 = − 0.7379 f(x^3)=1.512^2-2*1.512=-0.7379 f(x3)=1.5122−2∗1.512=−0.7379 | x 4 = 1.512 − 0.1 ∗ 1.024 = 1.4096 x^4=1.512-0.1*1.024=1.4096 x4=1.512−0.1∗1.024=1.4096 |
4 | x 4 = 1.4096 x^4=1.4096 x4=1.4096 | … \dots … | … \dots … | … \dots … |
根据梯度下降法的公式进行计算,可以得到上面的表格。可以观察到,导数 f ′ ( x ) f'(x) f′(x) 的值越来越小。继续计算上面的表, x x x 的值会越来越小,逐渐逼近 1。当 f ′ ( x ) = 0 f'(x)=0 f′(x)=0 时, x = 1 x=1 x=1,此时 f ( x ) = − 1 f(x)=-1 f(x)=−1。
4. 代码实现
我们利用 Python 代码可以模拟上面的梯度下降过程。
定义一个函数,表示 f ( x ) f(x) f(x):
def my_func(x):"""$y = x^2 - 2x$:param x: 变量:return: 函数值"""return x**2 - 2*x
变量 x 对应于 x x x,my_func() 的结果(返回值)对应于 f ( x ) f(x) f(x)。
再定义一个函数,表示 f ′ ( x ) f'(x) f′(x):
def grad_func(x):"""函数 $y = x^2 - 2x$ 的导数:param x: 变量:return: 导数值"""return 2*x - 2
变量 x 对应于 x x x,grad_func() 的结果(返回值)对应于 f ′ ( x ) f'(x) f′(x)。
给定一个学习率 η \eta η,给定一个 x x x 的初始值
eta = 0.1
x = 4.0
那么就可以开始模拟梯度下降法求解最小值。
import numpy as np
import matplotlib.pyplot as pltdef my_func(x):"""$y = x^2 - 2x$:param x: 变量:return: 函数值"""return x**2 - 2*xdef grad_func(x):"""函数 $y = x^2 - 2x$ 的导数:param x: 变量:return: 导数值"""return 2*x - 2eta = 0.1
x = 4.0
record_x = []
record_y = []for i in range(20):y = my_func(x)record_x.append(x)record_y.append(y)x -= eta * grad_func(x)print(np.round(record_x, 4))
print(np.round(record_y, 4))x_f = np.linspace(-2, 4)
y_f = my_func(x_f)plt.plot(x_f, y_f, linestyle='--', color='red')
plt.scatter(record_x, record_y)plt.xlabel('x', size=14)
plt.ylabel('y', size=14)
plt.grid()
plt.show()
x x x 的变化过程为:
[4. 3.4 2.92 2.536 2.2288 1.983 1.7864 1.6291 1.5033 1.4027 1.3221 1.2577 1.2062 1.1649 1.1319 1.1056 1.0844 1.0676 1.054 1.0432]
f ( x ) f(x) f(x) 的变化过程为:
[ 8. 4.76 2.6864 1.3593 0.5099 -0.0336 -0.3815 -0.6042 -0.7467 -0.8379 -0.8962 -0.9336 -0.9575 -0.9728 -0.9826 -0.9889 -0.9929 -0.9954 -0.9971 -0.9981]
我们使用了 matplotlib 可视化函数 f ( x ) f(x) f(x) 的图像,以及梯度下降法求解的过程。
红色虚线是函数 f ( x ) f(x) f(x) 的图像。
蓝色点表示梯度下降法求解过程中 f ( x ) f(x) f(x) 的值。
5. 讨论 — 学习率 η \eta η
学习率( η \eta η)是一个非常重要的参数。有多重要呢?请接着看……
在上面的例子中,我们设置学习率为 0.1,即 η = 0.1 \eta = 0.1 η=0.1。同样的以上面的例子为例,我们修改学习率。
5.1 当 η \eta η 设置过大
设置 η = 1 \eta = 1 η=1, x x x 的初始值保持一致,仍取值 4.0。
eta = 1
x = 4.0
那么,再一次利用梯度下降法求解,
x x x 的变化过程为:
[ 4. -2. 4. -2. 4. -2. 4. -2. 4. -2. 4. -2. 4. -2. 4. -2. 4. -2. 4. -2.]
f ( x ) f(x) f(x) 的变化过程为:
[8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8. 8.]
可视化结果为:
上面的输出结果和图像都可以看出, x x x 和 f ( x ) f(x) f(x) 的结果在循环,始终无法得到正确结果,进入了死循环。
5.2 当 η \eta η 设置过小
设置 η = 0.01 \eta = 0.01 η=0.01, x x x 的初始值保持一致,仍取值 4.0。
eta = 0.01
x = 4.0
那么,再一次利用梯度下降法求解,
x x x 的变化过程为:
[4. 3.94 3.8812 3.8236 3.7671 3.7118 3.6575 3.6044 3.5523 3.5012 3.4512 3.4022 3.3542 3.3071 3.2609 3.2157 3.1714 3.128 3.0854 3.0437]
f ( x ) f(x) f(x) 的变化过程为:
[8. 7.6436 7.3013 6.9726 6.6569 6.3537 6.0625 5.7828 5.5142 5.2562 5.0085 4.7705 4.542 4.3226 4.1118 3.9094 3.7149 3.5282 3.3489 3.1767]
可视化结果为:
上面的输出结果和图像都可以看出,梯度下降法在正确工作。但是求解过程很缓慢,离最小值还有一段距离。此时需要增加循环轮次,消耗更多的资源。
总结:需要设置合理的学习率 η \eta η,过大或过小都不好。
参考
-《用Python编程和实践!数学教科书》
相关文章:

梯度下降法以及 Python 实现
文章目录 1. 引言2. 梯度法3. 例子4. 代码实现5. 讨论 — 学习率 η \eta η5.1 当 η \eta η 设置过大5.2 当 η \eta η 设置过小 参考 1. 引言 梯度下降法,可以根据微分求出的斜率计算函数的最小值。 在人工智能中,经常被应用于学习算法。 2. 梯…...

Postman cURL命令导入导出
你是否曾为在Postman和终端之间切换、整理请求而抓狂?其实,Postman支持与cURL命令的无缝互通,通过导入导出,极大提升效率。用好这个功能,分分钟让接口测试更高效! Postman如何快速导入cURL命令?…...

Java 在Json对象字符串中查找和提取特定的数据
1、在处理JSON数据时,需要提出个别字段的值,通过正则表达式提取特定的数据 public static void main(String[] args) {//定义多个JSON对象字符串类型,假设每个对象有a,b,c 字段String strJson "{\"a\":1.23,\"b\"…...

synchronized的特性
1.互斥 对于synchronized修饰的方法及代码块不同线程想同时进行访问就会互斥。 就比如synchronized修饰代码块时,一个线程进入该代码块就会进行“加锁”。 退出代码块时会进行“解锁”。 当其他线程想要访问被加锁的代码块时,就会阻塞等待。 阻塞等待…...

领域泛化与领域自适应
领域泛化(Domain Generalization)和领域适应(Domain Adaptation)是机器学习领域中处理不同数据分布场景下模型训练与应用的两种策略,领域泛化在泛化到目标领域时不需要进行调整,而领域自适应在适应到目标领…...

使用aspx,完成一个转发http的post请求功能的api接口,url中增加目标地址参数,传递自定义header参数
使用aspx,完成一个转发http的post请求功能的api接口,url中增加目标地址参数,传递自定义header参数 首先,简单实现一下,如何在ASPX页面中实现这个功能实现代码说明:注意事项: 然后进阶࿰…...

实际车辆行驶轨迹与预设路线偏离检测的Java实现
准备工作 本项目依赖于两个关键库:JTS Topology Suite(简称JTS),用于几何对象创建和空间分析;以及GeoTools,用于处理坐标转换和其他地理信息任务。确保开发环境中已经包含了这两个库,并且正确配…...

从excel数据导入到sqlsever遇到的问题
1、格式问题时间格式,excel中将日期列改为日期未生效,改完后,必须手动单击这个单元格才能生效,那不可能一个一个去双击。解决方案如下 2、导入之后表字段格式问题,数据类型的用navicat导入之后默认是nvarchar类型的&a…...

Linux操作系统——Linux的磁盘管理系统、文件inode及软硬链接
目录 前言 一、磁盘 1、物理结构 2、存储结构 3、磁盘的逻辑结构 二、文件系统 1、基本概念 2、组的概念 1)Data Blaocks 2)inode Table 3)inode Bitmap 4)Blocks Bitmap 5)Group Descriptor Table 6)Sup…...

算法刷题Day11: BM33 二叉树的镜像
点击题目链接 思路 转换为子问题:左右子树相反转。遍历手法:后序遍历 代码 class Solution:def Transverse(self,root: TreeNode):if root None:return rootnewleft self.Transverse(root.left)newright self.Transverse(root.right)# 对root节点…...

WPF+MVVM案例实战与特效(三十五)- 掌握 Windows 屏幕键盘控制的艺术(TouchKeyBoardHelper 类)
文章目录 1、概述2、TouchKeyBoardHelper 类1、代码实现2、代码解释3、实际应用1、帮助类库与文件创建2、项目引用运行效果3、答疑解惑1、概述 在WPF应用程序开发中,有时需要提供启动或关闭屏幕键盘(On-Screen Keyboard, OSK)的功能。为了实现这一需求,我们创建了一个名为…...

Python+OpenCV系列:绘制中文的方法
绘制中文的方法 方法一:使用Pillow(PIL)与OpenCV结合方法二:使用Matplotlib与OpenCV结合方法三:结合第三方库OpenCV-ZH注意事项 在Python中,使用OpenCV绘制中文需要处理字体加载问题,因为OpenCV…...

精品推荐 | StarLighter 1×dsDNA HS Assay Kit
关键词:核酸浓度测定,核酸定量检测试剂盒,dsDNA浓度测定,dsDNA定量检测 产品简介 StarLighter 1dsDNA HS Assay Kit是一种快速简便的双链DNA(dsDNA)荧光定量检测试剂盒,具有极高的检测灵敏度&…...

挑战用React封装100个组件【010】
Hello,大家好,今天我挑战的组件是这样的! 今天这个组件是一个打卡成功,或者获得徽章后的组件。点击按钮后,会弹出礼花。项目中的勋章是我通过AI生成的,还是很厉害的哈!稍微抠图直接使用。最后面…...

burp suite 5
声明! 学习视频来自B站up主 **泷羽sec** 有兴趣的师傅可以关注一下,如涉及侵权马上删除文章,笔记只是方便各位师傅的学习和探讨,文章所提到的网站以及内容,只做学习交流,其他均与本人以及泷羽sec团队无关&a…...

锐捷Web认证
文章目录 Web认证二代 Web 认证配置 🏡作者主页:点击! 🤖Datacom专栏:点击! ⏰️创作时间:2024年12月6日11点40分 Web认证 Portal 认证、Web认证 Web认证的介绍 Web 认证使用浏览器进行身份验…...

【开源免费】基于Vue和SpringBoot的服装生产管理系统(附论文)
博主说明:本文项目编号 T 066 ,文末自助获取源码 \color{red}{T066,文末自助获取源码} T066,文末自助获取源码 目录 一、系统介绍二、演示录屏三、启动教程四、功能截图五、文案资料5.1 选题背景5.2 国内外研究现状5.3 可行性分析…...

每日速记10道MySQL面试题16
其他资料 每日速记10道java面试题01-CSDN博客 每日速记10道java面试题02-CSDN博客 每日速记10道java面试题03-CSDN博客 每日速记10道java面试题04-CSDN博客 每日速记10道java面试题05-CSDN博客 每日速记10道java面试题06-CSDN博客 每日速记10道java面试题07-CSDN博客 每…...

云计算考试题
1、与SaaS不同的,这种“云”计算形式把开发环境或者运行平台也作为一种服务给用户提供。(B) A、软件即服务 B、基于平台服务 C、基于WEB服务 D、基于管理服务 2、云计算是对(D)技术的发展与运用 A、并行计算 B、网格计算 C、分布式计算 D、三个选项都是 3、Amazon.com公司…...

无人机理论考试合格证书获取
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 轻型民用无人驾驶航空器安全操控理论培训合格证明 前言无人机特性和应用场景 前言 无人机(Drone)是一种非常受欢迎的技术产品,广泛应用于…...

AcWing 3496. 特殊年份
文章目录 前言代码思路 前言 写简单题没啥。反正都是要写的,先把能拿到的分数拿了,之后有机会再去啃一啃硬骨头。啃不下来就算了。 代码 #include<bits/stdc.h> using namespace std; char a1[10],a2[10],a3[10],a4[10],a5[10]; int main(){cin…...

YOLOv8模型改进 第二十讲 添加三重注意力机制Triplet Attention 提升小目标/遮挡目标
本文这次分享的是三重注意力机制Triplet Attention。现在注意力机制在计算机视觉任务中被广泛研究和应用,如 Squeeze-and-Excitation Networks (SENet)、Convolutional Block Attention Module (CBAM) 等。然而,这些方法存在一些局限性,例如需…...

Linux絮絮叨(三) Ubuntu桌面版添加中文拼音输入法
步骤很详细,直接上教程 一. 配置安装简体拼音输入法 #安装相应的平台支持包 sudo apt install ibus-gtk ibus-gtk3# 安装简体拼音输入法 sudo apt install ibus-pinyin安装完成如果下面的步骤找不到对应输入法可以重启一下,一般不需要 二. 添加简体拼音…...

Ungoogled Chromium127编译指南 Windows篇 - 安装Visual Studio 2022(六)
1. 引言 在编译Ungoogled Chromium之前,正确安装和配置Visual Studio 2022是至关重要的一步。作为主要的开发环境,Visual Studio不仅提供了必要的编译工具,还包含了大量构建过程中需要的组件和库。本文将详细介绍如何在Windows系统上安装和配…...

Kubernetes(K8s)
头条:参考资料 Kubernetes 入门指南:从基础到实践_kubernetes 从入门到实践-CSDN博客 Kubernetes(k8s)与docker的区别 Docker、Kubernetes之间的区别_docker和kubernetes区别-CSDN博客 Docker部署SpringBoot项目(镜…...

证明切平面过定点的曲面是锥面
目录 证明:切平面过定点的曲面是锥面. 证明:切平面过定点的曲面是锥面. 证明: 方法一: 设曲面 S : r r ( u , v ) S:\mathbf{r}\mathbf{r}(u,v) S:rr(u,v)的切平面过定点 P 0 P_0 P0,其位置向量为 p 0 . \mathbf{p}_0. p0…...

python中数组怎么转换为字符串
1、数组转字符串 #方法1 arr [a,b] str1 .join(arr)#方法2 arr [1,2,3] #str .join(str(i) for i in arr)#此处str命名与str函数冲突! str2 .join(str(i) for i in arr) 2、字符串转数组 #方法一 str_x avfg st_list list(str_x) #使用list()#方法二 list_s…...

Linux 查看运行了哪些服务
1、service --status-all service --status-all输出: ● fdfs_storaged.service - LSB: FastDFS storage serverLoaded: loaded (/etc/rc.d/init.d/fdfs_storaged; bad; vendor preset: disabled)Active: active (running) since Thu 2019-03-28 09:53:35 CST; 5 years 8 mon…...

WPS EXCEL 使用 WPS宏编辑器 写32位十六进制数据转换为浮点小数的公式。
新建EXCLE文件 另存为xlsm格式的文件 先打开WPS的开发工具中的宏编辑器 宏编辑器编译环境 在工作区添加函数并编译,如果有错误会有弹窗提示,如果没有错误则不会弹 函数名字 ”HEXTOFLOAT“ 可以自己修改。 function HEXTOFLOAT(hex) { // 将十六…...

SpringMVC ——(1)
1.SpringMVC请求流程 1.1 SpringMVC请求处理流程分析 Spring MVC框架也是⼀个基于请求驱动的Web框架,并且使⽤了前端控制器模式(是⽤来提供⼀个集中的请求处理机制,所有的请求都将由⼀个单⼀的处理程序处理来进⾏设计,再根据请求…...