当前位置: 首页 > news >正文

pytorch学习(7)——神经网络优化器torch.optim

1 optim 优化器

PyTorch神经网络优化器(optimizer)通过调整神经网络的参数(weight和bias)来最小化损失函数(Loss)。
学习链接:

https://pytorch.org/docs/stable/optim.html

在这里插入图片描述

1.1 优化器基类

使用时必须构造一个优化器对象,它将保存当前状态,并将根据计算的梯度(grad)更新参数。
调用优化器的step方法。

CLASS torch.optim.Optimizer(params, defaults)

  • Optimizer - 优化器的优化算法。
  • params (iterable) – torch的迭代器。张量s或dict s,指定应该优化什么张量。
  • defaults – (dict): 包含优化选项默认值的字典(在参数组没有指定优化选项时使用)。每个Optimizer算法都有其独特的设置字典。
算法(Optimizer)说明
Adadelta采用Adadelta算法。
Adagrad采用Adagrad算法。
Adam采用Adam算法。
AdamW采用AdamW算法。
SparseAdam采用适合稀疏张量的Adam算法的惰性版本。
Adamax采用Adamax算法(Adam基于无穷范数的变种)。
ASGD采用平均随机梯度下降。
LBFGS采用L-BFGS算法,深受minFunc的启发。
NAdam采用NAdam 算法。
RAdam采用RAdam 算法。
RMSprop采用RMSprop 算法。
Rprop采用有弹性的反向传播算法。
SGD采用随机梯度下降算法。

1.1.1 SGD 随机梯度下降算法

CLASS torch.optim.SGD(params, lr=<required parameter>, momentum=0, dampening=0, weight_decay=0, nesterov=False, *, maximize=False, foreach=None, differentiable=False)

  • params (iterable) – iterable参数优化或字典定义参数组。
  • lr (float) – 学习率,需要用户输入。
  • momentum (float, optional) – 动量系数(默认值为0)。
  • weight_decay (float, optional) – 权重衰减(L2惩罚) (默认值为0)
  • dampening (float, optional) – 动量阻尼(默认值为0)。
  • nesterov (bool, optional) – 使能Nesterov动量(默认值为False)。
    【Nesterov动量(Nesterov Momentum)是一种基于动量法的优化算法,用于加速神经网络的训练过程。它在随机梯度下降(SGD)的基础上进行改进,通过考虑参数更新前的动量信息来调整参数更新的方向。】
  • maximize (bool, optional) – 根据目标最大化参数,而不是最小化参数(默认值为False)。
  • foreach (bool, optional) – 是否使用foreach优化器的实现。如果用户未指定(foreach为None),我们将尝试在CUDA上的for循环实现上使用foreach,因为CUDA通常性能更高(默认值:None)。
  • differentiable (bool, optional) – 是否在训练中的优化器步骤中发生autograd。否则,step()函数在torch.no_grad()上下文中运行。设置为True会影响性能,所以如果你不打算通过这个实例运行autograd,请保留False(默认值为False)。

学习速率(lr)的取值,如果太大,则模型很不稳定;如果太小,学习速度非常缓慢。因此一般先设置较大的学习速率,然后降低学习速率。

python代码如下:

import torchvision
import torch
from torch import nn, optim
from torch.nn import Linear, Conv2d, MaxPool2d, Flatten, Sequential,CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10(root="G:\\Anaconda\\pycharm_pytorch\\learning_project\\dataset_CIFAR10",train=False,transform=torchvision.transforms.ToTensor(),download=False)dataloader = DataLoader(dataset, batch_size=1)class MYNN(nn.Module):def __init__(self):super(MYNN, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 64, 5, padding=2, stride=1),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = CrossEntropyLoss()
mynn = MYNN()
opitm = optim.SGD(mynn.parameters(), lr=0.01)        # 优化器for data in dataloader:imgs, targets = dataoutputs = mynn(imgs)result_loss = loss(outputs, targets)# print(outputs)          # 神经网络输出# print(targets)          # 目标# print(result_loss)      # 损失函数-交叉熵计算结果opitm.zero_grad()       # 梯度清零,设置断点result_loss.backward()  # 反向传播,求出每个节点的梯度,设置断点opitm.step()            # 对神经网络模型的参数进行调优,设置断点

设置断点,进入程序Debug:
(1)不断运行程序,能够观察到卷积层0的bias梯度变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> bias-> grad
在这里插入图片描述

(2)能观察到卷积层0的weight梯度变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> weight -> grad
在这里插入图片描述

(3)能观察到bias的变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> bias-> data
在这里插入图片描述

(4)能观察到weight的变化:mynn -> Protected Attributes -> _modules -> ‘model1’ -> Protected Attributes -> _modules -> ‘0’ -> weight-> data
在这里插入图片描述

(5)结论:运行opitm.zero_grad()后,清空weight和bias的梯度grad;运行result_loss.backward()后,计算得到新的weight和bias的梯度grad;运行opitm.step()后,调整weight和bias的值。

1.1.2 优化器多次循环

修改以上python代码,增加多次循环,观察总体损失值改变。

import torchvision
import torch
from torch import nn, optim
from torch.nn import Linear, Conv2d, MaxPool2d, Flatten, Sequential,CrossEntropyLoss
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriterdataset = torchvision.datasets.CIFAR10(root="G:\\Anaconda\\pycharm_pytorch\\learning_project\\dataset_CIFAR10",train=False,transform=torchvision.transforms.ToTensor(),download=False)dataloader = DataLoader(dataset, batch_size=1)class MYNN(nn.Module):def __init__(self):super(MYNN, self).__init__()self.model1 = Sequential(Conv2d(3, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 32, 5, padding=2, stride=1),MaxPool2d(2),Conv2d(32, 64, 5, padding=2, stride=1),MaxPool2d(2),Flatten(),Linear(1024, 64),Linear(64, 10))def forward(self, x):x = self.model1(x)return xloss = CrossEntropyLoss()
mynn = MYNN()
opitm = optim.SGD(mynn.parameters(), lr=0.01)        # 优化器for epoch in range(20):running_loss = 0.0for data in dataloader:imgs, targets = dataoutputs = mynn(imgs)result_loss = loss(outputs, targets)# print(outputs)          # 神经网络输出# print(targets)          # 目标# print(result_loss)      # 损失函数-交叉熵计算结果opitm.zero_grad()       # 梯度清零result_loss.backward()  # 反向传播,求出每个节点的梯度opitm.step()            # 对神经网络模型的参数进行调优running_loss = running_loss + result_loss#.dataprint(running_loss)

运行结果:

tensor(18746.2012, grad_fn=<AddBackward0>)
tensor(16136.0107, grad_fn=<AddBackward0>)
tensor(15499.3203, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)
tensor(nan, grad_fn=<AddBackward0>)

可以发现running_loss在一开始不断降低,但是以下的nan暂时不知道是什么原因。

相关文章:

pytorch学习(7)——神经网络优化器torch.optim

1 optim 优化器 PyTorch神经网络优化器&#xff08;optimizer&#xff09;通过调整神经网络的参数&#xff08;weight和bias&#xff09;来最小化损失函数&#xff08;Loss&#xff09;。 学习链接&#xff1a; https://pytorch.org/docs/stable/optim.html 1.1 优化器基类 使…...

leetcode做题笔记​101. 对称二叉树

给你一个二叉树的根节点 root &#xff0c; 检查它是否轴对称。 思路一&#xff1a;递归 bool isSymmetric(struct TreeNode* root){if (root NULL) return true;return fun(root->left, root->right); }int fun(struct TreeNode* l_root, struct TreeNode* r_root) {…...

边缘计算相关概念--学习笔记

一.边缘计算概念 边缘计算将数据的处理&#xff0c;应用程序的运行甚至一些功能服务的实现&#xff0c;由网络中心下放到网络边缘的节点上&#xff0c;在网络边缘侧的智能网关上就近采集并且处理数据&#xff0c;不需要将大量未处理的数据上传到远程的大数据平台。边缘计算理论…...

flutter windows编译错误 flutter_assemble.vcxproj

flutter 编译windows是出现错误。 [ 44 ms] d:\Program Files\Microsoft Visual Studio\2022\Community\MSBuild\Microsoft\VC\v170\Microsoft.CppCommon.targets(248,5): error MSB8066: &#xfffd;&#xfffd;E:\work\kkview_kuaichuan\kkview_kuaichuan\build\windows\C…...

通过运行中的容器生成 Docker Compose 配置文件

背景 笔者之前有一次不小心删除了原始的 docker-compose.yml 文件&#xff0c;不过正在运行的 Docker 容器还在&#xff0c;找了许久&#xff0c;发现一个方法可以从这些容器中生成一个等效的 Docker Compose 配置文件。本文将介绍使用 autocompose 工具从正在运行的容器中反向…...

rancher界面无法登陆问题解决,登录超时;

1.找到rancher主机&#xff0c;查看日志 docker ps | grep rancher # rancher 容器 名称 jolly_ptolemy docker logs -f jolly_ptolemy 日志提示&#xff0c; java.sql.SQLException: Got error 28 from storage engine&#xff0c;磁盘满了 2.磁盘管理 df -h #查看磁盘使…...

Django(6)-django项目自动化测试

Django 应用的测试应该写在应用的 tests.py 文件里。测试系统会自动的在所有以 tests 开头的文件里寻找并执行测试代码。 我们的 polls 应用现在有一个小 bug 需要被修复&#xff1a;我们的要求是如果 Question 是在一天之内发布的&#xff0c; Question.was_published_recentl…...

【AUTOSAR】【CAN通信】CanNm

目录 一、概述 二、说明 三、功能说明 3.1 协调算法 3.2 操作模式 3.2.1 网络模式...

拼多多淘宝大量缓存商品数据用什么格式提供比较好?

众所周知&#xff0c;淘宝拼多多是我国主流的电商平台&#xff0c;其上有大量的商品数据。很多商家会通过API来访问他们的商品数据&#xff0c;根据API的调用次数收费。第三方数据公司提供电商数据接口API&#xff0c;采集实时数据。但是&#xff0c;在他们的服务器上有大量的缓…...

【校招VIP】前端校招考点之页面转换算法

考点介绍&#xff1a; 在地址映射过程中&#xff0c;若在页面中发现所要访问的页面不在内存中&#xff0c;则产生缺页中断。当发生缺页中断时&#xff0c;如果操作系统内存中没有空闲页面&#xff0c;则操作系统必须在内存选择一个页面将其移出内存&#xff0c;以便为即将调入的…...

android 下载网络文件

工具类 import android.app.ProgressDialog; import android.content.Context; import android.os.AsyncTask; import android.os.Environment; import android.util.Log;import java.io.BufferedInputStream; import java.io.File; import java.io.FileOutputStream; import …...

springboot定时任务:同时使用定时任务和websocket报错

背景 项目使用了websocket,实现了消息的实时推送。后来项目需要一个定时任务&#xff0c;使用org.springframework.scheduling.annotation的EnableScheduling注解来实现&#xff0c;启动项目之后报错 Bean com.alibaba.cloud.sentinel.custom.SentinelAutoConfiguration of t…...

CSS3渐变及2D转换

CSS3渐变及2D转换 持续更新哦… 1、css3渐变 概念: CSS3渐变(gradient)可以让你在两个或多个指定的颜色之间显示平 稳的过渡。以前&#xff0c;你必须使用图像来实现这些效果&#xff0c;现在通过使用 CSS3的渐变(gradients)即可实现。此外&#xff0c;渐变效果的元素在放大…...

无涯教程-PHP - eregi()函数

eregi() - 语法 int eregi(string pattern, string string, [array regs]); eregi()函数在pattern指定的整个字符串中搜索string指定的字符串,。搜索不区分大小写。 Eregi()在检查字符串的有效性时特别有用。 可选的输入参数regs包含一个由正则表达式中的括号分组的所有匹配…...

Spring与Mybatis整合aop整合pageHelper分页插件

前言 Spring与MyBatis整合的意义在于提供了一种结合优势的方式&#xff0c;以便更好地开发和管理持久层&#xff08;数据库访问&#xff09;代码。 这里也是总结了几点主要意义 简化配置&#xff1a;Spring与MyBatis整合后&#xff0c;可以通过Spring的配置文件来管理和配置M…...

SSL/CA 证书及其相关证书文件(pem、crt、cer、key、csr)

数字证书是网络世界中的身份证&#xff0c;数字证书为实现双方安全通信提供了电子认证。数字证书中含有密钥对所有者的识别信息&#xff0c;通过验证识别信息的真伪实现对证书持有者身份的认证。数字证书可以在网络世界中为互不见面的用户建立安全可靠的信任关系&#xff0c;这…...

【JavaSE】内部类

文章目录 内部类概念局部内部类匿名内部类&#xff08;重点重点&#xff01;&#xff01;&#xff01; &#xff09;成员内部类静态内部类 内部类概念 可以将一个类定义在另一个类或者一个方法的内部&#xff0c;前者称为内部类&#xff0c;后者称为外部类。内部类也是封装的一…...

Django(2)-编写你的第一个 Django 应用

本教程的目的是创建一个网络投票应用程序。 它将由两部分组成&#xff1a; 一个让人们查看和投票的公共站点。 一个让你能添加、修改和删除投票的管理站点。 创建应用 $ python manage.py startapp polls每一个应用是一个python包&#xff0c;一个项目可以包含多个应用。 …...

燃气管网监测系统,24小时守护燃气安全

随着社会的发展和人民生活水平的提高&#xff0c;燃气逐渐成为人们日常生活和工作中不可或缺的一部分。然而&#xff0c;近年来&#xff0c;屡屡发生的燃气爆炸问题&#xff0c;也让人们不禁对燃气的安全性产生了担忧。因此&#xff0c;建立一个高效、实时、准确的燃气管网监测…...

昌硕科技、世硕电子同步上线法大大电子合同

近日&#xff0c;世界500强企业和硕联合旗下上海昌硕科技有限公司&#xff08;以下简称“昌硕科技”&#xff09;、世硕电子&#xff08;昆山&#xff09;有限公司&#xff08;以下简称“世硕电子”&#xff09;的电子签项目正式上线。上线仪式在上海浦东和硕集团科研大楼举行&…...

Docker 离线安装指南

参考文章 1、确认操作系统类型及内核版本 Docker依赖于Linux内核的一些特性&#xff0c;不同版本的Docker对内核版本有不同要求。例如&#xff0c;Docker 17.06及之后的版本通常需要Linux内核3.10及以上版本&#xff0c;Docker17.09及更高版本对应Linux内核4.9.x及更高版本。…...

连锁超市冷库节能解决方案:如何实现超市降本增效

在连锁超市冷库运营中&#xff0c;高能耗、设备损耗快、人工管理低效等问题长期困扰企业。御控冷库节能解决方案通过智能控制化霜、按需化霜、实时监控、故障诊断、自动预警、远程控制开关六大核心技术&#xff0c;实现年省电费15%-60%&#xff0c;且不改动原有装备、安装快捷、…...

【CSS position 属性】static、relative、fixed、absolute 、sticky详细介绍,多层嵌套定位示例

文章目录 ★ position 的五种类型及基本用法 ★ 一、position 属性概述 二、position 的五种类型详解(初学者版) 1. static(默认值) 2. relative(相对定位) 3. absolute(绝对定位) 4. fixed(固定定位) 5. sticky(粘性定位) 三、定位元素的层级关系(z-i…...

(二)原型模式

原型的功能是将一个已经存在的对象作为源目标,其余对象都是通过这个源目标创建。发挥复制的作用就是原型模式的核心思想。 一、源型模式的定义 原型模式是指第二次创建对象可以通过复制已经存在的原型对象来实现,忽略对象创建过程中的其它细节。 📌 核心特点: 避免重复初…...

spring:实例工厂方法获取bean

spring处理使用静态工厂方法获取bean实例&#xff0c;也可以通过实例工厂方法获取bean实例。 实例工厂方法步骤如下&#xff1a; 定义实例工厂类&#xff08;Java代码&#xff09;&#xff0c;定义实例工厂&#xff08;xml&#xff09;&#xff0c;定义调用实例工厂&#xff…...

python爬虫:Newspaper3k 的详细使用(好用的新闻网站文章抓取和解析的Python库)

更多内容请见: 爬虫和逆向教程-专栏介绍和目录 文章目录 一、Newspaper3k 概述1.1 Newspaper3k 介绍1.2 主要功能1.3 典型应用场景1.4 安装二、基本用法2.2 提取单篇文章的内容2.2 处理多篇文档三、高级选项3.1 自定义配置3.2 分析文章情感四、实战案例4.1 构建新闻摘要聚合器…...

C# 类和继承(抽象类)

抽象类 抽象类是指设计为被继承的类。抽象类只能被用作其他类的基类。 不能创建抽象类的实例。抽象类使用abstract修饰符声明。 抽象类可以包含抽象成员或普通的非抽象成员。抽象类的成员可以是抽象成员和普通带 实现的成员的任意组合。抽象类自己可以派生自另一个抽象类。例…...

2025盘古石杯决赛【手机取证】

前言 第三届盘古石杯国际电子数据取证大赛决赛 最后一题没有解出来&#xff0c;实在找不到&#xff0c;希望有大佬教一下我。 还有就会议时间&#xff0c;我感觉不是图片时间&#xff0c;因为在电脑看到是其他时间用老会议系统开的会。 手机取证 1、分析鸿蒙手机检材&#x…...

Spring Boot+Neo4j知识图谱实战:3步搭建智能关系网络!

一、引言 在数据驱动的背景下&#xff0c;知识图谱凭借其高效的信息组织能力&#xff0c;正逐步成为各行业应用的关键技术。本文聚焦 Spring Boot与Neo4j图数据库的技术结合&#xff0c;探讨知识图谱开发的实现细节&#xff0c;帮助读者掌握该技术栈在实际项目中的落地方法。 …...

vue3+vite项目中使用.env文件环境变量方法

vue3vite项目中使用.env文件环境变量方法 .env文件作用命名规则常用的配置项示例使用方法注意事项在vite.config.js文件中读取环境变量方法 .env文件作用 .env 文件用于定义环境变量&#xff0c;这些变量可以在项目中通过 import.meta.env 进行访问。Vite 会自动加载这些环境变…...