当前位置: 首页 > news >正文

P5:使用pytorch实现运动鞋识别

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊
    我的环境
    语言环境:python 3.7.12
    编译器:pycharm
    深度学习环境:tensorflow 2.7.0
    数据:本地数据集-运动鞋
    在这里插入图片描述

一、代码

# 1 设置GPU
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision
from torchvision import transforms, datasetsimport os,PIL,pathlibdevice = torch.device("cuda" if torch.cuda.is_available() else "cpu")print(device)# 2 导入数据
import os,PIL,random,pathlibdata_dir = './data_sneakers/'
data_dir = pathlib.Path(data_dir)data_paths  = list(data_dir.glob('*'))
classeNames = [str(path).split("/")[1] for path in data_paths]
print(classeNames)
# 构建数据集
# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸# transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])test_transform = transforms.Compose([transforms.Resize([224, 224]),  # 将输入图片resize成统一尺寸transforms.ToTensor(),          # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize(           # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406],std=[0.229, 0.224, 0.225])  # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])train_dataset = datasets.ImageFolder("./data_sneakers/train/",transform=train_transforms)
test_dataset  = datasets.ImageFolder("./data_sneakers/test/",transform=test_transform)train_dataset.class_to_idxbatch_size = 32train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=1)for X, y in test_dl:print("Shape of X [N, C, H, W]: ", X.shape)print("Shape of y: ", y.shape, y.dtype)break# 3 构建简单的CNN
import torch.nn.functional as Fclass Model(nn.Module):def __init__(self):super(Model, self).__init__()self.conv1 = nn.Sequential(nn.Conv2d(3, 12, kernel_size=5, padding=0),  # 12*220*220nn.BatchNorm2d(12),nn.ReLU())self.conv2 = nn.Sequential(nn.Conv2d(12, 12, kernel_size=5, padding=0),  # 12*216*216nn.BatchNorm2d(12),nn.ReLU())self.pool3 = nn.Sequential(nn.MaxPool2d(2))  # 12*108*108self.conv4 = nn.Sequential(nn.Conv2d(12, 24, kernel_size=5, padding=0),  # 24*104*104nn.BatchNorm2d(24),nn.ReLU())self.conv5 = nn.Sequential(nn.Conv2d(24, 24, kernel_size=5, padding=0),  # 24*100*100nn.BatchNorm2d(24),nn.ReLU())self.pool6 = nn.Sequential(nn.MaxPool2d(2))  # 24*50*50self.dropout = nn.Sequential(nn.Dropout(0.2))self.fc = nn.Sequential(nn.Linear(24 * 50 * 50, len(classeNames)))def forward(self, x):batch_size = x.size(0)x = self.conv1(x)  # 卷积-BN-激活x = self.conv2(x)  # 卷积-BN-激活x = self.pool3(x)  # 池化x = self.conv4(x)  # 卷积-BN-激活x = self.conv5(x)  # 卷积-BN-激活x = self.pool6(x)  # 池化x = self.dropout(x)x = x.view(batch_size, -1)  # flatten 变成全连接网络需要的输入 (batch, 24*50*50) ==> (batch, -1), -1 此处自动算出的是24*50*50x = self.fc(x)return xdevice = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))model = Model().to(device)
print(model)# 4 训练模型
# 训练函数
# 训练循环
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset)  # 训练集的大小num_batches = len(dataloader)  # 批次数目, (size/batch_size,向上取整)train_loss, train_acc = 0, 0  # 初始化训练损失和正确率for X, y in dataloader:  # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X)  # 网络输出loss = loss_fn(pred, y)  # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad()  # grad属性归零loss.backward()  # 反向传播optimizer.step()  # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss# 测试函数
def test(dataloader, model, loss_fn):size = len(dataloader.dataset)  # 测试集的大小num_batches = len(dataloader)  # 批次数目, (size/batch_size,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss
# 设置动态学习率
def adjust_learning_rate(optimizer, epoch, start_lr):# 每 2 个epoch衰减到原来的 0.92lr = start_lr * (0.92 ** (epoch // 2))for param_group in optimizer.param_groups:param_group['lr'] = lrlearn_rate = 1e-4 # 初始学习率
optimizer  = torch.optim.SGD(model.parameters(), lr=learn_rate)# # 调用官方动态学习率接口时使用
# lambda1 = lambda epoch: (0.92 ** (epoch // 2))
# optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)
# scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) #选定调整方法# 5 训练
loss_fn = nn.CrossEntropyLoss()  # 创建损失函数
epochs = 40train_loss = []
train_acc = []
test_loss = []
test_acc = []for epoch in range(epochs):# 更新学习率(使用自定义学习率时使用)adjust_learning_rate(optimizer, epoch, learn_rate)model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)# scheduler.step() # 更新学习率(调用官方动态学习率接口时使用)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch + 1, epoch_train_acc * 100, epoch_train_loss,epoch_test_acc * 100, epoch_test_loss, lr))
print('Done')# 6 结果可视化
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore")               #忽略警告信息
plt.rcParams['font.sans-serif']    = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False      # 用来正常显示负号
plt.rcParams['figure.dpi']         = 100        #分辨率from datetime import datetime
current_time = datetime.now() # 获取当前时间epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')
plt.xlabel(current_time) # 打卡请带上时间戳,否则代码截图无效plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()# 预测指定图片
from PIL import Imageclasses = list(train_dataset.class_to_idx)def predict_one_image(image_path, model, transform, classes):test_img = Image.open(image_path).convert('RGB')# plt.imshow(test_img)  # 展示预测的图片test_img = transform(test_img)img = test_img.to(device).unsqueeze(0)model.eval()output = model(img)_, pred = torch.max(output, 1)pred_class = classes[pred]print(f'预测结果是:{pred_class}')# 预测训练集中的某张照片
predict_one_image(image_path='./data_sneakers/test/adidas/1.jpg',model=model,transform=train_transforms,classes=classes)# 模型保存
PATH = './model.pth'  # 保存的参数文件名
torch.save(model.state_dict(), PATH)# 将参数加载到model当中
model.load_state_dict(torch.load(PATH, map_location=device))

二、结果

在这里插入图片描述
在这里插入图片描述

在这里插入图片描述

三、总结

3.1导入数据步骤

● 第一步:使用pathlib.Path()函数将字符串类型的文件夹路径转换为pathlib.Path对象。
● 第二步:使用glob()方法获取data_dir路径下的所有文件路径,并以列表形式存储在data_paths中。
● 第三步:通过split()函数对data_paths中的每个文件路径执行分割操作,获得各个文件所属的类别名称,并存储在classeNames中
● 第四步:打印classeNames列表,显示每个文件所属的类别名称。

3.2 模型结构

在这里插入图片描述

3.3训练函数与测试函数区别

由于测试不进行梯度下降对网络权重进行更新,所以不需要传入优化器

3.4动态学习率

1. torch.optim.lr_scheduler.StepLR

等间隔动态调整方法,每经过step_size个epoch,做一次学习率decay,以gamma值为缩小倍数。

函数原型:
torch.optim.lr_scheduler.StepLR(optimizer, step_size, gamma=0.1, last_epoch=-1)

关键参数详解:
● optimizer(Optimizer):是之前定义好的需要优化的优化器的实例名
● step_size(int):是学习率衰减的周期,每经过每个epoch,做一次学习率decay
● gamma(float):学习率衰减的乘法因子。Default:0.1

用法示例:

optimizer = torch.optim.SGD(net.parameters(), lr=0.001 )
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.1)

2. lr_scheduler.LambdaLR

根据自己定义的函数更新学习率。

函数原型:
torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda, last_epoch=-1, verbose=False)

关键参数详解:
● optimizer(Optimizer):是之前定义好的需要优化的优化器的实例名
● lr_lambda(function):更新学习率的函数

用法示例:

lambda1 = lambda epoch: (0.92 ** (epoch // 2) # 第二组参数的调整方法
optimizer = torch.optim.SGD(model.parameters(), lr=learn_rate)
scheduler = torch.optim.lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda1) #选定调整方法

3. lr_scheduler.MultiStepLR

在特定的 epoch 中调整学习率

函数原型:
torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones, gamma=0.1, last_epoch=-1, verbose=False)

关键参数详解:
● optimizer(Optimizer):是之前定义好的需要优化的优化器的实例名
● milestones(list):是一个关于epoch数值的list,表示在达到哪个epoch范围内开始变化,必须是升序排列
● gamma(float):学习率衰减的乘法因子。Default:0.1

用法示例:

optimizer = torch.optim.SGD(net.parameters(), lr=0.001 )
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[2,6,15], #调整学习率的epoch数gamma=0.1)

更多的官方动态学习率设置方式可参考:https://pytorch.org/docs/stable/optim.html
调用官方接口示例:

model = [Parameter(torch.randn(2, 2, requires_grad=True))]
optimizer = SGD(model, 0.1)
scheduler = ExponentialLR(optimizer, gamma=0.9)for epoch in range(20):for input, target in dataset:optimizer.zero_grad()output = model(input)loss = loss_fn(output, target)loss.backward()optimizer.step()scheduler.step()

相关文章:

P5:使用pytorch实现运动鞋识别

🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍖 原作者:K同学啊 我的环境 语言环境:python 3.7.12 编译器:pycharm 深度学习环境:tensorflow 2.7.0 数据:本地数据集-运动鞋 一…...

讲解下SpringBoot中MySql和MongoDB的配合使用

在Spring Boot中,MySQL和MongoDB可以配合使用,以充分发挥关系型数据库和非关系型数据库的优势。MySQL适合处理结构化数据,而MongoDB适合处理非结构化或半结构化数据。以下是如何在Spring Boot中同时使用MySQL和MongoDB的详细讲解。 1. 添加依…...

《手札·行业篇》开源Odoo MES系统与SKF Observer Phoenix API在化工行业的双向对接方案

一、项目背景 化工行业生产过程复杂,设备运行条件恶劣,对设备状态监测、生产数据采集和质量控制的要求极高。通过开源Odoo MES系统与SKF Observer Phoenix API的双向对接,可以实现设备状态的实时监测、生产数据的自动化采集以及质量数据的同步…...

数据结构与算法之数组: LeetCode 905. 按奇偶排序数组 (Ts版)

按奇偶排序数组 https://leetcode.cn/problems/sort-array-by-parity/description/ 描述 给你一个整数数组 nums,将 nums 中的的所有偶数元素移动到数组的前面,后跟所有奇数元素。 返回满足此条件的 任一数组 作为答案。 示例 1 输入:n…...

【STM32】HAL库Host MSC读写外部U盘及FatFS文件系统的USB Disk模式

【STM32】HAL库Host MSC读写外部U盘及FatFS文件系统的USB Disk模式 在先前 分别介绍了FatFS文件系统和USB虚拟U盘MSC配置 前者通过MCU读写Flash建立文件系统 后者通过MSC连接电脑使其能够被操作 这两者可以合起来 就能够实现同时在MCU、USB中操作Flash的文件系统 【STM32】通过…...

docker nginx 配置文件详解

在平常的开发工作中,我们经常需要访问静态资源(图片、HTML页面等)、访问文件目录、部署项目时进行负载均衡等。那么我们就会使用到Nginx,nginx.conf 的配置至关重要。那么今天主要结合访问静态资源、负载均衡等总结下 nginx.conf …...

如何实现华为云+deepseek?

在华为云上实现跨账号迁移数据或部署DeepSeek模型,可以通过以下步骤完成: 跨账号数据迁移 创建委托:在源账号中创建一个委托(Agency),授予目标账号访问数据的权限。 复制镜像:在源账号中&…...

【学习笔记】计算机网络(三)

第3章 数据链路层 文章目录 第3章 数据链路层3.1数据链路层的几个共同问题3.1.1 数据链路和帧3.1.2 三个基本功能3.1.3 其他功能 - 滑动窗口机制 3.2 点对点协议PPP(Point-to-Point Protocol)3.2.1 PPP 协议的特点3.2.2 PPP协议的帧格式3.2.3 PPP 协议的工作状态 3.3 使用广播信…...

稀土抑烟剂——为汽车火灾安全增添防线

一、稀土抑烟剂的基本概念 稀土抑烟剂是一类基于稀土元素(如稀土氧化物和稀土金属化合物)开发的高效阻燃材料。它可以显著提高汽车内饰材料的阻燃性能,减少火灾发生时有毒气体和烟雾的产生。稀土抑烟剂不仅能提升火灾时的安全性,…...

Qt Pro、Pri、Prf

一、概述 1、在Qt中,通常使用.pro(project)、pri(private include)、prf(project file)三种文件扩展名来组织项目。对于模块化编程,Qt提供了Pro和Pri,Pro管理项目,Pri管理模块。 2、pro文件是Qt项目的核心文件,包含了…...

基于AIOHTTP、Websocket和Vue3一步步实现web部署平台,无延迟控制台输出,接近原生SSH连接

背景:笔者是一名Javaer,但是最近因为某些原因迷上了Python和它的Asyncio,至于什么原因?请往下看。在着迷”犯浑“的过程中,也接触到了一些高并发高性能的组件,通过简单的学习和了解,aiohttp这个…...

如何在MacOS上查看edge/chrome的扩展源码

步骤 进入管理扩展页面点击详细信息复制对应id在命令行键入 open ~/Library/Application Support/Microsoft Edge/Default/Extensions/${你刚刚复制的id} 即可打开访达中对应的更目录 注意 由于原生命令行无法直接处理空格 ,所以需要加转义符\,即:open ~/Librar…...

【xdoj-离散线上练习H】T234(C++)

解题心得: 写递归函数的时候,首先写终止条件,这有助于对整个递归函数的把握。 题目:输入集合A和B,输出A到B上的所有函数。 问题描述 给定非空数字集合A和B,求出集合A到集合B上的所有函数。 输入格式 第一行…...

Docker Desktop Windows 安装

一、先下载Docker desktop WIndows 下载地址 二、安装 安装超简单 一路 下一步 三、安装之后,桌面会出现一个 小蓝鲸图标,打开它 》更新至最新版本,不然小蓝鲸打开,一会就退出了。 》wsl --update (这个有时比较慢…...

springCloud-2021.0.9 之 GateWay 示例

文章目录 前言springCloud-2021.0.9 之 GateWay 示例1. GateWay 官网2. GateWay 三个关键名称3. GateWay 工作原理的高级概述4. 示例4.1. POM4.2. 启动类4.3. 过滤器4.4. 配置 5. 启动/测试 前言 如果您觉得有用的话,记得给博主点个赞,评论,收…...

JDK8 stream API用法汇总

目录 1.集合处理数据的弊端 2. Steam流式思想概述 3. Stream流的获取方式 3.1 根据Collection获取 3.1 通过Stream的of方法 4.Stream常用方法介绍 4.1 forEach 4.2 count 4.3 filter 4.4 limit 4.5 skip 4.6 map 4.7 sorted 4.8 distinct 4.9 match 4.10 find …...

windows生成SSL的PFX格式证书

生成crt证书: 安装openssl winget install -e --id FireDaemon.OpenSSL 生成cert openssl req -x509 -newkey rsa:2048 -keyout private.key -out certificate.crt -days 365 -nodes -subj "/CN=localhost" 转换pfx openssl pkcs12 -export -out certificate.pfx…...

玩转大语言模型——使用Kiln AI可视化环境进行大语言模型微调数据合成

系列文章目录 玩转大语言模型——使用langchain和Ollama本地部署大语言模型 玩转大语言模型——三分钟教你用langchain提示词工程获得猫娘女友 玩转大语言模型——ollama导入huggingface下载的模型 玩转大语言模型——langchain调用ollama视觉多模态语言模型 玩转大语言模型—…...

2025 西湖论剑wp

web Rank-l 打开题目环境: 发现一个输入框,看一下他是用上面语言写的 发现是python,很容易想到ssti 密码随便输,发现没有回显 但是输入其他字符会报错 确定为ssti注入 开始构造payload, {{(lipsum|attr(‘global…...

FPGA 28 ,基于 Vivado Verilog 的呼吸灯效果设计与实现( 使用 Vivado Verilog 实现呼吸灯效果 )

目录 前言 一. 设计流程 1.1 需求分析 1.2 方案设计 1.3 PWM解析 二. 实现流程 2.1 确定时间单位和精度 2.2 定义参数和寄存器 2.3 实现计数器逻辑 2.4 控制 LED 状态 三. 整体流程 3.1 全部代码 3.2 代码逻辑 1. 参数定义 2. 分级计数 3. 状态切换 4. LED 输…...

智慧医疗能源事业线深度画像分析(上)

引言 医疗行业作为现代社会的关键基础设施,其能源消耗与环境影响正日益受到关注。随着全球"双碳"目标的推进和可持续发展理念的深入,智慧医疗能源事业线应运而生,致力于通过创新技术与管理方案,重构医疗领域的能源使用模式。这一事业线融合了能源管理、可持续发…...

DeepSeek 赋能智慧能源:微电网优化调度的智能革新路径

目录 一、智慧能源微电网优化调度概述1.1 智慧能源微电网概念1.2 优化调度的重要性1.3 目前面临的挑战 二、DeepSeek 技术探秘2.1 DeepSeek 技术原理2.2 DeepSeek 独特优势2.3 DeepSeek 在 AI 领域地位 三、DeepSeek 在微电网优化调度中的应用剖析3.1 数据处理与分析3.2 预测与…...

以下是对华为 HarmonyOS NETX 5属性动画(ArkTS)文档的结构化整理,通过层级标题、表格和代码块提升可读性:

一、属性动画概述NETX 作用:实现组件通用属性的渐变过渡效果,提升用户体验。支持属性:width、height、backgroundColor、opacity、scale、rotate、translate等。注意事项: 布局类属性(如宽高)变化时&#…...

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决

Spring Cloud Gateway 中自定义验证码接口返回 404 的排查与解决 问题背景 在一个基于 Spring Cloud Gateway WebFlux 构建的微服务项目中,新增了一个本地验证码接口 /code,使用函数式路由(RouterFunction)和 Hutool 的 Circle…...

以光量子为例,详解量子获取方式

光量子技术获取量子比特可在室温下进行。该方式有望通过与名为硅光子学(silicon photonics)的光波导(optical waveguide)芯片制造技术和光纤等光通信技术相结合来实现量子计算机。量子力学中,光既是波又是粒子。光子本…...

HarmonyOS运动开发:如何用mpchart绘制运动配速图表

##鸿蒙核心技术##运动开发##Sensor Service Kit(传感器服务)# 前言 在运动类应用中,运动数据的可视化是提升用户体验的重要环节。通过直观的图表展示运动过程中的关键数据,如配速、距离、卡路里消耗等,用户可以更清晰…...

【无标题】路径问题的革命性重构:基于二维拓扑收缩色动力学模型的零点隧穿理论

路径问题的革命性重构:基于二维拓扑收缩色动力学模型的零点隧穿理论 一、传统路径模型的根本缺陷 在经典正方形路径问题中(图1): mermaid graph LR A((A)) --- B((B)) B --- C((C)) C --- D((D)) D --- A A -.- C[无直接路径] B -…...

作为测试我们应该关注redis哪些方面

1、功能测试 数据结构操作:验证字符串、列表、哈希、集合和有序的基本操作是否正确 持久化:测试aof和aof持久化机制,确保数据在开启后正确恢复。 事务:检查事务的原子性和回滚机制。 发布订阅:确保消息正确传递。 2、性…...

c# 局部函数 定义、功能与示例

C# 局部函数:定义、功能与示例 1. 定义与功能 局部函数(Local Function)是嵌套在另一个方法内部的私有方法,仅在包含它的方法内可见。 • 作用:封装仅用于当前方法的逻辑,避免污染类作用域,提升…...

沙箱虚拟化技术虚拟机容器之间的关系详解

问题 沙箱、虚拟化、容器三者分开一一介绍的话我知道他们各自都是什么东西,但是如果把三者放在一起,它们之间到底什么关系?又有什么联系呢?我不是很明白!!! 就比如说: 沙箱&#…...