当前位置: 首页 > news >正文

pytorch入门7--自动求导和神经网络

深度学习网上自学学了10多天了,看了很多大神的课总是很快被劝退。终于,遇到了一位对小白友好的刘二大人,先附上链接,需要者自取:https://b23.tv/RHlDxbc。
下面是课程笔记。
一、自动求导
举例说明自动求导。
torch中的张量有两个重要属性:data(值)和grad(梯度),当我们在定义一个张量时设requires_grad=True就是说明后续可以使用自动求导机制。
在这里插入图片描述
注意:pytorch里可以设置为自动求导的张量的元素需要是浮点型。
例如,对于e=(a + b) * (b + 1),可以用一个图表示如下:
在这里插入图片描述
我们定义张量时通常是从下往上定义,即先定义张量a,b,再定义张量e(由张量a,b的关系式组成),这样张量e的值就由a,b得到,这就是前向传播(前馈),通常定义为forward函数:
在这里插入图片描述
当我们要进行求导时,求:
在这里插入图片描述
在这里插入图片描述
可以看出,求导是从上到下的,逐级相乘再将路径相加。比如求e对b的偏导数,,从b到e有两条路径,每条路径从e开始逐级求导,结果相乘再将多条路径求导结果相加,这个过程加反向传递(反馈),通过pytorch封装好的backward函数实现。
下面的图比我手绘的应该清楚一些:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
代码实现如下(以线性逻辑回归为例,y = w * x,给定训练数据集x,y,求最佳参数w拟合x与y的关系函数):
我们直到在深度学习中,我们都是将损失函数对参数求导,使用梯度下降法等方法使得损失函数最好,从而找到参数的最佳值。

# 训练数据集(人眼可以一下看出y=2*x是最好的拟合,但机器不知道,要一直训练
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0] # 参数w,初始值设为1.0
w = torch.tensor([1.0],requires_grad=True)# 前向传播
def forward(x):return x * w# 损失函数
def loss(x,y):y_pred = forward(x) # y_pred即为常说的y_hat,是y在当前w的值下计算的估计值,这里即建立了y_pred与w的关系,可以自动求导return (y_pred - y) ** 2# 训练数据集(梯度下降法)
for epoch in range(1000):for x,y in zip(x_data,y_data):l = loss(x,y)l.backward() # 自动求导,l对w求导,反向传播print('\tgrad:',x,y,w,w.grad.item()) # item()用于只含一个元素的tensor中提取值w.data = w.data - 0.01 * w.grad.data # **这里使用data属性就是为了防止使用自动求导机制**w.grad.data.zero_() # 将上一轮的梯度值清除print("progress:",epoch,l.item())# 测试结果
print("predict(after training)",4,forward(4).item()) # 计算当x=4时,根据训练出的模型求y的估计值

在这里插入图片描述
在这里插入图片描述
可以看出w的值一直在增加,直到加到2可以完全拟合训练集中x与y的关系,最后当x等于4时, 估计值接近8.

二、神经网络
在这里插入图片描述
我们知道,神经网络由多层组成,包括输入层、隐含层和输出层,每一层的包含不同个数的结点,每层的结点其实就是当前我们获得的数据的特征值(features),例如输入层(x1,x2,x3,x4,x5)有五个结点分别表示五个特征值,第一层的隐藏层有六个结点,这是就需要一个6 * 5的矩阵w将x的5个特征值转变为6个特征值。当然也可以添加偏置值b如下图所示:
在这里插入图片描述
在这里插入图片描述
而这个矩阵w就是我们要训练出包含着某种关系的参数矩阵,再一层一层的变换,每层都有一个参数矩阵,最终到达输出矩阵的四个特征,即(y1,y2,y3,y4)。
为了是我们的神经网络模型更好地拟合非线性函数关系,还可以使用激活函数:
在这里插入图片描述
激活函数前面的文章讲过,这里不再说了。使用sigmoid激活函数如下:
在这里插入图片描述
代码实现:
1.单隐藏层的神经网络模型
pytorch对于神经网络的代码封装得很好。

# 训练数据集
x_data = torch.tensor([[1.0],[2.0],[3.0]])
y_data = torch.tensor([[2.0],[4.0],[6.0]])# 定义一个单隐藏层得神经网络
class LinearModel(torch.nn.Module):# 神经网络的类必须继承类Moduledef __init__(self):super(LinearModel, self).__init__()self.linear = torch.nn.Linear(1,1) # torch.nn.Linear(1,1)表示该神经网络处理的是n * 1的输入,输出也是n * 1。# torch.nn.Linear()第三个参数是bias,设置为True即含有偏置值b,为False不适用偏置值,默认值为True。def forward(self,x):y_pred = self.linear(x) # 使用封装好的linear()计算y的预测值return y_pred# 生成神经网络的模型
model = LinearModel()# 损失函数
criterion = torch.nn.MSELoss(size_average=False) # size_average=False表示损失函数不求平均值# 优化器(梯度下降)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01) # model.parameters()可以获取神经网络中的所有参数参数矩阵的值,对其进行优化
# lr表示步长# 训练数据集
for epoch in range(100):y_pred = model(x_data) # 将数据传入搭建好的神经网络模型得到估计值loss = criterion(y_pred,y_data) # 计算损失值print(epoch,loss)optimizer.zero_grad() # 清除上次的梯度值loss.backward() # 自动求导optimizer.step() # 优化参数# 输出结果
print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())# 测试模型
x_test = torch.tensor([4.0])
y_test = model(x_test)
print('y_pred=',y_test.data)

在这里插入图片描述
在这里插入图片描述
说明:这里的torch.nn.Linear(1,1)表示该神经网络处理的是n * 1的输入,输出也是n * 1,其它情况使用情况如下:
在这里插入图片描述

可以看出,训练100轮效果不佳,可以训练1000次看看不同结果。

2.多隐藏层的神经网络模型
与单隐藏层神经网络模型区别如下:

class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()self.linear1 = torch.nn.Linear(8,6) # 模型从8维变为6维,再从6维变为4维,再从4维变为1维self.linear2 = torch.nn.Linear(6,4)self.linear1 = torch.nn.Linear(4,1)self.sigmoid = torch.nn.Sigmoid() # 使用sigmoid激活函数def forward(self,x):pred1 = self.sigmoid(self.linear1(x)) # 上一层输出结果传给下一层pred2 = self.sigmoid(self.linear2(pred1))y_pred = self.sigmoid(self.linear3(pred2))return nmodel = Model()

相关文章:

pytorch入门7--自动求导和神经网络

深度学习网上自学学了10多天了,看了很多大神的课总是很快被劝退。终于,遇到了一位对小白友好的刘二大人,先附上链接,需要者自取:https://b23.tv/RHlDxbc。 下面是课程笔记。 一、自动求导 举例说明自动求导。 torch中的…...

QT 之wayland 事件处理分析基于qt5wayland5.14.2

1. Qt wayland 初始化 接收鼠标/案件,触摸屏等事件事件 QWaylandNativeInterface : public QPlatformNativeInterface 在QWaylandNativeInterface 继承qpa 接口类QPlatformNativeInterface; 1.1 初始化鼠标: void *QWaylandNativeInterface::nativeR…...

【this 和 super 的区别】

在 Java 中,this 和 super 都是关键字,表示当前对象和父类对象。 this 关键字可以用于以下几种情况: 引用当前对象的成员变量,方法和构造方法,用于区分局部变量和成员变量重名的情况; 调用当前类的另外一…...

K8s:Monokle Desktop 一个集Yaml资源编写、项目管理、集群管理的 K8s IDE

写在前面 Monokle Desktop 是 kubeshop 推出的一个开源的 K8s IDE相关项目还有 Monokle CLI 和 Monokle Cloud相比其他的工具,Monokle Desktop 功能较全面,涉及 k8s 管理的整个生命周期博文内容:Monokle Desktop 下载安装,项目管理…...

自动化测试实战篇(8),jmeter并发测试登录接口,模拟从100到1000个用户同时登录测试服务器压力

首先进行使用jmeter进行并发测试之前就需要搞清楚线程和进程的区别还需要理解什么是并发、高并发、并行。还需要理解高并发中的以及老生常谈的,TCP三次握手协议和TCP四次握手协议**TCP三次握手协议指:****TCP四次挥手协议:**进入Jmeter&#…...

ATTCK v12版本战术实战研究—持久化(二)

一、前言前几期文章中,我们介绍了ATT&CK中侦察、资源开发、初始访问、执行战术、持久化战术的知识。那么从前文中介绍的相关持久化子技术来开展测试,进行更深一步的分析。本文主要内容是介绍攻击者在运用持久化子技术时,在相关的资产服务…...

python函数式编程

1 callable内建函数判断一个名字是否为一个可调用函数 >>> import math >>> x 1 >>> y math.sqrt >>> callable(x) False >>> callable(y) True 2 记录函数(文档字符串) >>> def square(x): …...

3.linux下安装mysql

1.安装前的环境准备 查看是否安装过mysql 首先检测Linux操作系统中是否安装了MySQL: # rpm -qa | grep -i mysql 卸载安装包 如果有信息出现,则进行删除,命令如下: # rpm -e --nodeps 包名 删除老版本mysql的开发头文件和…...

17、MySQL分库分表,原理实战

MySQL分库分表,原理实战 1.MyCAT分布式架构入门及双主架构1.1 主从架构1.2 MyCAT安装1.3 启动和连接1.4 配置文件介绍2.MyCAT读写分离架构2.1 架构说明2.2 创建用户2.3 schema.xml2.4 连接说明2.5 读写测试2.6 当前是单节点3.MyCAT高可用读写分离架构3.1 架构说明3.3 schema.xm…...

【C++的OpenCV】第九课-OpenCV图像常用操作(六):图像形态学-阈值的概念、功能及操作(threshold()函数))

目录一、阈值(thresh)的概念二、阈值在图形学中的用途三、阈值的作用和操作3.1 在OpenCV中可以进行的阈值操作3.2 操作实例3.2.1 threshold()函数介绍3.2.2 实例3.2.3 结果上节课的内容(作者还是鼓励各位同学按照顺序进行学习哦)&…...

[Java代码审计]—MCMS

环境搭建 MCMS 5.2.4:https://gitee.com/mingSoft/MCMS/tree/5.2.4/利用 idea 打开项目 创建数据库 mcms,导入 doc/mcms-5.2.8.sql 修改 src/main/resources/application-dev.yml 中关于数据库设置参数 启动项目登录后台 http://localhost:8080/ms/l…...

《程序员面试金典(第6版)》面试题 01.08. 零矩阵

题目描述 编写代码,移除未排序链表中的重复节点。保留最开始出现的节点。 示例1: 输入:[1, 2, 3, 3, 2, 1] 输出:[1, 2, 3] -示例2: 输入:[1, 1, 1, 1, 2] 输出:[1, 2] 提示: 链表长度在[0, 20000]范…...

初识 Python

文章目录简介用途解释器命令行模式交互模式输入和输出简介 高级编程语言,解释型语言代码在执行时会逐行翻译成 CPU 能理解的机器码代码精简,但运行速度慢基础代码库丰富,还有大量第三方库代码不能加密 用途 网络应用工具软件包装其他语言开…...

常用sql语句分享

SELECT COUNT(DISTINCT money) FROM ac_association_course;#COUNT() 函数返回匹配指定条件的行数SELECT AVG(money) FROM ac_association_course;#AVG 函数返回数值列的平均值。NULL 值不包括在计算中SELECT id FROM ac_association_course order by id desc limit 1;#返回最大…...

极狐GitLab DevSecOps 为企业许可证安全合规保驾护航

本文来自: 小马哥 极狐(GitLab) 技术布道师 开源许可证是开源软件的法律武器,是第三方正确使用开源软件的安全合规依据。 根据 Linux 发布的 SBOM 报告显示,98% 的企业都在使用开源软件(中文版报告详情)。随着开源使用…...

后端程序员的前端基础-前端三剑客之HTML

文章目录1 HTML简介1.1 什么是HTML1.2 HTML能做什么1.3 HTML书写规范2 HTML基本标签2.1 结构标签2.2 排版标签2.3 块标签2.4 基本文字标签2.5 文本格式化标签2.6 标题标签2.7 列表标签(清单标签)2.8 图片标签2.9 链接标签2.10 表格标签3 HTML表单标签3.1 form元素常用属性3.2 i…...

VS2019加载解决方案时不能自动打开之前的文档(回忆消失)

✏️作者:枫霜剑客 📋系列专栏:C实战宝典 🌲上一篇: 错误error c3861 :“_T“:找不到标识符 逐梦编程,让中华屹立世界之巅。 简单的事情重复做,重复的事情用心做,用心的事情坚持做; 文章目录前言一、问题描…...

ConcurrentHashMap-Java八股面试(五)

系列文章目录 第一章 ArrayList-Java八股面试(一) 第二章 HashMap-Java八股面试(二) 第三章 单例模式-Java八股面试(三) 第四章 线程池和Volatile关键字-Java八股面试(四) 提示:动态每日更新算法题,想要学习的可以关注一下 文章目录系列文章目录一、…...

互联网衰退期,测试工程师35岁的路该怎么走...

国内的互联网行业发展较快,所以造成了技术研发类员工工作强度比较大,同时技术的快速更新又需要员工不断的学习新的技术。因此淘汰率也比较高,超过35岁的基层研发类员工,往往因为家庭原因、身体原因,比较难以跟得上工作…...

Windows Cannot Initialize Data Bindings 问题的解决方法

前言 拿到一个调试程序, 怎么折腾都打不开, 在客户那边, 尝试了几个系统版本, 发现Windows 10 21H2 版本可以正常运行。 尝试 系统篇 系统结果公司电脑 Windows 8有问题…下载安装 Windows10 22H2问题依旧下载安装 Windows10 21H2问题依旧家里的 笔记本Window 11正常 网上…...

SpringBoot-17-MyBatis动态SQL标签之常用标签

文章目录 1 代码1.1 实体User.java1.2 接口UserMapper.java1.3 映射UserMapper.xml1.3.1 标签if1.3.2 标签if和where1.3.3 标签choose和when和otherwise1.4 UserController.java2 常用动态SQL标签2.1 标签set2.1.1 UserMapper.java2.1.2 UserMapper.xml2.1.3 UserController.ja…...

web vue 项目 Docker化部署

Web 项目 Docker 化部署详细教程 目录 Web 项目 Docker 化部署概述Dockerfile 详解 构建阶段生产阶段 构建和运行 Docker 镜像 1. Web 项目 Docker 化部署概述 Docker 化部署的主要步骤分为以下几个阶段: 构建阶段(Build Stage)&#xff1a…...

智慧医疗能源事业线深度画像分析(上)

引言 医疗行业作为现代社会的关键基础设施,其能源消耗与环境影响正日益受到关注。随着全球"双碳"目标的推进和可持续发展理念的深入,智慧医疗能源事业线应运而生,致力于通过创新技术与管理方案,重构医疗领域的能源使用模式。这一事业线融合了能源管理、可持续发…...

脑机新手指南(八):OpenBCI_GUI:从环境搭建到数据可视化(下)

一、数据处理与分析实战 (一)实时滤波与参数调整 基础滤波操作 60Hz 工频滤波:勾选界面右侧 “60Hz” 复选框,可有效抑制电网干扰(适用于北美地区,欧洲用户可调整为 50Hz)。 平滑处理&…...

蓝桥杯 2024 15届国赛 A组 儿童节快乐

P10576 [蓝桥杯 2024 国 A] 儿童节快乐 题目描述 五彩斑斓的气球在蓝天下悠然飘荡,轻快的音乐在耳边持续回荡,小朋友们手牵着手一同畅快欢笑。在这样一片安乐祥和的氛围下,六一来了。 今天是六一儿童节,小蓝老师为了让大家在节…...

Nginx server_name 配置说明

Nginx 是一个高性能的反向代理和负载均衡服务器,其核心配置之一是 server 块中的 server_name 指令。server_name 决定了 Nginx 如何根据客户端请求的 Host 头匹配对应的虚拟主机(Virtual Host)。 1. 简介 Nginx 使用 server_name 指令来确定…...

sqlserver 根据指定字符 解析拼接字符串

DECLARE LotNo NVARCHAR(50)A,B,C DECLARE xml XML ( SELECT <x> REPLACE(LotNo, ,, </x><x>) </x> ) DECLARE ErrorCode NVARCHAR(50) -- 提取 XML 中的值 SELECT value x.value(., VARCHAR(MAX))…...

docker 部署发现spring.profiles.active 问题

报错&#xff1a; org.springframework.boot.context.config.InvalidConfigDataPropertyException: Property spring.profiles.active imported from location class path resource [application-test.yml] is invalid in a profile specific resource [origin: class path re…...

C#学习第29天:表达式树(Expression Trees)

目录 什么是表达式树&#xff1f; 核心概念 1.表达式树的构建 2. 表达式树与Lambda表达式 3.解析和访问表达式树 4.动态条件查询 表达式树的优势 1.动态构建查询 2.LINQ 提供程序支持&#xff1a; 3.性能优化 4.元数据处理 5.代码转换和重写 适用场景 代码复杂性…...

4. TypeScript 类型推断与类型组合

一、类型推断 (一) 什么是类型推断 TypeScript 的类型推断会根据变量、函数返回值、对象和数组的赋值和使用方式&#xff0c;自动确定它们的类型。 这一特性减少了显式类型注解的需要&#xff0c;在保持类型安全的同时简化了代码。通过分析上下文和初始值&#xff0c;TypeSc…...