当前位置: 首页 > news >正文

pytorch学习(7)——神经网络优化器torch.optim

1 optim 优化器

PyTorch神经网络优化器(optimizer)通过调整神经网络的参数(weight和bias)来最小化损失函数(Loss)。
学习链接:

https://pytorch.org/docs/stable/optim.html

在这里插入图片描述

1.1 优化器基类

使用时必须构造一个优化器对象,它将保存当前状态,并将根据计算的梯度(grad)更新参数。
调用优化器的step方法。

CLASS torch.optim.Optimizer(params, defaults)

  • Optimizer - 优化器的优化算法。
  • params (iterable) – torch的迭代器。张量s或dict s,指定应该优化什么张量。
  • defaults – (dict): 包含优化选项默认值的字典(在参数组没有指定优化选项时使用)。每个Optimizer算法都有其独特的设置字典。
算法(Optimizer)说明
Adadelta采用Adadelta算法。
Adagrad采用Adagrad算法。
Adam采用Adam算法。
AdamW采用AdamW算法。
SparseAdam采用适合稀疏张量的Adam算法的惰性版本。
Adamax采用Adamax算法(Adam基于无穷范数的变种)。
ASGD采用平均随机梯度下降。
LBFGS采用L-BFGS算法,深受minFunc的启发。
NAdam采用NAdam 算法。
RAdam采用RAdam 算法。
RMSprop采用RMSprop 算法。
Rprop采用有弹性的反向传播算法。
SGD采用随机梯度下降算法。

1.1.1 SGD 随机梯度下降算法

CLASS torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False, *, maximize=False, foreach=None, differentiable=False)

  • params (iterable) – iterable参数优化或字典定义参数组。
  • lr (float) – 学习率,需要用户输入。
  • momentum (float, optional) – 动量系数(默认值为0)。
  • weight_decay (float, optional) – 权重衰减(L2惩罚) (默认值为0)
  • dampening (float, optional) – 动量阻尼(默认值为0)。
  • nesterov (bool, optional) – 使能Nesterov动量(默认值为False)。
    【Nesterov动量(Nesterov Momentum)是一种基于动量法的优化算法,用于加速神经网络的训练过程。它在随机梯度下降(SGD)的基础上进行改进,通过考虑参数更新前的动量信息来调整参数更新的方向。】
  • maximize (bool, optional) – 根据目标最大化参数,而不是最小化参数(默认值为False)。
  • foreach (bool, optional) – 是否使用foreach优化器的实现。如果用户未指定(foreach为None),我们将尝试在CUDA上的for循环实现上使用foreach,因为CUDA通常性能更高(默认值:None)。
  • differentiable (bool, optional) – 是否在训练中的优化器步骤中发生autograd。否则,step()函数在torch.no_grad()上下文中运行。设置为True会影响性能,所以如果你不打算通过这个实例运行autograd,请保留False(默认值为False)。

学习速率(lr)的取值,如果太大,则模型很不稳定;如果太小,学习速度非常缓慢。因此一般先设置较大的学习速率,然后降低学习速率。

python代码如下:

import torchvision
import torch
from torch import nn, optim
from torch.nn import Linear, Conv2d, MaxPool2d, Flatten, Sequential,CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10(root="G:\\Anaconda\\pycharm_pytorch\\learning_project\\dataset_CIFAR10",train=False,transform=torchvision.transforms.ToTensor(),download=False)dataloader = DataLoader(dataset, batch_size=1)class MYNN(nn.Module):def __init__(self):super(MYNN, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 64, 5, padding=2, stride=1),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = CrossEntropyLoss()
mynn = MYNN()
opitm = optim.SGD(mynn.parameters(), lr=0.01)        # 优化器for data in dataloader:imgs, targets = dataoutputs = mynn(imgs)result_loss = loss(outputs, targets)# print(outputs)          # 神经网络输出# print(targets)          # 目标# print(result_loss)      # 损失函数-交叉熵计算结果opitm.zero_grad()       # 梯度清零,设置断点result_loss.backward()  # 反向传播,求出每个节点的梯度,设置断点opitm.step()            # 对神经网络模型的参数进行调优,设置断点

设置断点,进入程序Debug:
(1)不断运行程序,能够观察到卷积层0的bias梯度变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> bias-> grad
在这里插入图片描述

(2)能观察到卷积层0的weight梯度变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> weight -> grad
在这里插入图片描述

(3)能观察到bias的变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> bias-> data
在这里插入图片描述

(4)能观察到weight的变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> weight-> data
在这里插入图片描述

(5)结论:运行opitm.zero_grad()后,清空weight和bias的梯度grad;运行result_loss.backward()后,计算得到新的weight和bias的梯度grad;运行opitm.step()后,调整weight和bias的值。

1.1.2 优化器多次循环

修改以上python代码,增加多次循环,观察总体损失值改变。

import torchvision
import torch
from torch import nn, optim
from torch.nn import Linear, Conv2d, MaxPool2d, Flatten, Sequential,CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10(root="G:\\Anaconda\\pycharm_pytorch\\learning_project\\dataset_CIFAR10",train=False,transform=torchvision.transforms.ToTensor(),download=False)dataloader = DataLoader(dataset, batch_size=1)class MYNN(nn.Module):def __init__(self):super(MYNN, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 64, 5, padding=2, stride=1),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = CrossEntropyLoss()
mynn = MYNN()
opitm = optim.SGD(mynn.parameters(), lr=0.01)        # 优化器for epoch in range(20):running_loss = 0.0for data in dataloader:imgs, targets = dataoutputs = mynn(imgs)result_loss = loss(outputs, targets)# print(outputs)          # 神经网络输出# print(targets)          # 目标# print(result_loss)      # 损失函数-交叉熵计算结果opitm.zero_grad()       # 梯度清零result_loss.backward()  # 反向传播,求出每个节点的梯度opitm.step()            # 对神经网络模型的参数进行调优running_loss = running_loss + result_loss#.dataprint(running_loss)

运行结果:

tensor(18746.2012, grad_fn=<AddBackward0>)
tensor(16136.0107, grad_fn=<AddBackward0>)
tensor(15499.3203, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)

可以发现running_loss在一开始不断降低,但是以下的nan暂时不知道是什么原因。

相关文章:

pytorch学习(7)——神经网络优化器torch.optim

1 optim 优化器 PyTorch神经网络优化器&#xff08;optimizer&#xff09;通过调整神经网络的参数&#xff08;weight和bias&#xff09;来最小化损失函数&#xff08;Loss&#xff09;。 学习链接&#xff1a; https://pytorch.org/docs/stable/optim.html 1.1 优化器基类 使…...

leetcode做题笔记​101. 对称二叉树

给你一个二叉树的根节点 root &#xff0c; 检查它是否轴对称。 思路一&#xff1a;递归 bool isSymmetric(struct TreeNode* root){if (root NULL) return true;return fun(root->left, root->right); }int fun(struct TreeNode* l_root, struct TreeNode* r_root) {…...

边缘计算相关概念--学习笔记

一.边缘计算概念 边缘计算将数据的处理&#xff0c;应用程序的运行甚至一些功能服务的实现&#xff0c;由网络中心下放到网络边缘的节点上&#xff0c;在网络边缘侧的智能网关上就近采集并且处理数据&#xff0c;不需要将大量未处理的数据上传到远程的大数据平台。边缘计算理论…...

flutter windows编译错误 flutter_assemble.vcxproj

flutter 编译windows是出现错误。 [ 44 ms] d:\Program Files\Microsoft Visual Studio\2022\Community\MSBuild\Microsoft\VC\v170\Microsoft.CppCommon.targets(248,5): error MSB8066: &#xfffd;&#xfffd;E:\work\kkview_kuaichuan\kkview_kuaichuan\build\windows\C…...

通过运行中的容器生成 Docker Compose 配置文件

背景 笔者之前有一次不小心删除了原始的 docker-compose.yml 文件&#xff0c;不过正在运行的 Docker 容器还在&#xff0c;找了许久&#xff0c;发现一个方法可以从这些容器中生成一个等效的 Docker Compose 配置文件。本文将介绍使用 autocompose 工具从正在运行的容器中反向…...

rancher界面无法登陆问题解决,登录超时;

1.找到rancher主机&#xff0c;查看日志 docker ps | grep rancher # rancher 容器 名称 jolly_ptolemy docker logs -f jolly_ptolemy 日志提示&#xff0c; java.sql.SQLException: Got error 28 from storage engine&#xff0c;磁盘满了 2.磁盘管理 df -h #查看磁盘使…...

Django(6)-django项目自动化测试

Django 应用的测试应该写在应用的 tests.py 文件里。测试系统会自动的在所有以 tests 开头的文件里寻找并执行测试代码。 我们的 polls 应用现在有一个小 bug 需要被修复&#xff1a;我们的要求是如果 Question 是在一天之内发布的&#xff0c; Question.was_published_recentl…...

【AUTOSAR】【CAN通信】CanNm

目录 一、概述 二、说明 三、功能说明 3.1 协调算法 3.2 操作模式 3.2.1 网络模式...

拼多多淘宝大量缓存商品数据用什么格式提供比较好?

众所周知&#xff0c;淘宝拼多多是我国主流的电商平台&#xff0c;其上有大量的商品数据。很多商家会通过API来访问他们的商品数据&#xff0c;根据API的调用次数收费。第三方数据公司提供电商数据接口API&#xff0c;采集实时数据。但是&#xff0c;在他们的服务器上有大量的缓…...

【校招VIP】前端校招考点之页面转换算法

考点介绍&#xff1a; 在地址映射过程中&#xff0c;若在页面中发现所要访问的页面不在内存中&#xff0c;则产生缺页中断。当发生缺页中断时&#xff0c;如果操作系统内存中没有空闲页面&#xff0c;则操作系统必须在内存选择一个页面将其移出内存&#xff0c;以便为即将调入的…...

android 下载网络文件

工具类 import android.app.ProgressDialog; import android.content.Context; import android.os.AsyncTask; import android.os.Environment; import android.util.Log;import java.io.BufferedInputStream; import java.io.File; import java.io.FileOutputStream; import …...

springboot定时任务:同时使用定时任务和websocket报错

背景 项目使用了websocket,实现了消息的实时推送。后来项目需要一个定时任务&#xff0c;使用org.springframework.scheduling.annotation的EnableScheduling注解来实现&#xff0c;启动项目之后报错 Bean com.alibaba.cloud.sentinel.custom.SentinelAutoConfiguration of t…...

CSS3渐变及2D转换

CSS3渐变及2D转换 持续更新哦… 1、css3渐变 概念: CSS3渐变(gradient)可以让你在两个或多个指定的颜色之间显示平 稳的过渡。以前&#xff0c;你必须使用图像来实现这些效果&#xff0c;现在通过使用 CSS3的渐变(gradients)即可实现。此外&#xff0c;渐变效果的元素在放大…...

无涯教程-PHP - eregi()函数

eregi() - 语法 int eregi(string pattern, string string, [array regs]); eregi()函数在pattern指定的整个字符串中搜索string指定的字符串,。搜索不区分大小写。 Eregi()在检查字符串的有效性时特别有用。 可选的输入参数regs包含一个由正则表达式中的括号分组的所有匹配…...

Spring与Mybatis整合aop整合pageHelper分页插件

前言 Spring与MyBatis整合的意义在于提供了一种结合优势的方式&#xff0c;以便更好地开发和管理持久层&#xff08;数据库访问&#xff09;代码。 这里也是总结了几点主要意义 简化配置&#xff1a;Spring与MyBatis整合后&#xff0c;可以通过Spring的配置文件来管理和配置M…...

SSL/CA 证书及其相关证书文件(pem、crt、cer、key、csr)

数字证书是网络世界中的身份证&#xff0c;数字证书为实现双方安全通信提供了电子认证。数字证书中含有密钥对所有者的识别信息&#xff0c;通过验证识别信息的真伪实现对证书持有者身份的认证。数字证书可以在网络世界中为互不见面的用户建立安全可靠的信任关系&#xff0c;这…...

【JavaSE】内部类

文章目录 内部类概念局部内部类匿名内部类&#xff08;重点重点&#xff01;&#xff01;&#xff01; &#xff09;成员内部类静态内部类 内部类概念 可以将一个类定义在另一个类或者一个方法的内部&#xff0c;前者称为内部类&#xff0c;后者称为外部类。内部类也是封装的一…...

Django(2)-编写你的第一个 Django 应用

本教程的目的是创建一个网络投票应用程序。 它将由两部分组成&#xff1a; 一个让人们查看和投票的公共站点。 一个让你能添加、修改和删除投票的管理站点。 创建应用 $ python manage.py startapp polls每一个应用是一个python包&#xff0c;一个项目可以包含多个应用。 …...

燃气管网监测系统,24小时守护燃气安全

随着社会的发展和人民生活水平的提高&#xff0c;燃气逐渐成为人们日常生活和工作中不可或缺的一部分。然而&#xff0c;近年来&#xff0c;屡屡发生的燃气爆炸问题&#xff0c;也让人们不禁对燃气的安全性产生了担忧。因此&#xff0c;建立一个高效、实时、准确的燃气管网监测…...

昌硕科技、世硕电子同步上线法大大电子合同

近日&#xff0c;世界500强企业和硕联合旗下上海昌硕科技有限公司&#xff08;以下简称“昌硕科技”&#xff09;、世硕电子&#xff08;昆山&#xff09;有限公司&#xff08;以下简称“世硕电子”&#xff09;的电子签项目正式上线。上线仪式在上海浦东和硕集团科研大楼举行&…...

es的索引管理

概念 &#xff08;1&#xff09;集群&#xff08;Cluster&#xff09;&#xff1a; ES可以作为一个独立的单个搜索服务器。不过&#xff0c;为了处理大型数据集&#xff0c;实现容错和高可用性&#xff0c;ES可以运行在许多互相合作的服务器上。这些服务器的集合称为集群。 &…...

Rust 的四大类型的宏 (元编程)

文章目录 概念函数宏或声明宏&#xff08;Function Macro&#xff09;过程宏&#xff08;Procedural Macro&#xff09;类函数的过程宏&#xff08;Function-like-procedural-macros&#xff09;派生宏&#xff08;Derive Macro&#xff09;派生宏附加其他属性 属性宏&#xff…...

探索数据湖中的巨兽:Apache Hive分布式SQL计算平台浅度剖析!

文章目录 ◆ Apache Hive 概述1.1 分布式SQL计算1.2 Hive的优势 ◆ 模拟实现Hive功能2.1 元数据管理2.2 解析器2.3 基础架构2.4 Hive架构 ◆ Hive基础架构3.1 Hive架构图3.2 Hive组件3.2.1 元数据存储3.2.2 Driver驱动程序3.2.3 用户接口 ◆ Hive部署4.1 VMware虚拟机部署步骤一…...

Node.js 的 Buffer 是什么?一站式了解指南

在 Node.js 中&#xff0c;Buffer 是一种用于处理二进制数据的机制。它允许你在不经过 JavaScript 垃圾回收机制的情况下直接操作原始内存&#xff0c;从而更高效地处理数据&#xff0c;特别是在处理网络流、文件系统操作和其他与 I/O 相关的任务时。Buffer 是一个全局对象&…...

延时盲注技术:SQL 注入漏洞检测入门指南

部分数据来源:ChatGPT 引言 在网络安全领域中,SQL 注入漏洞一直是常见的安全隐患之一。它可以利用应用程序对用户输入的不恰当处理,导致攻击者能够执行恶意的 SQL 查询语句,进而获取、修改或删除数据库中的数据。为了帮助初学者更好地理解和检测 SQL 注入漏洞,本文将介绍…...

【Midjourney电商与平面设计实战】创作效率提升300%

不得不说&#xff0c;最近智能AI的话题火爆圈内外啦。这不&#xff0c;战火已经从IT行业燃烧到设计行业里了。 刚研究完ChatGPT&#xff0c;现在又出来一个AI作图Midjourney。 其视觉效果令不少网友感叹&#xff1a;“AI已经不逊于人类画师了!” 现如今&#xff0c;在AIGC 热…...

URI、URL、URIBuilder、UriBuilder、UriComponentsBuilder说明及基本使用

之前想过直接获取url通过拼接字符串的方式实现,但是这种只是暂时的,后续地址如果有变化或参数很多,去岂不是要拼接很长,由于这些等等原因,所以找了一些方法实现 java.net.URI URI全称是Uniform Resource Identifier,也就是统一资源标识符,它是一种采用特定的语法标识一…...

抓包 - 简要总结 - Windows和Android抓包

抓包 - 简要总结 - Windows和Android抓包 前言 小巧且强大的抓包工具“Fiddler”安装可参考我的另一篇博客&#xff1a;抓包 - 经典抓包工具Fiddler的安装与初使用 本文主要介绍如何使用Fiddler抓包Windows和安卓。 Windows 抓包Windows很简单&#xff0c;安装证书&#x…...

iOS脱壳技术(二):深入探讨dumpdecrypted工具的高级使用方法

前言 应用程序脱壳是指从iOS应用程序中提取其未加密的二进制可执行文件&#xff0c;通常是Mach-O格式。这可以帮助我们深入研究应用程序的底层代码、算法、逻辑以及数据结构。这在逆向工程、性能优化、安全性分析等方面都有着重要的应用。 在上一篇内容中我们已经介绍了Clutc…...

4.RabbitMQ高级特性 幂等 可靠消息 等等

一、如何保证生产者生产消息100%的投递成功 保障消息的成功发出保障MQ节点的成功接收发送端收到MQ节点&#xff08;Broker&#xff09;确认应答完善的消息进行补偿机制 1. 理解Confirm确认消息机制 消息的确认&#xff0c;是指生产者投递消息后&#xff0c;如果Broker收到消…...

编辑网站内容怎么做滚动图片/多地优化完善疫情防控措施

一、Oracle--创建表create table test (id varchar2(200) primary key not null,sort number,name varchar(200))--字段加注释comment on column test.id is ‘id‘;comment on column test.sort is ‘序号‘;--表加注释comment on table test is ‘测试表‘二.Mysql--创建表cr…...

企业网站的建设论文/给我免费的视频在线观看

实验结论 一、实验一&#xff1a;成在屏幕上输出内存单元中的十进制两位数 1、源代码 1 ; 在屏幕上输出内存单元中的十进制两位数2 assume cs:code, ds:data3 data segment4 db 125 db ?,? ; 前一个字节用于保存商&#xff0c;后一个字节用于保存余数6 data ends…...

淘宝图片做链接的网站/百度竞价排名是以什么形式来计费的广告?

1、摘要 本文主要讲解&#xff1a;cnn-bigru-attention对电池寿命进行预测 主要思路&#xff1a; 建立cnn-bigru-attention模型读取数据&#xff0c;将数据截成时序块训练模型&#xff0c;调参&#xff0c;评估模型&#xff0c;保存模型 2、数据介绍 电池寿命数据 3、相关…...

如何自己做淘宝客网站/创建网站怎么创

为什么80%的码农都做不了架构师&#xff1f;>>> 问题来源: http://www.cnblogs.com/del/archive/2010/01/07/1641084.html#1742127 程序使用了 GDI 的新接口: http://www.cnblogs.com/del/archive/2009/12/11/1621790.html uses GdiPlus;procedure TForm1.Button1C…...

佛山网站建设服务/企业建设网站公司

测试自动化有助于提高开发速度&#xff0c;同时减少成本和工作量。在本文中&#xff0c;将分享如何进行自动化测试&#xff0c;以帮助保持测试自动化活动在正确的轨道上&#xff0c;以及测试执行、设计和维护大型企业应用程序的关键技巧。 选用合适的自动化测试工具 每个自动化…...

浙江省建设厅网站在哪里/关键词智能调词工具

场景分析 使用DeclareParents注解为实现类SoloService和KafkaService添加新方法ICommonService.valid方法。 实战案例 接口层 package com.jaemon.service;public interface IConsensusService {void consensus(); }实现类 package com.jaemon.service.impl;Slf4j Service…...