Channel-wise Knowledge Distillation for Dense Prediction(ICCV 2021)原理与代码解析
paper:Channel-wise Knowledge Distillation for Dense Prediction
official implementation:https://github.com/irfanICMLL/TorchDistiller/tree/main/SemSeg-distill
摘要
之前大多数用于密集预测dense prediction任务的蒸馏方法在空间域spatial domain中将教师和学生网络的激活图进行对齐,通过标准化normalize每个空间位置的激活值,并减小point-wise and/or pair-wise之间的差异来实现知识的传递。
和之前的方法不同,本文提出对每个通道的激活图进行标准化从而得到一个soft probabality map,通过减小两个网络channel-wise概率图之间的KL散度,使蒸馏过程更关注每个通道最显著的区域,这对密集预测任务非常有价值。
背景
密集预测任务是像素级的预测问题,比图像级的分类问题更具挑战性。以往的研究发现,直接将分类中的蒸馏方法应用到语义分割中得到的效果令人无法满意。严格地对齐教师和学生网络之间poit-wise分类得分或特征图可能会施加过于严格的约束,只能得到次优解。
最近的一些研究主要关注于加强不同空间位置之间的相关性。如图2(a)所示,每个空间位置上的激活值被标准化然后通过聚合不同空间位置的子集来得到一些特定任务的关系,比如pair-wise relations和inter-class relations。这些方法在捕获空间结构信息和提高学生网络的性能方面可能比point-wise对齐效果更好。然而,激活图中的每个空间位置对知识转移的贡献相同,这可能会从教师网络带来冗余的信息。
本文提出了一种新的通道级channel-wise知识蒸馏方法,通过对每个通道中的激活图进行标准化,用于密集预测任务,如图2(b)所示。然后减小教师和学生网络标准化后的通道激活图之间的KL散度。一个channel-wise distribution的例子如图2(c)所示,可以看出每个通道的激活图倾向于编码特定场景类别的显著区域。对于每个通道,引导学生网络更加关注于模拟激活值大的区域,从而在密集预测任务中实现更准确的定位。比如在目标检测任务中,让学生网络更关注于学习前景区域的激活。
本文的贡献
- 与现有的spatial蒸馏方法不同,本文提出了一种新的通道蒸馏范式用于密集预测任务,方法简单有效。
- 在语义分割和目标检测方面,本文提出的通道级蒸馏方法显著优于最先进的KD方法。
- 我们在语义分割和目标检测任务上的四个基准数据集上进行了一致的改进,证明了我们的方法是通用的。鉴于它的简单性和有效性,我们相信我们的方法可以作为一个strong baseline KD方法用于密集预测任务。
方法介绍
为了更好的利用每个通道中的知识,作者提出softly对齐教师网络和学生网络对应通道的激活。为此,首先将一个通道的激活转换为概率分布,这样就可以使用一个概率距离度量例如KL散度来衡量差异。如图2(c)所示,不同通道的激活倾向于编码输入图像中某个特定类别场景的显著saliency区域。此外,一个训练好的语义分割教师模型在每个通道显示出了清晰的类别特定掩码的激活图,这是符合预期的,如图1右侧所示。因此作者提出了一种新的通道蒸馏范式来指导学生从一个训练有素的教师那里学习知识。
首先定义教师网络和学生网络分别为 \(T\) 和 \(S\),\(T\) 和 \(S\) 的激活分别表示为 \(y^{T}\) 和 \(y^{S}\),通道蒸馏损失的一般表示形式如下
其中 \(\phi(\cdot)\) 用于将激活值转换为概率分布,如下
其中 \(c=1,2,...,C\) 表示通道,\(i\) 示一个通道的spatial location的索引,\(\mathcal{T}\) 是温度超参。如果我们使用更大的 \(\mathcal{T}\),概率分布会变得更softer,意味着每个通道中关注的spatial region更加wider。通过使用softmax归一化,消除了大网络和小网络之间尺度的差异。如果教师网络和学生网络的通道不匹配,则使用1x1卷积上采样小网络的通道数使两者相等。\(\varphi(\cdot)\) 评估教师模型和学生模型通道分布之间的差异,具体使用KL散度
KL散度是一种非对称度量。从式4可以看出,如果 \(\phi(y^{T}_{c,i})\) 很大,\(\phi(y^{S}_{c,i})\) 应该像 \(\phi(y^{T}_{c,i})\) 一样大从而减小KL散度。相反,如果 \(\phi(y^{T}_{c,i})\) 非常小,KL散度对减小 \(\phi(y^{S}_{c,i})\) 的关注相对较少。通过教师网络的监督,学生网络倾向于在前景显著区域产生和教师网络相似的激活分布,而教师网络背景区域中的激活对学生网络的影响较小。作者认为KL这种不对称性有助于密集预测任务中蒸馏的学习。
和分类任务中的channel-wise蒸馏的区别
在Channel Distillation: Channel-Wise Attention for Knowledge Distillation这篇文章中,作者也提出了使用通道蒸馏的方式,但主要应用于分类任务。受SENet的启发,通过全局平局池化将一个通道的特征图转换为一个标量,然后应用KL散度衡量教师网络和学生网络对应通道的标量之间的差异。
而本文主要考虑的是密集预测任务,GAP或许对image-level的分类任务有帮助,但所有空间位置的权重相同,丢失了空间信息,不适用于密集预测任务。本文通过softmax标准化的方式,考虑到了不同空间位置重要性的不同,保留了空间位置信息,因此更适用于密集预测任务。
实验结果
表2是本文提出的channel蒸馏和其它spatial蒸馏在Cityscapes数据集上复杂度和验证集上的mIoU的对比,可以看出通道蒸馏的精度最高,且复杂度较小
表5是在Cityscapes数据集上,不同的学生模型用不同的蒸馏方法的精度对比,可以看出,对于不同结构的学生网络,本文提出的通道蒸馏的效果都要好于其它蒸馏方法。
表6是在目标检测任务上与其它蒸馏方法的对比,可以看出,在两阶段、单阶段、anchor-free不同结构的目标检测模型中,本文提出的通道蒸馏的效果都是最好的。
代码解析
官方实现如下,其中归一化采用channel_norm损失采用KL损失是论文中给出的方法,官方实现中还给出了其它的归一化方法和损失函数的选择。
import torch.nn as nnclass ChannelNorm(nn.Module):def __init__(self):super(ChannelNorm, self).__init__()def forward(self, featmap):n, c, h, w = featmap.shapefeatmap = featmap.reshape((n, c, -1))featmap = featmap.softmax(dim=-1)return featmapclass CriterionCWD(nn.Module):def __init__(self, norm_type='none', divergence='mse', temperature=1.0):super(CriterionCWD, self).__init__()# define normalize functionif norm_type == 'channel':self.normalize = ChannelNorm()elif norm_type == 'spatial':self.normalize = nn.Softmax(dim=1)elif norm_type == 'channel_mean':self.normalize = lambda x: x.view(x.size(0), x.size(1), -1).mean(-1)else:self.normalize = Noneself.norm_type = norm_typeself.temperature = 1.0# define loss functionif divergence == 'mse':self.criterion = nn.MSELoss(reduction='sum')elif divergence == 'kl':self.criterion = nn.KLDivLoss(reduction='sum')self.temperature = temperatureself.divergence = divergencedef forward(self, preds_S, preds_T):n, c, h, w = preds_S.shape# import pdb;pdb.set_trace()if self.normalize is not None:norm_s = self.normalize(preds_S / self.temperature)norm_t = self.normalize(preds_T.detach() / self.temperature)else:norm_s = preds_S[0]norm_t = preds_T[0].detach()if self.divergence == 'kl':norm_s = norm_s.log()loss = self.criterion(norm_s, norm_t)# item_loss = [round(self.criterion(norm_t[0][0].log(),norm_t[0][i]).item(),4) for i in range(c)]# import pdb;pdb.set_trace()if self.norm_type == 'channel' or self.norm_type == 'channel_mean':loss /= n * c# loss /= n * h * welse:loss /= n * h * wreturn loss * (self.temperature ** 2)
相关文章:

Channel-wise Knowledge Distillation for Dense Prediction(ICCV 2021)原理与代码解析
paper:Channel-wise Knowledge Distillation for Dense Prediction official implementation:https://github.com/irfanICMLL/TorchDistiller/tree/main/SemSeg-distill 摘要 之前大多数用于密集预测dense prediction任务的蒸馏方法在空间域spatial…...

No.052<软考>《(高项)备考大全》【冲刺6】《软考之 119个工具 (4)》
《软考之 119个工具 (4)》 61.人际交往:62.组织理论:63.预分派:64.谈判:65.招募:66.虚拟团队:67.多标准决策分析:68.人际关系技能:69.培训:70.团队建设活动:71.基本规则:72.集中办公:73.认可与奖励:74.人事评测工具:75.观察和交谈:76.项目绩效评估:77.冲…...

Go | 一分钟掌握Go | 9 - 通道
作者:Mars酱 声明:本文章由Mars酱编写,部分内容来源于网络,如有疑问请联系本人。 转载:欢迎转载,转载前先请联系我! 前言 在Java中,多线程之间的通信方式有哪些?记得吗&…...

【建议收藏】计算机视觉是什么?这几个计算机视觉的核心任务你真的了解吗?
文章目录 📚引言📖计算机视觉的核心任务📑图像分类和对象识别📑目标检测📑语义分割📑实例分割📑图像生成 📖计算机视觉的应用领域📑人脸识别📑自动驾驶&#…...

BatteryChargingSpecification1.2中文详解
1. Introduction 1.1 Scope 规范定义了设备通过USB端口充电的检测、控制和报告机制,这些机制是USB2.0规范的扩展,用于专用 充电器(DCP)、主机(SDP)、hub(SDP)和CDP(大电流充电端口)对设备的充电和power up。这些机制适用 于兼…...

基于Jenkins,docker实现自动化部署(持续交互)【转】
前言 随着业务的增长,需求也开始增多,每个需求的大小,开发周期,发布时间都不一致。基于微服务的系统架构,功能的叠加,对应的服务的数量也在增加,大小功能的快速迭代,更加要求部署的…...

漫谈大数据 - 数据湖认知篇
导语:数据湖是目前比较热的一个概念,许多企业都在构建或者准备构建自己的数据湖。但是在计划构建数据湖之前,搞清楚什么是数据湖,明确一个数据湖项目的基本组成,进而设计数据湖的基本架构,对于数据湖的构建…...

阿里云国际版ACE与国内版ACE区别
1.国际版ACE与国内版ACE有哪些不同 2.国际版ACP/ACE约考流程 2.1 登录VUE官方网站约考 https://www.pearsonvue.com.cn/Clients/Alibaba-Cloud-Certification.aspx 2.2 如果之前有注册过账户,那就直接登录,如果还没有账户,那就创建账户 2.…...

Mysql8.0 gis支持
GIS数据类型 MySQL的GIS功能遵守OGC的OpenGIS Geometry Model,支持其定义的空间数据类型的一个子集,包括以下空间数据类型: GEOMETRY:不可实例化的数据类型,但是可以作为一个列的类型,存储任何一种其他类型的数据POIN…...

汇编---Nasm
文章目录 比较流行的汇编语言有3种:不同风格的汇编语言在语法格式上会有不同: 实战代码:Intrinsic函数手写汇编(8086汇编)调用C的API库函数调用约定实际代码 C调用汇编函数进行计算纯C实现如下:CASM实现:纯ASM实现:ASM打印命令行参…...

NDK OpenGL渲染画面效果
NDK系列之OpenGL渲染画面效果技术实战,本节主要是通过OpenGL Java库(谷歌对OpenGL C库做了JIN封装,核心实现还是在Native层),实现页面渲染,自定义渲染特效。 实现效果: 实现逻辑: 1…...

常见的深度学习框架
框架优点缺点TensorFlow- 由Google开发和维护,社区庞大,学习资源丰富- 具备优秀的性能表现,支持大规模分布式计算- 支持多种编程语言接口,易于使用- 提供了可视化工具TensorBoard,可用于调试和可视化模型- 底层架构复杂…...

【设计模式】七大设计原则--------单一职责原则
文章目录 1.案例1.1 原始案例1.2 改进一:类上遵循单一职责原则1.3 改进二:方法上遵循单一职责原则 2.小结 1.案例 1.1 原始案例 package com.sdnu.principle.singleresponsibility; //客户端 public class singleResponsibility {public static void m…...

MySQL-中间件mycat(一)
目录 🍁mycat基础概念 🍁Mycat安装部署 🍃初始环境 🍃测试环境 🍃下载安装 🍃修改配置文件 🍃启动mycat 🍃测试连接 🦐博客主页:大虾好吃吗的博客 ǹ…...

ARM寄存器组织
ARM有37个32位长的寄存器: 1个用做PC(Program Counter); 1个用做CPSR(Current Program Status Register); 5个用做SPSR(Saved Program Status Registers); 30个通用寄存器。 AR…...

记录一次webdav协议磁盘挂载经验总结
记录一次磁盘挂载经验总结 文章目录 记录一次磁盘挂载经验总结适配环境服务器协议适配方案脚本与详细说明 适配环境 windows 11windows 10windows 7 x86 and x64linuxuos统信国产化linux系统 服务器协议 webdav 适配方案 一、通用 winfsprclone 已验证通过,版…...

安装Django
1. 在物理环境安装Django Python官方的PyPi仓库为我们提供了一个统一的代码托管仓库,所有的第三方库,甚至你自己写的开源模块,都可以发布到这里,让全世界的人分享下载 pip是最有名的Python包管理工具 。提供了对Python包的查找、…...

【前端面经】JS-如何使用 JavaScript 来判断用户设备类型?
在 Web 开发中,有时需要针对不同的设备类型进行不同的处理。例如,对于移动设备,我们可能需要采用不同的布局或者交互方式,以提供更好的用户体验。因此,如何判断用户设备类型成为了一个重要的问题。 1. 使用 navigator…...

压缩HTML引用字体
内容简介 有些网站为了凸显某部分字体,而引入自定义字体,但由于自定义字体相对都比较大(几M),导致页面加载缓慢;所以本文介绍三种压缩字体的方法,可根据项目情况自行选择。 压缩方法 1、利用Fontmin程序&a…...

大厂高频面试:底层的源码逻辑知多少?
你好,我是何辉。今天我们来聊一聊Dubbo的大厂高频面试题。 大厂面试,一般重点考察对技术理解的深度,和中小厂的区别在于,不仅要你精于实战,还要你深懂原理,勤于思考并针对功能进行合理的设计。 网上一直流…...

【学习笔记】CF607E Cross Sum
最后一道数据结构,不能再多了。 而且需要一点计算几何的知识,有点难搞。 分为两个部分求解。 首先考虑找到距离 ≤ r \le r ≤r的交点数量。发现这等价于圆上两段圆弧相交,因此将圆上的点离散化后排序,用一个主席树来求就做完了…...

Python 一元线性回归模型预测实验完整版
一元线性回归预测模型 实验目的 通过一元线性回归预测模型,掌握预测模型的建立和应用方法,了解线性回归模型的基本原理 实验内容 一元线性回归预测模型 实验步骤和过程 (1)第一步:学习一元线性回归预测模型相关知识。 线性回归模型属于…...

GStreamer第一阶段的简单总结
这里写目录标题 前言个人的总结v4l2src插件的简单使用 前言 因为涉及很多细节的GStreamer官方论坛有详细解链接: GStreamer官网,这里不做说明,以下只是涉及到个人的理解和认知,方便后续的查阅。 个人的总结 1)了解pipeline的使用࿰…...

【网络进阶】服务器模型Reactor与Proactor
文章目录 1. Reactor模型2. Proactor模型3. 同步IO模拟Proactor模型 在高并发编程和网络连接的消息处理中,通常可分为两个阶段:等待消息就绪和消息处理。当使用默认的阻塞套接字时(例如每个线程专门处理一个连接),这两…...

使用div替代<frameset><frame>的问题以及解决办法
首先是原版三层框架的html: <html> <head> <title>THPWP</title> </head> <!-- 切记frameset不能写在body里面,以下代表首页由三层模块组成,其中第一层我是用来放菜单高度占比14%,中间的用作主…...

Verilog中的`define与`if的使用
一部分代码可能有时候用,有时候不用,为了避免全部编译占用资源,可以使用条件编译语句。 语法 // Style #1: Only single ifdef ifdef <FLAG>// Statements endif// Style #2: ifdef with else part ifdef <FLAG>// Statements …...

沃尔玛、亚马逊影响listing的转化率4大因素,测评补单自养号解析
1、listing的相关性:前期我们在找词,收集词的时候,我们通过插件来协助我们去筛选词。我们把流量高,中,低的关键词都一一收集,然后我们再进行对收集得来的关键词进行分析,再进行挑词,…...

静态分析和动态分析
在开发早期,发现并修复bug在许多方面都有好处。它可以减少开发时间,降低成本,并且防止数据泄露或其他安全漏洞。特别是对于DevOps,尽早持续地将测试纳入SDLC软件开发生命周期是非常有帮助的。 这就是动态和静态分析测试的用武之地…...

代码随想录_贪心_leetcode 1005 134
leetcode 1005. K 次取反后最大化的数组和 1005. K 次取反后最大化的数组和 给你一个整数数组 nums 和一个整数 k ,按以下方法修改该数组: 选择某个下标 i 并将 nums[i] 替换为 -nums[i] 。 重复这个过程恰好 k 次。可以多次选择同一个下标 i 。 以…...

笔记:对多维torch进行任意维度的多“行”操作
如何取出多维torch指定维度的指定“行” 从二维torch开始新建torch取出某一行取出某一列一次性取出多行取出连续的多行取出不连续的多行 一次取出多列取出连续的多列取出不连续的多列 考虑三维torch取出三维torch的任意两行(means 在dim0上操作)取出连续…...