当前位置: 首页 > news >正文

PyTorch求导相关

PyTorch是动态图,即计算图的搭建和运算是同时的,随时可以输出结果;而TensorFlow是静态图。

在pytorch的计算图里只有两种元素:数据(tensor)和 运算(operation)

运算包括了:加减乘除、开方、幂指对、三角函数等可求导运算

数据可分为:叶子节点(leaf node)和非叶子节点;叶子节点是用户创建的节点,不依赖其它节点;它们表现出来的区别在于反向传播结束之后,非叶子节点的梯度会被释放掉,只保留叶子节点的梯度,这样就节省了内存。如果想要保留非叶子节点的梯度,可以使用retain_grad()方法。

torch.tensor 具有如下属性:

  • 查看 是否可以求导 requires_grad
  • 查看 运算名称 grad_fn
  • 查看 是否为叶子节点 is_leaf
  • 查看 导数值 grad

针对requires_grad属性,自己定义的叶子节点默认为False,而非叶子节点默认为True,神经网络中的权重默认为True。判断哪些节点是True/False的一个原则就是从你需要求导的叶子节点到loss节点之间是一条可求导的通路。

当我们想要对某个Tensor变量求梯度时,需要先指定requires_grad属性为True,指定方式主要有两种:

x = torch.tensor(1.).requires_grad_() # 第一种x = torch.tensor(1., requires_grad=True) # 第二种

PyTorch提供两种求梯度的方法:backward() and torch.autograd.grad() ,他们的区别在于前者是给叶子节点填充.grad字段,而后者是直接返回梯度给你,我会在后面举例说明。还需要知道y.backward()其实等同于torch.autograd.backward(y)

一个简单的求导例子是:y=(x+1)∗(x+2) ,计算 ∂y/∂x ,假设给定 x=2
先画出计算图

手算:∂y/∂x=(x+2)*1+(x+1)*1->7

使用backward()

x = torch.tensor(2., requires_grad=True)a = torch.add(x, 1)
b = torch.add(x, 2)
y = torch.mul(a, b)y.backward()
print(x.grad)
>>>tensor(7.)

看一下这几个tensor的属性

print("requires_grad: ", x.requires_grad, a.requires_grad, b.requires_grad, y.requires_grad)
print("is_leaf: ", x.is_leaf, a.is_leaf, b.is_leaf, y.is_leaf)
print("grad: ", x.grad, a.grad, b.grad, y.grad)>>>requires_grad:  True True True True
>>>is_leaf:  True False False False
>>>grad:  tensor(7.) None None None

使用backward()函数反向传播计算tensor的梯度时,并不计算所有tensor的梯度,而是只计算满足这几个条件的tensor的梯度:1.类型为叶子节点、2.requires_grad=True、3.依赖该tensor的所有tensor的requires_grad=True。所有满足条件的变量梯度会自动保存到对应的grad属性里。

使用autograd.grad()

x = torch.tensor(2., requires_grad=True)a = torch.add(x, 1)
b = torch.add(x, 2)
y = torch.mul(a, b)grad = torch.autograd.grad(outputs=y, inputs=x)
print(grad[0])
>>>tensor(7.)

因为指定了输出y,输入x,所以返回值就是 ∂x/∂y 这一梯度,完整的返回值其实是一个元组,保留第一个元素就行,后面元素是

二阶求导

求一阶导可以用backward()

x = torch.tensor(2., requires_grad=True)
y = torch.tensor(3., requires_grad=True)z = x * x * yz.backward()
print(x.grad, y.grad)
>>>tensor(12.) tensor(4.)

也可以用autograd.grad()

x = torch.tensor(2.).requires_grad_()
y = torch.tensor(3.).requires_grad_()z = x * x * ygrad_x = torch.autograd.grad(outputs=z, inputs=x)
print(grad_x[0])
>>>tensor(12.)

为什么不在这里面同时也求对y的导数呢?因为无论是backward还是autograd.grad在计算一次梯度后图就被释放了,如果想要保留,需要添加retain_graph=True

x = torch.tensor(2.).requires_grad_()
y = torch.tensor(3.).requires_grad_()z = x * x * ygrad_x = torch.autograd.grad(outputs=z, inputs=x, retain_graph=True)
grad_y = torch.autograd.grad(outputs=z, inputs=y)print(grad_x[0], grad_y[0])
>>>tensor(12.) tensor(4.) 

再来看如何求高阶导,理论上其实是上面的grad_x再对x求梯度,试一下看

x = torch.tensor(2.).requires_grad_()
y = torch.tensor(3.).requires_grad_()z = x * x * ygrad_x = torch.autograd.grad(outputs=z, inputs=x, retain_graph=True)
grad_xx = torch.autograd.grad(outputs=grad_x, inputs=x)print(grad_xx[0])
>>>RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

报错了,虽然retain_graph=True保留了计算图和中间变量梯度, 但没有保存grad_x的运算方式,需要使用creat_graph=True在保留原图的基础上再建立额外的求导计算图,也就是会把 ∂z/∂x=2xy 这样的运算存下来

# autograd.grad() + autograd.grad()
x = torch.tensor(2.).requires_grad_()
y = torch.tensor(3.).requires_grad_()z = x * x * ygrad_x = torch.autograd.grad(outputs=z, inputs=x, create_graph=True)
grad_xx = torch.autograd.grad(outputs=grad_x, inputs=x)print(grad_xx[0])
>>>tensor(6.)

grad_xx这里也可以直接用backward(),相当于直接从 ∂z/∂x=2xy 开始回传

# autograd.grad() + backward()
x = torch.tensor(2.).requires_grad_()
y = torch.tensor(3.).requires_grad_()z = x * x * ygrad = torch.autograd.grad(outputs=z, inputs=x, create_graph=True)
grad[0].backward()print(x.grad)
>>>tensor(6.)

 也可以先用backward()然后对x.grad这个一阶导继续求导

# backward() + autograd.grad()
x = torch.tensor(2.).requires_grad_()
y = torch.tensor(3.).requires_grad_()z = x * x * yz.backward(create_graph=True)
grad_xx = torch.autograd.grad(outputs=x.grad, inputs=x)print(grad_xx[0])
>>>tensor(6.)

那是不是也可以直接用两次backward()呢?第二次直接x.grad从开始回传,我们试一下

# backward() + backward()
x = torch.tensor(2.).requires_grad_()
y = torch.tensor(3.).requires_grad_()z = x * x * yz.backward(create_graph=True) # x.grad = 12
x.grad.backward()print(x.grad)
>>>tensor(18., grad_fn=<CopyBackwards>)

发现了问题,结果不是6,而是18,发现第一次回传时输出x梯度是12。这是因为PyTorch使用backward()时默认会累加梯度,需要手动把前一次的梯度清零

x = torch.tensor(2.).requires_grad_()
y = torch.tensor(3.).requires_grad_()z = x * x * yz.backward(create_graph=True)
x.grad.data.zero_()
x.grad.backward()print(x.grad)
>>>tensor(6., grad_fn=<CopyBackwards>)

向量求导

有没有发现前面都是对标量求导,如果不是标量会怎么样呢?

x = torch.tensor([1., 2.]).requires_grad_()
y = x + 1y.backward()
print(x.grad)
>>>RuntimeError: grad can be implicitly created only for scalar outputs

x = torch.tensor([1., 2.]).requires_grad_()
y = x * xy.sum().backward()
print(x.grad)
>>>tensor([2., 4.])

相关文章:

PyTorch求导相关

PyTorch是动态图&#xff0c;即计算图的搭建和运算是同时的&#xff0c;随时可以输出结果&#xff1b;而TensorFlow是静态图。 在pytorch的计算图里只有两种元素&#xff1a;数据&#xff08;tensor&#xff09;和 运算&#xff08;operation&#xff09; 运算包括了&#xf…...

Halcon基础-瓶盖带角度的OCR批量识别

Halcon基础-OCR识别 1、OCR识别素材2、创建路径文件3、Halcon代码实现4、运行效果5、资源获取 1、OCR识别素材 这里我准备了7张不同角度的OCR图片&#xff0c;如下所示&#xff1a; 2、创建路径文件 按照下图所示创建全部文件夹和文件&#xff1a; 01用来存放OCR识别原图 c…...

php语法学习

启动php 进入软件 打开文件&#xff1a;编写代码 $php true; $java false; var_dump($php);//输出变量细节 var_dump($java) 字符串 注意可以使用双引号也可以使用单引号 测试 $php "最好学web语言"; $java 脱胎于c语言; var_dump($php);//输出变量细节 var…...

JavaWeb合集22-Apache POI

二十二、Apache POI Apache POI是一个处理Miscrosoft Office各种文件格式的开源项目。简单来说就是&#xff0c;我们可以使用POI在Java 序中对Miscrosoft Office各种文件进行读写操作。一般情况下&#xff0c;POI都是用于操作Excel文件。 使用场景&#xff1a;银行网银系统导出…...

DDD重构-实体与限界上下文重构

DDD重构-实体与限界上下文重构 概述 DDD 方法需要不同类型的类元素&#xff0c;例如实体或值对象&#xff0c;并且几乎所有这些类元素都可以看作是常规的 Java 类。它们的总体结构是 Name: 类的唯一名称 Properties&#xff1a;属性 Methods: 控制变量的变化和添加行为 一…...

MATLAB Simulink (二)高速跳频通信系统

MATLAB & Simulink &#xff08;二&#xff09;高速跳频通信系统 写在前面1 系统原理1.1 扩频通信系统理论基础1.1.1 基本原理1.1.2 扩频通信系统处理增益和干扰容限1.1.3 各种干扰模式下抗干扰性能 1.2 高速跳频通信系统理论基础1.2.1 基本原理1.2.2 物理模型 2 方案设计2…...

智能合约分享

智能合约练习 一、solidity初学者经典示例代码&#xff1a; 1.存储和检索数据&#xff1a; // SPDX-License-Identifier: MIT pragma solidity ^0.8.0; // 声明 Solidity 编译器版本// 定义一个名为 SimpleStorage 的合约 contract SimpleStorage {// 声明一个公共状态变量 d…...

【MR开发】在Pico设备上接入MRTK3(二)——在Unity中配置Pico SDK

上一篇文档介绍了 【MR开发】在Pico设备上接入MRTK3&#xff08;一&#xff09;在Unity中导入MRTK3依赖 下面将介绍在Unity中导入Pcio SDK的具体步骤 在Unity中导入Pico SDK 当前Pico SDK版本 Unity交互SDK git仓库&#xff1a; https://github.com/Pico-Developer/PICO-Un…...

【Java】探秘正则表达式:深度解析与精妙运用

目录 引言 一、基本概念 1.1 元字符 1.2 预定义字符类 1.3 边界匹配符 1.4 数量标识符 1.5 捕获与非捕获分组 二、Java中的正则表达式支持 三、正则表达式的使用示例 3.1 匹配字符串 3.2 替换字符串 3.3 分割字符串 3.4 使用Pattern和Matcher 3.5 捕获组和后向…...

2.6.ReactOS系统中从内核中发起系统调用

2.6.ReactOS系统中从内核中发起系统调用 2.6.ReactOS系统中从内核中发起系统调用 文章目录 2.6.ReactOS系统中从内核中发起系统调用前言 前言 上面我们已经可以看到用户空间&#xff08;R3&#xff09;进行系统调用的全过程即两种方法的具体实现。 系统调用一般时从R3发起的…...

chat_gpt回答:python获取当前utc时间,将xml里时间tag里的值修改为当前时间

你可以使用 lxml 库来读取、修改 XML 文件中的某个标签的值&#xff0c;并将其保存为新的 XML 文件。以下是一个示例代码&#xff0c;展示如何获取当前的 UTC 时间&#xff0c;并将 XML 文件中的某个时间标签修改为当前时间。 示例代码&#xff1a; from lxml import etree f…...

机器学习-语言分析

机器学习 1.1人工智能概述 1.2.1 机器学习与人工智能&#xff0c;深度学习 深度学习->机器学习->人工智能&#xff1b; 人工智能&#xff1a;1950&#xff0c;实现自动下棋&#xff0c;人机对弈&#xff0c;达特茅斯会议->人工智能的起点&#xff0c;1956年8月。克劳…...

Oracle 常见索引扫描方式概述,哪种索引扫描最快!

一.常见的索引扫描方式 INDEX RANGE SCANINDEX FAST FULL SCANINDEX FULL SCAN(MIN/MAX)INDEX FULL SCAN 二.分别模拟使用这些索引的场景 1.INDEX RANGE SCAN create table t1 as select rownum as id, rownum/2 as id2 from dual connect by level<500000; create inde…...

字符串(3)_二进制求和_高精度加法

个人主页&#xff1a;C忠实粉丝 欢迎 点赞&#x1f44d; 收藏✨ 留言✉ 加关注&#x1f493;本文由 C忠实粉丝 原创 字符串(3)_二进制求和_高精度加法 收录于专栏【经典算法练习】 本专栏旨在分享学习算法的一点学习笔记&#xff0c;欢迎大家在评论区交流讨论&#x1f48c; 目…...

《神经网络:智能时代的核心技术》

《神经网络&#xff1a;智能时代的核心技术》 一、神经网络的诞生与发展二、神经网络的结构与工作原理&#xff08;一&#xff09;神经元模型&#xff08;二&#xff09;神经网络训练过程 三、神经网络的应用领域&#xff08;一&#xff09;信息领域&#xff08;二&#xff09;…...

pdf内容三张以上转图片,使用spire.pdf.free

一、依赖 <spire.pdf.free.version>9.13.0</spire.pdf.free.version><itextpdf.version>5.5.13</itextpdf.version><dependency><groupId>e-iceblue</groupId><artifactId>spire.pdf.free</artifactId><version>$…...

游戏、软件、开源项目和资讯

游戏 标题链接【白嫖正版游戏】IT之家喜加一website 软件 标题链接【白嫖正版软件】反斗限免website 开源项目 标题链接【Luxirty Search】基于Google搜索结果&#xff0c;屏蔽内容农场Github【Video2X】图片/视频超分工具Github 新闻资讯 标题链接分享10个 Claude 3.5 …...

Acrel-1000变电站综合自动化系统及微机在化工企业中的应用方案

文&#xff1a;安科瑞郑桐 摘要&#xff1a;大型化工企业供配电具有的集约型特点&#xff0c;化工企业内35kV变电站和10kV变电所数量大、分布广&#xff0c;对于老的大多大型及中型化工企业而言&#xff0c;其变电站或变电所内高压电气设备为旧式继电保护装置&#xff0c;可靠…...

[Linux] CentOS7替换yum源为阿里云并安装gcc详细过程(附下载链接)

前言 CentOS7替换yum源为阿里云 yum是CentOS中的一种软件管理器&#xff0c;通过yum安装软件&#xff0c;可以自动解决包依赖的问题&#xff0c;免去手工安装依赖包的麻烦。 yum使用了一个中心仓库来记录和管理软件的依赖关系&#xff0c;默认为mirrorlist.centos.org&#xf…...

在Java中创建多线程的三种方式

多线程的创建和启动方式 在Java中&#xff0c;创建多线程主要有以下三种方式&#xff1a; 继承Thread类实现Runnable接口使用Callable接口与Future 下面是这三种方式的简单示例&#xff0c;以及如何在主类中启动它们。 1. 继承Thread类 class MyThread extends Thread {Ov…...

洛谷 AT_abc374_c [ABC374C] Separated Lunch 题解

题目大意 KEYENCE 总部有 N N N 个部门&#xff0c;第 i i i 个部门有 K i K_i Ki​ 个人。 现在要把所有部门分为 AB 两组&#xff0c;求这两组中人数多的那一组的人数最少为多少。 题目分析 设这些部门共有 x x x 个人&#xff0c;则较多的组的人数肯定大于等于 ⌈ …...

力扣2528.最大化城市的最小电量

力扣2528.最大化城市的最小电量 题目解析及思路 题目要求找到所有城市电量最小值的最大 电量为给城市供电的发电站数量 因此每座城市的电量可以用一段区间和表示&#xff0c;即前缀和 二分最低电量时 如果当前城市电量不够,贪心的想发电站建立的位置&#xff0c;应该是在mi…...

【zookeeper】集群配置

zookeeper 数据结构 zookeeper数据模型结构&#xff0c;就和Linux的文件系统类型&#xff0c;看起来是一颗树&#xff0c;每个节点称为一个znode.每一个Znode默认的存储1MB的数据&#xff0c;每个Znode都有唯一标识&#xff0c;可以通过命令显示节点的信息每当节点有数据变化…...

YOLO11 目标检测 | 导出ONNX模型 | ONNX模型推理

本文分享YOLO11中&#xff0c;从xxx.pt权重文件转为.onnx文件&#xff0c;然后使用.onnx文件&#xff0c;进行目标检测任务的模型推理。 用ONNX模型推理&#xff0c;便于算法到开发板或芯片的部署。 备注&#xff1a;本文是使用Python&#xff0c;编写ONNX模型推理代码的 目…...

PostgreSQL DBA月度检查列表

为了确保数据库系统能够稳定高效运行&#xff0c;DBA 需要定期对数据库进行检查和维护&#xff0c;这是一项非常具有挑战性的工作。 本文给大家推荐一个 PostgreSQL DBA 月度性能检查列表&#xff0c;遵循以下指导原则可以帮助我们实现一个高可用、高性能、低成本、可扩展的数…...

驱动开发系列12 - Linux 编译内核模块的Makefile解释

一:内核模块Makefile #这一行定义了要编译的内核模块目标文件。obj-m表示目标模块对象文件(.o文件), #并指定了两个模块源文件:helloworld-params.c 和 helloworld.c。最终会生成这 #这两个.c文件的.o对象文件。 obj-m := helloworld-params.o helloworld.o#这行定义了内核…...

用js+css实现圆环型的进度条——js+css基础积累

如果用jscss实现圆环型的进度条&#xff1a; 直接上代码&#xff1a; <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8" /><meta http-equiv"X-UA-Compatible" content"IEedge" /><met…...

TDengine 与北微传感达成合作,解决传统数据库性能瓶颈

在当今物联网&#xff08;IoT&#xff09;快速发展的背景下&#xff0c;传感器技术已成为各个行业数字化转型的关键组成部分。随着设备数量的激增和数据生成速度的加快&#xff0c;如何高效地管理和分析这些数据&#xff0c;成为企业实现智能化运营的重要挑战。尤其是在惯性传感…...

通过Python爬虫获取商品销量数据,轻松掌握市场动态

为什么选择Python爬虫&#xff1f; 简洁易用&#xff1a;Python语言具有简洁的语法和丰富的库&#xff0c;使得编写爬虫变得简单高效。强大的库支持&#xff1a;Python拥有强大的爬虫框架&#xff08;如Scrapy、BeautifulSoup、Requests等&#xff09;&#xff0c;可以快速实现…...

学习虚幻C++开发日志——TSet

TSet 官方文档&#xff1a;虚幻引擎中的Set容器 | 虚幻引擎 5.5 文档 | Epic Developer Community (epicgames.com) TSet 是通过对元素求值的可覆盖函数&#xff0c;使用数据值本身作为键&#xff0c;而不是将数据值与独立的键相关联。 默认情况下&#xff0c;TSet 不支持重…...

营销型网站单页面/电商平台怎么搭建

Vue组件调试遇到的坑,触发断点,但没有进入对应的文件 今天遇到这样一个问题 我再一个index.vue组件里调试,写下一个debugger,在运行时,也确实触发了断点,但显示的文件却不是我打断点的那个文件 而是在index.vue上级的一个index.vue 一句话描述就是:在vue组件里打断点,没有…...

政府门户网站建设总结/新产品推广方案怎么写

在论坛上经常会有人问&#xff0c;到底是使用Trie算法保存路由表还是用Hash算法。那么我首先要明白&#xff0c;你要保存多大的路由表。简单的答案如下&#xff1a;少量&#xff1a;Hash算法大量&#xff1a;Trie算法但是&#xff0c;仅仅这么回答会显得很业余&#xff0c;真的…...

电影购票网站开发背景/怎样做好网络营销推广

import pandas as pddata pd.DataFrame() series pd.Series({"x":1,"y":2},name"a") data data.append(series) print(data) data data.append(series) print(data)x y a 1.0 2.0x y a 1.0 2.0 a 1.0 2.0...

建设工程合同约定的质量目标/绍兴网站快速排名优化

shirio的功能 Shiro可以非常容易的开发出足够好的应用&#xff0c;其不仅可以用在JavaSE环境&#xff0c;也可以用在JavaEE环境。Shiro可以帮助我们完成&#xff1a;认证、授权、加密、会话管理、与Web集成、缓存等。这不就是我们想要的嘛&#xff0c;而且Shiro的API也是非常简…...

威海市网站建设/网店推广方法

文章目录[Spring MVC]未开启矩阵变量&#xff08;MatrixVariable&#xff09;的支持[Spring MVC]缺少依赖&#xff08;与jsp有关&#xff09;[Spring MVC]缺少依赖[commons-fileupload][Spring MVC]无法访问html文件静态资源缓存配置失败依赖冲突redirect-view-controller不起作…...

手机怎么建网站/杭州优化公司在线留言

The following are some english patterns I’ve learned today. Iwrote down them in order to restudy later. 1. have a bone in one’s throat 难以启齿 2. Don’t take it to heart. 别往心里去。 3. I just couldn’t help it. 我只是控制不住。 4. Let’s face it. 面对…...