当前位置: 首页 > news >正文

MMSegmentation改进:增加Kappa系数评价指数

将mmseg\evaluation\metrics\iou_metric.py文件中的内容替换成以下内容即可:

支持输出单类Kappa系数和平均Kappa系数。

使用方法:将dataset的config文件中:val_evaluator 添加'mKappa',如 val_evaluator = dict(type='mmseg.IoUMetric', iou_metrics=['mFscore', 'mIoU', 'mKappa'])。

欢迎关注大地主的CSDN与 ABCnutter (github.com),敬请期待更多精彩内容

# Copyright (c) OpenMMLab. All rights reserved.
import os.path as osp
from collections import OrderedDict
from typing import Dict, List, Optional, Sequenceimport numpy as np
import torch
from mmengine.dist import is_main_process
from mmengine.evaluator import BaseMetric
from mmengine.logging import MMLogger, print_log
from mmengine.utils import mkdir_or_exist
from PIL import Image
from prettytable import PrettyTablefrom mmseg.registry import METRICS@METRICS.register_module()
class IoUMetric(BaseMetric):"""IoU evaluation metric.Args:ignore_index (int): Index that will be ignored in evaluation.Default: 255.iou_metrics (list[str] | str): Metrics to be calculated, the optionsinclude 'mIoU', 'mDice', 'mFscore', and 'Kappa'.nan_to_num (int, optional): If specified, NaN values will be replacedby the numbers defined by the user. Default: None.beta (int): Determines the weight of recall in the combined score.Default: 1.collect_device (str): Device name used for collecting results fromdifferent ranks during distributed training. Must be 'cpu' or'gpu'. Defaults to 'cpu'.output_dir (str): The directory for output prediction. Defaults toNone.format_only (bool): Only format result for results commit withoutperform evaluation. It is useful when you want to save the resultto a specific format and submit it to the test server.Defaults to False.prefix (str, optional): The prefix that will be added in the metricnames to disambiguate homonymous metrics of different evaluators.If prefix is not provided in the argument, self.default_prefixwill be used instead. Defaults to None."""def __init__(self,ignore_index: int = 255,iou_metrics: List[str] = ['mIoU'],nan_to_num: Optional[int] = None,beta: int = 1,collect_device: str = 'cpu',output_dir: Optional[str] = None,format_only: bool = False,prefix: Optional[str] = None,**kwargs) -> None:super().__init__(collect_device=collect_device, prefix=prefix)self.ignore_index = ignore_indexself.metrics = iou_metricsself.nan_to_num = nan_to_numself.beta = betaself.output_dir = output_dirif self.output_dir and is_main_process():mkdir_or_exist(self.output_dir)self.format_only = format_onlydef process(self, data_batch: dict, data_samples: Sequence[dict]) -> None:"""Process one batch of data and data_samples.The processed results should be stored in ``self.results``, which willbe used to compute the metrics when all batches have been processed.Args:data_batch (dict): A batch of data from the dataloader.data_samples (Sequence[dict]): A batch of outputs from the model."""num_classes = len(self.dataset_meta['classes'])for data_sample in data_samples:pred_label = data_sample['pred_sem_seg']['data'].squeeze()# format_only always for test dataset without ground truthif not self.format_only:label = data_sample['gt_sem_seg']['data'].squeeze().to(pred_label)self.results.append(self.intersect_and_union(pred_label, label, num_classes,self.ignore_index))# format_resultif self.output_dir is not None:basename = osp.splitext(osp.basename(data_sample['img_path']))[0]png_filename = osp.abspath(osp.join(self.output_dir, f'{basename}.png'))output_mask = pred_label.cpu().numpy()# The index range of official ADE20k dataset is from 0 to 150.# But the index range of output is from 0 to 149.# That is because we set reduce_zero_label=True.if data_sample.get('reduce_zero_label', False):output_mask = output_mask + 1output = Image.fromarray(output_mask.astype(np.uint8))output.save(png_filename)def compute_metrics(self, results: list) -> Dict[str, float]:"""Compute the metrics from processed results.Args:results (list): The processed results of each batch.Returns:Dict[str, float]: The computed metrics. The keys are the names ofthe metrics, and the values are corresponding results. The keymainly includes aAcc, mIoU, mAcc, mDice, mFscore, mPrecision,mRecall, and Kappa."""logger: MMLogger = MMLogger.get_current_instance()if self.format_only:logger.info(f'results are saved to {osp.dirname(self.output_dir)}')return OrderedDict()# convert list of tuples to tuple of lists, e.g.# [(A_1, B_1, C_1, D_1), ...,  (A_n, B_n, C_n, D_n)] to# ([A_1, ..., A_n], ..., [D_1, ..., D_n])results = tuple(zip(*results))assert len(results) == 4total_area_intersect = sum(results[0])total_area_union = sum(results[1])total_area_pred_label = sum(results[2])total_area_label = sum(results[3])ret_metrics = self.total_area_to_metrics(total_area_intersect, total_area_union, total_area_pred_label,total_area_label, self.metrics, self.nan_to_num, self.beta)class_names = self.dataset_meta['classes']# summary tableret_metrics_summary = OrderedDict({ret_metric: np.round(np.nanmean(ret_metric_value) * 100, 2)for ret_metric, ret_metric_value in ret_metrics.items()})metrics = dict()for key, val in ret_metrics_summary.items():if key == 'aAcc':metrics[key] = valelse:metrics['m' + key] = val# each class tableret_metrics.pop('aAcc', None)# ret_metrics.pop('Kappa', None)ret_metrics_class = OrderedDict({ret_metric: np.round(ret_metric_value * 100, 2)for ret_metric, ret_metric_value in ret_metrics.items()})ret_metrics_class.update({'Class': class_names})ret_metrics_class.move_to_end('Class', last=False)class_table_data = PrettyTable()for key, val in ret_metrics_class.items():class_table_data.add_column(key, val)print_log('per class results:', logger)print_log('\n' + class_table_data.get_string(), logger=logger)return metrics@staticmethoddef intersect_and_union(pred_label: torch.tensor, label: torch.tensor,num_classes: int, ignore_index: int):"""Calculate Intersection and Union.Args:pred_label (torch.tensor): Prediction segmentation mapor predict result filename. The shape is (H, W).label (torch.tensor): Ground truth segmentation mapor label filename. The shape is (H, W).num_classes (int): Number of categories.ignore_index (int): Index that will be ignored in evaluation.Returns:torch.Tensor: The intersection of prediction and ground truthhistogram on all classes.torch.Tensor: The union of prediction and ground truth histogram onall classes.torch.Tensor: The prediction histogram on all classes.torch.Tensor: The ground truth histogram on all classes."""mask = (label != ignore_index)pred_label = pred_label[mask]label = label[mask]intersect = pred_label[pred_label == label]area_intersect = torch.histc(intersect.float(), bins=(num_classes), min=0,max=num_classes - 1).cpu()area_pred_label = torch.histc(pred_label.float(), bins=(num_classes), min=0,max=num_classes - 1).cpu()area_label = torch.histc(label.float(), bins=(num_classes), min=0,max=num_classes - 1).cpu()area_union = area_pred_label + area_label - area_intersectreturn area_intersect, area_union, area_pred_label, area_label@staticmethoddef total_area_to_metrics(total_area_intersect: np.ndarray,total_area_union: np.ndarray,total_area_pred_label: np.ndarray,total_area_label: np.ndarray,metrics: List[str] = ['mIoU'],nan_to_num: Optional[int] = None,beta: int = 1):"""Calculate evaluation metricsArgs:total_area_intersect (np.ndarray): The intersection of predictionand ground truth histogram on all classes.total_area_union (np.ndarray): The union of prediction and groundtruth histogram on all classes.total_area_pred_label (np.ndarray): The prediction histogram onall classes.total_area_label (np.ndarray): The ground truth histogram onall classes.metrics (List[str] | str): Metrics to be evaluated, 'mIoU', 'mDice','mFscore', and 'Kappa'.nan_to_num (int, optional): If specified, NaN values will bereplaced by the numbers defined by the user. Default: None.beta (int): Determines the weight of recall in the combined score.Default: 1.Returns:Dict[str, np.ndarray]: per category evaluation metrics,shape (num_classes, )."""def f_score(precision, recall, beta=1):"""calculate the f-score value.Args:precision (float | torch.Tensor): The precision value.recall (float | torch.Tensor): The recall value.beta (int): Determines the weight of recall in the combinedscore. Default: 1.Returns:[torch.tensor]: The f-score value."""score = (1 + beta**2) * (precision * recall) / ((beta**2 * precision) + recall)return scoreif isinstance(metrics, str):metrics = [metrics]allowed_metrics = ['mIoU', 'mDice', 'mFscore', 'mKappa']if not set(metrics).issubset(set(allowed_metrics)):raise KeyError(f'metrics {metrics} is not supported')all_acc = total_area_intersect.sum() / total_area_label.sum()ret_metrics = OrderedDict({'aAcc': all_acc})for metric in metrics:if metric == 'mIoU':iou = total_area_intersect / total_area_unionacc = total_area_intersect / total_area_labelret_metrics['IoU'] = iouret_metrics['Acc'] = accelif metric == 'mDice':dice = 2 * total_area_intersect / (total_area_pred_label + total_area_label)acc = total_area_intersect / total_area_labelret_metrics['Dice'] = diceret_metrics['Acc'] = accelif metric == 'mFscore':precision = total_area_intersect / total_area_pred_labelrecall = total_area_intersect / total_area_labelf_value = torch.tensor([f_score(x[0], x[1], beta) for x in zip(precision, recall)])ret_metrics['Fscore'] = f_valueret_metrics['Precision'] = precisionret_metrics['Recall'] = recallelif metric == 'mKappa':total = total_area_label.sum()po = total_area_intersect / total_area_labelpe = (total_area_pred_label * total_area_label) / (total ** 2)kappa = (po - pe) / (1 - pe)ret_metrics['Kappa'] = kapparet_metrics = {metric: value.numpy() if isinstance(value, torch.Tensor) else valuefor metric, value in ret_metrics.items()}if nan_to_num is not None:ret_metrics = OrderedDict({metric: np.nan_to_num(metric_value, nan=nan_to_num)for metric, metric_value in ret_metrics.items()})return ret_metrics

相关文章:

MMSegmentation改进:增加Kappa系数评价指数

将mmseg\evaluation\metrics\iou_metric.py文件中的内容替换成以下内容即可: 支持输出单类Kappa系数和平均Kappa系数。 使用方法:将dataset的config文件中:val_evaluator 添加mKappa,如 val_evaluator dict(typemmseg.IoUMetri…...

专栏【汇总】

专栏【汇总】 前言版权推荐专栏【汇总】付费 汇总置顶在读在学我的面试计算机重要课程java面试Java基础数据存储Java框架java提高计算机科学与技术课程算法杂项 最后 前言 2024-5-12 21:13:02 以下内容源自《【专栏】》 仅供学习交流使用 版权 禁止其他平台发布时删除以下此…...

成功解决IndexError: index 0 is out of bounds for axis 1 with size 0

成功解决IndexError: index 0 is out of bounds for axis 1 with size 0 🛠️ 成功解决IndexError: index 0 is out of bounds for axis 1 with size 0摘要引言正文内容(详细介绍)🤔 错误分析:为什么会发生IndexError&…...

C# MES通信从入门到精通(11)——C#如何使用Json字符串

前言 我们在开发上位机软件的过程中,经常需要和Mes系统进行数据交互,并且最常用的数据格式是Json,本文就是详细介绍Json格式的类型,以及我们在与mes系统进行交互时如何组织Json数据。 1、在C#中如何调用Json 在C#中调用Json相关…...

ON DUPLICATE KEY UPDATE 子句

ON DUPLICATE KEY UPDATE 是 MySQL 中的一个 SQL 语句中的子句,主要用于在执行 INSERT 操作时处理可能出现的重复键值冲突。当尝试插入的记录导致唯一索引或主键约束冲突时(即试图插入的记录的键值已经存在于表中),此子句会触发一…...

perl use HTTP::Server::Simple 轻量级 http server

cpan -i HTTP::Server::Simple 返回:已是 up to date. 但是我在 D:\Strawberry\perl\site\lib\ 找不到 HTTP\Server 手工安装:下载 HTTP-Server-Simple-0.52.tar.gz 解压 tar zxvf HTTP-Server-Simple-0.52.tar.gz cd D:\perl\HTTP-Server-Simple-…...

【STM32】基于I2C协议的OLED显示(利用U82G库)

【STM32】基于I2C协议的OLED显示(利用U82G库) 文章目录 【STM32】基于I2C协议的OLED显示(利用U82G库)一、实验背景二、U8g2介绍(一)获取(二)简介 三、实践(一)CubexMX配置(二)U8g2配…...

掌握Python3输入输出:轻松实现用户交互、日志记录与数据处理

Python 是一门简洁且强大的编程语言,广泛应用于各个领域。在 Python 编程中,输入和输出是基本而重要的操作。无论是进行用户交互、记录日志信息,还是将计算结果输出到控制台或文件,掌握这些操作都是编写高效 Python 程序的关键。本…...

用于每个平台的最佳WordPress LMS主题

你已选择在 WordPress 上构建学习管理系统 (LMS)了。恭喜! 你甚至可能已经选择了要使用的 LMS 插件,这已经是成功的一半了。 现在是时候弄清楚哪个 WordPress LMS 主题要与你的插件配对。 我将解释 LMS 主题和插件之间的区别,以便你了解要…...

pytorch 加权CE_loss实现(语义分割中的类不平衡使用)

加权CE_loss和BCE_loss稍有不同 1.标签为long类型,BCE标签为float类型 2.当reduction为mean时计算每个像素点的损失的平均,BCE除以像素数得到平均值,CE除以像素对应的权重之和得到平均值。 参数配置torch.nn.CrossEntropyLoss(weightNone,…...

【iOS】UI——关于UIAlertController类(警告对话框)

目录 前言关于UIAlertController具体操作及代码实现总结 前言 在UI的警告对话框的学习中,我们发现UIAlertView在iOS 9中已经被废弃,我们找到UIAlertController来代替UIAlertView实现弹出框的功能,从而有了这篇关于UIAlertController的学习笔记…...

django支持https

测试环境,可以用django自带的证书 安装模块 sudo pip3 install django_sslserver服务端https启动 python3 manage.py runsslserver 127.0.0.1:8001https访问 https://127.0.0.1:8001/quota/api/XXX...

算法题day41(补5.27日卡:动态规划01)

一、动态规划基础知识:在动态规划中每一个状态一定是由上一个状态推导出来的。 动态规划五部曲: 1.确定dp数组 以及下标的含义 2.确定递推公式 3.dp数组如何初始化 4.确定遍历顺序 5.举例推导dp数组 debug方式:打印 二、刷题&#xf…...

【附带源码】机械臂MoveIt2极简教程(四)、第一个入门demo

系列文章目录 【附带源码】机械臂MoveIt2极简教程(一)、moveit2安装 【附带源码】机械臂MoveIt2极简教程(二)、move_group交互 【附带源码】机械臂MoveIt2极简教程(三)、URDF/SRDF介绍 【附带源码】机械臂MoveIt2极简教程(四)、第一个入门demo 目录 系列文章目录1. 创…...

基于蚁群算法的二维路径规划算法(matlab)

微♥关注“电击小子程高兴的MATLAB小屋”获得资料 一、理论基础 1、路径规划算法 路径规划算法是指在有障碍物的工作环境中寻找一条从起点到终点、无碰撞地绕过所有障碍物的运动路径。路径规划算法较多,大体上可分为全局路径规划算法和局部路径规划算法两大类。其…...

政务云参考技术架构

行业优势 总体架构 政务云平台技术框架图,由机房环境、基础设施层、支撑软件层及业务应用层组成,在运维、安全和运营体系的保障下,为政务云使用单位提供统一服务支撑。 功能架构 标准双区隔离 参照国家电子政务规范,打造符合标准的…...

android 13 aosp 预置so库

展讯对应的main.mk配置 device/sprd/qogirn**/ums***/product/***_native/main.mk $(call inherit-product-if-exists, vendor/***/build.mk)vendor/***/build.mk PRODUCT_PACKAGES \libtestvendor///Android.bp cc_prebuilt_library_shared{name:"libtest",srcs:…...

mongo篇---mongoDB Compass连接数据库

mongo篇—mongoDB Compass连接数据库 mongoDB笔记 – 第一条 一、mongoDB Compass连接远程数据库,配置URL。 URL: mongodb://username:passwordhost:port点击connect即可。 注意:host最好使用名称,防止出错连接超时。...

基于SOA海鸥优化算法的三维曲面最高点搜索matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于SOA海鸥优化算法的三维曲面最高点搜索matlab仿真,输出收敛曲线以及三维曲面最高点搜索结果。 2.测试软件版本以及运行结果展示 MATLAB2022A版本…...

前端js解析websocket推送的gzip压缩json的Blob数据

主要依赖插件pako https://www.npmjs.com/package/pako 1、安装 npm install pako 2、使用, pako.inflate(reader.result, {to: "string"}) 解压后的string 对象,需要JSON.parse转成json this.ws.onmessage (evt) > {console.log("…...

抖音增长新引擎:品融电商,一站式全案代运营领跑者

抖音增长新引擎:品融电商,一站式全案代运营领跑者 在抖音这个日活超7亿的流量汪洋中,品牌如何破浪前行?自建团队成本高、效果难控;碎片化运营又难成合力——这正是许多企业面临的增长困局。品融电商以「抖音全案代运营…...

[ICLR 2022]How Much Can CLIP Benefit Vision-and-Language Tasks?

论文网址:pdf 英文是纯手打的!论文原文的summarizing and paraphrasing。可能会出现难以避免的拼写错误和语法错误,若有发现欢迎评论指正!文章偏向于笔记,谨慎食用 目录 1. 心得 2. 论文逐段精读 2.1. Abstract 2…...

CocosCreator 之 JavaScript/TypeScript和Java的相互交互

引擎版本: 3.8.1 语言: JavaScript/TypeScript、C、Java 环境:Window 参考:Java原生反射机制 您好,我是鹤九日! 回顾 在上篇文章中:CocosCreator Android项目接入UnityAds 广告SDK。 我们简单讲…...

论文浅尝 | 基于判别指令微调生成式大语言模型的知识图谱补全方法(ISWC2024)

笔记整理:刘治强,浙江大学硕士生,研究方向为知识图谱表示学习,大语言模型 论文链接:http://arxiv.org/abs/2407.16127 发表会议:ISWC 2024 1. 动机 传统的知识图谱补全(KGC)模型通过…...

全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比

目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec? IPsec VPN 5.1 IPsec传输模式(Transport Mode) 5.2 IPsec隧道模式(Tunne…...

均衡后的SNRSINR

本文主要摘自参考文献中的前两篇,相关文献中经常会出现MIMO检测后的SINR不过一直没有找到相关数学推到过程,其中文献[1]中给出了相关原理在此仅做记录。 1. 系统模型 复信道模型 n t n_t nt​ 根发送天线, n r n_r nr​ 根接收天线的 MIMO 系…...

laravel8+vue3.0+element-plus搭建方法

创建 laravel8 项目 composer create-project --prefer-dist laravel/laravel laravel8 8.* 安装 laravel/ui composer require laravel/ui 修改 package.json 文件 "devDependencies": {"vue/compiler-sfc": "^3.0.7","axios": …...

C++课设:简易日历程序(支持传统节假日 + 二十四节气 + 个人纪念日管理)

名人说:路漫漫其修远兮,吾将上下而求索。—— 屈原《离骚》 创作者:Code_流苏(CSDN)(一个喜欢古诗词和编程的Coder😊) 专栏介绍:《编程项目实战》 目录 一、为什么要开发一个日历程序?1. 深入理解时间算法2. 练习面向对象设计3. 学习数据结构应用二、核心算法深度解析…...

「全栈技术解析」推客小程序系统开发:从架构设计到裂变增长的完整解决方案

在移动互联网营销竞争白热化的当下,推客小程序系统凭借其裂变传播、精准营销等特性,成为企业抢占市场的利器。本文将深度解析推客小程序系统开发的核心技术与实现路径,助力开发者打造具有市场竞争力的营销工具。​ 一、系统核心功能架构&…...

数据结构:递归的种类(Types of Recursion)

目录 尾递归(Tail Recursion) 什么是 Loop(循环)? 复杂度分析 头递归(Head Recursion) 树形递归(Tree Recursion) 线性递归(Linear Recursion)…...