PyTorch 中使用自动求导计算梯度
使用 PyTorch 进行自动求导和梯度计算
在 PyTorch 中,张量的 requires_grad 属性决定了是否需要计算该张量的梯度。设置为 True 的张量会在计算过程中记录操作,以便在调用 .backward() 方法时自动计算梯度。通过构建计算图,PyTorch 能够有效地追踪和计算梯度。
1、梯度的定义
在数学中,梯度是一个向量,表示函数在某一点的变化率。在深度学习中,我们通常关心的是损失函数相对于模型参数的梯度。具体来说,假设我们有一个输出 out,我们计算的是损失函数对模型参数(如权重和偏置)的梯度,而不是直接对输出的梯度。
2、 简单例子
在我们接下来的例子中,我们将计算 out 相对于输入变量 x x x 和 y y y的梯度,通常表示为 ( d out d x ) ( \frac{d \text{out}}{dx}) (dxdout)和 ( d out d y ) ( \frac{d \text{out}}{dy}) (dydout)
import torch# 1. 创建张量并设置 requires_grad=True
x = torch.tensor(2.0, requires_grad=True) # 输入变量 x
y = torch.tensor(3.0, requires_grad=True) # 输入变量 y# 2. 定义第一个函数 f(z) = z^2
def f(z):return z**2# 3. 定义第二个函数 g(x, y) = f(z) + y^3
def g(x, y):z = x + y # 中间变量 zz_no_grad = z.detach() # 创建不需要梯度的副本return f(z_no_grad) + y**3 # 输出 out = f(z_no_grad) + y^3# 4. 计算输出
out = g(x, y) # 计算输出# 5. 反向传播以计算梯度
out.backward() # 计算梯度# 6. 打印梯度
print(f"dz/dx: {x.grad}") # 输出 x 的梯度
print(f"dz/dy: {y.grad}") # 输出 y 的梯度
dout/dx: None
dout/dy: 27.0
import torch# 1. 创建张量并设置 requires_grad=True
x = torch.tensor(2.0, requires_grad=True) # 输入变量 x
y = torch.tensor(3.0, requires_grad=True) # 输入变量 y# 2. 定义第一个函数 f(z) = z^2
def f(z):return z ** 2# 3. 定义第二个函数 g(x, y) = f(z) + y^3
def g(x, y):z = x + y # 中间变量 zreturn f(z) + y ** 3 # 输出 out = f(z_no_grad) + y^3# 4. 计算输出
out = g(x, y) # 计算输出# 5. 反向传播以计算梯度
out.backward() # 计算梯度# 6. 打印梯度
print(f"dout/dx: {x.grad}") # 输出 x 的梯度
print(f"dout/dy: {y.grad}") # 输出 y 的梯度
dout/dx: 10.0
dout/dy: 37.0
在这两个代码示例中,dout/dx 和 dout/dy 的值存在显著差异,主要原因在于如何处理中间变量 ( z ) 以及其对最终输出 out 的影响。
结果分析
-
第一部分代码:
-
在
g(x, y)函数中,使用了 z . detach ( ) z.\text{detach}() z.detach() 创建了一个不需要梯度的副本 z no_grad z_{\text{no\_grad}} zno_grad。这意味着在计算 f ( z no_grad ) f(z_{\text{no\_grad}}) f(zno_grad) 时,PyTorch 不会将 z z z 的变化记录进计算图中。 -
因此, z z z 对 out \text{out} out 的影响被切断,导致
d out d x = None \frac{d \text{out}}{d x} = \text{None} dxdout=None
因为 x x x 的变化不会影响到 out \text{out} out 的计算。 -
对于 y y y,计算得到的梯度为
d out d y = 27.0 \frac{d \text{out}}{d y} = 27.0 dydout=27.0
这是通过以下步骤得到的: -
输出为
out = f ( z no_grad ) + y 3 \text{out} = f(z_{\text{no\_grad}}) + y^3 out=f(zno_grad)+y3 -
使用链式法则:
d out d y = 0 + 3 y 2 = 3 ( 3 2 ) = 27 \frac{d \text{out}}{d y} = 0 + 3y^2 = 3(3^2) = 27 dydout=0+3y2=3(32)=27
-
-
第二部分代码:
- 在
g(x, y)函数中,直接使用了 z z z 而没有使用 z . detach ( ) z.\text{detach}() z.detach()。这使得 z z z 的变化会被记录在计算图中。 - 计算
d out d x \frac{d \text{out}}{d x} dxdout
时, z = x + y z = x + y z=x+y 的变化会影响到 out \text{out} out,因此计算得到的梯度为
d out d x = 10.0 \frac{d \text{out}}{d x} = 10.0 dxdout=10.0
这是因为: - f ( z ) = z 2 f(z) = z^2 f(z)=z2 的导数为
d f ( z ) d z = 2 z \frac{d f(z)}{d z} = 2z dzdf(z)=2z
当 z = 5 z = 5 z=5(当 x = 2 , y = 3 x=2, y=3 x=2,y=3 时),所以
2 z = 10 2z = 10 2z=10 - 对于 y y y,计算得到的梯度为
d out d y = 37.0 \frac{d \text{out}}{d y} = 37.0 dydout=37.0
这是因为
d out d y = d ( f ( z ) + y 3 ) d y = 2 z ⋅ d z d y + 3 y 2 = 2 ( 5 ) ( 1 ) + 3 ( 3 2 ) = 10 + 27 = 37 \frac{d \text{out}}{d y} = \frac{d (f(z) + y^3)}{d y} = 2z \cdot \frac{d z}{d y} + 3y^2 = 2(5)(1) + 3(3^2) = 10 + 27 = 37 dydout=dyd(f(z)+y3)=2z⋅dydz+3y2=2(5)(1)+3(32)=10+27=37
- 在
3、线性拟合及梯度计算
在深度学习中,线性回归是最基本的模型之一。通过线性回归,我们可以找到输入特征与输出之间的线性关系。在本文中,我们将使用 PyTorch 实现一个简单的线性拟合模型,定义模型为 y = a x + b x + c + d y = ax + bx + c + d y=ax+bx+c+d,并展示如何计算梯度,同时控制某些参数(如 b b b 和 d d d)不更新梯度。
在这个模型中,我们将定义以下参数:
- a a a:斜率,表示输入 x x x 对输出 y y y 的影响。
- b b b:另一个斜率,表示输入 x x x 对输出 y y y 的影响,但在训练过程中不更新。
- c c c:截距,表示当 x = 0 x=0 x=0 时的输出值。
- d d d:一个常数项,在训练过程中不更新。
3.1、完整代码
下面是实现线性拟合的完整代码:
import torch
import torch.nn as nn
import torch.optim as optim
import matplotlib.pyplot as plt# 1. 创建数据
# 假设我们有一些样本数据
x_data = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])
y_data = torch.tensor([3.0, 5.0, 7.0, 9.0, 11.0]) # 目标值# 2. 定义线性模型
class LinearModel(nn.Module):def __init__(self):super(LinearModel, self).__init__()self.a = nn.Parameter(torch.tensor(1.0)) # 需要更新的参数self.b = nn.Parameter(torch.tensor(0.5), requires_grad=False) # 不需要更新的参数self.c = nn.Parameter(torch.tensor(0.0)) # 需要更新的参数self.d = nn.Parameter(torch.tensor(0.5), requires_grad=False) # 不需要更新的参数def forward(self, x):return self.a * x + self.b * x + self.c + self.d# 3. 实例化模型
model = LinearModel()# 4. 定义损失函数和优化器
criterion = nn.MSELoss() # 均方误差损失
optimizer = optim.SGD(model.parameters(), lr=0.005) # 随机梯度下降优化器# 5. 训练模型
for epoch in range(5000):model.train() # 设置模型为训练模式# 计算模型输出y_pred = model(x_data)# 计算损失loss = criterion(y_pred, y_data)# 反向传播optimizer.zero_grad() # 清零梯度loss.backward() # 计算梯度optimizer.step() # 更新参数# 每10个epoch打印一次loss和参数值if (epoch + 1) % 500 == 0:print(f'Epoch [{epoch + 1}/100], Loss: {loss.item():.4f}, a: {model.a.item():.4f}, b: {model.b.item():.4f}, c: {model.c.item():.4f}, d: {model.d.item():.4f}')# 6. 打印最终参数
print(f'Final parameters: a = {model.a.item()}, b = {model.b.item()}, c = {model.c.item()}, d = {model.d.item()}')# 7. 绘制拟合结果
with torch.no_grad():# 生成用于绘图的 x 值x_fit = torch.linspace(0, 6, 100) # 从 0 到 6 生成 100 个点y_fit = model(x_fit) # 计算对应的 y 值# 绘制真实数据点
plt.scatter(x_data.numpy(), y_data.numpy(), color='red', label='True Data')
# 绘制拟合曲线
plt.plot(x_fit.numpy(), y_fit.numpy(), color='blue', label='Fitted Curve')
plt.xlabel('X')
plt.ylabel('Y')
plt.title('Linear Fit Result')
plt.legend()
plt.grid()
plt.show()
3.2、梯度计算过程
在这个例子中,我们使用了 PyTorch 的自动求导功能来计算梯度。以下是对每个参数的梯度计算过程的解释:
-
参数定义:
- a a a 和 c c c 是需要更新的参数,因此它们的
requires_grad属性默认为True。 - b b b 和 d d d 是不需要更新的参数,设置了
requires_grad=False,因此它们的梯度不会被计算。
- a a a 和 c c c 是需要更新的参数,因此它们的
-
损失计算:
- 在每个训练周期中,我们计算模型的预测值 y pred y_{\text{pred}} ypred,并与真实值 y data y_{\text{data}} ydata 计算均方误差损失:
loss = 1 n ∑ i = 1 n ( y pred , i − y i ) 2 \text{loss} = \frac{1}{n} \sum_{i=1}^{n} (y_{\text{pred},i} - y_{i})^2 loss=n1i=1∑n(ypred,i−yi)2
- 在每个训练周期中,我们计算模型的预测值 y pred y_{\text{pred}} ypred,并与真实值 y data y_{\text{data}} ydata 计算均方误差损失:
-
反向传播:
- 调用
loss.backward()计算所有参数的梯度。由于 b b b 和 d d d 的requires_grad被设置为False,因此它们的梯度不会被计算和更新。
- 调用
-
参数更新:
- 使用优化器
optimizer.step()更新参数。只有 a a a 和 c c c 会被更新。
- 使用优化器
Epoch [500/100], Loss: 0.0038, a: 1.5399, b: 0.5000, c: 0.3559, d: 0.5000
Epoch [1000/100], Loss: 0.0007, a: 1.5171, b: 0.5000, c: 0.4382, d: 0.5000
Epoch [1500/100], Loss: 0.0001, a: 1.5073, b: 0.5000, c: 0.4735, d: 0.5000
Epoch [2000/100], Loss: 0.0000, a: 1.5032, b: 0.5000, c: 0.4886, d: 0.5000
Epoch [2500/100], Loss: 0.0000, a: 1.5014, b: 0.5000, c: 0.4951, d: 0.5000
Epoch [3000/100], Loss: 0.0000, a: 1.5006, b: 0.5000, c: 0.4979, d: 0.5000
Epoch [3500/100], Loss: 0.0000, a: 1.5002, b: 0.5000, c: 0.4991, d: 0.5000
Epoch [4000/100], Loss: 0.0000, a: 1.5001, b: 0.5000, c: 0.4996, d: 0.5000
Epoch [4500/100], Loss: 0.0000, a: 1.5000, b: 0.5000, c: 0.4998, d: 0.5000
Epoch [5000/100], Loss: 0.0000, a: 1.5000, b: 0.5000, c: 0.4999, d: 0.5000
Final parameters: a = 1.5000202655792236, b = 0.5, c = 0.4999275505542755, d = 0.5

相关文章:
PyTorch 中使用自动求导计算梯度
使用 PyTorch 进行自动求导和梯度计算 在 PyTorch 中,张量的 requires_grad 属性决定了是否需要计算该张量的梯度。设置为 True 的张量会在计算过程中记录操作,以便在调用 .backward() 方法时自动计算梯度。通过构建计算图,PyTorch 能够有效…...
Oracle Instant Client 23.5安装配置完整教程
Oracle Instant Client 23.5安装配置完整教程 简介环境要求安装步骤1. 准备工作目录2. 下载Oracle Instant Client3. 解压Instant Client4. 安装依赖包5. 配置系统环境5.1 配置库文件路径5.2 配置环境变量 6. 配置Oracle钱包(可选) 验证安装常见问题解决…...
【jvm】方法区的理解
目录 1. 说明2. 方法区的演进3. 内部结构4. 作用5.内存管理 1. 说明 1.方法区用于存储已被虚拟机加载的类信息、常量、静态变量、即时编译器编译后的代码缓存等数据。它是各个线程共享的内存区域。2.尽管《Java虚拟机规范》中把方法区描述为堆的一个逻辑部分,但它却…...
ES-针对某个字段去重后-获取某个字段值的所有值
针对上面表的数据,现在想根据age分组,并获取每个分组后的name有哪些(去重后)。 select age, GROUP_CONCAT(DISTINCT(name)) from testtable group by age ; 结果: 如果想要增加排序: SELECT age, GROUP_CONCAT(DISTINCT name)…...
百度 2025届秋招提前批 文心一言大模型算法工程师
文章目录 个人情况一面/技术面 1h二面/技术面 1h三面/技术面 40min 个人情况 先说一下个人情况: 学校情况:211本中9硕,本硕学校都一般,本硕都是计算机科班,但研究方向并不是NLP,而是图表示学习论文情况&a…...
sglang 部署Qwen2VL7B,大模型部署,速度测试,深度学习
sglang 项目github仓库: https://github.com/sgl-project/sglang 项目说明书: https://sgl-project.github.io/start/install.html 资讯: https://github.com/sgl-project/sgl-learning-materials?tabreadme-ov-file#the-first-sglang…...
fastadmin操作数据库字段为json、查询遍历each、多级下拉、union、php密码设置、common常用函数的使用小技巧
数据库中遇到的操作 查询字段是json的某个值 //获取数据库中某个字段是json中得某个值,进行查询,goods是表中字段,brand_id是json中要查詢的字段。//数据类型一定要对应要不然查询不出来。$map[json_extract(goods, "$.brand_id")]…...
UniApp在Vue3的setup语法糖下自定义组件插槽详解
UniApp在 Vue3的 setup 语法糖下自定义组件插槽详解 UniApp 是一个基于 Vue.js 的跨平台开发框架,可以用来开发微信小程序、H5、App 等多种平台的应用。Vue 3 引入了 <script setup> 语法糖,使得组件的编写更加简洁和直观。本文将详细介绍如何在 …...
springboot上传下载文件
RequestMapping(“bigJson”) RestController Slf4j public class TestBigJsonController { Resource private BigjsonService bigjsonService;PostMapping("uploadJsonFile") public ResponseResult<Long> uploadJsonFile(RequestParam("file")Mul…...
Python学习从0到1 day29 Python 高阶技巧 ⑦ 正则表达式
目录 一、正则表达式 二、正则表达式的三个基础方法 1.match 从头匹配 2.search(匹配规则,被匹配字符串) 3.findall(匹配规则,被匹配字符串) 三、元字符匹配 单字符匹配: 注: 示例&a…...
机器学习-web scraping
Web Scraping,通常称为网络抓取或数据抓取,是一种通过自动化程序从网页中提取数据的技术。以下是对Web Scraping的详细解释: 一、定义与原理 Web Scraping是指采用技术手段从大量网页中提取结构化和非结构化信息,并按照一定的规…...
移远通信5G RedCap模组RG255C-CN通过中国电信5G Inside终端生态认证
近日,移远通信5G RedCap模组RG255C-CN荣获中国电信颁发的5G Inside终端生态认证证书。这表明,该产品在5G基本性能、网络兼容性、安全特性等方面已经过严格评测且表现优异,将进一步加速推动5G行业终端规模化应用。 中国电信5G Inside终端生态认…...
Javaweb梳理17——HTMLCSS简介
Javaweb梳理17——HTML&CSS简介 17 HTML&CSS简介17.1 HTML介绍17.2 快速入门17.3 基础标签17.3 .1 标题标签17.3.2 hr标签17.3.3 字体标签17.3.4 换行17.3.8 案例17.3.9 图片、音频、视频标签17.3.10 超链接标签17.3.11 列表标签17.3.12 表格标签17.3.11 布局标签17.3.…...
【Android、IOS、Flutter、鸿蒙、ReactNative 】自定义View
Android Java 自定义View 步骤 创建一个新的Java类,继承自View、ViewGroup或其他任何一个视图类。 如果需要,重写构造函数以支持不同的初始化方式。 重写onMeasure方法以提供正确的测量逻辑。 重写onDraw方法以实现绘制逻辑。 根据需要重写其他方法&…...
win11跳过联网激活步骤
win11跳过联网激活步骤 win11跳过联网激活步骤方法一:使用Shift F10快捷键(推荐)1. 启动Windows 112. 选择键盘布局或输入法3. 是否想要添加第二种键盘布局4. 让我们为你连接到网络5. 调出管理员模式CMD6. 耐心等待自动重启7. 启动Windows 1…...
利用c语言详细介绍下冒泡排序
软件开发过程中,排序算法是常规且使用众多的方法之一,而冒泡算法又是排序算法中最常规且基本的算法。今天我们利用c语言,图文详细介绍下冒泡算法。 一、图文介绍 我们输入一个数组,数组为【10,5,3…...
C# 面向对象
C# 面向对象编程 面向过程:一件事情分成多个步骤来完成。 把大象装进冰箱 (面向过程化设计思想)。走一步看一步。 1、打开冰箱门 2、把大象放进冰箱 3、关闭冰箱门 面向对象:以对象作为主体 把大象装进冰箱 1、抽取对象 大象 冰箱 门 ࿰…...
android wifi扫描的capability
混合型加密android11 8155与普通linux设备扫描到的安全字段差别 android应用拿到关于wifi安全的字段: systembar-WifiBroadcastReceiver---- scanResult SSID: Redmi_697B, BSSID: a4:39:b3:70:8c:20, capabilities: [WPA-PSK-TKIPCCMP][WPA2-PSK-TKIPCCMP][RSN-PSK…...
datawhale 2411组队学习:模型压缩4 模型量化理论(数据类型、int8量化方法、PTQ和QWT)
文章目录 一、数据类型1.1 整型1.2 定点数1.3 浮点数1.3.1 正规浮点数(fp32)1.3.2 非正规浮点数(fp32)1.3.3 其它数据类型1.3.4 浮点数误差1.3.5 浮点数导致的模型训练问题 二、量化基本方法2.1 int8量化2.1.1 k-means 量化2.1.2 …...
数据分析-48-时间序列变点检测之在线实时数据的CPD
文章目录 1 时间序列结构1.1 变化点的定义1.2 结构变化的类型1.2.1 水平变化1.2.2 方差变化1.3 变点检测1.3.1 离线数据检测方法1.3.2 实时数据检测方法2 模拟数据2.1 模拟恒定方差数据2.2 模拟变化方差数据3 实时数据CPD3.1 SDAR学习算法3.2 Changefinder模块3.3 恒定方差CPD3…...
手游刚开服就被攻击怎么办?如何防御DDoS?
开服初期是手游最脆弱的阶段,极易成为DDoS攻击的目标。一旦遭遇攻击,可能导致服务器瘫痪、玩家流失,甚至造成巨大经济损失。本文为开发者提供一套简洁有效的应急与防御方案,帮助快速应对并构建长期防护体系。 一、遭遇攻击的紧急应…...
golang循环变量捕获问题
在 Go 语言中,当在循环中启动协程(goroutine)时,如果在协程闭包中直接引用循环变量,可能会遇到一个常见的陷阱 - 循环变量捕获问题。让我详细解释一下: 问题背景 看这个代码片段: fo…...
【SpringBoot】100、SpringBoot中使用自定义注解+AOP实现参数自动解密
在实际项目中,用户注册、登录、修改密码等操作,都涉及到参数传输安全问题。所以我们需要在前端对账户、密码等敏感信息加密传输,在后端接收到数据后能自动解密。 1、引入依赖 <dependency><groupId>org.springframework.boot</groupId><artifactId...
理解 MCP 工作流:使用 Ollama 和 LangChain 构建本地 MCP 客户端
🌟 什么是 MCP? 模型控制协议 (MCP) 是一种创新的协议,旨在无缝连接 AI 模型与应用程序。 MCP 是一个开源协议,它标准化了我们的 LLM 应用程序连接所需工具和数据源并与之协作的方式。 可以把它想象成你的 AI 模型 和想要使用它…...
关于iview组件中使用 table , 绑定序号分页后序号从1开始的解决方案
问题描述:iview使用table 中type: "index",分页之后 ,索引还是从1开始,试过绑定后台返回数据的id, 这种方法可行,就是后台返回数据的每个页面id都不完全是按照从1开始的升序,因此百度了下,找到了…...
Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决
Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决 问题背景 在一个基于 Spring Cloud Gateway WebFlux 构建的微服务项目中,新增了一个本地验证码接口 /code,使用函数式路由(RouterFunction)和 Hutool 的 Circle…...
Mobile ALOHA全身模仿学习
一、题目 Mobile ALOHA:通过低成本全身远程操作学习双手移动操作 传统模仿学习(Imitation Learning)缺点:聚焦与桌面操作,缺乏通用任务所需的移动性和灵活性 本论文优点:(1)在ALOHA…...
中医有效性探讨
文章目录 西医是如何发展到以生物化学为药理基础的现代医学?传统医学奠基期(远古 - 17 世纪)近代医学转型期(17 世纪 - 19 世纪末)现代医学成熟期(20世纪至今) 中医的源远流长和一脉相承远古至…...
《C++ 模板》
目录 函数模板 类模板 非类型模板参数 模板特化 函数模板特化 类模板的特化 模板,就像一个模具,里面可以将不同类型的材料做成一个形状,其分为函数模板和类模板。 函数模板 函数模板可以简化函数重载的代码。格式:templa…...
使用Spring AI和MCP协议构建图片搜索服务
目录 使用Spring AI和MCP协议构建图片搜索服务 引言 技术栈概览 项目架构设计 架构图 服务端开发 1. 创建Spring Boot项目 2. 实现图片搜索工具 3. 配置传输模式 Stdio模式(本地调用) SSE模式(远程调用) 4. 注册工具提…...
