当前位置: 首页 > news >正文

pytorch 3 计算图

计算图结构

**加粗样式**

分析:

  1. 起始节点 a
  2. b = 5 - 3a
  3. c = 2b + 3
  4. d = 5b + 6
  5. e = 7c + d^2
  6. f = 2e
  7. 最终输出 g = 3f - o(其中 o 是另一个输入)

前向传播

前向传播按照上述顺序计算每个节点的值。

反向传播过程

反向传播的目标是计算损失函数(这里假设为 g)对每个中间变量和输入的偏导数。从右向左进行计算:

  1. ∂g/∂o = -1
  2. ∂g/∂f = 3
  3. ∂f/∂e = 2
  4. ∂e/∂c = 7
  5. ∂e/∂d = 2d
  6. ∂d/∂b = 5
  7. ∂c/∂b = 2
  8. ∂b/∂a = -3

链式法则应用

使用链式法则计算出 g 对每个变量的全导数:

  1. dg/df = ∂g/∂f = 3
  2. dg/de = (∂g/∂f) * (∂f/∂e) = 3 * 2 = 6
  3. dg/dc = (dg/de) * (∂e/∂c) = 6 * 7 = 42
  4. dg/dd = (dg/de) * (∂e/∂d) = 6 * 2d
  5. dg/db = (dg/dc) * (∂c/∂b) + (dg/dd) * (∂d/∂b)
    = 42 * 2 + 6 * 2d * 5
    = 84 + 60d
  6. dg/da = (dg/db) * (∂b/∂a)
    = (84 + 60d) * (-3)
    = -252 - 180d

最终梯度

最终得到 g 对输入 a 和 o 的梯度:

  • dg/da = -252 - 180d
  • dg/do = -1

代码实现

静态图

import mathclass Node:"""表示计算图中的一个节点。每个节点都可以存储一个值、梯度,并且知道如何计算前向传播和反向传播。"""def __init__(self, value=None):self.value = value  # 节点的值self.gradient = 0   # 节点的梯度self.parents = []   # 父节点列表self.forward_fn = lambda: None  # 前向传播函数self.backward_fn = lambda: None  # 反向传播函数def __add__(self, other):"""加法操作"""return self._create_binary_operation(other, lambda x, y: x + y, lambda: (1, 1))def __mul__(self, other):"""乘法操作"""return self._create_binary_operation(other, lambda x, y: x * y, lambda: (other.value, self.value))def __sub__(self, other):"""减法操作"""return self._create_binary_operation(other, lambda x, y: x - y, lambda: (1, -1))def __pow__(self, power):"""幂运算"""result = Node()result.parents = [self]def forward():result.value = math.pow(self.value, power)def backward():self.gradient += power * math.pow(self.value, power-1) * result.gradientresult.forward_fn = forwardresult.backward_fn = backwardreturn resultdef _create_binary_operation(self, other, forward_op, gradient_op):"""创建二元操作的辅助方法。用于简化加法、乘法和减法的实现。"""result = Node()result.parents = [self, other]def forward():result.value = forward_op(self.value, other.value)def backward():grads = gradient_op()self.gradient += grads[0] * result.gradientother.gradient += grads[1] * result.gradientresult.forward_fn = forwardresult.backward_fn = backwardreturn resultdef topological_sort(node):"""对计算图进行拓扑排序。确保在前向和反向传播中按正确的顺序处理节点。"""visited = set()topo_order = []def dfs(n):if n not in visited:visited.add(n)for parent in n.parents:dfs(parent)topo_order.append(n)dfs(node)return topo_order# 构建计算图
a = Node(2)  # 假设a的初始值为2
o = Node(1)  # 假设o的初始值为1# 按照给定的数学表达式构建计算图
b = Node(5) - a * Node(3)
c = b * Node(2) + Node(3)
d = b * Node(5) + Node(6)
e = c * Node(7) + d ** 2
f = e * Node(2)
g = f * Node(3) - o# 前向传播
sorted_nodes = topological_sort(g)
for node in sorted_nodes:node.forward_fn()# 反向传播
g.gradient = 1  # 设置输出节点的梯度为1
for node in reversed(sorted_nodes):node.backward_fn()# 打印结果
print(f"g = {g.value}")
print(f"dg/da = {a.gradient}")
print(f"dg/do = {o.gradient}")# 验证手动计算的结果
d_value = 5 * b.value + 6
expected_dg_da = -252 - 180 * d_value
print(f"Expected dg/da = {expected_dg_da}")
print(f"Difference: {abs(a.gradient - expected_dg_da)}")

动态图

import mathclass Node:"""表示计算图中的一个节点。实现了动态计算图的核心功能,包括前向计算和反向传播。"""def __init__(self, value, children=(), op=''):self.value = value  # 节点的值self.grad = 0       # 节点的梯度self._backward = lambda: None  # 反向传播函数,默认为空操作self._prev = set(children)  # 前驱节点集合self._op = op  # 操作符,用于调试def __add__(self, other):"""加法操作"""other = other if isinstance(other, Node) else Node(other)result = Node(self.value + other.value, (self, other), '+')def _backward():self.grad += result.gradother.grad += result.gradresult._backward = _backwardreturn resultdef __mul__(self, other):"""乘法操作"""other = other if isinstance(other, Node) else Node(other)result = Node(self.value * other.value, (self, other), '*')def _backward():self.grad += other.value * result.gradother.grad += self.value * result.gradresult._backward = _backwardreturn resultdef __pow__(self, other):"""幂运算"""assert isinstance(other, (int, float)), "only supporting int/float powers for now"result = Node(self.value ** other, (self,), f'**{other}')def _backward():self.grad += (other * self.value**(other-1)) * result.gradresult._backward = _backwardreturn resultdef __neg__(self):"""取反操作"""return self * -1def __sub__(self, other):"""减法操作"""return self + (-other)def __truediv__(self, other):"""除法操作"""return self * other**-1def __radd__(self, other):"""反向加法"""return self + otherdef __rmul__(self, other):"""反向乘法"""return self * otherdef __rtruediv__(self, other):"""反向除法"""return other * self**-1def tanh(self):"""双曲正切函数"""x = self.valuet = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)result = Node(t, (self,), 'tanh')def _backward():self.grad += (1 - t**2) * result.gradresult._backward = _backwardreturn resultdef backward(self):"""执行反向传播,计算梯度。使用拓扑排序确保正确的反向传播顺序。"""topo = []visited = set()def build_topo(v):if v not in visited:visited.add(v)for child in v._prev:build_topo(child)topo.append(v)build_topo(self)self.grad = 1  # 设置输出节点的梯度为1for node in reversed(topo):node._backward()  # 对每个节点执行反向传播def main():"""主函数,用于测试自动微分系统。构建一个计算图,执行反向传播,并验证结果。"""# 构建计算图a = Node(2)o = Node(1)b = Node(5) - a * 3c = b * 2 + 3d = b * 5 + 6e = c * 7 + d ** 2f = e * 2g = f * 3 - o# 反向传播g.backward()# 打印结果print(f"g = {g.value}")print(f"dg/da = {a.grad}")print(f"dg/do = {o.grad}")# 验证手动计算的结果d_value = 5 * b.value + 6expected_dg_da = -252 - 180 * d_valueprint(f"Expected dg/da = {expected_dg_da}")print(f"Difference: {abs(a.grad - expected_dg_da)}")if __name__ == "__main__":main()

解释:

  1. Node 类代表计算图中的一个节点,包含值、梯度、父节点以及前向和反向传播函数。
  2. 重载的数学运算符 (__add__, __mul__, __sub__, __pow__) 允许直观地构建计算图。
  3. _create_binary_operation 方法用于创建二元操作,简化了加法、乘法和减法的实现。
  4. topological_sort 函数对计算图进行拓扑排序,确保正确的计算顺序。
import mathclass Node:"""表示计算图中的一个节点。实现了动态计算图的核心功能,包括前向计算和反向传播。"""def __init__(self, value, children=(), op=''):self.value = value  # 节点的值self.grad = 0       # 节点的梯度self._backward = lambda: None  # 反向传播函数,默认为空操作self._prev = set(children)  # 前驱节点集合self._op = op  # 操作符,用于调试def __add__(self, other):"""加法操作"""other = other if isinstance(other, Node) else Node(other)result = Node(self.value + other.value, (self, other), '+')def _backward():self.grad += result.gradother.grad += result.gradresult._backward = _backwardreturn resultdef __mul__(self, other):"""乘法操作"""other = other if isinstance(other, Node) else Node(other)result = Node(self.value * other.value, (self, other), '*')def _backward():self.grad += other.value * result.gradother.grad += self.value * result.gradresult._backward = _backwardreturn resultdef __pow__(self, other):"""幂运算"""assert isinstance(other, (int, float)), "only supporting int/float powers for now"result = Node(self.value ** other, (self,), f'**{other}')def _backward():self.grad += (other * self.value**(other-1)) * result.gradresult._backward = _backwardreturn resultdef __neg__(self):"""取反操作"""return self * -1def __sub__(self, other):"""减法操作"""return self + (-other)def __truediv__(self, other):"""除法操作"""return self * other**-1def __radd__(self, other):"""反向加法"""return self + otherdef __rmul__(self, other):"""反向乘法"""return self * otherdef __rtruediv__(self, other):"""反向除法"""return other * self**-1def tanh(self):"""双曲正切函数"""x = self.valuet = (math.exp(2*x) - 1)/(math.exp(2*x) + 1)result = Node(t, (self,), 'tanh')def _backward():self.grad += (1 - t**2) * result.gradresult._backward = _backwardreturn resultdef backward(self):"""执行反向传播,计算梯度。使用拓扑排序确保正确的反向传播顺序。"""topo = []visited = set()def build_topo(v):if v not in visited:visited.add(v)for child in v._prev:build_topo(child)topo.append(v)build_topo(self)self.grad = 1  # 设置输出节点的梯度为1for node in reversed(topo):node._backward()  # 对每个节点执行反向传播def main():"""主函数,用于测试自动微分系统。构建一个计算图,执行反向传播,并验证结果。"""# 构建计算图a = Node(2)o = Node(1)b = Node(5) - a * 3c = b * 2 + 3d = b * 5 + 6e = c * 7 + d ** 2f = e * 2g = f * 3 - o# 反向传播g.backward()# 打印结果print(f"g = {g.value}")print(f"dg/da = {a.grad}")print(f"dg/do = {o.grad}")# 验证手动计算的结果d_value = 5 * b.value + 6expected_dg_da = -252 - 180 * d_valueprint(f"Expected dg/da = {expected_dg_da}")print(f"Difference: {abs(a.grad - expected_dg_da)}")if __name__ == "__main__":main()

解释:

  1. Node 类是核心,它代表计算图中的一个节点,并实现了各种数学运算。

  2. 每个数学运算(如 __add__, __mul__ 等)都创建一个新的 Node,并定义了相应的反向传播函数。

  3. backward 方法实现了反向传播算法,使用拓扑排序确保正确的计算顺序。

相关文章:

pytorch 3 计算图

计算图结构 分析: 起始节点 ab 5 - 3ac 2b 3d 5b 6e 7c d^2f 2e最终输出 g 3f - o(其中 o 是另一个输入) 前向传播 前向传播按照上述顺序计算每个节点的值。 反向传播过程 反向传播的目标是计算损失函数(这里假设为…...

一文吃透:暗水印是什么?企业防泄密可以加暗水印吗?

设计部主管:昨天下班的时候我在办公室捡到一张文件,上面可是我们最新产品的设计草稿,严禁打印的,到底是谁干的? 员工:办公室没有监控,似乎很难查到哦。 网络部经理:不用担心&#…...

Ajax-02.Axios

Axios入门 1.引入Axios的js文件 <script src"js/axios-0.18.0.js"></script> Axios 请求方式别名: axios.get(url[,config]) axios.delete(url[,config]) axios.post(url[,data[,config]]) axios.put(url[,data[,config]]) 发送GET/POST请求 axios.get…...

NodeJS的核心配置文件package.json和package.lock.json详解

package.json 文件 package.json 文件是 Node.js 项目的核心配置文件&#xff0c;它包含了项目的基本信息、依赖关系以及一些脚本命令等。以下是 package.json 文件的主要字段说明&#xff1a; name&#xff1a;项目的名称&#xff0c;必须是小写&#xff0c;可以包含字母、数…...

开源数据采集和跟踪系统:助力营销决策的关键工具

开源数据采集和跟踪系统&#xff1a;助力营销决策的关键工具 在现代营销中&#xff0c;数据是最重要的资产之一。了解用户行为、优化广告效果、提升转化率&#xff0c;这一切都离不开精准的数据分析。为了帮助商家更好地掌握这些数据&#xff0c;市场上出现了许多开源的数据采…...

Luminar Neo for Mac/Win:创新AI图像编辑软件的强大功能

Luminar Neo&#xff0c;这款由Skylum公司倾力打造的图像编辑软件&#xff0c;为Mac和Windows用户带来了前所未有的创作体验与编辑便利。作为一款融合了先进AI技术的图像处理工具&#xff0c;Luminar Neo以其独特的功能和高效的操作流程&#xff0c;成为了摄影师、设计师及摄影…...

Mac平台M1PRO芯片MiniCPM-V-2.6网页部署跑通

Mac平台M1PRO芯片MiniCPM-V-2.6网页部署跑通 契机 ⚙ 2.6的小钢炮可以输入视频了&#xff0c;我必须拉到本地跑跑。主要解决2.6版本默认绑定flash_atten问题&#xff0c;pip install flash_attn也无法安装&#xff0c;因为强制依赖cuda。主要解决的就是这个问题&#xff0c;还…...

MyBatis:Maven,Git,TortoiseGit,Gradle

1&#xff0c;Maven Maven是一个非常优秀的项目管理工具&#xff0c;采用一种“约定优于配置&#xff08;CoC&#xff09;”的策略来管理项目。使用Maven不仅可以把源代码构建成可发布的项目&#xff08;包括编译、打包、测试和分发&#xff09;&#xff0c;还可以生成报告、生…...

获取链表中间位置的两种方法方法

方法一&#xff1a; 我们可以计算链表节点的数量&#xff0c;然后遍历链表找到前半部分的尾节点。 方法二: 我们也可以使用快慢指针在一次遍历中找到&#xff1a;慢指针一次走一步&#xff0c;快指针一次走两步&#xff0c;快慢指针同时出发。当快指针移动到链表的末尾时&am…...

第二十天的学习(2024.8.8)Vue拓展

昨天的笔记中&#xff0c;我们进行的项目已经可以在网页上显示查询到数据库中的数据&#xff0c;今天的笔记中将会完成在网页上进行增删改查的操作 1.删除表中数据 现在网页上只能呈现出数据库中的数据&#xff0c;我们首先添加一个删除按钮&#xff0c;使其可以对数据库数据…...

微信小程序教程011:全局配置:Window

文章目录 1、window1.1、`window`-小程序窗口的组成部分1.2、了解 window 节点常用的配置项1.3、设置导航栏的标题1.4、设置导航栏的背景色1.5、设置导航栏的标题颜色1.6、全局开启下拉刷新功能1.7、设置下拉刷新时窗口的背景色1.8、设置下拉刷新时 loading 的样式1.9、设置上拉…...

Tomcat服务器和Web项目的部署

目录 一、概述和作用 二、安装 1.进入官网 2.Download下面选择想要下载的版本 3.点击Which version查看版本所需要的JRE版本 4.返回上一页下载和电脑和操作系统匹配的Tomcat 5. 安装完成后&#xff0c;点击bin目录下的startup.bat&#xff08;linux系统下就运行startup.sh&…...

PCIe学习笔记(22)

Transaction Ordering Transaction Ordering Rules 表2-40定义了PCI Express Transactions的排序要求。该表中定义的规则统一适用于PCI Express上所有类型的事务&#xff0c;包括内存、I/O、配置和消息。该表中定义的排序规则适用于单个流量类(TC)。不同TC标签的事务之间没有…...

Vue3 依赖注入Provide / Inject

在实际开发中&#xff0c;我们经常需要从父组件向子组件传递数据&#xff0c;一般情况下&#xff0c;我们使用 props。但有时候会遇到深度嵌套的组件&#xff0c;而深层的子组件只需要父组件的部分内容。在这种情况下&#xff0c;如果仍然将 prop 沿着组件链逐级传递下去&#…...

Python | Leetcode Python题解之第332题重新安排行程

题目&#xff1a; 题解&#xff1a; class Solution:def findItinerary(self, tickets: List[List[str]]) -> List[str]:def dfs(curr: str):while vec[curr]:tmp heapq.heappop(vec[curr])dfs(tmp)stack.append(curr)vec collections.defaultdict(list)for depart, arri…...

React状态管理:react-redux和redux-saga(适合由vue转到react的同学)

注意&#xff1a;本文不会把所有知识点都写一遍&#xff0c;并不适合纯新手阅读 首先Redux是一种状态管理方案&#xff0c;本身和react并没有什么联系&#xff0c;redux也可以结合其他框架来用。 react-redux是基于react的一种状态管理实现&#xff0c;他不像vuex那样直接内置在…...

刷题技巧:双指针法的核心思想总结+例题整合+力扣接雨水双指针c++实现

双指针法的核心思想是通过同时操作两个指针来遍历数据结构&#xff0c;通常是数组或链表&#xff0c;以达到优化算法性能的目的。具体来说&#xff0c;双指针法能够减少时间复杂度、空间复杂度&#xff0c;或者简化逻辑结构。以下是双指针法的几个核心思想&#xff1a; ps 下面…...

什么是前端微服务,有何优势

随着互联网技术的发展&#xff0c;传统的单体应用架构已经无法满足复杂业务场景的需求。微服务架构的兴起为后端应用的开发和部署提供了灵活性和可扩展性。与此同时&#xff0c;前端开发也经历了类似的演变&#xff0c;前端微服务作为一种新兴的架构模式应运而生。 一、前端微服…...

小论文写作——02:编故事

一篇论文&#xff0c;可以发水刊&#xff0c;也可以发顶刊顶会&#xff0c;这两者的区别就是一个故事编的好不好。 你的论文ABC&#xff0c;但不能之说有ABC。创新就是看你故事编的怎么样&#xff1f;创新是编出来的。 我们要说&#xff1a;我发现了问题&#xff0c;然后准备…...

GIT企业开发使用介绍

0.认识git git就是一个版本控制器&#xff0c;记录每次的修改以及版本迭代的一个管理系统 至于为什么会有git的出现&#xff0c;主要是为了解决一份代码改了又改&#xff0c;但最后还是要第一版的情况 git 可以控制电脑上所有格式的文档 1.安装git sudo yum install git -y…...

文件上传-前端验证

查看源代码&#xff08;找验证代码&#xff09; 1、源代码直接找到验证代码 示例&#xff1a; function checkFileExt(filename){var flag false; //状态var arr ["jpg","png","gif"]; //允许上传的文件//取出上传文件的扩展名var index f…...

ROT加密算法login-RESERVE

ROT算法(字母轮换加密) 也称为Caesar加密&#xff0c;是一种简单的字母替换加密算法。它通过将字母表中的每个字母向后&#xff08;或向前&#xff09;移动固定的位置来加密文本。 加密步骤&#xff1a; 选择一个固定的偏移量&#xff08;通常是1到25之间的整数&#xff09;&…...

C++ 新特性 | C++20 常用新特性介绍

目录 1、模块(Modules) 2、协程(Coroutines) 3、概念(Concepts) 4、范围(Ranges) 5、三向比较符&#xff08;three-way comparison&#xff09; C软件异常排查从入门到精通系列教程&#xff08;专栏文章列表&#xff0c;欢迎订阅&#xff0c;持续更新...&#xff09;https…...

Java设计模式之策略模式实践

1、策略接口 /*** 策略接口*/ public interface DemoStrategy {Result execute(); } 2、策略工厂 /*** 策略工厂*/ Component public class DemoFactory {Resourceprivate final Map<String, DemoStrategy> demoStrategy new ConcurrentHashMap<>();public Demo…...

C语言——结构体数组、结构体指针、结构体函数与二级指针

C语言中的结构体&#xff08;struct&#xff09;是一种用户自定义的数据类型&#xff0c;它允许你将不同类型的数据项组合成一个单一的类型。结构体数组则是一种特殊的数组&#xff0c;其元素为结构体类型。这意味着你可以在一个数组中存储多个具有相同结构的记录。 定义结构体…...

【4】策略模式

如上图所示&#xff0c;如果要加入一个新的货币&#xff0c;那么就需要对类中的Calculate函数进行修改&#xff0c;这违背了封闭开放原则。 上图中的方式更加合适&#xff0c;搞一个抽象类&#xff08;方法中可以用多态调用&#xff09;&#xff0c;然后每个货币自己是一个类&a…...

BGP 反射器联邦实验

要求&#xff1a; 1.如图连接网络&#xff0c;合理规划IP地址&#xff0c;AS 200内IGP协议为OSPF 2.R1属于AS 100&#xff1b;R2-R3-R4小AS 234 R5-R6-R7小AS 567&#xff0c;同时声明大AS 200&#xff0c;R8属于AS 300 3.R2-R5 R4-R7 之间为联邦EBGP邻居关系 4.R1-R8之…...

stm32入门学习13-时钟RTC

&#xff08;一&#xff09;时钟RTC stm32内部集成了一个秒计数器RTC&#xff0c;用于显示我们日常的时间&#xff0c;如日期年月日&#xff0c;时分秒等&#xff0c;RTC的主要原理就是进行每秒自增&#xff0c;如果我们知道开始记秒的开始时间&#xff0c;就可以计算现在的日…...

vuex properties of undefined (reading ‘getters‘)

前言&#xff1a; 最近打算用vue 写个音乐播放器&#xff0c;在搞 vuex 的时候遇到一个很神奇报错&#xff1b;vuex 姿势练了千百次了&#xff0c;刚开始的时候我一直以为是代码问题&#xff0c;反复检查了带了&#xff0c;依旧报错。 Error in mounted hook: "TypeError:…...

再谈表的约束

文章目录 自增长唯一键外键 自增长 auto_increment&#xff1a;当对应的字段&#xff0c;不给值&#xff0c;会自动的被系统触发&#xff0c;系统会从当前字段中已经有的最大值1操作&#xff0c;得到一个新的不同的值。通常和主键搭配使用&#xff0c;作为逻辑主键。 自增长的…...

手机网站建设推广/杭州关键词排名工具

下面是我最近总结的一点点东西而已&#xff0c;以后还会更多1、.时间linux系统在时间上有比较多的东西。在游戏里&#xff0c;时间是一个非常重要的一个变量&#xff0c;涉及到前后端时间同步&#xff0c;游戏业务的倒计时&#xff0c;心跳等等的一系列功能点等等&#xff0c;如…...

怎样把自己的网站做推广/企业网站代运营

前置知识 #define pi acos(-1.0) 是因为 acos为cos的反函数 cos&#xff08;pi&#xff09;-1 使用三角函数都要换为弧度制&#xff0c;角度制*pi/180弧度制 C1. Simple Polygon Embedding 题目大意 给定一个边长为 1 的正 2n 边形&#xff0c;求外接正方形的最小面积,n为…...

西安政府做网站/seo官网

福利来袭>>>团圆佳节&#xff0c;你送祝福我送福利&#xff01;9月13日晚上十点之前在下方评论区留言&#xff0c;说出你的中秋祝福or小长假安排&#xff0c;就有机会获得爱奇艺VIP卡&#xff01;&#xff01;赶快行动吧&#xff01;end往期精选1.爱奇艺ZoomAI技术 助…...

亿唐网不做网站做品牌考试题/河南今日头条新闻

2021牛客暑期多校训练营3 Kuriyama Mirai and Exclusive Or 题目链接 题意 给定一个长度为n的数组a。 有q次操作&#xff0c;每次操作有两种类型&#xff1a; 给定一个区间[al,ar][a_l,a_r][al​,ar​]​​&#xff0c;对​​区间内的数ai⊕x,i∈[l,r]a_i\oplus x,i\in […...

地方门户网站建设要求/怎么样才能引流客人进店

单元格的高度自适应原理就是通过内部label的高度变化来增加和减少单元格的高度。 - (UILabel *)label { if(_label nil) { _label [[UILabel alloc] init]; [self.contentView addSubview:_label]; _label.numberOfLines 0; [_label mas_makeConstraints:^(MASConstraintMak…...

整个网站都在下雪特效怎么做/今日武汉最新消息

2019独角兽企业重金招聘Python工程师标准>>> 一、什么是webpack? 他有什么优点&#xff1f; 首先对于很多刚接触webpack人来说&#xff0c;肯定会问webpack是什么&#xff1f;它有什么优点&#xff1f;我们为什么要使用它&#xff1f;带着这些问题&#xff0c;我们…...