pytorch,用lenet5识别cifar10数据集(训练+测试+单张图片识别)
目录
LeNet-5
LeNet-5 结构
CIFAR-10
pytorch实现
lenet模型
训练模型
1.导入数据
2.训练模型
3.测试模型
测试单张图片
代码
运行结果
LeNet-5
LeNet-5 是由 Yann LeCun 等人在 1998 年提出的一种经典卷积神经网络(CNN)模型,主要用于手写数字识别任务。它在 MNIST 数据集上表现出色,并且是深度学习历史上的一个重要里程碑。
LeNet-5 结构
LeNet-5 的结构包括以下几个层次:
- 输入层: 32x32 的灰度图像。
- 卷积层 C1: 包含 6 个 5x5 的滤波器,输出尺寸为 28x28x6。
- 池化层 S2: 平均池化层,输出尺寸为 14x14x6。
- 卷积层 C3: 包含 16 个 5x5 的滤波器,输出尺寸为 10x10x16。
- 池化层 S4: 平均池化层,输出尺寸为 5x5x16。
- 卷积层 C5: 包含 120 个 5x5 的滤波器,输出尺寸为 1x1x120。
- 全连接层 F6: 包含 84 个神经元。
- 输出层: 包含 10 个神经元,对应于 10 个类别。
CIFAR-10
CIFAR-10 是一个常用的图像分类数据集,包含 10 个类别的 60,000 张 32x32 彩色图像。每个类别有 6,000 张图像,其中 50,000 张用于训练,10,000 张用于测试。
1. 标注数据量训练集:50000张图像测试集:10000张图像
2. 标注类别数据集共有10个类别。具体分类见图1。
3. 可视化
pytorch实现
lenet模型
- 平均池化(Average Pooling):对池化窗口内所有像素的值取平均,适合保留图像的背景信息。
- 最大池化(Max Pooling):对池化窗口内的最大值进行选择,适合提取显著特征并具有降噪效果。
在实际应用中,最大池化更常用,因为它通常能更好地保留重要特征并提高模型的性能。
import torch.nn as nn
import torch.nn.functional as funcclass LeNet(nn.Module):def __init__(self):super(LeNet, self).__init__()self.conv1 = nn.Conv2d(3, 6, kernel_size=5)self.conv2 = nn.Conv2d(6, 16, kernel_size=5)self.fc1 = nn.Linear(16*5*5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = func.relu(self.conv1(x))x = func.max_pool2d(x, 2)x = func.relu(self.conv2(x))x = func.max_pool2d(x, 2)x = x.view(x.size(0), -1)x = func.relu(self.fc1(x))x = func.relu(self.fc2(x))x = self.fc3(x)return x
训练模型
1.导入数据
导入训练数据和测试数据
def load_data(self):#transforms.RandomHorizontalFlip() 是 pytorch 中用来进行随机水平翻转的函数。它将以一定概率(默认为0.5)对输入的图像进行水平翻转,并返回翻转后的图像。这可以用于数据增强,使模型能够更好地泛化。train_transform = transforms.Compose([transforms.RandomHorizontalFlip(), transforms.ToTensor()])test_transform = transforms.Compose([transforms.ToTensor()])train_set = torchvision.datasets.CIFAR10(root='./data', train=True, download=True, transform=train_transform)self.train_loader = torch.utils.data.DataLoader(dataset=train_set, batch_size=self.train_batch_size, shuffle=True)# shuffle=True 表示在每次迭代时,数据集都会被重新打乱。这可以防止模型在训练过程中过度拟合训练数据,并提高模型的泛化能力。test_set = torchvision.datasets.CIFAR10(root='./data', train=False, download=True, transform=test_transform)self.test_loader = torch.utils.data.DataLoader(dataset=test_set, batch_size=self.test_batch_size, shuffle=False)
2.训练模型
def train(self):print("train:")self.model.train()train_loss = 0train_correct = 0total = 0for batch_num, (data, target) in enumerate(self.train_loader):data, target = data.to(self.device), target.to(self.device)self.optimizer.zero_grad()output = self.model(data)loss = self.criterion(output, target)loss.backward()self.optimizer.step()train_loss += loss.item()prediction = torch.max(output, 1) # second param "1" represents the dimension to be reducedtotal += target.size(0)# train_correct incremented by one if predicted righttrain_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())progress_bar(batch_num, len(self.train_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'% (train_loss / (batch_num + 1), 100. * train_correct / total, train_correct, total))return train_loss, train_correct / total
3.测试模型
def test(self):print("test:")self.model.eval()test_loss = 0test_correct = 0total = 0with torch.no_grad():for batch_num, (data, target) in enumerate(self.test_loader):data, target = data.to(self.device), target.to(self.device)output = self.model(data)loss = self.criterion(output, target)test_loss += loss.item()prediction = torch.max(output, 1)total += target.size(0)test_correct += np.sum(prediction[1].cpu().numpy() == target.cpu().numpy())progress_bar(batch_num, len(self.test_loader), 'Loss: %.4f | Acc: %.3f%% (%d/%d)'% (test_loss / (batch_num + 1), 100. * test_correct / total, test_correct, total))return test_loss, test_correct / total
测试单张图片
网上随便下载一个图片
然后使用图片编辑工具,把图片设置为32x32大小
通过导入模型,然后测试一下
代码
import torch
import cv2
import torch.nn.functional as F
#from model import Net ##重要,虽然显示灰色(即在次代码中没用到),但若没有引入这个模型代码,加载模型时会找不到模型
from torch.autograd import Variable
from torchvision import datasets, transforms
import numpy as npclasses = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')
if __name__ == '__main__':device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')model = torch.load('lenet.pth') # 加载模型model = model.to(device)model.eval() # 把模型转为test模式img = cv2.imread("bird1.png") # 读取要预测的图片trans = transforms.Compose([transforms.ToTensor()])img = trans(img)img = img.to(device)img = img.unsqueeze(0) # 图片扩展多一维,因为输入到保存的模型中是4维的[batch_size,通道,长,宽],而普通图片只有三维,[通道,长,宽]# 扩展后,为[1,1,28,28]output = model(img)prob = F.softmax(output,dim=1) #prob是10个分类的概率print(prob)value, predicted = torch.max(output.data, 1)print(predicted.item())print(value)pred_class = classes[predicted.item()]print(pred_class)
运行结果
tensor([[1.8428e-01, 1.3935e-06, 7.8295e-01, 8.5042e-04, 3.0219e-06, 1.6916e-04,5.8798e-06, 3.1647e-02, 1.7037e-08, 8.9128e-05]], device='cuda:0',grad_fn=<SoftmaxBackward0>)
2
tensor([4.0915], device='cuda:0')
bird
从结果看,效果还不错。记录一下
相关文章:

pytorch,用lenet5识别cifar10数据集(训练+测试+单张图片识别)
目录 LeNet-5 LeNet-5 结构 CIFAR-10 pytorch实现 lenet模型 训练模型 1.导入数据 2.训练模型 3.测试模型 测试单张图片 代码 运行结果 LeNet-5 LeNet-5 是由 Yann LeCun 等人在 1998 年提出的一种经典卷积神经网络(CNN)模型,主要…...

Word卡顿的处理方法
1. 检查和关闭后台程序 关闭不必要的后台程序,释放系统资源。使用任务管理器(Ctrl + Shift + Esc)查看占用CPU和内存较高的应用,并关闭它们。2. 更新Microsoft Office 确保你的Microsoft Office软件是最新版本。新版本通常修复了已知的性能问题。打开Word,点击文件 > 账…...

在 Linux上常见的10大压缩格式解压命令和它们对应的压缩格式
文章目录 前言一、解压 .zip 文件二、解压 .tar.gz 或 .tgz 文件三、解压 .tar 文件四、解压 .tar.bz2 文件五、解压 .tar.xz 文件六、解压 .gz 文件七、解压 .bz2 文件八、解压 .xz 文件九、解压 .7z 文件十、解压 .rar 文件总结 前言 Linux 命令可以解压不同格式的压缩文件。…...

【数据结构】三、栈和队列:6.链队列、双端队列、队列的应用(树的层次遍历、广度优先BFS、先来先服务FCFS)
文章目录 2.链队列2.1初始化(带头结点)不带头结点 2.2入队(带头结点)2.3出队(带头结点)❗2.4链队列c实例 3.双端队列考点:输出序列合法性栈双端队列 队列的应用1.树的层次遍历2.图的广度优先遍历3.操作系统…...

技术速递|使用 Native Library Interop 为 .NET MAUI 创建绑定
作者:Rachel Kang 排版:Alan Wang 在当今的应用开发领域,通过利用本机功能来扩展 .NET 应用程序的能力非常宝贵。.NET MAUI 处理程序架构使开发人员能够使用 .NET 代码直接操作本机控件,甚至允许无缝创建跨平台自定义控件。然而&a…...

Linux笔记 --- 标准IO
系统IO的最大特点一个是更具通用性,不管是普通文件、管道文件、设备节点文件、接字文件等等都可以使用,另一个是他的简约性,对文件内数据的读写在任何情况下都是带任何格式的,而且数据的读写也都没有经过任何缓冲处理,…...

洛谷:B3625 迷宫寻路
迷宫寻路 题目描述 机器猫被困在一个矩形迷宫里。 迷宫可以视为一个 n m n\times m nm 矩阵,每个位置要么是空地,要么是墙。机器猫只能从一个空地走到其上、下、左、右的空地。 机器猫初始时位于 ( 1 , 1 ) (1, 1) (1,1) 的位置,问能否…...

【C#】explicit、implicit与operator
字面解释 explicit:清楚明白的;易于理解的;(说话)清晰的,明确的;直言的;坦率的;直截了当的;不隐晦的;不含糊的。 implicit:含蓄的;不直接言明的;成为一部分的;内含的;完全的;无疑问的。 operator:操作人员;技工;电话员;接线员;…...

Vue:Vuex-Store使用指南
一、简介 1.1Vuex 是什么 Vuex 是一个专为 Vue.js 应用程序开发的状态管理模式。它采用集中式存储管理应用的所有组件的状态,并以相应的规则保证状态以一种可预测的方式发生变化。Vuex 也集成到 Vue 的官方调试工具 devtools extension (opens new window)…...

对经典动态规划问题【爬台阶】的一些思考
背景 今天在做Leetcode题目时,做到了一道经典的动态规划问题:爬楼梯,题目的大致意思很简单,有个小孩正在上楼梯,楼梯有n阶台阶,小孩一次可以上1阶、2阶或3阶。实现一种方法,计算小孩有多少种上…...

开发一个能打造虚拟带货直播间的工具!
在当今数字化时代,直播带货已成为电商领域的一股强劲力量,其直观、互动性强的特点极大地提升了消费者的购物体验。 然而,随着技术的不断进步,传统直播带货模式正逐步向更加智能化、虚拟化的方向演进,本文将深入探讨如…...

汽车补光照明实验太阳光模拟器光源
汽车补光照明实验概览 汽车补光照明实验是汽车照明领域的一个重要环节,它涉及到汽车照明系统的性能测试和优化。实验的目的在于确保汽车在各种光照条件下都能提供良好的照明效果,以提高行车安全。实验内容通常包括但不限于灯光的亮度、色温、均匀性、响应…...

MediaPipe人体姿态、手指关键点检测
MediaPipe人体姿态、手指关键点检测 文章目录 MediaPipe人体姿态、手指关键点检测前言一、手指关键点检测二、姿态检测三、3D物体案例检测案例 前言 Mediapipe是google的一个开源项目,用于构建机器学习管道。 提供了16个预训练模型的案例:人脸检测、…...

树上dp之换根dp
基本概念: 换根dp是树上dp的一种 我们在什么时候需要用到换根dp呢? 当题目询问的属性,是需要当前结点为根时的属性,这个时候,我们就要使用换根dp 换根dp的基本思路: 假设题目询问的的属性为x 通常我们…...

2024/8/13 英语每日一段
Mackey says while Whole Foods has become more homogenized under Amazon, it did enable the store to do what it couldn’t have done independently. “People saw us as too expensive and out of touch with our customers,” he says. “The main thing Whole Foods n…...

Java多线程练习(1)
MultiProcessingExercise package MultiProcessingExercise120240813;public class MultiProcessingExercise {public static void main(String[] args) {/*需求:一共有1000张电影票,可以在两个窗口领取,假设每次领取的时间为3000毫秒,请用多线程模拟卖票过程并打印…...

AI高级肖像动画神器LivePortrait
文章目录 前言一、安装1.1 源码安装1.2 windows一键启动包 二、人像生成2.1 浏览器2.2 输入图像2.3 选择驱动视频2.4 生成2.5 结果 三、动物生成3.1 浏览器3.2 输入图片3.3 选择视频3.4 生成3.5 最终结果 四、软件获取 前言 最近,快手可灵大模型团队、中国科学技术…...

Java反射机制深度解析与实践应用
Java反射机制深度解析与实践应用 引言 Java反射是Java语言提供的一种能力,允许程序在运行时访问、检测和修改其自身的属性和行为。反射机制是Java面向对象编程的一大亮点,也是Java框架和库常用的技术之一。 反射的基本概念 反射的核心是java.lang.re…...

Oracle递归查询层级及路径
一、建表及插入数据 ocation_idlocation_nameparent_location_id1广东省NULL2广州市13深圳市14天河区25番禺区26南山区37宝安区3 建表sql: CREATE TABLE locations (location_id NUMBER PRIMARY KEY,location_name VARCHAR2(100),parent_location_id NUMBER ); I…...

leetcode300. 最长递增子序列,动态规划附状态转移方程
leetcode300. 最长递增子序列 给你一个整数数组 nums ,找到其中最长严格递增子序列的长度。 子序列 是由数组派生而来的序列,删除(或不删除)数组中的元素而不改变其余元素的顺序。例如,[3,6,2,7] 是数组 [0,3,1,6,2,2…...

C语言:字符串函数strcpy
该函数用于字符串的拷贝。 使用方法如下: #include<stdio.h> #include<string.h>int main() {char str[10];char* str1 "abcd";//strcpy(str, str1);//把str1复制到str,但此函数不安全所以用strcpy_sstrcpy_s(str, 10, str1);/…...

Day16-指针2
数组指针与指针数组 变量指针:指向变量的地址。 数组指针:指向数组的地址。 指针变量:存放其他变量地址的变量。 指针数组:存放数组元素指针的变量。 数组指针 概念:数组指针是指向数组的指针。特点: 先…...

数据结构(5.5_3)——并查集的进一步优化
Find操作的优化(压缩路径) 压缩路径——Find操作,先找到根节点,再将查找路径上所有结点都挂到根结点下 代码: //Find "查"操作优化,先找到根节点,再进行"路径压缩" int Find(int S[], int x) {…...

(回溯) LeetCode 131. 分割回文串
原题链接 一. 题目描述 给你一个字符串 s,请你将 s 分割成一些子串,使每个子串都是 回文串。返回 s 所有可能的分割方案。 示例 1: 输入:s "aab" 输出:[["a","a","b"],[…...

【Linux】阻塞信号|信号原理|深入理解捕获信号|内核态|用户态|sigaction|可重入函数|volatile|SIGCHILD|万字详解
目录 编辑 一,常见的信号术语 二,信号在内核中的表示 信号标志位 Pending表 Block表 handler表 POSIX.1标准 三,sigset_t 信号集操作函数 sigemptyset sigfillset sigaddset sigdelset sigismember sigprocmask sig…...

基于Linux对 【进程地址空间】的详细讲解
研究背景: ● kernel 2.6.32 ● 32位平台 –❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀-正文开始-❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀–❀– 在学习操作系统中想必大家肯定都见过下面这…...

[python]使用Pandas处理多个Excel文件并汇总数据
在数据分析和处理过程中,经常需要处理多个Excel文件,并将其中的数据进行汇总和分析。本文介绍使用Python的Pandas库来读取多个Excel文件,并汇总不同类型的数据,例如员工工资、工件数量等。 代码示例 以下是一个完整的代码示例&a…...

提升体验:UI设计的可用性原则
在中国,每年都有数十万设计专业毕业生涌入市场,但只有少数能够进入顶尖企业。尽管如此,所有设计师都怀揣着成长和提升的愿望。在评价产品的用户体验时,我们可能会依赖直觉来决定设计方案,或者在寻找改善产品体验的切入…...

x264 编码器 SSIM 算法源码分析
SSIM SSIM(Structural Similarity Index)是一种用于衡量两幅图像之间视觉相似度的指标。它不仅考虑了图像的亮度、对比度和饱和度,还考虑了图像的结构信息。SSIM的值介于-1到1之间,值越接近1表示两幅图像越相似。 SSIM是基于以下三个方面来计算的: 亮度(Luminance):比…...

echarts使图表组件根据屏幕尺寸变更而重新渲染大小
效果图: 通过 window.addEventListener(resize, this.resizeChart); 实现 完整代码: <template><div class"dunBlock"><div class"char2" id"char2" ref"chart"></div></div…...