使用小尺寸的图像进行逐像素语义分割训练,出现样本不均衡训练效果问题
在使用小尺寸图像进行逐像素语义分割训练时,确实可能出现样本不均衡问题,且这种问题可能比大尺寸图像更显著。
1. 小尺寸图像如何加剧样本不均衡?
(1) 局部裁剪导致类别分布偏差
- 问题:遥感图像中某些类别(如道路、建筑)可能稀疏分布。小尺寸裁剪后,部分训练样本可能完全不含某些类别(例如一块纯农田的补丁),导致模型对这些类别缺乏学习机会。
- 示例:
- 原图中“道路”占比5%,若裁剪为
256x256
的小图,部分小图中可能完全无道路像素。 - 极端情况下,某些类别可能仅在极少数小图中出现,形成“长尾分布”。
- 原图中“道路”占比5%,若裁剪为
(2) 批次内类别覆盖不足
- 问题:小尺寸图像的批训练(batch training)中,若单个批次内缺少某些类别,梯度更新会偏向多数类。
- 示例:若一个batch中80%的补丁以“植被”为主,模型会倾向于将模糊区域预测为植被。
(3) 像素级不平衡放大
- 问题:即使原图类别均衡,小尺寸裁剪可能导致局部像素比例失衡。
- 例如,原图中“水体”占10%,但某个小图中水体可能占90%(河流区域)或0%(干旱区域)。
2. 样本不均衡的典型影响
- 模型偏向多数类:对高频类别(如植被、背景)过拟合,低频类别(如车辆、道路)漏检。
- 边界模糊:模型对类别交界处的预测置信度低,导致分割边缘不连续。
- 评估指标失真:全局指标(如整体准确率)虚高,但关键类别(如灾害损毁区域)的IoU/F1值极低。
3. 针对小尺寸图像的解决方案
(1) 数据层面的优化
- 定向裁剪(Guided Cropping):
- 根据类别分布优先裁剪包含稀有类别的小图。
- 工具:使用滑动窗口统计每个候选补丁的类别比例,筛选包含目标类别的补丁。
- 过采样(Oversampling):
- 对包含稀有类别的小图增加采样概率。
- 例如:若某小图中含“道路”,则其在训练集中的出现次数增加3倍。
- 数据增强强化:
- 对小图中稀有类别区域进行针对性增强:
- 局部旋转、缩放、亮度调整(避免全局变换导致稀有目标失真)。
- 复制-粘贴增强(Copy-Paste):将稀有目标粘贴到其他背景中(如将车辆粘贴到农田补丁上)。
- 对小图中稀有类别区域进行针对性增强:
(2) 损失函数设计
- 加权交叉熵(Weighted Cross-Entropy):
- 根据类别像素频率反向加权,例如权重与类别频率成反比:
weight = 1 / (class_freq + epsilon) # 防止除零
- 根据类别像素频率反向加权,例如权重与类别频率成反比:
- Focal Loss:
- 抑制易分类样本(如背景)的损失贡献,聚焦难样本(如小目标):
loss = -α * (1 - p)^γ * log(p) # α平衡类别,γ聚焦难样本
- 抑制易分类样本(如背景)的损失贡献,聚焦难样本(如小目标):
- Dice Loss / Tversky Loss:
- 直接优化分割重叠区域(IoU),对类别不平衡更鲁棒:
Dice Loss = 1 - (2*|X∩Y|) / (|X| + |Y|) Tversky Loss = 1 - (|X∩Y|) / (|X∩Y| + α|X-Y| + β|Y-X|) # 调整α,β权衡假阳/假阴
- 直接优化分割重叠区域(IoU),对类别不平衡更鲁棒:
(3) 模型架构改进
- 上下文感知模块:
- 使用空洞卷积(Dilated Convolution)或注意力机制(如SE Block、Non-local Networks),增强模型对稀疏目标的捕捉能力。
- 多尺度特征融合:
- 通过金字塔池化(PSPNet)或U-Net++结构,融合不同尺度的特征,缓解因小尺寸输入丢失的上下文信息。
- 辅助监督(Auxiliary Supervision):
- 在中间层添加辅助损失函数,强制模型关注细粒度特征。
(4) 训练策略调整
- 小批次大迭代:
- 使用小batch size但增加迭代次数,确保稀有类别在多个epoch中被充分学习。
- 动态类别权重:
- 根据当前batch内的类别分布实时调整损失权重。
- 困难样本挖掘(Hard Example Mining):
- 在每个epoch后,筛选对稀有类别预测误差大的样本,下一轮训练中增加其采样概率。
4. 实验验证建议
- 监控类别指标:除了整体准确率,跟踪每个类别的IoU、F1-score。
- 可视化错误样本:检查模型在稀有类别上的失败案例,针对性优化数据或模型。
- 消融实验:对比不同损失函数、数据增强策略的效果。
小尺寸图像训练会放大样本不均衡问题,但通过定向数据采样、损失函数优化、模型结构改进三者结合,可显著缓解影响。关键是根据任务特点(如目标大小、类别分布)选择组合策略,例如:
- 稀疏小目标:Focal Loss + Copy-Paste增强 + 空洞卷积。
- 长尾分布:加权交叉熵 + 过采样 + 动态类别权重。
在 PyTorch 中,虽然没有直接解决语义分割样本不均衡的“万能模块”,但可以通过组合现有模块和社区成熟库高效实现解决方案。
1. 数据层面:加权采样与增强
(1) 加权随机采样(WeightedRandomSampler)
PyTorch 内置 WeightedRandomSampler
,可对包含稀有类别的图像补丁过采样:
import numpy as npdef compute_weight_for_patch(patch):image, mask = patch# 假设 mask 是一个二维数组,每个像素值表示类别标签# 计算每个类别的像素数量class_counts = np.bincount(mask.flatten())# 计算总像素数量total_pixels = mask.size# 计算每个类别的比例class_ratios = class_counts / total_pixels# 计算所有类别的权重class_weights = 1.0 / (class_ratios + 1e-6) # 避免除以零,添加一个小的常数# 应用 sigmoid 函数class_weights = 1.0 / (1.0 + np.exp(-class_weights))# 计算样本的权重sample_weight = np.sum(class_weights)print("Total samples weights:", sample_weight)return class_weights
from torch.utils.data import WeightedRandomSampler# 假设 dataset 返回 (image, mask),且每个样本有一个权重 weight
weights = [compute_weight_for_patch(patch) for patch in dataset] # 根据补丁中稀有类别比例计算权重
sampler = WeightedRandomSampler(weights, num_samples=len(dataset), replacement=True)
dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)
(2) 数据增强库(Albumentations)
Albumentations 提供针对分割任务的增强,支持对特定类别区域增强:
import albumentations as Atransform = A.Compose([A.RandomCrop(256, 256),A.OneOf([A.RandomRotate90(),A.HorizontalFlip(),A.VerticalFlip()]),A.RandomBrightnessContrast(p=0.5),# 对特定类别区域增强(如仅增强“车辆”区域)A.RandomCropNearBBox(p=0.5, max_part_shift=0.3)
])
2. 损失函数:直接调用社区实现
(1) Focal Loss
使用 torchvision.ops
或第三方库:
# 使用 torchvision(需 0.10+ 版本)
from torchvision.ops import sigmoid_focal_lossloss = sigmoid_focal_loss(outputs, targets, alpha=0.25, gamma=2, reduction="mean")# 或自定义多类别 Focal Loss
class FocalLoss(nn.Module):def __init__(self, alpha=0.25, gamma=2):super().__init__()self.alpha = alphaself.gamma = gammadef forward(self, inputs, targets):ce_loss = F.cross_entropy(inputs, targets, reduction="none")pt = torch.exp(-ce_loss)loss = self.alpha * (1 - pt) ** self.gamma * ce_lossreturn loss.mean()
(2) Dice Loss
社区标准实现(或使用 segmentation_models_pytorch
库):
class DiceLoss(nn.Module):def __init__(self, smooth=1e-6):super().__init__()self.smooth = smoothdef forward(self, inputs, targets):inputs = F.softmax(inputs, dim=1)targets = F.one_hot(targets, num_classes=inputs.shape[1]).permute(0, 3, 1, 2)intersection = (inputs * targets).sum()union = inputs.sum() + targets.sum()dice = (2 * intersection + self.smooth) / (union + self.smooth)return 1 - dice
(3) 直接调用 segmentation_models_pytorch
损失函数
import segmentation_models_pytorch as smploss = smp.losses.DiceLoss(mode="multiclass", classes=[0, 1, 2]) # 指定关注类别
loss = smp.losses.FocalLoss(mode="multiclass", normalized=True) # 归一化版本
3. 模型层面:集成注意力与多尺度模块
(1) 使用预建模型库
segmentation_models_pytorch
(SMP)提供即用的模型和模块:
import segmentation_models_pytorch as smpmodel = smp.Unet(encoder_name="resnet34",encoder_weights="imagenet",in_channels=3,classes=5,decoder_attention_type="scse", # 添加空间-通道注意力
)
(2) 空洞卷积(Dilated Convolution)
直接使用 PyTorch 的 Conv2d
实现:
class DilatedConvBlock(nn.Module):def __init__(self, in_channels, out_channels, dilation_rate=2):super().__init__()self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=dilation_rate, dilation=dilation_rate)self.norm = nn.BatchNorm2d(out_channels)self.act = nn.ReLU()def forward(self, x):return self.act(self.norm(self.conv(x)))# 在 U-Net 的 decoder 中插入空洞卷积块
4. 类别权重计算工具
(1) 自动计算类别权重
from sklearn.utils.class_weight import compute_class_weight# 统计训练集所有像素的类别分布
class_counts = np.bincount(all_pixel_labels.flatten())
class_weights = compute_class_weight(class_weight="balanced", classes=np.arange(num_classes), y=all_pixel_labels.flatten()
)
class_weights = torch.tensor(class_weights, dtype=torch.float32).to(device)# 在损失函数中使用
criterion = nn.CrossEntropyLoss(weight=class_weights)
5. 完整 Pipeline 示例
import torch
from torch.utils.data import DataLoader, WeightedRandomSampler
import segmentation_models_pytorch as smp
import albumentations as A# 1. 定义数据集和采样器
dataset = YourDataset(transform=albumentations_transform)
weights = compute_patch_weights(dataset) # 根据补丁中目标类别比例计算
sampler = WeightedRandomSampler(weights, len(dataset), replacement=True)
dataloader = DataLoader(dataset, batch_size=16, sampler=sampler)# 2. 定义模型和损失
model = smp.Unet(encoder_name="resnet34", classes=5, decoder_attention_type="scse")
criterion = smp.losses.DiceLoss(mode="multiclass") + smp.losses.FocalLoss(mode="multiclass")# 3. 训练循环
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
for epoch in range(100):for images, masks in dataloader:outputs = model(images)loss = criterion(outputs, masks)loss.backward()optimizer.step()
关键工具总结
问题类型 | PyTorch 原生支持 | 推荐第三方库(直接调用) |
---|---|---|
数据采样 | WeightedRandomSampler | Albumentations(定向增强) |
损失函数 | 自定义(需手写) | segmentation_models_pytorch.losses |
模型结构 | 手动添加模块(空洞卷积、注意力) | segmentation_models_pytorch 预建模型 |
类别权重计算 | sklearn.utils.class_weight | 内置自动统计工具(如 SMP 数据集类) |
注意事项
- 灵活组合策略:例如同时使用
WeightedRandomSampler
和Focal Loss
可能过度偏向少数类,需通过实验调整。 - 监控类别指标:使用
torchmetrics
库计算每个类别的 IoU:from torchmetrics import JaccardIndex iou = JaccardIndex(num_classes=5, task="multiclass") iou.update(outputs, targets) print(f"IoU: {iou.compute()}")
- 混合精度训练:使用
torch.cuda.amp
加速训练,缓解显存压力:scaler = torch.cuda.amp.GradScaler() with torch.cuda.amp.autocast():outputs = model(images)loss = criterion(outputs, masks) scaler.scale(loss).backward() scaler.step(optimizer) scaler.update()
相关文章:

使用小尺寸的图像进行逐像素语义分割训练,出现样本不均衡训练效果问题
在使用小尺寸图像进行逐像素语义分割训练时,确实可能出现样本不均衡问题,且这种问题可能比大尺寸图像更显著。 1. 小尺寸图像如何加剧样本不均衡? (1) 局部裁剪导致类别分布偏差 问题:遥感图像中某些类别(如道路、建…...

0.91英寸OLED显示屏一种具有小尺寸、高分辨率、低功耗特性的显示器件
0.91英寸OLED显示屏是一种具有小尺寸、高分辨率、低功耗特性的显示器件。以下是对0.91英寸OLED显示屏的详细介绍: 一、基本参数 尺寸:0.91英寸分辨率:通常为128x32像素,意味着显示屏上有128列和32行的像素点,总共409…...

读书笔记--分布式服务架构对比及优势
本篇是在上一篇的基础上,主要对共享服务平台建设所依赖的分布式服务架构进行学习,主要记录和思考如下,供大家学习参考。随着企业各业务数字化转型工作的推进,之前在传统的单一系统(或单体应用)模式中&#…...

HTML5 新的 Input 类型详解
HTML5 引入了许多新的输入类型,极大地增强了表单的功能和用户体验。这些新的输入类型不仅提供了更好的输入控制,还支持内置的验证功能,减少了开发者手动编写验证逻辑的工作量。本文将全面介绍 HTML5 中新增的输入类型,并结合代码示…...

ESP32-CAM实验集(WebServer)
WebServer 效果图 已连接 web端 platformio.ini ; PlatformIO Project Configuration File ; ; Build options: build flags, source filter ; Upload options: custom upload port, speed and extra flags ; Library options: dependencies, extra library stor…...

Case逢无意难休——深度解析JAVA中case穿透问题
Case逢无意难休——深度解析JAVA中case穿透问题~ 不作溢美之词,不作浮夸文章,此文与功名进取毫不相关也!与大家共勉!! 更多文章:个人主页 系列文章:JAVA专栏 欢迎各位大佬来访哦~互三必回&#…...

Golang笔记——常用库context和runtime
大家好,这里是Good Note,关注 公主号:Goodnote,专栏文章私信限时Free。本文详细介绍Golang的常用库context和runtime,包括库的基本概念和基本函数的使用等。 文章目录 contextcontext 包的基本概念主要类型和函数1. **…...

2000-2020年各省第二产业增加值占GDP比重数据
2000-2020年各省第二产业增加值占GDP比重数据 1、时间:2000-2020年 2、来源:国家统计局、统计年鉴 3、指标:行政区划代码、地区名称、年份、第二产业增加值占GDP比重 4、范围:31省 5、指标解释:第二产业增加值占GDP比重…...

unity商店插件A* Pathfinding Project如何判断一个点是否在导航网格上?
需要使用NavGraph.IsPointOnNavmesh(Vector3 point) 如果点位于导航网的可步行部分,则为真。 如果一个点在可步行导航网表面之上或之下,在任何距离,如果它不在更近的不可步行节点之上 / 之下,则认为它在导航网上。 使用方法 Ast…...

Day24-【13003】短文,数据结构与算法开篇,什么是数据元素?数据结构有哪些类型?什么是抽象类型?
文章目录 13003数据结构与算法全书框架考试题型的分值分布如何? 本次内容概述绪论第一节概览什么是数据、数据元素,数据项,数据项的值?什么是数据结构?分哪两种集合形式(逻辑和存储)?…...

富文本 tinyMCE Vue2 组件使用简易教程
参考官方教程 TinyMCE Vue.js integration technical reference Vue2 项目需要使用 tinyMCE Vue2 组件(tinymce/tinymce-vue)的第 3 版 安装组件 npm install --save "tinymce/tinymce-vue^3" 编写组件调用 <template><Editorref"editor"v-m…...

强化学习在自动驾驶中的实现与挑战
强化学习在自动驾驶中的实现与挑战 自动驾驶技术作为当今人工智能领域的前沿之一,正通过各种方式改变我们的出行方式。而强化学习(Reinforcement Learning, RL),作为机器学习的一大分支,在自动驾驶的实现中扮演了至关重要的角色。它通过模仿人类驾驶员的决策过程,为车辆…...

记录 | MaxKB创建本地AI智能问答系统
目录 前言一、重建MaxKBStep1 复制路径Step2 删除MaxKBStep3 创建数据存储文件夹Step4 重建 二、创建知识库Step1 新建知识库Step2 下载测试所用的txtStep3 上传本地文档Step4 选择模型补充智谱的API Key如何获取 Step5 查看是否成功 三、创建应用Step1 新建应用Step2 配置AI助…...

特种作业操作之低压电工考试真题
1.下面( )属于顺磁性材料。 A. 铜 B. 水 C. 空气 答案:C 2.事故照明一般采用( )。 A. 日光灯 B. 白炽灯 C. 压汞灯 答案:B 3.人体同时接触带电设备或线路中的两相导体时,电流从一相通过人体流…...

[免费]基于Python的Django博客系统【论文+源码+SQL脚本】
大家好,我是java1234_小锋老师,看到一个不错的基于Python的Django博客系统,分享下哈。 项目视频演示 【免费】基于Python的Django博客系统 Python毕业设计_哔哩哔哩_bilibili 项目介绍 随着互联网技术的飞速发展,信息的传播与…...

Cannot resolve symbol ‘XXX‘ Maven 依赖问题的解决过程
一、问题描述 在使用 Maven 管理项目依赖时,遇到了一个棘手的问题。具体表现为:在 pom.xml 文件中导入了所需的依赖,并且在 IDE 中导入语句没有显示为红色(表示 IDE 没有提示依赖缺失),但是在实际使用这些依…...

我们需要有哪些知识体系,知识体系里面要有什么哪些内容?
01、管理知识体系的学习知识体系 主要内容: 1、知识管理框架的外部借鉴、和自身知识体系的搭建; 2、学习能力、思维逻辑能力等的塑造; 3、知识管理工具的使用; 4、学习资料的导入和查找资料的渠道; 5、深层关键的…...

什么是vue.js组件开发,我们需要做哪些准备工作?
Vue.js 是一个非常流行的前端框架,用于构建用户界面。组件开发是 Vue.js 的核心概念之一,通过将界面拆分为独立的组件,可以提高代码的可维护性和复用性。以下是一个详细的 Vue.js 组件开发指南,包括基础概念、开发流程和代码示例。 一、Vue.js 组件开发基础 1. 组件的基本…...

网络工程师 (3)指令系统基础
一、寻址方式 (一)指令寻址 顺序寻址:通过程序计数器(PC)加1,自动形成下一条指令的地址。这是计算机中最基本、最常用的寻址方式。 跳跃寻址:通过转移类指令直接或间接给出下一条指令的地址。跳…...

第4章 神经网络【1】——损失函数
4.1.从数据中学习 实际的神经网络中,参数的数量成千上万,因此,需要由数据自动决定权重参数的值。 4.1.1.数据驱动 数据是机器学习的核心。 我们的目标是要提取出特征量,特征量指的是从输入数据/图像中提取出的本质的数 …...

【Python】第五弹---深入理解函数:从基础到进阶的全面解析
✨个人主页: 熬夜学编程的小林 💗系列专栏: 【C语言详解】 【数据结构详解】【C详解】【Linux系统编程】【MySQL】【Python】 目录 1、函数 1.1、函数是什么 1.2、语法格式 1.3、函数参数 1.4、函数返回值 1.5、变量作用域 1.6、函数…...

【MQ】如何保证消息队列的高性能?
零拷贝 Kafka 使用到了 mmap 和 sendfile 的方式来实现零拷贝。分别对应 Java 的 MappedByteBuffer 和 FileChannel.transferTo 顺序写磁盘 Kafka 采用顺序写文件的方式来提高磁盘写入性能。顺序写文件,基本减少了磁盘寻道和旋转的次数完成一次磁盘 IO࿰…...

RAG是否被取代(缓存增强生成-CAG)吗?
引言: 本文深入研究一种名为缓存增强生成(CAG)的新技术如何工作并减少/消除检索增强生成(RAG)弱点和瓶颈。 LLMs 可以根据输入给他的信息给出对应的输出,但是这样的工作方式很快就不能满足应用的需要: 因…...

用C++编写一个2048的小游戏
以下是一个简单的2048游戏的实现。这个实现使用了控制台输入和输出,适合在终端或命令行环境中运行。 2048游戏的实现 1.游戏逻辑 2048游戏的核心逻辑包括: • 初始化一个4x4的网格。 • 随机生成2或4。 • 处理玩家的移动操作(上、下、左、…...

为何SAP S4系统中要设置MRP区域?MD04中可否同时显示工厂级、库存地点级的数据?
【SAP系统PP模块研究】 一、物料主数据的MRP区域设置 SAP ECC系统中想要指定不影响MRP运算的库存地点,是针对库存地点设置MRP标识,路径为:SPRO->生产->物料需求计划->计划->定义每一个工厂的存储地点MRP,如下图所示: 另外,在给物料主数据MMSC扩充库存地点时…...

Windows10官方系统下载与安装保姆级教程【U盘-官方ISO直装】
Windows 10 官方系统安装/重装 制作启动盘的U盘微软官网下载Win10安装包创建启动盘U盘 安装Win10 本文采用U盘安装Windows10官方系统。 制作启动盘的U盘 微软官网下载Win10安装包 微软官网下载Win10安装包链接:https://www.microsoft.com/zh-cn/software-downloa…...

第05章 07 切片图等值线代码一则
绘制脑部切面图的阈值等值线是一个常见的任务,通常涉及使用VTK(Visualization Toolkit)库来处理医学图像数据。以下是一个基于VTK/C的示例代码,展示如何读取脑部DICOM图像数据,应用阈值过滤器来提取特定组织的等值线&a…...

【深度学习】线性回归的简洁实现
线性回归的简洁实现 在过去的几年里,出于对深度学习强烈的兴趣,许多公司、学者和业余爱好者开发了各种成熟的开源框架。 这些框架可以自动化基于梯度的学习算法中重复性的工作。 目前,我们只会运用: (1)通…...

渗透测试技法之口令安全
一、口令安全威胁 口令泄露途径 代码与文件存储不当:在软件开发和系统维护过程中,开发者可能会将口令以明文形式存储在代码文件、配置文件或注释中。例如,在开源代码托管平台 GitHub 上,一些开发者由于疏忽,将包含数据…...

【R语言】数学运算
一、基础运算 R语言中能实现加、减、乘、除、求模、取整、取绝对值、指数、对数等运算。 x <- 2 y <- 10 # 求模 y %% x # 整除 y %/% x # 取绝对值 abs(-x) # 指数运算 y ^x y^1/x #对数运算 log(x) #log()函数默认情况下以 e 为底 双等号“”的作用等同于identical(…...