当前位置: 首页 > news >正文

VOLO实战:使用VOLO实现图像分类任务(二)

文章目录

  • 训练部分
    • 导入项目使用的库
    • 设置随机因子
    • 设置全局参数
    • 图像预处理与增强
    • 读取数据
    • 设置Loss
    • 设置模型
    • 设置优化器和学习率调整策略
    • 设置混合精度,DP多卡,EMA
    • 定义训练和验证函数
      • 训练函数
      • 验证函数
      • 调用训练和验证方法
  • 运行以及结果查看
  • 测试
  • 完整的代码

在上一篇文章中完成了前期的准备工作,见链接:
VOLO实战:使用VOLO实现图像分类任务(一)
前期的工作主要是数据的准备,安装库文件,数据增强方式的讲解,模型的介绍和实验效果等内容。接下来,这篇主要是讲解如何训练和测试

训练部分

完成上面的步骤后,就开始train脚本的编写,新建train.py

导入项目使用的库

在train.py导入

import json
import os
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
import torch.nn.parallel
import torch.optim as optim
import torch.utils.data
import torch.utils.data.distributed
import torchvision.transforms as transforms
from timm.utils import accuracy, AverageMeter, ModelEma
from sklearn.metrics import classification_report
from timm.data.mixup import Mixup
from timm.loss import SoftTargetCrossEntropy
from models.volo import volo_d1
from torchvision import datasetstorch.backends.cudnn.benchmark = False
import warningswarnings.filterwarnings("ignore")
os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

当您需要在具有多个GPU的机器上指定用于训练的GPU时,可以通过设置环境变量CUDA_VISIBLE_DEVICES来实现。这个环境变量的值是一个由逗号分隔的GPU索引列表,索引从0开始。例如,如果您的机器上有8块GPU,并且您希望仅使用前两块GPU(即索引为0和1的GPU)进行训练,您应该设置:

os.environ['CUDA_VISIBLE_DEVICES'] = "0,1"

这样,只有索引为0和1的GPU会被系统识别并用于训练。类似地,如果您希望使用第三块(索引为2)和第六块(索引为5)GPU进行训练,您应该相应地设置:

os.environ['CUDA_VISIBLE_DEVICES'] = "2,5"

通过这种方式,您可以灵活地选择任意数量的GPU进行训练,而无需担心其他GPU的干扰。

设置随机因子

def seed_everything(seed=42):# 设置Python的哈希种子os.environ['PYTHONHASHSEED'] = str(seed)# 设置PyTorch的CPU随机种子torch.manual_seed(seed)# 如果使用CUDA,设置CUDA的随机种子if torch.cuda.is_available():torch.cuda.manual_seed(seed)torch.cuda.manual_seed_all(seed)  # 如果你的代码在多个GPU上运行# 启用CUDA的确定性行为(对卷积等操作的确定性有帮助)torch.backends.cudnn.benchmark = Falsetorch.backends.cudnn.deterministic = True# 使用示例
seed_everything(42)

这里有一些额外的说明和注意事项:

  1. torch.cuda.manual_seed_all(seed):这个调用是可选的,但如果你在多GPU环境中工作(比如使用DataParallelDistributedDataParallel),它确保所有GPU上的随机操作都将从相同的种子开始。如果你的代码只在一个GPU上运行,这个调用不是必需的,但也不会造成问题。

  2. torch.backends.cudnn.benchmark = False:当设置为True时,cuDNN会在运行时自动选择算法来优化性能。然而,这可能会导致每次运行时的行为不完全相同,因为算法的选择可能会基于输入数据的形状和大小而变化。为了实验的可重复性,最好将其设置为False

  3. 图片加载顺序:虽然设置随机种子有助于确保模型的随机操作(如初始化权重、dropout等)是可重复的,但它本身并不直接控制图片加载的顺序。图片加载顺序通常由数据集加载器(如DataLoader)的shuffle参数控制。如果你想要固定的加载顺序,确保在创建DataLoader时将shuffle=False

  4. 其他随机性来源:请注意,即使你设置了这些随机种子,还可能存在其他随机性来源,如操作系统级别的调度或硬件层面的差异(如GPU的浮点精度差异)。在极端情况下,这些差异可能会影响结果的精确可重复性。然而,在大多数情况下,上述设置应该足以确保实验在相同的软件和环境配置下是可重复的。

设置全局参数

if __name__ == '__main__':# 创建保存模型的文件夹file_dir = 'checkpoints/VOLO/'if os.path.exists(file_dir):print('true')os.makedirs(file_dir, exist_ok=True)else:os.makedirs(file_dir)# 设置全局参数model_lr = 1e-4BATCH_SIZE = 16EPOCHS = 300DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')use_amp = True  # 是否使用混合精度use_dp = True  # 是否开启dp方式的多卡训练classes = 12resume = NoneCLIP_GRAD = 5.0Best_ACC = 0  # 记录最高得分use_ema = Falsemodel_ema_decay = 0.9998start_epoch = 1seed = 1seed_everything(seed)

创建一个名为 ‘checkpoints/VOLO/’ 的文件夹,用于保存训练过程中的模型。如果该文件夹已经存在,则不会再次创建,否则会创建该文件夹。

设置训练模型的全局参数,包括学习率、批次大小、训练轮数、设备选择(是否使用 GPU)、是否使用混合精度、是否开启数据并行等。

注:建议使用GPU,CPU太慢了。

参数的详细解释:

model_lr:学习率,根据实际情况做调整。

BATCH_SIZE:batchsize,根据显卡的大小设置。

EPOCHS:epoch的个数,一般300够用。

use_amp:是否使用混合精度。

use_dp :是否开启dp方式的多卡训练?如果您打算使用多GPU训练将use_dp 设置为 True。

classes:类别个数。

resume:再次训练的模型路径,如果不为None,则表示加载resume指向的模型继续训练。

CLIP_GRAD:梯度的最大范数,在梯度裁剪里设置。

Best_ACC:记录最高ACC得分。

use_ema:是否使用ema,如果没有使用预训练模型,直接打开use_ema会造成不上分的情况。可以先关闭ema训练几个epoch,然后,将训练的权重赋值到resume,再将启用ema

model_ema_decay:设置了EMA的衰减率。衰减率决定了当前模型权重和之前的EMA权重在更新新的EMA权重时的相对贡献。具体来说,每次更新EMA权重时,都会按照以下公式进行:
newemaweight = decay × oldemaweight + ( 1 − decay ) × currentmodelweight \text{newemaweight} = \text{decay} \times \text{oldemaweight} + (1 - \text{decay}) \times \text{currentmodelweight} newemaweight=decay×oldemaweight+(1decay)×currentmodelweight
例如,衰减率被设置为0.9998。这意味着在更新EMA权重时,大约99.98%的权重来自之前的EMA权重,而剩下的0.02%来自当前的模型权重。由于衰减率非常接近1,EMA权重会更多地依赖于之前的EMA权重,而不是当前的模型权重。这有助于平滑模型权重的波动,并减少噪声对最终模型性能的影响。

start_epoch:开始的epoch,默认是1,如果重新训练时,需要给start_epoch重新赋值。

SEED:随机因子,数值可以随意设定,但是设置后,不要随意更改,更改后,图片加载的顺序会改变,影响测试结果。

  file_dir = 'checkpoints/VOLO/'

这是存放VOLO模型的路径。

图像预处理与增强

   # 数据预处理7transform = transforms.Compose([transforms.RandomRotation(10),transforms.GaussianBlur(kernel_size=(5,5),sigma=(0.1, 3.0)),transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5),transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.3281186, 0.28937867, 0.20702125], std= [0.09407319, 0.09732835, 0.106712654])])transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.3281186, 0.28937867, 0.20702125], std= [0.09407319, 0.09732835, 0.106712654])])mixup_fn = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,prob=0.1, switch_prob=0.5, mode='batch',label_smoothing=0.1, num_classes=classes)

数据处理和增强比较简单,加入了随机10度的旋转、高斯模糊、色彩饱和度明亮度的变化、Mixup等比较常用的增强手段,做了Resize和归一化。

 transforms.Normalize(mean=[0.3281186, 0.28937867, 0.20702125], std= [0.09407319, 0.09732835, 0.106712654])

这里设置为计算mean和std。
这里注意下Resize的大小,由于选用的模型输入是224×224的大小,所以要Resize为224×224。

数据预处理流程结合了多种常用的数据增强技术,包括随机旋转、高斯模糊、色彩抖动(ColorJitter)、Resize以及归一化,还引入了Mixup和可能的CutMix技术来进一步增强模型的泛化能力。参数详解:

  • transforms.RandomRotation(10): 随机旋转图像最多10度,有助于模型学习旋转不变性。
  • transforms.GaussianBlur(kernel_size=(5,5), sigma=(0.1, 3.0)): 应用高斯模糊,模拟图像的模糊情况,增强模型对模糊图像的鲁棒性。
  • transforms.ColorJitter(brightness=0.5, contrast=0.5, saturation=0.5): 调整图像的亮度、对比度和饱和度,增加数据的多样性。
  • transforms.Resize((224, 224)): 将图像大小调整为224x224,以符合模型的输入要求。
  • transforms.ToTensor(): 将PIL Image或NumPy ndarray转换为FloatTensor,并归一化到[0.0, 1.0]。
  • transforms.Normalize(mean, std): 使用指定的均值和标准差对图像进行归一化处理,有助于模型训练。
 mixup_fn = Mixup(mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,prob=0.1, switch_prob=0.5, mode='batch',label_smoothing=0.1, num_classes=classes)

定义了一个 Mixup 函数。Mixup 是一种在图像分类任务中常用的数据增强技术,它通过将两张图像以及其对应的标签进行线性组合来生成新的数据和标签。

Mixup 是一种正则化技术,通过混合输入数据和它们的标签来增强模型的泛化能力。在您的代码中,Mixup 类还包含了 CutMix 的参数,但具体实现可能需要根据您使用的库(如 timm 或自定义实现)来确定。参数详解:

mixup_alpha: Mixup 中用于Beta分布的α参数,控制混合强度的分布。 cutmix_alpha: CutMix
中用于Beta分布的α参数,同样控制混合强度的分布。 cutmix_minmax: CutMix 中裁剪区域的最小和最大比例,但在这里设为
None,可能表示使用默认的或根据 cutmix_alpha 自动计算的比例。 prob: 应用Mixup或CutMix的概率。
switch_prob: 在Mixup和CutMix之间切换的概率(如果Mixup和CutMix都被启用)。 mode:
指定Mixup是在整个批次上进行还是在单个样本之间进行。 label_smoothing: 标签平滑参数,用于减少模型对硬标签的过度自信。
num_classes: 类别数,用于标签平滑计算。

读取数据

   # 读取数据dataset_train = datasets.ImageFolder('data/train', transform=transform)dataset_test = datasets.ImageFolder("data/val", transform=transform_test)with open('class.txt', 'w') as file:file.write(str(dataset_train.class_to_idx))with open('class.json', 'w', encoding='utf-8') as file:file.write(json.dumps(dataset_train.class_to_idx))# 导入数据train_loader = torch.utils.data.DataLoader(dataset_train, batch_size=BATCH_SIZE,num_workers=8,pin_memory=True,shuffle=True,drop_last=True)test_loader = torch.utils.data.DataLoader(dataset_test, batch_size=BATCH_SIZE, shuffle=False)
  • 使用pytorch默认读取数据的方式,然后将dataset_train.class_to_idx打印出来,预测的时候要用到。

  • 对于train_loader ,drop_last设置为True,因为使用了Mixup数据增强,必须保证每个batch里面的图片个数为偶数(不能为零),如果最后一个batch里面的图片为奇数,则会报错,所以舍弃最后batch的迭代,pin_memory设置为True,可以加快运行速度,num_workers多进程加载图像,不要超过CPU 的核数。

  • 将dataset_train.class_to_idx保存到txt文件或者json文件中。

class_to_idx的结果:

{'Black-grass': 0, 'Charlock': 1, 'Cleavers': 2, 'Common Chickweed': 3, 'Common wheat': 4, 'Fat Hen': 5, 'Loose Silky-bent': 6, 'Maize': 7, 'Scentless Mayweed': 8, 'Shepherds Purse': 9, 'Small-flowered Cranesbill': 10, 'Sugar beet': 11}

设置Loss

# 设置loss函数  
# 训练的loss函数为SoftTargetCrossEntropy,用于处理具有软目标(soft targets)的训练场景  
criterion_train = SoftTargetCrossEntropy()  # 验证的loss函数为nn.CrossEntropyLoss(),适用于多分类问题的标准交叉熵损失  
criterion_val = torch.nn.CrossEntropyLoss() 

设置loss函数,训练的loss为:SoftTargetCrossEntropy,验证的loss:nn.CrossEntropyLoss()。

设置模型

    # 设置模型model_ft = volo_d1(pretrained=True)print(model_ft)num_fr = model_ft.head.in_featuresmodel_ft.head = nn.Linear(num_fr, classes)num_fr = model_ft.aux_head.in_featuresmodel_ft.aux_head = nn.Linear(num_fr, classes)print(model_ft)if resume:model = torch.load(resume)print(model['state_dict'].keys())model_ft.load_state_dict(model['state_dict'])Best_ACC = model['Best_ACC']start_epoch = model['epoch'] + 1model_ft.to(DEVICE)
  • 设置模型为volo_d1,然后,找到head的in_features,修改为数据集的类别,也就是classes。

  • 如果resume设置为已经训练的模型的路径,则加载模型接着resume指向的模型接着训练,使用模型里的Best_ACC初始化Best_ACC,使用epoch参数初始化start_epoch。

  • 如果模型输出是classes的长度,则表示修改正确了。

在这里插入图片描述

设置优化器和学习率调整策略

   # 选择简单暴力的Adam优化器,学习率调低optimizer = optim.AdamW(model_ft.parameters(),lr=model_lr)cosine_schedule = optim.lr_scheduler.CosineAnnealingLR(optimizer=optimizer, T_max=20, eta_min=1e-6)
  • 优化器设置为adamW。
  • 学习率调整策略选择为余弦退火。

设置混合精度,DP多卡,EMA

    if use_amp:scaler = torch.cuda.amp.GradScaler()if torch.cuda.device_count() > 1 and use_dp:print("Let's use", torch.cuda.device_count(), "GPUs!")model_ft = torch.nn.DataParallel(model_ft)if use_ema:model_ema = ModelEma(model_ft,decay=model_ema_decay,device=DEVICE,resume=resume)else:model_ema=None

定义训练和验证函数

训练函数

# 定义训练过程
def train(model, device, train_loader, optimizer, epoch,model_ema):model.train()loss_meter = AverageMeter()acc1_meter = AverageMeter()acc5_meter = AverageMeter()total_num = len(train_loader.dataset)print(total_num, len(train_loader))for batch_idx, (data, target) in enumerate(train_loader):data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)samples, targets = mixup_fn(data, target)output = model(samples)[0]optimizer.zero_grad()if use_amp:with torch.cuda.amp.autocast():loss = torch.nan_to_num(criterion_train(output, targets))scaler.scale(loss).backward()torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD)# Unscales gradients and calls# or skips optimizer.step()scaler.step(optimizer)# Updates the scale for next iterationscaler.update()else:loss = criterion_train(output, targets)torch.nn.utils.clip_grad_norm_(model.parameters(), CLIP_GRAD)loss.backward()optimizer.step()if model_ema is not None:model_ema.update(model)lr = optimizer.state_dict()['param_groups'][0]['lr']loss_meter.update(loss.item(), target.size(0))acc1, acc5 = accuracy(output, target, topk=(1, 5))acc1_meter.update(acc1.item(), target.size(0))acc5_meter.update(acc5.item(), target.size(0))if (batch_idx + 1) % 10 == 0:print('Train Epoch: {} [{}/{} ({:.0f}%)]\tLoss: {:.6f}\tLR:{:.9f}'.format(epoch, (batch_idx + 1) * train_loader.batch_size, len(train_loader.dataset),100. * (batch_idx + 1) / len(train_loader), loss.item(), lr))ave_loss =loss_meter.avgacc = acc1_meter.avgprint('epoch:{}\tloss:{:.2f}\tacc:{:.2f}'.format(epoch, ave_loss, acc))return ave_loss, acc

训练的主要步骤:

1、使用AverageMeter保存自定义变量,包括loss,ACC1,ACC5。

2、进入循环,将data和target放入device上,non_blocking设置为True。如果pin_memory=True的话,将数据放入GPU的时候,也应该把non_blocking打开,这样就只把数据放入GPU而不取出,访问时间会大大减少。
如果pin_memory=False时,则将non_blocking设置为False。

3、将数据输入mixup_fn生成mixup数据。

4、将第三部生成的mixup数据输入model,输出预测结果,然后再计算loss。

5、 optimizer.zero_grad() 梯度清零,把loss关于weight的导数变成0。

6、如果使用混合精度,则

  • with torch.cuda.amp.autocast(),开启混合精度。
  • 计算loss。torch.nan_to_num将输入中的NaN、正无穷大和负无穷大替换为NaN、posinf和neginf。默认情况下,nan会被替换为零,正无穷大会被替换为输入的dtype所能表示的最大有限值,负无穷大会被替换为输入的dtype所能表示的最小有限值。
  • scaler.scale(loss).backward(),梯度放大。
  • torch.nn.utils.clip_grad_norm_,梯度裁剪,放置梯度爆炸。
  • scaler.step(optimizer) ,首先把梯度值unscale回来,如果梯度值不是inf或NaN,则调用optimizer.step()来更新权重,否则,忽略step调用,从而保证权重不更新。
  • 更新下一次迭代的scaler。

否则,直接反向传播求梯度。torch.nn.utils.clip_grad_norm_函数执行梯度裁剪,防止梯度爆炸。

7、如果use_ema为True,则执行model_ema的updata函数,更新模型。

8、 torch.cuda.synchronize(),等待上面所有的操作执行完成。

9、接下来,更新loss,ACC1,ACC5的值。

等待一个epoch训练完成后,计算平均loss和平均acc

验证函数

# 验证过程
@torch.no_grad()
def val(model, device, test_loader):global Best_ACCmodel.eval()loss_meter = AverageMeter()acc1_meter = AverageMeter()acc5_meter = AverageMeter()total_num = len(test_loader.dataset)print(total_num, len(test_loader))val_list = []pred_list = []for data, target in test_loader:for t in target:val_list.append(t.data.item())data, target = data.to(device, non_blocking=True), target.to(device, non_blocking=True)output = model(data)loss = criterion_val(output, target)_, pred = torch.max(output.data, 1)for p in pred:pred_list.append(p.data.item())acc1, acc5 = accuracy(output, target, topk=(1, 5))loss_meter.update(loss.item(), target.size(0))acc1_meter.update(acc1.item(), target.size(0))acc5_meter.update(acc5.item(), target.size(0))acc = acc1_meter.avgprint('\nVal set: Average loss: {:.4f}\tAcc1:{:.3f}%\tAcc5:{:.3f}%\n'.format(loss_meter.avg, acc, acc5_meter.avg))if acc > Best_ACC:if isinstance(model, torch.nn.DataParallel):torch.save(model.module, file_dir + '/' + 'best.pth')else:torch.save(model, file_dir + '/' + 'best.pth')Best_ACC = accif isinstance(model, torch.nn.DataParallel):state = {'epoch': epoch,'state_dict': model.module.state_dict(),'Best_ACC': Best_ACC}if use_ema:state['state_dict_ema'] = model.module.state_dict()torch.save(state, file_dir + "/" + 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')else:state = {'epoch': epoch,'state_dict': model.state_dict(),'Best_ACC': Best_ACC}if use_ema:state['state_dict_ema'] = model.state_dict()torch.save(state, file_dir + "/" + 'model_' + str(epoch) + '_' + str(round(acc, 3)) + '.pth')return val_list, pred_list, loss_meter.avg, acc

验证集和训练集大致相似,主要步骤:

1、在val的函数上面添加@torch.no_grad(),作用:所有计算得出的tensor的requires_grad都自动设置为False。即使一个tensor(命名为x)的requires_grad = True,在with torch.no_grad计算,由x得到的新tensor(命名为w-标量)requires_grad也为False,且grad_fn也为None,即不会对w求导。

2、定义参数:
loss_meter: 测试的loss
acc1_meter:top1的ACC。
acc5_meter:top5的ACC。
total_num:总的验证集的数量。
val_list:验证集的label。
pred_list:预测的label。

3、进入循环,迭代test_loader:

将label保存到val_list。

将data和target放入device上,non_blocking设置为True。

将data输入到model中,求出预测值,然后输入到loss函数中,求出loss。

调用torch.max函数,将预测值转为对应的label。

将输出的预测值的label存入pred_list。

调用accuracy函数计算ACC1和ACC5

更新loss_meter、acc1_meter、acc5_meter的参数。

4、本次epoch循环完成后,求得本次epoch的acc、loss。
5、接下来是保存模型的逻辑
如果ACC比Best_ACC高,则保存best模型
判断模型是否为DP方式训练的模型。

如果是DP方式训练的模型,模型参数放在model.module,则需要保存model.module。
否则直接保存model。
注:保存best模型,我们采用保存整个模型的方式,这样保存的模型包含网络结构,在预测的时候,就不用再重新定义网络了。

6、接下来保存每个epoch的模型。
判断模型是否为DP方式训练的模型。

如果是DP方式训练的模型,模型参数放在model.module,则需要保存model.module.state_dict()。

新建个字典,放置Best_ACC、epoch和 model.module.state_dict()等参数。然后将这个字典保存。判断是否是使用EMA,如果使用,则还需要保存一份ema的权重。
否则,新建个字典,放置Best_ACC、epoch和 model.state_dict()等参数。然后将这个字典保存。判断是否是使用EMA,如果使用,则还需要保存一份ema的权重。

注意:对于每个epoch的模型只保存了state_dict参数,没有保存整个模型文件。

调用训练和验证方法

    # 训练与验证is_set_lr = Falselog_dir = {}train_loss_list, val_loss_list, train_acc_list, val_acc_list, epoch_list = [], [], [], [], []if resume and os.path.isfile(file_dir+"result.json"):with open(file_dir+'result.json', 'r', encoding='utf-8') as file:logs = json.load(file)train_acc_list = logs['train_acc']train_loss_list = logs['train_loss']val_acc_list = logs['val_acc']val_loss_list = logs['val_loss']epoch_list = logs['epoch_list']for epoch in range(start_epoch, EPOCHS + 1):epoch_list.append(epoch)log_dir['epoch_list'] = epoch_listtrain_loss, train_acc = train(model_ft, DEVICE, train_loader, optimizer, epoch,model_ema)train_loss_list.append(train_loss)train_acc_list.append(train_acc)log_dir['train_acc'] = train_acc_listlog_dir['train_loss'] = train_loss_listif use_ema:val_list, pred_list, val_loss, val_acc = val(model_ema.ema, DEVICE, test_loader)else:val_list, pred_list, val_loss, val_acc = val(model_ft, DEVICE, test_loader)val_loss_list.append(val_loss)val_acc_list.append(val_acc)log_dir['val_acc'] = val_acc_listlog_dir['val_loss'] = val_loss_listlog_dir['best_acc'] = Best_ACCwith open(file_dir + '/result.json', 'w', encoding='utf-8') as file:file.write(json.dumps(log_dir))print(classification_report(val_list, pred_list, target_names=dataset_train.class_to_idx))if epoch < 600:cosine_schedule.step()else:if not is_set_lr:for param_group in optimizer.param_groups:param_group["lr"] = 1e-6is_set_lr = Truefig = plt.figure(1)plt.plot(epoch_list, train_loss_list, 'r-', label=u'Train Loss')# 显示图例plt.plot(epoch_list, val_loss_list, 'b-', label=u'Val Loss')plt.legend(["Train Loss", "Val Loss"], loc="upper right")plt.xlabel(u'epoch')plt.ylabel(u'loss')plt.title('Model Loss ')plt.savefig(file_dir + "/loss.png")plt.close(1)fig2 = plt.figure(2)plt.plot(epoch_list, train_acc_list, 'r-', label=u'Train Acc')plt.plot(epoch_list, val_acc_list, 'b-', label=u'Val Acc')plt.legend(["Train Acc", "Val Acc"], loc="lower right")plt.title("Model Acc")plt.ylabel("acc")plt.xlabel("epoch")plt.savefig(file_dir + "/acc.png")plt.close(2)

调用训练函数和验证函数的主要步骤:

1、定义参数:

  • is_set_lr,是否已经设置了学习率,当epoch大于一定的次数后,会将学习率设置到一定的值,并将其置为True。
  • log_dir:记录log用的,将有用的信息保存到字典中,然后转为json保存起来。
  • train_loss_list:保存每个epoch的训练loss。
  • val_loss_list:保存每个epoch的验证loss。
  • train_acc_list:保存每个epoch的训练acc。
  • val_acc_list:保存么每个epoch的验证acc。
  • epoch_list:存放每个epoch的值。

如果是接着上次的断点继续训练则读取log文件,然后把log取出来,赋值到对应的list上。
循环epoch

1、调用train函数,得到 train_loss, train_acc,并将分别放入train_loss_list,train_acc_list,然后存入到logdir字典中。

2、调用验证函数,判断是否使用EMA?
如果使用EMA,则传入model_ema.ema,否则,传入model_ft。得到val_list, pred_list, val_loss, val_acc。将val_loss, val_acc分别放入val_loss_list和val_acc_list中,然后存入到logdir字典中。

3、保存log。

4、打印本次的测试报告。

5、如果epoch大于600,将学习率设置为固定的1e-6。

6、绘制loss曲线和acc曲线。

运行以及结果查看

完成上面的所有代码就可以开始运行了。点击右键,然后选择“run train.py”即可,运行结果如下:

在这里插入图片描述

在每个epoch测试完成之后,打印验证集的acc、recall等指标。

VOLO测试结果:

在这里插入图片描述
在这里插入图片描述

测试

测试,我们采用一种通用的方式。

测试集存放的目录如下图:

VOLO_Demo
├─test
│  ├─1.jpg
│  ├─2.jpg
│  ├─3.jpg
│  ├ ......
└─test.py
import torch.utils.data.distributed
import torchvision.transforms as transforms
from PIL import Image
from torch.autograd import Variable
import osclasses = ('Black-grass', 'Charlock', 'Cleavers', 'Common Chickweed','Common wheat', 'Fat Hen', 'Loose Silky-bent','Maize', 'Scentless Mayweed', 'Shepherds Purse', 'Small-flowered Cranesbill', 'Sugar beet')
transform_test = transforms.Compose([transforms.Resize((224, 224)),transforms.ToTensor(),transforms.Normalize(mean=[0.44127703, 0.4712498, 0.43714803], std=[0.18507297, 0.18050247, 0.16784933])
])DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model=torch.load('checkpoints/VOLO/best.pth')
model.eval()
model.to(DEVICE)path = 'test/'
testList = os.listdir(path)
for file in testList:img = Image.open(path + file)img = transform_test(img)img.unsqueeze_(0)img = Variable(img).to(DEVICE)out = model(img)# Predict_, pred = torch.max(out.data, 1)print('Image Name:{},predict:{}'.format(file, classes[pred.data.item()]))

测试的主要逻辑:

1、定义类别,这个类别的顺序和训练时的类别顺序对应,一定不要改变顺序!!!!

2、定义transforms,transforms和验证集的transforms一样即可,别做数据增强。

3、 torch.load加载model,然后将模型放在DEVICE里,

4、循环 读取图片并预测图片的类别,在这里注意,读取图片用PIL库的Image。不要用cv2,transforms不支持。循环里面的主要逻辑:

  • 使用Image.open读取图片
  • 使用transform_test对图片做归一化和标椎化。
  • img.unsqueeze_(0) 增加一个维度,由(3,224,224)变为(1,3,224,224)
  • Variable(img).to(DEVICE):将数据放入DEVICE中。
  • model(img):执行预测。
  • _, pred = torch.max(out.data, 1):获取预测值的最大下角标。

运行结果:

在这里插入图片描述

完整的代码

完整的代码:

https://download.csdn.net/download/hhhhhhhhhhwwwwwwwwww/90033922

相关文章:

VOLO实战:使用VOLO实现图像分类任务(二)

文章目录 训练部分导入项目使用的库设置随机因子设置全局参数图像预处理与增强读取数据设置Loss设置模型设置优化器和学习率调整策略设置混合精度&#xff0c;DP多卡&#xff0c;EMA定义训练和验证函数训练函数验证函数调用训练和验证方法 运行以及结果查看测试完整的代码 在上…...

【kafka02】消息队列与微服务之Kafka部署

Kafka 部署 Kafka 部署说明 kafka 版本选择 kafka 基于scala语言实现,所以使用kafka需要指定scala的相应的版本.kafka 为多个版本的Scala构建。这仅在使用 Scala 时才重要&#xff0c;并且希望为使用的相同 Scala 版本构建一个版本。否则&#xff0c;任何版本都可以 kafka下…...

MySQL系列之数据类型(Numeric)

导览 前言一、数值类型综述二、数值类型详解1. NUMERIC1.1 UNSIGNED或SIGNED1.2 数据类型划分 2. Integer类型取值和存储要求3. Fixed-Point类型取值和存储要求4. Floating-Point类型取值和存储要求 结语精彩回放 前言 MySQL系列最近三篇均关注了和我们日常工作或学习密切相关…...

BERT简单理解;双向编码器优势

目录 BERT简单理解 一、BERT模型简单理解 二、BERT模型使用举例 三、BERT模型的优势 双向编码器优势 BERT简单理解 (Bidirectional Encoder Representations from Transformers)模型是一种预训练的自然语言处理(NLP)模型,由Google于2018年推出。以下是对BERT模型的简…...

LLamafactory 批量推理与异步 API 调用效率对比实测

背景 在阅读 LLamafactory 的文档时候&#xff0c;发现它支持批量推理: 推理.https://llamafactory.readthedocs.io/zh-cn/latest/getting_started/inference.html 。 于是便想测试一下&#xff0c;它的批量推理速度有多快。本文实现了 下述两种的大模型推理&#xff0c;并对…...

spf算法、三类LSA、区间防环路机制/规则、虚连接

1.构建spf树&#xff1a; 路由器将自己作为最短路经树的树根根据Router-LSA和Network-LSA中的拓扑信息,依次将Cost值最小的路由器添加到SPF树中。路由器以Router ID或者DR标识。广播网络中DR和其所连接路由器的Cost值为0。SPF树中只有单向的最短路径,保证了OSPF区域内路由计管不…...

C语言学习 12(指针学习1)

一.内存和地址 1.内存 在讲内存和地址之前&#xff0c;我们想有个⽣活中的案例&#xff1a; 假设有⼀栋宿舍楼&#xff0c;把你放在楼⾥&#xff0c;楼上有100个房间&#xff0c;但是房间没有编号&#xff0c;你的⼀个朋友来找你玩&#xff0c;如果想找到你&#xff0c;就得挨…...

TypeError: issubclass() arg 1 must be a class

TypeError: issubclass() arg 1 must be a class 报错代码&#xff1a; import spacy 原因&#xff1a; 库版本错误&#xff0c; 解决方法&#xff1a; pip install typing-inspect0.8.0 typing_extensions4.5.0 感谢作者&#xff1a; langchain TypeError: issubclass() …...

Java面试题、八股文学习之JVM篇

1.对象一定分配在堆中吗&#xff1f;有没有了解逃逸分析技术&#xff1f; 对象不一定总是分配在堆中。在Java等一些高级编程语言中&#xff0c;对象的分配位置可以通过编译器或运行时系统的优化来决定。其中&#xff0c;逃逸分析&#xff08;Escape Analysis&#xff09;是用于…...

【eNSP】动态路由协议RIP和OSPF

动态路由RIP&#xff08;Routing Information Protocol&#xff0c;路由信息协议&#xff09;和OSPF&#xff08;Open Shortest Path First&#xff0c;开放式最短路径优先&#xff09;是两种常见的动态路由协议&#xff0c;它们各自具有不同的特点和使用场景。本篇会对这两种协…...

春秋云境 CVE 复现

CVE-2022-4230 靶标介绍 WP Statistics WordPress 插件13.2.9之前的版本不会转义参数&#xff0c;这可能允许经过身份验证的用户执行 SQL 注入攻击。默认情况下&#xff0c;具有管理选项功能 (admin) 的用户可以使用受影响的功能&#xff0c;但是该插件有一个设置允许低权限用…...

Linux入门攻坚——39、Nginx入门

Nginx&#xff1a;engine X Tengine&#xff1a;淘宝改进维护的版本 Registry&#xff1a; 使用了libevent库&#xff1a;高性能的网络库 epoll()函数 Nginx特性&#xff1a; 模块化设计、较好的扩展性&#xff1b;&#xff08;但不支持动态加载模块功能&#…...

计算机网络的类型

目录 按覆盖范围分类 个人区域网&#xff08;PAN&#xff09; 局域网&#xff08;LAN&#xff09; 城域网&#xff08;MAN&#xff09; 4. 广域网&#xff08;WAN&#xff09; 按使用场景和性质分类 公网&#xff08;全球网络&#xff09; 外网 内网&#xff08;私有网…...

解决 MySQL 5.7 安装中的常见问题及解决方案

目录 前言1. 安装MySQL 5.7时的常见错误分析1.1 错误原因及表现1.2 错误的根源 2. 解决方案2.1 修改YUM仓库配置2.2 重新尝试安装2.3 处理GPG密钥错误2.4 解决依赖包问题 3. 安装成功后的配置3.1 启动MySQL服务3.2 获取临时密码3.3 修改root密码 4. 结语 前言 在Linux服务器上…...

VITE+VUE3+TS环境搭建

前言&#xff08;与搭建项目无关&#xff09;&#xff1a; 可以安装一个node管理工具&#xff0c;比如nvm&#xff0c;这样可以顺畅的切换vue2和vue3项目&#xff0c;以免出现项目跑不起来的窘境。我使用的nvm&#xff0c;当前node 22.11.0 目录 搭建项目 添加状态管理库&…...

【设计模式】【创建型模式(Creational Patterns)】之原型模式(Prototype Pattern)

1. 设计模式原理说明 原型模式&#xff08;Prototype Pattern&#xff09; 是一种创建型设计模式&#xff0c;它允许你通过复制现有对象来创建新对象&#xff0c;而无需通过构造函数来创建。这种方式可以提高性能&#xff0c;尤其是在对象初始化需要消耗大量资源或耗时较长的情…...

黄仁勋:人形机器人在内,仅有三种机器人有望实现大规模生产

11月23日&#xff0c;芯片巨头、AI时代“卖铲人”和最大受益者、全球市值最高【英伟达】创始人兼CEO黄仁勋在香港科技大学被授予工程学荣誉博士学位&#xff1b;并与香港科技大学校董会主席沈向洋展开深刻对话&#xff0c;涉及人工智能&#xff08;AI&#xff09;、计算力、领导…...

【C语言】宏定义详解

C语言中的宏定义&#xff08;#define&#xff09;详细解析 在C语言中&#xff0c;宏定义是一种预处理指令&#xff0c;使用 #define 关键字定义。它由预处理器&#xff08;Preprocessor&#xff09;在编译前处理&#xff0c;用于定义常量、代码片段或函数样式的代码替换。宏是…...

LangChain——多向量检索器

每个文档存储多个向量通常是有益的。在许多用例中&#xff0c;这是有益的。 LangChain 有一个基础 MultiVectorRetriever &#xff0c;这使得查询此类设置变得容易。很多复杂性在于如何为每个文档创建多个向量。本笔记本涵盖了创建这些向量和使用 MultiVectorRetriever 的一些常…...

《岩石学报》

本刊主要报道有关岩石学基础理论的岩石学领域各学科包括岩浆岩石学、变质岩石学、沉积岩石学、岩石大地构造学、岩石同位素年代学和同位素地球化学、岩石成矿学、造岩矿物学等方面的重要基础理论和应用研究成果&#xff0c;同时也刊载综述性文章、问题讨论、学术动态以及书评等…...

数据结构 (12)串的存储实现

一、顺序存储结构 顺序存储结构是用一组连续的存储单元来存储串中的字符序列。这种存储方式类似于线性表的顺序存储结构&#xff0c;但串的存储对象仅限于字符。顺序存储结构又可以分为定长顺序存储和堆分配存储两种方式。 定长顺序存储&#xff1a; 使用静态数组存储&#xff…...

职场发展陷阱

一、只有执行&#xff0c;没有思考 二、只有过程&#xff0c;没有结果 三、只有重复&#xff0c;没有精进 四、不懂向上管理 五、定期汇报 六、不要憋大招 七、多同步信息...

Xcode15(iOS17.4)打包的项目在 iOS12 系统上启动崩溃

0x00 启动崩溃 崩溃日志&#xff0c;只有 2 行&#xff0c;看不出啥来。 0x01 默认配置 由于我开发时&#xff0c;使用的 Xcode 14.1&#xff0c;打包在另外一台电脑 Xcode 15.3 Xcode 14.1 Build Settings -> Asset Catalog Compliter - Options Xcode 15.3 Build S…...

极狐GitLab 17.6 正式发布几十项与 DevSecOps 相关的功能【二】

GitLab 是一个全球知名的一体化 DevOps 平台&#xff0c;很多人都通过私有化部署 GitLab 来进行源代码托管。极狐GitLab 是 GitLab 在中国的发行版&#xff0c;专门为中国程序员服务。可以一键式部署极狐GitLab。 学习极狐GitLab 的相关资料&#xff1a; 极狐GitLab 官网极狐…...

PVE相关名词通俗表述方式———多处细节实验(方便理解)

PVE设置初期&#xff0c;对CIDR、 网关、 LinuxBridge、VLAN等很有困惑的朋友一定很需要一篇能够全面通俗易懂的方式去理解PVE 中Linux网桥的工作方式&#xff0c;就像操作一个英雄&#xff0c;多个技能&#xff0c;还是需要一点点去学习理解的&#xff0c;如果你上来就对着别人…...

Ansible--自动化运维工具

Ansible自动化运维工具介绍 1.Ansible介绍 Ansible是一款自动化运维工具&#xff0c;基于Python开发&#xff0c;集合了众多运维工具&#xff08;puppet、cfengine、chef、func、fabric&#xff09;的优点&#xff0c;实现了批量系统配置、批量程序部署、批量运行命令等功能。…...

微信小程序学习指南从入门到精通

&#x1f5fd;微信小程序学习指南从入门到精通&#x1f5fd; &#x1f51d;微信小程序学习指南从入门到精通&#x1f51d;✍前言✍&#x1f4bb;微信小程序学习指南前言&#x1f4bb;一、&#x1f680;文章列表&#x1f680;二、&#x1f52f;教程文章的好处&#x1f52f;1. ✅…...

微服务篇-深入了解使用 RestTemplate 远程调用、Nacos 注册中心基本原理与使用、OpenFeign 的基本使用

&#x1f525;博客主页&#xff1a; 【小扳_-CSDN博客】 ❤感谢大家点赞&#x1f44d;收藏⭐评论✍ 文章目录 1.0 认识微服务 1.1 单体架构 1.2 微服务 1.3 SpringCloud 框架 2.0 服务调用 2.1 RestTemplate 远程调用 3.0 服务注册和发现 3.1 注册中心原理 3.2 Nacos 注册中心 …...

使用 Django 构建支持 Kubernetes API 测试连接的 POST 接口

文章目录 使用 Django 构建支持 Kubernetes API 测试连接的 POST 接口功能需求使用 kubectl 获取 Token命令解析输出示例 完整代码实现Kubernetes API 客户端类功能说明 Django 接口视图关键点解析 路由配置 接口测试请求示例响应结果成功错误 优化建议1. 安全性2. 错误处理3. …...

十二、正则表达式、元字符、替换修饰符、手势和对话框插件

1. 正则表达式 1.1 基本使用 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Document</title&g…...

四川省住房和城乡建设厅门户网站/百度一下就会知道了

项目经理特别是IT类的项目经理&#xff0c;是我们开发软件产品和互联网类产品的项目核心人物&#xff0c;可以这么说一个好的合格的项目经理&#xff0c;是一个IT项目从立项到正式发布上线的成败的关键人物&#xff0c;选对了一个好的项目经理&#xff0c;一个项目可以说成功了…...

模板王网站/查网址

原文博客地址: Hexo博客多台电脑设备同步管理最近一直在折腾Hexo博客, 玩的可谓是不亦乐乎啊; 这里就整理一下之前遗留的一些问题和一些个性化配置如有遇到搭建个人博客时遇到的问题, 这里可参考我的之前的两篇相关博客 基于GitHub和Hexo搭建个人博客NexT主题配置优化-出土指南…...

用dw做网站怎么单独修改字体/全国疫情地区查询最新

一、写在前面之前写过一篇用Python发送天气预报邮件的博客&#xff0c;但是因为要手动输入城市名称&#xff0c;还要打开邮箱才能知道天气情况&#xff0c;这也太麻烦了。于是乎&#xff0c;有了这一篇博客&#xff0c;这次我要做的就是用Python获取本机IP地址&#xff0c;并根…...

建筑工程网签备案合同/西安seo代理计费

leetcode415题字符串相加 进位求和 字符串的拼接 想起来一个将数字反转的题目,简单说一下思路吧, carry num % 10 num num / 10 res carry * 10 class Solution {public String addStrings(String num1, String num2) {//说实话,这种思路还是有点东西的StringBuilder…...

单页网站怎么赚钱/seo全网优化指南

2019独角兽企业重金招聘Python工程师标准>>> public class BaseViewHoler extends RecyclerView.ViewHolder {private Context context;//行布局的viewprivate View mView;//用来装载id的集合 用法和map类似private SparseArray<View> sparseArray;public Bas…...

白沟做网站/企业网站设计制作

文章目录练习9.41练习9.42练习9.43练习9.44练习9.45练习9.46练习9.47练习9.48练习9.49练习9.50练习9.41 编写程序&#xff0c;从一个vector初始化一个string。 vector<char> v{ h, e, l, l, o }; string str(v.cbegin(), v.cend());练习9.42 假定你希望每次读取一个字符存…...