当前位置: 首页 > news >正文

LeNet-5上手敲代码

LeNet-5

LeNet-5Yann LeCun在1998年提出,旨在解决手写数字识别问题,被认为是卷积神经网络的开创性工作之一。该网络是第一个被广泛应用于数字图像识别的神经网络之一,也是深度学习领域的里程碑之一。

LeNet-5的整体架构:

在这里插入图片描述

总体来看LeNet-5由两个部分组成:

  • 卷积编码器:由两个卷积层和两个下采样层组成;
  • 全连接层密集块:由三个全连接层组成

特点:

1.相比MLPLeNet使用了相对更少的参数,获得了更好的结果。

2.设计了MaxPool来提取特征

代码实现

1. 模型文件的实现

通过观察模型的整体架构,可以知到LeNet-5只用了三个基本的层——卷积层、下采样层、全连接层,因此我们很容易写出模型的基本框架。

其中Gaussian connections也是一个全连接层。Gaussian Connections利用的是RBF函数(径向欧式距离函数),计算输入向量和参数向量之间的欧式距离。目前该方式基本已淘汰,取而代之的是Softmax

为了提高模型的性能,我们会在卷积层与下采样层之间添加一个Relu激活函数,因此模型的整体流程架构为:

Convolutions -> Relu->Subsampling -> Convolutions -> Relu-> Subsampling -> Full connection -> Full connection -> Full connection

pytorch中,卷积层对应的是nn.Conv2d()方法, 下采样层可以使用pytorch中的最大池化下采样nn.MaxPool2d()方法来实现,全连接层可以使用nn.Linear()方法来实现。

确定参数:

卷积层:对于LeNet-5论文中输入的图片是 32 × 32 32 \times 32 32×32大小的图片(图片通道个数为3)。因此第一个卷积层的输入的通道个数为3,输出的通道个数为16,也就是说一共有16个卷积核。卷积核的个数等于通过卷积后图片的通道个数

我们可以根据如下公式来计算出卷积核的大小。

计算卷积后图像宽和高的公式

  • I n p u t : ( N , C i n , H i n , W i n ) Input:(N, C_{in},H_{in},W_{in}) Input(NCinHinWin)

  • O u t p u t : ( N , C o u t , H o u t , W o u t ) Output:(N,C_{out},H_{out},W_{out}) Output(NCoutHoutWout)

H o u t = [ H i n + 2 × p a d d i n g [ 0 ] − d i l a t i o n [ 0 ] × ( k e r n e l _ s i z e [ 0 ] − 1 ) − 1 s t r i d e [ 0 ] + 1 ] H_{out} = [\frac{H_{in} + 2 \times padding[0] - dilation[0] \times (kernel\_size[0] - 1) - 1}{stride[0]} + 1] Hout=[stride[0]Hin+2×padding[0]dilation[0]×(kernel_size[0]1)1+1]

W o u t = [ W i n + 2 × p a d d i n g [ 1 ] − d i l a t i o n [ 1 ] × ( k e r n e l _ s i z e [ 1 ] − 1 ) − 1 s t r i d e [ 1 ] + 1 ] W_{out} = [\frac{W_{in} + 2 \times padding[1] - dilation[1] \times (kernel\_size[1] - 1) - 1}{stride[1]} + 1] Wout=[stride[1]Win+2×padding[1]dilation[1]×(kernel_size[1]1)1+1]

公式中dilation我们没有使用,默认情况为1,输入的图片为 32 × 32 × 3 32 \times 32 \times 3 32×32×3输出为 28 × 28 × 6 28 \times 28 \times 6 28×28×6,通过公式,我们很容易算出 k e r n e l s i z e = ( 5 , 5 ) kernel_{size} = (5, 5) kernelsize=(5,5)【通常情况下如果通过卷积层后的图片的大小没有很明显的缩小(成倍数缩小),那么stride一般为默认值1,通过以上公式,我们可以求得每一个卷积核的大小 。

最大池化下采样:由于特征图通过最大池化下采样层之后,图片的大小变为原来的一半,因此我们知道在长度方向上每两个像素之间取一个最大值,这样才能将长度变为原来的一半,宽度方向上每两个像素之间取一个最大值,这样才能将宽度变为原来的一半。结合起来得到池化层的每一个滑动窗口的大小为 2 × 2 2 \times 2 2×2,也就是说,每四个像素取一个最大值。

在这里插入图片描述

全连接层:输入为上一个层的输出数据大小,输出为自定义大小,对于第一个全连接层,输入为下采样层的输出,即: 5 × 5 × 16 5 \times 5 \times 16 5×5×16 个矩阵值。输出为下一个全连接层单元的个数(第二个全连接层的单元个数为84个),可以推出所有全连接层的单元个数。

model.py

import torch
import torch.nn as nn
import torch.nn.functional as Fclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, (5, 5))self.pool1 = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, (5, 5))self.pool2 = nn.MaxPool2d(2, 2)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = F.relu(self.conv1(x))x = self.pool1(x)x = F.relu(self.conv2(x))x = self.pool2(x)x = x.view(-1, 16 * 5 * 5)   # 改变张量形状为一个二维张量,第一个维度是自动推断的,第二个维度设定为16 * 5 * 5x = self.fc1(x)x = self.fc2(x)x = self.fc3(x)return xif __name__ == '__main__':model =  LeNet()x = torch.randn((3, 32, 32))output = model(x)print(x)

2. 训练程序

写训练程序的基本步骤为:

  1. 加载训练数据
  2. 初始化模型
  3. 设定损失函数
  4. 设定优化器
  5. 设定迭代次数
  6. 根据情况保存模型权重文件

训练数据我们使用的是CIFAR10中的训练数据,验证集的数据也使用的是CIFAR10中的数据,同时将训练集和验证集的数据进行转换(转换为tensor类型,进行归一化)。设置dataloader,训练集的batch_size64,并且进行随机打乱,设置num_workers2,验证集的batch_size5000,进行随机打乱,设置num_workers2

num_workers:用于设置是否使用多线程读取数据,开启后会加快数据读取速度,但是会占用更多内存,内存较小的电脑可以设置为2或者0

训练数据时,我们在每次的500步之后进行一次验证,验证的方式为,加载验证集,然后输入到网络中进行预测,得到输出的最大值的索引,然后再与真实标签进行比较,统计为True的个数,然后除以所有的标签的个数,得到最后的模型的正确率。

predict_y = torch.max(outputs, dim=1)[1]
accuracy = torch.eq(predict_y, test_label).sum().item() / test_label.size(0)  # .item() 方法将结果转换为标量,即 Python 中的普通数字类型。

在迭代完所有的步数之后进行保存模型的权重文件。

train.py

import torch
import torchvision
from torch import nn, optim
from torch.utils.data import DataLoaderfrom model import LeNetdef main():device = torch.device("cuda" if torch.cuda.is_available() else "cpu")transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), torchvision.transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])# 训练集train_set = torchvision.datasets.CIFAR10(root="./data", train=True, download=False, transform=transform)train_loader = DataLoader(dataset=train_set, batch_size=64, shuffle=True, num_workers=2)# 验证集test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=False, transform=transform)test_loader = DataLoader(dataset=test_set, batch_size=5000, shuffle=True, num_workers=0)# 实例化网络,损失函数,优化器net = LeNet().to(device)net.load_state_dict(torch.load('LeNet_200.pth'))  # 加载权重loss_function = nn.CrossEntropyLoss().to(device)optimizer = optim.Adam(net.parameters(), lr=0.001)epochs = 200epoch = 0# 开始训练print("training...")while epoch <= epochs:epoch += 1running_loss = 0.0for step, data in enumerate(train_loader):print(f"epoc: {epoch}, step: {step}")inputs, lables = datainputs, lables = inputs.to(device), lables.to(device)   # 将数据移动到GPU上optimizer.zero_grad()output = net(inputs)loss = loss_function(output, lables)loss.backward()optimizer.step()running_loss += loss.item()if step % 500 == 499:   # 每500个batch_size之后进行验证一次with torch.no_grad():test_image, test_label = next(iter(test_loader))  # iter(test_loader)作用是设定一个迭代器,这行代码的作用是取出验证集中的一个batch_size的图片和对应的标签。test_image, test_label = test_image.to(device), test_label.to(device)  # 将数据移动到 GPU 上outputs = net(test_image)predict_y = torch.max(outputs, dim=1)[1]accuracy = torch.eq(predict_y, test_label).sum().item() / test_label.size(0)  # .item() 方法将结果转换为标量,即 Python 中的普通数字类型。print('[%d, %5d] train_loss: %.3f  test_accuracy: %.3f' %(epoch + 1, step + 1, running_loss / 500, accuracy))running_loss = 0.0print(f"The epoc is {epoch}")print("Finish Training")save_path = "./LeNet.pth"torch.save(net.state_dict(), save_path)if __name__ == '__main__':main()

3. 验证程序

验证程序,首先需要加载图片,然后进行转换(包括裁剪为模型的输入形状大小【这里为 32 × 32 32 \times 32 32×32】,然后转换为tensor类型,最后进行归一化),将预处理后的图片送入到模型中,模型输出的是一个batch_size个一维向量,每一个一维向量有10个数,表示输出的类别一共有10个,取10个中值最大的数的索引作为预测的类别,可以使用以下代码:predict = torch.max(outputs, dim=1)[1].numpy(),这表示在模型输出的结果中,取第一个维度上的10个数取最大值的索引,并将其转换为numpy类型的数据。然后将这个数对照标签的映射关系,可以得到最终预测的类别。

varify.py

import torch
import torchvision.transforms as transforms
from PIL import Imagefrom model import LeNetdef main():transform = transforms.Compose([transforms.Resize((32, 32)),transforms.ToTensor(),transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])classes = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')net = LeNet()net.load_state_dict(torch.load('LeNet_250.pth'))im = Image.open('2.jpg')  # 加载图片im = transform(im)  # [C, H, W]im = torch.unsqueeze(im, dim=0)  # [N, C, H, W]with torch.no_grad():  # 用于设置在该上下文中不进行梯度计算,因为推断时不需要计算梯度,可以提高计算效率。outputs = net(im)predict = torch.max(outputs, dim=1)[1].numpy()print(classes[int(predict)])if __name__ == '__main__':main()

相关文章:

LeNet-5上手敲代码

LeNet-5 LeNet-5由Yann LeCun在1998年提出&#xff0c;旨在解决手写数字识别问题&#xff0c;被认为是卷积神经网络的开创性工作之一。该网络是第一个被广泛应用于数字图像识别的神经网络之一&#xff0c;也是深度学习领域的里程碑之一。 LeNet-5的整体架构&#xff1a; 总体…...

javaWeb入门(自用)

1. vue学习 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title><script src"https://unpkg.com/vue2"></script> </head> <body><div id"…...

web3风格的网页怎么设计?分享几个,找找感觉。

web3风格的网站是指基于区块链技术和去中心化理念的网站设计风格。这种设计风格强调开放性、透明性和用户自治&#xff0c;体现了Web3的核心价值观。 以下是一些常见的Web3风格网站设计元素&#xff1a; 去中心化标志&#xff1a;在网站的设计中使用去中心化的标志&#xff0…...

ASP.NET MVC(-)表单的提交、获取表单数据

FromCollection 方式...

[AIGC] 《MyBatis-Plus 结合 Spring Boot 的动态数据源介绍及 Demo 演示》

在现代的 Web 应用开发中&#xff0c;Spring Boot 已经成为了一种流行的框架选择。而 MyBatis-Plus 则为 MyBatis 框架提供了更强大的功能和便利。当它们结合使用时&#xff0c;动态数据源的运用变得更加简单和高效。 动态数据源的概念允许我们在运行时根据不同的条件或需求选…...

【华为OD机试C卷D卷】部门人力分配(C++/Java/Python)

【华为OD机试】-(A卷+B卷+C卷+D卷)-2024真题合集目录 【华为OD机试】-(C卷+D卷)-2024最新真题目录 题目描述 部门在进行需求开发时需要进行人力安排。 当前部门需要完成 N 个需求,需求用 requirements 表述,requirements[i] 表示第 i 个需求的工作量大小,单位:人月。 这部…...

毕业设计:《基于 Prometheus 和 ELK 的基础平台监控系统设计与实现》

前言 《基于 Prometheus 和 ELK 的基础平台监控系统设计与实现》&#xff0c;这是我在本科阶段的毕业设计&#xff0c;通过引入 Prometheus 和 ELK 架构实现企业对指标与日志的全方位监控。并且基于云原生&#xff0c;使用容器化持续集成部署的开发方式&#xff0c;通过 Sprin…...

docker私有仓库部署与管理

一、搭建本地公有仓库 1.1 首先下载registry镜像 docker pull registry 1.2 在daemon.json文件中添加私有镜像仓库地址并重新启动docker服务 vim /etc/docker/daemon.json 1.3 运行registry容器 docker run -itd -v /data/registry:/var/lib/registry -p 5000:5000 --restartal…...

2024第六届济南国际大健康产业博会将于5月27日如期开幕

由山东省城市经济学会、山东省科学养生协会主办的第六届中国&#xff08;济南&#xff09;国际大健康产业博览会&#xff0c;将于5月27-29日&#xff0c;在济南黄河国际会展中心盛大举办。 近年来&#xff0c;健康越来越受到大众的重视&#xff0c;在我国经济重要的转型阶段成…...

计算方法实验9:Romberg积分求解速度、位移

任务 输出质点的轨迹 ( x ( t ) , y ( t ) ) , t ∈ { 0.1 , 0.2 , 0.3 , . . . , 10 } (x(t), y(t)), t\in \{0.1, 0.2, 0.3, ..., 10\} (x(t),y(t)),t∈{0.1,0.2,0.3,...,10}&#xff0c;并在二维平面中画出该轨迹.请比较M分别取4, 8, 12, 16, 20 时&#xff0c;Romberg积分达…...

设计模式有哪些基本原则

目录 开闭原则(Open Closed Principle) 里氏替换原则(Liskov Substitution principle) 单一职责原则(Single Responsibility Principle,SRP)...

别再出错了!华为交换机到底如何配置access、trunk、hybird端口?

号主&#xff1a;老杨丨11年资深网络工程师&#xff0c;更多网工提升干货&#xff0c;请关注公众号&#xff1a;网络工程师俱乐部 下午好&#xff0c;我的网工朋友。 我们都知道&#xff0c;网络工程师的工作离不开对交换机的熟练操作。华为交换机的配置&#xff0c;绝对是考验…...

OceanBase 分布式数据库【信创/国产化】- OceanBase 平台产品 - 迁移评估工具 OMA

本心、输入输出、结果 文章目录 OceanBase 分布式数据库【信创/国产化】- OceanBase 平台产品 - 迁移评估工具 OMA前言OceanBase 数据更新架构OceanBase 平台产品 - 迁移评估工具 OMA兼容性评估性能评估导出 OceanBase 数据库对象和 SQL 语句OceanBase 分布式数据库【信创/国产…...

UE5入门学习笔记(六)——编译低版本插件

对于有些低版本的插件&#xff0c;可以通过此方法自己编译到高版本而无需等待插件作者更新 使用工具&#xff1a;如图所示 步骤1&#xff1a;打开cmd&#xff0c;并使用cd命令切换到此目录 步骤2&#xff1a;输入如下指令 RunUAT.bat BuildPlugin -Plugin“路径1” -Package“…...

MySQL全局锁、表级锁、行锁、死锁、索引选择

文章目录 全局锁表级锁表锁元数据锁 MDL 如何安全的给小表添加字段1. 理解和监控长事务2. 使用NOWAIT和WAIT语法示例 3. 选择合适的时间窗口4. 分阶段执行5. 使用在线DDL工具 行锁死锁普通索引和唯一索引的选择索引基础业务场景分析性能考量实践建议索引及其选择机制索引选择错…...

深入解析算法效率核心:时间与空间复杂度概览及优化策略

算法复杂度&#xff0c;即时间复杂度与空间复杂度&#xff0c;衡量算法运行时资源消耗。时间复杂度反映执行时间随数据规模增长的关系&#xff0c;空间复杂度表明额外内存需求。优化策略&#xff0c;如选择合适数据结构、算法改进、循环展开等&#xff0c;对于提升程序效率、减…...

虚拟机装CentOS镜像

起先&#xff0c;是先安装一个VM虚拟机&#xff0c;再去官方网站之类的下载一些镜像&#xff0c;常见镜像有CentOS镜像&#xff0c;ubantu镜像&#xff0c;好像还有一个树莓还是什么的&#xff0c;软件这块&#xff0c;日新月异&#xff0c;更新太快&#xff0c;好久没碰&#…...

SpringCloud 集成consul,消费者报I/O error on GET request for...

创建消费者微服务&#xff0c;去调用生产者微服务的请求过程中&#xff0c;出现以下错误&#xff1a; 报错原因 因为在使用SpringCloudAlibaba中的Nacos框架时&#xff0c;自动整合了SpringCloud中的Ribbon框架中的负载均衡&#xff0c;因为微服务提供者有两个&#xff0c;在消…...

pytest的测试标记marks

引用打标的marks文档 Python的pytest框架(5)--测试标记(Markers)_pytest执行指定的marker-CSDN博客 https://www.cnblogs.com/pipile/p/12696226.html 给用例自定义打标签的代码示例 #coding:utf-8 import pytest pytest.mark.smoke def test_1():print("smoke的测试用…...

端口占用解决方法

1、查询端口 打开cmd命令提示符窗口&#xff0c;输入以下指令查询所有端口 netstat -ano //查询所有端口 netstat -ano|findstr 8080 //查询指定端口 2、杀死进程 taskkill /t /f /im 进程号(PID)...

Java毕设之基于springboot的医护人员排班系统

运行环境 开发语言:java 框架:springboot&#xff0c;vue JDK版本:JDK1.8 数据库:mysql5.7(推荐5.7&#xff0c;8.0也可以) 数据库工具:Navicat11 开发软件:idea/eclipse(推荐idea) 系统详细实现 医护类型管理 医护人员排班系统的系统管理员可以对医护类型添加修改删除以及…...

OpenCV4.8 VS2019 MFC编程出现的诡异现象

OpenCV4.8及OpenCV4.4 VS2019MFC编程在调用imred&#xff08;&#xff09;函数时&#xff0c;debug X64试运行没问题。 release X64试运行时出现下面错误。 void CEasyPictureDlg::OnBnClickedOpen() {CFileDialog fdlg(TRUE, NULL, 0, OFN_HIDEREADONLY | OFN_OVERWRITEPROMP…...

游戏辅助 -- 三种分析角色坐标方法(CE、xdbg、龙龙遍历工具)

所用工具下载地址&#xff1a; https://pan.quark.cn/s/d54e7cdc55e6 在上次课程中&#xff0c;我们成功获取了人物对象的基址&#xff1a;[[[0xd75db8]1C]28]&#xff0c;而人物血量的地址则是基址再加上偏移量278。 接下来&#xff0c;我们需要执行以下步骤来进一步操作&a…...

【VTKExamples::Rendering】第一期 TestAmbientSpheres(环境照明系数)

很高兴在雪易的CSDN遇见你 VTK技术爱好者 QQ:870202403 公众号:VTK忠粉 前言 本文分享VTK样例TestAmbientShperes,介绍环境照明系数对Actor颜色的影响,希望对各位小伙伴有所帮助! 感谢各位小伙伴的点赞+关注,小易会继续努力分享,一起进步! 你的点赞就是我的动…...

代码随想录leetcode200题之栈与队列

目录 1 介绍2 训练3 参考 1 介绍 本博客用来记录代码随想录leetcode200题中栈与队列部分的题目。 2 训练 题目1&#xff1a;232. 用栈实现队列 C代码如下&#xff0c; #include <stack>class MyQueue { private:stack<int> a;stack<int> b; //辅助栈 pu…...

使用Python实现2048小游戏

使用Python实现2048小游戏源码分享。实现效果如下所示。 实现效果图 游戏开始效果图 游戏结束效果图 部分源码截图 下载链接 基于如下的运行环境。运行需要安装tkinter /Library/Frameworks/Python.framework/Versions/3.7/bin/python/bin/python /Users/nihui/Documents/P…...

漏洞管理是如何在攻击者之前识别漏洞从而帮助人们阻止攻击的

漏洞管理 是主动查找、评估和缓解组织 IT 环境中的安全漏洞、弱点、差距、错误配置和错误的过程。该过程通常扩展到整个 IT 环境&#xff0c;包括网络、应用程序、系统、基础设施、软件和第三方服务等。鉴于所涉及的高成本&#xff0c;组织根本无法承受网络攻击和数据泄露。如果…...

LNMT部署jpress

LNMT部署jpress 环境要求&#xff1a; MySQL版本5.6/5.7 tomcat版本9.0.65 源码安装MySQL5.7版 //源码安装MySQL5.7版1关闭防火墙 2创建mysql用户 3上传mysql5.7包&#xff08;https://downloads.mysql.com/archives/get/p/23/file/mysql-5.7.30-linux-glibc2.12-x86_64.tar.g…...

汽车软件研发工具链丨怿星科技新产品重磅发布

“创新引领未来”聚焦汽车软件新基建&#xff0c;4月27日下午&#xff0c;怿星科技2024新产品发布会在北京圆满举行&#xff01;智能汽车领域的企业代表、知名大企业负责人、投资机构代表、研究机构代表齐聚现场&#xff0c;线上直播同步开启&#xff0c;共同见证怿星科技从单点…...

Faiss原理及使用总结

Faiss&#xff08;Facebook AI Similarity Search&#xff09;是一个用于高效相似性搜索和密集向量聚类的库。 一、原理 向量表示与相似度度量&#xff1a;在Faiss中&#xff0c;数据通常被表示为高维向量&#xff0c;这些向量可以来自深度学习模型的特征提取&#xff0c;也可…...

seo网站架构设计/企业网站建设的作用

1、Linux系统简单介绍 Linux是一套免费使用, 支持多用户、多任务、支持多线程和多个核心CPU的操作系统&#xff1b;很多中型, 大型甚至是巨型项目都在使用Linux。 Linux的发行版说简单点就是将Linux与应用软件做一个打包, 目前市面上比较知名的发行版有: Ubuntu, RedHat, Cen…...

做网站推广可行吗/简述网站推广的意义和方法

之前买了个荔枝派&#xff0c;全志的A3S芯片。折腾了两天&#xff0c;写一下编译和SD烧录的过程。 目录 1.直接烧录镜像文件 2.uboot编译 3.kernel编译 4.rootfs编译 5.烧录 6.串口登录 1.直接烧录镜像文件 百度到了一堆的资料&#xff0c;下面是网盘链接 链接&#x…...

帮人网站开发维护违法/windows10优化软件

点击蓝色字免费订阅&#xff0c;每天收到这样的好信息一、数据治理与数据分类分级《DAMA 数据管理知识体系指南》给出的定义&#xff1a;数据治理是对数据资产管理行使权力和控制的活动集合(规划、监控和执行)。数据治理的职能是指导其他数据管理职能如何执行。数据治理就是以服…...

网站开发项目交接/同城推广引流平台

搞惯导、组合导航领域的专家严恭敏老师&#xff0c;在新浪博客上有一系列专业文章。这里记录下博客地址&#xff1a; http://blog.sina.com.cn/s/articlelist_1089338825_0_1.html...

要维护公司的网站该怎么做/免费注册网址

2019独角兽企业重金招聘Python工程师标准>>> 今天决定看看开源中国安卓版app&#xff0c;并试着重构一下。好的进入主题。 创建MainActivity public class MainActivity extends ActionBarActivity implementsNavigationDrawerFragment.NavigationDrawerCallbacks,O…...

传奇网站装备动态图怎么做/网站关键词排名seo

ubuntu 10.10以前的操作方法&#xff1a;1 第一步&#xff0c;具体命令及操作如下&#xff1a;sudo vi /etc/init/rc-sysinit.confenv DEFAULT_RUNLEVEL3 <------将原来的env DEFAULT_RUNLEVEL2修改为env DEFAULT_RUNLEVEL32 第二步&#xff0c;具体命令及操作如下&#xff…...