卷积神经网络——LeNet——FashionMNIST
目录
- 一、整体结构
- 二、model.py
- 三、model_train.py
- 四、model_test.py
GitHub地址
一、整体结构

二、model.py
import torch
from torch import nn
from torchsummary import summaryclass LeNet(nn.Module):def __init__(self):super(LeNet,self).__init__()self.c1 = nn.Conv2d(in_channels=1,out_channels=6,kernel_size=5,padding=2)self.sig = nn.Sigmoid()self.s2 = nn.AvgPool2d(kernel_size=2,stride=2)self.c3 = nn.Conv2d(in_channels=6,out_channels=16,kernel_size=5)self.s4 = nn.AvgPool2d(kernel_size=2,stride=2)self.flatten = nn.Flatten()self.f5 = nn.Linear(in_features=5*5*16,out_features=120)self.f6 = nn.Linear(in_features=120,out_features=84)self.f7 = nn.Linear(in_features=84,out_features=10)def forward(self,x):x = self.sig(self.c1(x))x = self.s2(x)x = self.sig(self.c3(x))x = self.s4(x)x = self.flatten(x)x = self.f5(x)x = self.f6(x)x = self.f7(x)return x# if __name__ =="__main__":
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#
# model = LeNet().to(device)
#
# print(summary(model,input_size=(1,28,28)))
三、model_train.py
# 导入所需的Python库
from torchvision.datasets import FashionMNIST
from torchvision import transforms
import torch.utils.data as Data
import torch
from torch import nn
import time
import copy
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from model import LeNet # model.py中定义了LeNet模型
from tqdm import tqdm # 导入tqdm库,用于显示进度条# 定义数据加载和处理函数
def train_val_data_process():# 加载FashionMNIST数据集,Resize到28x28尺寸,并转换为Tensortrain_data = FashionMNIST(root="./data",train=True,transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),download=True)# 将加载的数据集分为80%的训练数据和20%的验证数据train_data, val_data = Data.random_split(train_data, lengths=[round(0.8 * len(train_data)), round(0.2 * len(train_data))])# 为训练数据和验证数据创建DataLoader,设置批量大小为32,洗牌,2个进程加载数据train_dataloader = Data.DataLoader(dataset=train_data,batch_size=32,shuffle=True,num_workers=2)val_dataloader = Data.DataLoader(dataset=val_data,batch_size=32,shuffle=True,num_workers=2)# 返回训练和验证的DataLoaderreturn train_dataloader, val_dataloader# 定义模型训练和验证过程的函数
def train_model_process(model, train_dataloader, val_dataloader, num_epochs):# 设置使用CUDA如果可用device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 打印使用的设备dev = "cuda" if torch.cuda.is_available() else "cpu"print(f'当前模型训练设备为: {dev}')# 初始化Adam优化器和交叉熵损失函数optimizer = torch.optim.Adam(model.parameters(), lr=0.001)criterion = nn.CrossEntropyLoss()# 将模型移动到选定的设备上model = model.to(device)# 复制模型权重用于后续更新最佳模型best_model_wts = copy.deepcopy(model.state_dict())best_acc = 0.0 # 初始化最佳准确度# 初始化用于记录训练和验证过程中损失和准确度的列表train_loss_all = []val_loss_all = []train_acc_all = []val_acc_all = []# 记录训练开始时间start_time = time.time()# 迭代指定的训练轮数for epoch in range(1, num_epochs + 1):# 记录每个epoch开始的时间since = time.time()# 打印分隔符和当前epoch信息print("-" * 10)print(f"Epoch: {epoch}/{num_epochs}")# 初始化训练和验证过程中的损失和正确预测数量train_loss = 0.0train_corrects = 0val_loss = 0.0val_corrects = 0# 初始化批次计数器train_num = 0val_num = 0# 创建训练进度条progress_train_bar = tqdm(total=len(train_dataloader), desc=f'Training {epoch}', unit='batch')# 训练数据集的遍历for step, (b_x, b_y) in enumerate(train_dataloader):# 将数据移动到相应的设备上b_x = b_x.to(device)b_y = b_y.to(device)# 训练模型model.train()# 前向传播output = model(b_x)# 计算预测标签pre_label = torch.argmax(output, dim=1)# 计算损失loss = criterion(output, b_y)# 清空梯度optimizer.zero_grad()# 反向传播loss.backward()# 更新权重optimizer.step()# 累加损失和正确预测数量train_loss += loss.item() * b_x.size(0)train_corrects += torch.sum(pre_label == b_y.data)# 更新批次计数器train_num += b_x.size(0)# 更新训练进度条progress_train_bar.update(1)# 关闭训练进度条progress_train_bar.close()# 创建验证进度条progress_val_bar = tqdm(total=len(val_dataloader), desc=f'Validation {epoch}', unit='batch')# 验证数据集的遍历for step, (b_x, b_y) in enumerate(val_dataloader):# 将数据移动到相应的设备上b_x = b_x.to(device)b_y = b_y.to(device)# 评估模型model.eval()# 前向传播output = model(b_x)# 计算预测标签pre_label = torch.argmax(output, dim=1)# 计算损失loss = criterion(output, b_y)# 累加损失和正确预测数量val_loss += loss.item() * b_x.size(0)val_corrects += torch.sum(pre_label == b_y.data)# 更新批次计数器val_num += b_x.size(0)# 更新验证进度条progress_val_bar.update(1)# 关闭验证进度条progress_val_bar.close()# 计算并记录epoch的平均损失和准确度train_loss_all.append(train_loss / train_num)train_acc_all.append(train_corrects.double().item() / train_num)val_loss_all.append(val_loss / val_num)val_acc_all.append(val_corrects.double().item() / val_num)# 打印训练和验证的损失与准确度print(f'{epoch} Train Loss: {train_loss_all[-1]:.4f} Train Acc: {train_acc_all[-1]:.4f}')print(f'{epoch} Val Loss: {val_loss_all[-1]:.4f} Val Acc: {val_acc_all[-1]:.4f}')# 计算并打印epoch训练耗费的时间time_use = time.time() - sinceprint(f'第 {epoch} 个 epoch 训练耗费时间: {time_use // 60:.0f}m {time_use % 60:.0f}s')# 若当前epoch的验证准确度为最佳,则更新最佳模型权重if val_acc_all[-1] > best_acc:best_acc = val_acc_all[-1]best_model_wts = copy.deepcopy(model.state_dict())# 训练结束,保存最佳模型权重torch.save(best_model_wts, 'D:/Pycharm/deepl/LeNet/weight/best_model.pth')# 如果当前epoch为总epoch数,则保存最终模型权重if epoch == num_epochs:torch.save(model.state_dict(), f'D:/Pycharm/deepl/LeNet/weight/{num_epochs}_model.pth')# 将训练过程中的统计数据整理成DataFrametrain_process = pd.DataFrame(data={"epoch": range(1, num_epochs + 1),"train_loss_all": train_loss_all,"val_loss_all": val_loss_all,"train_acc_all": train_acc_all,"val_acc_all": val_acc_all})# 打印总训练时间consume_time = time.time() - start_timeprint(f'总耗时:{consume_time // 60:.0f}m {consume_time % 60:.0f}s')# 返回包含训练过程统计数据的DataFramereturn train_process# 定义绘制训练和验证过程中损失与准确度的函数
def matplot_acc_loss(train_process):# 创建图形和子图plt.figure(figsize=(12, 4))# 绘制训练和验证损失plt.subplot(1, 2, 1)plt.plot(train_process["epoch"], train_process["train_loss_all"], 'ro-', label="train_loss")plt.plot(train_process["epoch"], train_process["val_loss_all"], 'bs-', label="val_loss")plt.legend()plt.xlabel("epoch")plt.ylabel("loss")# 保存损失图像plt.savefig('./result_picture/training_loss_accuracy.png', bbox_inches='tight')# 绘制训练和验证准确度plt.subplot(1, 2, 2)plt.plot(train_process["epoch"], train_process["train_acc_all"], 'ro-', label="train_acc")plt.plot(train_process["epoch"], train_process["val_acc_all"], 'bs-', label="val_acc")plt.legend()plt.xlabel("epoch")plt.ylabel("accuracy")# 保存准确率曲线图plt.savefig('./result_picture/training_accuracy.png', bbox_inches='tight')plt.show()if __name__ == "__main__":model = LeNet()train_dataloader, val_dataloader = train_val_data_process()train_process = train_model_process(model, train_dataloader, val_dataloader, num_epochs=20)matplot_acc_loss(train_process)
四、model_test.py
import torch
import torch.utils.data as Data
from torchvision import transforms
from torchvision.datasets import FashionMNIST
from model import LeNet
from sklearn.metrics import confusion_matrix
import seaborn as sns
import matplotlib.pyplot as plt
# t代表testdef t_data_process():test_data = FashionMNIST(root="./data",train=False,transform=transforms.Compose([transforms.Resize(size=28), transforms.ToTensor()]),download=True)test_dataloader = Data.DataLoader(dataset=test_data,batch_size=1,shuffle=True,num_workers=0)return test_dataloaderdef t_model_process(model, test_dataloader):if model is not None:print('Successfully loaded the model.')device = "cuda" if torch.cuda.is_available() else "cpu"model = model.to(device)# 初始化参数test_corrects = 0.0test_num = 0all_preds = [] # 存储所有预测标签all_labels = [] # 存储所有实际标签# 只进行前向传播,不计算梯度with torch.no_grad():for test_x, test_y in test_dataloader:test_x = test_x.to(device)test_y = test_y.to(device)# 设置模型为验证模式model.eval()# 前向传播得到一个batch的结果output = model(test_x)# 查找最大值对应的行标pre_lab = torch.argmax(output, dim=1)# 收集预测和实际标签all_preds.extend(pre_lab.tolist())all_labels.extend(test_y.tolist())# 计算准确率test_corrects += torch.sum(pre_lab == test_y.data)# 将所有的测试样本进行累加test_num += test_x.size(0)# 计算准确率test_acc = test_corrects.double().item() / test_numprint(f'测试的准确率:{test_acc}')# 绘制混淆矩阵conf_matrix = confusion_matrix(all_labels, all_preds)sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues')plt.xlabel('Predicted labels')plt.ylabel('True labels')plt.title('Confusion Matrix')plt.show()plt.savefig('./result_picture/Confusion_Matrix.png', bbox_inches='tight')if __name__=="__main__":# 加载模型model = LeNet()print('loading model')# 加载权重model.load_state_dict(torch.load('D:/Pycharm/deepl/LeNet/weight/best_model.pth'))# 加载测试数据test_dataloader = t_data_process()# 加载模型测试的函数t_model_process(model,test_dataloader)device = "cuda" if torch.cuda.is_available() else "cpu"model = model.to(device)classes = ['T-shirt/top','Trouser','Pullover','Dress','coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']with torch.no_grad():for b_x,b_y in test_dataloader:b_x = b_x.to(device)b_y = b_y.to(device)model.eval()output = model(b_x)pre_lab = torch.argmax(output,dim=1)result = pre_lab.item()label = b_y.item()print(f'预测值:{classes[result]}',"-----------",f'真实值:{classes[label]}')
相关文章:
卷积神经网络——LeNet——FashionMNIST
目录 一、整体结构二、model.py三、model_train.py四、model_test.py GitHub地址 一、整体结构 二、model.py import torch from torch import nn from torchsummary import summaryclass LeNet(nn.Module):def __init__(self):super(LeNet,self).__init__()self.c1 nn.Conv…...
k8s-第十二节-DaemonSet
DaemonSet是什么? DaemonSet 是一个确保全部或者某些节点上必须运行一个 Pod的工作负载资源(守护进程),当有node(节点)加入集群时, 也会为他们新增一个 Pod。 下面是常用的使用案例: 可以用来部署以下进程的pod 集群守护进程,如Kured、node-problem-detector日志收集…...
Mysql-内置函数
一.什么是函数? 函数是指一段可以直接被另外一段程序调用的程序或代码。 mysql内置了很多的函数,我们只需要调用即可。 二.字符串函数 MySQL中内置了很多字符串函数: 三.根据需求完成以下SQL编写 由于业务需求变更,企业员工的工号,统一为5位数,目前不足5位数的全…...
新浪API系列:支付API打造无缝支付体验,畅享便利生活(3)
在当今数字化时代,支付功能已经成为各类应用和平台的必备要素之一。作为开发者,要构建出安全、便捷的支付解决方案,新浪支付API是你不可或缺的利器。新浪支付API提供了全面而强大的接口和功能,帮助开发者轻松实现在线支付的集成和…...
终于弄明白了什么是EI!
EI是Engineering Index的缩写,中文意为“工程索引”,是由美国工程信息公司(Engineering Information, Inc.)编辑出版的著名检索工具。它始创于1884年,拥有超过一个世纪的历史,是全球工程界最权威的文献检索系统之一。EI虽然名为“…...
微信小程序常见页面跳转方式
1. wx.navigateTo() 保留当前页,跳转到不是 tabbar 的页面,会新增页面到页面栈。通过返回按钮或 wx.navigateBack()返回上一个页面。 2. wx.redirectTo() 跳转到不是 tabbar 的页面,替换当前页面。不能返回。 3. wx.switchTab() 跳转到 …...
Vim常用整理快捷键
一、光标跳转 参数释义w下一行首字符e下一行尾字符0跳至行首$跳至行尾gg跳至文首5gg跳至第五行gd标记跳转到当前光标所在的变量的定义位置fn找当前行后的n字符,跳转到n字符位置 二、修改类操作 参数释义D删除光标之后的字符dd删除整行x删除当前字符yy复制一行p向…...
【docker 把系统盘空间耗没了!】windows11 更改 ubuntu 子系统存储位置
系统:win11 ubuntu 22 子系统,docker 出现问题:系统盘突然没空间了,一片红 经过排查,发现 AppData\Local\packages\CanonicalGroupLimited.Ubuntu22.04LTS_79rhkp1fndgsc\ 这个文件夹竟然有 90GB 下面提供解决办法 步…...
前端如何让网页页面完美适配不同大小和分辨率屏幕
推荐使用postcss插件,它会自动将项目所有的px单位统一转换为vw等单位(包括npm安装的第三方组件),从而实现适配,具体配置规则可参考官网或npm网站介绍。 另外对于大屏的适配,需要缩放网页,可使用…...
gitlab-runner安装部署CI/CD
手动安装 卸载旧版: gitlab-runner --version gitlab-runner stop yum remove gitlab-runner下载gitlab对应版本的runner # https://docs.gitlab.com/runner/install/bleeding-edge.html#download-any-other-tagged-releasecurl -L --output /usr/bin/gitlab-run…...
数据分析案例-2024 年全电动汽车数据集可视化分析
🤵♂️ 个人主页:艾派森的个人主页 ✍🏻作者简介:Python学习者 🐋 希望大家多多支持,我们一起进步!😄 如果文章对你有帮助的话, 欢迎评论 💬点赞Ǵ…...
H桥驱动器芯片详解
H桥驱动器芯片详解 上一篇文章讲解了H桥驱动器的控制原理,本文以汽车行业广泛应用的DRV8245芯片为例,详细讲解基于集成电路的H桥驱动器芯片。 1.概述 DRV824x-Q1系列器件是德州仪器(TI)的一款专为汽车应用设计的全集成H桥驱动器…...
哪个充电宝口碑比较好?怎么选充电宝?2024年口碑优秀充电宝推荐
在如今快节奏的生活中,充电宝已然成为我们日常生活中的必备品。然而,市场上充电宝品牌众多,质量参差不齐,如何选择一款安全、可靠且口碑优秀的充电宝成为了消费者关注的焦点。安全性能不仅关系到充电宝的使用寿命,更关…...
Memcached 介绍与详解及在Java Spring Boot项目中的使用与集成
Memcached 介绍 Memcached 是一种高性能的分布式内存对象缓存系统,主要用于加速动态Web应用以减少数据库负载,从而提高访问速度和性能。作为一个开源项目,Memcached 被广泛应用于许多大型互联网公司,如Facebook、Twitter 和 YouT…...
淮北在选择SCADA系统时,哪些因素会影响其稳定性?
关键字:LP-SCADA系统, 传感器可视化, 设备可视化, 独立SPC系统, 智能仪表系统,SPC可视化,独立SPC系统 在选择SCADA系统时,稳定性是一个关键因素,因为它直接影响到生产过程的连续性和安全性。以下是一些影响SCADA系统稳定性的因素: 硬件质量…...
Linux: 命令行参数和环境变量究竟是什么?
Linux: 命令行参数和环境变量究竟是什么? 一、命令行参数1.1 main函数参数意义1.2 命令行参数概念1.3 命令行参数实例 二、环境变量2.1 环境变量概念2.2 环境变量:PATH2.2.1 如何查看PATH中的内容2.2.2 如何让自己的可执行文件不带路径运行 2.3 环境变量…...
数学系C++ 类与对象 STL(九)
目录 目录 面向对象:py,c艹,Java都是,但c是面向过程 特征: 对象 内敛成员函数【是啥】: 构造函数和析构函数 构造函数 复制构造函数/拷贝构造函数: 【……】 实参与形参的传递方式:值…...
CSS技巧专栏:一日一例 2.纯CSS实现 多彩边框按钮特效
大家好,今天是 CSS技巧一日一例 专栏的第二篇《纯CSS实现多彩边框按钮特效》 先看图: 开工前的准备工作 正如昨日所讲,为了案例的表现,也处于书写的习惯,在今天的案例开工前,先把昨天的准备工作重做一遍。 清除浏览器的默认样式定义页面基本颜色设定body的样式清除butt…...
JCEF 在idea 开发 java 应用
JCEF(Java Chromium Embedded Framework)是一个Java库,用于在Java应用程序中嵌入Chromium浏览器引擎。如果您想在IDEA开发环境中使用JCEF,您可以按照以下步骤进行操作: 1. 下载JCEF库文件:您可以从JCEF的官…...
绝区伍--2024年AI发展路线图
2024 年将是人工智能具有里程碑意义的一年。随着新模式、融资轮次和进步以惊人的速度出现,很难跟上人工智能世界发生的一切。让我们深入了解 2024 年可能定义人工智能的关键事件、产品发布、研究突破和趋势。 2024 年第一季度 2024 年第一季度将推出一些主要车型并…...
循环冗余码校验CRC码 算法步骤+详细实例计算
通信过程:(白话解释) 我们将原始待发送的消息称为 M M M,依据发送接收消息双方约定的生成多项式 G ( x ) G(x) G(x)(意思就是 G ( x ) G(x) G(x) 是已知的)࿰…...
UDP(Echoserver)
网络命令 Ping 命令 检测网络是否连通 使用方法: ping -c 次数 网址ping -c 3 www.baidu.comnetstat 命令 netstat 是一个用来查看网络状态的重要工具. 语法:netstat [选项] 功能:查看网络状态 常用选项: n 拒绝显示别名&#…...
汽车生产虚拟实训中的技能提升与生产优化
在制造业蓬勃发展的大背景下,虚拟教学实训宛如一颗璀璨的新星,正发挥着不可或缺且日益凸显的关键作用,源源不断地为企业的稳健前行与创新发展注入磅礴强大的动力。就以汽车制造企业这一极具代表性的行业主体为例,汽车生产线上各类…...
2.Vue编写一个app
1.src中重要的组成 1.1main.ts // 引入createApp用于创建应用 import { createApp } from "vue"; // 引用App根组件 import App from ./App.vue;createApp(App).mount(#app)1.2 App.vue 其中要写三种标签 <template> <!--html--> </template>…...
五年级数学知识边界总结思考-下册
目录 一、背景二、过程1.观察物体小学五年级下册“观察物体”知识点详解:由来、作用与意义**一、知识点核心内容****二、知识点的由来:从生活实践到数学抽象****三、知识的作用:解决实际问题的工具****四、学习的意义:培养核心素养…...
【python异步多线程】异步多线程爬虫代码示例
claude生成的python多线程、异步代码示例,模拟20个网页的爬取,每个网页假设要0.5-2秒完成。 代码 Python多线程爬虫教程 核心概念 多线程:允许程序同时执行多个任务,提高IO密集型任务(如网络请求)的效率…...
在鸿蒙HarmonyOS 5中使用DevEco Studio实现录音机应用
1. 项目配置与权限设置 1.1 配置module.json5 {"module": {"requestPermissions": [{"name": "ohos.permission.MICROPHONE","reason": "录音需要麦克风权限"},{"name": "ohos.permission.WRITE…...
用docker来安装部署freeswitch记录
今天刚才测试一个callcenter的项目,所以尝试安装freeswitch 1、使用轩辕镜像 - 中国开发者首选的专业 Docker 镜像加速服务平台 编辑下面/etc/docker/daemon.json文件为 {"registry-mirrors": ["https://docker.xuanyuan.me"] }同时可以进入轩…...
MFC 抛体运动模拟:常见问题解决与界面美化
在 MFC 中开发抛体运动模拟程序时,我们常遇到 轨迹残留、无效刷新、视觉单调、物理逻辑瑕疵 等问题。本文将针对这些痛点,详细解析原因并提供解决方案,同时兼顾界面美化,让模拟效果更专业、更高效。 问题一:历史轨迹与小球残影残留 现象 小球运动后,历史位置的 “残影”…...
【Kafka】Kafka从入门到实战:构建高吞吐量分布式消息系统
Kafka从入门到实战:构建高吞吐量分布式消息系统 一、Kafka概述 Apache Kafka是一个分布式流处理平台,最初由LinkedIn开发,后成为Apache顶级项目。它被设计用于高吞吐量、低延迟的消息处理,能够处理来自多个生产者的海量数据,并将这些数据实时传递给消费者。 Kafka核心特…...
