现代卷积网络实战系列2:PyTorch构建训练函数、LeNet网络
🌈🌈🌈现代卷积网络实战系列 总目录
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
1、MNIST数据集处理、加载、网络初始化、测试函数
2、训练函数、PyTorch构建LeNet网络
3、PyTorch从零构建AlexNet训练MNIST数据集
4、PyTorch从零构建VGGNet训练MNIST数据集
5、PyTorch从零构建GoogLeNet训练MNIST数据集
6、PyTorch从零构建ResNet训练MNIST数据集
4、训练函数
4.1 调用训练函数
train(epochs, net, train_loader, device, optimizer, test_loader, true_value)
因为每一个epoch训练结束后,我们需要测试一下这个网络的性能,所有会在训练函数中频繁调用测试函数,所有测试函数中所有需要的参数,训练函数都需要
这七个参数,是训练一个神经网络所需要的最少参数
4.2 训练函数
训练函数中,所有训练集进行多次迭代,而每次迭代又会将数据分成多个批次进行迭代
def train(epochs, net, train_loader, device, optimizer, test_loader, true_value):for epoch in range(1, epochs + 1):net.train()all_train_loss = []for batch_idx, (data, target) in enumerate(train_loader):data = data.to(device)target = target.to(device)optimizer.zero_grad()output = net(data)loss = F.cross_entropy(output, target)loss.backward()optimizer.step()cur_train_loss = loss.item()all_train_loss.append(cur_train_loss)train_loss = np.round(np.mean(all_train_loss) * 1000, 2)print('\nepoch step:', epoch)print('training loss: ', train_loss)test(net, test_loader, device, true_value, epoch)print("\nTraining finished")
- 定义训练函数
- 安装epochs迭代数据
- 进入pytorch的训练模式
- all_train_loss 存放训练集5万张图片的损失值
- 按照batch取数据
- 数据进入GPU
- 标签进入GPU
- 梯度清零
- 当前batch进入网络后得到输出
- 根据输出得到当前损失
- 反向传播
- 梯度下降
- 获取损失的损失值(PyTorch框架中的数据)
- 把当前batch的损失加入all_train_loss数组中,结束batch的迭代
- 将5张图片的损失计算出来并且进行求平均,这里乘以1000是因为我觉得计算出的损失太小了,所以乘以1000,方便看损失的变化,保留两位有效数字
- 打印当前epoch
- 打印损失
- 调用测试函数,测试当前训练的网络的性能,结束epoch的迭代
- 打印训练完成
5、LeNet
向传播来优化学习策略,而是采用的无监督学习的方案,这其实限制了Neocognitron模型。反向传播算法于1974年哈佛大学的 Paul Werbos 提出,并由LeCun于1989将反向传播算法引入了卷积神经网络并且用于手写数字识别任务上,这个就是LeNet-1,通过几年的迭代,LeNet在1998的手写体数字识别任务上取得了很大的成功,这个版本的LeNet就是著名的LeNet-5。为什么LeNet-5这么被广泛使用呢?因为LeNet-5在美国被大规模用于自动对银行支票上的手写数字进行分类。在LeNet之前,字符识别主要是通过手工特征工程来完成特征提取,然后利用机器学习模型来学习手工特征进行分类。因此,特征工程就是一个很大的问题,究竟什么样的特征是需要的特征呢?LeNet-5可以自己学习图像的特征,这就意味着,网络模型自己学习特征成为可能,手工提取特征将成为过去式。卷积还可以被看作是“滑动平均”的推广。
5.1 网络结构
LeNet可以说是首次提出卷积神经网络的模型
主要包含下面的网络层:
- 5*5的二维卷积
- sigmoid激活函数(这里使用了relu)
- 5*5的二维卷积
- sigmoid激活函数
- 数据一维化
- 全连接层
- 全连接层
- softmax分类器
将网络结构打印出来:
LeNet(
-------(conv1): Conv2d(1, 10, kernel_size=(5, 5), stride=(1, 1))
-------(conv2): Conv2d(10, 20, kernel_size=(5, 5), stride=(1, 1))
-------(conv2_drop): Dropout2d(p=0.5, inplace=False)
-------(fc1): Linear(in_features=320, out_features=50, bias=True)
-------(fc2): Linear(in_features=50, out_features=10, bias=True)
)
5.2 PyTorch构建LeNet
class LeNet(nn.Module):def __init__(self, num_classes):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(1, 10, kernel_size=5)self.conv2 = nn.Conv2d(10, 20, kernel_size=5)self.conv2_drop = nn.Dropout2d()self.fc1 = nn.Linear(320, 50)self.fc2 = nn.Linear(50, num_classes)def forward(self, x):x = F.relu(F.max_pool2d(self.conv1(x), 2))x = F.relu(F.max_pool2d(self.conv2_drop(self.conv2(x)), 2))x = x.view(-1, 320)x = F.relu(self.fc1(x))x = F.dropout(x, training=self.training)x = self.fc2(x)return F.log_softmax(x, dim=1)
这个时候已经是一个完整的项目了,看看10个epoch训练过程的打印:
D:\conda\envs\pytorch\python.exe A:\0_MNIST\train.py
Reading data…
train_data: (60000, 28, 28) train_label (60000,)
test_data: (10000, 28, 28) test_label (10000,)
Initialize neural network
test loss: 2301.68
test accuracy: 11.3 %
epoch step: 1
training loss: 634.74
test loss: 158.03
test accuracy: 95.29 %
epoch step: 2
training loss: 324.04
test loss: 107.62
test accuracy: 96.55 %
epoch step: 3
training loss: 271.25
test loss: 88.43
test accuracy: 97.04 %
epoch step: 4
training loss: 236.69
test loss: 70.94
test accuracy: 97.61 %
epoch step: 5
training loss: 211.05
test loss: 69.69
test accuracy: 97.72 %
epoch step: 6
training loss: 199.28
test loss: 62.04
test accuracy: 97.98 %
epoch step: 7
training loss: 187.11
test loss: 59.65
test accuracy: 97.98 %
epoch step: 8
training loss: 178.79
test loss: 53.89
test accuracy: 98.2 %
epoch step: 9
training loss: 168.75
test loss: 51.83
test accuracy: 98.43 %
epoch step: 10
training loss: 160.83
test loss: 50.35
test accuracy: 98.4 %
Training finished
进程已结束,退出代码为 0
可以看出基本上只要一个epoch就可以得到很好的训练效果了,后续的epoch中的提升比较小
相关文章:
现代卷积网络实战系列2:PyTorch构建训练函数、LeNet网络
🌈🌈🌈现代卷积网络实战系列 总目录 本篇文章的代码运行界面均在Pycharm中进行 本篇文章配套的代码资源已经上传 1、MNIST数据集处理、加载、网络初始化、测试函数 2、训练函数、PyTorch构建LeNet网络 3、PyTorch从零构建AlexNet训练MNIST数据…...
leetCode 62.不同路径 动态规划 + 空间复杂度优化
62. 不同路径 - 力扣(LeetCode) 一个机器人位于一个 m x n 网格的左上角 (起始点在下图中标记为 “Start” )。机器人每次只能向下或者向右移动一步。机器人试图达到网格的右下角(在下图中标记为 “Finish” …...
在 .NET 8 Release Candidate 1 中推出 .NET MAUI:质量
作者:David Ortinau 排版:Alan Wang 今天,我们很高兴地宣布 .NET MAUI 在 .NET 8 Release Candidate 1 中已经可用,该版本带有适用于生产应用程序的正式许可证,因此您可以放心地将此版本用于生产环境。我们在 .NET 8 中…...
Spring 学习(八)事务管理
1. 事务 1.1 事务的 ACID 原则 数据库事务(transaction)是访问并可能操作各种数据项的一个数据库操作序列。事务必须满足 ACID 原则——即原子性(Atomicity)、一致性(Consistency)、隔离性(Iso…...
CodeTON Round 6 (Div 1 + Div 2, Rated, Prizes!)(A - E)
CodeTON Round 6 (Div. 1 Div. 2, Rated, Prizes!)(A - E) CodeTON Round 6 (Div. 1 Div. 2, Rated, Prizes!) A. MEXanized Array(分类讨论) 可以发现当 n < k 或者 k > x 1 的时候无法构成 , 其余的时候贪心的用 x 最大化贡献即…...
Spring 源码分析(五)——Spring三级缓存的作用分别是什么?
Spring 的三级缓存是经典面试题,也会看到一些文章讲三级缓存与循环依赖之的关系。那么,三级缓存分别存储的什么呢?他们的作用又分别是什么? 一、一、二级缓存 一级缓存是一个名为 singletonObjects 的 ConcurrentHashMap&#x…...
Django基于类视图实现增删改查
第一步:导入View from django.views import View 第二步:新建这个基类 class CLS_executer(View):db DB_executerdef get(self, request):executer_list list(self.db.objects.all().values())return HttpResponse(json.dumps(executer_list), conte…...
matplotlib绘图实现中文宋体的两种方法(亲测)
方法一:这种方法我没有测试。 第一步 找宋体字体 (win11系统) 2.matplotlib字体目录,如果不知道的话,可以通过以下代码查询: matplotlib.matplotlib_fname() 如果你是Anaconda3 安装的matplotlib&#x…...
非常有用的JavaScript高阶面试技巧!
🍀一、闭包 闭包是指函数中定义的函数,它可以访问外部函数的变量。闭包可以用来创建私有变量和方法,从而保护代码不受外界干扰。 // 例1 function outerFunction() {const privateVariable "私有变量";function innerFunction()…...
windows 安装Linux子系统 Ubuntu 并配置python3
环境说明: Windows 11 Ubuntu 20.04.6 安装步骤以及问题: 1、开启Windows Subsystem for Linux dism.exe /online /enable-feature /featurename:Microsoft-Windows-Subsystem-Linux /all /norestart 2、开启虚拟机特性 dism.exe /online /enabl…...
pytorch的pixel_shuffle转tflite文件
torch.pixel_shuffle()是pytorch里面上采样比较常用的方法,但是和tensoflow的depth_to_space不是完全一样的,虽然看起来功能很像,但是细微是有差异的 def tf_pixelshuffle(input, upscale_factor):temp []depth upscale_factor *upscale_f…...
sentinel-dashboard-1.8.0.jar开机自启动脚本
启动阿里巴巴的流控组件控制面板需要运行一个jar包,通常需要运行如下命令: java -server -Xms4G -Xmx4G -Dserver.port8080 -Dcsp.sentinel.dashboard.server127.0.0.1:8080 -Dproject.namesentinel-dashboard -jar sentinel-dashboard-1.8.0.jar &…...
c++堆排序-建堆-插入-删除-排序
本文以大根堆为例,用数组实现,它的nums[0]是数组最大值。 时间复杂度分析: 建堆o(n) 插入删除o(logn) 堆排序O(nlogn) 首先上代码 #include<bits/stdc.h>using namespace std; void down(vector<int>&nums, int idx, i…...
使用代理后pip install 出现ssl错误
window直接设置代理 httphttp://127.0.0.1:7890;httpshttp://127.0.0.1...
护眼灯什么价位的好?最具性价比的护眼台灯推荐
到了晚上光线比较弱,这时候就需要开灯,要是孩子需要近距离看字学习等等,给孩子选择的灯具要特别的重视。护眼灯就是目前颇受学生家长青睐的灯具之一,越来越多的人会购买一个护眼灯给自己的孩子让孩子能够在灯光下学习的时候&#…...
vue event bus 事件总线
vue event bus 事件总线 创建 工程: H:\java_work\java_springboot\vue_study ctrl按住不放 右键 悬着 powershell H:\java_work\java_springboot\js_study\Vue2_3入门到实战-配套资料\01-随堂代码素材\day04\准备代码\08-事件总线-扩展 vue --version vue crea…...
深信服云桌面用户忘记密码后的处理
深信服云桌面用户忘记了密码,分两种情况,一个是忘记了登录深信服云桌面的密码,另外一个是忘记了进入操作系统的密码。 一、忘记了登录深信服云桌面的密码 登录虚拟桌面接入管理系统界面,在用户管理中选择用户后,点击后…...
Cocos Creator3.8 实战问题(一)cocos creator prefab 无法显示内容
问题描述: cocos creator prefab 无法显示内容, 或者只显示一部分内容。 creator编辑器中能看见: 预览时,看不见内容: **问题原因:** prefab node 所在的layer,默认是default。 解决方法&…...
朴素贝叶斯深度解码:从原理到深度学习应用
目录 一、简介贝叶斯定理的历史和重要性定义例子 朴素贝叶斯分类器的应用场景定义例子常见应用场景 二、贝叶斯定理基础条件概率定义例子 贝叶斯公式定义例子 三、朴素贝叶斯算法原理基本构成定义例子 分类过程定义例子 不同变体定义例子 四、朴素贝叶斯的种类高斯朴素贝叶斯&a…...
RUST 每日一省:闭包
Rust中的闭包是一种可以存入外层函数中变量或作为参数传递给其他函数的匿名函数。你可以在一个地方创建闭包,然后在不同的上下文环境中调用该闭包来完成运算。和一般的函数不同,闭包可以从定义它的作用域中捕获值。 语法 闭包由“||”和“{}”组合而成。…...
Ubuntu下文件的解压缩操作:常用zip和unzip
Ubuntu下文件的解\压缩 压缩一个文件夹为zip包,加参数-r: zip -r MyWeb.zip MyWeb需要排除目录里某个文件夹?例如我要去掉node_modules,以显著减小压缩包体积,此时该怎么做? zip -r MyWeb.zip ./MyWeb…...
Linux学习第22天:Linux中断驱动开发(一): 突如其来
Linux版本号4.1.15 芯片I.MX6ULL 大叔学Linux 品人间百味 思文短情长 中断作为驱动开发中很重要的一个概念,在实际的项目实践中经常用到。本节的主要内容包括中断简介、硬件原理分析、驱动程序开发及运行测试。其中驱动程…...
IDEA 2019 Springboot 3.1.3 运行异常
项目场景: 在IDEA 2019 中集成Springboot 3.1.3 框架,运行异常。 <?xml version"1.0" encoding"UTF-8"?><project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSch…...
【JAVA】飞机大战
代码和图片放在这个地址了: https://gitee.com/r77683962/fighting/tree/master 最新的代码运行,可以有两架飞机,分别通过WASD(方向),F(发子弹);上下左右(控…...
Midjourney 生成油画技巧
基本 prompt oil painting, a cute corgi dog surrounded with colorful flowers技法 Pointillism 点描绘法 笔刷比较细,图像更精细 oil painting, a cute corgi dog surrounded with colorful flowers, pontillismImpasto 厚涂绘法 笔刷比较粗,图像…...
26559-2021 机械式停车设备 分类
声明 本文是学习GB-T 26559-2021 机械式停车设备 分类. 而整理的学习笔记,分享出来希望更多人受益,如果存在侵权请及时联系我们 1 范围 本文件规定了机械式停车设备的分类及有关的型式、型号和适停汽车组别、尺寸及质量。 本文件适用于 GB/T 3730.1—2001定义的乘用车及商用…...
xxe攻击(XML外部实体)
1.定义 XML用于标记电子文件使其具有结构性的标记语言,可以用来标记数据、定义数据类型,是一种允许用户对自己的标记语言进行定义的源语言。XML文档结构包括XML声明、DTD文档类型定义(可选)、文档元素。 http://www.w3school.com.…...
大数据-hadoop
1.hadoop介绍 1.1 起源 1.2 版本 1.3生产环境版本选择 Hadoop三大发行版本:Apache、Cloudera、Hortonworks Apache版本最原始的版本 Cloudera在大型互联网企业中用的较多 Hortonworks文档较好 1.4架构 hadoop由三个模块组成 分布式存储HDFS 分布式计算MapReduce 资源调度引擎Y…...
容器启动报错
容器启动报错 docker: Error response from daemon: driver failed programming external connectivity on endpoint XXX 如下: 据百度: 在docker启动后在,再对防火墙firewalld进行操作,就会发生上述报错 详细原因:…...
求生之路2服务器搭建插件安装及详细的游戏参数配置教程linux
求生之路2服务器搭建插件安装及详细的游戏参数配置教程linux 大家好我是艾西,在上一篇文章中我用windows系统给搭建演示了一遍怎么搭建自己的L4D2游戏。 那么也有不少小伙伴想知道linux系统的搭建方式以及在这个过程中有什么区别。 那么艾西今天就跟大家分享下用lin…...
怎么查看网站有没有做ssl/2345系统导航
% and_hand.m 手算实现与逻辑%% 清理close allclear,clc%% 定义变量P[0,0,1,1;0,1,0,1] % 输入向量P[ones(1,4);P] % 包含偏置的输入向量d[0,0,0,1] % 期望输出向量% 初始化w…...
开了个网站用年份做名字好吗/软广告经典案例
总是觉得自己内心空洞,空虚无聊,该怎么办呢? 我们先定义一下什么是空虚?百度给出的定义是:百无聊赖、闲散寂寞的消极心态。即人们常说的“没劲”,是心理不充实的表现。简单来说就是过于无聊,导致…...
买个网站域名要多少钱一年/全网自媒体平台
为什么80%的码农都做不了架构师?>>> 行为型模式:Strategy 策略模式 1、算法与对象的耦合 对象可能经常需要使用多种不同的算法,但是如果变化频繁,会将类型变得脆弱... 2、动机(…...
郑州做网站推广的公司哪家好/无排名优化
最近在网上找了个vue搭建的后台管理的框架,在使用的时候发现没有了config和build文件夹,所以当时就蒙圈了,以为是作者自己改了什么东西,所以感觉自己不知道从何下手了,不过通过查资料发现原来是vue-cli2和3的config不相…...
wordpress官方响应式主题/深圳推广网络
◆如何去金山海滩?2007-8-26 在锦江乐园附近的西南汽车站,乘石梅线,票价十元,先购票,中间不停站,大约1小时,但是石梅线坐的人比较多,排队等车大概要半小时。卫梅线也可以到。从石化车…...
网站开发文档/安卓优化大师最新版下载
什么是阿里云轻量应用服务器? 轻量应用服务器是面向入门级云计算及简单应用用户,提供基于单台云服务器的域名管理、应用部署、安全和运维管理的一站式综合服务。用户可以选择精品应用镜像(比如wordpress),并可在控制…...