当前位置: 首页 > news >正文

Channel-wise Knowledge Distillation for Dense Prediction(ICCV 2021)原理与代码解析

paper:Channel-wise Knowledge Distillation for Dense Prediction

official implementation:https://github.com/irfanICMLL/TorchDistiller/tree/main/SemSeg-distill 

摘要 

之前大多数用于密集预测dense prediction任务的蒸馏方法在空间域spatial domain中将教师和学生网络的激活图进行对齐,通过标准化normalize每个空间位置的激活值,并减小point-wise and/or pair-wise之间的差异来实现知识的传递。

和之前的方法不同,本文提出对每个通道的激活图进行标准化从而得到一个soft probabality map,通过减小两个网络channel-wise概率图之间的KL散度,使蒸馏过程更关注每个通道最显著的区域,这对密集预测任务非常有价值。

背景

密集预测任务是像素级的预测问题,比图像级的分类问题更具挑战性。以往的研究发现,直接将分类中的蒸馏方法应用到语义分割中得到的效果令人无法满意。严格地对齐教师和学生网络之间poit-wise分类得分或特征图可能会施加过于严格的约束,只能得到次优解。

最近的一些研究主要关注于加强不同空间位置之间的相关性。如图2(a)所示,每个空间位置上的激活值被标准化然后通过聚合不同空间位置的子集来得到一些特定任务的关系,比如pair-wise relations和inter-class relations。这些方法在捕获空间结构信息和提高学生网络的性能方面可能比point-wise对齐效果更好。然而,激活图中的每个空间位置对知识转移的贡献相同,这可能会从教师网络带来冗余的信息。 

本文提出了一种新的通道级channel-wise知识蒸馏方法,通过对每个通道中的激活图进行标准化,用于密集预测任务,如图2(b)所示。然后减小教师和学生网络标准化后的通道激活图之间的KL散度。一个channel-wise distribution的例子如图2(c)所示,可以看出每个通道的激活图倾向于编码特定场景类别的显著区域。对于每个通道,引导学生网络更加关注于模拟激活值大的区域,从而在密集预测任务中实现更准确的定位。比如在目标检测任务中,让学生网络更关注于学习前景区域的激活。

本文的贡献

  • 与现有的spatial蒸馏方法不同,本文提出了一种新的通道蒸馏范式用于密集预测任务,方法简单有效。
  • 在语义分割和目标检测方面,本文提出的通道级蒸馏方法显著优于最先进的KD方法。
  • 我们在语义分割和目标检测任务上的四个基准数据集上进行了一致的改进,证明了我们的方法是通用的。鉴于它的简单性和有效性,我们相信我们的方法可以作为一个strong baseline KD方法用于密集预测任务。

方法介绍

为了更好的利用每个通道中的知识,作者提出softly对齐教师网络和学生网络对应通道的激活。为此,首先将一个通道的激活转换为概率分布,这样就可以使用一个概率距离度量例如KL散度来衡量差异。如图2(c)所示,不同通道的激活倾向于编码输入图像中某个特定类别场景的显著saliency区域。此外,一个训练好的语义分割教师模型在每个通道显示出了清晰的类别特定掩码的激活图,这是符合预期的,如图1右侧所示。因此作者提出了一种新的通道蒸馏范式来指导学生从一个训练有素的教师那里学习知识。

首先定义教师网络和学生网络分别为 \(T\) 和 \(S\),\(T\) 和 \(S\) 的激活分别表示为 \(y^{T}\) 和 \(y^{S}\),通道蒸馏损失的一般表示形式如下

其中 \(\phi(\cdot)\) 用于将激活值转换为概率分布,如下

其中 \(c=1,2,...,C\) 表示通道,\(i\) 示一个通道的spatial location的索引,\(\mathcal{T}\) 是温度超参。如果我们使用更大的 \(\mathcal{T}\),概率分布会变得更softer,意味着每个通道中关注的spatial region更加wider。通过使用softmax归一化,消除了大网络和小网络之间尺度的差异。如果教师网络和学生网络的通道不匹配,则使用1x1卷积上采样小网络的通道数使两者相等。\(\varphi(\cdot)\) 评估教师模型和学生模型通道分布之间的差异,具体使用KL散度

KL散度是一种非对称度量。从式4可以看出,如果 \(\phi(y^{T}_{c,i})\) 很大,\(\phi(y^{S}_{c,i})\) 应该像 \(\phi(y^{T}_{c,i})\) 一样大从而减小KL散度。相反,如果 \(\phi(y^{T}_{c,i})\) 非常小,KL散度对减小 \(\phi(y^{S}_{c,i})\) 的关注相对较少。通过教师网络的监督,学生网络倾向于在前景显著区域产生和教师网络相似的激活分布,而教师网络背景区域中的激活对学生网络的影响较小。作者认为KL这种不对称性有助于密集预测任务中蒸馏的学习。

和分类任务中的channel-wise蒸馏的区别

在Channel Distillation: Channel-Wise Attention for Knowledge Distillation这篇文章中,作者也提出了使用通道蒸馏的方式,但主要应用于分类任务。受SENet的启发,通过全局平局池化将一个通道的特征图转换为一个标量,然后应用KL散度衡量教师网络和学生网络对应通道的标量之间的差异。

而本文主要考虑的是密集预测任务,GAP或许对image-level的分类任务有帮助,但所有空间位置的权重相同,丢失了空间信息,不适用于密集预测任务。本文通过softmax标准化的方式,考虑到了不同空间位置重要性的不同,保留了空间位置信息,因此更适用于密集预测任务。

实验结果

表2是本文提出的channel蒸馏和其它spatial蒸馏在Cityscapes数据集上复杂度和验证集上的mIoU的对比,可以看出通道蒸馏的精度最高,且复杂度较小

表5是在Cityscapes数据集上,不同的学生模型用不同的蒸馏方法的精度对比,可以看出,对于不同结构的学生网络,本文提出的通道蒸馏的效果都要好于其它蒸馏方法。

 

表6是在目标检测任务上与其它蒸馏方法的对比,可以看出,在两阶段、单阶段、anchor-free不同结构的目标检测模型中,本文提出的通道蒸馏的效果都是最好的。

代码解析

官方实现如下,其中归一化采用channel_norm损失采用KL损失是论文中给出的方法,官方实现中还给出了其它的归一化方法和损失函数的选择。 

import torch.nn as nnclass ChannelNorm(nn.Module):def __init__(self):super(ChannelNorm, self).__init__()def forward(self, featmap):n, c, h, w = featmap.shapefeatmap = featmap.reshape((n, c, -1))featmap = featmap.softmax(dim=-1)return featmapclass CriterionCWD(nn.Module):def __init__(self, norm_type='none', divergence='mse', temperature=1.0):super(CriterionCWD, self).__init__()# define normalize functionif norm_type == 'channel':self.normalize = ChannelNorm()elif norm_type == 'spatial':self.normalize = nn.Softmax(dim=1)elif norm_type == 'channel_mean':self.normalize = lambda x: x.view(x.size(0), x.size(1), -1).mean(-1)else:self.normalize = Noneself.norm_type = norm_typeself.temperature = 1.0# define loss functionif divergence == 'mse':self.criterion = nn.MSELoss(reduction='sum')elif divergence == 'kl':self.criterion = nn.KLDivLoss(reduction='sum')self.temperature = temperatureself.divergence = divergencedef forward(self, preds_S, preds_T):n, c, h, w = preds_S.shape# import pdb;pdb.set_trace()if self.normalize is not None:norm_s = self.normalize(preds_S / self.temperature)norm_t = self.normalize(preds_T.detach() / self.temperature)else:norm_s = preds_S[0]norm_t = preds_T[0].detach()if self.divergence == 'kl':norm_s = norm_s.log()loss = self.criterion(norm_s, norm_t)# item_loss = [round(self.criterion(norm_t[0][0].log(),norm_t[0][i]).item(),4) for i in range(c)]# import pdb;pdb.set_trace()if self.norm_type == 'channel' or self.norm_type == 'channel_mean':loss /= n * c# loss /= n * h * welse:loss /= n * h * wreturn loss * (self.temperature ** 2)

 

相关文章:

Channel-wise Knowledge Distillation for Dense Prediction(ICCV 2021)原理与代码解析

paper:Channel-wise Knowledge Distillation for Dense Prediction official implementation:https://github.com/irfanICMLL/TorchDistiller/tree/main/SemSeg-distill 摘要 之前大多数用于密集预测dense prediction任务的蒸馏方法在空间域spatial…...

No.052<软考>《(高项)备考大全》【冲刺6】《软考之 119个工具 (4)》

《软考之 119个工具 (4)》 61.人际交往:62.组织理论:63.预分派:64.谈判:65.招募:66.虚拟团队:67.多标准决策分析:68.人际关系技能:69.培训:70.团队建设活动:71.基本规则:72.集中办公:73.认可与奖励:74.人事评测工具:75.观察和交谈:76.项目绩效评估:77.冲…...

Go | 一分钟掌握Go | 9 - 通道

作者:Mars酱 声明:本文章由Mars酱编写,部分内容来源于网络,如有疑问请联系本人。 转载:欢迎转载,转载前先请联系我! 前言 在Java中,多线程之间的通信方式有哪些?记得吗&…...

【建议收藏】计算机视觉是什么?这几个计算机视觉的核心任务你真的了解吗?

文章目录 📚引言📖计算机视觉的核心任务📑图像分类和对象识别📑目标检测📑语义分割📑实例分割📑图像生成 📖计算机视觉的应用领域📑人脸识别📑自动驾驶&#…...

BatteryChargingSpecification1.2中文详解

1. Introduction 1.1 Scope 规范定义了设备通过USB端口充电的检测、控制和报告机制,这些机制是USB2.0规范的扩展,用于专用 充电器(DCP)、主机(SDP)、hub(SDP)和CDP(大电流充电端口)对设备的充电和power up。这些机制适用 于兼…...

基于Jenkins,docker实现自动化部署(持续交互)【转】

前言 随着业务的增长,需求也开始增多,每个需求的大小,开发周期,发布时间都不一致。基于微服务的系统架构,功能的叠加,对应的服务的数量也在增加,大小功能的快速迭代,更加要求部署的…...

漫谈大数据 - 数据湖认知篇

导语:数据湖是目前比较热的一个概念,许多企业都在构建或者准备构建自己的数据湖。但是在计划构建数据湖之前,搞清楚什么是数据湖,明确一个数据湖项目的基本组成,进而设计数据湖的基本架构,对于数据湖的构建…...

阿里云国际版ACE与国内版ACE区别

1.国际版ACE与国内版ACE有哪些不同 2.国际版ACP/ACE约考流程 2.1 登录VUE官方网站约考 https://www.pearsonvue.com.cn/Clients/Alibaba-Cloud-Certification.aspx ​ 2.2 如果之前有注册过账户,那就直接登录,如果还没有账户,那就创建账户 2.…...

Mysql8.0 gis支持

GIS数据类型 MySQL的GIS功能遵守OGC的OpenGIS Geometry Model,支持其定义的空间数据类型的一个子集,包括以下空间数据类型: GEOMETRY:不可实例化的数据类型,但是可以作为一个列的类型,存储任何一种其他类型的数据POIN…...

汇编---Nasm

文章目录 比较流行的汇编语言有3种:不同风格的汇编语言在语法格式上会有不同: 实战代码:Intrinsic函数手写汇编(8086汇编)调用C的API库函数调用约定实际代码 C调用汇编函数进行计算纯C实现如下:CASM实现:纯ASM实现:ASM打印命令行参…...

NDK OpenGL渲染画面效果

NDK系列之OpenGL渲染画面效果技术实战,本节主要是通过OpenGL Java库(谷歌对OpenGL C库做了JIN封装,核心实现还是在Native层),实现页面渲染,自定义渲染特效。 实现效果: 实现逻辑: 1…...

常见的深度学习框架

框架优点缺点TensorFlow- 由Google开发和维护,社区庞大,学习资源丰富- 具备优秀的性能表现,支持大规模分布式计算- 支持多种编程语言接口,易于使用- 提供了可视化工具TensorBoard,可用于调试和可视化模型- 底层架构复杂…...

【设计模式】七大设计原则--------单一职责原则

文章目录 1.案例1.1 原始案例1.2 改进一:类上遵循单一职责原则1.3 改进二:方法上遵循单一职责原则 2.小结 1.案例 1.1 原始案例 package com.sdnu.principle.singleresponsibility; //客户端 public class singleResponsibility {public static void m…...

MySQL-中间件mycat(一)

目录 🍁mycat基础概念 🍁Mycat安装部署 🍃初始环境 🍃测试环境 🍃下载安装 🍃修改配置文件 🍃启动mycat 🍃测试连接 🦐博客主页:大虾好吃吗的博客 &#x1f9…...

ARM寄存器组织

ARM有37个32位长的寄存器: 1个用做PC(Program Counter); 1个用做CPSR(Current Program Status Register); 5个用做SPSR(Saved Program Status Registers); 30个通用寄存器。 AR…...

记录一次webdav协议磁盘挂载经验总结

记录一次磁盘挂载经验总结 文章目录 记录一次磁盘挂载经验总结适配环境服务器协议适配方案脚本与详细说明 适配环境 windows 11windows 10windows 7 x86 and x64linuxuos统信国产化linux系统 服务器协议 webdav 适配方案 一、通用 winfsprclone 已验证通过,版…...

安装Django

1. 在物理环境安装Django Python官方的PyPi仓库为我们提供了一个统一的代码托管仓库,所有的第三方库,甚至你自己写的开源模块,都可以发布到这里,让全世界的人分享下载 pip是最有名的Python包管理工具 。提供了对Python包的查找、…...

【前端面经】JS-如何使用 JavaScript 来判断用户设备类型?

在 Web 开发中,有时需要针对不同的设备类型进行不同的处理。例如,对于移动设备,我们可能需要采用不同的布局或者交互方式,以提供更好的用户体验。因此,如何判断用户设备类型成为了一个重要的问题。 1. 使用 navigator…...

压缩HTML引用字体

内容简介 有些网站为了凸显某部分字体,而引入自定义字体,但由于自定义字体相对都比较大(几M),导致页面加载缓慢;所以本文介绍三种压缩字体的方法,可根据项目情况自行选择。 压缩方法 1、利用Fontmin程序&a…...

大厂高频面试:底层的源码逻辑知多少?

你好,我是何辉。今天我们来聊一聊Dubbo的大厂高频面试题。 大厂面试,一般重点考察对技术理解的深度,和中小厂的区别在于,不仅要你精于实战,还要你深懂原理,勤于思考并针对功能进行合理的设计。 网上一直流…...

rknn优化教程(二)

文章目录 1. 前述2. 三方库的封装2.1 xrepo中的库2.2 xrepo之外的库2.2.1 opencv2.2.2 rknnrt2.2.3 spdlog 3. rknn_engine库 1. 前述 OK,开始写第二篇的内容了。这篇博客主要能写一下: 如何给一些三方库按照xmake方式进行封装,供调用如何按…...

《Qt C++ 与 OpenCV:解锁视频播放程序设计的奥秘》

引言:探索视频播放程序设计之旅 在当今数字化时代,多媒体应用已渗透到我们生活的方方面面,从日常的视频娱乐到专业的视频监控、视频会议系统,视频播放程序作为多媒体应用的核心组成部分,扮演着至关重要的角色。无论是在个人电脑、移动设备还是智能电视等平台上,用户都期望…...

.Net框架,除了EF还有很多很多......

文章目录 1. 引言2. Dapper2.1 概述与设计原理2.2 核心功能与代码示例基本查询多映射查询存储过程调用 2.3 性能优化原理2.4 适用场景 3. NHibernate3.1 概述与架构设计3.2 映射配置示例Fluent映射XML映射 3.3 查询示例HQL查询Criteria APILINQ提供程序 3.4 高级特性3.5 适用场…...

【2025年】解决Burpsuite抓不到https包的问题

环境:windows11 burpsuite:2025.5 在抓取https网站时,burpsuite抓取不到https数据包,只显示: 解决该问题只需如下三个步骤: 1、浏览器中访问 http://burp 2、下载 CA certificate 证书 3、在设置--隐私与安全--…...

Spring Boot面试题精选汇总

🤟致敬读者 🟩感谢阅读🟦笑口常开🟪生日快乐⬛早点睡觉 📘博主相关 🟧博主信息🟨博客首页🟫专栏推荐🟥活动信息 文章目录 Spring Boot面试题精选汇总⚙️ **一、核心概…...

安宝特方案丨船舶智造的“AR+AI+作业标准化管理解决方案”(装配)

船舶制造装配管理现状:装配工作依赖人工经验,装配工人凭借长期实践积累的操作技巧完成零部件组装。企业通常制定了装配作业指导书,但在实际执行中,工人对指导书的理解和遵循程度参差不齐。 船舶装配过程中的挑战与需求 挑战 (1…...

C++.OpenGL (14/64)多光源(Multiple Lights)

多光源(Multiple Lights) 多光源渲染技术概览 #mermaid-svg-3L5e5gGn76TNh7Lq {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-3L5e5gGn76TNh7Lq .error-icon{fill:#552222;}#mermaid-svg-3L5e5gGn76TNh7Lq .erro…...

LINUX 69 FTP 客服管理系统 man 5 /etc/vsftpd/vsftpd.conf

FTP 客服管理系统 实现kefu123登录,不允许匿名访问,kefu只能访问/data/kefu目录,不能查看其他目录 创建账号密码 useradd kefu echo 123|passwd -stdin kefu [rootcode caozx26420]# echo 123|passwd --stdin kefu 更改用户 kefu 的密码…...

STM32HAL库USART源代码解析及应用

STM32HAL库USART源代码解析 前言STM32CubeIDE配置串口USART和UART的选择使用模式参数设置GPIO配置DMA配置中断配置硬件流控制使能生成代码解析和使用方法串口初始化__UART_HandleTypeDef结构体浅析HAL库代码实际使用方法使用轮询方式发送使用轮询方式接收使用中断方式发送使用中…...

MySQL:分区的基本使用

目录 一、什么是分区二、有什么作用三、分类四、创建分区五、删除分区 一、什么是分区 MySQL 分区(Partitioning)是一种将单张表的数据逻辑上拆分成多个物理部分的技术。这些物理部分(分区)可以独立存储、管理和优化,…...