当前位置: 首页 > news >正文

【论文复现】偏标记学习+图像分类

在这里插入图片描述

📝个人主页🌹:Eternity._
🌹🌹期待您的关注 🌹🌹

在这里插入图片描述
在这里插入图片描述

❀ 偏标记学习+图像分类

  • 概述
  • 算法原理
  • 核心逻辑
  • 效果演示
  • 使用方式
  • 参考文献

概述


本文复现论文 Progressive Identification of True Labels for Partial-Label Learning[1] 提出的偏标记学习方法。

随着深度神经网络的发展,机器学习任务对标注数据的需求不断增加。然而,大量的标注数据十分依赖人力资源与标注者的专业知识。弱监督学习可以有效缓解这一问题,因其不需要完全且准确的标注数据。该论文关注一个重要的弱监督学习问题——偏标记学习(Partial Label Learning),其中每个训练实例与一组候选标签相关联,但仅有一个标签是真实的。

在这里插入图片描述
该论文提出了一种渐进式真实标签识别方法,旨在训练过程中逐渐确定样本的真实标签。该论文所提出的方法获得了接近监督学习的性能,且与具体的网络结构、损失函数、随机优化算法无关。

本文所涉及的所有资源的获取方式:这里

算法原理


传统的监督学习常用交叉熵损失和随机梯度下降来优化深度神经网络。交叉熵损失定义如下:
在这里插入图片描述
其中, x 表示样本特征; [ y = [ y 1 , y 2 , … , y c ] ] [ \mathbf{y} = [y_1, y_2, \ldots, y_c] ] [y=[y1,y2,,yc]]表示样本标签,其为独热码,即除了真实标签对应维度值为 1,其余为零; [ f i ( x ; θ ) ] [ f_i(x; \theta) ] [fi(x;θ)]表示模型预测样本 x 标签为 i 的概率。

该论文提出的方法使用一个软标签 [ y ^ = [ y ^ 1 , y ^ 2 , … , y ^ c ] ] [ \hat{y} = [\hat{y}_1, \hat{y}_2, \ldots, \hat{y}_c] ] [y^=[y^1,y^2,,y^c]],其对任意 [ i ∈ [ 0 , c ] ] [ i \in [0, c] ] [i[0,c]]满足 [ ∑ i y ^ i = 1 且 0 ≤ y ^ i ≤ 1 ] [ \sum_{i} \hat{y}_i = 1 \quad \text{且} \quad 0 \leq \hat{y}_i \leq 1 ] [iy^i=10y^i1]为了使用该软标签,论文根据候选标签集 s 对软标签进行初始化:
在这里插入图片描述

为了渐进式地识别真实标签,算法在每次更新参数之前,根据预测结果为下轮训练使用的软标签赋值:
在这里插入图片描述
其中, [ I ( j ∈ s ) = { 1 当且仅当  j ∈ s 为真 0 否则 ] [ I(j \in s) = \begin{cases} 1 & \text{当且仅当 } j \in s \text{ 为真} \\ 0 & \text{否则} \end{cases} ] [I(js)={10当且仅当 js 为真否则]

核心逻辑


具体的核心逻辑如下所示:

import models
import datasets
import torch
from torch.utils.data import DataLoader
import numpy as np
import torch.optim as optim
from torch.optim.lr_scheduler import StepLR
import torch.nn.functional as F
import torchvision.transforms as transforms
from tqdm import tqdmdef CE_loss(probs, targets):"""交叉熵损失函数"""loss = -torch.sum(targets * torch.log(probs), dim = -1)loss_avg = torch.sum(loss)/probs.shape[0]return loss_avgclass Proden:def __init__(self, configs):self.configs = configsdef train(self, save = False):configs = self.configs# 读取数据集dataset_path = configs['dataset path']if configs['dataset'] == 'CIFAR-10':train_data, train_labels, test_data, test_labels = datasets.cifar10_read(dataset_path)train_dataset = datasets.Cifar(train_data, train_labels)test_dataset = datasets.Cifar(test_data, test_labels)output_dimension = 10elif configs['dataset'] == 'CIFAR-100':train_data, train_labels, test_data, test_labels = datasets.cifar100_read(dataset_path)train_dataset = datasets.Cifar(train_data, train_labels)test_dataset = datasets.Cifar(test_data, test_labels)output_dimension = 100# 生成偏标记partial_labels = datasets.generate_partial_labels(train_labels, configs['partial rate'])train_dataset.load_partial_labels(partial_labels)# 计算数据的均值和方差,用于模型输入的标准化mean = [np.mean(train_data[:, i, :, :]) for i in range(3)]std = [np.std(train_data[:, i, :, :]) for i in range(3)]normalize = transforms.Normalize(mean, std)# 设备:GPU或CPUdevice = torch.device('cuda' if torch.cuda.is_available() else 'cpu')# 加载模型if configs['model'] == 'ResNet18':model = models.ResNet18(output_dimension = output_dimension).to(device)elif configs['model'] == 'ConvNet':model = models.ConvNet(output_dimension = output_dimension).to(device)# 设置学习率等超参数lr = configs['learning rate']weight_decay = configs['weight decay']momentum = configs['momentum']optimizer = optim.SGD(model.parameters(), lr = lr, weight_decay = weight_decay, momentum = momentum)lr_step = configs['learning rate decay step']lr_decay = configs['learning rate decay rate']lr_scheduler = StepLR(optimizer, step_size=lr_step, gamma=lr_decay)for epoch_id in range(configs['epoch count']):# 训练模型train_dataloader = DataLoader(train_dataset, batch_size = configs['batch size'], shuffle = True)model.train()for batch in tqdm(train_dataloader, desc='Training(Epoch %d)' % epoch_id, ascii=' 123456789#'):ids = batch['ids']# 标准化输入data = normalize(batch['data'].to(device))partial_labels = batch['partial_labels'].to(device)targets = batch['targets'].to(device)optimizer.zero_grad()# 计算预测概率logits = model(data)probs = F.softmax(logits, dim=-1)# 更新软标签with torch.no_grad():new_targets = F.normalize(probs * partial_labels, p=1, dim=-1)train_dataset.targets[ids] = new_targets.cpu().numpy()# 计算交叉熵损失loss = CE_loss(probs, targets)loss.backward()# 更新模型参数optimizer.step()# 调整学习率lr_scheduler.step()

以上代码仅作展示,更详细的代码文件请参见附件。

效果演示


我提前在 CIFAR-10[2] 数据集和 12 层的 ConvNet[3] 网络上训练了一份模型参数。为了测试其准确率,需要配置环境并运行main.py脚本,得到结果如下:
在这里插入图片描述
由图可见,该算法在测试集上获得了 89.8% 的准确率。

进一步地,测试训练出的模型在真实图片上的预测结果。在线部署模型后,将一张轮船的图片输入,可以得到输出的预测类型为 “Ship”:
在这里插入图片描述
在这里插入图片描述
在这里插入图片描述
网站提供了在线演示功能,使用者请输入一张小于1MB、类别为上述十个类别之一、长宽尽可能相等的JPG图像。

使用方式


解压附件压缩包并进入工作目录。如果是Linux系统,请使用如下命令:

unzip Proden-implemention.zip
cd Proden-implemention

代码的运行环境可通过如下命令进行配置:

pip install -r requirements.txt

运行如下命令以下载并解压数据集

bash download.sh

如果希望在本地训练模型,请运行如下命令:

python main.py -c [你的配置文件路径] -r [选择下者之一:"train""test""infer"]

如果希望在线部署,请运行如下命令:

python main-flask.py

参考文献


[1] Lv J, Xu M, Feng L, et al. Progressive identification of true labels for partial-label learning[C]//International conference on machine learning. PMLR, 2020: 6500-6510.

[2] Krizhevsky A, Hinton G. Learning multiple layers of features from tiny images[J]. 2009.

[3] Laine S, Aila T. Temporal ensembling for semi-supervised learning[J]. arXiv preprint arXiv:1610.02242, 2016.


编程未来,从这里启航!解锁无限创意,让每一行代码都成为你通往成功的阶梯,帮助更多人欣赏与学习!

更多内容详见:这里

相关文章:

【论文复现】偏标记学习+图像分类

📝个人主页🌹:Eternity._ 🌹🌹期待您的关注 🌹🌹 ❀ 偏标记学习图像分类 概述算法原理核心逻辑效果演示使用方式参考文献 概述 本文复现论文 Progressive Identification of True Labels for Pa…...

C嘎嘎探索篇:栈与队列的交响:C++中的结构艺术

C嘎嘎探索篇:栈与队列的交响:C中的结构艺术 前言: 小编在之前刚完成了C中栈和队列(stack和queue)的讲解,忘记的小伙伴可以去我上一篇文章看一眼的,今天小编将会带领大家吹奏栈和队列的交响&am…...

AIGC-----AIGC在虚拟现实中的应用前景

AIGC在虚拟现实中的应用前景 引言 随着人工智能生成内容(AIGC)的快速发展,虚拟现实(VR)技术的应用也迎来了新的契机。AIGC与VR的结合为创造沉浸式体验带来了全新的可能性,这种组合不仅极大地降低了VR内容的…...

Django 路由层

1. 路由基础概念 URLconf (URL 配置):Django 的路由系统是基于 urls.py 文件定义的。路径匹配:通过模式匹配 URL,并将请求传递给对应的视图处理函数。命名路由:每个路由可以定义一个名称,用于反向解析。 2. 基本路由配…...

《硬件架构的艺术》笔记(八):消抖技术

简介 在电子设备中两个金属触点随着触点的断开闭合便产生了多个信号,这就是抖动。 消抖是用来确保每一次断开或闭合触点时只有一个信号起作用的硬件设备或软件。(就是每次断开闭合只对应一个操作)。 抖动在某些模拟和逻辑电路中可能产生问…...

Spring 与 Spring MVC 与 Spring Boot三者之间的区别与联系

一.什么是Spring?它解决了什么问题? 1.1什么是Spring? Spring,一般指代的是Spring Framework 它是一个开源的应用程序框架,提供了一个简易的开发方式,通过这种开发方式,将避免那些可能致使代码…...

【算法】连通块问题(C/C++)

目录 连通块问题 解决思路 步骤: 初始化: DFS函数: 复杂度分析 代码实现(C) 题目链接:2060. 奶牛选美 - AcWing题库 解题思路: AC代码: 题目链接:687. 扫雷 -…...

如何选择黑白相机和彩色相机

我们在选择成像解决方案时黑白相机很容易被忽略,因为许多新相机提供鲜艳的颜色,鲜明的对比度和改进的弱光性能。然而,有许多应用,选择黑白相机将是更好的选择,因为他们产生更清晰的图像,更好的分辨率&#…...

Rust 力扣 - 740. 删除并获得点数

文章目录 题目描述题解思路题解代码题目链接 题目描述 题解思路 首先对于这题我们如果将所有点数装入一个切片f中,该切片f中的i号下标表示所有点数为i的点数之和 那么这题就转换成了打家劫舍这道题,也就是求选择了切片中某个下标的元素后,该…...

OpenCV从入门到精通实战(七)——探索图像处理:自定义滤波与OpenCV卷积核

本文主要介绍如何使用Python和OpenCV库通过卷积操作来应用不同的图像滤波效果。主要分为几个步骤:图像的读取与处理、自定义卷积函数的实现、不同卷积核的应用,以及结果的展示。 卷积 在图像处理中,卷积是一种重要的操作,它通过…...

Docker核心概念总结

本文只是对 Docker 的概念做了较为详细的介绍,并不涉及一些像 Docker 环境的安装以及 Docker 的一些常见操作和命令。 容器介绍 Docker 是世界领先的软件容器平台,所以想要搞懂 Docker 的概念我们必须先从容器开始说起。 什么是容器? 先来看看容器较为…...

环形缓冲区

什么是环形缓冲区 环形缓冲区,也称为循环缓冲区或环形队列,是一种特殊的FIFO(先进先出)数据结构。它使用一块固定大小的内存空间来缓存数据,并通过两个指针(读指针和写指针)来管理数据的读写。当任意一个指针到达缓冲区末尾时,会自动回绕到缓冲区开头,形成一个"环"。…...

jQuery-Word-Export 使用记录及完整修正文件下载 jquery.wordexport.js

参考资料: jQuery-Word-Export导出word_jquery.wordexport.js下载-CSDN博客 近期又需要自己做个 Html2Doc 的解决方案,因为客户又不想要 Html2pdf 的下载了,当初还给我费尽心思解决Html转pdf时中文输出的问题(html转pdf文件下载之…...

云服务器部署WebSocket项目

WebSocket是一种在单个TCP连接上进行全双工通信的协议,其设计的目的是在Web浏览器和Web服务器之间进行实时通信(实时Web) WebSocket协议的优点包括: 1. 更高效的网络利用率:与HTTP相比,WebSocket的握手只…...

C#+数据库 实现动态权限设置

将权限信息存储在数据库中,支持动态调整。根据用户所属的角色、特定的功能模块,动态加载权限” 1. 数据库设计 根据这种需求,可以通过以下表设计: 用户表 (Users):存储用户信息。角色表 (Roles):存储角色…...

(原创)Android Studio新老界面UI切换及老版本下载地址

前言 这两天下载了一个新版的Android Studio,发现整个界面都发生了很大改动: 新的界面的一些设置可参考一些博客: Android Studio新版UI常用设置 但是对于一些急着开发的小伙伴来说,没有时间去适应,那么怎么办呢&am…...

Ubuntu24虚拟机-gnome-boxes

推荐使用gnome-boxes, virtualbox构建失败,multipass需要开启防火墙 sudo apt install gnome-boxes创建完毕~...

k8s rainbond centos7/win10 -20241124

参考 https://www.rainbond.com/ 国内一站式云原生平台 对centos7环境支持不太行 [lighthouseVM-16-5-centos ~]$ curl -o install.sh https://get.rainbond.com && bash ./install.sh 2024-11-24 09:56:57 ERROR: Ops! Docker daemon is not running. Start docke…...

SpringBoot+Vue滑雪社区网站设计与实现

【1】系统介绍 研究背景 随着互联网技术的快速发展和冰雪运动的普及,滑雪作为一种受欢迎的冬季运动项目,吸引了越来越多的爱好者。与此同时,社交媒体和在线社区平台的兴起为滑雪爱好者提供了一个交流经验、分享心得、获取信息的重要渠道。滑…...

MySql.2

sql查询语句执行过程 SQL 查询语句的执行过程是一个复杂的过程,涉及多个步骤。以下是典型的关系数据库管理系统 (RDBMS) 中 SQL 查询语句的执行过程概述: 1. ‌客户端发送查询‌ 用户通过 SQL 客户端或应用程序发送 SQL 查询语句给数据库服务器。 2. ‌…...

阿里云ACP云计算备考笔记 (5)——弹性伸缩

目录 第一章 概述 第二章 弹性伸缩简介 1、弹性伸缩 2、垂直伸缩 3、优势 4、应用场景 ① 无规律的业务量波动 ② 有规律的业务量波动 ③ 无明显业务量波动 ④ 混合型业务 ⑤ 消息通知 ⑥ 生命周期挂钩 ⑦ 自定义方式 ⑧ 滚的升级 5、使用限制 第三章 主要定义 …...

FFmpeg 低延迟同屏方案

引言 在实时互动需求激增的当下,无论是在线教育中的师生同屏演示、远程办公的屏幕共享协作,还是游戏直播的画面实时传输,低延迟同屏已成为保障用户体验的核心指标。FFmpeg 作为一款功能强大的多媒体框架,凭借其灵活的编解码、数据…...

鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院挂号小程序

一、开发准备 ​​环境搭建​​: 安装DevEco Studio 3.0或更高版本配置HarmonyOS SDK申请开发者账号 ​​项目创建​​: File > New > Create Project > Application (选择"Empty Ability") 二、核心功能实现 1. 医院科室展示 /…...

Qt Http Server模块功能及架构

Qt Http Server 是 Qt 6.0 中引入的一个新模块,它提供了一个轻量级的 HTTP 服务器实现,主要用于构建基于 HTTP 的应用程序和服务。 功能介绍: 主要功能 HTTP服务器功能: 支持 HTTP/1.1 协议 简单的请求/响应处理模型 支持 GET…...

【单片机期末】单片机系统设计

主要内容:系统状态机,系统时基,系统需求分析,系统构建,系统状态流图 一、题目要求 二、绘制系统状态流图 题目:根据上述描述绘制系统状态流图,注明状态转移条件及方向。 三、利用定时器产生时…...

C# 类和继承(抽象类)

抽象类 抽象类是指设计为被继承的类。抽象类只能被用作其他类的基类。 不能创建抽象类的实例。抽象类使用abstract修饰符声明。 抽象类可以包含抽象成员或普通的非抽象成员。抽象类的成员可以是抽象成员和普通带 实现的成员的任意组合。抽象类自己可以派生自另一个抽象类。例…...

C++.OpenGL (10/64)基础光照(Basic Lighting)

基础光照(Basic Lighting) 冯氏光照模型(Phong Lighting Model) #mermaid-svg-GLdskXwWINxNGHso {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-GLdskXwWINxNGHso .error-icon{fill:#552222;}#mermaid-svg-GLd…...

A2A JS SDK 完整教程:快速入门指南

目录 什么是 A2A JS SDK?A2A JS 安装与设置A2A JS 核心概念创建你的第一个 A2A JS 代理A2A JS 服务端开发A2A JS 客户端使用A2A JS 高级特性A2A JS 最佳实践A2A JS 故障排除 什么是 A2A JS SDK? A2A JS SDK 是一个专为 JavaScript/TypeScript 开发者设计的强大库&#xff…...

MFE(微前端) Module Federation:Webpack.config.js文件中每个属性的含义解释

以Module Federation 插件详为例,Webpack.config.js它可能的配置和含义如下: 前言 Module Federation 的Webpack.config.js核心配置包括: name filename(定义应用标识) remotes(引用远程模块&#xff0…...

es6+和css3新增的特性有哪些

一:ECMAScript 新特性(ES6) ES6 (2015) - 革命性更新 1,记住的方法,从一个方法里面用到了哪些技术 1,let /const块级作用域声明2,**默认参数**:函数参数可以设置默认值。3&#x…...