ConvNeXt V2实战:使用ConvNeXt V2实现图像分类任务(一)
文章目录
- 摘要
- 安装包
- 安装timm
- 安装 grad-cam
- 数据增强Cutout和Mixup
- EMA
- 项目结构
- 计算mean和std
- 生成数据集
- 关于不上分的问题
摘要
论文:https://arxiv.org/pdf/2301.00808.pdf
论文翻译:https://wanghao.blog.csdn.net/article/details/128541957
官方源码: https://github.com/facebookresearch/ConvNeXt-V2
当前的主干网络几乎是Transformers的时代,ConvNeXt为数不多的的高性能CNN网络,V1版本就证明了其强大的存在,在V2版本中,作者提出了一个全卷积掩码自编码器框架和一个新的全局响应归一化(GRN)层,添加到ConvNeXt架构中,以增强通道间的特征竞争。作者将这种自监督学习技术和架构改进的共同设计命名为ConvNeXt V2,它显著提高了纯卷积在各种识别基准上的性能,包括ImageNet分类、COCO检测和ADE20K分割。在ImageNet上取得了88.9%的精度。
这篇文章主要讲解如何使用ConvNeXt V2完成图像分类任务,接下来我们一起完成项目的实战。本例选用的模型是convnextv2_base,在植物幼苗数据集上实现了96
%的准确率。
通过这篇文章能让你学到:
- 如何使用数据增强,包括transforms的增强、CutOut、MixUp、CutMix等增强手段?
- 如何实现ConvNeXt V2模型实现训练?
- 如何使用pytorch自带混合精度?
- 如何使用梯度裁剪防止梯度爆炸?
- 如何使用DP多显卡训练?
- 如何绘制loss和acc曲线?
- 如何生成val的测评报告?
- 如何编写测试脚本测试测试集?
- 如何使用余弦退火策略调整学习率?
- 如何使用AverageMeter类统计ACC和loss等自定义变量?
- 如何理解和统计ACC1和ACC5?
- 如何使用EMA?
- 如果使用Grad-CAM 实现热力图可视化?
如果基础薄弱,对上面的这些功能难以理解可以看我的专栏:经典主干网络精讲与实战
这个专栏,从零开始时,一步一步的讲解这些,让大家更容易接受。
安装包
安装timm
使用pip就行,命令:
pip install timm
本文实战用的timm里面的模型。
安装 grad-cam
pip install grad-cam
数据增强Cutout和Mixup
为了提高成绩我在代码中加入Cutout和Mixup这两种增强方式。实现这两种增强需要安装torchtoolbox。安装命令:
pip install torchtoolbox
Cutout实现,在transforms中。
from torchtoolbox.transform import Cutout
# 数据预处理
transform = transforms.Compose([transforms.Resize((224, 224)),Cutout(),transforms.ToTensor(),transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])])
需要导入包:from timm.data.mixup import Mixup,
定义Mixup,和SoftTargetCrossEntropy
mixup_fn = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,prob=0.1, switch_prob=0.5, mode='batch',label_smoothing=0.1, num_classes=12)criterion_train = SoftTargetCrossEntropy()
参数详解:
mixup_alpha (float): mixup alpha 值,如果 > 0,则 mixup 处于活动状态。
cutmix_alpha (float):cutmix alpha 值,如果 > 0,cutmix 处于活动状态。
cutmix_minmax (List[float]):cutmix 最小/最大图像比率,cutmix 处于活动状态,如果不是 None,则使用这个 vs alpha。
如果设置了 cutmix_minmax 则cutmix_alpha 默认为1.0
prob (float): 每批次或元素应用 mixup 或 cutmix 的概率。
switch_prob (float): 当两者都处于活动状态时切换cutmix 和mixup 的概率 。
mode (str): 如何应用 mixup/cutmix 参数(每个’batch’,‘pair’(元素对),‘elem’(元素)。
correct_lam (bool): 当 cutmix bbox 被图像边框剪裁时应用。 lambda 校正
label_smoothing (float):将标签平滑应用于混合目标张量。
num_classes (int): 目标的类数。
EMA
EMA(Exponential Moving Average)是指数移动平均值。在深度学习中的做法是保存历史的一份参数,在一定训练阶段后,拿历史的参数给目前学习的参数做一次平滑。具体实现如下:
import logging
from collections import OrderedDict
from copy import deepcopy
import torch
import torch.nn as nn_logger = logging.getLogger(__name__)class ModelEma:def __init__(self, model, decay=0.9999, device='', resume=''):# make a copy of the model for accumulating moving average of weightsself.ema = deepcopy(model)self.ema.eval()self.decay = decayself.device = device # perform ema on different device from model if setif device:self.ema.to(device=device)self.ema_has_module = hasattr(self.ema, 'module')if resume:self._load_checkpoint(resume)for p in self.ema.parameters():p.requires_grad_(False)def _load_checkpoint(self, checkpoint_path):checkpoint = torch.load(checkpoint_path, map_location='cpu')assert isinstance(checkpoint, dict)if 'state_dict_ema' in checkpoint:new_state_dict = OrderedDict()for k, v in checkpoint['state_dict_ema'].items():# ema model may have been wrapped by DataParallel, and need module prefixif self.ema_has_module:name = 'module.' + k if not k.startswith('module') else kelse:name = knew_state_dict[name] = vself.ema.load_state_dict(new_state_dict)_logger.info("Loaded state_dict_ema")else:_logger.warning("Failed to find state_dict_ema, starting from loaded model weights")def update(self, model):# correct a mismatch in state dict keysneeds_module = hasattr(model, 'module') and not self.ema_has_modulewith torch.no_grad():msd = model.state_dict()for k, ema_v in self.ema.state_dict().items():if needs_module:k = 'module.' + kmodel_v = msd[k].detach()if self.device:model_v = model_v.to(device=self.device)ema_v.copy_(ema_v * self.decay + (1. - self.decay) * model_v)
加入到模型中。
#初始化
if use_ema:model_ema = ModelEma(model_ft,decay=model_ema_decay,device='cpu',resume=resume)# 训练过程中,更新完参数后,同步update shadow weights
def train():optimizer.step()if model_ema is not None:model_ema.update(model)# 将model_ema传入验证函数中
val(model_ema.ema, DEVICE, test_loader)
针对没有预训练的模型,容易出现EMA不上分的情况,这点大家要注意啊!
项目结构
ConvNeXtV2_Demo
├─data1
│ ├─Black-grass
│ ├─Charlock
│ ├─Cleavers
│ ├─Common Chickweed
│ ├─Common wheat
│ ├─Fat Hen
│ ├─Loose Silky-bent
│ ├─Maize
│ ├─Scentless Mayweed
│ ├─Shepherds Purse
│ ├─Small-flowered Cranesbill
│ └─Sugar beet
├─models
│ ├─convnextv2.py
│ └─utils.py
├─mean_std.py
├─makedata.py
├─train.py
├─cam_image.py
└─test.py
models:来源官方代码,对面的代码做了一些适应性修改。
mean_std.py:计算mean和std的值。
makedata.py:生成数据集。
ema.py:EMA脚本
train.py:训练PoolFormer模型
cam_image.py:热力图可视化
为了能在DP方式中使用混合精度,还需要在模型的forward函数前增加@autocast(),如果使用GPU训练导入包from torch.cuda.amp import autocast,如果使用CPU,则导入from torch.cpu.amp import autocast。
计算mean和std
为了使模型更加快速的收敛,我们需要计算出mean和std的值,新建mean_std.py,插入代码:
from torchvision.datasets import ImageFolder
import torch
from torchvision import transformsdef get_mean_and_std(train_data):train_loader = torch.utils.data.DataLoader(train_data, batch_size=1, shuffle=False, num_workers=0,pin_memory=True)mean = torch.zeros(3)std = torch.zeros(3)for X, _ in train_loader:for d in range(3):mean[d] += X[:, d, :, :].mean()std[d] += X[:, d, :, :].std()mean.div_(len(train_data))std.div_(len(train_data))return list(mean.numpy()), list(std.numpy())if __name__ == '__main__':train_dataset = ImageFolder(root=r'data1', transform=transforms.ToTensor())print(get_mean_and_std(train_dataset))
数据集结构:
运行结果:
([0.3281186, 0.28937867, 0.20702125], [0.09407319, 0.09732835, 0.106712654])
把这个结果记录下来,后面要用!
生成数据集
我们整理还的图像分类的数据集结构是这样的
data
├─Black-grass
├─Charlock
├─Cleavers
├─Common Chickweed
├─Common wheat
├─Fat Hen
├─Loose Silky-bent
├─Maize
├─Scentless Mayweed
├─Shepherds Purse
├─Small-flowered Cranesbill
└─Sugar beet
pytorch和keras默认加载方式是ImageNet数据集格式,格式是
├─data
│ ├─val
│ │ ├─Black-grass
│ │ ├─Charlock
│ │ ├─Cleavers
│ │ ├─Common Chickweed
│ │ ├─Common wheat
│ │ ├─Fat Hen
│ │ ├─Loose Silky-bent
│ │ ├─Maize
│ │ ├─Scentless Mayweed
│ │ ├─Shepherds Purse
│ │ ├─Small-flowered Cranesbill
│ │ └─Sugar beet
│ └─train
│ ├─Black-grass
│ ├─Charlock
│ ├─Cleavers
│ ├─Common Chickweed
│ ├─Common wheat
│ ├─Fat Hen
│ ├─Loose Silky-bent
│ ├─Maize
│ ├─Scentless Mayweed
│ ├─Shepherds Purse
│ ├─Small-flowered Cranesbill
│ └─Sugar beet
新增格式转化脚本makedata.py,插入代码:
import glob
import os
import shutilimage_list=glob.glob('data1/*/*.png')
print(image_list)
file_dir='data'
if os.path.exists(file_dir):print('true')#os.rmdir(file_dir)shutil.rmtree(file_dir)#删除再建立os.makedirs(file_dir)
else:os.makedirs(file_dir)from sklearn.model_selection import train_test_split
trainval_files, val_files = train_test_split(image_list, test_size=0.3, random_state=42)
train_dir='train'
val_dir='val'
train_root=os.path.join(file_dir,train_dir)
val_root=os.path.join(file_dir,val_dir)
for file in trainval_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(train_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)for file in val_files:file_class=file.replace("\\","/").split('/')[-2]file_name=file.replace("\\","/").split('/')[-1]file_class=os.path.join(val_root,file_class)if not os.path.isdir(file_class):os.makedirs(file_class)shutil.copy(file, file_class + '/' + file_name)
完成上面的内容就可以开启训练和测试了。
关于不上分的问题
经过多次训练,我发现了一个玄学的问题:有时候莫名其妙的不上分,为了防止大家出现这种情况后,不知道如何去调试,我做个说明。
1、查看ema是否开启,如果开启了,先把它关掉。训练几个epoch后再开启,加载模型继续训练。
2、改变seed的值,这个就非常玄学了,设置一个自己的幸运数字。
相关文章:
ConvNeXt V2实战:使用ConvNeXt V2实现图像分类任务(一)
文章目录摘要安装包安装timm安装 grad-cam数据增强Cutout和MixupEMA项目结构计算mean和std生成数据集关于不上分的问题摘要 论文:https://arxiv.org/pdf/2301.00808.pdf 论文翻译:https://wanghao.blog.csdn.net/article/details/128541957 官方源码&am…...
3.2 报错整理
报错1: 报错:RuntimeError: DataLoader worker (pid 93789) is killed by signal: Killed.原因:显存不够报错2: 报错:TqdmWarning: IProgress not found. Please update jupyter and ipywidgets.解决:pip i…...
从0开始学python -46
Python CGI编程 什么是CGI CGI 目前由NCSA维护,NCSA定义CGI如下: CGI(Common Gateway Interface),通用网关接口,它是一段程序,运行在服务器上如:HTTP服务器,提供同客户端HTML页面的接口。 网页浏览 为了更好的了解CGI是如何工作…...
JavaScript事件委托机制详解
一、什么是事件委托机制 事件委托机制就是:我们给元素添加click事件时不在该元素上添加,而是委托给某个公共的祖辈元素,告诉祖辈元素如果接收到了click事件,并且这个click事件是由该元素触发的,就执行祖辈元素上委托绑…...
【项目实战】MySQL中union和union all的相同点与不同点
一、union和union all的相同点 在MySQL中,Union和Union All都是用来合并两个或者多个查询结果集的关键字 二、union和union all的不同点 union复杂,union all简单 2.1 自动压缩,自动求并集、去重、排序操作 (1)unio…...
ChatGPT最牛应用,让它帮你更新网站新闻吧!
谁能想到,ChatGPT火了!既能对话入流,又能写诗歌论文、出面试题、编代码,甚至还通过了谷歌面试拿到L3工程师offer,放在一年之前,没人相信这是当前AI能够达到的水平。ChatGPT自面世以来,凭借其极为…...
乌班图安装kvm并配置网络
乌班图22安装KVM 1.安装KVM sudo apt install qemu-kvm libvirt-daemon-system libvirt-clients bridge-utils virt-manager virtinstsudo adduser id -un libvirt sudo adduser id -un kvm sudo apt install virtinst qemu-efi sudo systemctl enable --now libvirtd sudo s…...
蓝库云|ERP系统在企业数字化转型中最常用的八大功能
ERP系统和与企业数字化转型 随着数字化发展的兴起,规划和管理已成为企业产生富有成效的成果的关键。许多企业采用了企业资源规划 (ERP) 等先进工具,使企业所有者能够以高效的方式规划和管理其资源和运营。 ERP系统负责整合业务的不同流程并向决策者提供…...
Pytorch学习笔记#1:拟合函数/梯度下降
学习自https://pytorch.org/tutorials/beginner/pytorch_with_examples.html 概念 Pytorch Tensor在概念上和Numpy的array一样是一个nnn维向量的。不过Tensor可以在GPU中进行计算,且可以跟踪计算图(computational graph)和梯度(…...
挑战图像处理100问(24)——伽玛校正
伽马校正(Gamma Correction)是一种图像处理技术,用于校正显示设备的非线性响应。通过对图像进行伽马变换,可以将图像的亮度范围映射到显示设备的亮度范围内,从而提高图像的对比度和细节,改善图像的视觉效果…...
高级信息系统项目管理师(高项)软考论文评分标准(附历年高项论文题目汇总)
1、如果您想了解如何高分通过高级信息系统项目管理师(高项)你可以点击一下链接: 高级信息系统项目管理师(高项)高分通过经验分享_高项经验 2、如果您想了解更多的高级信息系统项目管理(高项 软考)原创论文࿰…...
MySQL实战记录篇2
事务? 1、事务的特性:原子性、一致性、隔离性、持久性 (ACID) 2、多事务同时执行的时候,可能会出现的问题:脏读、不可重复读、幻读 3、事务隔离级别:读未提交、读提交、可重复读、串行化 4、不…...
C++实现AVL树
目录 一、搜索二叉树 1.1 搜索二叉树概念 二、模拟实现二叉搜索树 2.1 框架 2.2 构造函数 2.2.1 构造函数 2.2.2 拷贝构造 2.2.3 赋值拷贝 2.3 插入函数 2.3.1 insert() 2.3.2 RcInsert() 递归实现 2.4 删除结点函数 2.4.1 Erase() 2.4.2 RcErase() 2.5 中序遍历…...
高并发语言erlang编程初步
初步 下载安装与初步使用 下载并安装,然后开始菜单中有对应的图标,打开就能进入erlang的命令行。当然也可以将其安装路径的bin文件夹加入环境变量,然后就可以在命令行中输入erl进入erlang了。 在erlang语言中,语句结束需要用.标…...
springboot 问题记录
部署到Tomcat中的时候,找不到需要部署的项目; project facets severt-name severt-class安装lombok.jar eclipse添加lombok插件后闪退打不开Clean 项目,project clean clean的作用检查插件部署项目Springboot修改端口号:applica…...
【PAT甲级题解记录】1034 Head of a Gang (30 分)
【PAT甲级题解记录】1034 Head of a Gang (30 分) 前言 Problem:1034 Head of a Gang (30 分) Tags:图的遍历 连通分量统计 DFS Difficulty:剧情模式 想流点汗 想流点血 死而无憾 Address:1034 Head of a Gang (30 分) 问题描述 …...
Python搭建一个steam钓鱼网站,只要免费领游戏,一钓一个准
前言 嗨喽~大家好呀,这里是魔王呐 ❤ ~! 我们日常上网的时候,总是会碰到一些盗号的网站,或者是别人发一些链接给你, 里面的内容是一些可以免费购物网站的优惠券、游戏官网上可以免费领取皮肤、打折的游戏。 这些盗号网站统一的目…...
maven 私服nexus安装与使用
一、下载nexus Sonatype公司的一款maven私服产品 1、官网下载地址:https://help.sonatype.com/repomanager3/product-information/download 2、csdn下载地址:https://download.csdn.net/download/u010197591/87522994 二、安装与配置 1、下载后解压如…...
详解数据结构中的顺序表的手动实现,顺序表功能接口【数据结构】
文章目录线性表顺序表接口实现尾插尾删头插头删指定位置插入指定位置删除练习线性表 线性表(linear list)是n个具有相同特性的数据元素的有限序列。 线性表是一种在实际中广泛使用的数据结构,常见的线性表:顺序表、链表、栈、队列…...
【二】kubernetes操作
k8s卸载重置 名词解释 1、Namespace:名称用来隔离资源,不隔离网络 创建名称空间 一、命名空间namesapce 方式一:命令行创建 kubectl create ns hello删除名称空间 kubectl delete ns hello查询指定的名称空间 kubectl get pod -n kube-s…...
如何在 C++ 中调用 python 解析器来执行 python 代码(五)?
本节研究如何对 import 做白名单 目录 如何在 C 中调用 python 解析器来执行 python 代码(一)?如何在 C 中调用 python 解析器来执行 python 代码(二)?如何在 C 中调用 python 解析器来执行 python 代码&…...
八 SpringMVC【拦截器】登录验证
目录🚩一 SpringMVC拦截器✅ 1.配置文件✅2.登录验证代码(HandlerInterceptor)✅3.继承HandlerInterceptorAdapter(不建议使用)✅4.登录页面jsp✅5.主页面(操作页面)✅6.crud用户在访问页面时 只…...
PhotoShop基础使用
49:图片分类 1:像素图 特点:放大后可见,右一个个色块(像素)组合而成。 优点:容量小,纯天然 JPG、JPEG、png、GIF 2:矢量图 面向对象图像 绘图图像 特点:不…...
借助阿里云 AHPA,苏打智能轻松实现降本增效
作者:元毅 “高猛科技已在几个主要服务 ACK 集群上启用了 AHPA。相比于 HPA 的方案,AHPA 的主动预测模式额外降低了 12% 的资源成本。同时 AHPA 能够提前资源预热、自动容量规划,能够很好的应对突发流量。” ——赵劲松 (高猛科技高级后台工…...
美团2面:如何保障 MySQL 和 Redis 数据一致性?这样答,让面试官爱到 死去活来
美团2面:如何保障 MySQL 和 Redis 的数据一致性? 说在前面 在尼恩的(50)读者社群中,经常遇到一个 非常、非常高频的一个面试题,但是很不好回答,类似如下: 如何保障 MySQL 和 Redis…...
react hooks学习记录
react hook学习记录1.什么是hooks2.State Hook3.Effect Hook4.Ref Hook1.什么是hooks (1). Hook是React 16.8.0版本增加的新特性/新语法 (2). 可以让你在函数组件中使用 state 以及其他的 React 特性 貌似现在更多的也是使用函数式组件的了,重要 2.State Hook imp…...
创新型中小企业认定评定标准
一、公告条件评价得分达到 60 分以上(其中创新能力指标得分不低于 20 分、成长性指标及专业化指标得分均不低于 15 分),或满足下列条件之一:(一)近三年内获得过国家级、省级科技奖励。(二&#…...
记录一次nginx转发代理skywalking白屏 以及nginx鉴权配置
上nginx代码 #user nobody; worker_processes 1; #error_log logs/error.log; #error_log logs/error.log notice; #error_log logs/error.log info; #pid logs/nginx.pid; events { worker_connections 1024; } http { include mime.types; …...
如何使用FarsightAD在活动目录域中检测攻击者部署的持久化机制
关于FarsightAD FarsightAD是一款功能强大的PowerShell脚本,该工具可以帮助广大研究人员在活动目录域遭受到渗透攻击之后,检测到由攻击者部署的持久化机制。 该脚本能够生成并导出各种对象及其属性的CSV/JSON文件,并附带从元数据副本中获取…...
Python - 操作txt文件
文章目录打开txt文件读取txt文件写入txt文件删除txt文件打开txt文件 open(file, moder, bufferingNone, encodingNone, errorsNone, newlineNone, closefdTrue)函数用来打开txt文件。 #方法1,这种方式使用后需要关闭文件 f open("data.txt","r&qu…...
新网站怎么发外链/电脑培训机构
将我正在制作的游戏应用程序导出为可运行的.jar文件后,我从命令提示符处运行了它,只是发现了一些问题。 我不得不解决一些问题,找不到音频路径,也无法读取/写入.txt文件。 我看过并能够解决的音频文件,我看过并接近解决…...
建设网站免费/站长工具seo综合查询官网
她, 一岁了, 可爱的大白; 她, 一个女孩, 活泼的兔兔; 她, 一个贪吃鬼, 胖胖的吃货; 虽叫兔兔但并非兔子, 是一只龙猫! 首先先介绍下龙猫的由来: 由于毛丝鼠的长相与日本动画大师宫崎俊的动画片《龙猫》中的主人公龙猫太郎相似,故后被港台称之为龙猫. 毛丝鼠…...
做网站需要招什么条件/百度竞价推广开户联系方式
对于一个普通用户来说,你的项目就是一组简单的视图集合,用户直接通过跟视图进行交互来操作你的应用,对于一个开发人员来说,视图是一个项目的入口,虽然大部分时候最有价值的部分是在model层和control层,所以…...
北京网站建设的价格天/淄博网络推广公司哪家好
早晨起床时间:8:20 晚上休息时间:23:28 全天处理事件:1.监督小妹复习前一周学的内容和布置的作业。2.调试移植的nrf24l01模块的驱动程序,使其工作正常。 处事经验总结:经过这三天的调试,最总发现是模块的供…...
中国平湖首页规划建设局网站/成都seo优化排名公司
1:单元测试要求至少达到语句覆盖。 2:单元测试开始要跟踪每一条语句,并观察数据流及变量的变化。 3:清理、整理或优化后的代码要经过审查及测试。 4:代码版本升级要经过严格测试。 5:使用工具软件对代码版本…...
深圳小程序网站开发公司/关键词异地排名查询
2019独角兽企业重金招聘Python工程师标准>>> for /r E:\XX %i in (CVS) do rd /s /q %i 上述命令的含义是 删除E盘下XX问价下的CVS文件夹 转载于:https://my.oschina.net/u/1262156/blog/363242...