CNN——LeNet
1.LeNet概述
LeNet是Yann LeCun于1988年提出的用于手写体数字识别的网络结构,它是最早发布的卷积神经网络之一,可以说LeNet是深度CNN网络的基石。
当时,LeNet取得了与支持向量机(support vector machines)性能相媲美的成果,成为监督学习的主流方法。 LeNet当时被广泛用于自动取款机(ATM)机中,帮助识别处理支票的数字。
下面是整个网络的结构图
LeNet共有8层,其中包括输入层,3个卷积层,2个子采样层(也就是现在的池化层),1个全连接层和1个高斯连接层。
上图中用C代表卷积层,用S代表采样层,用F代表全连接层。输入size固定在1*32*32,LeNet图片的输入是二值图像。网络的输出为0~9十个数字的RBF度量,可以理解为输入图像属于0~9数字的可能性大小。
2.详解LeNet
下面对图中每一层做详细的介绍:
- LeNet使用的卷积核大小都为5*5,步长为1,无填充,只是卷积深度不一样(卷积核个数导致生成的特征图的通道数)
- 激活函数为Sigmoid
- 下采样层都是使用最大池化实现,池化的核都为2*2,步长为2,无填充
input输入层,尺寸为1*32*32的二值图
C1层是一个卷积层。该层使用6个卷积核,生成特征图尺寸为32-5+1=28,输出为6个大小为28*28的特征图。再经过一个Sigmoid激活函数非线性变换。
S2层是一个下采样层。生成特征图尺寸为28/2=14,得到6个14*14的特征图。
C3层是一个卷积层,该层使用16个卷积核,生成特征图尺寸为14-5+1=10,输出为16个10*10的特征图。再经过一个Sigmoid激活函数非线性变换。
S4层是一个下采样层,生成特征图尺寸为10/2=5,得到16个5*5的特征图
C5层是一个卷积层,卷积核数量增加至120。生成特征图尺寸为5-5+1=1。得到120个1*1的特征图。这里实际上相当于S4全连接了,但仍将其标为卷积层,原因是如果LeNet-5的输入图片尺寸变大,其他保持不变,那该层特征图的维数也会大于1*1,那就不是全连接了。再经过一个Sigmoid激活函数非线性变换。
F6层是一个全连接层,该层与C5层全连接,输出84张特征图。再经过一个Sigmoid激活函数非线性变换。
输出层:输出层由欧式径向基函数(高斯)单元组成,每个类别(0~9数字)对应一个径向基函数单元,每个单元有84个输入。也就是说,每个输出RBF单元计算输入向量和该类别标记向量之间的欧式距离,距离越远,PRF输出越大,同时我们也会将与标记向量欧式距离最近的类别作为数字识别的输出结果。当然现在通常使用的Softmax实现。
3.使用LeNet实现Mnist数据集分类
1.导入所需库
import torch
import torch.nn as nn
from torchsummary import summary
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from tqdm import tqdm # 显示训练进度条
2.使用GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'
3.读取Mnist数据集
# 定义数据转换以进行数据标准化
transform = transforms.Compose([transforms.ToTensor(), # 将图像转换为 PyTorch 张量
])# 下载并加载 MNIST 训练和测试数据集
train_dataset = datasets.MNIST(root='./dataset', train=True, download=True, transform=transform)
test_dataset = datasets.MNIST(root='./dataset', train=False, download=True, transform=transform)# 创建数据加载器以批量加载数据
batch_size = 256
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
4.搭建LeNet
需要注意的是torch.nn.CrossEntropyLoss自带了softmax函数,所以最后一层使用全连接即可。
class LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()# Mnist尺寸为28*28,这里设置填充变成32*32self.conv1 = nn.Conv2d(1, 6, kernel_size=5, padding=2) self.sigmoid = nn.Sigmoid()self.pool = nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = nn.Conv2d(6, 16, kernel_size=5)self.flatten = nn.Flatten()self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(self.sigmoid(self.conv1(x)))x = self.pool(self.sigmoid(self.conv2(x)))x = self.flatten(x)x = self.sigmoid(self.fc1(x))x = self.sigmoid(self.fc2(x))x = self.fc3(x)return x
# 实例化模型
model = LeNet().to(device)
summary(model, (1, 28, 28))
5.训练函数
def train(model, lr, epochs):# 将模型放入GPUmodel = model.to(device)# 使用交叉熵损失函数loss_fn = nn.CrossEntropyLoss().to(device)# SGDoptimizer = torch.optim.SGD(model.parameters(), lr=lr)# 记录训练与验证数据train_losses = []train_accuracies = []# 开始迭代 for epoch in range(epochs): # 切换训练模式model.train() # 记录变量train_loss = 0.0correct_train = 0total_train = 0# 读取训练数据并使用 tqdm 显示进度条for i, (inputs, targets) in tqdm(enumerate(train_dataloader), total=len(train_dataloader), desc=f"Epoch {epoch+1}/{epochs}", unit='batch'):# 训练数据移入GPUinputs = inputs.to(device)targets = targets.to(device)# 模型预测outputs = model(inputs)# 计算损失loss = loss_fn(outputs, targets)# 梯度清零optimizer.zero_grad()# 反向传播loss.backward()# 使用优化器优化参数optimizer.step()# 记录损失train_loss += loss.item()# 计算训练正确个数_, predicted = torch.max(outputs, 1)total_train += targets.size(0)correct_train += (predicted == targets).sum().item()# 计算训练正确率并记录train_loss /= len(train_dataloader)train_accuracy = correct_train / total_traintrain_losses.append(train_loss)train_accuracies.append(train_accuracy)# 输出训练信息print(f"Epoch [{epoch + 1}/{epochs}] - Train Loss: {train_loss:.4f}, Train Acc: {train_accuracy:.4f}")# 绘制损失和正确率曲线plt.figure(figsize=(10, 5))plt.subplot(1, 2, 1)plt.plot(range(epochs), train_losses, label='Training Loss')plt.xlabel('Epoch')plt.ylabel('Loss')plt.legend()plt.subplot(1, 2, 2)plt.plot(range(epochs), train_accuracies, label='Accuracy')plt.xlabel('Epoch')plt.ylabel('Accuracy')plt.legend()plt.tight_layout()plt.show()
6.模型训练
model = LeNet()
lr = 0.9 # sigmoid两端容易饱和,gradient比较小,学得比较慢,所以学习率要大一些
epochs = 20
train(model,lr,epochs)
7.模型测试
def test(model, test_dataloader, device, model_path):# 将模型设置为评估模式model.eval()# 将模型移动到指定设备上model.to(device)# 从给定路径加载模型的状态字典model.load_state_dict(torch.load(model_path))correct_test = 0total_test = 0# 不计算梯度with torch.no_grad():# 遍历测试数据加载器for inputs, targets in test_dataloader: # 将输入数据和标签移动到指定设备上inputs = inputs.to(device)targets = targets.to(device)# 模型进行推理outputs = model(inputs)# 获取预测结果中的最大值_, predicted = torch.max(outputs, 1)total_test += targets.size(0)# 统计预测正确的数量correct_test += (predicted == targets).sum().item()# 计算并打印测试数据的准确率test_accuracy = correct_test / total_testprint(f"Accuracy on Test: {test_accuracy:.4f}")return test_accuracy
model_path = save_path
test(model, test_dataloader, device, save_path)
相关文章:

CNN——LeNet
1.LeNet概述 LeNet是Yann LeCun于1988年提出的用于手写体数字识别的网络结构,它是最早发布的卷积神经网络之一,可以说LeNet是深度CNN网络的基石。 当时,LeNet取得了与支持向量机(support vector machines)性能相…...

分类模型评估方法
1.数据集划分 1.1 为什么要划分数据集? 思考:我们有以下场景: 将所有的数据都作为训练数据,训练出一个模型直接上线预测 每当得到一个新的数据,则计算新数据到训练数据的距离,预测得到新数据的类别 存在问题&…...

RabbitMQ高级
文章目录 一.消息可靠性1.生产者消息确认 MQ的一些常见问题 1.消息可靠性问题:如何确保发送的消息至少被消费一次 2.延迟消息问题:如何实现消息的延迟投递 3.高可用问题:如何避免单点的MQ故障而导致的不可用问题 4.消息堆积问题:如何解决数百万消息堆积,无法及时…...

SonarQube 漏洞扫描
SonarQube 漏洞扫描 一、部署服务 1.1 docker方式部署 #安装docker curl -L download.beyourself.org.cn/shell-project/os/get-docker-latest.sh | sh yum install -y docker-compose #进去输入:set paste可以保证不穿行 [rootlocalhost sonar]# vim docker-compose.yml v…...

Web前端篇——ElementUI的Backtop 不显示问题
在使用ElementUI的Backtop回到顶部组件时,单独复制这一行代码 <el-backtop :right"100" :bottom"100" /> 发现页面在向下滚动时,并未出现Backtop组件。 可从以下3个方向进行分析: 指定target属性,且…...

MySQL 管理工具
1、MySQL 管理 系统数据库 a. mysql 命令 语法:mysql [options] [database] -u,--username 指定用户名-p,--password[name] 指定密码-h, --hostname 指定服务器IP或域名-P, --portport 指定连接端-e,--executename 执行SQL语句并退出 mysql -h192.168.200.202 -…...

LeetCode 33 搜索旋转排序数组
题目描述 搜索旋转排序数组 整数数组 nums 按升序排列,数组中的值 互不相同 。 在传递给函数之前,nums 在预先未知的某个下标 k(0 < k < nums.length)上进行了 旋转,使数组变为 [nums[k], nums[k1], ..., num…...

分类预测 | Python实现基于SVM-RFE-LSTM的特征选择算法结合LSTM神经网络的多输入单输出分类预测
分类预测 | Python实现基于SVM-RFE-LSTM的特征选择算法结合LSTM神经网络的多输入单输出分类预测 目录 分类预测 | Python实现基于SVM-RFE-LSTM的特征选择算法结合LSTM神经网络的多输入单输出分类预测分类效果基本描述程序设计参考资料 分类效果 基本描述 基于SVM-RFE-LSTM的特征…...

JetBrains Rider使用总结
简介: JetBrains Rider 诞生于2016年,一款适配于游戏开发人员,是JetBrains旗下一款非常年轻的跨平台 .NET IDE。目前支持包括.NET 桌面应用、服务和库、Unity 和 Unreal Engine 游戏、Xamarin 、ASP.NET 和 ASP.NET Core web 等多种应用程序…...

C# Emgu.CV4.8.0读取rtsp流录制mp4可分段保存
【官方框架地址】 https://github.com/emgucv/emgucv 【算法介绍】 EMGU CV(Emgu Computer Vision)是一个开源的、基于.NET框架的计算机视觉库,它提供了对OpenCV(开源计算机视觉库)的封装。EMGU CV使得在.NET应用程序…...

java碳排放数据信息管理系统Myeclipse开发mysql数据库web结构java编程计算机网页项目
一、源码特点 java Web碳排放数据信息管理系统是一套完善的java web信息管理系统,对理解JSP java编程开发语言有帮助,系统具有完整的源代码和数据库,系统主要采用B/S模式开发。开发环 境为TOMCAT7.0,Myeclipse8.5开发,数据库为…...

K8S陈述式资源管理(1)
命令行: kubectl命令行工具 优点: 90%以上的场景都可以满足对资源的增,删,查比较方便,对改不是很友好 缺点:命令比较冗长,复杂,难记声明式 声明式:K8S当中的yaml文件来实现资源管理 GUI:图形…...

STL map容器与pair类模板(解决扫雷问题)
CSTL之Map容器 - 数据结构教程 - C语言网 (dotcpp.com)https://www.dotcpp.com/course/118CSTL之Pair类模板 - 数据结构教程 - C语言网 (dotcpp.com)https://www.dotcpp.com/course/119 刷到一个扫雷的题目,之前没有玩怎么过扫雷,于是我就去玩了玩…...

【React系列】Portals、Fragment
本文来自#React系列教程:https://mp.weixin.qq.com/mp/appmsgalbum?__bizMzg5MDAzNzkwNA&actiongetalbum&album_id1566025152667107329) Portals 某些情况下,我们希望渲染的内容独立于父组件,甚至是独立于当前挂载到的DOM元素中&am…...

ByteTrack算法流程的简单示例
ByteTrack ByteTrack算法是将t帧检测出来的检测框集合 D t {\mathcal{D}_{t}} Dt 和t-1帧预测轨迹集合 T ~ t − 1 {\tilde{T}_{t-1}} T~t−1 进行匹配关联得到t帧的轨迹集合 T t {T_{t}} Tt。 首先使用检测器检测t帧的图像得到检测框集合 D t {\mathcal{D}_{t}} …...

免费的GPT4来了,你还不知道吗?
程序员的公众号:源1024,获取更多资料,无加密无套路! 最近整理了一波电子书籍资料,包含《Effective Java中文版 第2版》《深入JAVA虚拟机》,《重构改善既有代码设计》,《MySQL高性能-第3版》&…...

win10报错“zlib.dll文件丢失,软件无法启动”,修复方法,亲测有效
zlib.dll文件是一个由Zlib创建的动态链接库文件,它是用于Windows操作系统的数据压缩和解压缩的软件。Zlib是一个广泛使用的软件库,广泛应用在许多不同类型的软件中,包括游戏、浏览器和操作系统。 zlib.dll的主要作用是提供数据压缩和解压缩的…...

MFC中如何使用CListCtrl可以编辑,并添加鼠标右键及双击事件。
要在MFC中使用CListCtrl来实现编辑功能,可以按照以下步骤进行操作: 在对话框资源中添加CListCtrl控件,并设置合适的属性。在对话框类的头文件中添加成员变量来管理CListCtrl控件,例如: CListCtrl m_listCtrl; 3. 在O…...

[每周一更]-(第81期):PS抠图流程(扭扭曲曲的身份证修正)
应朋友之急,整理下思路,分享一下~~ 分两步走:先用磁性套索工具圈出要处理的图;然后使用透视剪裁工具,将扭曲的图片拉平即可;(macbook pro) 做事有规则,才能更高效;用什么工具,先列举…...

Kafka安全认证机制详解之SASL_PLAIN
一、概述 官方文档: https://kafka.apache.org/documentation/#security 在官方文档中,kafka有五种加密认证方式,分别如下: SSL:用于测试环境SASL/GSSAPI (Kerberos) :使用kerberos认证,密码是…...

2023南京理工大学通信工程818信号系统及数电考试大纲
注:(Δ)表示重点内容。具体内容详见博睿泽信息通信考研论坛 参考书目: [1] 钱玲,谷亚林,王海青. 信号与系统(第五版). 北京:电子工业出版社 [2] 郑君里,应…...

wsl(ubuntu)创建用户
我们打卡ubuntu窗口,如果没有创建用户,那么默认是root用户 用户的增删改查 查 查询所有的用户列表 cat /etc/passwd | cut -d: -f1cat /etc/passwd: 这个命令用于显示 /etc/passwd 文件的内容。/etc/passwd 文件包含了系统上所有用户的基本信息。每一…...

[足式机器人]Part2 Dr. CAN学习笔记-自动控制原理Ch1-8Lag Compensator滞后补偿器
本文仅供学习使用 本文参考: B站:DR_CAN Dr. CAN学习笔记-自动控制原理Ch1-8Lag Compensator滞后补偿器 从稳态误差入手(steady state Error) 误差 Error : E ( s ) R ( s ) − X ( s ) R ( s ) − E ( s ) ⋅ K G …...

swift-碰到的问题
如何让工程不使用storyboard和scene 删除info.plist里面的Application Scene mainifest 删除SceneDelegate.swift 删除AppDelegate.swift里面的这两个方法 func application(_ application: UIApplication, configurationForConnecting connectingSceneSession: UISceneSession…...

安全与认证Week4
目录 目录 Web Security (TLS/SSL) 各层安全协议 Transport Layer Security (TLS)传输层安全性(TLS) SSL和TLS的联系与区别 TLS connection&session 连接与会话 题目2答案点 TLS ArchitectureTLS架构(5个协议) 题目1答案点 Handshake Proto…...

Golang高质量编程与性能调优实战
1.1 简介 高质量:编写的代码能否达到正确可靠、简洁清晰的目标 各种边界条件是否考虑完备异常情况处理,稳定性保证易读易维护编程原则 简单性 消除多余的重复性,以简单清晰的逻辑编写代码不理解的代码无法修复改进可读性 代码是写给人看的,并不是机器编写可维护代码的第一…...

vite 如何打包 dist 文件到 zip 使用插件 vite-plugin-zip-pack,vue3 ts
vite 如何打包 dist 文件到 zip 使用插件 vite-plugin-zip-pack,vue3 ts 开发过程中一个经常做的事就是将 ./dist 文件夹打包成 zip 分发。 每次手动打包还是很费劲的, vite 同样也有能把 ./dist 文件夹打包成 .zip 的插件,当然这个打包的文…...

jdbc源码研究
JDBC介绍 JDBC(Java Data Base Connectivity,java数据库连接)是一种用于执行SQL语句的Java API,可以为多种关系数据库提供统一访问,它由一组用Java语言编写的类和接口组成。 开发者不必为每家数据通信协议的不同而疲于奔命&#…...

挠性及刚挠结合印制电路技术
1.1挠性印制电路板概述 20世纪70年代末期,以日本厂商为主导,逐渐将挠性印制电路板(flexible printedcircuit board,FPCB,简称为FPC)广泛应用于计算机、照相机、打印机、汽车音响、硬盘驱动器等电子信息产品中。20世纪90年代初期&…...

Python+OpenGL绘制3D模型(七)制作3dsmax导出插件
系列文章 一、逆向工程 Sketchup 逆向工程(一)破解.skp文件数据结构 Sketchup 逆向工程(二)分析三维模型数据结构 Sketchup 逆向工程(三)软件逆向工程从何处入手 Sketchup 逆向工程(四…...