当前位置: 首页 > news >正文

使用 PyTorch 实现 ZFNet 进行 MNIST 图像分类

         在本篇博客中,我们将通过两个主要部分来演示如何使用 PyTorch 实现 ZFNet,并在 MNIST 数据集上进行训练和测试。ZFNet(ZFNet)是基于卷积神经网络(CNN)的图像分类模型,广泛用于图像识别任务。

环境准备

        在开始之前,请确保你的环境已经安装了以下依赖:

pip install torch torchvision matplotlib tqdm

一、训练部分:训练 ZFNet 模型

首先,我们需要准备训练数据、定义 ZFNet 模型,并进行模型训练。

1. 数据加载与预处理

MNIST 数据集由 28x28 的手写数字图像组成。我们将通过 torchvision.datasets 来加载数据,并进行必要的预处理。

import torch
import torch.optim as optim
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from zfnet import ZFNet  # 假设 ZFNet 定义在 zfnet.py 文件中
from tqdm import tqdm  # 导入 tqdm
from torch.cuda.amp import autocast, GradScaler  # 导入混合精度训练def prepare_data(batch_size=128, num_workers=2, data_dir='D:/workspace/data'):"""准备 MNIST 数据集并返回数据加载器:param batch_size: 批处理大小:param num_workers: 数据加载的工作线程数:param data_dir: 数据存储的目录:return: 训练数据加载器"""transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))  # 正则化])trainset = datasets.MNIST(root=data_dir, train=True, download=True, transform=transform)trainloader = DataLoader(trainset, batch_size=batch_size, shuffle=True, pin_memory=True, num_workers=num_workers)return trainloader

2. 初始化模型与优化器

在这里,我们将初始化模型和优化器。我们选择 Adam 优化器,并且为提高计算效率,我们采用混合精度训练。

def initialize_device():"""初始化计算设备(GPU 或 CPU):return: 计算设备"""device = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(f"Using device: {device}")return devicedef initialize_model(device):"""初始化模型并移动到指定设备:param device: 计算设备:return: 初始化好的模型"""model = ZFNet().to(device)  # 假设 ZFNet 是自定义模型return modeldef initialize_optimizer(model, lr=0.001):"""初始化优化器:param model: 需要优化的模型:param lr: 学习率:return: 优化器"""optimizer = optim.Adam(model.parameters(), lr=lr)return optimizer

3. 训练模型

使用训练数据进行训练,并且每训练一个 epoch 就更新一次进度条,同时使用混合精度训练来提高效率。

def train_model(model, trainloader, criterion, optimizer, num_epochs=5, device='cuda'):"""训练模型:param model: 训练的模型:param trainloader: 数据加载器:param criterion: 损失函数:param optimizer: 优化器:param num_epochs: 训练的轮数:param device: 计算设备"""scaler = GradScaler()  # 用于自动缩放梯度for epoch in range(num_epochs):model.train()running_loss = 0.0# 使用 tqdm 包裹 DataLoader 来显示进度条with tqdm(trainloader, unit="batch", desc=f"Epoch {epoch + 1}/{num_epochs}") as tepoch:for inputs, labels in tepoch:# 直接将数据和标签移动到 GPUinputs, labels = inputs.to(device, non_blocking=True), labels.to(device, non_blocking=True)optimizer.zero_grad()# 混合精度前向和反向传播with autocast():  # 自动混合精度outputs = model(inputs)loss = criterion(outputs, labels)# 反向传播与优化scaler.scale(loss).backward()  # 使用混合精度反向传播scaler.step(optimizer)  # 更新参数scaler.update()  # 更新缩放因子running_loss += loss.item()# 更新进度条显示tepoch.set_postfix(loss=running_loss / (tepoch.n + 1))# 打印每个 epoch 的平均损失print(f"Epoch {epoch + 1}, Loss: {running_loss / len(trainloader)}")# 保存模型torch.save(model.state_dict(), 'zfnet_model.pth')print("Model saved as zfnet_model.pth")

4. 主函数

在主函数中,我们会初始化设备、模型、损失函数,并启动训练过程。

if __name__ == '__main__':"""主函数:组织所有步骤的执行"""# 数据加载trainloader = prepare_data()# 设备选择device = initialize_device()# 模型初始化model = initialize_model(device)# 损失函数criterion = torch.nn.CrossEntropyLoss()# 优化器初始化optimizer = initialize_optimizer(model)# 启动训练train_model(model, trainloader, criterion, optimizer, num_epochs=5, device=device)

二、测试部分:评估 ZFNet 模型

训练完成后,我们将加载训练好的模型,并在测试集上评估其性能。

1. 加载和预处理数据
import torch
import matplotlib.pyplot as plt
from torch.utils.data import DataLoader
from torchvision import datasets, transforms
from zfnet import ZFNet  # 假设 ZFNet 定义在 zfnet.py 文件中def load_and_preprocess_data(batch_size=1000):"""加载并预处理 MNIST 数据集:param batch_size: 数据加载的批次大小:return: 测试数据加载器"""transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])# 下载 MNIST 测试集testset = datasets.MNIST(root='D:/workspace/data', train=False, download=True, transform=transform)# 数据加载器testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)return testloader

2. 加载训练好的模型
def load_and_preprocess_data(batch_size=1000):"""加载并预处理 MNIST 数据集:param batch_size: 数据加载的批次大小:return: 测试数据加载器"""transform = transforms.Compose([transforms.ToTensor(),transforms.Normalize((0.5,), (0.5,))])# 下载 MNIST 测试集testset = datasets.MNIST(root='D:/workspace/data', train=False, download=True, transform=transform)# 数据加载器testloader = DataLoader(testset, batch_size=batch_size, shuffle=False)return testloaderdef load_trained_model(model_path='zfnet_model.pth'):"""加载训练好的模型:param model_path: 模型文件路径:return: 加载的模型"""model = ZFNet()model.load_state_dict(torch.load(model_path))model.eval()  # 设置为评估模式return model

3. 评估模型
def evaluate_model(model, testloader):"""评估模型在测试集上的表现:param model: 训练好的模型:param testloader: 测试数据加载器:return: 模型准确率"""correct = 0total = 0with torch.no_grad():for inputs, labels in testloader:outputs = model(inputs)_, predicted = torch.max(outputs, 1)total += labels.size(0)correct += (predicted == labels).sum().item()accuracy = correct / totalreturn accuracy

4. 可视化预测结果
def visualize_predictions(model, testloader, num_images=6):"""可视化模型对多张测试图片的预测结果:param model: 训练好的模型:param testloader: 测试数据加载器:param num_images: 显示图像的数量"""model.eval()data_iter = iter(testloader)images, labels = next(data_iter)outputs = model(images)_, predicted = torch.max(outputs, 1)# 绘制结果fig, axes = plt.subplots(2, 3, figsize=(10, 7))axes = axes.ravel()for i in range(num_images):ax = axes[i]img = images[i].numpy().transpose(1, 2, 0)  # 将 Tensor 转换为 NumPy 数组并转置为 HWC 格式ax.imshow(img.squeeze(), cmap='gray')  # squeeze 去除单通道维度ax.set_title(f"Pred: {predicted[i].item()} | Actual: {labels[i].item()}")ax.axis('off')plt.tight_layout()plt.show()

5. 主函数

在测试阶段,我们加载模型并在测试数据集上评估它。

def main():"""主函数,组织数据加载、模型加载、评估和可视化步骤"""# 加载并预处理数据testloader = load_and_preprocess_data()# 加载训练好的模型model = load_trained_model()# 评估模型accuracy = evaluate_model(model, testloader)print(f"Accuracy: {accuracy * 100:.2f}%")# 可视化预测结果visualize_predictions(model, testloader, num_images=6)if __name__ == '__main__':main()


结语

通过本文的介绍,我们实现了一个基于 ZFNet 模型的图像分类任务,使用 PyTorch 对 MNIST 数据集进行训练与测试,并展示了如何进行混合精度训练以提高效率。在未来,你可以根据不同的任务修改模型结构、优化器或者训练策略,进一步提升性能。


完整项目ZFNet-PyTorch: 使用 PyTorch 实现 ZFNet 进行 MNIST 图像分类icon-default.png?t=O83Ahttps://gitee.com/qxdlll/zfnet-py-torch


  

相关文章:

使用 PyTorch 实现 ZFNet 进行 MNIST 图像分类

在本篇博客中,我们将通过两个主要部分来演示如何使用 PyTorch 实现 ZFNet,并在 MNIST 数据集上进行训练和测试。ZFNet(ZFNet)是基于卷积神经网络(CNN)的图像分类模型,广泛用于图像识别任务。 环…...

车轮上的科技:Spring Boot汽车新闻集散地

1系统概述 1.1 研究背景 随着计算机技术的发展以及计算机网络的逐渐普及,互联网成为人们查找信息的重要场所,二十一世纪是信息的时代,所以信息的管理显得特别重要。因此,使用计算机来管理汽车资讯网站的相关信息成为必然。开发合适…...

IDEA2023 SpringBoot整合Web开发(二)

一、SpringBoot介绍 由Pivotal团队提供的全新框架,其设计目的是用来简化Spring应用的初始搭建以及开发过程。该框架使用了特定的方式来进行配置,从而使开发人员不再需要定义样板化的配置。SpringBoot提供了一种新的编程范式,可以更加快速便捷…...

国产三维CAD 2025新动向:推进MBD模式,联通企业设计-制造数据

本文为CAD芯智库原创整理,未经允许请勿复制、转载! 上一篇文章阿芯分享了影响企业数字化转型的「MBD」是什么、对企业优化产品设计流程有何价值——这也是国产三维CAD软件中望3D 2024发布会上,胡其登先生(中望软件产品规划与GTM中…...

ubuntu 之 安装mysql8

安装 # 如果 ubuntu 版本 > 20.04 则不用执行 wget 这步 wget https://dev.mysql.com/get/mysql-apt-config_0.8.12-1_all.debsudo apt-get updatesudo apt-get install mysql-server mysql-client 安装过程中如果没有提示输入密码 sudo cat /etc/mysql/debian.cnf # 查…...

Flink Lookup Join(维表 Join)

Lookup Join 定义(支持 Batch\Streaming) Lookup Join 其实就是维表 Join,比如拿离线数仓来说,常常会有用户画像,设备画像等数据,而对应到实时数仓场景中,这种实时获取外部缓存的 Join 就叫做维…...

Elasticsearch retrievers 通常与 Elasticsearch 8.16.0 一起正式发布!

作者:来自 Elastic Panagiotis Bailis Elasticsearch 检索器经过了重大改进,现在可供所有人使用。了解其架构和用例。 在这篇博文中,我们将再次深入探讨检索器(retrievers)。我们已经在之前的博文中讨论过它们&#xf…...

【并发模式】Go 常见并发模式实现Runner、Pool、Work

通过并发编程在 Go 程序中实现的3种常见的并发模式。 参考:https://cloud.tencent.com/developer/article/1720733 1、Runner 定时任务 Runner 模式有代表性,能把(任务队列,超时,系统中断信号)等结合起来…...

【前端知识】Javascript前端框架Vue入门

前端框架VUE入门 概述基础语法介绍组件特性组件注册Props 属性声明事件组件 v-model(双向绑定)插槽Slots内容与出口 组件生命周期样式文件使用1. 直接在<style>标签中写CSS2. 引入外部CSS文件3. 使用CSS预处理器4. 在main.js中全局引入CSS文件5. 使用CSS Modules6. 使用P…...

Springboot3.3.5 启动流程之 Bean创建流程

在文章Springboot3.3.5 启动流程&#xff08;源码分析&#xff09;中我们只是粗略的介绍了bean 的装配(Bean的定义)流程和实例化流程分别开始于 finishBeanFactoryInitialization 和 preInstantiateSingletons. 其实,在Spring boot中&#xff0c;Bean 的装配是多阶段的&#xf…...

golang反射函数注册

package main import ( “fmt” “reflect” ) type Job interface { New([]interface{}) interface{} Run() (interface{}, error) } type DetEd struct { Name string Age int } // 为什么这样设计 // 这样就避免了 在创建新的实例的之后 结构体的方法中接受者为指针类型…...

【Spring】Bean

Spring 将管理对象称为 Bean。 Spring 可以看作是一个大型工厂&#xff0c;用于生产和管理 Spring 容器中的 Bean。如果要使用 Spring 生产和管理 Bean&#xff0c;那么就需要将 Bean 配置在 Spring 的配置文件中。Spring 框架支持 XML 和 Properties 两种格式的配置文件&#…...

深入解析TK技术下视频音频不同步的成因与解决方案

随着互联网和数字视频技术的飞速发展&#xff0c;音视频同步问题逐渐成为网络视频播放、直播、编辑等过程中不可忽视的技术难题。尤其是在采用TK&#xff08;Transmission Keying&#xff09;技术进行视频传输时&#xff0c;由于其特殊的时序同步要求&#xff0c;音视频不同步现…...

为什么要使用Ansible实现Linux管理自动化?

自动化和Linux系统管理 多年来&#xff0c;大多数系统管理和基础架构管理都依赖于通过图形或命令行用户界面执行的手动任务。系统管理员通常使用清单、其他文档或记忆的例程来执行标准任务。 这种方法容易出错。系统管理员很容易跳过某个步骤或在某个步骤上犯错误。验证这些步…...

Android:任意层级树形控件(有效果图和Demo示例)

先上效果图&#xff1a; 1.创建treeview文件夹 2.treeview -> adapter -> SimpleTreeAdapter.java import android.content.Context; import android.view.View; import android.view.ViewGroup; import android.widget.ImageView; import android.widget.ListView; i…...

C++ 容器全面剖析:掌握 STL 的奥秘,从入门到高效编程

引言 C 标准模板库&#xff08;STL&#xff09;提供了一组功能强大的容器类&#xff0c;用于存储和操作数据集合。不同的容器具有独特的特性和应用场景&#xff0c;因此选择合适的容器对于程序的性能和代码的可读性至关重要。对于刚接触 C 的开发者来说&#xff0c;了解这些容…...

C++---类型转换

文章目录 C的类型转换C的4种强制类型转换RTTI C的类型转换 类型转换 内置类型之间的转换 // a、内置类型之间 // 1、隐式类型转换 整形之间/整形和浮点数之间 // 2、显示类型的转换 指针和整形、指针之间 int main() {int i 1;// 隐式类型转换double d i;printf("%d…...

CSS基础学习练习题

编程题 1.为下面这段文字定义字体样式&#xff0c;要求字体类型指定多种、大小为14px、粗细为粗体、颜色为蓝色。 “有规划的人生叫蓝图&#xff0c;没规划的人生叫拼图。​” 代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><me…...

TypeScript知识点总结和案例使用

TypeScript 是一种由微软开发的开源编程语言&#xff0c;它是 JavaScript 的超集&#xff0c;提供了静态类型检查和其他一些增强功能。以下是一些 TypeScript 的重要知识点总结&#xff1a; 1. 基本类型 TypeScript 支持多种基本数据类型&#xff0c;包括&#xff1a; numbe…...

解决BUG: Since 17.0, the “attrs“ and “states“ attributes are no longer used.

从Odoo 17.0开始&#xff0c;attrs和states属性不再使用&#xff0c;取而代之的是使用depends和domain属性来控制字段的可见性和其他行为。如果您想要在选择国家之后继续选择州&#xff0c;并且希望在选择了国家之后才显示州字段&#xff0c;您可以使用depends属性来实现这一点…...

单片机GPIO中断+定时器 实现模拟串口接收

单片机GPIO中断定时器 实现模拟串口接收 解决思路代码示例 解决思路 串口波特率9600bps,每个bit约为1000000us/9600104.16us&#xff1b; 定时器第一次定时时间设为52us即半个bit的时间&#xff0c;其目的是偏移半个bit时间&#xff0c;之后的每104us采样并读取1bit数据。使得…...

《深入理解 Spring MVC 工作流程》

一、Spring MVC 架构概述 Spring MVC 是一个基于 Java 的轻量级 Web 应用框架&#xff0c;它遵循了经典的 MVC&#xff08;Model-View-Controller&#xff09;设计模式&#xff0c;将请求、响应和业务逻辑分离&#xff0c;从而构建出灵活可维护的 Web 应用程序。 在 Spring MV…...

HTML简介

知识点一 HTML 什么是HTML&#xff1f; 超文本标记语言(HyperTextMarkup Language&#xff0c;简称HTML) 怎么学HTML&#xff1f; HTML 是一门标记语言&#xff0c;标记语言由一套标记标签组成&#xff0c;学习 HTML&#xff0c;其实就是学习标签 开发工具 编辑器: Pycha…...

Linux系统Centos设置开机默认root用户

目录 一. 教程 二. 部分第三方工具配置也无效 一. 教程 使用 Linux 安装Centos系统的小伙伴大概都知道&#xff0c;我们进入系统后&#xff0c;通常都是自己设置的普通用户身份&#xff0c;而不是 root 超级管理员用户&#xff0c;导致我们在操作文件夹时往往爆出没有权限&am…...

【网络安全 | 甲方建设】双/多因素认证、TOTP原理及实现

未经许可,不得转载。 文章目录 背景双因素、多因素认证双因素认证(2FA)多因素认证(MFA)TOTP实现TOTP生成流程TOTP算法TOTP代码示例(JS)Google Authenticator总结背景 在传统的在线银行系统中,用户通常只需输入用户名和密码就可以访问自己的账户。然而,如果密码不慎泄…...

Nuxt3 动态路由URL不更改的前提下参数更新,NuxtLink不刷新不跳转,生命周期无响应解决方案

Nuxt3 动态路由URL不更改的前提下参数更新&#xff0c;NuxtLink不刷新不跳转&#xff0c;生命周期无响应解决方案 首先说明一点&#xff0c;Nuxt3 的动态路由响应机制是根据 URL 是否更改&#xff0c;参数的更改并不会触发 Router 去更新页面&#xff0c;这在 Vue3 上同样存在…...

2024华为java面经

华为2024年Java招聘面试题目可能会涵盖Java基础知识、核心技术、框架与工具、项目经验以及算法与数据结构等多个方面。以下是考的内容。 一、Java基础知识 Java中有哪些基本数据类型&#xff1f; Java为什么能够跨平台运行&#xff1f; String是基本数据类型吗&#xff1f;能…...

2021 年 9 月青少年软编等考 C 语言三级真题解析

目录 T1. 课程冲突思路分析T2. 余数相同问题思路分析T3. 生成括号思路分析T4. 广义格雷码思路分析T5. 菲波那契数列思路分析T1. 课程冲突 小 A 修了 n n n 门课程,第 i i i 门课程是从第 a i a_i ai​ 天一直上到第 b i b_i bi​ 天。 定义两门课程的冲突程度为:有几天…...

深度解析FastDFS:构建高效分布式文件存储的实战指南(下)

接上篇&#xff1a;《深度解析FastDFS&#xff1a;构建高效分布式文件存储的实战指南&#xff08;上&#xff09;》 传送门: link 文章目录 六、常用命令七、FastDFS配置详解7.1 tracker配置文件7.2 tracker目录及文件结构7.3 storage配置文件7.4 storage服务器的目录结构和文件…...

Python学习29天

二分查找 # 定义函数冒泡排序法从大到小排列 def bbble_sort(list):# i控制排序次数for i in range(len(list) - 1):# j控制每次排序比较次数for j in range(len(list) - 1 - i):if list[j] < list[j 1]:list[j], list[j 1] list[j 1], list[j] # 定义二分查找函数 def…...

0基础做网站多久/网址域名大全

外部命令&#xff1a; impala-shell –h 可以帮助我们查看帮助手册 impala-shell –r 刷新impala元数据 impala-shell –f 文件路径 执行指的的sql查询文件。 impala-shell –i 指定连接运行 impalad 守护进程的主机。 impala-shell –o 保存执行结果到文件当中去。 内部命令 c…...

dw网站制作模板/营销外包

转自&#xff1a;https://www.jb51.net/article/184639.htm 前言 时间是宝贵的&#xff0c;我们无时无刻不在和时间打交道&#xff0c;这个任务明天下班前截止&#xff0c;你点的外卖还有5分钟才能送到&#xff0c;那个程序已经运行了整整48个小时&#xff0c;既然时间和我们…...

建设留学网站/产品网络营销方案

1、 1&#xff09; 2) 3&#xff09; 4&#xff09; 2、 1&#xff09; 2&#xff09;AVL 3&#xff09;B树 B树 3、哈希表 转载于:https://www.cnblogs.com/KennyRom/p/6222387.html...

烟台网站建设费用/电商网络推广是什么

kyeye项目介绍win10风格的一套系统&#xff0c;前端采用layui作为前端框架&#xff0c;后端采用SpringBoot作为服务框架&#xff0c;采用自封装的xml对所有请求进行参数校验&#xff0c;以保证接口安全性。启动方式直接运行com.skyeye.SkyEyeApplication即可&#xff0c;启动完…...

天津模板做网站/点金推广优化公司

在先前的文章&#xff02;在Ubuntu上的传感器&#xff02;中&#xff0c;我们已经从QML中&#xff0c;展示了如何在Ubuntu平台中利用Sensor来给我所需要的数据&#xff0e;在今天的例程中&#xff0c;我们将通过C的API例举所有的Sensor&#xff0c;并展示他们所有的属性&#x…...

油管代理网页/网站seo收录

1.设置选中tree的节点 var node $(#tt).tree(find, 1);//找到id为”tt“这个树的节点id为”1“的对象$(#tt).tree(select, node.target);//设置选中该节点 2.获取选中节点的值 $("#tt").tree(getSelected).id $("#tt").tree(getSelected).text 2.通过子节…...