当前位置: 首页 > news >正文

基于注意力的知识蒸馏Attention Transfer原理与代码解析

paper:Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfer

code:https://github.com/megvii-research/mdistiller/blob/master/mdistiller/distillers/AT.py

背景

一个流行的假设是存在非注意non-attentional和注意attentional感知过程,非注意感知有助于从整体观察一个场景并获取high-level信息,注意力感知过程会将我们导向并更关注某个局部。不同的观察者有不同的知识,不同的目标,因此有不同的注意力策略,从而以不同方式看待同一场景。本文研究的主题是一个教师网络能否通过向学生网络传递它的注意力信息(即它更关注哪些区域)来提升学生网络的性能。

本文的创新点

本文提出将注意力作为一种将知识从教师网络传递给学生网络的机制。对于一个给定的卷积神经网络,首先需要合适地定义注意力,作者将注意力看做是一组spatial maps,给定输入spatial map上编码了网络最关注的空间区域。本文还提出了同时使用基于激活的和基于梯度的空间注意力图,并通过实验表明了基于注意力的知识蒸馏的有效性,同时表明了基于激活的注意力传递比单纯基于激活的知识传递效果更好。

方法介绍

给定CNN某层的输出激活张量 \(A\in R^{C\times H\times W}\),输入到一个基于激活的映射函数 \(\mathcal{F}\) 得到得到一个spatial attention map

这里隐含的假设是一个隐含神经元激活值的绝对值可以用来表明该神经元对某个特定输入的重要程度,比如注意力图上某个位置的值越大说明网络越关注该位置。基于该假设,我们可以沿通道维度上计算这些值的统计数据来构建一个空间注意力图。如下图所示

本文主要考虑了以下三种统计方法

我们假设教师网络和学生网络之间的注意力传递发生在相同分辨率的注意力图上,当分辨率不一致时也可以通过插值来进行匹配。一个示例如下图所示,其中教师和学生网络都是残差网络,在每个stage的最后进行注意力知识的传递,即计算对应attention map之间的损失。

定义 \(S,T\) 和 \(\mathbf{W_{S}},\mathbf{W_{T}}\) 分别表示学生和教师网络以及对应的模型权重,\(\mathcal{L}\left ( \mathbf{W},x \right ) \) 是交叉熵损失,\(\mathcal{I}\) 是所有教师和学生网络对应注意力图的索引。完整损失函数如下

其中 \(Q_{S}^{j}=vec(F(A^{j}_{S}))\) 和 \(Q_{T}^{j}=vec(F(A^{j}_{T}))\) 别是学生网络和教师网络第 \(j\) 层的向量形式的注意力图,\(p\) 是范数类型本文默认 \(p=2\)。注意力的知识传递还可以和知识蒸馏结合使用,只需要在式(2)中加一项教师和学生软化标签分布之间的交叉熵损失项即可。

代码解析

其中函数at_loss从教师和学生网络中一一取出对应层的特征图f_sf_t,函数single_stage_at_loss计算对应单层之间的注意力损失。

import torch
import torch.nn as nn
import torch.nn.functional as Ffrom ._base import Distillerdef single_stage_at_loss(f_s, f_t, p):def _at(feat, p):# (64,64,32,32)->(64,64,32,32)->(64,32,32)->(64,1024)->(64,1024)return F.normalize(feat.pow(p).mean(1).reshape(feat.size(0), -1))  # 沿通道取means_H, t_H = f_s.shape[2], f_t.shape[2]if s_H > t_H:f_s = F.adaptive_avg_pool2d(f_s, (t_H, t_H))elif s_H < t_H:f_t = F.adaptive_avg_pool2d(f_t, (s_H, s_H))# (64,1024)-(64,1024)->(64,1024)->()return (_at(f_s, p) - _at(f_t, p)).pow(2).mean()def at_loss(g_s, g_t, p):return sum([single_stage_at_loss(f_s, f_t, p) for f_s, f_t in zip(g_s, g_t)])class AT(Distiller):"""Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfersrc code: https://github.com/szagoruyko/attention-transfer"""def __init__(self, student, teacher, cfg):super(AT, self).__init__(student, teacher)self.p = cfg.AT.Pself.ce_loss_weight = cfg.AT.LOSS.CE_WEIGHTself.feat_loss_weight = cfg.AT.LOSS.FEAT_WEIGHTdef forward_train(self, image, target, **kwargs):logits_student, feature_student = self.student(image)with torch.no_grad():_, feature_teacher = self.teacher(image)# lossesloss_ce = self.ce_loss_weight * F.cross_entropy(logits_student, target)loss_feat = self.feat_loss_weight * at_loss(feature_student["feats"][1:], feature_teacher["feats"][1:], self.p)losses_dict = {"loss_ce": loss_ce,"loss_kd": loss_feat,}return logits_student, losses_dict

相关文章:

基于注意力的知识蒸馏Attention Transfer原理与代码解析

paper&#xff1a;Paying More Attention to Attention: Improving the Performance of Convolutional Neural Networks via Attention Transfercode&#xff1a;https://github.com/megvii-research/mdistiller/blob/master/mdistiller/distillers/AT.py背景一个流行的假设是存…...

利尔达在北交所上市:总市值突破29亿元,叶文光为董事长

2月17日&#xff0c;利尔达科技集团股份有限公司&#xff08;下称“利尔达”&#xff0c;BJ:832149&#xff09;在北京证券交易所上市。本次上市&#xff0c;利尔达的发行价格为5.00元/股&#xff0c;发行数量为1980万股&#xff0c;发行市盈率为12.29倍&#xff0c;募资总额为…...

C#操作字符串方法 [万余字总结 · 详细]

C#操作字符串方法总结C#常用字符串函数大全C#常用字符串操作方法C#操作字符串方法总结C#常用字符串函数大全 Compare 比较字符串的内容&#xff0c;考虑文化背景(场所)&#xff0c;确定某些字符是否相等 CompareOrdinal 与Compare一样&#xff0c;但不考虑文化背景 Format 格…...

极兔一面:10亿级ES海量搜索狂飙10倍,该怎么办?

背景说明&#xff1a; ES高性能全文索引&#xff0c;如果不会用&#xff0c;或者没有用过&#xff0c;在面试中&#xff0c;会非常吃亏。 所以ES的实操和底层原理&#xff0c;大家要好好准备。 另外&#xff0c;ES调优是一个非常、非常核心的面试知识点&#xff0c;大家要非…...

【Mysql基础 —— SQL语句(一)】

文章目录概述使用启动/停止mysql服务连接mysql客户端数据模型SQLSQL语句分类DDL数据库操作表操作查询创建数据类型修改删除DML添加数据修改数据删除数据DQL基础查询条件查询聚合函数分组查询排序查询分页查询执行顺序DCL管理用户权限控制概述 数据库&#xff08;Database&#…...

华为OD机试 - 新员工座位安排系统(Python) | 机试题算法思路

最近更新的博客 华为OD机试 - 招聘(Python) | 备考思路,刷题要点,答疑 【新解法】华为OD机试 - 五键键盘 | 备考思路,刷题要点,答疑 【新解法】华为OD机试 - 热点网络统计 | 备考思路,刷题要点,答疑 【新解法】华为OD机试 - 路灯照明 | 备考思路,刷题要点,答疑 【新解…...

MySQL - 介绍

前言 本篇介绍认识MySQL&#xff0c;重装mysql操作 如有错误&#xff0c;请在评论区指正&#xff0c;让我们一起交流&#xff0c;共同进步&#xff01; 本文开始 1.什么是数据库? 数据库: 一种通过SQL语言操作管理数据的软件; 重装数据库的卸载数据库步骤 : ① 停止MySQL服…...

ChatGPT国内镜像站初体验:聊天、Python代码生成等

ChatGPT国内镜像站初体验&#xff0c;聊天、Python代码生成。 (本文获得CSDN质量评分【92】)【学习的细节是欢悦的历程】Python 官网&#xff1a;https://www.python.org/ Free&#xff1a;大咖免费“圣经”教程《 python 完全自学教程》&#xff0c;不仅仅是基础那么简单………...

SAP数据导入工具(LSMW) 超级详细教程(批量导入内部订单)

目录 第一步&#xff1a;记录批导步骤编辑数据源对应字段 第二步&#xff1a;维护数据源 第三步&#xff1a;维护数据源对应字段&#xff08;重要&#xff09; 第四步&#xff1a;维护数据源关系。 第五步&#xff1a;维护数据源与导入字段的对应关系。 第六步&#xff0…...

第9天-商品服务(电商核心概念,属性分组开发及分类和品牌的级联更新)

1.电商核心概念 1.1.SPU与SKU SPU&#xff1a;Standard Product Unit&#xff08;标准化产品单元&#xff09; 是商品信息聚合的最小单位&#xff0c;是一组可复用、易检索的标准化信息的集合&#xff0c;该集合描述了一个 产品的特性。 决定商品属性的值 SKU&#xff1a;Stock…...

动漫人物眼睛画法

本期的动漫绘画课程教大家来学习动漫人物眼睛画法&#xff0c;结合板绘软件从草稿开始一步步教你画出动漫人物眼睛&#xff0c;不用报动漫培训班也能学会&#xff0c;快来跟着本期的动漫人物眼睛画法教程试试吧&#xff01; 动漫人物眼睛画法步骤教程&#xff1a; 注意&#x…...

张晨光-JAVA零基础保姆式JDBC技术教程

JDBC文档 JDBC概述 JDBC概述 Java DataBase Connectivity Java 数据库连接技术 JDBC的作用 通过Java语言操作数据库&#xff0c;操作表中的数据 SUN公司为**了简化、**统一对数据库的操作&#xff0c;定义了一套Java操作数据库的规范&#xff0c;称之为JDBC JDBC的本质 是官方…...

华为OD机试 - 最多提取子串数目(Python)

最多提取子串数目 题目 给定由 [a-z] 26 个英文小写字母组成的字符串 A 和 B,其中 A 中可能存在重复字母,B 中不会存在重复字母 现从字符串 A 中按规则挑选一些字母,可以组成字符串 B。 挑选规则如下: 1) 同一个位置的字母只能被挑选一次 2) 被挑选字母的相对先后顺序不…...

LeetCode-1237. 找出给定方程的正整数解【双指针,二分查找】

LeetCode-1237. 找出给定方程的正整数解【双指针&#xff0c;二分查找】题目描述&#xff1a;解题思路一&#xff1a;双指针。首先我们不管f是什么&#xff0c;即function_id等于什么不管。但是我们可以调用customfunction中的f函数&#xff0c;然后我们遍历x,y(1 < x, y &l…...

广度优先搜索算法 - 迷宫找路

广度优先搜索算法1 思考问题1.1 这个迷宫需不需要指定入口和出口&#xff1f;2 先粗略实现2.1 源码2.2 源码解释3 优化代码3.1 优化读取文件部分3.2 增加错误处理4 再优化-让程序变得更加灵活4.1 用户外部可以循环输入入口和出口5 完整代码这是一个提问者的提出的问题&#xff…...

泡脚材料简记

文章目录一般条件中药包&#xff08;药粉&#xff09;泡脚丸中药包&#xff08;药材&#xff09;艾叶生姜益母草藏红花食盐花椒白醋柠檬藿香泡脚私方紫苏叶、白术、白芍、黄芪、青皮、柴胡、夜交藤、丹参、当归&#xff0c;每种各10g艾叶、花椒、肉桂、桂枝、红花干姜30克、小茴…...

【计算机网络】因特网概述

文章目录因特网概述网络、互联网和因特网互联网历史与ISP标准化与RFC因特网的组成三种交换方式电路交换分组交换和报文交换三种交换方式的对比与总结计算机网络的定义和分类计算机网络的定义计算机网络的分类计算机网络的性能指标速率带宽吞吐量时延时延带宽积往返时间利用率丢…...

STC单片机 VS/HX1838红外接收和发送实验

STC单片机 VS/HX1838红外接收和发送实验 📌相关篇《STC单片机获取红外解码从串口输出》🔨所使用的红外接收头VS1838 📋VS1838引脚定义🌿5MM发射头,940nm红外发射二极管 红外遥控发射头。(外观看起来和普通的发光二极管没有什么差异,购买时需要注意确认)。 🔰采用的…...

前端开发常用案例(一)

前端开发常用案例1.实现三角形百度热榜样式分页效果小米商城自动轮播图效果二级下拉菜单效果时间轴效果展示音乐排行榜效果鼠标移入文字加载动画鼠标悬停缩放效果1.实现三角形 <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8…...

Linux 日志查找常用命令

1.1 cat、zcat cat -n app.log | grep "error"&#xff1a;查询日志中含有某个关键字error的信息&#xff0c;显示行号。 cat -n app.log | grep "error" --color&#xff1a;查询日志中含有某个关键字error的信息&#xff0c;显示行号&#xff0c;带颜色…...

CleanMyMac4.12.5最新版安装下载教程

告别硬盘空间不足&#xff0c;让您的Mac极速如新CleanMyMac是一款强大的 Mac 清理、加速工具和健康卫士&#xff0c;让您的 Mac 加快启动速度。CleanMyMac是一款专业的Mac清理软件&#xff0c;可智能清理mac磁盘垃圾和多余语言安装包&#xff0c;快速释放电脑内存&#xff0c;轻…...

RFID射频识别技术(四) RFID高频电路基础|课堂笔记|10月11日

2022年10月11日 week7 ​​​​​​​ 目录 ​​​​​​​ 第四讲: RFID高频电路基础 一、RLC(串联)电路的阻抗...

数据库系统是什么?它由哪几部分组成?

数据库系统&#xff08;Database System&#xff0c;DBS&#xff09;由硬件和软件共同构成。硬件主要用于存储数据库中的数据&#xff0c;包括计算机、存储设备等。软件部分主要包括数据库管理系统、支持数据库管理系统运行的操作系统&#xff0c;以及支持多种语言进行应用开发…...

华为OD机试题 - 任务混部(JavaScript)

最近更新的博客 2023新华为OD机试题 - 斗地主(JavaScript)2023新华为OD机试题 - 箱子之形摆放(JavaScript)2023新华为OD机试题 - 考古学家(JavaScript)2023新华为OD机试题 - 相同数字的积木游戏 1(JavaScript)2023新华为OD机试题 - 最多等和不相交连续子序列(JavaScri…...

键盘输入a,到屏幕显示,操作系统做了什么

首先&#xff0c;假定操作系统有中断系统。 等待的键盘写入的时候&#xff0c;txt进程被read函数阻塞。输入a之后&#xff0c;首先控制器&#xff0c;把扫描到的a放入到了控制器的寄存器中。触发硬中断通知cpu—> 中断IO控制方式&#xff0c;由硬件触发的。键盘读入中断cpu…...

Python机器学习入门笔记(2)—— 分类算法

目录 转换器&#xff08;transformer&#xff09;和估计器&#xff08;estimator&#xff09; K-近邻&#xff08;K-Nearest Neighbors&#xff0c;简称KNN&#xff09;算法 模型选择与调优 交叉验证&#xff08;Cross-validation&#xff09; GridSearchCV API 朴素贝叶…...

Docker镜像发布到阿里云和私有库

目录 一、Docker镜像 &#xff08;一&#xff09;概述 &#xff08;二&#xff09;Docker镜像加载原理 &#xff08;三&#xff09;镜像分层结构优势 &#xff08;四&#xff09;重点理解 &#xff08;五&#xff09;docker commit操作实例 &#xff08;六&#xff09;总…...

初识CSS,美化HTML

CSS称为&#xff1a;层叠样式表&#xff08;Cascading style sheets&#xff09;美化HTML即给页面种的HTML标签设置样式CSS语法规则css要写在head标签的里边&#xff0c;title标签的下面&#xff0c;用style标签框住<head> <title>...</title> <style>…...

华为OD机试 - 二维矩阵的最大值(Python)

题目二维矩阵的最大值 给定一个仅包含0和1的n*n二维矩阵 请计算二维矩阵的最大值 计算规则如下 每行元素按下标顺序组成一个二进制数(下标越大约排在低位), 二进制数的值就是该行的值,矩阵各行之和为矩阵的值允许通过向左或向右整体循环移动每个元素来改变元素在行中的位置 …...

华为OD机试 - 快递业务站(Python)

快递业务站 题目 快递业务范围有 N 个站点,A 站点与 B 站点可以中转快递,则认为 A-B 站可达, 如果 A-B 可达,B-C 可达,则 A-C 可达。 现在给 N 个站点编号 0、1、…n-1,用 s[i][j]表示 i-j 是否可达, s[i][j] = 1表示 i-j可达,s[i][j] = 0表示 i-j 不可达。 现用二维…...

网站的建设原始代码/重庆网站建设外包

一、 ng-class ng-class 指令用于给 HTML 元素动态绑定一个或多个 CSS 类。 ng-class 指令的值可以是字符串&#xff0c;对象&#xff0c;或一个数组。 如果是字符串&#xff0c;多个类名使用空格分隔。如果是对象&#xff0c;需要使用 key-value 对&#xff0c;key 为你想要添…...

西安做网站排名/做网页

创建表的两种办法&#xff1a; 使用DBMS 提供的交互式创建和管理数据库表的工具&#xff1b;直接用SQL 语句创建。表创建基础 创建表示例&#xff1a; 1 CREATE TABLE Products 2 ( 3 prod_id CHAR(10) NOT NULL,--是否可以为null 4 vend_id CHAR(10) NOT NULL, 5 …...

文字网站和图片网站哪个难做/站长之家0

6月18日&#xff0c;为期三天的第九届全球云计算大会中国站(Cloud Connect China 2021)在宁波圆满落幕。TcaplusDB作为专注于高性能高可用特性的NoSQL数据库&#xff0c;在本届“云鼎奖”的评选中&#xff0c;与GCloudSDK联合申请&#xff0c;斩获了**“2020-2021年度优秀解决方…...

前端做网站都要做哪些/seo排名优化联系13火星软件

Description有一个a*b的整数组成的矩阵&#xff0c;现请你从中找出一个n*n的正方形区域&#xff0c;使得该区域所有数中的最大值和最小值的差最小。Input第一行为3个整数&#xff0c;分别表示a,b,n的值第二行至第a1行每行为b个非负整数&#xff0c;表示矩阵中相应位置上的数。每…...

wordpress删掉不需要的/新冠疫苗接种最新消息

事件机制和消息循环原理1、鼠标和键盘事件2、event 对象常用属性3、源代码1、鼠标和键盘事件 代码说明<Button-1> <ButtonPress-1> <1>鼠标左键按下&#xff0c;2 表示中间的滚轮&#xff0c;3 表示右键<ButtonRelease-1>鼠标左键释放<B1-Motion&g…...

建设独立服务器网站/建立企业网站步骤

首先引入程序入口 对于很多编程语言来说&#xff0c;程序都必须要有一个入口&#xff0c;比如 C&#xff0c;C&#xff0c;以及完全面向对象的编程语言 Java&#xff0c;C# 等。如果你接触过这些语言&#xff0c;对于程序入口这个概念应该很好理解&#xff0c;C 和 C 都需要有一…...