YOLOv5-Backbone模块实现
🍨 本文为🔗365天深度学习训练营 中的学习记录博客
🍦 参考文章地址: 365天深度学习训练营-第P8周:YOLOv5-Backbone模块实现
🍖 作者:K同学啊
一、前期准备
1.设置GPU
import torch
from torch import nn
import torchvision
from torchvision import transforms,datasets,models
import matplotlib.pyplot as plt
import os,PIL,pathlib
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
device(type='cuda')
2.导入数据
data_dir = './weather_photos/'
data_dir = pathlib.Path(data_dir)data_paths = list(data_dir.glob('*'))
classNames = [str(path).split('\\')[1] for path in data_paths]
classNames
['cloudy', 'rain', 'shine', 'sunrise']
# 关于transforms.Compose的更多介绍可以参考:https://blog.csdn.net/qq_38251616/article/details/124878863
train_transforms = transforms.Compose([transforms.Resize([224, 224]), # 将输入图片resize成统一尺寸# transforms.RandomHorizontalFlip(), # 随机水平翻转transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize( # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])test_transform = transforms.Compose([transforms.Resize([224, 224]), # 将输入图片resize成统一尺寸transforms.ToTensor(), # 将PIL Image或numpy.ndarray转换为tensor,并归一化到[0,1]之间transforms.Normalize( # 标准化处理-->转换为标准正太分布(高斯分布),使模型更容易收敛mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) # 其中 mean=[0.485,0.456,0.406]与std=[0.229,0.224,0.225] 从数据集中随机抽样计算得到的。
])total_data = datasets.ImageFolder(data_dir,transform=train_transforms)
total_data
Dataset ImageFolder
Number of datapoints: 1125
Root location: weather_photos
StandardTransform
Transform: Compose(
Resize(size=[224, 224], interpolation=PIL.Image.BILINEAR)
ToTensor()
Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
)
total_data.class_to_idx
{'cloudy': 0, 'rain': 1, 'shine': 2, 'sunrise': 3}
3.划分数据集
train_size = int(0.8*len(total_data))
test_size = len(total_data) - train_size
train_dataset, test_dataset = torch.utils.data.random_split(total_data,[train_size,test_size])
train_dataset,test_dataset
(<torch.utils.data.dataset.Subset at 0x1e42b97f4f0>,
<torch.utils.data.dataset.Subset at 0x1e42b196a30>)
batch_size = 4
train_dl = torch.utils.data.DataLoader(train_dataset,batch_size=batch_size,shuffle=True,num_workers=1)
test_dl = torch.utils.data.DataLoader(test_dataset,batch_size=batch_size,shuffle=True,num_workers=1)
for X,y in test_dl:print('Shape of X [N, C, H, W]:', X.shape)print('Shape of y:', y.shape)break
Shape of X [N, C, H, W]: torch.Size([4, 3, 224, 224])
Shape of y: torch.Size([4])
二、搭建包含Backbone模块的模型
1.搭建模型
import torch.nn.functional as Fdef autopad(k, p=None): # kernel, padding# Pad to 'same'if p is None:p = k // 2 if isinstance(k, int) else [x // 2 for x in k] # auto-padreturn pclass Conv(nn.Module):# Standard convolutiondef __init__(self, c1, c2, k=1, s=1, p=None, g=1, act=True): # ch_in, ch_out, kernel, stride, padding, groupssuper().__init__()self.conv = nn.Conv2d(c1, c2, k, s, autopad(k, p), groups=g, bias=False)self.bn = nn.BatchNorm2d(c2)self.act = nn.SiLU() if act is True else (act if isinstance(act, nn.Module) else nn.Identity())def forward(self, x):return self.act(self.bn(self.conv(x)))class Bottleneck(nn.Module):# Standard bottleneckdef __init__(self, c1, c2, shortcut=True, g=1, e=0.5): # ch_in, ch_out, shortcut, groups, expansionsuper().__init__()c_ = int(c2 * e) # hidden channelsself.cv1 = Conv(c1, c_, 1, 1)self.cv2 = Conv(c_, c2, 3, 1, g=g)self.add = shortcut and c1 == c2def forward(self, x):return x + self.cv2(self.cv1(x)) if self.add else self.cv2(self.cv1(x))class C3(nn.Module):# CSP Bottleneck with 3 convolutionsdef __init__(self, c1, c2, n=1, shortcut=True, g=1, e=0.5): # ch_in, ch_out, number, shortcut, groups, expansionsuper().__init__()c_ = int(c2 * e) # hidden channelsself.cv1 = Conv(c1, c_, 1, 1)self.cv2 = Conv(c1, c_, 1, 1)self.cv3 = Conv(2 * c_, c2, 1) # act=FReLU(c2)self.m = nn.Sequential(*(Bottleneck(c_, c_, shortcut, g, e=1.0) for _ in range(n)))def forward(self, x):return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), dim=1))class SPPF(nn.Module):# Spatial Pyramid Pooling - Fast (SPPF) layer for YOLOv5 by Glenn Jocherdef __init__(self, c1, c2, k=5): # equivalent to SPP(k=(5, 9, 13))super().__init__()c_ = c1 // 2 # hidden channelsself.cv1 = Conv(c1, c_, 1, 1)self.cv2 = Conv(c_ * 4, c2, 1, 1)self.m = nn.MaxPool2d(kernel_size=k, stride=1, padding=k // 2)def forward(self, x):x = self.cv1(x)with warnings.catch_warnings():warnings.simplefilter('ignore') # suppress torch 1.9.0 max_pool2d() warningy1 = self.m(x)y2 = self.m(y1)return self.cv2(torch.cat([x, y1, y2, self.m(y2)], 1))
"""
这个是YOLOv5, 6.0版本的主干网络,这里进行复现
(注:有部分删改,详细讲解将在后续进行展开)
"""
class YOLOv5_backbone(nn.Module):def __init__(self):super(YOLOv5_backbone, self).__init__()self.Conv_1 = Conv(3, 64, 3, 2, 2) self.Conv_2 = Conv(64, 128, 3, 2) self.C3_3 = C3(128,128)self.Conv_4 = Conv(128, 256, 3, 2) self.C3_5 = C3(256,256)self.Conv_6 = Conv(256, 512, 3, 2) self.C3_7 = C3(512,512)self.Conv_8 = Conv(512, 1024, 3, 2) self.C3_9 = C3(1024, 1024)self.SPPF = SPPF(1024, 1024, 5)# 全连接网络层,用于分类self.classifier = nn.Sequential(nn.Linear(in_features=65536, out_features=100),nn.ReLU(),nn.Linear(in_features=100, out_features=4))def forward(self, x):x = self.Conv_1(x)x = self.Conv_2(x)x = self.C3_3(x)x = self.Conv_4(x)x = self.C3_5(x)x = self.Conv_6(x)x = self.C3_7(x)x = self.Conv_8(x)x = self.C3_9(x)x = self.SPPF(x)x = torch.flatten(x, start_dim=1)x = self.classifier(x)return xdevice = "cuda" if torch.cuda.is_available() else "cpu"
print("Using {} device".format(device))model = YOLOv5_backbone().to(device)
model
略
2.查看详细模型
# 统计模型参数量以及其他指标
import torchsummary as summary
summary.summary(model, (3, 224, 224))
略
三、训练模型
1.编写训练函数
# 训练循环
def train(dataloader, model, loss_fn, optimizer):size = len(dataloader.dataset) # 训练集的大小,一共900张图片num_batches = len(dataloader) # 批次数目,29(900/32)train_loss, train_acc = 0, 0 # 初始化训练损失和正确率for X, y in dataloader: # 获取图片及其标签X, y = X.to(device), y.to(device)# 计算预测误差pred = model(X) # 网络输出loss = loss_fn(pred, y) # 计算网络输出和真实值之间的差距,targets为真实值,计算二者差值即为损失# 反向传播optimizer.zero_grad() # grad属性归零loss.backward() # 反向传播optimizer.step() # 每一步自动更新# 记录acc与losstrain_acc += (pred.argmax(1) == y).type(torch.float).sum().item()train_loss += loss.item()train_acc /= sizetrain_loss /= num_batchesreturn train_acc, train_loss
2.编写测试函数
def test (dataloader, model, loss_fn):size = len(dataloader.dataset) # 测试集的大小,一共10000张图片num_batches = len(dataloader) # 批次数目,8(255/32=8,向上取整)test_loss, test_acc = 0, 0# 当不进行训练时,停止梯度更新,节省计算内存消耗with torch.no_grad():for imgs, target in dataloader:imgs, target = imgs.to(device), target.to(device)# 计算losstarget_pred = model(imgs)loss = loss_fn(target_pred, target)test_loss += loss.item()test_acc += (target_pred.argmax(1) == target).type(torch.float).sum().item()test_acc /= sizetest_loss /= num_batchesreturn test_acc, test_loss
3.正式训练
import copyoptimizer = torch.optim.Adam(model.parameters(), lr= 1e-4)
loss_fn = nn.CrossEntropyLoss() # 创建损失函数epochs = 20train_loss = []
train_acc = []
test_loss = []
test_acc = []best_acc = 0 # 设置一个最佳准确率,作为最佳模型的判别指标for epoch in range(epochs):model.train()epoch_train_acc, epoch_train_loss = train(train_dl, model, loss_fn, optimizer)model.eval()epoch_test_acc, epoch_test_loss = test(test_dl, model, loss_fn)# 保存最佳模型到 best_modelif epoch_test_acc > best_acc:best_acc = epoch_test_accbest_model = copy.deepcopy(model)train_acc.append(epoch_train_acc)train_loss.append(epoch_train_loss)test_acc.append(epoch_test_acc)test_loss.append(epoch_test_loss)# 获取当前的学习率lr = optimizer.state_dict()['param_groups'][0]['lr']template = ('Epoch:{:2d}, Train_acc:{:.1f}%, Train_loss:{:.3f}, Test_acc:{:.1f}%, Test_loss:{:.3f}, Lr:{:.2E}')print(template.format(epoch+1, epoch_train_acc*100, epoch_train_loss, epoch_test_acc*100, epoch_test_loss, lr))# 保存最佳模型到文件中
PATH = './best_model.pth' # 保存的参数文件名
torch.save(model.state_dict(), PATH)print('Done')
。。。
Epoch:18, Train_acc:95.0%, Train_loss:0.142, Test_acc:91.6%, Test_loss:0.236, Lr:1.00E-04
Epoch:19, Train_acc:92.8%, Train_loss:0.193, Test_acc:88.0%, Test_loss:0.278, Lr:1.00E-04
Epoch:20, Train_acc:94.6%, Train_loss:0.160, Test_acc:92.0%, Test_loss:0.220, Lr:1.00E-04
Done
四、结果可视化
1.Loss与Accuracy图
import matplotlib.pyplot as plt
#隐藏警告
import warnings
warnings.filterwarnings("ignore") #忽略警告信息
plt.rcParams['font.sans-serif'] = ['SimHei'] # 用来正常显示中文标签
plt.rcParams['axes.unicode_minus'] = False # 用来正常显示负号
plt.rcParams['figure.dpi'] = 100 #分辨率epochs_range = range(epochs)plt.figure(figsize=(12, 3))
plt.subplot(1, 2, 1)plt.plot(epochs_range, train_acc, label='Training Accuracy')
plt.plot(epochs_range, test_acc, label='Test Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')plt.subplot(1, 2, 2)
plt.plot(epochs_range, train_loss, label='Training Loss')
plt.plot(epochs_range, test_loss, label='Test Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()
2.模型评估
# 将参数加载到model当中
best_model.load_state_dict(torch.load(PATH, map_location=device))
epoch_test_acc, epoch_test_loss = test(test_dl, best_model, loss_fn)
epoch_test_acc, epoch_test_loss
(0.92, 0.21799196774352886)
# 查看是否与我们记录的最高准确率一致
epoch_test_acc
0.92
相关文章:
YOLOv5-Backbone模块实现
🍨 本文为🔗365天深度学习训练营 中的学习记录博客🍦 参考文章地址: 365天深度学习训练营-第P8周:YOLOv5-Backbone模块实现🍖 作者:K同学啊一、前期准备1.设置GPUimport torch from torch impor…...
【C语言】程序环境和预处理
🌇个人主页:平凡的小苏 📚学习格言:别人可以拷贝我的模式,但不能拷贝我不断往前的激情 🛸C语言专栏:https://blog.csdn.net/vhhhbb/category_12174730.html 小苏希望大家能从这篇文章中收获到许…...
9.关系查询处理和查询优化
其他章节索引 梳理 名词解释 代数优化:是指关系代数表达式的优化,也即按照一定规则,通过对关系代数表达式进行等价变换,改变代数表达式中操作的次序和组合,使查询更高效物理优化:是指存取路径和底层操作算…...
计算机组成原理(三)
5.掌握定点数的表示和应用(主要是无符号数和有符号数的表示、机器数的定点表示、数的机器码表示); 定点数:小数点位置固定不变。 定点小数:小数点固定在数值位与符号位之间; 定点整数:小…...
C. Least Prefix Sum codeforces每日一题
🚀前言 🚀 大家好啊,这里是幸麟 🧩 一名普通的大学牲,最近在学习算法 🧩每日一题的话难度的话是根据博主水平来找的 🧩所以可能难度比较低,以后会慢慢提高难度的 🧩此题标…...
ASEMI三相整流模块MDS100-16图片,MDS100-16尺寸
编辑-Z ASEMI三相整流模块MDS100-16参数: 型号:MDS100-16 最大重复峰值反向电压(VRRM):1600V 最大RMS电桥输入电压(VRMS):1700V 最大平均正向整流输出电流(IF&#…...
【第37天】斐波那契数列与爬楼梯 | 迭代的鼻祖,递推与记忆化
本文已收录于专栏🌸《Java入门一百例》🌸学习指引序、专栏前言一、递推与记忆化二、【例题1】1、题目描述2、解题思路3、模板代码4、代码解析5.原题链接三、【例题1】1、题目描述2.解题思路3、模板代码4、代码解析5、原题链接三、推荐专栏四、课后习题序…...
Map集合
Map集合 Map接口的简介 Map用于保存具有映射关系的数据,Map里保存着两组数据:key和value,它们都可以使任何引用类型的数据,但key不能重复。所以通过指定的key就可以取出对应的value。 Map 没有继承 Collection 接口,…...
PyQt5编程扩展 3.2 资源文件的使用
目录 本例运行效果: 设计Qt窗体 建立项目 放一个Group Box 放三个Label 放一个Horizontal Slider 放两个Line Edit 层次结构 布局 放一个Group Box 放两个Label 放两个Line Edit 放一个Push Button 层次结构 布局 放一个frame 层次结构 布局 窗体…...
Linux系统之文件共享目录设置方法
Linux系统之文件共享目录设置方法一、本次实践目的二、检查本地系统环境1.检查系统版本2.检查系统内核三、创建相关用户及用户组1.创建共享目录2.创建测试用户账号3.创建用户组4.设置用户的属组5.查看admin和IT用户组成员6.查看所有用户信息四、共享目录权限设置1.设置/data/so…...
上海亚商投顾:三大指数均涨超1% 芯片板块集体大涨
上海亚商投顾前言:无惧大盘涨跌,解密龙虎榜资金,跟踪一线游资和机构资金动向,识别短期热点和强势个股。市场情绪三大指数今日低开高走,午后集体涨超1%,创业板指盘中涨超1.7%。芯片板块集体大涨,…...
Harbor私有仓库部署与管理
目录 前言 一、Harbor概述 二、Harbor 的特性 三、Harbor的构成 四、Harbor构建Docker私有仓库 1、环境配置 2、案例需求 3、部署Harbor服务 3.1、部署docker compose服务 3.2 下载或上传Harbor安装程序 3.3、启动Harbor 3.4、查看Harbor启动镜像 4、物理机访问se…...
互联网架构之 “高可用” 详解
一、什么是高可用 高可用HA(High Availability)是分布式系统架构设计中必须考虑的因素之一,它通常是指,通过设计减少系统不能提供服务的时间。 假设系统一直能够提供服务,我们说系统的可用性是100%。 如果系统每运行…...
分布式高级篇4 —— 商城业务(2)
一、订单服务1、订单基本概念2、订单基本构成3、订单状态4、订单流程5、配置拦截器拦截订单请求6、订单确认页模型抽取7、订单确认页vo封装8、Feign 远程调用请求头丢失问题\*\*\*\*\* 惨痛教训9、Feign 异步调用请求头丢失问题10、查看库存状态11、模拟计算运费12、接口幂等性…...
二分查找基本原理
二分查找基本原理1.二分查找1.1 基本概念1.2 二分查找查找步骤1.2.1 中间索引不能整除,取整数作为中间索引1.2.2 索引不能整除,整数1作为中间索引1.3 二分查找大O记法表示2. 二分查找代码实现1.二分查找 1.1 基本概念 二分法(折半查找)是一…...
【Python实战案例】Python3网络爬虫:“可惜你不看火影,也不明白这个视频的分量......”m3u8视频下载,那些事儿~
前言 哈喽!上午好嘞,各位小可爱们!有没有等着急了呀~ 由于最近一直在学习新的内容,所以耽搁了一下下,抱歉.jpg 双手合十。 所有文章完整的素材源码都在👇👇 粉丝白嫖源码福利,请移…...
UE4:使用样条生成随机路径,并使物体沿着路径行走
一、关于样条的相关知识 参考自:样条函数 - 馒头and花卷 - 博客园 三次样条(cubic spline)插值 - 知乎 B-Spline(三)样条曲线的性质 - Fun With GeometryFun With Geometry 个人理解的也不是非常深,但是大概要知道的就是样条具…...
计算机组成原理(判断题)
计算机控制器是根据事先编好的程序,根据其指令来进行控制只会每一步骤的操作; 面向主存的双总线结构计算机系统,因在CPU与主存之间增加了一组存储器总线,由于通过存储器总线访存,提高了CPU的访存速度,也减轻…...
error: failed to push some refs to ... 就这篇,一定帮你解决
目录 一、问题产生原因 二、解决办法 三、如果还是出问题,怎么办?(必杀) 一、问题产生原因 当你直接在github上在线修改了代码,或者是直接向某个库中添加文件,但是没有对本地库同步,接着你想…...
DAMA数据管理知识体系指南之数据仓库和商务智能管理
第9章 数据仓库和商务智能管理 9.1简介 数据仓库(Data Warehouse,DW)由两个主要部分构成:首先是一个整合的决策支持数据库,其次是用于收集、清洗、转换、存储来自于各种操作型数据源和外部数据源数据的相关软件程序。两者结合以支持历史的、…...
PHP的五种常见设计模式
工厂模式 最初在设计模式 一书中,许多设计模式都鼓励使用松散耦合。要理解这个概念,让我们最好谈一下许多开发人员从事大型系统的艰苦历程。在更改一个代码片段时,就会发生问题,系统其他部分 —— 您曾认为完全不相关的部分中也有…...
教你搞懂线段树,从基础到提高
秋名山码民的主页 🎉欢迎关注🔎点赞👍收藏⭐️留言📝 🙏作者水平有限,如发现错误,还请私信或者评论区留言! 目录前言线段树逻辑概念线段树的俩个重要用处代码实现线段树题目巩固最后…...
C语言进阶——自定义类型:结构体
🌇个人主页:_麦麦_ 📚今日名言:生活不可能像你想象的那么好,也不会像你想象的那么糟。——莫泊桑《羊脂球》 目录 一、前言 二、正文 1结构体 1.1结构体的基础知识 1.2结构的声明 1.3特殊的声明 1.4结构体变量的…...
SpringSecurity学习笔记01
目录 一、课程介绍 二、框架概述 三、入门案例 四、基本原理(过滤器链) 五、基本原理(过滤器加载过程) 六、基本原理(两个重要的接口) 七、web权限方案-用户认证(设置用户名密码上) 八、…...
Python语言零基础入门教程(十一)
Python 列表(List) 序列是Python中最基本的数据结构。序列中的每个元素都分配一个数字 - 它的位置,或索引,第一个索引是0,第二个索引是1,依此类推。 Python有6个序列的内置类型,但最常见的是列表和元组。 序列都可以…...
现货白银基础知识
任何活动,任何项目,任何工作都离不开基础知识,这是肯定的。万丈高楼平地起,要想要简称百层高楼,首先得把低级打好!现货白银投资也是一样的道理,现在我们就来一起聊聊现货白银基础知识的问题&…...
数据库原理及应用基础知识点
数据库原理基础知识点大全数据库原理及应用1、数据库系统概述1.1 基本概念1.2 数据模型1.3 数据库系统的结构2、实体 -- 联系模型2.1 基本概念2.2 实体-联系图2.3 弱实体集3、关系数据模型3.1 关系数据库的结构3.2 从ER模型到关系模型3.3 关系操作、完整性约束、关系代数4、关系…...
【数据结构】栈(stack)
写在前面本篇文章开始讲解栈的有关知识,其实把顺序表和链表学好,那么这一章便不在话下,栈实际上就是顺序表或链表的一些特殊情况。用顺序表实现的栈叫做顺序栈用链表实现的栈叫做链栈文章的内容分为几个部分,希望读者能快速了解文…...
初识shell
文章目录一、shell基本知识1.1为什么学习和使用Shell编程1.2 什么是Shell1.2.1 shell的起源1.2.2 shell的功能1.3 shell的分类1.4 作为程序设计的语言——shell1.5 如何学好shell1.6 shell脚本的基本元素1.7 shell脚本编写规范1.8shell脚本的执行方式1.9 执行脚本的方法1.10 sh…...
程序员如何编写好开发技术文档 如何编写优质的API文档工作
编写技术文档,是令众多开发者望而生畏的任务之一。它本身是一件费时费力才能做好的工作。可是大多数时候,人们却总是想抄抄捷径,这样做的结果往往非常令人遗憾的,因为优质的技术文档是决定你的项目是否引人关注的重要因素。无论开…...
做网站需要什么认证/只要做好关键词优化
专栏 | 九章算法网址 | http://www.jiuzhang.com问1动态规划是个什么鸟蛋?答:动态规划是一种通过“大而化小”的思路解决问题的算法。区别于一些固定形式的算法,如二分法,宽度优先搜索法,动态规划没有实际的步骤来规定…...
盐城网站推广哪家好/免费seo工具汇总
认识主机板煮 鸡板....咳!「主机板」(Motherboard)不算电脑里最先进的零组件,但绝对是塞最多东西的零组件。事实上,现在新的主机板简直像怪物,上面 可能有数十个长长短短、大大小小、圆的方的、各式各样的插…...
美乐乐网站模板/成都网站优化公司
今天安装了一下TortoiseSVN,然后建了个test测试文件,在add或者check out 、update的时候,虽然文件是最新的,但是文件上没有对应的状态显示,即感叹号或者绿色对勾。百度了一下,找到了解决办法,在…...
做iframe跳转怎么自适应网站/seo全网推广营销软件
1、QRCode QRCode最简单的使用 import qrcode qrcode.make("第一个二维码").get_image().show() 根据文本生成二维码并且直接显示。 根据文本或URL生成二维码,保存到指定目录并显示二维码 import qrcode import os text input("请输入文本或者URL:&…...
化工网站制作/网页制作公司排名
工作小计: 参照: http://xiaomaimai.blog.51cto.com/1182965/449729 Omnitty ,一款基于ssh批量管理操作.当需要登录到远程机器时,需要确认当前用户用户权限,以避免带的损失! 下载地址: http://prdownloads.sourceforge.net/rote/rote-0.2.8.tar.gz?download http://prdownload…...
做珠宝首饰网站/电商平台运营方案思路
展开全部1、可以62616964757a686964616fe58685e5aeb931333365633939用python自带的安装工具,pip install numpy scipy 等。2、如果没有pip的话,可以试试easy-install numpy scipy。打开cmd,在里面输入这些命令。Python程序员的常见错误&#…...