pytorch入门7--自动求导和神经网络
深度学习网上自学学了10多天了,看了很多大神的课总是很快被劝退。终于,遇到了一位对小白友好的刘二大人,先附上链接,需要者自取:https://b23.tv/RHlDxbc。
下面是课程笔记。
一、自动求导
举例说明自动求导。
torch中的张量有两个重要属性:data(值)和grad(梯度),当我们在定义一个张量时设requires_grad=True就是说明后续可以使用自动求导机制。
注意:pytorch里可以设置为自动求导的张量的元素需要是浮点型。
例如,对于e=(a + b) * (b + 1),可以用一个图表示如下:
我们定义张量时通常是从下往上定义,即先定义张量a,b,再定义张量e(由张量a,b的关系式组成),这样张量e的值就由a,b得到,这就是前向传播(前馈),通常定义为forward函数:
当我们要进行求导时,求:
可以看出,求导是从上到下的,逐级相乘再将路径相加。比如求e对b的偏导数,,从b到e有两条路径,每条路径从e开始逐级求导,结果相乘再将多条路径求导结果相加,这个过程加反向传递(反馈),通过pytorch封装好的backward函数实现。
下面的图比我手绘的应该清楚一些:
代码实现如下(以线性逻辑回归为例,y = w * x,给定训练数据集x,y,求最佳参数w拟合x与y的关系函数):
我们直到在深度学习中,我们都是将损失函数对参数求导,使用梯度下降法等方法使得损失函数最好,从而找到参数的最佳值。
# 训练数据集(人眼可以一下看出y=2*x是最好的拟合,但机器不知道,要一直训练
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0] # 参数w,初始值设为1.0
w = torch.tensor([1.0],requires_grad=True)# 前向传播
def forward(x):return x * w# 损失函数
def loss(x,y):y_pred = forward(x) # y_pred即为常说的y_hat,是y在当前w的值下计算的估计值,这里即建立了y_pred与w的关系,可以自动求导return (y_pred - y) ** 2# 训练数据集(梯度下降法)
for epoch in range(1000):for x,y in zip(x_data,y_data):l = loss(x,y)l.backward() # 自动求导,l对w求导,反向传播print('\tgrad:',x,y,w,w.grad.item()) # item()用于只含一个元素的tensor中提取值w.data = w.data - 0.01 * w.grad.data # **这里使用data属性就是为了防止使用自动求导机制**w.grad.data.zero_() # 将上一轮的梯度值清除print("progress:",epoch,l.item())# 测试结果
print("predict(after training)",4,forward(4).item()) # 计算当x=4时,根据训练出的模型求y的估计值
可以看出w的值一直在增加,直到加到2可以完全拟合训练集中x与y的关系,最后当x等于4时, 估计值接近8.
二、神经网络
我们知道,神经网络由多层组成,包括输入层、隐含层和输出层,每一层的包含不同个数的结点,每层的结点其实就是当前我们获得的数据的特征值(features),例如输入层(x1,x2,x3,x4,x5)有五个结点分别表示五个特征值,第一层的隐藏层有六个结点,这是就需要一个6 * 5的矩阵w将x的5个特征值转变为6个特征值。当然也可以添加偏置值b如下图所示:
而这个矩阵w就是我们要训练出包含着某种关系的参数矩阵,再一层一层的变换,每层都有一个参数矩阵,最终到达输出矩阵的四个特征,即(y1,y2,y3,y4)。
为了是我们的神经网络模型更好地拟合非线性函数关系,还可以使用激活函数:
激活函数前面的文章讲过,这里不再说了。使用sigmoid激活函数如下:
代码实现:
1.单隐藏层的神经网络模型
pytorch对于神经网络的代码封装得很好。
# 训练数据集
x_data = torch.tensor([[1.0],[2.0],[3.0]])
y_data = torch.tensor([[2.0],[4.0],[6.0]])# 定义一个单隐藏层得神经网络
class LinearModel(torch.nn.Module):# 神经网络的类必须继承类Moduledef __init__(self):super(LinearModel, self).__init__()self.linear = torch.nn.Linear(1,1) # torch.nn.Linear(1,1)表示该神经网络处理的是n * 1的输入,输出也是n * 1。# torch.nn.Linear()第三个参数是bias,设置为True即含有偏置值b,为False不适用偏置值,默认值为True。def forward(self,x):y_pred = self.linear(x) # 使用封装好的linear()计算y的预测值return y_pred# 生成神经网络的模型
model = LinearModel()# 损失函数
criterion = torch.nn.MSELoss(size_average=False) # size_average=False表示损失函数不求平均值# 优化器(梯度下降)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01) # model.parameters()可以获取神经网络中的所有参数参数矩阵的值,对其进行优化
# lr表示步长# 训练数据集
for epoch in range(100):y_pred = model(x_data) # 将数据传入搭建好的神经网络模型得到估计值loss = criterion(y_pred,y_data) # 计算损失值print(epoch,loss)optimizer.zero_grad() # 清除上次的梯度值loss.backward() # 自动求导optimizer.step() # 优化参数# 输出结果
print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())# 测试模型
x_test = torch.tensor([4.0])
y_test = model(x_test)
print('y_pred=',y_test.data)
说明:这里的torch.nn.Linear(1,1)表示该神经网络处理的是n * 1的输入,输出也是n * 1,其它情况使用情况如下:
可以看出,训练100轮效果不佳,可以训练1000次看看不同结果。
2.多隐藏层的神经网络模型
与单隐藏层神经网络模型区别如下:
class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()self.linear1 = torch.nn.Linear(8,6) # 模型从8维变为6维,再从6维变为4维,再从4维变为1维self.linear2 = torch.nn.Linear(6,4)self.linear1 = torch.nn.Linear(4,1)self.sigmoid = torch.nn.Sigmoid() # 使用sigmoid激活函数def forward(self,x):pred1 = self.sigmoid(self.linear1(x)) # 上一层输出结果传给下一层pred2 = self.sigmoid(self.linear2(pred1))y_pred = self.sigmoid(self.linear3(pred2))return nmodel = Model()
相关文章:

pytorch入门7--自动求导和神经网络
深度学习网上自学学了10多天了,看了很多大神的课总是很快被劝退。终于,遇到了一位对小白友好的刘二大人,先附上链接,需要者自取:https://b23.tv/RHlDxbc。 下面是课程笔记。 一、自动求导 举例说明自动求导。 torch中的…...
QT 之wayland 事件处理分析基于qt5wayland5.14.2
1. Qt wayland 初始化 接收鼠标/案件,触摸屏等事件事件 QWaylandNativeInterface : public QPlatformNativeInterface 在QWaylandNativeInterface 继承qpa 接口类QPlatformNativeInterface; 1.1 初始化鼠标: void *QWaylandNativeInterface::nativeR…...
【this 和 super 的区别】
在 Java 中,this 和 super 都是关键字,表示当前对象和父类对象。 this 关键字可以用于以下几种情况: 引用当前对象的成员变量,方法和构造方法,用于区分局部变量和成员变量重名的情况; 调用当前类的另外一…...

K8s:Monokle Desktop 一个集Yaml资源编写、项目管理、集群管理的 K8s IDE
写在前面 Monokle Desktop 是 kubeshop 推出的一个开源的 K8s IDE相关项目还有 Monokle CLI 和 Monokle Cloud相比其他的工具,Monokle Desktop 功能较全面,涉及 k8s 管理的整个生命周期博文内容:Monokle Desktop 下载安装,项目管理…...

自动化测试实战篇(8),jmeter并发测试登录接口,模拟从100到1000个用户同时登录测试服务器压力
首先进行使用jmeter进行并发测试之前就需要搞清楚线程和进程的区别还需要理解什么是并发、高并发、并行。还需要理解高并发中的以及老生常谈的,TCP三次握手协议和TCP四次握手协议**TCP三次握手协议指:****TCP四次挥手协议:**进入Jmeter&#…...

ATTCK v12版本战术实战研究—持久化(二)
一、前言前几期文章中,我们介绍了ATT&CK中侦察、资源开发、初始访问、执行战术、持久化战术的知识。那么从前文中介绍的相关持久化子技术来开展测试,进行更深一步的分析。本文主要内容是介绍攻击者在运用持久化子技术时,在相关的资产服务…...
python函数式编程
1 callable内建函数判断一个名字是否为一个可调用函数 >>> import math >>> x 1 >>> y math.sqrt >>> callable(x) False >>> callable(y) True 2 记录函数(文档字符串) >>> def square(x): …...
3.linux下安装mysql
1.安装前的环境准备 查看是否安装过mysql 首先检测Linux操作系统中是否安装了MySQL: # rpm -qa | grep -i mysql 卸载安装包 如果有信息出现,则进行删除,命令如下: # rpm -e --nodeps 包名 删除老版本mysql的开发头文件和…...
17、MySQL分库分表,原理实战
MySQL分库分表,原理实战 1.MyCAT分布式架构入门及双主架构1.1 主从架构1.2 MyCAT安装1.3 启动和连接1.4 配置文件介绍2.MyCAT读写分离架构2.1 架构说明2.2 创建用户2.3 schema.xml2.4 连接说明2.5 读写测试2.6 当前是单节点3.MyCAT高可用读写分离架构3.1 架构说明3.3 schema.xm…...

【C++的OpenCV】第九课-OpenCV图像常用操作(六):图像形态学-阈值的概念、功能及操作(threshold()函数))
目录一、阈值(thresh)的概念二、阈值在图形学中的用途三、阈值的作用和操作3.1 在OpenCV中可以进行的阈值操作3.2 操作实例3.2.1 threshold()函数介绍3.2.2 实例3.2.3 结果上节课的内容(作者还是鼓励各位同学按照顺序进行学习哦)&…...

[Java代码审计]—MCMS
环境搭建 MCMS 5.2.4:https://gitee.com/mingSoft/MCMS/tree/5.2.4/利用 idea 打开项目 创建数据库 mcms,导入 doc/mcms-5.2.8.sql 修改 src/main/resources/application-dev.yml 中关于数据库设置参数 启动项目登录后台 http://localhost:8080/ms/l…...
《程序员面试金典(第6版)》面试题 01.08. 零矩阵
题目描述 编写代码,移除未排序链表中的重复节点。保留最开始出现的节点。 示例1: 输入:[1, 2, 3, 3, 2, 1] 输出:[1, 2, 3] -示例2: 输入:[1, 1, 1, 1, 2] 输出:[1, 2] 提示: 链表长度在[0, 20000]范…...
初识 Python
文章目录简介用途解释器命令行模式交互模式输入和输出简介 高级编程语言,解释型语言代码在执行时会逐行翻译成 CPU 能理解的机器码代码精简,但运行速度慢基础代码库丰富,还有大量第三方库代码不能加密 用途 网络应用工具软件包装其他语言开…...
常用sql语句分享
SELECT COUNT(DISTINCT money) FROM ac_association_course;#COUNT() 函数返回匹配指定条件的行数SELECT AVG(money) FROM ac_association_course;#AVG 函数返回数值列的平均值。NULL 值不包括在计算中SELECT id FROM ac_association_course order by id desc limit 1;#返回最大…...

极狐GitLab DevSecOps 为企业许可证安全合规保驾护航
本文来自: 小马哥 极狐(GitLab) 技术布道师 开源许可证是开源软件的法律武器,是第三方正确使用开源软件的安全合规依据。 根据 Linux 发布的 SBOM 报告显示,98% 的企业都在使用开源软件(中文版报告详情)。随着开源使用…...
后端程序员的前端基础-前端三剑客之HTML
文章目录1 HTML简介1.1 什么是HTML1.2 HTML能做什么1.3 HTML书写规范2 HTML基本标签2.1 结构标签2.2 排版标签2.3 块标签2.4 基本文字标签2.5 文本格式化标签2.6 标题标签2.7 列表标签(清单标签)2.8 图片标签2.9 链接标签2.10 表格标签3 HTML表单标签3.1 form元素常用属性3.2 i…...

VS2019加载解决方案时不能自动打开之前的文档(回忆消失)
✏️作者:枫霜剑客 📋系列专栏:C实战宝典 🌲上一篇: 错误error c3861 :“_T“:找不到标识符 逐梦编程,让中华屹立世界之巅。 简单的事情重复做,重复的事情用心做,用心的事情坚持做; 文章目录前言一、问题描…...

ConcurrentHashMap-Java八股面试(五)
系列文章目录 第一章 ArrayList-Java八股面试(一) 第二章 HashMap-Java八股面试(二) 第三章 单例模式-Java八股面试(三) 第四章 线程池和Volatile关键字-Java八股面试(四) 提示:动态每日更新算法题,想要学习的可以关注一下 文章目录系列文章目录一、…...

互联网衰退期,测试工程师35岁的路该怎么走...
国内的互联网行业发展较快,所以造成了技术研发类员工工作强度比较大,同时技术的快速更新又需要员工不断的学习新的技术。因此淘汰率也比较高,超过35岁的基层研发类员工,往往因为家庭原因、身体原因,比较难以跟得上工作…...

Windows Cannot Initialize Data Bindings 问题的解决方法
前言 拿到一个调试程序, 怎么折腾都打不开, 在客户那边, 尝试了几个系统版本, 发现Windows 10 21H2 版本可以正常运行。 尝试 系统篇 系统结果公司电脑 Windows 8有问题…下载安装 Windows10 22H2问题依旧下载安装 Windows10 21H2问题依旧家里的 笔记本Window 11正常 网上…...
反向工程与模型迁移:打造未来商品详情API的可持续创新体系
在电商行业蓬勃发展的当下,商品详情API作为连接电商平台与开发者、商家及用户的关键纽带,其重要性日益凸显。传统商品详情API主要聚焦于商品基本信息(如名称、价格、库存等)的获取与展示,已难以满足市场对个性化、智能…...
Spring Boot 实现流式响应(兼容 2.7.x)
在实际开发中,我们可能会遇到一些流式数据处理的场景,比如接收来自上游接口的 Server-Sent Events(SSE) 或 流式 JSON 内容,并将其原样中转给前端页面或客户端。这种情况下,传统的 RestTemplate 缓存机制会…...
Oracle查询表空间大小
1 查询数据库中所有的表空间以及表空间所占空间的大小 SELECTtablespace_name,sum( bytes ) / 1024 / 1024 FROMdba_data_files GROUP BYtablespace_name; 2 Oracle查询表空间大小及每个表所占空间的大小 SELECTtablespace_name,file_id,file_name,round( bytes / ( 1024 …...
【磁盘】每天掌握一个Linux命令 - iostat
目录 【磁盘】每天掌握一个Linux命令 - iostat工具概述安装方式核心功能基础用法进阶操作实战案例面试题场景生产场景 注意事项 【磁盘】每天掌握一个Linux命令 - iostat 工具概述 iostat(I/O Statistics)是Linux系统下用于监视系统输入输出设备和CPU使…...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)
UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中,UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化…...
稳定币的深度剖析与展望
一、引言 在当今数字化浪潮席卷全球的时代,加密货币作为一种新兴的金融现象,正以前所未有的速度改变着我们对传统货币和金融体系的认知。然而,加密货币市场的高度波动性却成为了其广泛应用和普及的一大障碍。在这样的背景下,稳定…...

【7色560页】职场可视化逻辑图高级数据分析PPT模版
7种色调职场工作汇报PPT,橙蓝、黑红、红蓝、蓝橙灰、浅蓝、浅绿、深蓝七种色调模版 【7色560页】职场可视化逻辑图高级数据分析PPT模版:职场可视化逻辑图分析PPT模版https://pan.quark.cn/s/78aeabbd92d1...
苹果AI眼镜:从“工具”到“社交姿态”的范式革命——重新定义AI交互入口的未来机会
在2025年的AI硬件浪潮中,苹果AI眼镜(Apple Glasses)正在引发一场关于“人机交互形态”的深度思考。它并非简单地替代AirPods或Apple Watch,而是开辟了一个全新的、日常可接受的AI入口。其核心价值不在于功能的堆叠,而在于如何通过形态设计打破社交壁垒,成为用户“全天佩戴…...

抽象类和接口(全)
一、抽象类 1.概念:如果⼀个类中没有包含⾜够的信息来描绘⼀个具体的对象,这样的类就是抽象类。 像是没有实际⼯作的⽅法,我们可以把它设计成⼀个抽象⽅法,包含抽象⽅法的类我们称为抽象类。 2.语法 在Java中,⼀个类如果被 abs…...

高考志愿填报管理系统---开发介绍
高考志愿填报管理系统是一款专为教育机构、学校和教师设计的学生信息管理和志愿填报辅助平台。系统基于Django框架开发,采用现代化的Web技术,为教育工作者提供高效、安全、便捷的学生管理解决方案。 ## 📋 系统概述 ### 🎯 系统定…...