当前位置: 首页 > news >正文

Pytorch自定义算子反向传播

文章目录

      • 自定义一个线性函数算子
      • 如何实现反向传播

有关 自定义算子的实现前面已经提到,可以参考。本文讲述自定义算子如何前向推理+反向传播进行模型训练。

自定义一个线性函数算子

线性函数 Y = X W T + B Y = XW^T + B Y=XWT+B 定义输入M 个X变量,输出N个Y变量的线性方程组。
X X X 为一个 1 x M 矩阵, W W W为 N x M 矩阵, B B B 为 1xN 矩阵,根据公式,输出 Y Y Y为1xN 矩阵。其中 W 和 B 为算子权重参数,保存在模型中。
在训练时刻,模型输入 X X X , 和监督值 Y Y Y,根据 算子forward()计算的 Y p Y^p Yp ,计算Loss = criterion( Y Y Y, Y p Y^p Yp ),然后根据backward()链式求导反向传播计算梯度值。最后根据梯度更新W 和 B 参数。

class LinearF(torch.autograd.Function):@staticmethoddef symbolic(g, input, weight, bias):return g.op("MYLINEAR", input, weight, bias)@staticmethoddef forward(ctx, input:Tensor, weight: Tensor, bias: Tensor) -> Tensor:output = input @ weight.T + bias[None, ...]ctx.save_for_backward(input, weight)return output@staticmethoddef backward(ctx, grad_output:Tensor)->Tuple[Tensor, Tensor, Tensor]:# grad_output -- [B, N] = d(Loss) / d(Y)input, weight = ctx.saved_tensorsgrad_input = grad_output @ weightgrad_weight = grad_output.T @ inputgrad_bias = grad_output.sum(0)# print("grad_input: ", grad_input)# print("grad_weight: ", grad_weight)# print("grad_bias: ", grad_bias)return grad_input, grad_weight, grad_bias

如何实现反向传播

在这里插入图片描述
前向推理比较简单,就根据公式来既可以。反向传播backward() 怎么写呢?
反向传播有两个输入参数,第一个为ctx,第二个grad_output,grad_output就是对forward() 输出output 的求导,如果是最后的节点,那就是loss对输出的求导,否则就是下一层对输出求导,输出grad_input, grad_weight, grad_bias则分别对应了forward的输入input、weight、bias的梯度。这很容易理解,因为是在做链式求导,LinearFunction是这条链上的某个节点,输入输出的数量和含义刚好相反且对应。
根据公式:
Y = X W T + B Y = XW^T + B Y=XWT+B
Loss = criterion( Y t Y^t_{} Yt, Y Y_{} Y ), 假设我们选择判别函数为L2范数,Loss = ∑ j = 0 N 0.5 ∗ ( Y j t − Y j ) 2 \sum_{j=0}^N0.5 * (Y^t_{j}-Y_{j} )^2 j=0N0.5(YjtYj)2

grad_output(j) = d ( L o s s ) d ( Y j ) \frac{d(Loss) }{d(Y_{j})} d(Yj)d(Loss) = Y j t − Y j Y^t_{j} - Y_{j} YjtYj

其中 Y j t Y^t_{j} Yjt为监督值, Y j Y_{j} Yj为模型输出值。

根据链式求导法则, 对输入 X i X_{i} Xi 的求导为

grad_input[i] = ∑ j = 0 N d ( L o s s ) d ( Y j ) ∗ d ( Y j ) d ( X i ) \sum_{j=0}^N\frac{d(Loss) }{d(Y_{j})}*\frac{d(Y_{j}) }{d(X_{i})} j=0Nd(Yj)d(Loss)d(Xi)d(Yj)= ∑ j = 0 N g r a d _ o u t p u t [ j ] ∗ d ( Y j ) d ( X i ) \sum_{j=0}^N{grad\_output}[j] *\frac{d(Y_{j}) }{d(X_{i})} j=0Ngrad_output[j]d(Xi)d(Yj)

d ( Y j ) d ( X i ) \frac{d(Y_{j}) }{d(X_{i})} d(Xi)d(Yj) 即为 W i j T = W j i W^T_{ij} = W_{ji} WijT=Wji

其中i 对应X维度, j对应输出Y维度。

最后整理成矩阵形式:

g r a d _ i n p u t = g r a d _ o u t p u t ∗ W {grad\_input}={grad\_output} * W grad_input=grad_outputW

同理:
g r a d _ w e i g h t = g r a d _ o u t p u t T ∗ X {grad\_weight}={grad\_output}^T * X grad_weight=grad_outputTX

g r a d _ b i a s = ∑ q = 0 N g r a d _ o u t p u t {grad\_bias}=\sum_{q=0}^N{grad\_output} grad_bias=q=0Ngrad_output

最后根据公式形式得到backward()函数。

反向传播的梯度求解还是不容易的,一不小心可能算错了,所以务必在模型训练以前检查梯度计算的正确性。pytorch提供了torch.autograd.gradcheck方法来检验梯度计算的正确性。

其他参考文献:pytorch自定义算子实现详解及反向传播梯度推导

最后根据自定义算子,搭建模型,训练模型参数W,B。并导出onnx。参考代码如下:

import torch
from torch import Tensor
from typing import Tuple
import numpy as np
class LinearF(torch.autograd.Function):@staticmethoddef symbolic(g, input, weight, bias):return g.op("MYLINEAR", input, weight, bias)@staticmethoddef forward(ctx, input:Tensor, weight: Tensor, bias: Tensor) -> Tensor:output = input @ weight.T + bias[None, ...]ctx.save_for_backward(input, weight)return output@staticmethoddef backward(ctx, grad_output:Tensor)->Tuple[Tensor, Tensor, Tensor]:print("grad_output: ", grad_output)# grad_output -- [B, N] = d(Loss) / d(Y)input, weight = ctx.saved_tensorsgrad_input = grad_output @ weightgrad_weight = grad_output.T @ inputgrad_bias = grad_output.sum(0)return grad_input, grad_weight, grad_bias#对LinearFunction进行封装
class MyLinear(torch.nn.Module):def __init__(self, in_features: int, out_features: int, dtype:torch.dtype) -> None:super().__init__()self.in_features = in_featuresself.out_features = out_featuresself.weight = torch.nn.Parameter(torch.empty((out_features, in_features), dtype=dtype))self.bias = torch.nn.Parameter(torch.empty((out_features,), dtype=dtype))self.reset_parameters()# self.weight = torch.nn.Parameter(torch.Tensor([2.0, 3.0]))# self.bias = torch.nn.Parameter(torch.Tensor([4.0]))#y = 2 * x1 + 3 * x2 + 4def reset_parameters(self) -> None:torch.nn.init.uniform_(self.weight)torch.nn.init.uniform_(self.bias)def forward(self, input: Tensor) -> Tensor:# for name, pa in self.named_parameters():#     print(name, pa)return LinearF.apply(input, self.weight, self.bias)  # 在此处使用if __name__ == "__main__":device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')print(device.type)model = MyLinear(2, 1, dtype=torch.float64).to(device)# torch.Tensor 默认类型为float32,使用gpu时,输入数据类型与W权重类型一致,否则报错# torch.Tensor([3.0, 2.0].double() 转换为float64#input = torch.Tensor([3.0, 2.0], ).requires_grad_(True).unsqueeze(0).double()#input = input.to(device)#assert torch.autograd.gradcheck(model, input)import torch.optim as optim#定义优化策略和判别函数optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0.9, weight_decay=5e-4)criterion = torch.nn.MSELoss()for epoch in range(300):print("************** epoch: ", epoch , " ************************************* ")inputx = torch.Tensor(np.random.rand(2)).unsqueeze(0).double().to(device)lable = torch.Tensor(2 * inputx[:, 0] + 3 * inputx[:, 1] + 4).double().to(device)print("outlable", lable)optimizer.zero_grad()  # 梯度清零prob = model(inputx)print("prob", prob)loss = criterion(lable, prob)print("loss: ", loss)loss.backward()  #反向传播optimizer.step() #更新参数# 完成训练model.cpu().eval()input = torch.tensor([[3.0, 2.0]], dtype=torch.float64)output = model(input)torch.onnx.export(model,  # 这里的args,是指输入给model的参数,需要传递tuple,因此用括号(input,),"linear.onnx",  # 储存的文件路径verbose=True,  # 打印详细信息input_names=["x"],  #为输入和输出节点指定名称,方便后面查看或者操作output_names=["y"],opset_version=11,  #这里的opset,指,各类算子以何种方式导出,对应于symbolic_opset11dynamic_axes={"image": {0: "batch"},"output": {0: "batch"},},operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)

相关文章:

Pytorch自定义算子反向传播

文章目录 自定义一个线性函数算子如何实现反向传播 有关 自定义算子的实现前面已经提到,可以参考。本文讲述自定义算子如何前向推理反向传播进行模型训练。 自定义一个线性函数算子 线性函数 Y X W T B Y XW^T B YXWTB 定义输入M 个X变量,输出N个…...

aws服务(二)机密数据存储

在AWS(Amazon Web Services)中存储机密数据时,安全性和合规性是最重要的考虑因素。AWS 提供了多个服务和工具,帮助用户确保数据的安全性、机密性以及合规性。以下是一些推荐的存储机密数据的AWS服务和最佳实践: 一、A…...

VMware Workstation 17.6.1

概述 目前 VMware Workstation Pro 发布了最新版 v17.6.1: 本月11号官宣:针对所有人免费提供,包括商业、教育和个人用户。 使用说明 软件安装 获取安装包后,双击默认安装即可: 一路单击下一步按钮: 等待…...

高校企业数据挖掘平台推荐

TipDM数据挖掘建模平台是由广东泰迪智能科技股份有限公司自主研发打造的可视化、一站式、高性能的数据挖掘与人工智能建模服务平台,致力于为使用者打通从数据接入、数据预处理、模型开发训练、模型评估比较、模型应用部署到模型任务调度的全链路。平台内置丰富的机器…...

Vue项目开发 formatData 函数有哪些常用的场景?

formatData 不是 JavaScript 中的内建函数,它通常是一个自定义函数,用来格式化数据。不同的开发环境和框架中可能有不同的 formatData 实现方式。如果你指的是某个特定框架或者库中的 formatData,请提供更多的上下文信息。不过,以…...

【AI知识】两类最主流AI应用(文生图、ChatGPT)中的目标函数

之前写过一篇 【AI知识】了解两类最主流AI任务中的目标函数,介绍了AI最常见的两类任务【分类、回归】的基础损失函数【交叉熵、均方差】,以初步了解AI的训练目标。 本篇更进一步,聊一聊流行的“文生图”、“聊天机器人ChatGPT”模型中的目标函…...

【单片机基础】定时器/计数器的工作原理

单片机中的定时器/计数器(Timer/Counter)是用于时间测量和事件计数的重要模块。它们可以用来生成精确的延时、测量外部信号的频率或周期、捕获外部事件的时间戳等。理解定时器/计数器的工作原理对于单片机编程和系统设计非常重要。以下是定时器/计数器的…...

ModuleNotFoundError: No module named ‘distutils.msvccompiler‘ 报错的解决

报错 在conda 环境安装 numpy 时,出现报错 ModuleNotFoundError: No module named distutils.msvccompiler 解决 Python 版本过高导致的,降低版本到 Python 3.8 conda install python3.8即可解决。...

HCIA笔记2--ARP+ICMP+VRP基础

1. ARP ARP: 地址解析协议(address resolve protocol)。 网络数据包在通信的时候一般是使用 I P IP IP地址进行通信。 但是在封装数据链路层的时候是需要目标 m a c mac mac地址的。 而 A R P ARP ARP协议实现的功能就是根据 I P IP IP地址来获得 m a c mac mac地址。 1.1 a…...

SpringBoot与MongoDB深度整合及应用案例

SpringBoot与MongoDB深度整合及应用案例 在当今快速发展的软件开发领域,NoSQL数据库因其灵活性和可扩展性而变得越来越流行。MongoDB,作为一款领先的NoSQL数据库,以其文档导向的存储模型和强大的查询能力脱颖而出。本文将为您提供一个全方位…...

Redis模拟延时队列 实现日程提醒

使用Redis模拟延时队列 实际上通过MQ实现延时队列更加方便,只是在实际业务中种种原因导致最终选择使用redis作为该业务实现的中间件,顺便记录一下。 该业务是用于日程短信提醒,用户添加日程后,就会被放入redis队列中等待被执行发…...

vue项目中富文本编辑器的实现

文章目录 vue前端实现富文本编辑器的功能需要用到第三方库1. 安装包2.全局引入注册3.组件内使用4.图片缩放功能实现①安装包②注册并添加配置项③报错解决 vue前端实现富文本编辑器的功能需要用到第三方库 vue2使用vue-quill-editor,vue3使用vueup/vue-quill&#…...

nginx 配置lua执行shell脚本

1.需要nginx安装lua_nginx_module模块,这一步安装时,遇到一个坑,nginx执行configure时,一直提示./configure: error: unsupported LuaJIT version; ngx_http_lua_module requires LuaJIT 2.x。 网上一堆方法都试了,都…...

Keil+VSCode优化开发体验

目录 一、引言 二、详细步骤 1、编译器准备 2、安装相应插件 2.1 安装C/C插件 2.2 安装Keil相关插件 3、添加keil环境变量 4、加载keil工程文件 5、VSCode中成功添加工程文件后可能出现的问题 5.1 编码不一致问题 6、在VSCode中进行编译工程以及烧录程序 7、效果展示…...

vue2中引入cesium全步骤

1.npm 下载cesium建议指定版本下载,最新版本有兼容性问题 npm install cesium1.95.0 2.在node_models中找到cesium将此文件下的Cesium文件复制出来放在项目的静态资源public中或者static中,获取去github上去下载zip包放在本地也可以 3.在index.html中引…...

工程师 - 智能家居方案介绍

1. 智能家居硬件方案概述 智能家居硬件方案是实现家庭自动化的重要组件,通过集成各种设备来提升生活的便利性、安全性和效率。这些方案通常结合了物联网技术,为用户提供智能化、自动化的生活体验。硬件方案的选择直接影响到智能家居系统的性能、兼容性、…...

中小企业人事管理:SpringBoot框架高级应用

摘 要 随着科学技术的飞速发展,社会的方方面面、各行各业都在努力与现代的先进技术接轨,通过科技手段来提高自身的优势,中小企业人事管理系统当然也不能排除在外。中小企业人事管理系统是以实际运用为开发背景,运用软件工程原理和…...

嵌入式Linux驱动开发日记

目录 让我们从环境配置开始 目标平台 从Ubuntu开始 从交叉编译器继续 arm-linux-gnueabihf-gcc vscode 没学过ARM汇编 正文开始——速度体验一把 写一个链接脚本 写一个简单的Makefile脚本 使用正点原子的imxdownload下载到自己的SD卡上 更进一步的笔记和说明 从IM…...

迪杰特斯拉算法(Dijkstra‘s)

迪杰斯特拉算法(Dijkstras algorithm)是由荷兰计算机科学家艾兹格迪科斯彻(Edsger W. Dijkstra)在1956年提出的,用于在加权图中找到单个源点到所有其他顶点的最短路径的算法。这个算法广泛应用于网络路由、地图导航等领…...

reids基础

数据结构类型 String setnx //设置key不存在,则添加成功 setex name 10 jack // key 10s失效,自动删除 hash hset hget list 按添加数据排序 lpush //左侧插入 rpush //右侧插入 set 不重复 sadd //添加…...

私有化部署视频平台EasyCVR宇视设备视频平台如何构建视频联网平台及升级视频转码业务?

在当今数字化、网络化的时代背景下,视频监控技术已广泛应用于各行各业,成为保障安全、提升效率的重要工具。然而,面对复杂多变的监控需求和跨区域、网络化的管理挑战,传统的视频监控解决方案往往显得力不从心。 EasyCVR视频融合云…...

SparkContext讲解

SparkContext讲解 什么是 SparkContext? SparkContext 是 Spark 应用程序的入口点,是 Spark 的核心组件之一。每个 Spark 应用程序启动时,都会创建一个 SparkContext 对象,它负责与集群管理器(如 YARN、Mesos 或 Spa…...

MODBUS TCP转CANOpen网关

Modbus TCP转CANopen网关 型号:SG-TCP-COE-210 产品用途 本网关可以实现将CANOpen接口设备连接到MODBUS TCP网络中;并且用户不需要了解具体的CANOpen和Modbus TCP 协议即可实现将CANOpen设备挂载到MODBUS TCP接口的 PLC上,并和CANOpen设备…...

渗透测试---shell(4)脚本与用户交互以及if条件判断

声明:学习素材来自b站up【泷羽Sec】,侵删,若阅读过程中有相关方面的不足,还请指正,本文只做相关技术分享,切莫从事违法等相关行为,本人一律不承担一切后果 目录 一、shell脚本与用户进行交互 使用 read 指…...

02_Spring_IoC实现

接下来先简单说一下关于IoC的一些要点,后面我们再详细一步一步讨论。 一、IoC控制反转 IoC控制反转它是一种思想,不是具体的实现控制反转的目的是为了降低程序的耦合度,提高程序的可扩展性,从而满足OCP原则和DIP原则控制反转,那到底反转是什么东西? 我们不再使用某个对象…...

使用Python3实现Gitee码云自动化发布

仓库信息 https://gitee.com/liumou_site/ip 实现代码 import osimport requests from loguru import loggerdef gitee(ver, message, prerelease: bool False):"""在 Gitee 上创建发布版本:param ver: 版本号:param message: 发布信息:param prerelease: 是…...

Ubuntu24.04下的docker问题

按官网提示是可以安装成功的,但是curl无法使用https下载,会造成下述语句执行失败 # Add Dockers official GPG key: sudo apt-get update sudo apt-get install ca-certificates curl sudo install -m 0755 -d /etc/apt/keyrings sudo curl -fsSL https…...

PAT (Basic Level) Practice (中文)1002 写出这个数

读入一个正整数 n&#xff0c;计算其各位数字之和&#xff0c;用汉语拼音写出和的每一位数字。 #include<bits/stdc.h> using namespace std; string a; int sum0; int f0; int n[10005]; int main(){ cin>>a; int c0; int laa.size(); for(int i…...

C07.L07.STL之映射.应用2.统计数字

题目描述 某次科研调查时得到了 n 个自然数&#xff0c;每个数均不超过 1500000000 (1.5*10^9 )。已知不相同的数不超过 10000 个&#xff0c;现在需要统计这些自然数各自出现的次数&#xff0c;并按照自然数从小到大的顺序输出统计结果。 输入格式 包含 2 行&#xff1a; 第…...

微信小程序组件详解:text 和 rich-text 组件的基本用法

微信小程序组件详解:text 和 rich-text 组件的基本用法 引言 在微信小程序的开发中,文本展示是用户界面设计中不可或缺的一部分。无论是简单的文本信息,还是复杂的富文本内容,text 和 rich-text 组件都能够帮助我们实现这些需求。本文将详细介绍这两个组件的基本用法,包…...

网络工作网站/外贸网络营销推广

前言关于异或值怎么计算代码第二种方式&#xff0c;适合不懂怎么计算&#xff0c;想直接用的代码前言偶然看到有可以解密微信dat的文档&#xff0c;上网查了查&#xff0c;找到了一篇可以用的文章&#xff0c;不过转换过程代码是有问题的&#xff0c;在这里改了下发布上来。提取…...

做金属探测门批发网站/今日国际新闻最新消息事件

菜单是许多应用类型中常见的用户界面组件。要提供熟悉而一致的用户体验&#xff0c;您应使用 Menu API 呈现 Activity 中的用户操作和其他选项。 从 Android 3.0&#xff08;API 级别 11&#xff09;开始&#xff0c;采用 Android 技术的设备不必再提供一个专用“菜单”按钮。…...

网站建设灬金手指下拉/揭阳市seo上词外包

我现在的公司是自己和几个朋友创办的(大学生自主创业).人数不是很多&#xff0c;每个人负责一个版块.我们的经理本来开始也是专注技术的,后来因为公司的发展,他的经理基本上转到管理方面去了,因此在web技术方面就我孤单一个人了.现在觉得前途太迷茫了.本人是从学习asp开始的.然…...

杭州网站建设服务/seo如何优化网站推广

在 Spring Boot 中&#xff0c;我们可以使用多种方式来批量插入数据到数据库。下面介绍几种常用的方案。 使用JdbcTemplate批量插入 JdbcTemplate 是 Spring Framework 提供的一个基于 JDBC 的数据库访问工具&#xff0c;可以方便地进行数据库操作。我们可以使用 JdbcTempla…...

做网站去除视频广告/网络营销的八种方式

原题目及解答见第4周-任务4-设计工资类(Salary)&#xff1a; 【拓展1】使用salary[50]有限制&#xff0c;实际人数少浪费空间&#xff0c;人数多时无法完成任务。程序执行中先输入职工人数&#xff0c;然后利用教材P217所讲的动态分配内存的运算符new&#xff0c;开辟一个大小…...

做网站挣钱不/seo搜索引擎优化招聘

引言&#xff1a; 我在回龙观买了1000平方dp的房子 这种用法正确吗&#xff1f;如果正确&#xff0c;那么现在回龙观的房价是多少1平方dp呢&#xff1f; dp是物理单位吗&#xff1f; dp&#xff0c;全称是Density-independent Pixels, 设备独立像素, 在大多的android开发书籍里…...