当前位置: 首页 > news >正文

【Pytorch】学习记录分享3——PyTorch 自动微分与线性回归

【【Pytorch】学习记录分享3——PyTorch 自动微分与线性回归

      • 1. autograd 包,自动微分
      • 2. 线性模型回归演示
      • 3. GPU进行模型训练

小结:只需要将前向传播设置好,调用反向传播接口,即可实现反向传播的链式求导

1. autograd 包,自动微分

自动微分是机器学习工具包必备的工具,它可以自动计算整个计算图的微分。

PyTorch内建了一个叫做torch.autograd的自动微分引擎,该引擎支持的数据类型为:浮点数Tensor类型 ( half, float, double and bfloat16) 和复数Tensor 类型(cfloat, cdouble)

PyTorch中与自动微分相关的常用的Tensor属性和函数:

属性requires_grad:
默认值为False,表明该Tensor不会被自动微分引擎计算微分。设置为True,表明让自动微分引擎计算该Tensor的微分
属性grad:存储自动微分的计算结果,即调用backward()方法后的计算结果
方法backward(): 计算微分,一般不带参数,等效于:backward(torch.tensor(1.0))。若backward()方法在DAG的root上调用,它会依据链式法则自动计算DAG所有枝叶上的微分。
方法no_grad():禁用自动微分上下文管理, 一般用于模型评估或推理计算这些不需要执行自动微分计算的地方,以减少内存和算力的消耗。另外禁止在模型参数上自动计算微分,即不允许更新该参数,即所谓的冻结参数(frozen parameters)。
zero_grad()方法:PyTorch的微分是自动积累的,需要用zero_grad()方法手动清零

# 模型:z = x@w + b;激活函数:Softmax
x = torch.ones(5)  # 输入张量,shape=(5,)
labels = torch.zeros(3) # 标签值,shape=(3,)
w = torch.randn(5,3,requires_grad=True) # 模型参数,需要计算微分, shape=(5,3)
b = torch.randn(3, requires_grad=True)  # 模型参数,需要计算微分, shape=(3,)
z = x@w + b # 模型前向计算
outputs = torch.nn.functional.softmax(z) # 激活函数
print("z: ",z)
print("outputs: ",outputs)
loss = torch.nn.functional.binary_cross_entropy(outputs, labels)
# 查看loss函数的微分计算函数
print('Gradient function for loss =', loss.grad_fn)
# 调用loss函数的backward()方法计算模型参数的微分
loss.backward()
# 查看模型参数的微分值
print("w: ",w.grad)
print("b.grad: ",b.grad)

在这里插入图片描述

小姐:

方法描述
.requires_grad 设置为True会开始跟踪针对 tensor 的所有操作
.backward()张量的梯度将累积到 .grad 属性
import torchx=torch.rand(1)
b=torch.rand(1,requires_grad=True)
w=torch.rand(1,requires_grad=True)
y = w * x
z = y + bx.requires_grad, w.requires_grad,b.requires_grad,y.requires_grad,z.requires_gradprint("x: ",x, end="\n"),print("b: ",b ,end="\n"),print("w: ",w ,end="\n")
print("y: ",y, end="\n"),print("z: ",z, end="\n")# 反向传播计算
z.backward(retain_graph=True) #注意:如果不清空,b每一次更新,都会自我累加起来,依次为1 2 3 4 。。。w.grad
b.grad

运行结果:
在这里插入图片描述
反向传播求导原理:
在这里插入图片描述

2. 线性模型回归演示

import torch
import torch.nn as nn## 线性回归模型: 本质上就是一个不加 激活函数的 全连接层
class LinearRegressionModel(nn.Module):def __init__(self, input_size, output_size):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(input_size, output_size)def forward(self, x):out = self.linear(x)return out
input_size = 1
output_size = 1model = LinearRegressionModel(input_size, output_size)
model# 指定号参数和损失函数
epochs = 500
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()# train model
for epoch in range(epochs):epochs+=1#注意 将numpy格式的输入数据转换成 tensorinputs = torch.from_numpy(x_train)labels = torch.from_numpy(y_train)#每次迭代梯度清零optimizer.zero_grad()#前向传播outputs = model(inputs)#计算损失loss = criterion(outputs, labels)#反向传播loss.backward()#updates weight and parametersoptimizer.step()if epoch % 50 == 0:print("Epoch: {}, Loss: {}".format(epoch, loss.item()))# predict model test,预测结果并且奖结果转换成np格式
predicted =model(torch.from_numpy(x_train).requires_grad_()).data.numpy()
predicted#model save
torch.save(model.state_dict(),'model.pkl')#model 读取
model.load_state_dict(torch.load('model.pkl'))

在这里插入图片描述

3. GPU进行模型训练

只需要 将模型和数据传入到“cuda”中运行即可,详细实现见截图

import torch
import torch.nn as nn
import numpy as np# #构建一个回归方程 y = 2*x+1#构建输如数据,将输入numpy格式转成tensor格式
x_values = [i for i in range(11)]
x_train = np.array(x_values,dtype=np.float32)
x_train = x_train.reshape(-1,1)y_values = [2*i + 1 for i in x_values]
y_train = np.array(y_values, dtype=np.float32)
y_train = y_train.reshape(-1,1)## 线性回归模型: 本质上就是一个不加 激活函数的 全连接层
class LinearRegressionModel(nn.Module):def __init__(self, input_size, output_size):super(LinearRegressionModel, self).__init__()self.linear = nn.Linear(input_size, output_size)def forward(self, x):out = self.linear(x)return outinput_size = 1
output_size = 1model = LinearRegressionModel(input_size, output_size)device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 指定号参数和损失函数
epochs = 500
learning_rate = 0.01
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)
criterion = nn.MSELoss()# train model
for epoch in range(epochs):epochs+=1#注意 将numpy格式的输入数据转换成 tensorinputs = torch.from_numpy(x_train)labels = torch.from_numpy(y_train)#每次迭代梯度清零optimizer.zero_grad()#前向传播outputs = model(inputs)#计算损失loss = criterion(outputs, labels)#反向传播loss.backward()#updates weight and parametersoptimizer.step()if epoch % 50 == 0:print("Epoch: {}, Loss: {}".format(epoch, loss.item()))# predict model test,预测结果并且奖结果转换成np格式
predicted = model(torch.from_numpy(x_train).requires_grad_()).data.numpy()
predicted#model save
torch.save(model.state_dict(),'model.pkl')

在这里插入图片描述

相关文章:

【Pytorch】学习记录分享3——PyTorch 自动微分与线性回归

【【Pytorch】学习记录分享3——PyTorch 自动微分与线性回归 1. autograd 包,自动微分2. 线性模型回归演示3. GPU进行模型训练 小结:只需要将前向传播设置好,调用反向传播接口,即可实现反向传播的链式求导 1. autograd 包&#x…...

Android Studio实现俄罗斯方块

文章目录 一、项目概述二、开发环境三、详细设计3.1 CacheUtils类3.2 BlockAdapter类3.3 CommonAdapter类3.4 SelectActivity3.5 MainActivity 四、运行演示五、项目总结 一、项目概述 俄罗斯方块是一种经典的电子游戏,最早由俄罗斯人Alexey Pajitnov在1984年创建。…...

【Hive】——DDL(DATABASE)

1 概述 2 创建数据库 create database if not exists test_database comment "this is my first db" with dbproperties (createdByAllen);3 描述数据库信息 describe 可以简写为desc extended 可以展示更多信息 describe database test_database; describe databa…...

【华为OD题库-092】单词加密-java

题目 输入一个英文句子,句子中包含若干个单词,每个单词间有一个空格需要将句子中的每个单词按照要求加密输出。要求: 1)单词中包括元音字符(‘aeuio’、‘AEUIO’,大小写都算),则将元音字符替换成’*) 2)单词中不包括元音字符&…...

构建一个简单的 npm 验证项目

构建一个简单的 npm 验证项目 0. 背景1. 构建过程1-1. 创建项目并初始化1-2. 安装 mjs 支持的 package1-3. 在 package.json 中添加 mjs 脚本1-4. 创建 index.mjs 文件1-5. 执行脚本 2. (Optional)环境变量配置 0. 背景 工作上需要验证一下 npm 程序,所以需要构建一…...

利用vue-okr-tree实现飞书OKR对齐视图

vue-okr-tree-demo 因开发需求需要做一个类似飞书OKR对齐视图的功能,参考了两位大神的代码: 开源组件vue-okr-tree作者博客地址:http://t.csdnimg.cn/5gNfd 对组件二次封装的作者博客地址:http://t.csdnimg.cn/Tjaf0 开源组件v…...

持续集成交付CICD:CentOS 7 安装SaltStack

目录 一、理论 1.SaltStack 二、实验 1.主机一安装master 2.主机二安装第一台minion 3.主机三安装第二台minion 4.测试SaltStack 三、问题 1.CentOS 8 如何安装SaltStack 一、理论 1.SaltStack (1)概念 SaltStack是基于python开发的一套C/S自…...

vscode 环境配置

必备插件 配置调试 {// Use IntelliSense to learn about possible attributes.// Hover to view descriptions of existing attributes.// For more information, visit: https://go.microsoft.com/fwlink/?linkid830387"version": "0.2.0","confi…...

pytorch文本分类(二):引入pytorch处理文本数据

pytorch文本数据处理 目录 pytorch文本数据处理1. Pytorch背景2. 数据分割3. 数据加载Dataset代码分析字典的用途代码修改的目的 Dataloader 4. 练习 原学习任务链接 相关数据链接:https://pan.baidu.com/s/1iwE3LdRv3uAkGGI2fF9BjA?pwdro0v 提取码:ro…...

Centos硬盘操作合集

一、硬盘命令说明 lsblk 列出系统上的所有磁盘列表 查看磁盘列表 参数意义 blkid 列出硬盘UUID [rootzs ~]# blkid /dev/sda1: UUID"77dcd110-dad6-45b8-97d4-fa592dc56d07" TYPE"xfs" /dev/sda2: UUID"oDT0oD-LCIJ-Xh7r-lBfd-axLD-DRiN-Twa…...

三大循环语句

goto 我们看代码去感受goto的循环,那么goto循环最经常搭配的就是loop,那么就像如下代码 这个代码中loop:就是个标志,然后程序正常向下运行,goto loop;就会让她回到loop,然后在运行到goto loop…...

Mybatis详解

MyBatis是什么 MyBatis是一个持久层框架,用于简化数据库操作的开发。它通过将SQL语句和Java方法进行映射,实现了数据库操作的解耦和简化。以下是MyBatis的优点和缺点: 优点: 1. 灵活性:MyBatis允许开发人员编写原生的…...

spring cloud alibaba RocketMQ 最佳实践

目录 概述使用准备工作引入依赖创建Topic代码应用启动消息接收再扩展一个 结束 概述 github 文档地址 rocket mq example RocketMQ 版本为 5.1.4 使用 准备工作 阅读此文需要事先准备 RocketMQ ,如有疑问,请移步 RocketMQ 服务搭建 引入依赖 此处…...

php使用OpenCV实现从照片中截取身份证区域照片

<?php // 获取上传的文件 $file $_FILES[file]; // 获取文件的临时名称 $tmp_name $file[tmp_name]; // 获取文件的类型 $type $file[type]; // 获取文件的大小 $size $file[size]; // 获取文件的错误信息 $error $file[error]; // 检查文件是否上传成功 if ($er…...

抖音ip地址切换会看不到视频吗

随着社交媒体平台的快速发展&#xff0c;抖音已经成为了许多人分享生活点滴、展示才艺的热门平台。然而&#xff0c;有时候使用抖音时会遇到一些问题&#xff0c;比如IP地址切换后无法观看视频。那么&#xff0c;为什么会出现这种情况呢&#xff1f;让我们分析一下。 首先&…...

有关爬虫http/https的请求与响应

简介 HTTP协议&#xff08;HyperText Transfer Protocol&#xff0c;超文本传输协议&#xff09;&#xff1a;是一种发布和接收 HTML页面的方法。 HTTPS&#xff08;Hypertext Transfer Protocol over Secure Socket Layer&#xff09;简单讲是HTTP的安全版&#xff0c;在HTT…...

模块二——滑动窗口:438.找到字符串中所有字母异位词

文章目录 题目描述算法原理滑动窗口哈希表 代码实现 题目描述 题目链接&#xff1a;438.找到字符串中所有字母异位词 算法原理 滑动窗口哈希表 因为字符串p的异位词的⻓度⼀定与字符串p 的⻓度相同&#xff0c;所以我们可以在字符串s 中构造⼀个⻓度为与字符串p的⻓度相同…...

排序算法(二)-冒泡排序、选择排序、插入排序、希尔排序、快速排序、归并排序、基数排序

排序算法(二) 前面介绍了排序算法的时间复杂度和空间复杂数据结构与算法—排序算法&#xff08;一&#xff09;时间复杂度和空间复杂度介绍-CSDN博客&#xff0c;这次介绍各种排序算法——冒泡排序、选择排序、插入排序、希尔排序、快速排序、归并排序、基数排序。 文章目录 排…...

智能优化算法应用:基于探路者算法3D无线传感器网络(WSN)覆盖优化 - 附代码

智能优化算法应用&#xff1a;基于探路者算法3D无线传感器网络(WSN)覆盖优化 - 附代码 文章目录 智能优化算法应用&#xff1a;基于探路者算法3D无线传感器网络(WSN)覆盖优化 - 附代码1.无线传感网络节点模型2.覆盖数学模型及分析3.探路者算法4.实验参数设定5.算法结果6.参考文…...

高效排队,紧急响应:RabbitMQ Priority Queue全面指南【RabbitMQ 九】

欢迎来到我的博客&#xff0c;代码的世界里&#xff0c;每一行都是一个故事 高效排队&#xff0c;紧急响应&#xff1a;RabbitMQ Priority Queue全面指南 引言前言第一&#xff1a;初识RabbitMQ Priority Queue插件插件的背景和目的&#xff1a;为什么需要消息优先级&#xff1…...

XCTF-web-easyupload

试了试php&#xff0c;php7&#xff0c;pht&#xff0c;phtml等&#xff0c;都没有用 尝试.user.ini 抓包修改将.user.ini修改为jpg图片 在上传一个123.jpg 用蚁剑连接&#xff0c;得到flag...

【kafka】Golang实现分布式Masscan任务调度系统

要求&#xff1a; 输出两个程序&#xff0c;一个命令行程序&#xff08;命令行参数用flag&#xff09;和一个服务端程序。 命令行程序支持通过命令行参数配置下发IP或IP段、端口、扫描带宽&#xff0c;然后将消息推送到kafka里面。 服务端程序&#xff1a; 从kafka消费者接收…...

stm32G473的flash模式是单bank还是双bank?

今天突然有人stm32G473的flash模式是单bank还是双bank&#xff1f;由于时间太久&#xff0c;我真忘记了。搜搜发现&#xff0c;还真有人和我一样。见下面的链接&#xff1a;https://shequ.stmicroelectronics.cn/forum.php?modviewthread&tid644563 根据STM32G4系列参考手…...

Java 8 Stream API 入门到实践详解

一、告别 for 循环&#xff01; 传统痛点&#xff1a; Java 8 之前&#xff0c;集合操作离不开冗长的 for 循环和匿名类。例如&#xff0c;过滤列表中的偶数&#xff1a; List<Integer> list Arrays.asList(1, 2, 3, 4, 5); List<Integer> evens new ArrayList…...

vscode(仍待补充)

写于2025 6.9 主包将加入vscode这个更权威的圈子 vscode的基本使用 侧边栏 vscode还能连接ssh&#xff1f; debug时使用的launch文件 1.task.json {"tasks": [{"type": "cppbuild","label": "C/C: gcc.exe 生成活动文件"…...

UE5 学习系列(三)创建和移动物体

这篇博客是该系列的第三篇&#xff0c;是在之前两篇博客的基础上展开&#xff0c;主要介绍如何在操作界面中创建和拖动物体&#xff0c;这篇博客跟随的视频链接如下&#xff1a; B 站视频&#xff1a;s03-创建和移动物体 如果你不打算开之前的博客并且对UE5 比较熟的话按照以…...

SpringCloudGateway 自定义局部过滤器

场景&#xff1a; 将所有请求转化为同一路径请求&#xff08;方便穿网配置&#xff09;在请求头内标识原来路径&#xff0c;然后在将请求分发给不同服务 AllToOneGatewayFilterFactory import lombok.Getter; import lombok.Setter; import lombok.extern.slf4j.Slf4j; impor…...

【碎碎念】宝可梦 Mesh GO : 基于MESH网络的口袋妖怪 宝可梦GO游戏自组网系统

目录 游戏说明《宝可梦 Mesh GO》 —— 局域宝可梦探索Pokmon GO 类游戏核心理念应用场景Mesh 特性 宝可梦玩法融合设计游戏构想要素1. 地图探索&#xff08;基于物理空间 广播范围&#xff09;2. 野生宝可梦生成与广播3. 对战系统4. 道具与通信5. 延伸玩法 安全性设计 技术选…...

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决 问题背景 在一个基于 Spring Cloud Gateway WebFlux 构建的微服务项目中&#xff0c;新增了一个本地验证码接口 /code&#xff0c;使用函数式路由&#xff08;RouterFunction&#xff09;和 Hutool 的 Circle…...

JAVA后端开发——多租户

数据隔离是多租户系统中的核心概念&#xff0c;确保一个租户&#xff08;在这个系统中可能是一个公司或一个独立的客户&#xff09;的数据对其他租户是不可见的。在 RuoYi 框架&#xff08;您当前项目所使用的基础框架&#xff09;中&#xff0c;这通常是通过在数据表中增加一个…...