当前位置: 首页 > news >正文

pytorch 加权CE_loss实现(语义分割中的类不平衡使用)

加权CE_loss和BCE_loss稍有不同

1.标签为long类型,BCE标签为float类型
2.当reduction为mean时计算每个像素点的损失的平均,BCE除以像素数得到平均值,CE除以像素对应的权重之和得到平均值。
在这里插入图片描述

参数配置torch.nn.CrossEntropyLoss(weight=None,size_average=None,ignore_index=-100,reduce=None,reduction=‘mean’,label_smoothing=0.0)

增加加权的CE_loss代码实现

# 总之, CrossEntropyLoss() = softmax + log + NLLLoss() = log_softmax + NLLLoss(), 具体等价应用如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as npclass CrossEntropyLoss2d(nn.Module):def __init__(self, weight=None):super(CrossEntropyLoss2d, self).__init__()self.nll_loss = nn.CrossEntropyLoss(weight, reduction='mean')def forward(self, preds, targets):return self.nll_loss(preds, targets)

语义分割类别计算

class CE_w_loss(nn.Module):def __init__(self,ignore_index=255):super(CE_w_loss, self).__init__()self.ignore_index = ignore_index# self.CE = nn.CrossEntropyLoss(ignore_index=self.ignore_index)def forward(self, outputs, targets):class_num = outputs.shape[1]# print("class_num :",class_num )# # 计算每个类别在整个 batch 中的像素数占比class_pixel_counts = torch.bincount(targets.flatten(), minlength=class_num)  # 假设有class_num个类别class_pixel_proportions = class_pixel_counts.float() / torch.numel(targets)# # 根据类别占比计算权重class_weights = 1.0 / (torch.log(1.02 + class_pixel_proportions)).double()  # 使用对数变换平衡权重# # print("class_weights :",class_weights)## 定义交叉熵损失函数,并使用动态计算的类别权重criterion = nn.CrossEntropyLoss(ignore_index=self.ignore_index,weight= class_weights)# 计算损失loss = criterion(outputs, targets)print(loss.item())  # 打印损失值return lossnp.random.seed(666)pred = np.ones((2, 5, 256,256))seg = np.ones((2, 5, 256, 256)) # 灰度label = np.ones((2, 256, 256))  # 灰度pred = torch.from_numpy(pred)seg = torch.from_numpy(seg).int()  # 灰度label = torch.from_numpy(label).long()ce = CE_w_loss()loss = ce(pred, label)print("loss:",loss.item())

调用库(手动设置权重)

import torch
import torch.nn as nn# 假设有一些模型输出和目标标签
model_output = torch.randn(3, 5)  # 假设有5个类别
target = torch.empty(3, dtype=torch.long).random_(5)# 定义权重
weights = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])# 定义交叉熵损失函数,并设置权重
criterion = nn.CrossEntropyLoss(weight=weights)# 计算损失
loss = criterion(model_output, target)
print(loss)

自适应计算权重

import torch
import torch.nn as nn
import numpy as np# 假设我们有一个包含10个样本的批次,每个样本属于4个类别之一
batch_size = 10
num_classes = 4# 随机生成未经过 softmax 的logits输出(网络的最后一层输出)
logits = torch.randn(batch_size, num_classes, requires_grad=True)# 真实的标签(每个样本的类别索引),例如 [0, 2, 1, 3, 0, 0, 1, 2, 3, 3]
labels = torch.tensor([0, 2, 1, 3, 0, 0, 1, 2, 3, 3])# 统计每个类别的频率
class_counts = torch.bincount(labels, minlength=num_classes).float()# 计算每个类别的权重,权重可以为类别频率的倒数
# 为了防止分母为零,这里加一个小的常数epsilon
epsilon = 1e-6
class_weights = 1.0 / (class_counts + epsilon)# 归一化权重,使其和为1
class_weights /= class_weights.sum()print('Class Counts:', class_counts)
print('Class Weights:', class_weights)# 创建带权重的交叉熵损失函数
criterion = nn.CrossEntropyLoss(weight=class_weights)# 计算损失值
loss = criterion(logits, labels)print('Logits:\n', logits)
print('Labels:\n', labels)
print('Weighted Cross-Entropy Loss:', loss.item())# 反向传播梯度
loss.backward()

报错

Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).float() 报错
Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).double() 正确

参考:[1]https://blog.csdn.net/CSDN_of_ding/article/details/111515226
[2] https://blog.csdn.net/qq_40306845/article/details/137651442
[3] https://www.zhihu.com/question/400443029/answer/2477658229

相关文章:

pytorch 加权CE_loss实现(语义分割中的类不平衡使用)

加权CE_loss和BCE_loss稍有不同 1.标签为long类型,BCE标签为float类型 2.当reduction为mean时计算每个像素点的损失的平均,BCE除以像素数得到平均值,CE除以像素对应的权重之和得到平均值。 参数配置torch.nn.CrossEntropyLoss(weightNone,…...

【iOS】UI——关于UIAlertController类(警告对话框)

目录 前言关于UIAlertController具体操作及代码实现总结 前言 在UI的警告对话框的学习中,我们发现UIAlertView在iOS 9中已经被废弃,我们找到UIAlertController来代替UIAlertView实现弹出框的功能,从而有了这篇关于UIAlertController的学习笔记…...

django支持https

测试环境,可以用django自带的证书 安装模块 sudo pip3 install django_sslserver服务端https启动 python3 manage.py runsslserver 127.0.0.1:8001https访问 https://127.0.0.1:8001/quota/api/XXX...

算法题day41(补5.27日卡:动态规划01)

一、动态规划基础知识:在动态规划中每一个状态一定是由上一个状态推导出来的。 动态规划五部曲: 1.确定dp数组 以及下标的含义 2.确定递推公式 3.dp数组如何初始化 4.确定遍历顺序 5.举例推导dp数组 debug方式:打印 二、刷题&#xf…...

【附带源码】机械臂MoveIt2极简教程(四)、第一个入门demo

系列文章目录 【附带源码】机械臂MoveIt2极简教程(一)、moveit2安装 【附带源码】机械臂MoveIt2极简教程(二)、move_group交互 【附带源码】机械臂MoveIt2极简教程(三)、URDF/SRDF介绍 【附带源码】机械臂MoveIt2极简教程(四)、第一个入门demo 目录 系列文章目录1. 创…...

基于蚁群算法的二维路径规划算法(matlab)

微♥关注“电击小子程高兴的MATLAB小屋”获得资料 一、理论基础 1、路径规划算法 路径规划算法是指在有障碍物的工作环境中寻找一条从起点到终点、无碰撞地绕过所有障碍物的运动路径。路径规划算法较多,大体上可分为全局路径规划算法和局部路径规划算法两大类。其…...

政务云参考技术架构

行业优势 总体架构 政务云平台技术框架图,由机房环境、基础设施层、支撑软件层及业务应用层组成,在运维、安全和运营体系的保障下,为政务云使用单位提供统一服务支撑。 功能架构 标准双区隔离 参照国家电子政务规范,打造符合标准的…...

android 13 aosp 预置so库

展讯对应的main.mk配置 device/sprd/qogirn**/ums***/product/***_native/main.mk $(call inherit-product-if-exists, vendor/***/build.mk)vendor/***/build.mk PRODUCT_PACKAGES \libtestvendor///Android.bp cc_prebuilt_library_shared{name:"libtest",srcs:…...

mongo篇---mongoDB Compass连接数据库

mongo篇—mongoDB Compass连接数据库 mongoDB笔记 – 第一条 一、mongoDB Compass连接远程数据库,配置URL。 URL: mongodb://username:passwordhost:port点击connect即可。 注意:host最好使用名称,防止出错连接超时。...

基于SOA海鸥优化算法的三维曲面最高点搜索matlab仿真

目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于SOA海鸥优化算法的三维曲面最高点搜索matlab仿真,输出收敛曲线以及三维曲面最高点搜索结果。 2.测试软件版本以及运行结果展示 MATLAB2022A版本…...

前端js解析websocket推送的gzip压缩json的Blob数据

主要依赖插件pako https://www.npmjs.com/package/pako 1、安装 npm install pako 2、使用, pako.inflate(reader.result, {to: "string"}) 解压后的string 对象,需要JSON.parse转成json this.ws.onmessage (evt) > {console.log("…...

【wiki知识库】06.文档管理接口的实现--SpringBoot后端部分

目录 一、🔥今日目标 二、🎈SpringBoot部分类的添加 1.调用MybatisGenerator 2.添加DocSaveParam 3.添加DocQueryVo 三、🚆后端新增接口 3.1添加DocController 3.1.1 /all/{ebokId} 3.1.2 /doc/save 3.1.3 /doc/delete/{idStr} …...

c,c++,go语言字符串的演进

#include <stdio.h> #include <string.h> int main() {char str[] {a,b,c,\0,d,d,d};printf("string:[%s], len:%d \n", str, strlen(str) );return 0; } string:[abc], len:3 c语言只有数组的概念&#xff0c;数组本身没有长度的概念&#xff0c;需…...

vue-cli 快速入门

vue-cli &#xff08;目前向Vite发展&#xff09; 介绍&#xff1a;Vue-cli 是Vue官方提供一个脚手架&#xff0c;用于快速生成一个Vue的项目模板。 Vue-cli提供了如下功能&#xff1a; 统一的目录结构 本地调试 热部署 单元测试 集成打包上线 依赖环境&#xff1a;NodeJ…...

机器人--矩阵运算

两个矩阵相乘的含义 P点在坐标系B中的坐标系PB&#xff0c;需要乘以B到A到变换矩阵TAB。 M点在B坐标系中的位姿MB&#xff0c;怎么计算M在A中的坐标系&#xff1f; 两个矩阵相乘 一个矩阵*另一个矩阵的逆矩阵...

期末复习【汇总】

期末复习【汇总】 前言版权推荐期末复习【汇总】最后 前言 2024-5-12 20:52:17 截止到今天&#xff0c;所有期末复习的汇总 以下内容源自《【创作模板】》 仅供学习交流使用 版权 禁止其他平台发布时删除以下此话 本文首次发布于CSDN平台 作者是CSDN日星月云 博客主页是ht…...

【IM即时通讯】MQTT协议的详解(3)- CONNACK Packet

【IM即时通讯】MQTT协议的详解&#xff08;3&#xff09;- CONNACK Packet 文章目录 【IM即时通讯】MQTT协议的详解&#xff08;3&#xff09;- CONNACK Packet前言说明一、固定同步详解、可变头部详解总结 前言 关于所有的类型的数据示例已经在上面一篇博客说完&#xff1a; …...

Linux - 深入理解/proc虚拟文件系统:从基础到高级

文章目录 Linux /proc虚拟文件系统/proc/self使用 /proc/self 的优势/proc/self 的使用案例案例1&#xff1a;获取当前进程的状态信息案例2&#xff1a;获取当前进程的命令行参数案例3&#xff1a;获取当前进程的内存映射案例4&#xff1a;获取当前进程的文件描述符 /proc中进程…...

Django DeleteView视图

Django 的 DeleteView 是一个基于类的视图&#xff0c;用于处理对象的删除操作。 1&#xff0c;添加视图函数 Test/app3/views.py from django.shortcuts import render# Create your views here. from .models import Bookfrom django.views.generic import ListView class B…...

代码杂谈 之 pyspark如何做相似度计算

在 PySpark 中&#xff0c;计算 DataFrame 两列向量的差可以通过使用 UDF&#xff08;用户自定义函数&#xff09;和 Vector 类型完成。这里有一个示例&#xff0c;展示了如何使用 PySpark 的 pyspark.ml.linalg.Vectorspyspark.sql.functions.udf 来实现这一功能&#xff1a…...

C++:std::is_convertible

C++标志库中提供is_convertible,可以测试一种类型是否可以转换为另一只类型: template <class From, class To> struct is_convertible; 使用举例: #include <iostream> #include <string>using namespace std;struct A { }; struct B : A { };int main…...

工业安全零事故的智能守护者:一体化AI智能安防平台

前言&#xff1a; 通过AI视觉技术&#xff0c;为船厂提供全面的安全监控解决方案&#xff0c;涵盖交通违规检测、起重机轨道安全、非法入侵检测、盗窃防范、安全规范执行监控等多个方面&#xff0c;能够实现对应负责人反馈机制&#xff0c;并最终实现数据的统计报表。提升船厂…...

安宝特方案丨XRSOP人员作业标准化管理平台:AR智慧点检验收套件

在选煤厂、化工厂、钢铁厂等过程生产型企业&#xff0c;其生产设备的运行效率和非计划停机对工业制造效益有较大影响。 随着企业自动化和智能化建设的推进&#xff0c;需提前预防假检、错检、漏检&#xff0c;推动智慧生产运维系统数据的流动和现场赋能应用。同时&#xff0c;…...

Opencv中的addweighted函数

一.addweighted函数作用 addweighted&#xff08;&#xff09;是OpenCV库中用于图像处理的函数&#xff0c;主要功能是将两个输入图像&#xff08;尺寸和类型相同&#xff09;按照指定的权重进行加权叠加&#xff08;图像融合&#xff09;&#xff0c;并添加一个标量值&#x…...

第一篇:Agent2Agent (A2A) 协议——协作式人工智能的黎明

AI 领域的快速发展正在催生一个新时代&#xff0c;智能代理&#xff08;agents&#xff09;不再是孤立的个体&#xff0c;而是能够像一个数字团队一样协作。然而&#xff0c;当前 AI 生态系统的碎片化阻碍了这一愿景的实现&#xff0c;导致了“AI 巴别塔问题”——不同代理之间…...

论文浅尝 | 基于判别指令微调生成式大语言模型的知识图谱补全方法(ISWC2024)

笔记整理&#xff1a;刘治强&#xff0c;浙江大学硕士生&#xff0c;研究方向为知识图谱表示学习&#xff0c;大语言模型 论文链接&#xff1a;http://arxiv.org/abs/2407.16127 发表会议&#xff1a;ISWC 2024 1. 动机 传统的知识图谱补全&#xff08;KGC&#xff09;模型通过…...

NLP学习路线图(二十三):长短期记忆网络(LSTM)

在自然语言处理(NLP)领域,我们时刻面临着处理序列数据的核心挑战。无论是理解句子的结构、分析文本的情感,还是实现语言的翻译,都需要模型能够捕捉词语之间依时序产生的复杂依赖关系。传统的神经网络结构在处理这种序列依赖时显得力不从心,而循环神经网络(RNN) 曾被视为…...

大学生职业发展与就业创业指导教学评价

这里是引用 作为软工2203/2204班的学生&#xff0c;我们非常感谢您在《大学生职业发展与就业创业指导》课程中的悉心教导。这门课程对我们即将面临实习和就业的工科学生来说至关重要&#xff0c;而您认真负责的教学态度&#xff0c;让课程的每一部分都充满了实用价值。 尤其让我…...

蓝桥杯3498 01串的熵

问题描述 对于一个长度为 23333333的 01 串, 如果其信息熵为 11625907.5798&#xff0c; 且 0 出现次数比 1 少, 那么这个 01 串中 0 出现了多少次? #include<iostream> #include<cmath> using namespace std;int n 23333333;int main() {//枚举 0 出现的次数//因…...

初探Service服务发现机制

1.Service简介 Service是将运行在一组Pod上的应用程序发布为网络服务的抽象方法。 主要功能&#xff1a;服务发现和负载均衡。 Service类型的包括ClusterIP类型、NodePort类型、LoadBalancer类型、ExternalName类型 2.Endpoints简介 Endpoints是一种Kubernetes资源&#xf…...