当前位置: 首页 > news >正文

使用小尺寸的图像进行逐像素语义分割训练,出现样本不均衡训练效果问题

在使用小尺寸图像进行逐像素语义分割训练时,确实可能出现样本不均衡问题,且这种问题可能比大尺寸图像更显著


1. 小尺寸图像如何加剧样本不均衡?

(1) 局部裁剪导致类别分布偏差
  • 问题:遥感图像中某些类别(如道路、建筑)可能稀疏分布。小尺寸裁剪后,部分训练样本可能完全不含某些类别(例如一块纯农田的补丁),导致模型对这些类别缺乏学习机会。
  • 示例
    • 原图中“道路”占比5%,若裁剪为 256x256 的小图,部分小图中可能完全无道路像素。
    • 极端情况下,某些类别可能仅在极少数小图中出现,形成“长尾分布”。
(2) 批次内类别覆盖不足
  • 问题:小尺寸图像的批训练(batch training)中,若单个批次内缺少某些类别,梯度更新会偏向多数类。
  • 示例:若一个batch中80%的补丁以“植被”为主,模型会倾向于将模糊区域预测为植被。
(3) 像素级不平衡放大
  • 问题:即使原图类别均衡,小尺寸裁剪可能导致局部像素比例失衡。
    • 例如,原图中“水体”占10%,但某个小图中水体可能占90%(河流区域)或0%(干旱区域)。

2. 样本不均衡的典型影响

  • 模型偏向多数类:对高频类别(如植被、背景)过拟合,低频类别(如车辆、道路)漏检。
  • 边界模糊:模型对类别交界处的预测置信度低,导致分割边缘不连续。
  • 评估指标失真:全局指标(如整体准确率)虚高,但关键类别(如灾害损毁区域)的IoU/F1值极低。

3. 针对小尺寸图像的解决方案

(1) 数据层面的优化
  • 定向裁剪(Guided Cropping)
    • 根据类别分布优先裁剪包含稀有类别的小图。
    • 工具:使用滑动窗口统计每个候选补丁的类别比例,筛选包含目标类别的补丁。
  • 过采样(Oversampling)
    • 对包含稀有类别的小图增加采样概率。
    • 例如:若某小图中含“道路”,则其在训练集中的出现次数增加3倍。
  • 数据增强强化
    • 对小图中稀有类别区域进行针对性增强:
      • 局部旋转、缩放、亮度调整(避免全局变换导致稀有目标失真)。
      • 复制-粘贴增强(Copy-Paste):将稀有目标粘贴到其他背景中(如将车辆粘贴到农田补丁上)。
(2) 损失函数设计
  • 加权交叉熵(Weighted Cross-Entropy)
    • 根据类别像素频率反向加权,例如权重与类别频率成反比:
      weight = 1 / (class_freq + epsilon)  # 防止除零
      
  • Focal Loss
    • 抑制易分类样本(如背景)的损失贡献,聚焦难样本(如小目标):
      loss = -α * (1 - p)^γ * log(p)  # α平衡类别,γ聚焦难样本
      
  • Dice Loss / Tversky Loss
    • 直接优化分割重叠区域(IoU),对类别不平衡更鲁棒:
      Dice Loss = 1 - (2*|X∩Y|) / (|X| + |Y|)
      Tversky Loss = 1 - (|X∩Y|) / (|X∩Y| + α|X-Y| + β|Y-X|)  # 调整α,β权衡假阳/假阴
      
(3) 模型架构改进
  • 上下文感知模块
    • 使用空洞卷积(Dilated Convolution)或注意力机制(如SE Block、Non-local Networks),增强模型对稀疏目标的捕捉能力。
  • 多尺度特征融合
    • 通过金字塔池化(PSPNet)或U-Net++结构,融合不同尺度的特征,缓解因小尺寸输入丢失的上下文信息。
  • 辅助监督(Auxiliary Supervision)
    • 在中间层添加辅助损失函数,强制模型关注细粒度特征。
(4) 训练策略调整
  • 小批次大迭代
    • 使用小batch size但增加迭代次数,确保稀有类别在多个epoch中被充分学习。
  • 动态类别权重
    • 根据当前batch内的类别分布实时调整损失权重。
  • 困难样本挖掘(Hard Example Mining)
    • 在每个epoch后,筛选对稀有类别预测误差大的样本,下一轮训练中增加其采样概率。

4. 实验验证建议

  • 监控类别指标:除了整体准确率,跟踪每个类别的IoU、F1-score。
  • 可视化错误样本:检查模型在稀有类别上的失败案例,针对性优化数据或模型。
  • 消融实验:对比不同损失函数、数据增强策略的效果。

小尺寸图像训练会放大样本不均衡问题,但通过定向数据采样、损失函数优化、模型结构改进三者结合,可显著缓解影响。关键是根据任务特点(如目标大小、类别分布)选择组合策略,例如:

  • 稀疏小目标:Focal Loss + Copy-Paste增强 + 空洞卷积。
  • 长尾分布:加权交叉熵 + 过采样 + 动态类别权重。

在 PyTorch 中,虽然没有直接解决语义分割样本不均衡的“万能模块”,但可以通过组合现有模块社区成熟库高效实现解决方案。


1. 数据层面:加权采样与增强

(1) 加权随机采样(WeightedRandomSampler)

PyTorch 内置 WeightedRandomSampler,可对包含稀有类别的图像补丁过采样:

import numpy as npdef compute_weight_for_patch(patch):image, mask = patch# 假设 mask 是一个二维数组,每个像素值表示类别标签# 计算每个类别的像素数量class_counts = np.bincount(mask.flatten())# 计算总像素数量total_pixels = mask.size# 计算每个类别的比例class_ratios = class_counts / total_pixels# 计算所有类别的权重class_weights = 1.0 / (class_ratios + 1e-6)  # 避免除以零,添加一个小的常数# 应用 sigmoid 函数class_weights = 1.0 / (1.0 + np.exp(-class_weights))# 计算样本的权重sample_weight = np.sum(class_weights)print("Total samples weights:", sample_weight)return class_weights
from torch.utils.data import WeightedRandomSampler# 假设 dataset 返回 (image, mask),且每个样本有一个权重 weight
weights = [compute_weight_for_patch(patch) for patch in dataset]  # 根据补丁中稀有类别比例计算权重
sampler = WeightedRandomSampler(weights, num_samples=len(dataset), replacement=True)
dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)
(2) 数据增强库(Albumentations)

Albumentations 提供针对分割任务的增强,支持对特定类别区域增强:

import albumentations as Atransform = A.Compose([A.RandomCrop(256, 256),A.OneOf([A.RandomRotate90(),A.HorizontalFlip(),A.VerticalFlip()]),A.RandomBrightnessContrast(p=0.5),# 对特定类别区域增强(如仅增强“车辆”区域)A.RandomCropNearBBox(p=0.5, max_part_shift=0.3)
])

2. 损失函数:直接调用社区实现

(1) Focal Loss

使用 torchvision.ops 或第三方库:

# 使用 torchvision(需 0.10+ 版本)
from torchvision.ops import sigmoid_focal_lossloss = sigmoid_focal_loss(outputs, targets, alpha=0.25, gamma=2, reduction="mean")# 或自定义多类别 Focal Loss
class FocalLoss(nn.Module):def __init__(self, alpha=0.25, gamma=2):super().__init__()self.alpha = alphaself.gamma = gammadef forward(self, inputs, targets):ce_loss = F.cross_entropy(inputs, targets, reduction="none")pt = torch.exp(-ce_loss)loss = self.alpha * (1 - pt) ** self.gamma * ce_lossreturn loss.mean()
(2) Dice Loss

社区标准实现(或使用 segmentation_models_pytorch 库):

class DiceLoss(nn.Module):def __init__(self, smooth=1e-6):super().__init__()self.smooth = smoothdef forward(self, inputs, targets):inputs = F.softmax(inputs, dim=1)targets = F.one_hot(targets, num_classes=inputs.shape[1]).permute(0, 3, 1, 2)intersection = (inputs * targets).sum()union = inputs.sum() + targets.sum()dice = (2 * intersection + self.smooth) / (union + self.smooth)return 1 - dice
(3) 直接调用 segmentation_models_pytorch 损失函数
import segmentation_models_pytorch as smploss = smp.losses.DiceLoss(mode="multiclass", classes=[0, 1, 2])  # 指定关注类别
loss = smp.losses.FocalLoss(mode="multiclass", normalized=True)   # 归一化版本

3. 模型层面:集成注意力与多尺度模块

(1) 使用预建模型库

segmentation_models_pytorch(SMP)提供即用的模型和模块:

import segmentation_models_pytorch as smpmodel = smp.Unet(encoder_name="resnet34",encoder_weights="imagenet",in_channels=3,classes=5,decoder_attention_type="scse",  # 添加空间-通道注意力
)
(2) 空洞卷积(Dilated Convolution)

直接使用 PyTorch 的 Conv2d 实现:

class DilatedConvBlock(nn.Module):def __init__(self, in_channels, out_channels, dilation_rate=2):super().__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilation_rate, dilation=dilation_rate)self.norm = nn.BatchNorm2d(out_channels)self.act = nn.ReLU()def forward(self, x):return self.act(self.norm(self.conv(x)))# 在 U-Net 的 decoder 中插入空洞卷积块

4. 类别权重计算工具

(1) 自动计算类别权重
from sklearn.utils.class_weight import compute_class_weight# 统计训练集所有像素的类别分布
class_counts = np.bincount(all_pixel_labels.flatten())
class_weights = compute_class_weight(class_weight="balanced", classes=np.arange(num_classes), y=all_pixel_labels.flatten()
)
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)# 在损失函数中使用
criterion = nn.CrossEntropyLoss(weight=class_weights)

5. 完整 Pipeline 示例

import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
import segmentation_models_pytorch as smp
import albumentations as A# 1. 定义数据集和采样器
dataset = YourDataset(transform=albumentations_transform)
weights = compute_patch_weights(dataset)  # 根据补丁中目标类别比例计算
sampler = WeightedRandomSampler(weights, len(dataset), replacement=True)
dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)# 2. 定义模型和损失
model = smp.Unet(encoder_name="resnet34", classes=5, decoder_attention_type="scse")
criterion = smp.losses.DiceLoss(mode="multiclass") + smp.losses.FocalLoss(mode="multiclass")# 3. 训练循环
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(100):for images, masks in dataloader:outputs = model(images)loss = criterion(outputs, masks)loss.backward()optimizer.step()

关键工具总结

问题类型PyTorch 原生支持推荐第三方库(直接调用)
数据采样WeightedRandomSamplerAlbumentations(定向增强)
损失函数自定义(需手写)segmentation_models_pytorch.losses
模型结构手动添加模块(空洞卷积、注意力)segmentation_models_pytorch 预建模型
类别权重计算sklearn.utils.class_weight内置自动统计工具(如 SMP 数据集类)

注意事项

  1. 灵活组合策略:例如同时使用 WeightedRandomSamplerFocal Loss 可能过度偏向少数类,需通过实验调整。
  2. 监控类别指标:使用 torchmetrics 库计算每个类别的 IoU:
    from torchmetrics import JaccardIndex
    iou = JaccardIndex(num_classes=5, task="multiclass")
    iou.update(outputs, targets)
    print(f"IoU: {iou.compute()}")
    
  3. 混合精度训练:使用 torch.cuda.amp 加速训练,缓解显存压力:
    scaler = torch.cuda.amp.GradScaler()
    with torch.cuda.amp.autocast():outputs = model(images)loss = criterion(outputs, masks)
    scaler.scale(loss).backward()
    scaler.step(optimizer)
    scaler.update()
    

相关文章:

使用小尺寸的图像进行逐像素语义分割训练,出现样本不均衡训练效果问题

在使用小尺寸图像进行逐像素语义分割训练时,确实可能出现样本不均衡问题,且这种问题可能比大尺寸图像更显著。 1. 小尺寸图像如何加剧样本不均衡? (1) 局部裁剪导致类别分布偏差 问题:遥感图像中某些类别(如道路、建…...

0.91英寸OLED显示屏一种具有小尺寸、高分辨率、低功耗特性的显示器件

0.91英寸OLED显示屏是一种具有小尺寸、高分辨率、低功耗特性的显示器件。以下是对0.91英寸OLED显示屏的详细介绍: 一、基本参数 尺寸:0.91英寸分辨率:通常为128x32像素,意味着显示屏上有128列和32行的像素点,总共409…...

读书笔记--分布式服务架构对比及优势

本篇是在上一篇的基础上,主要对共享服务平台建设所依赖的分布式服务架构进行学习,主要记录和思考如下,供大家学习参考。随着企业各业务数字化转型工作的推进,之前在传统的单一系统(或单体应用)模式中&#…...

HTML5 新的 Input 类型详解

HTML5 引入了许多新的输入类型,极大地增强了表单的功能和用户体验。这些新的输入类型不仅提供了更好的输入控制,还支持内置的验证功能,减少了开发者手动编写验证逻辑的工作量。本文将全面介绍 HTML5 中新增的输入类型,并结合代码示…...

ESP32-CAM实验集(WebServer)

WebServer 效果图 已连接 web端 platformio.ini ; PlatformIO Project Configuration File ; ; Build options: build flags, source filter ; Upload options: custom upload port, speed and extra flags ; Library options: dependencies, extra library stor…...

Case逢无意难休——深度解析JAVA中case穿透问题

Case逢无意难休——深度解析JAVA中case穿透问题~ 不作溢美之词,不作浮夸文章,此文与功名进取毫不相关也!与大家共勉!! 更多文章:个人主页 系列文章:JAVA专栏 欢迎各位大佬来访哦~互三必回&#…...

Golang笔记——常用库context和runtime

大家好,这里是Good Note,关注 公主号:Goodnote,专栏文章私信限时Free。本文详细介绍Golang的常用库context和runtime,包括库的基本概念和基本函数的使用等。 文章目录 contextcontext 包的基本概念主要类型和函数1. **…...

2000-2020年各省第二产业增加值占GDP比重数据

2000-2020年各省第二产业增加值占GDP比重数据 1、时间:2000-2020年 2、来源:国家统计局、统计年鉴 3、指标:行政区划代码、地区名称、年份、第二产业增加值占GDP比重 4、范围:31省 5、指标解释:第二产业增加值占GDP比重…...

unity商店插件A* Pathfinding Project如何判断一个点是否在导航网格上?

需要使用NavGraph.IsPointOnNavmesh(Vector3 point) 如果点位于导航网的可步行部分,则为真。 如果一个点在可步行导航网表面之上或之下,在任何距离,如果它不在更近的不可步行节点之上 / 之下,则认为它在导航网上。 使用方法 Ast…...

Day24-【13003】短文,数据结构与算法开篇,什么是数据元素?数据结构有哪些类型?什么是抽象类型?

文章目录 13003数据结构与算法全书框架考试题型的分值分布如何? 本次内容概述绪论第一节概览什么是数据、数据元素,数据项,数据项的值?什么是数据结构?分哪两种集合形式(逻辑和存储)&#xff1f…...

富文本 tinyMCE Vue2 组件使用简易教程

参考官方教程 TinyMCE Vue.js integration technical reference Vue2 项目需要使用 tinyMCE Vue2 组件(tinymce/tinymce-vue)的第 3 版 安装组件 npm install --save "tinymce/tinymce-vue^3" 编写组件调用 <template><Editorref"editor"v-m…...

强化学习在自动驾驶中的实现与挑战

强化学习在自动驾驶中的实现与挑战 自动驾驶技术作为当今人工智能领域的前沿之一,正通过各种方式改变我们的出行方式。而强化学习(Reinforcement Learning, RL),作为机器学习的一大分支,在自动驾驶的实现中扮演了至关重要的角色。它通过模仿人类驾驶员的决策过程,为车辆…...

记录 | MaxKB创建本地AI智能问答系统

目录 前言一、重建MaxKBStep1 复制路径Step2 删除MaxKBStep3 创建数据存储文件夹Step4 重建 二、创建知识库Step1 新建知识库Step2 下载测试所用的txtStep3 上传本地文档Step4 选择模型补充智谱的API Key如何获取 Step5 查看是否成功 三、创建应用Step1 新建应用Step2 配置AI助…...

特种作业操作之低压电工考试真题

1.下面&#xff08; &#xff09;属于顺磁性材料。 A. 铜 B. 水 C. 空气 答案&#xff1a;C 2.事故照明一般采用&#xff08; &#xff09;。 A. 日光灯 B. 白炽灯 C. 压汞灯 答案&#xff1a;B 3.人体同时接触带电设备或线路中的两相导体时&#xff0c;电流从一相通过人体流…...

[免费]基于Python的Django博客系统【论文+源码+SQL脚本】

大家好&#xff0c;我是java1234_小锋老师&#xff0c;看到一个不错的基于Python的Django博客系统&#xff0c;分享下哈。 项目视频演示 【免费】基于Python的Django博客系统 Python毕业设计_哔哩哔哩_bilibili 项目介绍 随着互联网技术的飞速发展&#xff0c;信息的传播与…...

Cannot resolve symbol ‘XXX‘ Maven 依赖问题的解决过程

一、问题描述 在使用 Maven 管理项目依赖时&#xff0c;遇到了一个棘手的问题。具体表现为&#xff1a;在 pom.xml 文件中导入了所需的依赖&#xff0c;并且在 IDE 中导入语句没有显示为红色&#xff08;表示 IDE 没有提示依赖缺失&#xff09;&#xff0c;但是在实际使用这些依…...

我们需要有哪些知识体系,知识体系里面要有什么哪些内容?

01、管理知识体系的学习知识体系 主要内容&#xff1a; 1、知识管理框架的外部借鉴、和自身知识体系的搭建&#xff1b; 2、学习能力、思维逻辑能力等的塑造&#xff1b; 3、知识管理工具的使用&#xff1b; 4、学习资料的导入和查找资料的渠道&#xff1b; 5、深层关键的…...

什么是vue.js组件开发,我们需要做哪些准备工作?

Vue.js 是一个非常流行的前端框架,用于构建用户界面。组件开发是 Vue.js 的核心概念之一,通过将界面拆分为独立的组件,可以提高代码的可维护性和复用性。以下是一个详细的 Vue.js 组件开发指南,包括基础概念、开发流程和代码示例。 一、Vue.js 组件开发基础 1. 组件的基本…...

网络工程师 (3)指令系统基础

一、寻址方式 &#xff08;一&#xff09;指令寻址 顺序寻址&#xff1a;通过程序计数器&#xff08;PC&#xff09;加1&#xff0c;自动形成下一条指令的地址。这是计算机中最基本、最常用的寻址方式。 跳跃寻址&#xff1a;通过转移类指令直接或间接给出下一条指令的地址。跳…...

第4章 神经网络【1】——损失函数

4.1.从数据中学习 实际的神经网络中&#xff0c;参数的数量成千上万&#xff0c;因此&#xff0c;需要由数据自动决定权重参数的值。 4.1.1.数据驱动 数据是机器学习的核心。 我们的目标是要提取出特征量&#xff0c;特征量指的是从输入数据/图像中提取出的本质的数 …...

【Python】第五弹---深入理解函数:从基础到进阶的全面解析

✨个人主页&#xff1a; 熬夜学编程的小林 &#x1f497;系列专栏&#xff1a; 【C语言详解】 【数据结构详解】【C详解】【Linux系统编程】【MySQL】【Python】 目录 1、函数 1.1、函数是什么 1.2、语法格式 1.3、函数参数 1.4、函数返回值 1.5、变量作用域 1.6、函数…...

【MQ】如何保证消息队列的高性能?

零拷贝 Kafka 使用到了 mmap 和 sendfile 的方式来实现零拷贝。分别对应 Java 的 MappedByteBuffer 和 FileChannel.transferTo 顺序写磁盘 Kafka 采用顺序写文件的方式来提高磁盘写入性能。顺序写文件&#xff0c;基本减少了磁盘寻道和旋转的次数完成一次磁盘 IO&#xff0…...

RAG是否被取代(缓存增强生成-CAG)吗?

引言&#xff1a; 本文深入研究一种名为缓存增强生成&#xff08;CAG&#xff09;的新技术如何工作并减少/消除检索增强生成&#xff08;RAG&#xff09;弱点和瓶颈。 LLMs 可以根据输入给他的信息给出对应的输出&#xff0c;但是这样的工作方式很快就不能满足应用的需要: 因…...

用C++编写一个2048的小游戏

以下是一个简单的2048游戏的实现。这个实现使用了控制台输入和输出&#xff0c;适合在终端或命令行环境中运行。 2048游戏的实现 1.游戏逻辑 2048游戏的核心逻辑包括&#xff1a; • 初始化一个4x4的网格。 • 随机生成2或4。 • 处理玩家的移动操作&#xff08;上、下、左、…...

为何SAP S4系统中要设置MRP区域?MD04中可否同时显示工厂级、库存地点级的数据?

【SAP系统PP模块研究】 一、物料主数据的MRP区域设置 SAP ECC系统中想要指定不影响MRP运算的库存地点,是针对库存地点设置MRP标识,路径为:SPRO->生产->物料需求计划->计划->定义每一个工厂的存储地点MRP,如下图所示: 另外,在给物料主数据MMSC扩充库存地点时…...

Windows10官方系统下载与安装保姆级教程【U盘-官方ISO直装】

Windows 10 官方系统安装/重装 制作启动盘的U盘微软官网下载Win10安装包创建启动盘U盘 安装Win10 本文采用U盘安装Windows10官方系统。 制作启动盘的U盘 微软官网下载Win10安装包 微软官网下载Win10安装包链接&#xff1a;https://www.microsoft.com/zh-cn/software-downloa…...

第05章 07 切片图等值线代码一则

绘制脑部切面图的阈值等值线是一个常见的任务&#xff0c;通常涉及使用VTK&#xff08;Visualization Toolkit&#xff09;库来处理医学图像数据。以下是一个基于VTK/C的示例代码&#xff0c;展示如何读取脑部DICOM图像数据&#xff0c;应用阈值过滤器来提取特定组织的等值线&a…...

【深度学习】线性回归的简洁实现

线性回归的简洁实现 在过去的几年里&#xff0c;出于对深度学习强烈的兴趣&#xff0c;许多公司、学者和业余爱好者开发了各种成熟的开源框架。 这些框架可以自动化基于梯度的学习算法中重复性的工作。 目前&#xff0c;我们只会运用&#xff1a; &#xff08;1&#xff09;通…...

渗透测试技法之口令安全

一、口令安全威胁 口令泄露途径 代码与文件存储不当&#xff1a;在软件开发和系统维护过程中&#xff0c;开发者可能会将口令以明文形式存储在代码文件、配置文件或注释中。例如&#xff0c;在开源代码托管平台 GitHub 上&#xff0c;一些开发者由于疏忽&#xff0c;将包含数据…...

【R语言】数学运算

一、基础运算 R语言中能实现加、减、乘、除、求模、取整、取绝对值、指数、对数等运算。 x <- 2 y <- 10 # 求模 y %% x # 整除 y %/% x # 取绝对值 abs(-x) # 指数运算 y ^x y^1/x #对数运算 log(x) #log()函数默认情况下以 e 为底 双等号“”的作用等同于identical(…...