pytorch入门7--自动求导和神经网络
深度学习网上自学学了10多天了,看了很多大神的课总是很快被劝退。终于,遇到了一位对小白友好的刘二大人,先附上链接,需要者自取:https://b23.tv/RHlDxbc。
下面是课程笔记。
一、自动求导
举例说明自动求导。
torch中的张量有两个重要属性:data(值)和grad(梯度),当我们在定义一个张量时设requires_grad=True就是说明后续可以使用自动求导机制。
注意:pytorch里可以设置为自动求导的张量的元素需要是浮点型。
例如,对于e=(a + b) * (b + 1),可以用一个图表示如下:
我们定义张量时通常是从下往上定义,即先定义张量a,b,再定义张量e(由张量a,b的关系式组成),这样张量e的值就由a,b得到,这就是前向传播(前馈),通常定义为forward函数:
当我们要进行求导时,求:
可以看出,求导是从上到下的,逐级相乘再将路径相加。比如求e对b的偏导数,,从b到e有两条路径,每条路径从e开始逐级求导,结果相乘再将多条路径求导结果相加,这个过程加反向传递(反馈),通过pytorch封装好的backward函数实现。
下面的图比我手绘的应该清楚一些:
代码实现如下(以线性逻辑回归为例,y = w * x,给定训练数据集x,y,求最佳参数w拟合x与y的关系函数):
我们直到在深度学习中,我们都是将损失函数对参数求导,使用梯度下降法等方法使得损失函数最好,从而找到参数的最佳值。
# 训练数据集(人眼可以一下看出y=2*x是最好的拟合,但机器不知道,要一直训练
x_data = [1.0,2.0,3.0]
y_data = [2.0,4.0,6.0] # 参数w,初始值设为1.0
w = torch.tensor([1.0],requires_grad=True)# 前向传播
def forward(x):return x * w# 损失函数
def loss(x,y):y_pred = forward(x) # y_pred即为常说的y_hat,是y在当前w的值下计算的估计值,这里即建立了y_pred与w的关系,可以自动求导return (y_pred - y) ** 2# 训练数据集(梯度下降法)
for epoch in range(1000):for x,y in zip(x_data,y_data):l = loss(x,y)l.backward() # 自动求导,l对w求导,反向传播print('\tgrad:',x,y,w,w.grad.item()) # item()用于只含一个元素的tensor中提取值w.data = w.data - 0.01 * w.grad.data # **这里使用data属性就是为了防止使用自动求导机制**w.grad.data.zero_() # 将上一轮的梯度值清除print("progress:",epoch,l.item())# 测试结果
print("predict(after training)",4,forward(4).item()) # 计算当x=4时,根据训练出的模型求y的估计值
可以看出w的值一直在增加,直到加到2可以完全拟合训练集中x与y的关系,最后当x等于4时, 估计值接近8.
二、神经网络
我们知道,神经网络由多层组成,包括输入层、隐含层和输出层,每一层的包含不同个数的结点,每层的结点其实就是当前我们获得的数据的特征值(features),例如输入层(x1,x2,x3,x4,x5)有五个结点分别表示五个特征值,第一层的隐藏层有六个结点,这是就需要一个6 * 5的矩阵w将x的5个特征值转变为6个特征值。当然也可以添加偏置值b如下图所示:
而这个矩阵w就是我们要训练出包含着某种关系的参数矩阵,再一层一层的变换,每层都有一个参数矩阵,最终到达输出矩阵的四个特征,即(y1,y2,y3,y4)。
为了是我们的神经网络模型更好地拟合非线性函数关系,还可以使用激活函数:
激活函数前面的文章讲过,这里不再说了。使用sigmoid激活函数如下:
代码实现:
1.单隐藏层的神经网络模型
pytorch对于神经网络的代码封装得很好。
# 训练数据集
x_data = torch.tensor([[1.0],[2.0],[3.0]])
y_data = torch.tensor([[2.0],[4.0],[6.0]])# 定义一个单隐藏层得神经网络
class LinearModel(torch.nn.Module):# 神经网络的类必须继承类Moduledef __init__(self):super(LinearModel, self).__init__()self.linear = torch.nn.Linear(1,1) # torch.nn.Linear(1,1)表示该神经网络处理的是n * 1的输入,输出也是n * 1。# torch.nn.Linear()第三个参数是bias,设置为True即含有偏置值b,为False不适用偏置值,默认值为True。def forward(self,x):y_pred = self.linear(x) # 使用封装好的linear()计算y的预测值return y_pred# 生成神经网络的模型
model = LinearModel()# 损失函数
criterion = torch.nn.MSELoss(size_average=False) # size_average=False表示损失函数不求平均值# 优化器(梯度下降)
optimizer = torch.optim.SGD(model.parameters(),lr=0.01) # model.parameters()可以获取神经网络中的所有参数参数矩阵的值,对其进行优化
# lr表示步长# 训练数据集
for epoch in range(100):y_pred = model(x_data) # 将数据传入搭建好的神经网络模型得到估计值loss = criterion(y_pred,y_data) # 计算损失值print(epoch,loss)optimizer.zero_grad() # 清除上次的梯度值loss.backward() # 自动求导optimizer.step() # 优化参数# 输出结果
print('w=',model.linear.weight.item())
print('b=',model.linear.bias.item())# 测试模型
x_test = torch.tensor([4.0])
y_test = model(x_test)
print('y_pred=',y_test.data)
说明:这里的torch.nn.Linear(1,1)表示该神经网络处理的是n * 1的输入,输出也是n * 1,其它情况使用情况如下:
可以看出,训练100轮效果不佳,可以训练1000次看看不同结果。
2.多隐藏层的神经网络模型
与单隐藏层神经网络模型区别如下:
class Model(torch.nn.Module):def __init__(self):super(Model,self).__init__()self.linear1 = torch.nn.Linear(8,6) # 模型从8维变为6维,再从6维变为4维,再从4维变为1维self.linear2 = torch.nn.Linear(6,4)self.linear1 = torch.nn.Linear(4,1)self.sigmoid = torch.nn.Sigmoid() # 使用sigmoid激活函数def forward(self,x):pred1 = self.sigmoid(self.linear1(x)) # 上一层输出结果传给下一层pred2 = self.sigmoid(self.linear2(pred1))y_pred = self.sigmoid(self.linear3(pred2))return nmodel = Model()
相关文章:
pytorch入门7--自动求导和神经网络
深度学习网上自学学了10多天了,看了很多大神的课总是很快被劝退。终于,遇到了一位对小白友好的刘二大人,先附上链接,需要者自取:https://b23.tv/RHlDxbc。 下面是课程笔记。 一、自动求导 举例说明自动求导。 torch中的…...
QT 之wayland 事件处理分析基于qt5wayland5.14.2
1. Qt wayland 初始化 接收鼠标/案件,触摸屏等事件事件 QWaylandNativeInterface : public QPlatformNativeInterface 在QWaylandNativeInterface 继承qpa 接口类QPlatformNativeInterface; 1.1 初始化鼠标: void *QWaylandNativeInterface::nativeR…...
【this 和 super 的区别】
在 Java 中,this 和 super 都是关键字,表示当前对象和父类对象。 this 关键字可以用于以下几种情况: 引用当前对象的成员变量,方法和构造方法,用于区分局部变量和成员变量重名的情况; 调用当前类的另外一…...
K8s:Monokle Desktop 一个集Yaml资源编写、项目管理、集群管理的 K8s IDE
写在前面 Monokle Desktop 是 kubeshop 推出的一个开源的 K8s IDE相关项目还有 Monokle CLI 和 Monokle Cloud相比其他的工具,Monokle Desktop 功能较全面,涉及 k8s 管理的整个生命周期博文内容:Monokle Desktop 下载安装,项目管理…...
自动化测试实战篇(8),jmeter并发测试登录接口,模拟从100到1000个用户同时登录测试服务器压力
首先进行使用jmeter进行并发测试之前就需要搞清楚线程和进程的区别还需要理解什么是并发、高并发、并行。还需要理解高并发中的以及老生常谈的,TCP三次握手协议和TCP四次握手协议**TCP三次握手协议指:****TCP四次挥手协议:**进入Jmeter&#…...
ATTCK v12版本战术实战研究—持久化(二)
一、前言前几期文章中,我们介绍了ATT&CK中侦察、资源开发、初始访问、执行战术、持久化战术的知识。那么从前文中介绍的相关持久化子技术来开展测试,进行更深一步的分析。本文主要内容是介绍攻击者在运用持久化子技术时,在相关的资产服务…...
python函数式编程
1 callable内建函数判断一个名字是否为一个可调用函数 >>> import math >>> x 1 >>> y math.sqrt >>> callable(x) False >>> callable(y) True 2 记录函数(文档字符串) >>> def square(x): …...
3.linux下安装mysql
1.安装前的环境准备 查看是否安装过mysql 首先检测Linux操作系统中是否安装了MySQL: # rpm -qa | grep -i mysql 卸载安装包 如果有信息出现,则进行删除,命令如下: # rpm -e --nodeps 包名 删除老版本mysql的开发头文件和…...
17、MySQL分库分表,原理实战
MySQL分库分表,原理实战 1.MyCAT分布式架构入门及双主架构1.1 主从架构1.2 MyCAT安装1.3 启动和连接1.4 配置文件介绍2.MyCAT读写分离架构2.1 架构说明2.2 创建用户2.3 schema.xml2.4 连接说明2.5 读写测试2.6 当前是单节点3.MyCAT高可用读写分离架构3.1 架构说明3.3 schema.xm…...
【C++的OpenCV】第九课-OpenCV图像常用操作(六):图像形态学-阈值的概念、功能及操作(threshold()函数))
目录一、阈值(thresh)的概念二、阈值在图形学中的用途三、阈值的作用和操作3.1 在OpenCV中可以进行的阈值操作3.2 操作实例3.2.1 threshold()函数介绍3.2.2 实例3.2.3 结果上节课的内容(作者还是鼓励各位同学按照顺序进行学习哦)&…...
[Java代码审计]—MCMS
环境搭建 MCMS 5.2.4:https://gitee.com/mingSoft/MCMS/tree/5.2.4/利用 idea 打开项目 创建数据库 mcms,导入 doc/mcms-5.2.8.sql 修改 src/main/resources/application-dev.yml 中关于数据库设置参数 启动项目登录后台 http://localhost:8080/ms/l…...
《程序员面试金典(第6版)》面试题 01.08. 零矩阵
题目描述 编写代码,移除未排序链表中的重复节点。保留最开始出现的节点。 示例1: 输入:[1, 2, 3, 3, 2, 1] 输出:[1, 2, 3] -示例2: 输入:[1, 1, 1, 1, 2] 输出:[1, 2] 提示: 链表长度在[0, 20000]范…...
初识 Python
文章目录简介用途解释器命令行模式交互模式输入和输出简介 高级编程语言,解释型语言代码在执行时会逐行翻译成 CPU 能理解的机器码代码精简,但运行速度慢基础代码库丰富,还有大量第三方库代码不能加密 用途 网络应用工具软件包装其他语言开…...
常用sql语句分享
SELECT COUNT(DISTINCT money) FROM ac_association_course;#COUNT() 函数返回匹配指定条件的行数SELECT AVG(money) FROM ac_association_course;#AVG 函数返回数值列的平均值。NULL 值不包括在计算中SELECT id FROM ac_association_course order by id desc limit 1;#返回最大…...
极狐GitLab DevSecOps 为企业许可证安全合规保驾护航
本文来自: 小马哥 极狐(GitLab) 技术布道师 开源许可证是开源软件的法律武器,是第三方正确使用开源软件的安全合规依据。 根据 Linux 发布的 SBOM 报告显示,98% 的企业都在使用开源软件(中文版报告详情)。随着开源使用…...
后端程序员的前端基础-前端三剑客之HTML
文章目录1 HTML简介1.1 什么是HTML1.2 HTML能做什么1.3 HTML书写规范2 HTML基本标签2.1 结构标签2.2 排版标签2.3 块标签2.4 基本文字标签2.5 文本格式化标签2.6 标题标签2.7 列表标签(清单标签)2.8 图片标签2.9 链接标签2.10 表格标签3 HTML表单标签3.1 form元素常用属性3.2 i…...
VS2019加载解决方案时不能自动打开之前的文档(回忆消失)
✏️作者:枫霜剑客 📋系列专栏:C实战宝典 🌲上一篇: 错误error c3861 :“_T“:找不到标识符 逐梦编程,让中华屹立世界之巅。 简单的事情重复做,重复的事情用心做,用心的事情坚持做; 文章目录前言一、问题描…...
ConcurrentHashMap-Java八股面试(五)
系列文章目录 第一章 ArrayList-Java八股面试(一) 第二章 HashMap-Java八股面试(二) 第三章 单例模式-Java八股面试(三) 第四章 线程池和Volatile关键字-Java八股面试(四) 提示:动态每日更新算法题,想要学习的可以关注一下 文章目录系列文章目录一、…...
互联网衰退期,测试工程师35岁的路该怎么走...
国内的互联网行业发展较快,所以造成了技术研发类员工工作强度比较大,同时技术的快速更新又需要员工不断的学习新的技术。因此淘汰率也比较高,超过35岁的基层研发类员工,往往因为家庭原因、身体原因,比较难以跟得上工作…...
Windows Cannot Initialize Data Bindings 问题的解决方法
前言 拿到一个调试程序, 怎么折腾都打不开, 在客户那边, 尝试了几个系统版本, 发现Windows 10 21H2 版本可以正常运行。 尝试 系统篇 系统结果公司电脑 Windows 8有问题…下载安装 Windows10 22H2问题依旧下载安装 Windows10 21H2问题依旧家里的 笔记本Window 11正常 网上…...
Leetcode每日一题 1487. 保证文件名唯一
Halo,这里是Ppeua。平时主要更新C语言,C,数据结构算法......感兴趣就关注我吧!你定不会失望。 🌈个人主页:主页链接 🌈算法专栏:专栏链接 我会一直往里填充内容哒! &…...
Linux常用命令——lsusb命令
在线Linux命令查询工具(http://www.lzltool.com/LinuxCommand) lsusb 显示本机的USB设备列表信息 补充说明 lsusb命令用于显示本机的USB设备列表,以及USB设备的详细信息。 lsusb命令是一个学习USB驱动开发,认识USB设备的助手,推荐大家使用…...
Python——我愿称之为最简单的语言
Python——我愿称之为最简单的语言开发工具基础语法变量和数据类型列表和元组字典if语句while语句函数类文件与异常测试代码参考书籍:《python编程从入门到实践》 开发工具 python编程环境分为两个部分:python解释器和文本编辑器。运行.py文件时&#…...
java.io.IOException: Broken pipe
1、问题出现的场景 线上环境,拉取对账单,走的接口的形式,当天单量比较大,就出现了,拉取订单超时,报了个错java.io.IOException: Broken pipe。 2、解决方案 我们设置的超时时间是100S,由于当…...
Python——列表排序和赋值
(1)列表排序: 列表排序方法 ls.sort() 对列表ls 中的数据在原地进行排序 ls [13, 5, 73, 4, 9] ls.sort()ls.sort(reverseFalse) 默认升序,reverseTrue,降序 ls [13, 5, 73, 4, 9] ls.sort(reverseTrue)key指定排序时…...
python+pytest接口自动化(7)-cookie绕过登录(保持登录状态)
在编写接口自动化测试用例或其他脚本的过程中,经常会遇到需要绕过用户名/密码或验证码登录,去请求接口的情况,一是因为有时验证码会比较复杂,比如有些图形验证码,难以通过接口的方式去处理;再者,…...
【连接池】什么是HikariCP?HikariCP 解决了哪些问题?为什么要使用 HikariCP?
文章目录什么是连接池什么是HikariCPHikariCP 解决了哪些问题?为什么要使用 HikariCP?HikariCP 的使用Maven支持数据库什么是连接池 数据库连接池负责分配、管理和释放数据库的连接。 数据库连接复用:重复使用现有的数据库长连接࿰…...
Tapdata Cloud 基础课:新功能详解之「微信告警」,更及时的告警通知渠道
【前言】作为中国的 “Fivetran/Airbyte”, Tapdata 是一个以低延迟数据移动为核心优势构建的现代数据平台,内置 60 数据连接器,拥有稳定的实时采集和传输能力、秒级响应的数据实时计算能力、稳定易用的数据实时服务能力,以及低代码可视化操作…...
【巨人的肩膀】JAVA面试总结(四)
💪、JVM 目录💪、JVM1、说一下JVM的主要组成部分及其作用2、什么是JVM内存结构(谈谈对运行时数据区的理解)3、堆和栈的区别是什么4、堆中存什么?栈中存什么?5、为什么不把基本类型放堆中呢?6、为…...
攒了一冬的甜,米易枇杷借力新电商走出川西大山
“绿暗初迎夏,红残不及春。魏花非老伴,卢橘是乡人。”苏轼文中的卢橘,就是枇杷,在苏轼看来,相较于姚黄魏紫,来自故乡四川的枇杷更为亲近。 四川省攀枝花市米易县是全国枇杷早熟产区之一,得益于…...
企业建站报价方案/自贡网站seo
文章目录R_install参考Anaconda安装R语言解释器及第三方包安装pkgR方式【推荐使用】conda方式手动安装装载包到库中更新更新(R方式):R_install 参考 [genomicranges的学习]https://www.it610.com/article/1288783044492206080.htm 官网 A…...
制作俄语网站/南宁网站推广营销
传入一个需要比较的字符串。例如 [value compare:"********"] ,返回 NSOrderedSame。 options:(NSStringCompareOptions)传入 NSStringCompareOptions 枚举的值 enum{NSCaseInsensitiveSearch 1,//不区分大小写比较NSLiteralSearch 2,//区分大小写比较N…...
山东建设工会网站/百度搜索引擎怎么弄
一丶解释性语言和编译型语言的优缺点 编译型语言: 编译型语言最大的优势之一就是其执行速度。用C/C编写的程序运行速度要比用Java编写的相同程序快30%-70%。编译型程序比解释型程序消耗的内存更少。不利的一面——编译器比解释器要难写得多。编译器在调试程序时提…...
wordpress 自动保存图片/百度地图导航
PHP 表单验证 本章节我们将介绍如何使用PHP验证客户端提交的表单数据。 PHP 表单验证 在处理PHP表单时我们需要考虑安全性。 本章节我们将展示PHP表单数据安全处理,为了防止黑客及垃圾信息我们需要对表单进行数据安全验证。 在本章节介绍的HTML表单中包含以下输入字…...
手机网站css/招商外包公司
1.首先找一张看起来很酷的图(也可以选择自己喜欢的图片); 2. 复制图层,点击添加图层样式,选择混合选项,在高级混合里面的通道选项,有R、G、B三个通道选项,默认是全部勾选的状态&…...
怎么做网站的在线客服/百度seo优化网站
使用dbua升级时,需要手工设置CLUSTER_DATABASE参数么? 来源于: Is Manual Setting Of CLUSTER_DATABASE Parameter Required For DBUA Upgrade? (文档 ID 741081.1) 适用于: Oracle Server - Enterprise Edition - Version: 10.…...