pytorch 加权CE_loss实现(语义分割中的类不平衡使用)
加权CE_loss和BCE_loss稍有不同
1.标签为long类型,BCE标签为float类型
2.当reduction为mean时计算每个像素点的损失的平均,BCE除以像素数得到平均值,CE除以像素对应的权重之和得到平均值。
参数配置torch.nn.CrossEntropyLoss(weight=None,size_average=None,ignore_index=-100,reduce=None,reduction=‘mean’,label_smoothing=0.0)
增加加权的CE_loss代码实现
# 总之, CrossEntropyLoss() = softmax + log + NLLLoss() = log_softmax + NLLLoss(), 具体等价应用如下:
import torch
import torch.nn as nn
import torch.nn.functional as F
import random
import numpy as npclass CrossEntropyLoss2d(nn.Module):def __init__(self, weight=None):super(CrossEntropyLoss2d, self).__init__()self.nll_loss = nn.CrossEntropyLoss(weight, reduction='mean')def forward(self, preds, targets):return self.nll_loss(preds, targets)
语义分割类别计算
class CE_w_loss(nn.Module):def __init__(self,ignore_index=255):super(CE_w_loss, self).__init__()self.ignore_index = ignore_index# self.CE = nn.CrossEntropyLoss(ignore_index=self.ignore_index)def forward(self, outputs, targets):class_num = outputs.shape[1]# print("class_num :",class_num )# # 计算每个类别在整个 batch 中的像素数占比class_pixel_counts = torch.bincount(targets.flatten(), minlength=class_num) # 假设有class_num个类别class_pixel_proportions = class_pixel_counts.float() / torch.numel(targets)# # 根据类别占比计算权重class_weights = 1.0 / (torch.log(1.02 + class_pixel_proportions)).double() # 使用对数变换平衡权重# # print("class_weights :",class_weights)## 定义交叉熵损失函数,并使用动态计算的类别权重criterion = nn.CrossEntropyLoss(ignore_index=self.ignore_index,weight= class_weights)# 计算损失loss = criterion(outputs, targets)print(loss.item()) # 打印损失值return lossnp.random.seed(666)pred = np.ones((2, 5, 256,256))seg = np.ones((2, 5, 256, 256)) # 灰度label = np.ones((2, 256, 256)) # 灰度pred = torch.from_numpy(pred)seg = torch.from_numpy(seg).int() # 灰度label = torch.from_numpy(label).long()ce = CE_w_loss()loss = ce(pred, label)print("loss:",loss.item())
调用库(手动设置权重)
import torch
import torch.nn as nn# 假设有一些模型输出和目标标签
model_output = torch.randn(3, 5) # 假设有5个类别
target = torch.empty(3, dtype=torch.long).random_(5)# 定义权重
weights = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0])# 定义交叉熵损失函数,并设置权重
criterion = nn.CrossEntropyLoss(weight=weights)# 计算损失
loss = criterion(model_output, target)
print(loss)
自适应计算权重
import torch
import torch.nn as nn
import numpy as np# 假设我们有一个包含10个样本的批次,每个样本属于4个类别之一
batch_size = 10
num_classes = 4# 随机生成未经过 softmax 的logits输出(网络的最后一层输出)
logits = torch.randn(batch_size, num_classes, requires_grad=True)# 真实的标签(每个样本的类别索引),例如 [0, 2, 1, 3, 0, 0, 1, 2, 3, 3]
labels = torch.tensor([0, 2, 1, 3, 0, 0, 1, 2, 3, 3])# 统计每个类别的频率
class_counts = torch.bincount(labels, minlength=num_classes).float()# 计算每个类别的权重,权重可以为类别频率的倒数
# 为了防止分母为零,这里加一个小的常数epsilon
epsilon = 1e-6
class_weights = 1.0 / (class_counts + epsilon)# 归一化权重,使其和为1
class_weights /= class_weights.sum()print('Class Counts:', class_counts)
print('Class Weights:', class_weights)# 创建带权重的交叉熵损失函数
criterion = nn.CrossEntropyLoss(weight=class_weights)# 计算损失值
loss = criterion(logits, labels)print('Logits:\n', logits)
print('Labels:\n', labels)
print('Weighted Cross-Entropy Loss:', loss.item())# 反向传播梯度
loss.backward()
报错
Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).float() 报错
Weight=torch.from_numpy(np.array([0.1, 0.8, 1.0, 1.0])).double() 正确
参考:[1]https://blog.csdn.net/CSDN_of_ding/article/details/111515226
[2] https://blog.csdn.net/qq_40306845/article/details/137651442
[3] https://www.zhihu.com/question/400443029/answer/2477658229
相关文章:
pytorch 加权CE_loss实现(语义分割中的类不平衡使用)
加权CE_loss和BCE_loss稍有不同 1.标签为long类型,BCE标签为float类型 2.当reduction为mean时计算每个像素点的损失的平均,BCE除以像素数得到平均值,CE除以像素对应的权重之和得到平均值。 参数配置torch.nn.CrossEntropyLoss(weightNone,…...
【iOS】UI——关于UIAlertController类(警告对话框)
目录 前言关于UIAlertController具体操作及代码实现总结 前言 在UI的警告对话框的学习中,我们发现UIAlertView在iOS 9中已经被废弃,我们找到UIAlertController来代替UIAlertView实现弹出框的功能,从而有了这篇关于UIAlertController的学习笔记…...
django支持https
测试环境,可以用django自带的证书 安装模块 sudo pip3 install django_sslserver服务端https启动 python3 manage.py runsslserver 127.0.0.1:8001https访问 https://127.0.0.1:8001/quota/api/XXX...
算法题day41(补5.27日卡:动态规划01)
一、动态规划基础知识:在动态规划中每一个状态一定是由上一个状态推导出来的。 动态规划五部曲: 1.确定dp数组 以及下标的含义 2.确定递推公式 3.dp数组如何初始化 4.确定遍历顺序 5.举例推导dp数组 debug方式:打印 二、刷题…...
【附带源码】机械臂MoveIt2极简教程(四)、第一个入门demo
系列文章目录 【附带源码】机械臂MoveIt2极简教程(一)、moveit2安装 【附带源码】机械臂MoveIt2极简教程(二)、move_group交互 【附带源码】机械臂MoveIt2极简教程(三)、URDF/SRDF介绍 【附带源码】机械臂MoveIt2极简教程(四)、第一个入门demo 目录 系列文章目录1. 创…...
基于蚁群算法的二维路径规划算法(matlab)
微♥关注“电击小子程高兴的MATLAB小屋”获得资料 一、理论基础 1、路径规划算法 路径规划算法是指在有障碍物的工作环境中寻找一条从起点到终点、无碰撞地绕过所有障碍物的运动路径。路径规划算法较多,大体上可分为全局路径规划算法和局部路径规划算法两大类。其…...
政务云参考技术架构
行业优势 总体架构 政务云平台技术框架图,由机房环境、基础设施层、支撑软件层及业务应用层组成,在运维、安全和运营体系的保障下,为政务云使用单位提供统一服务支撑。 功能架构 标准双区隔离 参照国家电子政务规范,打造符合标准的…...
android 13 aosp 预置so库
展讯对应的main.mk配置 device/sprd/qogirn**/ums***/product/***_native/main.mk $(call inherit-product-if-exists, vendor/***/build.mk)vendor/***/build.mk PRODUCT_PACKAGES \libtestvendor///Android.bp cc_prebuilt_library_shared{name:"libtest",srcs:…...
mongo篇---mongoDB Compass连接数据库
mongo篇—mongoDB Compass连接数据库 mongoDB笔记 – 第一条 一、mongoDB Compass连接远程数据库,配置URL。 URL: mongodb://username:passwordhost:port点击connect即可。 注意:host最好使用名称,防止出错连接超时。...
基于SOA海鸥优化算法的三维曲面最高点搜索matlab仿真
目录 1.程序功能描述 2.测试软件版本以及运行结果展示 3.核心程序 4.本算法原理 5.完整程序 1.程序功能描述 基于SOA海鸥优化算法的三维曲面最高点搜索matlab仿真,输出收敛曲线以及三维曲面最高点搜索结果。 2.测试软件版本以及运行结果展示 MATLAB2022A版本…...
前端js解析websocket推送的gzip压缩json的Blob数据
主要依赖插件pako https://www.npmjs.com/package/pako 1、安装 npm install pako 2、使用, pako.inflate(reader.result, {to: "string"}) 解压后的string 对象,需要JSON.parse转成json this.ws.onmessage (evt) > {console.log("…...
【wiki知识库】06.文档管理接口的实现--SpringBoot后端部分
目录 一、🔥今日目标 二、🎈SpringBoot部分类的添加 1.调用MybatisGenerator 2.添加DocSaveParam 3.添加DocQueryVo 三、🚆后端新增接口 3.1添加DocController 3.1.1 /all/{ebokId} 3.1.2 /doc/save 3.1.3 /doc/delete/{idStr} …...
c,c++,go语言字符串的演进
#include <stdio.h> #include <string.h> int main() {char str[] {a,b,c,\0,d,d,d};printf("string:[%s], len:%d \n", str, strlen(str) );return 0; } string:[abc], len:3 c语言只有数组的概念,数组本身没有长度的概念,需…...
vue-cli 快速入门
vue-cli (目前向Vite发展) 介绍:Vue-cli 是Vue官方提供一个脚手架,用于快速生成一个Vue的项目模板。 Vue-cli提供了如下功能: 统一的目录结构 本地调试 热部署 单元测试 集成打包上线 依赖环境:NodeJ…...
机器人--矩阵运算
两个矩阵相乘的含义 P点在坐标系B中的坐标系PB,需要乘以B到A到变换矩阵TAB。 M点在B坐标系中的位姿MB,怎么计算M在A中的坐标系? 两个矩阵相乘 一个矩阵*另一个矩阵的逆矩阵...
期末复习【汇总】
期末复习【汇总】 前言版权推荐期末复习【汇总】最后 前言 2024-5-12 20:52:17 截止到今天,所有期末复习的汇总 以下内容源自《【创作模板】》 仅供学习交流使用 版权 禁止其他平台发布时删除以下此话 本文首次发布于CSDN平台 作者是CSDN日星月云 博客主页是ht…...
【IM即时通讯】MQTT协议的详解(3)- CONNACK Packet
【IM即时通讯】MQTT协议的详解(3)- CONNACK Packet 文章目录 【IM即时通讯】MQTT协议的详解(3)- CONNACK Packet前言说明一、固定同步详解、可变头部详解总结 前言 关于所有的类型的数据示例已经在上面一篇博客说完: …...
Linux - 深入理解/proc虚拟文件系统:从基础到高级
文章目录 Linux /proc虚拟文件系统/proc/self使用 /proc/self 的优势/proc/self 的使用案例案例1:获取当前进程的状态信息案例2:获取当前进程的命令行参数案例3:获取当前进程的内存映射案例4:获取当前进程的文件描述符 /proc中进程…...
Django DeleteView视图
Django 的 DeleteView 是一个基于类的视图,用于处理对象的删除操作。 1,添加视图函数 Test/app3/views.py from django.shortcuts import render# Create your views here. from .models import Bookfrom django.views.generic import ListView class B…...
代码杂谈 之 pyspark如何做相似度计算
在 PySpark 中,计算 DataFrame 两列向量的差可以通过使用 UDF(用户自定义函数)和 Vector 类型完成。这里有一个示例,展示了如何使用 PySpark 的 pyspark.ml.linalg.Vectorspyspark.sql.functions.udf 来实现这一功能:…...
混剪素材哪里找?分享8个热门素材网站
今天我们来深入探讨如何获取高质量的混剪素材,为您的短视频和自媒体制作提供最佳资源。在这篇指南中,我将介绍几个热门的素材网站,让您轻松掌握素材获取的技巧,并根据百度SEO排名规则,优化关键词的使用,确保…...
临床应用的深度学习在视网膜疾病的诊断和转诊中的应用| 文献速递-视觉通用模型与疾病诊断
Title 题目 Clinically applicable deep learning for diagnosis and referral in retinal disease 临床应用的深度学习在视网膜疾病的诊断和转诊中的应用 01 文献速递介绍 诊断成像的数量和复杂性正在以比人类专家可用性更快的速度增加。人工智能在分类一些常见疾病的二…...
中继器简介
一、网络信号衰减问题 现在的网路信号有两种,一种是电信号,另一种的光信号,电信号在网线、电话线或者电视闭路线中传输,光信号在光缆中传输,但是不管是以那种信号进行传输,随着传输距离的增加,电…...
websocket 前端项目js示例
websocket前端 和服务端websocket通信示例, 前端直接使用h5的内置对象 WebSocket 来创建和管理 WebSocket 连接,以及可以通过该连接发送和接收数据。 这个对象都是是事件方式来处理和与后端交互数据, 他们分别是 onopen打开, onclose关闭, o…...
webapi跨越问题
由于浏览器存在同源策略,为了防止 钓鱼问题,浏览器直接请求才不会有跨越的问题 浏览器要求JavaScript或Cookie只能访问同域下的内容 浏览器也是一个应用程序,有很多限制,不能访问和使用电脑信息(获取cpu、硬盘等&#…...
你知道 npmrc 文档吗? ---- npmrc 关键作用介绍
你知道 npmrc 文档吗? ---- npmrc 关键作用介绍 你知道 npmrc 文档吗? ---- npmrc 关键作用介绍如何修改配置呢?日常开放常常需要置哪些信息呢?registry 信息配置限定包认证信息代理配置缓存配置安装行为 参考 你知道 npmrc 文档吗…...
发现 Laravel 中的 api 响应时间明显过长
背景 近期在排查网站后台页面功能时 发现,部分查询页面,明显响应时间过长(12秒),不合理 优先排查 接口运行时长 经过打印,发现代码是正常的,且时间仅需不到一秒 进一步怀疑是 VUE框架的渲染加载…...
如何在MySQL中创建不同的索引和用途?
目录 1 基本的 CREATE INDEX 语法 2 创建单列索引 3 创建多列索引 4 创建唯一索引 5 创建全文索引 6 在表创建时添加索引 7 使用 ALTER TABLE 添加索引 8 删除索引 9 索引管理的最佳实践 10 示例 在 MySQL 中,索引(index)是一种用于…...
maxwell同步mysql到kafka(一个服务器启动多个)
创建mysql同步用户 CREATE USER maxwell% IDENTIFIED BY 123456; GRANT ALL ON maxwell.* TO maxwell%; GRANT SELECT, REPLICATION CLIENT, REPLICATION SLAVE on *.* to maxwell%; 开启mysql binlog a.修改 /etc/my.cnf 配置 log-binmysql-bin # 开启binlog binlog-forma…...
实用软件分享---简单菜谱 0.3版本 几千种美食(安卓)
专栏介绍:本专栏主要分享一些实用的软件(Po Jie版); 声明1:软件不保证时效性;只能保证在写本文时,该软件是可用的;不保证后续时间该软件能一直正常运行;不保证没有bug;如果软件不可用了,我知道后会第一时间在题目上注明(已失效)。介意者请勿订阅。 声明2:本专栏的…...
客户对网站设计的要求/网站推广的作用在哪里
随着“数字新基建”以及5G的发展,很多领域的技术都有了更深厚的基础,而数字孪生、智慧工厂、智慧城市、虚拟仿真教学的发展也越来越快。部署方案也从最初的本地部署,到现在的webGL本地网络方式、实时云渲染技术加持等多种方案可选。每种方案各…...
wordpress全屏主题/企业宣传方式有哪些
目录 一、绑定 HTML class 1. 绑定对象 2. 绑定数组 3. 在组件上使用 二、绑定内联样式 1. 绑定对象 2. 绑定数组 3. 自动前缀 4. 样式多值 数据绑定的一个常见需求场景是操纵元素的 CSS class 列表和内联样式。因为 class 和 style 都是 attribute,我们可…...
建一个收费网站/阿森纳英超积分
本文转载自:Spring事务管理知多少?面试时会讲吗?工作时会用吗?来点简单的 Spring事务的其实就是数据库对事务的支持,使用JDBC的事务管理机制,就是利用java.sql.Connection对象完成对事务的提交 事务是指在一系列的数据…...
设计师可以做兼职的网站有哪些/网页设计与制作模板
RestCloud API服务编排平台,更轻量、更高性能的API可视化编排平台,基于微服务架构、快速构建企业服务总线、全面提升敏捷集成能力、每日调度API流程超过100W。 一、真正的高性能服务编排引擎 1、首创基于纯内存的流程调度引擎,是支持高频调度…...
电商店铺设计/关键词seo排名怎么样
数智融合时代,必将唤起思想与技术的嬗变与觉醒!绘蓝图 揭秘“数矩觉醒”归纳起来,数智融合时代,企业用户将面临以下三方面的挑战:如何迎接快速增长的数据洪流?如何有效地提升数据利用率?如何实现…...
大连网站制作推广/常用的网络推广手段有哪些
启动和关闭服务指令 启动:redis-server.exe --service-start 关闭:redis-server.exe --service-stop 下面也可以 redis-server --service-start redis-server --service-stop 卸载服务 指令:redis-server --service-uninstall...