optimizer.zero_grad(), loss.backward(), optimizer.step()的理解及使用
optimizer.zero_grad,loss.backward,optimizer.step
- 用法介绍
- optimizer.zero_grad():
- loss.backward():
- optimizer.step():
用法介绍
这三个函数的作用是将梯度归零(optimizer.zero_grad()),然后反向传播计算得到每个参数的梯度值(loss.backward()),最后通过梯度下降执行一步参数更新(optimizer.step())。
简单的说就是进来一个batch的数据,先将梯度归零,计算一次梯度,更新一次网络。
model = MyModel()
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001, momentum=0.9, weight_decay=1e-4)for epoch in range(1, epochs):for i, (inputs, labels) in enumerate(train_loader):inputs= inputs.to(device=device)labels= labels.to(device=device)# forwardoutput= model(inputs)loss = criterion(output, labels)# backwardoptimizer.zero_grad()loss.backward()# gardient descent or adam stepoptimizer.step()
另外一种:将optimizer.zero_grad() 放在 optimizer.step() 后面,即梯度累加。
- 获取loss:输入图像和标签,通过infer计算得到预测值,计算损失函数;
- loss.backward() 反向传播,计算当前梯度;
- 多次循环步骤1-2,不清空梯度,使梯度累加在已有梯度上;
- 梯度累加了一定次数后,先optimizer.step() 根据累计的梯度更新网络参数,然后optimizer.zero_grad() 清空过往梯度,为下一波梯度累加做准备;
总结来说:梯度累加就是,每次获取1个batch的数据,计算1次梯度,梯度不清空,不断累加,累加一定次数后,根据累加的梯度更新网络参数,然后清空梯度,进行下一次循环。
一定条件下,batchsize越大训练效果越好,梯度累加则实现了batchsize的变相扩大,如果accumulation_steps为8,则batchsize ‘变相’ 扩大了8倍,是我们这种乞丐实验室解决显存受限的一个不错的trick,使用时需要注意,学习率也要适当放大。
参考链接:https://blog.csdn.net/weixin_36670529/article/details/108630740
optimizer.zero_grad():
param_groups:Optimizer类在实例化时会在构造函数中创建一个param_groups列表,列表中有num_groups个长度为6的param_group字典(num_groups取决于你定义optimizer时传入了几组参数),每个param_group包含了 [‘params’, ‘lr’, ‘momentum’, ‘dampening’, ‘weight_decay’, ‘nesterov’] 这6组键值对。
param_group[‘params’]:由传入的模型参数组成的列表,即实例化Optimizer类时传入该group的参数,如果参数没有分组,则为整个模型的参数model.parameters(),每个参数是一个torch.nn.parameter.Parameter对象。
def zero_grad(self):r"""Clears the gradients of all optimized :class:`torch.Tensor` s."""for group in self.param_groups:for p in group['params']:if p.grad is not None:p.grad.detach_()p.grad.zero_()
optimizer.zero_grad()函数会遍历模型的所有参数,通过p.grad.detach_()方法截断反向传播的梯度流,再通过p.grad.zero_()函数将每个参数的梯度值设为0,即上一次的梯度记录被清空。
因为训练的过程通常使用mini-batch方法,所以如果不将梯度清零的话,梯度会与上一个batch的数据相关,因此该函数要写在反向传播和梯度下降之前。
loss.backward():
PyTorch的反向传播(即tensor.backward())是通过autograd包来实现的,autograd包会根据tensor进行过的数学运算来自动计算其对应的梯度。
具体来说,torch.tensor是autograd包的基础类,如果你设置tensor的requires_grads为True,就会开始跟踪这个tensor上面的所有运算,如果你做完运算后使用tensor.backward(),所有的梯度就会自动运算,tensor的梯度将会累加到它的.grad属性里面去。
更具体地说,损失函数loss是由模型的所有权重w经过一系列运算得到的,若某个w的requires_grads为True,则w的所有上层参数(后面层的权重w)的.grad_fn属性中就保存了对应的运算,然后在使用loss.backward()后,会一层层的反向传播计算每个w的梯度值,并保存到该w的.grad属性中。
如果没有进行tensor.backward()的话,梯度值将会是None,因此loss.backward()要写在optimizer.step()之前。
optimizer.step():
以SGD为例,torch.optim.SGD().step()源码如下:
def step(self, closure=None):"""Performs a single optimization step.Arguments:closure (callable, optional): A closure that reevaluates the modeland returns the loss."""loss = Noneif closure is not None:loss = closure()for group in self.param_groups:weight_decay = group['weight_decay']momentum = group['momentum']dampening = group['dampening']nesterov = group['nesterov']for p in group['params']:if p.grad is None:continued_p = p.grad.dataif weight_decay != 0:d_p.add_(weight_decay, p.data)if momentum != 0:param_state = self.state[p]if 'momentum_buffer' not in param_state:buf = param_state['momentum_buffer'] = torch.clone(d_p).detach()else:buf = param_state['momentum_buffer']buf.mul_(momentum).add_(1 - dampening, d_p)if nesterov:d_p = d_p.add(momentum, buf)else:d_p = bufp.data.add_(-group['lr'], d_p)return loss
step()函数的作用是执行一次优化步骤,通过梯度下降法来更新参数的值。因为梯度下降是基于梯度的,所以在执行optimizer.step()函数前应先执行loss.backward()函数来计算梯度。
注意:optimizer只负责通过梯度下降进行优化,而不负责产生梯度,梯度是tensor.backward()方法产生的。
相关文章:
optimizer.zero_grad(), loss.backward(), optimizer.step()的理解及使用
optimizer.zero_grad,loss.backward,optimizer.step用法介绍optimizer.zero_grad():loss.backward():optimizer.step():用法介绍 这三个函数的作用是将梯度归零(optimizer.zero_grad())&#x…...
融资、量产和一栈式布局,这家Tier 1如此备战高阶智驾决赛圈
作者 | Bruce 编辑 | 于婷从早期的ADAS,到高速/城市NOA,智能驾驶的竞争正逐渐升级,这对于车企和供应商的核心技术和产品布局都是一个重要的考验。 部分智驾供应商已经在囤积粮草,响应变化。 2023刚一开年,智能驾驶领域…...
centos7.8安装oralce11g
文章目录环境安装文件准备添加用户操作系统环境配置解压安装问题解决创建用户远程连接为了熟悉rman备份操作,参照大神的博客在centos中安装了一套oracle11g,将安装步骤记录如下环境安装文件准备 这里准备一台centos7.8 虚拟机 配置ip 192.168.18.100 主…...
【蓝桥杯集训·每日一题】AcWing 3956. 截断数组
文章目录一、题目1、原题链接2、题目描述二、解题报告1、思路分析2、时间复杂度3、代码详解三、知识风暴一维前缀和一、题目 1、原题链接 3956. 截断数组 2、题目描述 给定一个长度为 n 的数组 a1,a2,…,an。 现在,要将该数组从中间截断,得到三个非空子…...
万丈高楼平地起:Linux常用命令
目录 系统管理命令 man命令 ls命令 cd命令 useradd命令 passwd命令 free命令 whoami命令 ps命令 date命令 pwd命令 shutdown命令 文件目录管理命令 touch命令 cat命令 mkdir命令 rm命令 cp命令 mv命令 find命令 more指令 less指令 head指令 tail指令 …...
Linux(Linux的连接使用)
连接Linux我们一般使用CRT或者Xshell工具进行连接使用。 如CRT使用SSH的方式 输出主机,账户,密码那些就可以连接上了。 Linux系统是一个文件型操作系统,有一句话说Linux的一切皆是文件。Linux系统的启动大致有下面几个步骤 Linux系统有7个运…...
Unity中画2D图表(2)——用XChart包绘制散点分布图 + 一条直线方程
散点图用于显示关系。 对于 【相关性】 ,散点图有助于显示两个变量之间线性关系的强度。 对于 【回归】 ,散点图常常会添加拟合线。 举例1:你可以展示【年降雨量】与【玉米亩产量】的关系 举例2:你也可以分析各个【节假日】与【大…...
Go 排序包 sort
写在前面的使用总结: 排序结构体 实现Len,Less,Swap三个函数 package main import ( "fmt" "sort") type StuScore struct { name string score int } type StuScores []StuScore func (s StuScores) Len(…...
Java Email 发HTML邮件工具 采用 freemarker模板引擎渲染
Java Email 发HTML邮件工具 采用 freemarker模板引擎 1.常用方式对比 Java发送邮件有很多的实现方式 第一种:Java 原生发邮件mail.jar和activation.jar <!-- https://mvnrepository.com/artifact/javax.mail/mail --> <dependency><groupId>jav…...
CNI 网络流量分析(六)Calico 介绍与原理(二)
文章目录CNI 网络流量分析(六)Calico 介绍与原理(二)CNIIPAM指定 IP指定非 IPAM IPCNI 网络流量分析(六)Calico 介绍与原理(二) CNI 支持多种 datapath,默认是 linuxDa…...
短视频标题的几种类型和闭坑注意事项
目录 短视频标题的几种类型 1、悬念式 2、蹭热门式 3、干货式 4、对比式方法 5、总分/分总式方法 6、挑战式方式 7、启发激励式 8、讲故事式 02注意事项 1、避免使用冷门、生僻词汇 标题是点睛之笔,核心是视频内容 短视频标题的几种类型 1、悬念式 通过…...
操作系统——1.操作系统的概念、定义和目标
目录 1.概念 1.1 操作系统的种类 1.2电脑的组成 1.3电脑组成的介绍 1.4操作系统的概念(定义) 2.操作系统的功能和目标 2.1概述 2.2 操作系统作为系统资源的管理者 2.3 操作系统作为用户和计算机硬件间的接口 2.3.1用户接口的解释 2.3.2 GUI 2.3.3接…...
【html弹框拖拽和div拖拽功能】原生html页面引入vue语法后通过自定义指令简单实现div和弹框拖拽功能
前言 这是html版本的。只是引用了vue的语法。 这是很多公司会出现的一种情况,就是原生的页面,引入vue的语法开发 这就导致有些vue上很简单的功能。放到这里需要转换一下 以前写过一个vue版本的帖子,现在再加一个html版本的。 另一个vue版本…...
2023新华为OD机试题 - 计算网络信号(JavaScript) | 刷完必过
计算网络信号 题目 网络信号经过传递会逐层衰减,且遇到阻隔物无法直接穿透,在此情况下需要计算某个位置的网络信号值。 注意:网络信号可以绕过阻隔物 array[m][n] 的二维数组代表网格地图,array[i][j] = 0代表 i 行 j 列是空旷位置;array[i][j] = x(x 为正整数)代表 i 行 …...
27.边缘系统的架构
文章目录27 Architecures for the Edge 边缘系统的架构27.1 The Ecosystem of Edge-Dominant Systems 边缘主导系统的生态系统27.2 Changes to the Software Development Life Cycle 软件开发生命周期的变化27.3 Implications for Architecture 对架构的影响27.4 Implications …...
机器学习强基计划8-1:图解主成分分析PCA算法(附Python实现)
目录0 写在前面1 为什么要降维?2 主成分分析原理3 PCA与SVD的联系4 Python实现0 写在前面 机器学习强基计划聚焦深度和广度,加深对机器学习模型的理解与应用。“深”在详细推导算法模型背后的数学原理;“广”在分析多个机器学习模型…...
Hudi-集成Spark之spark-shell 方式
Hudi集成Spark之spark-shell 方式 启动 spark-shell (1)启动命令 #针对Spark 3.2 spark-shell \--conf spark.serializerorg.apache.spark.serializer.KryoSerializer \--conf spark.sql.catalog.spark_catalogorg.apache.spark.sql.hudi.catalog.Hoo…...
Python爬虫:从js逆向了解西瓜视频的下载链接的生成
前言 最近花费了几天时间,想获取西瓜视频这个平台上某个视频的下载链接,运用js逆向进行获取。其实,如果小编一开始就注意到这一点(就是在做js逆向时,打了断点之后,然后执行相关代码,查看相关变量的值,结果一下子就蹦出很多视频相关的数据,查看了网站下的相关api链接,也…...
Numpy-如何对数组进行切割
前言 本文是该专栏的第24篇,后面会持续分享python的数据分析知识,记得关注。 继上篇文章,详细介绍了使用numpy对数组进行叠加。本文再详细来介绍,使用numpy如何对数组进行切割。说句题外话,前面有重点介绍numpy的各个知识点。 感兴趣的同学,可查看笔者之前写的详细内容…...
Python之字符串精讲(下)
前言 今天继续讲解字符串下半部分,内容包括字符串的检索、大小写转换、去除字符串中空格和特殊字符。 一、检索字符串 在Python中,字符串对象提供了很多用于字符串查找的方法,主要给大家介绍以下几种方法。 1. count() 方法 count() 方法…...
Python图像卡通化animegan2-pytorch实例演示
先看下效果图: 左边是原图,右边是处理后的图片,使用的 face_paint_512_v2 模型。 项目获取: animegan2-pytorch 下载解压后 cmd 可进入项目地址的命令界面。 其中 img 是我自己建的,用于存放图片。 需要 torch 版本 …...
谢希仁版《计算机网络》期末总复习【完结】
文章目录说明第一章 计算机网络概述计算机网络和互联网网络边缘网络核心分组交换网的性能网络体系结构控制平面和数据平面第二章 IP地址分类编址子网划分无分类编址特殊用途的IP地址IP地址规划和分配第三章 应用层应用层协议原理万维网【URL / HTML / HTTP】域名系统DNS动态主机…...
问:React的useState和setState到底是同步还是异步呢?
先来思考一个老生常谈的问题,setState是同步还是异步? 再深入思考一下,useState是同步还是异步呢? 我们来写几个 demo 试验一下。 先看 useState 同步和异步情况下,连续执行两个 useState 示例 function Component() {const…...
深度理解机器学习16-门控循环单元
评估简单循环神经网络的缺点。 描述门控循环单元(Gated Recurrent Unit,GRU)的架构。 使用GRU进行情绪分析。 将GRU应用于文本生成。 基本RNN通常由输入层、输出层和几个互连的隐藏层组成。最简单的RNN有一个缺点,那就是它们不…...
Python中Generators教程
要想创建一个iterator,必须实现一个有__iter__()和__next__()方法的类,类要能够跟踪内部状态并且在没有元素返回的时候引发StopIteration异常. 这个过程很繁琐而且违反直觉.Generator能够解决这个问题. python generator是一个简单的创建iterator的途径…...
数据结构与算法基础-学习-10-线性表之栈的清理、销毁、压栈、弹栈
一、函数实现1、ClearSqStack(1)用途清理栈的空间。只需要栈顶指针和栈底指针相等,就说明栈已经清空,后续新入栈的数据可以直接覆盖,不用实际清理数据,提升了清理效率。(2)源码Statu…...
Leetcode 每日一题 1234. 替换子串得到平衡字符串
Halo,这里是Ppeua。平时主要更新C语言,C,数据结构算法......感兴趣就关注我吧!你定不会失望。 🌈个人主页:主页链接 🌈算法专栏:专栏链接 我会一直往里填充内容哒! &…...
【MYSQL中级篇】数据库数据查询学习
🍁博主简介 🏅云计算领域优质创作者 🏅华为云开发者社区专家博主 🏅阿里云开发者社区专家博主 💊交流社区:运维交流社区 欢迎大家的加入! 相关文章 文章名文章地址【MYSQL初级篇】入门…...
华为OD机试真题JAVA实现【火星文计算】真题+解题思路+代码(20222023)
🔥系列专栏 华为OD机试(JAVA)真题目录汇总华为OD机试(Python)真题目录汇总华为OD机试(C++)真题目录汇总华为OD机试(JavaScript)真题目录汇总文章目录 🔥系列专栏题目输入输出描述示例一输入输出说明解题思路核心知识点Code运行结果版...
Linux基础知识
♥️作者:小刘在C站 ♥️个人主页:小刘主页 ♥️每天分享云计算网络运维课堂笔记,努力不一定有收获,但一定会有收获加油!一起努力,共赴美好人生! ♥️夕阳下,是最美的绽放࿰…...
贵港市建设局网站/推广之家
一. RMAN 备份的一些优点和OS命令备份方式相比,使用RMAN的优点1 备份执行期间不需要人工干预,因此减少了误操作的机会;2 可以有效的将备份和恢复结合起来;3 支持除逻辑备份以外的所有备份类型,包括完全备…...
网站建设定制设计/品牌推广策划方案案例
iwconfig命令用于系统配置无线网络设备或显示无线网络设备信息。iwconfig命令类似于ifconfig命令,但是他配置对象是无线网卡,它对网络设备进行无线操作,如设置无线通信频段。auto 自动模式essid 设置ESSIDnwid 设置网络IDfreq 设置无线网络通…...
甘肃省第九建设集团网站首页/志鸿优化设计答案
朋友门!等待终于结束了! 我们的社区在短短一周内已经发展到了超过10,000名Twitter粉丝,12,000名Discord成员,以及7,000名Telegram成员!为了感谢您的大力支持,我们将继续努力。为了感谢您的大力支持,我们很高…...
学校网站开发研究的意义和目的/企业营销
本篇文章主要给大家介绍一下如何使用htmlcss实现元素的水平与垂直居中效果,这也是我们网页在编码制作中会经常用到的问题。 1)单行文本的居中 主要实现css代码: 水平居中:text-align:center;垂直居中:line-height:X…...
做网站标题/百度优化点击软件
6174问题 时间限制:1000 ms | 内存限制:65535 KB 难度:2描述 假设你有一个各位数字互不相同的四位数,把所有的数字从大到小排序后得到a,从小到大后得到b,然后用a-b替换原来这个数,并且继续操作。例如,从1…...
个人网站在那建设/短视频剪辑培训班速成
Thread.currentThread():static 当前线程 方法: setName() getName() isAlive() 优先级: 概率 非绝对的优先级 t1.setPriority(Thread.MAX_PRIORITY); t1.getPriority(); * MAX_PRIORITY 10 * NORM_PRIORITY 5 * MIN_PRIORITY 1 /*** 描述:优先级 概率 非绝对的优先级* MA…...