当前位置: 首页 > news >正文

pytorch-复现经典深度学习模型-LeNet5

Neural Networks

  • 使用torch.nn包来构建神经网络。nn包依赖autograd包来定义模型并求导。 一个nn.Module包含各个层和一个forward(input)方法,该方法返回output

  • 一个简单的前馈神经网络,它接受一个输入,然后一层接着一层地传递,最后输出计算的结果。

    • 在这里插入图片描述
  • 神经网络的典型训练过程如下:

      1. 定义包含一些可学习的参数(或者叫权重)神经网络模型;
      2. 在数据集上迭代;
      3. 通过神经网络处理输入;
      4. 计算损失(输出结果和正确值的差值大小);
      5. 将梯度反向传播回网络的参数;
      6. 更新网络的参数,主要使用如下简单的更新原则: weight = weight - learning_rate * gradient

开始定义一个Lenet5网络:

  • import torch
    import torch.nn as nn
    import torch.nn.functional as F
    class Net(nn.Module):def __init__(self):super(Net, self).__init__()# 1 input image channel, 6 output channels, 5x5 square convolution# kernelself.conv1 = nn.Conv2d(1, 6, 5)self.conv2 = nn.Conv2d(6, 16, 5)# an affine operation: y = Wx + bself.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):# Max pooling over a (2, 2) windowx = F.max_pool2d(F.relu(self.conv1(x)), (2, 2))# If the size is a square you can only specify a single numberx = F.max_pool2d(F.relu(self.conv2(x)), 2)x = x.view(-1, self.num_flat_features(x))x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return xdef num_flat_features(self, x):size = x.size()[1:]  # all dimensions except the batch dimensionnum_features = 1for s in size:num_features *= sreturn num_features
    net = Net()
    print(net)
    
  • Net((conv1): Conv2d(1, 6, kernel_size=(5, 5), stride=(1, 1))(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))(fc1): Linear(in_features=400, out_features=120, bias=True)(fc2): Linear(in_features=120, out_features=84, bias=True)(fc3): Linear(in_features=84, out_features=10, bias=True)
    )
    
  • 在模型中必须要定义 forward 函数,backward 函数(用来计算梯度)会被autograd自动创建。 可以在 forward 函数中使用任何针对 Tensor 的操作。torch.nn 只支持小批量输入。整个 torch.nn 包都只支持小批量样本,而不支持单个样本。 例如,nn.Conv2d 接受一个4维的张量, 每一维分别是sSamples * nChannels * Height * Width(样本数*通道数*高*宽)。 如果你有单个样本,只需使用 input.unsqueeze(0) 来添加其它的维数net.parameters()返回可被学习的参数(权重)列表和值

  • params = list(net.parameters())
    print(len(params))
    for i in range(len(params)):print(f"第{i}层需要学习参数的参数量矩阵大小:",params[i].size())  )  
    
  • 100层需要学习参数的参数量矩阵大小: torch.Size([6, 1, 5, 5])1层需要学习参数的参数量矩阵大小: torch.Size([6])2层需要学习参数的参数量矩阵大小: torch.Size([16, 6, 5, 5])3层需要学习参数的参数量矩阵大小: torch.Size([16])4层需要学习参数的参数量矩阵大小: torch.Size([120, 400])5层需要学习参数的参数量矩阵大小: torch.Size([120])6层需要学习参数的参数量矩阵大小: torch.Size([84, 120])7层需要学习参数的参数量矩阵大小: torch.Size([84])8层需要学习参数的参数量矩阵大小: torch.Size([10, 84])9层需要学习参数的参数量矩阵大小: torch.Size([10])
    
  • 测试随机输入32×32。 注:这个网络(LeNet)期望的输入大小是32×32,如果使用MNIST数据集来训练这个网络,请把图片大小重新调整到32×32。

  • input = torch.randn(1, 1, 32, 32)
    out = net(input)
    print(out)
    
  • tensor([[ 0.0501,  0.1101, -0.0294,  0.0030,  0.0629,  0.0379,  0.0860,  0.0104,0.1108,  0.0916]], grad_fn=<AddmmBackward0>)
    
  • 将所有参数的梯度缓存清零,然后进行随机梯度的的反向传播:

  • net.zero_grad()
    out.backward(torch.randn(1, 10).data)
    print(out)
    
  • tensor([[-0.0275, -0.0630, -0.0253, -0.0811,  0.0872, -0.0282,  0.0926,  0.1020,-0.1034,  0.0727]], grad_fn=<AddmmBackward0>)
    
  • torch.Tensor:一个用过自动调用 backward()实现支持自动梯度计算的 多维数组 , 并且保存关于这个向量的梯度 w.r.t.

  • nn.Module:神经网络模块。封装参数、移动到GPU上运行、导出、加载等。

  • nn.Parameter:一种变量,当把它赋值给一个Module时,被 自动 地注册为一个参数。

  • autograd.Function:实现一个自动求导操作的前向和反向定义,每个变量操作至少创建一个函数节点,每一个Tensor的操作都回创建一个接到创建Tensor编码其历史 的函数的Function节点。

损失函数

  • 一个损失函数接受一对 (output, target) 作为输入,计算一个值来估计网络的输出和目标值相差多少。nn包中有很多不同的损失函数。 nn.MSELoss是一个比较简单的损失函数,它计算输出和目标间的均方误差, 例如:

  • output = net(input)
    target = torch.randn(10)  # 随机值作为样例
    target = target.view(1, -1)  # 使target和output的shape相同
    criterion = nn.MSELoss()
    loss = criterion(output, target)
    print(loss)
    
  • tensor(0.8873, grad_fn=<MseLossBackward0>)
    
  • 反向过程中跟随loss , 使用它的 .grad_fn 属性,将看到如下所示的计算图。

  • input -> conv2d -> relu -> maxpool2d -> conv2d -> relu -> maxpool2d-> view -> linear -> relu -> linear -> relu -> linear-> MSELoss-> loss
    
  • 当我们调用 loss.backward()时,整张计算图都会 根据loss进行微分,而且图中所有设置为requires_grad=True的张量 将会拥有一个随着梯度累积的.grad 张量。

反向传播

  • 调用loss.backward()获得反向传播的误差。但是在调用前需要清除已存在的梯度,否则梯度将被累加到已存在的梯度。现在,我们将调用loss.backward(),并查看conv1层的偏差(bias)项在反向传播前后的梯度。

  • net.zero_grad()     # 清除梯度
    print('conv1.bias.grad before backward')
    print(net.conv1.bias.grad)
    loss.backward()
    print('conv1.bias.grad after backward')
    print(net.conv1.bias.grad)
    
  • conv1.bias.grad before backward
    tensor([0., 0., 0., 0., 0., 0.])
    conv1.bias.grad after backward
    tensor([-0.0136,  0.0067,  0.0111,  0.0210,  0.0111, -0.0072])
    

更新权重

  • 当使用神经网络是想要使用各种不同的更新规则时,比如SGD、Nesterov-SGD、Adam、RMSPROP等,PyTorch中构建了一个包torch.optim实现了所有的这些规则。 使用它们非常简单:

  • import torch.optim as optim
    # create your optimizer
    optimizer = optim.SGD(net.parameters(), lr=0.01)
    # in your training loop:
    for i in range(10):optimizer.zero_grad()   # zero the gradient buffersoutput = net(input)loss = criterion(output, target)loss.backward()optimizer.step()    # Does the updateprint(f"第{i}轮训练:",loss)
    
  • 0轮训练: tensor(0.8873, grad_fn=<MseLossBackward0>)1轮训练: tensor(0.8725, grad_fn=<MseLossBackward0>)2轮训练: tensor(0.8582, grad_fn=<MseLossBackward0>)3轮训练: tensor(0.8447, grad_fn=<MseLossBackward0>)4轮训练: tensor(0.8309, grad_fn=<MseLossBackward0>)5轮训练: tensor(0.8173, grad_fn=<MseLossBackward0>)6轮训练: tensor(0.8035, grad_fn=<MseLossBackward0>)7轮训练: tensor(0.7897, grad_fn=<MseLossBackward0>)8轮训练: tensor(0.7764, grad_fn=<MseLossBackward0>)9轮训练: tensor(0.7623, grad_fn=<MseLossBackward0>)
    

035, grad_fn=)
第7轮训练: tensor(0.7897, grad_fn=)
第8轮训练: tensor(0.7764, grad_fn=)
第9轮训练: tensor(0.7623, grad_fn=)

相关文章:

pytorch-复现经典深度学习模型-LeNet5

Neural Networks 使用torch.nn包来构建神经网络。nn包依赖autograd包来定义模型并求导。 一个nn.Module包含各个层和一个forward(input)方法&#xff0c;该方法返回output。 一个简单的前馈神经网络&#xff0c;它接受一个输入&#xff0c;然后一层接着一层地传递&#xff0c;…...

【C++】类和对象(上)

文章目录对象的介绍类的介绍类的两种定义方式类的访问限定符及封装访问限定符封装类的作用域类的实例化类的对象模型对象的介绍 C语言是面向过程的&#xff0c;关注的是过程&#xff0c;分析出求解问题的步骤&#xff0c;通过函数调用逐步解决问题&#xff1b;   C是基于面向…...

工作中责任链模式用法及其使用场景?

前言 笔者是金融保险行业&#xff0c;有这么一种场景&#xff0c;业务员录完单后提交核保&#xff0c;这时候系统会对保单数据进行校验&#xff0c;如不允许手续费超限校验&#xff0c;客户真实性校验、费率限额校验等等&#xff0c;当校验一多时&#xff0c;维护起来特别麻烦…...

三八女神节有哪些数码好物?2023年三八女神节数码好物清单

2023年的三八女神节就快到了&#xff0c;大家还在烦恼&#xff0c;不知道有哪些数码好物&#xff1f;在此&#xff0c;我来给大家分享几款三八女神节实用性强的数码好物&#xff0c;一起来看看吧。 一、蓝牙耳机&#xff1a;南卡小音舱 参考价&#xff1a;239 推荐理由&…...

FairGuard-Windows加固工具版本更新日志

FairGuard-Windows加固工具1.2.2版本更新日志&#xff1a; ■ 增加Unity Resources资源加密的支持; ■ 增加单独Assetbundle资源加密&#xff0c;并同时支持压缩包和文件夹作为输入的方式; ■ 增加对游戏原文件夹加固的支持; Windows加固方案介绍 FairGuard专为游戏量身定…...

基于RT-Thread完整版搭建的极简Bootloader

项目背景Agile Upgrade: 用于快速构建 bootloader 的中间件。example 文件夹提供 PC 上的示例特性适配 RT-Thread 官方固件打包工具 (图形化工具及命令行工具)使用纯 C 开发&#xff0c;不涉及任何硬件接口&#xff0c;可在任何形式的硬件上直接使用加密、压缩支持如下&#xf…...

3.flinkDateStreamAPI介绍env与source

执行环境 Flink可以在不同的环境上下文中运行.可以本地集成开发环境中运行也可以提交到远程集群环境运行. 不同的运行环境对应的flink的运行过程不同,需要首先获取flink的运行环境,才能将具体的job调度到不同的TaskManager 在flink中可以通过StreamExecutionEnvironment类获取…...

$ 2 :数据类型

1.数据类型 1.1基本类型 a、整型int b、浮点型float c、字符型char 1.2构造类型 a、数组[ ] b、结构体struct 1.3指针类型 * 1.4空类型(void) 2.关键字 autoconstdoublefloatintshortstructunsignedbreakcontinueelseforlongsignedswitchvoidcasedefaultenumgotoregistersiz…...

类和对象 - 上

本文已收录至《C语言》专栏&#xff01; 作者&#xff1a;ARMCSKGT 目录 前言 正文 面向过程与面向对象 面向过程的解决方法 面向对象的解决方法 面向对象的优势 类的引入 早期C类的实现 class定义类 class定义规则 类成员的两种定义方式 类的访问限定符及封装 访…...

补档:红黑树代码实现+简略讲解

红黑树讲解和实现1 红黑树介绍1.1 红黑树特性1.2 红黑树的插入1.3 红黑树的删除2 完整代码实现2.1 rtbtree.h头文件2.2 main.c源文件1 红黑树介绍 红黑树( Red-Black tree&#xff0c;简称RB树)是一种自平衡二叉查找树&#xff0c;是计算机科学中常见的一种数据结构&#xff0c…...

FirePower X2 14.0.1 for RAD Studio Alexandria

介绍 FirePower X2 FirePower X2 集成了 RAD Studio 11.0 Alexandria 中的新功能&#xff0c;并预览了我们的新特色组件 TwwDataGrouper。 FirePower X2 还允许您为 Apple 的新 M1 芯片构建应用程序&#xff0c;这样您就可以进一步利用 M1 芯片来提高本机应用程序的性能&#x…...

二十九、MongoDB 恢复数据( mongorestore )

MongoDB mongorestore 脚本命令可以用来恢复备份的数据 语法 MongoDB mongorestore 命令脚本语法如下 $ mongorestore -h <hostname><:port> -d dbname <path> 参数说明 -h <:port>, -h<:port> MongoDB 所在服务器地址&#xff0c;默认为 l…...

【数据分析】缺失数据如何处理?pandas

本文目录1. 基础概念1.1. 缺失值分类1.2. 缺失值处理方法2. 缺失观测及其类型2.1. 了解缺失信息2.2. 三种缺失符号2.3. Nullable类型与NA符号2.4. NA的特性2.5. convert_dtypes方法3. 缺失数据的运算与分组 3.1. 加号与乘号规则3.2. groupby方法中的缺失值4. 填充与剔除4.1. fi…...

嵌入式开发--STM32H750VBT6开发中,新版本CubeMX的时钟问题,不能设置到最高速度480MHZ

嵌入式开发–STM32H750VBT6开发中&#xff0c;新版本CubeMX的时钟问题&#xff0c;不能设置到最高速度480MHZ 问题描述 之前开发的项目&#xff0c;开发环境是CubeMX6.6.1&#xff0c;H7系列的支持包版本是1.10.0。跑得没问题&#xff0c;最近需要对项目做修改&#xff0c;同…...

一文读懂PaddleSpeech中英混合语音识别技术

语音识别技术能够让计算机理解人类的语音&#xff0c;从而支持多种语音交互的场景&#xff0c;如手机应用、人车协同、机器人对话、语音转写等。然而&#xff0c;在这些场景中&#xff0c;语音识别的输入并不总是单一的语言&#xff0c;有时会出现多语言混合的情况。例如&#…...

问题三十四:傅立叶变换——高通滤波

高通滤波器是一种可以通过去除图像低频信息来增强高频信息的滤波器。在图像处理中&#xff0c;高通滤波器常常用于去除模糊或平滑效果&#xff0c;以及增强边缘或细节。在本篇回答中&#xff0c;我们将使用Python和OpenCV实现高通滤波器。 Step 1&#xff1a;加载图像并进行傅…...

flink 键控状态(keyed state)

github开源项目flink-note的笔记。本博客的实现代码都写在项目的flink-state/src/main/java/state/keyed/KeyedStateDemo.java文件中。 项目github地址: github 1. flink键控状态 flink键控状态是作用与flink KeyedStream上的,也就是说需要将DataStream先进行keyby之后才能使…...

【ChatGPT】sqlachmey 多表连表查询语句

感受下科技带来的魅力&#xff0c;这篇文章是通过ChatGPT自动生成的&#xff0c;不得不说技术强大!!! 在SQLAlchemy中进行多表连接查询可以使用join()方法或join()函数&#xff0c;具体用法如下&#xff1a; join()方法 join()方法可以在SQLAlchemy ORM中的查询中使用。假设…...

win11 系统登录问题,PIN 设置问题

我的电脑配置是华为MateBook X Pro 12&#xff0c;i7处理器&#xff0c;16G&#xff0c;1T&#xff0c;win11 系统通过微软账户登录&#xff0c;下午一直登录不进去&#xff0c;网络能连外网&#xff0c;分析应该是连微软服务器不行。连续登录几十次&#xff0c;偶尔可能有一次…...

数据结构六大排序

1.插入排序 思路&#xff1a; 从第一个元素开始认为是有序的&#xff0c;去一个元素tem从有序序列从后往前扫描&#xff0c;如果该元素大于tem&#xff0c;将该元素一刀下一位&#xff0c;循环步骤3知道找到有序序列中小于等于的元素将tem插入到该元素后&#xff0c;如果已排序…...

快速生成QR码的方法:教你变成QR Code Master

目录 简介: 具体实现步骤&#xff1a; 一、可以使用Python中的qrcode和tkinter模块来生成QR码。以下是一个简单的例子&#xff0c;演示如何在Tkinter窗口中获取用户输入并使用qrcode生成QR码。 1&#xff09;首先需要安装qrcode模块&#xff0c;可以使用以下命令在终端或命令…...

tensorflow1.14.0安装教程--保姆级

//方法不止一种&#xff0c;下面仅展示一种。 注&#xff1a;本人电脑为win11&#xff0c;anaconda的python版本为3.9&#xff0c;但tensorflow需要python版本为3.7&#xff0c;所以下面主要阐述将python版本改为3.7后的安装过程以及常遇到的问题。 1.首先电脑安装好anaconda…...

AcWing算法提高课-3.1.3香甜的黄油

宣传一下算法提高课整理 <— CSDN个人主页&#xff1a;更好的阅读体验 <— 题目传送门点这里 题目描述 农夫John发现了做出全威斯康辛州最甜的黄油的方法&#xff1a;糖。 把糖放在一片牧场上&#xff0c;他知道 N 只奶牛会过来舔它&#xff0c;这样就能做出能卖好价…...

私库搭建1:Nexus 安装 Docker 版

本文内容以语雀为准 文档 https://hub.docker.com/r/sonatype/nexus3Docker 安装&#xff1a;https://www.yuque.com/xuxiaowei-com-cn/gitlab-k8s/docker-install 安装 创建文件夹 由于 Nexus 的数据可能会很大&#xff0c;比如&#xff1a;作为 Docker、Maven 私库时&…...

LeetCode-面试题 05.02. 二进制数转字符串【数学,字符串,位运算】

LeetCode-面试题 05.02. 二进制数转字符串【数学&#xff0c;字符串&#xff0c;位运算】题目描述&#xff1a;解题思路一&#xff1a;简单暴力。小数点后面的二进制&#xff0c;now首先从0.5开始之和每次除以2。然后依次判断当前数是否大于now&#xff0c;是则答案加1。若等于…...

pandas: 三种算法实现递归分析Excel中各列相关性

目录 前言 目的 思路 代码实现 1. 循环遍历整个SDGs列&#xff0c;两两拿到数据 2. 调用pandas库函数直接进行分析 完整源码 运行效果 总结 前言 博主之前刚刚被学弟邀请参与了2023美赛&#xff0c;这也是第一次正式接触数学建模竞赛&#xff0c;现在已经提交等待结果…...

【Python百日进阶-Web开发-Vue3】Day543 - Vue3 商城后台 03:登录页面初建

文章目录 一、创建登录页面 login.vue二、登录页面响应式处理,以适应不同大小的屏幕2.1 element-plus 的layout布局中关于响应式的说明2.2 修改login.vue文件2.2.1 :lg=16 大于1200px 横排 2:12.2.2 :md=12 大于992小于1200px 横排 1:12.2.3 小于992 竖排三、引入Element-plus…...

python画直方图,刻画数据分布

先展示效果 准备一维数据 n 个数据元素计算最大值&#xff0c;最小值、均值、标准差、以及直方图分组 import numpy as np data list() for i in range(640):data.append(np.random.normal(1)) print(data)z np.histogram(data, bins64) print(list(z[0])) ### 对应 x 轴数据…...

几何学小课堂:非欧几何(广义相对论采用黎曼几何作为数学工具)【学数学关键是要学会在什么情况下,知道使用什么工具。】

文章目录 引言I 非欧几何1.1 黎曼几何1.2 共形几何1.3 罗氏几何II 黎曼几何的应用2.1 广义相对论2.2 超弦III 理解不同的几何体系的共存3.1 更扎实的欧氏几何3.2 殊途同归引言 公理有错会得到两种情况: 如果某一条自己设定的新公理和现有的公理相矛盾,那么相应的知识体系就建…...

Ubuntu配置静态IP的方法

Ubuntu配置静态IP的方法前言一、查看虚机分配的网卡IP二、查看网卡的网关IP三、配置静态IP1.配置IPv4地址2.执行netplan apply使改动生效3.配置的网卡未生效&#xff0c;修改50-cloud-init.yaml文件解决4.测试vlan网络通信总结前言 Ubuntu18.04 欧拉环境 vlan网络支持ipv6场景…...

网站开发后端 书/排超最新积分榜

Day1 题目描述 计算字符串最后一个单词的长度&#xff0c;单词以空格隔开。 输入描述: 一行字符串&#xff0c;非空&#xff0c;长度小于5000。 输出描述: 整数N&#xff0c;最后一个单词的长度。 示例1 输入 hello world 输出 5 #include <iostream> #inclu…...

企业门户网站建设的必要性/百度搜索关键词推广

最近在用Java调用ffmpeg的命令&#xff0c;所以记录下踩到的坑如果要在Java中调用shell脚本时&#xff0c;可以使用Runtime.exec或ProcessBuilder.start。它们都会返回一个Process对象&#xff0c;通过这个Process可以对获取脚本执行的输出&#xff0c;然后在Java中进行相应处理…...

网站建设唐山/关键词排名优化教程

点击上方蓝色字体&#xff0c;选择“标星公众号”优质文章&#xff0c;第一时间送达关注公众号后台回复pay或mall获取实战项目资料视频新款5G手机&#xff0c;王者上分就是这么爽&#xff01;包邮送一个&#xff01;文章作者: Jitwxs 链接: https://jitwxs.cn/4135e0a9.html一、…...

家居网站开发项目计划书/网络推广怎么做效果好

IBM SPSS Statistics为用户提供了三种相关性分析的方法&#xff0c;分别是双变量分析、偏相关分析和距离分析&#xff0c;三种相关分析方法各针对不同的数据情况&#xff0c;接下来我们将为大家介绍如何使用SPSS相关性分析中的距离分析。 一、数据简述 距离分析和其他两类相关…...

web网站漏洞扫描/网页设计模板网站免费

坑边闲话&#xff1a;如何在 Word 里添加代码片段呢&#xff1f;这个可以通过 VBA 编程实现&#xff0c;加上某些可以导出 HTML 格式的源码编辑器&#xff0c;基本无缝操作。但是 Word 插入代码并自动更新&#xff0c;真的是让人非常恼火&#xff0c;写完这篇文章&#xff0c;我…...

企业网站开发/网络舆情分析报告模板

摘要&#xff1a; 一、背景介绍近年来&#xff0c;越来越热的云计算被推倒风口浪尖&#xff0c;各大中型企业纷纷把企业服务迁移到云上&#xff0c;众多的创业公司也把云服务器作为数据服务的首选。那么问题来了&#xff0c;有些企业的运维开始担心上云的过程是否能做到简单和平…...