当前位置: 首页 > news >正文

pytorch学习(7)——神经网络优化器torch.optim

1 optim 优化器

PyTorch神经网络优化器(optimizer)通过调整神经网络的参数(weight和bias)来最小化损失函数(Loss)。
学习链接:

https://pytorch.org/docs/stable/optim.html

在这里插入图片描述

1.1 优化器基类

使用时必须构造一个优化器对象,它将保存当前状态,并将根据计算的梯度(grad)更新参数。
调用优化器的step方法。

CLASS torch.optim.Optimizer(params, defaults)

  • Optimizer - 优化器的优化算法。
  • params (iterable) – torch的迭代器。张量s或dict s,指定应该优化什么张量。
  • defaults – (dict): 包含优化选项默认值的字典(在参数组没有指定优化选项时使用)。每个Optimizer算法都有其独特的设置字典。
算法(Optimizer)说明
Adadelta采用Adadelta算法。
Adagrad采用Adagrad算法。
Adam采用Adam算法。
AdamW采用AdamW算法。
SparseAdam采用适合稀疏张量的Adam算法的惰性版本。
Adamax采用Adamax算法(Adam基于无穷范数的变种)。
ASGD采用平均随机梯度下降。
LBFGS采用L-BFGS算法,深受minFunc的启发。
NAdam采用NAdam 算法。
RAdam采用RAdam 算法。
RMSprop采用RMSprop 算法。
Rprop采用有弹性的反向传播算法。
SGD采用随机梯度下降算法。

1.1.1 SGD 随机梯度下降算法

CLASS torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False, *, maximize=False, foreach=None, differentiable=False)

  • params (iterable) – iterable参数优化或字典定义参数组。
  • lr (float) – 学习率,需要用户输入。
  • momentum (float, optional) – 动量系数(默认值为0)。
  • weight_decay (float, optional) – 权重衰减(L2惩罚) (默认值为0)
  • dampening (float, optional) – 动量阻尼(默认值为0)。
  • nesterov (bool, optional) – 使能Nesterov动量(默认值为False)。
    【Nesterov动量(Nesterov Momentum)是一种基于动量法的优化算法,用于加速神经网络的训练过程。它在随机梯度下降(SGD)的基础上进行改进,通过考虑参数更新前的动量信息来调整参数更新的方向。】
  • maximize (bool, optional) – 根据目标最大化参数,而不是最小化参数(默认值为False)。
  • foreach (bool, optional) – 是否使用foreach优化器的实现。如果用户未指定(foreach为None),我们将尝试在CUDA上的for循环实现上使用foreach,因为CUDA通常性能更高(默认值:None)。
  • differentiable (bool, optional) – 是否在训练中的优化器步骤中发生autograd。否则,step()函数在torch.no_grad()上下文中运行。设置为True会影响性能,所以如果你不打算通过这个实例运行autograd,请保留False(默认值为False)。

学习速率(lr)的取值,如果太大,则模型很不稳定;如果太小,学习速度非常缓慢。因此一般先设置较大的学习速率,然后降低学习速率。

python代码如下:

import torchvision
import torch
from torch import nn, optim
from torch.nn import Linear, Conv2d, MaxPool2d, Flatten, Sequential,CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10(root="G:\\Anaconda\\pycharm_pytorch\\learning_project\\dataset_CIFAR10",train=False,transform=torchvision.transforms.ToTensor(),download=False)dataloader = DataLoader(dataset, batch_size=1)class MYNN(nn.Module):def __init__(self):super(MYNN, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 64, 5, padding=2, stride=1),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = CrossEntropyLoss()
mynn = MYNN()
opitm = optim.SGD(mynn.parameters(), lr=0.01)        # 优化器for data in dataloader:imgs, targets = dataoutputs = mynn(imgs)result_loss = loss(outputs, targets)# print(outputs)          # 神经网络输出# print(targets)          # 目标# print(result_loss)      # 损失函数-交叉熵计算结果opitm.zero_grad()       # 梯度清零,设置断点result_loss.backward()  # 反向传播,求出每个节点的梯度,设置断点opitm.step()            # 对神经网络模型的参数进行调优,设置断点

设置断点,进入程序Debug:
(1)不断运行程序,能够观察到卷积层0的bias梯度变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> bias-> grad
在这里插入图片描述

(2)能观察到卷积层0的weight梯度变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> weight -> grad
在这里插入图片描述

(3)能观察到bias的变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> bias-> data
在这里插入图片描述

(4)能观察到weight的变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> weight-> data
在这里插入图片描述

(5)结论:运行opitm.zero_grad()后,清空weight和bias的梯度grad;运行result_loss.backward()后,计算得到新的weight和bias的梯度grad;运行opitm.step()后,调整weight和bias的值。

1.1.2 优化器多次循环

修改以上python代码,增加多次循环,观察总体损失值改变。

import torchvision
import torch
from torch import nn, optim
from torch.nn import Linear, Conv2d, MaxPool2d, Flatten, Sequential,CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10(root="G:\\Anaconda\\pycharm_pytorch\\learning_project\\dataset_CIFAR10",train=False,transform=torchvision.transforms.ToTensor(),download=False)dataloader = DataLoader(dataset, batch_size=1)class MYNN(nn.Module):def __init__(self):super(MYNN, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 64, 5, padding=2, stride=1),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = CrossEntropyLoss()
mynn = MYNN()
opitm = optim.SGD(mynn.parameters(), lr=0.01)        # 优化器for epoch in range(20):running_loss = 0.0for data in dataloader:imgs, targets = dataoutputs = mynn(imgs)result_loss = loss(outputs, targets)# print(outputs)          # 神经网络输出# print(targets)          # 目标# print(result_loss)      # 损失函数-交叉熵计算结果opitm.zero_grad()       # 梯度清零result_loss.backward()  # 反向传播,求出每个节点的梯度opitm.step()            # 对神经网络模型的参数进行调优running_loss = running_loss + result_loss#.dataprint(running_loss)

运行结果:

tensor(18746.2012, grad_fn=<AddBackward0>)
tensor(16136.0107, grad_fn=<AddBackward0>)
tensor(15499.3203, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)

可以发现running_loss在一开始不断降低,但是以下的nan暂时不知道是什么原因。

相关文章:

pytorch学习(7)——神经网络优化器torch.optim

1 optim 优化器 PyTorch神经网络优化器&#xff08;optimizer&#xff09;通过调整神经网络的参数&#xff08;weight和bias&#xff09;来最小化损失函数&#xff08;Loss&#xff09;。 学习链接&#xff1a; https://pytorch.org/docs/stable/optim.html 1.1 优化器基类 使…...

leetcode做题笔记​101. 对称二叉树

给你一个二叉树的根节点 root &#xff0c; 检查它是否轴对称。 思路一&#xff1a;递归 bool isSymmetric(struct TreeNode* root){if (root NULL) return true;return fun(root->left, root->right); }int fun(struct TreeNode* l_root, struct TreeNode* r_root) {…...

边缘计算相关概念--学习笔记

一.边缘计算概念 边缘计算将数据的处理&#xff0c;应用程序的运行甚至一些功能服务的实现&#xff0c;由网络中心下放到网络边缘的节点上&#xff0c;在网络边缘侧的智能网关上就近采集并且处理数据&#xff0c;不需要将大量未处理的数据上传到远程的大数据平台。边缘计算理论…...

flutter windows编译错误 flutter_assemble.vcxproj

flutter 编译windows是出现错误。 [ 44 ms] d:\Program Files\Microsoft Visual Studio\2022\Community\MSBuild\Microsoft\VC\v170\Microsoft.CppCommon.targets(248,5): error MSB8066: &#xfffd;&#xfffd;E:\work\kkview_kuaichuan\kkview_kuaichuan\build\windows\C…...

通过运行中的容器生成 Docker Compose 配置文件

背景 笔者之前有一次不小心删除了原始的 docker-compose.yml 文件&#xff0c;不过正在运行的 Docker 容器还在&#xff0c;找了许久&#xff0c;发现一个方法可以从这些容器中生成一个等效的 Docker Compose 配置文件。本文将介绍使用 autocompose 工具从正在运行的容器中反向…...

rancher界面无法登陆问题解决,登录超时;

1.找到rancher主机&#xff0c;查看日志 docker ps | grep rancher # rancher 容器 名称 jolly_ptolemy docker logs -f jolly_ptolemy 日志提示&#xff0c; java.sql.SQLException: Got error 28 from storage engine&#xff0c;磁盘满了 2.磁盘管理 df -h #查看磁盘使…...

Django(6)-django项目自动化测试

Django 应用的测试应该写在应用的 tests.py 文件里。测试系统会自动的在所有以 tests 开头的文件里寻找并执行测试代码。 我们的 polls 应用现在有一个小 bug 需要被修复&#xff1a;我们的要求是如果 Question 是在一天之内发布的&#xff0c; Question.was_published_recentl…...

【AUTOSAR】【CAN通信】CanNm

目录 一、概述 二、说明 三、功能说明 3.1 协调算法 3.2 操作模式 3.2.1 网络模式...

拼多多淘宝大量缓存商品数据用什么格式提供比较好?

众所周知&#xff0c;淘宝拼多多是我国主流的电商平台&#xff0c;其上有大量的商品数据。很多商家会通过API来访问他们的商品数据&#xff0c;根据API的调用次数收费。第三方数据公司提供电商数据接口API&#xff0c;采集实时数据。但是&#xff0c;在他们的服务器上有大量的缓…...

【校招VIP】前端校招考点之页面转换算法

考点介绍&#xff1a; 在地址映射过程中&#xff0c;若在页面中发现所要访问的页面不在内存中&#xff0c;则产生缺页中断。当发生缺页中断时&#xff0c;如果操作系统内存中没有空闲页面&#xff0c;则操作系统必须在内存选择一个页面将其移出内存&#xff0c;以便为即将调入的…...

android 下载网络文件

工具类 import android.app.ProgressDialog; import android.content.Context; import android.os.AsyncTask; import android.os.Environment; import android.util.Log;import java.io.BufferedInputStream; import java.io.File; import java.io.FileOutputStream; import …...

springboot定时任务:同时使用定时任务和websocket报错

背景 项目使用了websocket,实现了消息的实时推送。后来项目需要一个定时任务&#xff0c;使用org.springframework.scheduling.annotation的EnableScheduling注解来实现&#xff0c;启动项目之后报错 Bean com.alibaba.cloud.sentinel.custom.SentinelAutoConfiguration of t…...

CSS3渐变及2D转换

CSS3渐变及2D转换 持续更新哦… 1、css3渐变 概念: CSS3渐变(gradient)可以让你在两个或多个指定的颜色之间显示平 稳的过渡。以前&#xff0c;你必须使用图像来实现这些效果&#xff0c;现在通过使用 CSS3的渐变(gradients)即可实现。此外&#xff0c;渐变效果的元素在放大…...

无涯教程-PHP - eregi()函数

eregi() - 语法 int eregi(string pattern, string string, [array regs]); eregi()函数在pattern指定的整个字符串中搜索string指定的字符串,。搜索不区分大小写。 Eregi()在检查字符串的有效性时特别有用。 可选的输入参数regs包含一个由正则表达式中的括号分组的所有匹配…...

Spring与Mybatis整合aop整合pageHelper分页插件

前言 Spring与MyBatis整合的意义在于提供了一种结合优势的方式&#xff0c;以便更好地开发和管理持久层&#xff08;数据库访问&#xff09;代码。 这里也是总结了几点主要意义 简化配置&#xff1a;Spring与MyBatis整合后&#xff0c;可以通过Spring的配置文件来管理和配置M…...

SSL/CA 证书及其相关证书文件(pem、crt、cer、key、csr)

数字证书是网络世界中的身份证&#xff0c;数字证书为实现双方安全通信提供了电子认证。数字证书中含有密钥对所有者的识别信息&#xff0c;通过验证识别信息的真伪实现对证书持有者身份的认证。数字证书可以在网络世界中为互不见面的用户建立安全可靠的信任关系&#xff0c;这…...

【JavaSE】内部类

文章目录 内部类概念局部内部类匿名内部类&#xff08;重点重点&#xff01;&#xff01;&#xff01; &#xff09;成员内部类静态内部类 内部类概念 可以将一个类定义在另一个类或者一个方法的内部&#xff0c;前者称为内部类&#xff0c;后者称为外部类。内部类也是封装的一…...

Django(2)-编写你的第一个 Django 应用

本教程的目的是创建一个网络投票应用程序。 它将由两部分组成&#xff1a; 一个让人们查看和投票的公共站点。 一个让你能添加、修改和删除投票的管理站点。 创建应用 $ python manage.py startapp polls每一个应用是一个python包&#xff0c;一个项目可以包含多个应用。 …...

燃气管网监测系统,24小时守护燃气安全

随着社会的发展和人民生活水平的提高&#xff0c;燃气逐渐成为人们日常生活和工作中不可或缺的一部分。然而&#xff0c;近年来&#xff0c;屡屡发生的燃气爆炸问题&#xff0c;也让人们不禁对燃气的安全性产生了担忧。因此&#xff0c;建立一个高效、实时、准确的燃气管网监测…...

昌硕科技、世硕电子同步上线法大大电子合同

近日&#xff0c;世界500强企业和硕联合旗下上海昌硕科技有限公司&#xff08;以下简称“昌硕科技”&#xff09;、世硕电子&#xff08;昆山&#xff09;有限公司&#xff08;以下简称“世硕电子”&#xff09;的电子签项目正式上线。上线仪式在上海浦东和硕集团科研大楼举行&…...

QMC5883L的驱动

简介 本篇文章的代码已经上传到了github上面&#xff0c;开源代码 作为一个电子罗盘模块&#xff0c;我们可以通过I2C从中获取偏航角yaw&#xff0c;相对于六轴陀螺仪的yaw&#xff0c;qmc5883l几乎不会零飘并且成本较低。 参考资料 QMC5883L磁场传感器驱动 QMC5883L磁力计…...

HBuilderX安装(uni-app和小程序开发)

下载HBuilderX 访问官方网站&#xff1a;https://www.dcloud.io/hbuilderx.html 根据您的操作系统选择合适版本&#xff1a; Windows版&#xff08;推荐下载标准版&#xff09; Windows系统安装步骤 运行安装程序&#xff1a; 双击下载的.exe安装文件 如果出现安全提示&…...

微服务商城-商品微服务

数据表 CREATE TABLE product (id bigint(20) UNSIGNED NOT NULL AUTO_INCREMENT COMMENT 商品id,cateid smallint(6) UNSIGNED NOT NULL DEFAULT 0 COMMENT 类别Id,name varchar(100) NOT NULL DEFAULT COMMENT 商品名称,subtitle varchar(200) NOT NULL DEFAULT COMMENT 商…...

MySQL 知识小结(一)

一、my.cnf配置详解 我们知道安装MySQL有两种方式来安装咱们的MySQL数据库&#xff0c;分别是二进制安装编译数据库或者使用三方yum来进行安装,第三方yum的安装相对于二进制压缩包的安装更快捷&#xff0c;但是文件存放起来数据比较冗余&#xff0c;用二进制能够更好管理咱们M…...

Redis:现代应用开发的高效内存数据存储利器

一、Redis的起源与发展 Redis最初由意大利程序员Salvatore Sanfilippo在2009年开发&#xff0c;其初衷是为了满足他自己的一个项目需求&#xff0c;即需要一个高性能的键值存储系统来解决传统数据库在高并发场景下的性能瓶颈。随着项目的开源&#xff0c;Redis凭借其简单易用、…...

STM32---外部32.768K晶振(LSE)无法起振问题

晶振是否起振主要就检查两个1、晶振与MCU是否兼容&#xff1b;2、晶振的负载电容是否匹配 目录 一、判断晶振与MCU是否兼容 二、判断负载电容是否匹配 1. 晶振负载电容&#xff08;CL&#xff09;与匹配电容&#xff08;CL1、CL2&#xff09;的关系 2. 如何选择 CL1 和 CL…...

第一篇:Liunx环境下搭建PaddlePaddle 3.0基础环境(Liunx Centos8.5安装Python3.10+pip3.10)

第一篇&#xff1a;Liunx环境下搭建PaddlePaddle 3.0基础环境&#xff08;Liunx Centos8.5安装Python3.10pip3.10&#xff09; 一&#xff1a;前言二&#xff1a;安装编译依赖二&#xff1a;安装Python3.10三&#xff1a;安装PIP3.10四&#xff1a;安装Paddlepaddle基础框架4.1…...

Linux安全加固:从攻防视角构建系统免疫

Linux安全加固:从攻防视角构建系统免疫 构建坚不可摧的数字堡垒 引言:攻防对抗的新纪元 在日益复杂的网络威胁环境中,Linux系统安全已从被动防御转向主动免疫。2023年全球网络安全报告显示,高级持续性威胁(APT)攻击同比增长65%,平均入侵停留时间缩短至48小时。本章将从…...

用 Rust 重写 Linux 内核模块实战:迈向安全内核的新篇章

用 Rust 重写 Linux 内核模块实战&#xff1a;迈向安全内核的新篇章 ​​摘要&#xff1a;​​ 操作系统内核的安全性、稳定性至关重要。传统 Linux 内核模块开发长期依赖于 C 语言&#xff0c;受限于 C 语言本身的内存安全和并发安全问题&#xff0c;开发复杂模块极易引入难以…...

GraphRAG优化新思路-开源的ROGRAG框架

目前的如微软开源的GraphRAG的工作流程都较为复杂&#xff0c;难以孤立地评估各个组件的贡献&#xff0c;传统的检索方法在处理复杂推理任务时可能不够有效&#xff0c;特别是在需要理解实体间关系或多跳知识的情况下。先说结论&#xff0c;看完后感觉这个框架性能上不会比Grap…...