当前位置: 首页 > news >正文

VGG卷积神经网络实现Cifar10图片分类-Pytorch实战


前言

当涉足深度学习,选择合适的框架是至关重要的一步。PyTorch作为三大主流框架之一,以其简单易用的特点,成为初学者们的首选。相比其他框架,PyTorch更像是一门易学的编程语言,让我们专注于实现项目的功能,而无需深陷于底层原理的细节。

就像我们使用汽车时,更重要的是了解如何驾驭,而不是花费过多时间研究轮子是如何制造的。我将以一系列专门针对深度学习框架的文章,逐步深入理论知识和实践操作。但这需要在对深度学习有一定了解后才能进行,现阶段我们的重点是学会如何灵活使用PyTorch工具。深度学习涉及大量数学理论和计算原理,对于初学者来说可能会有些繁琐。然而,只有通过实际操作,我们才能真正理解所写代码在神经网络中的作用。我将努力将知识简化,转化为我们熟悉的内容,让大家能够理解和熟练使用神经网络框架。

如果你发现深度学习看似难以掌握,我将尽力简化知识,将其转化为我们更容易理解的内容。我会确保你能够理解知识并顺利运用到实践中。在后期,我将发布一系列专门解析深度学习框架的文章,但在开始学习之前,我们需要对深度学习的理论知识和实践操作有一定的熟悉度。

作为一个从事数据建模五年的专业人士,我参与了许多数学建模项目,了解各种模型的原理、建模流程和题目分析方法。我希望通过这个专栏让你能够快速掌握各类数学模型、机器学习和深度学习知识,并掌握相应的代码实现。每篇文章都包含实际项目和可运行的代码。我会紧跟各类数模比赛,将最新的思路和代码分享给你,保证你能够高效地学习这些知识。

博主非常期待与你一同探索这个精心打造的专栏,里面充满了丰富的实战项目和可运行的代码,希望你不要错过:专栏链接


一、VGGNet概述

VGGNet(Visual Geometry Group Network)是由牛津大学视觉几何组(Visual Geometry Group)提出的深度卷积神经网络架构,它在2014年的ImageNet图像分类挑战中取得了优异的成绩。VGGNet之所以著名,一方面是因为其简洁而高效的网络结构,另一方面是因为它通过深度堆叠的方式展示了深度卷积神经网络的强大能力。

VGGNet探索了卷积神经网络的深度与其性能之间的关系,成功地构筑了16~19层深的卷积神经网络,证明了增加网络的深度能够在一定程度上影响网络最终的性能,使错误率大幅下降,同时拓展性又很强,迁移到其它图片数据上的泛化性也非常好。到目前为止,VGG仍然被用来提取图像特征。

VGGNet包含两种结构,分别为16层和19层。VGGNet结构中,所有卷积层的kernel都只有3*3。VGGNet中连续使用3组3*3kernel的原因是它与使用1个7*7kernel产生的效果相同,然而更深的网络结构还会学习到更复杂的非线性关系,从而使得模型的效果更好。该操作带来的另一个好处是参数数量的减少,因为对于一个包含了C个kernel的卷积层来说,原来的参数个数为7*7*C,而新的参数个数为3*(3*3*C)。
下图给出了VGG16的具体结构示意图:

 根据VGG16进行具体分析,包含:

  • 13个卷积层(Convolutional Layer)
  • 3个全连接层(Fully connected Layer)
  • 5个池化层(Pool layer)

其中,卷积层和全连接层具有权重系数,因此也被称为权重层,总数目为13+3=16,这即是VGG16中16的来源。

 内存消耗主要来自早期的卷积,而参数量的激增则发生在后期的全连接层。由于采用了大量的卷积层,导致VGGNet的参数数量较大,训练和推理过程需要更多的计算资源。而且参数量较大,需要更多的数据来避免过拟合问题。

二、PyTorch网络搭建

我们参考上述网络结构,利用pytorch进行网络搭建,首先我们可以先搭建输出层,根据我上述提供的每一层具体的parameters搭建即可:

def __init__(self, num_classes=1000):super(VGG,self).__init()__self.features = self._make_layers()self.classifier = nn.Sequential(nn.Linear(512*7*7,4096),nn.ReLU(True),nn.Dropout(),nn,Linear(4096,4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096,num_classes))

 接下来我们来搭建卷积和全连接层,可以利用循环帮助我们省去每个步骤繁琐的写层:

        
def _make_layers(self):layers = []in_clannels = 3cfg =[64,64,'M',128,128,'M',256,256,256,'M',512,512,512,'M']for v in cfg:if v =='M':layers +=[nn.MaxPool2d(kernel_size=2,stride=2)]else:conv2d = nn.Conv2d(in_channels,v,kernel_size)layers +=[conv2d,nn.ReLU(inplace=True)]in_channels = vreturn nn.Sequential(*layers)

 然后写入每个神经网络必备的传播:

def forward(self, x):x = self.features(x)x = x.view(x.size(0), -1)x = self.classifier(x)return x

 总体网络结构为:

VGGNet((features): Sequential((0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): ReLU(inplace=True)(2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(3): ReLU(inplace=True)(4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(6): ReLU(inplace=True)(7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(8): ReLU(inplace=True)(9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(11): ReLU(inplace=True)(12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(13): ReLU(inplace=True)(14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(15): ReLU(inplace=True)(16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)(17): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(18): ReLU(inplace=True)(19): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(20): ReLU(inplace=True)(21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(22): ReLU(inplace=True)(23): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(classifier): Sequential((0): Linear(in_features=25088, out_features=4096, bias=True)(1): ReLU(inplace=True)(2): Dropout(p=0.5, inplace=False)(3): Linear(in_features=4096, out_features=4096, bias=True)(4): ReLU(inplace=True)(5): Dropout(p=0.5, inplace=False)(6): Linear(in_features=4096, out_features=1000, bias=True))
)

定义损失函数和优化方法:

#定义损失函数和优化方式
criterion = nn.CrossEntropyLoss() #定义损失函数:交叉熵
optimizer = torch.optim.SGD(net.parameters(),lr=0.001,momentum=0.9)#定义优化方法,随机梯度下降

 进行卷积网络训练,这里需要微调一下原来vgg的模型,Cifar10的数据集有10个类别而且图片转换的矩阵需要加入自适应池化层,要一些改进:

import torch.nn as nn# 设置随机种子以保证实验的可复现性
torch.manual_seed(0)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = Falseclass VGGNet(nn.Module):def __init__(self, num_classes=10):super(VGGNet, self).__init__()self.features = self._make_layers()self.avgpool = nn.AdaptiveAvgPool2d((7, 7))self.classifier = nn.Sequential(nn.Linear(512*7*7,4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096,4096),nn.ReLU(True),nn.Dropout(),nn.Linear(4096,num_classes))def _make_layers(self):layers = []in_channels = 3cfg =[64,64,'M',128,128,'M',256,256,256,'M',512,512,512,'M',512, 512, 512, 'M']for v in cfg:if v =='M':layers +=[nn.MaxPool2d(kernel_size=2,stride=2)]else:conv2d = nn.Conv2d(in_channels,v,kernel_size=3, padding=1)layers +=[conv2d,nn.ReLU(inplace=True)]in_channels = vreturn nn.Sequential(*layers)def forward(self, x):x = self.features(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.classifier(x)return x

 需要注意到是我们需要初始化网络的权重,不更新权重的话10000张图片和实际不借助算法猜测图片的概率是一致的,我们先不初始化网络的权重进行训练:

for epoch in range(1):train_loss=0.0for batch_idx,data in enumerate(train_loader,0):#初始化inputs,labels = data #获取数据inputs = inputs.to(device)labels = labels.to(device)optimizer.zero_grad() #梯度置0#优化过程outputs = net(inputs) #将数据输入到网络,得到第一轮网络前向传播的预测结果outputsloss = criterion(outputs,labels) #预测结果outputs和labels通过之前定义的交叉熵计算损失loss.backward() #误差反向传播optimizer.step() #随机梯度下降优化权重#查看网络训练状态train_loss += loss.item()if batch_idx % 2000 == 1 :print(batch_idx)print('[%d,%5d] loss: %.3f' % (epoch + 1,batch_idx + 1,train_loss / 2000))train_loss = 0.0print('Saving epoch %d model ...'%(epoch + 1))state = {'net':net.state_dict(),'epoch':epoch+1,}if not os.path.isdir('checkpoint'):os.mkdir('checkpoint')#torch.save(state,'./checkpoint/cifar10_epoch_%d.ckpt'%(epoch+1))print('Finished Training')

 然后我们去计算整个测试集的预测效果:

#批量计算整个测试集的预测效果
correct= 0
total = 0
with torch.no_grad():for data in test_loader:images,labels = dataimages = images.to(device)labels = labels.to(device)outputs = net(images)_,predicted = torch.max(outputs.data,1)total += labels.size(0)correct += (predicted == labels ).sum().item() #当标记的label种类和预测的种类一致时认为正确,并计数print('Accurary of the network on the 10000 test images : %d %%'%(100*correct/total))

 很明显和实际猜测的概率是一模一样的,总共十个类别1/10很正常:

Accurary of the network on the 10000 test images : 10 %

我们需要先进行初始化网络权重在训练:

def initialize_weights(module):if isinstance(module, nn.Conv2d):nn.init.kaiming_normal_(module.weight, mode='fan_out', nonlinearity='relu')if module.bias is not None:nn.init.constant_(module.bias, 0)elif isinstance(module, nn.Linear):nn.init.normal_(module.weight, 0, 0.01)nn.init.constant_(module.bias, 0)

之后在训练预测一版:

Accurary of the network on the 10000 test images : 47 %

 效果就十分明显了。


点关注,防走丢,如有纰漏之处,请留言指教,非常感谢

以上就是本期全部内容。我是fanstuck ,有问题大家随时留言讨论 ,我们下期见。

相关文章:

VGG卷积神经网络实现Cifar10图片分类-Pytorch实战

前言 当涉足深度学习,选择合适的框架是至关重要的一步。PyTorch作为三大主流框架之一,以其简单易用的特点,成为初学者们的首选。相比其他框架,PyTorch更像是一门易学的编程语言,让我们专注于实现项目的功能&#xff0…...

CentOS 7文件系统中的软链接和硬链接

软链接(Symbolic Link) 软链接,也称为符号链接,是一个指向另一个文件或目录的特殊类型的文件。它是一个指向目标文件的符号,就像快捷方式一样。软链接的创建和使用非常灵活,适用于各种情况。 创建软链接 …...

【AI】深度学习——前馈神经网络——全连接前馈神经网络

文章目录 1.1 全连接前馈神经网络1.1.1 符号说明超参数参数活性值 1.1.2 信息传播公式通用近似定理 1.1.3 神经网络与机器学习结合二分类问题多分类问题 1.1.4 参数学习矩阵求导链式法则更为高效的参数学习反向传播算法目标计算 ∂ z ( l ) ∂ w i j ( l ) \frac{\partial z^{…...

超简单的视频截取方法,迅速提取所需片段!

“视频可以截取吗?用相机拍摄了一段视频,但是中途相机发生了故障,录进去了很多不需要的片段,现在想截取一部分视频出来,但是不知道方法,想问问广大的网友,知不知道视频截取的方法。” 无论是工…...

ArcGIS/GeoScene脚本:基于粒子群优化的支持向量机回归模型

参数输入 1.样本数据必须包含需要回归的字段 2.回归字段是数值类型 3.影响因子是栅格数据,可添加多个 4.随机种子可以确保每次运行的训练集和测试集一致 5.训练集占比为0-1之间的小数 6.迭代次数:迭代次数越高精度越高,但是运行时间越长…...

vue3组件的通信方式

一、vue3组件通信方式 通信仓库地址:vue3_communication: 当前仓库为贾成豪老师使用组件通信案例 不管是vue2还是vue3,组件通信方式很重要,不管是项目还是面试都是经常用到的知识点。 比如:vue2组件通信方式 props:可以实现父子组件、子父组件、甚至兄弟组件通信 自定义事件:可…...

Qt QPair

QPair 文章目录 QPair 摘要QPairQPair 特点代码示例QPair 与 QMap 区别 关键字: Qt、 QPair、 QMap、 键值、 容器 摘要 今天在观摩小伙伴撸代码的时候,突然听到了QPair自己使用Qt开发这么就,竟然都不知道,所以趁没有被人发…...

K8S云计算系列-(3)

K8S Kubeadm案例实战 Kubeadm 是一个K8S部署工具,它提供了kubeadm init 以及 kubeadm join 这两个命令来快速创建kubernetes集群。 Kubeadm 通过执行必要的操作来启动和运行一个最小可用的集群。它故意被设计为只关心启动集群,而不是之前的节点准备工作…...

ardupilot罗盘数据计算航向

目录 文章目录 目录摘要1.数据特点2.数据结论1.结论2.结论摘要 本节主要记录ardupilot 根据罗盘数据计算航向的过程。 如果知道了一组罗盘数据,我们可以粗略估计航向:主要后面我们所说的X和Y都是表示的飞机里面的坐标系,也就是X前Y右边,如果按照罗盘坐标系Y实际在左边。 我…...

第六章:最新版零基础学习 PYTHON 教程—Python 正则表达式(第一节 - Python 正则表达式)

在本教程中,您将了解RegEx并了解各种正则表达式。 常用表达为什么使用正则表达式基本正则表达式更多正则表达式编译的正则表达式 目录​​​​​​​ 元字符 为什么是正则表达式?...

docker安装Jenkins完整教程

1.docker拉取 Jenkins镜像并启动容器 新版本的Jenkins依赖于JDK11 我们选择docker中jdk11版本的镜像 # 拉取镜像 docker pull jenkins/jenkins:2.346.3-2-lts-jdk11 2.宿主机上创建文件夹 # 创建Jenkins目录文件夹 mkdir -p /data/jenkins_home # 设置权限 chmod 777 -R /dat…...

[CISCN 2019初赛]Love Math - RCE(异或绕过)

[CISCN 2019初赛]Love Math 1 解题流程1.1 分析1.2 解题题目代码: <?php //听说你很喜欢数学,不知道你是否爱它胜过爱flag if(!isset($_GET[c]))...

C++ 使用getline()从文件中读取一行字符串

我们知道,getline() 方法定义在 istream 类中,而 fstream 和 ifstream 类继承自 istream 类,因此 fstream 和 ifstream 的类对象可以调用 getline() 成员方法。 当文件流对象调用 getline() 方法时,该方法的功能就变成了从指定文件中读取一行字符串。 该方法有以下 2 种语…...

JS进阶-原型

原型 原型就是一个对象&#xff0c;也称为原型对象 构造函数通过原型分配的函数是所有对象所共享的 JavaScript规定&#xff0c;每一个构造函数都有一个prototype属性&#xff0c;指向另一个对象&#xff0c;所以我们也称为原型对象 这个对象可以挂载函数&#xff0c;对象实…...

虹科方案 | 汽车CAN/LIN总线数据采集解决方案

全文导读&#xff1a;现代汽车配备了复杂的电子系统&#xff0c;CAN和LIN总线已成为这些系统之间实现通信的标准协议&#xff0c;为了开发和优化汽车的电子功能&#xff0c;汽车制造商和工程师需要可靠的数据采集解决方案。基于PCAN和PLIN设备&#xff0c;虹科提供了一种高效、…...

HTML5+CSSDAY4综合案例一--热词

样式展示图&#xff1a; 代码如下&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>热词…...

【源码】hamcrest 源码阅读 泛型 extends 和迭代器模式

文章目录 前言1. 泛型参数和自定义迭代器1.1 使用场景1.2 实现 2. 值得一提 前言 官方文档 Hamcrest Tutorial 上篇文章 Hamcrest 源码阅读及空对象模式、模板方法模式的应用 本篇文章 迭代器模式 1. 泛型参数和自定义迭代器 hamcrest 作为一个matcher库&#xff0c;把某个…...

IntelliJ IDEA 2023.1 版本可以安装了

Maven 的导入时间更加快了。 收到的有邮件提醒安装。 安装后的版本&#xff0c;其实就是升级下&#xff0c;并没有什么主要改变。 IntelliJ IDEA 2023.1 版本可以安装了 - 软件技术 - OSSEZMaven 的导入时间更加快了。 收到的有邮件提醒安装。 安装后的版本&#xff0c;其实就是…...

安全论坛和外包平台汇总

文章目录 一. 网络安全论坛汇总二. 外包平台汇总1. 国内&#xff1a;2. 国外 一. 网络安全论坛汇总 安全焦点BugTraq&#xff1a;http://www.fuzzysecurity.com/Exploit-DB&#xff1a;https://www.exploit-db.com/hackone&#xff1a;https://www.hackerone.com/FreeBuf&…...

9-2-Dataset创建-import调用

文章目录 utils_dataset.pymain-调用utils_dateset.pyutils_dataset.py 1默认:没有改变尺寸,数据集中的图像可以是任意形状尺寸。dataloader中必须令batch_size=1 transforms.Resize((宽,高))(image) 和 batch_size=1 必须用其一 原因:当batch_size>1时,每个batch的数…...

XSS原理

原理&#xff1a; 这是一种将任意 Javascript 代码插入到其他Web用户页面里执行以达到攻击目的的漏洞。攻击者利用浏览器的动态展示数据功能&#xff0c;在HTML页面里嵌入恶意代码。当用户浏览改页时&#xff0c;这些潜入在HTML中的恶意代码会被执行&#xff0c;用户浏览器被攻…...

记一个带批注、表头样式的导入导出excel方法(基于easyexcel)

技术栈&#xff1a;easyexcel-2.2.10&#xff0c;poi-4.1.2&#xff0c;lombok&#xff0c;hutool-5.8.19&#xff1b;公司自用导入导出方法&#xff0c;可能不是那么的优雅&#xff0c;但胜在稳定实用。 /*** Author 955* Date 2023-10-10 11:52* Description 错误批注信息对…...

二叉搜索树--新增节点-力扣 701 题

例题细节二叉搜索树的基础操作-CSDN博客也讲过了&#xff08;put&#xff09;&#xff0c;下面给出递归实现 public TreeNode insertIntoBST(TreeNode node, int val) {//找到空位了if(node null) {return new TreeNode(val);}if(val < node.val) {//一直找到有null的位置…...

C++ - 智能指针 - auto_ptr - unique_ptr - std::shared_ptr - weak_ptr

前言 C当中的内存管理机制需要我们自己来进行控制&#xff0c;比如 在堆上 new 了一块空间&#xff0c;那么当这块空间不需要再使用的时候。我们需要手动 delete 掉这块空间&#xff0c;我们不可能每一次都会记得&#xff0c;而且在很大的项目程序当中&#xff0c;造成内存泄漏…...

【快速入门】JVM之类加载机制与Native

感慨&#xff1a; 如何定义一个合格的Java程序员&#xff0c;Java程序员要了解掌握哪些知识点&#xff0c;网上的面试题太多了&#xff0c;后端需要了解掌握的知识点太多太多了&#xff0c;Java基础、数据结构、异常、多线程、Spring、Spring boot、事务、算法、数据库&#xf…...

R实现数据分布特征的视觉化——多笔数据之间的比较

大家好&#xff0c;我是带我去滑雪&#xff01; 如果要对两笔数据或者多笔数据的分布情况进行比较&#xff0c;Q-Q图、柱状图、星形图都是非常好的选择&#xff0c;下面开始实战。 &#xff08;1&#xff09;绘制Q-Q图 首先导入数据bankwage.csv文件&#xff0c;该数据集…...

TCPUDP

TCP 1.什么是TCP TCP是处于运输层的通信协议&#xff0c;该协议能够实现数据的可靠性传输。 2.TCP报文格式 源端口和目的端口&#xff1a;各占两个字节&#xff0c;发送进程的端口和接收进程的端口号。 序号&#xff1a;占4个字节,序号如果增加到溢出&#xff0c;则下一个序…...

设计模式 - 备忘录模式

目录 一. 前言 二. 实现 三. 优缺点 一. 前言 备忘录模式又称快照模式&#xff0c;是一种行为型设计模式。它可以在不破坏封装性的前提下捕获一个对象的内部状态&#xff0c;并在对象之外保存这个状态&#xff0c;以便在需要的时候恢复到原先保存的状态。在不违反封装的情况…...

OpenCV4(C++)—— 几何图形的绘制

文章目录 一、基本图形1、线2、线圆3、线椭圆4、矩形 二、多边形 一、基本图形 1、线 绘制线&#xff0c;要给出两个点坐标 void cv::line(InputOutputArray img, Point pt1, Point pt2, const Scalar& color, int thickness 1, int lineType LINE_8, int shift 0);…...

智能优化算法常用指标一键导出为EXCEL,CEC2017函数集最优值,平均值,标准差,最差值,中位数,秩和检验,箱线图...

声明&#xff1a;对于作者的原创代码&#xff0c;禁止转售倒卖&#xff0c;违者必究&#xff01; 之前出了一篇关于CEC2005函数集的智能算法指标一键统计&#xff0c;然而后台有很多小伙伴在询问其他函数集该怎么调用。今天采用CEC2017函数集为例&#xff0c;进行展示。 为了突…...