当前位置: 首页 > news >正文

【ECCV2022】DaViT: Dual Attention Vision Transformers

DaViT: Dual Attention Vision Transformers, ECCV2022

解读:【ECCV2022】DaViT: Dual Attention Vision Transformers - 高峰OUC - 博客园 (cnblogs.com)

DaViT:双注意力Vision Transformer - 知乎 (zhihu.com) 

DaViT: Dual Attention Vision Transformers - 知乎 (zhihu.com) 

论文:https://arxiv.org/abs/2204.03645

代码:https://github.com/dingmyu/davit

动机

以往的工作一般是,在分辨率、全局上下文和计算复杂度之间权衡:像素级和patch级的self-attention要么是有二次计算成本,要么损失全局上下文信息。除了像素级和patch级的self-attention的变化之外,是否可以设计一个图像级的self-attention机制来捕获全局信息?

作者提出了Dual Attention Vision Transformers (DaViT),能够在保持计算效率的同时捕获全局上下文。提出的方法具有层次结构和细粒度局部注意的优点,同时采用 group channel attention,有效地建模全局环境。

创新点:

  • 提出 Dual Attention Vision Transformers(DaViT),它交替地应用spatial window attentionchannel group attention来捕获长短依赖关系。
  • 提出 channel group attention,将特征通道划分为几个组,并在每个组内进行图像级别的交互。通过group attention,作者将空间和通道维度的复杂性降低到线性。

方法

dual attention

双attention机制是从两个正交的角度来进行self-attention:

一是对spatial tokens进行self-attention,此时空间维度(HW)定义了tokens的数量,而channel维度(C)定义了tokens的特征大小,这其实也是ViT最常采用的方式;

二是对channel tokens进行self-attention,这和前面的处理完全相反,此时channel维度(C)定义了tokens的数量,而空间维度(HW)定义了tokens的特征大小。

可以看出两种self-attention完全是相反的思路。为了减少计算量,两种self-attention均采用分组的attention:对于spatial token而言,就是在空间维度上划分成不同的windows,这就是Swin中所提出的window attention,论文称之为spatial window attention;而对于channel tokens,同样地可以在channel维度上划分成不同的groups,论文称之为channel group attention

 (a)空间窗口多头自注意将空间维度分割为局部窗口,其中每个窗口包含多个空间token。每个token也被分成多个头。(b)通道组单自注意组将token分成多组。在每个通道组中使用整个图像级通道作为token进行Attention。在(a)中也突出显示了捕获全局信息的通道级token。交替地使用这两种类型的注意力机制来获得局部的细粒度,以及全局特征。

两种attention能够实现互补:spatial window attention能够提取windows内的局部特征,而channel group attention能学习到全局特征,这是因为每个channel token在图像空间上都是全局的。

dual attention block

dual attention block的模型架构,它包含两个transformer block:空间window self-attention block和通道group self-attention block。通过交替使用这两种类型的attention机制,作者的模型能实现局部细粒度和全局图像级交互。图3(a)展示了作者的dual attention block的体系结构,包括一个空间window attention block和一个通道group attention block。

Spatial Window Attention

将patchs按照空间结构划分为Nw个window,每个window 里的patchs(Pw)单独计算self-attention:(P=Nw*Pw)

 Channel Group Attention

将channels分为Ng个group,每个group的channel数量为Cg,有C=Ng*Cg,计算如下: 

关键代码

class SpatialBlock(nn.Module):r""" Windows Block.Args:dim (int): Number of input channels.num_heads (int): Number of attention heads.window_size (int): Window size.mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: Truedrop_path (float, optional): Stochastic depth rate. Default: 0.0act_layer (nn.Module, optional): Activation layer. Default: nn.GELUnorm_layer (nn.Module, optional): Normalization layer.  Default: nn.LayerNorm"""def __init__(self, dim, num_heads, window_size=7,mlp_ratio=4., qkv_bias=True, drop_path=0.,act_layer=nn.GELU, norm_layer=nn.LayerNorm,ffn=True, cpe_act=False):super().__init__()self.dim = dimself.ffn = ffnself.num_heads = num_headsself.window_size = window_sizeself.mlp_ratio = mlp_ratio# conv位置编码self.cpe = nn.ModuleList([ConvPosEnc(dim=dim, k=3, act=cpe_act),ConvPosEnc(dim=dim, k=3, act=cpe_act)])self.norm1 = norm_layer(dim)self.attn = WindowAttention(dim,window_size=to_2tuple(self.window_size),num_heads=num_heads,qkv_bias=qkv_bias)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()if self.ffn:self.norm2 = norm_layer(dim)mlp_hidden_dim = int(dim * mlp_ratio)self.mlp = Mlp(in_features=dim,hidden_features=mlp_hidden_dim,act_layer=act_layer)def forward(self, x, size):H, W = sizeB, L, C = x.shapeassert L == H * W, "input feature has wrong size"shortcut = self.cpe[0](x, size) # depth-wise convx = self.norm1(shortcut)x = x.view(B, H, W, C)pad_l = pad_t = 0pad_r = (self.window_size - W % self.window_size) % self.window_sizepad_b = (self.window_size - H % self.window_size) % self.window_sizex = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))_, Hp, Wp, _ = x.shapex_windows = window_partition(x, self.window_size)x_windows = x_windows.view(-1, self.window_size * self.window_size, C)# W-MSA/SW-MSAattn_windows = self.attn(x_windows)# merge windowsattn_windows = attn_windows.view(-1,self.window_size,self.window_size,C)x = window_reverse(attn_windows, self.window_size, Hp, Wp)if pad_r > 0 or pad_b > 0:x = x[:, :H, :W, :].contiguous()x = x.view(B, H * W, C)x = shortcut + self.drop_path(x)x = self.cpe[1](x, size) # 第2个depth-wise convif self.ffn:x = x + self.drop_path(self.mlp(self.norm2(x)))return x, sizeclass ChannelAttention(nn.Module):def __init__(self, dim, num_heads=8, qkv_bias=False):super().__init__()self.num_heads = num_heads # 这里的num_heads实际上是num_groupshead_dim = dim // num_headsself.scale = head_dim ** -0.5self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.proj = nn.Linear(dim, dim)def forward(self, x):B, N, C = x.shape# 得到query,key和value,是在channel维度上进行线性投射qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]k = k * self.scaleattention = k.transpose(-1, -2) @ v # 对维度进行反转attention = attention.softmax(dim=-1)x = (attention @ q.transpose(-1, -2)).transpose(-1, -2)x = x.transpose(1, 2).reshape(B, N, C)x = self.proj(x)return x

DaViT采用金字塔结构,共包含4个stages,每个stage的开始时都插入一个 patch embedding 层。作者在每个stage叠加dual attention block,这个block就是将两种attention(还包含FFN)交替地堆叠在一起,其分辨率和特征维度保持不变。

采用stride=4的7x7 conv,然后是4个stage,各stage通过stride=2的2x2 conv来进行降采样。其中DaViT-Tiny,DaViT-Small和DaViT-Base三个模型的配置如下所示:

 

实验

 ​​​​​​

 

相关文章:

【ECCV2022】DaViT: Dual Attention Vision Transformers

DaViT: Dual Attention Vision Transformers, ECCV2022 解读:【ECCV2022】DaViT: Dual Attention Vision Transformers - 高峰OUC - 博客园 (cnblogs.com) DaViT:双注意力Vision Transformer - 知乎 (zhihu.com) DaViT: Dual Attention Vision Trans…...

Apache 配置与应用

Apache 配置与应用 一、构建虚拟 Web 主机httpd服务支持的虚拟主机类型包括以下三种: 二、基于域名的虚拟主机1.为虚拟主机提供域名解析方法一:部署DNS域名解析服务器 来提供域名解析方法二:在/etc/hosts 文件中临时配置域名与IP地址的映射关…...

OpenGL 纹理

1.简介 纹理是一个2D图片(甚至也有1D和3D的纹理),它可以用来添加物体的细节;你可以想象纹理是一张绘有砖块的纸,无缝折叠贴合到你的3D的房子上,这样你的房子看起来就像有砖墙外表了。 为了能够把纹理映射(M…...

Jeston Orin Nnao 安装pytorch与torchvision环境

大家好,我是虎哥,Jeston Orin nano 8G模块,提供高达 40 TOPS 的 AI 算力,安装好了Jetpack5.1之后,我们需要配置一些支持环境,来为我们后续的深度学习开发提供支持。本章内容,我将主要围绕安装对…...

ROS:常用可视化工具的使用

目录 一、日志输出工具——rqt_console二、绘制数据曲线——rqt_plot三、图像渲染工具——rqt_image_view四、图形界面总接口——rqt五、Rviz六、Gazebo 一、日志输出工具——rqt_console 启动海龟键盘控制节点,打开日志输出工具 roscorerosrun turtlesim turtles…...

智能应用搭建平台——LCHub低代码表单 vs 流程表单 vs 仪表盘

1. LCHub低代码如何选择 「流程表单」:填报数据,并带有流程审批功能,适合报销、请假申请或其他工作流; 「表单」:填报数据,并带有数据协作功能,如修改、删除、导入、导出,并可以给不同的人不同的管理权限; 「仪表盘」:数据分析处理、结果展示功能,如数据汇总、趋…...

Mac下通过Docker安装ElasticSearch集群

1、安装ElasticSearch 使用docker直接获取es镜像,执行命令docker pull elasticsearch:7.7.0 执行完成后,执行docker images即可看到上一步拉取的镜像。 2、创建数据挂在目录,以及配置ElasticSearch集群配置文件 创建数据文件挂载目录 mkdir -…...

MySQL redo log、undo log、binlog

MySQL是一个广泛使用的关系型数据库管理系统,它通过一系列的日志来保证数据的一致性和持久性。在MySQL中,有三个重要的日志组件,它们分别是redo log(重做日志)、undo log(回滚日志)和binlog&…...

文件包含漏洞

一、原理解析。 开发人员通常会把可重复使用的函数写到单个文件中,在使用到某些函数时,可直接调用此文件,而无须再次编写,这种调用文件的过程被称为包含。 注意:对于开发人员来讲,文件包含很有用&#xf…...

Python 中的 F-Test

文章目录 F 统计量和 P 值方差(ANOVA) 分析中的 F 值 本篇文章介绍 F 统计、F 分布以及如何使用 Python 对数据执行 F-Test 测试。 F 统计量是在方差分析检验或回归分析后获得的数字,以确定两个总体的平均值是否存在显着差异。 它类似于 T 检验的 T 统计量&#xf…...

Docker网络模式

一、docker网络概述 1、docker网络实现的原理 Docker使用Linux桥接,在宿主机虚拟一个Docker容器网桥(docker0),Docker启动一个容器时会根据Docker网桥的网段分配给容器一个IP地址,称为Container-IP, 同时Docker网桥是 每个容器的…...

百度离线资源治理

作者 | 百度MEG离线优化团队 导读 近些年移动互联网的高速发展驱动了数据爆发式的增长,各大公司之间都在通过竞争获得更大的增长空间,大数据计算的效果直接影响到公司的发展,而这背后其实依赖庞大的算力及数据作为支撑,因此在满足…...

C++进阶之继承

文章目录 前言一、继承的概念及定义1.继承概念2.继承格式与访问限定符3.继承基类与派生类的访问关系变化4.总结 二、基类和派生类对象赋值转换基本概念与规则 三、继承中的作用域四、派生类的默认成员函数五、继承与友元六、继承与静态成员六、复杂的菱形继承及菱形虚拟继承七、…...

在 Transformers 中使用约束波束搜索引导文本生成

引言 本文假设读者已经熟悉文本生成领域波束搜索相关的背景知识,具体可参见博文 如何生成文本: 通过 Transformers 用不同的解码方法生成文本。 与普通的波束搜索不同,约束 波束搜索允许我们控制所生成的文本。这很有用,因为有时我们确切地知…...

Centos7更换OpenSSL版本

OpenSSL 1.1.0 用户应升级至 1.1.0aOpenSSL 1.0.2 用户应升级至 1.0.2iOpenSSL 1.0.1 用户应升级至 1.0.1u 查看openssl版本 openssl version -v选择升级版本 我的版本是OpenSSL 1.0.2系列,所以要升级1.0.2i https://www.openssl.org/source/old/1.0.2/openssl-…...

基于摄影测量的三维重建【终极指南】

我们生活的时代非常令人兴奋,如果你对 3D 东西感兴趣,更是如此。 我们有能力使用任何相机,从感兴趣的物体中捕捉一些图像数据,并在眨眼间将它们变成 3D 资产! 这种通过简单的数据采集阶段进行的 3D 重建过程是许多行业…...

配置ThreadPoolExecutor

ThreadPoolExecutor为一些Executor 提供了基本的实现,这些Executor 是由Executors中的newCachedThreadPool、newFixedThreadPool和newScheduledThreadExecutor 等工厂方法返回的。ThreadPoolExecutor是一个灵活的、稳定的线程池,允许进行各种定制。 如果默认的执行策略不能满足…...

Yolov5s算法从训练到部署

文章目录 PyTorch GPU环境搭建查看显卡CUDA版本Anaconda安装PyTorch环境安装PyCharm中验证 训练算法模型克隆Yolov5代码工程制作数据集划分训练集、验证集修改工程相关文件配置预训练权重文件配置数据文件配置模型文件配置 超参数配置 测试训练出来的算法模型 量化转换算法模型…...

分布式补充技术 01.AOP技术

01.AOP技术是对于面向对象编程(OOP)的补充。是按照OCP原则进行的编写,(ocp是修改模块权限不行,扩充可以) 02.写一个例子: 创建一个新的java项目,在main主启动类中,写如下代码。 package com.co…...

QT 多对一服务插件 CTK开发(五)

CTK在软件的开发过程中可以很好的降低复杂性、使用 CTK Plugin Framework 提供统一的框架来进行开发增加了复用性 将同一功能打包可以提供多个应用程序使用避免重复性工作、可以进行版本控制提供了良好的版本更新迭代需求、并且支持动态热拔插 动态更新、开发更加简单快捷 方便…...

[Windows]_[初级]_[创建目录和文件的名字注意事项]

场景 在开发Windows程序时,会出现目录生成了,但是函数无法在目录里创建文件,怎么回事?说明 在之前说过Windows上有些字符是不能作为文件名的[1],但是检查了下出问题的目录名没有非法字符,所以不是这个原因。 把文件的绝对路径打印出来就发现了问题,目录名后边带了空格,…...

「QT」QT5程序设计目录

✨博客主页:何曾参静谧的博客 📌文章专栏:「QT」QT5程序设计 目录 📑【QT的基础知识篇】📑【QT的GUI编程篇】📑【QT的项目示例篇】📑【QT的网络编程篇】📑【QT的数据库编程篇】📑【QT的跨平台编程篇】📑【QT的高级编程篇】📑【QT的开发工具篇】📑【QT的调…...

ConcurrentHashMap核心源码(JDK1.8)

一、ConcurrentHashMap的前置知识扫盲 ConcurrentHashMap的存储结构? 数组 链表 红黑树 二、ConcurrentHashMap的DCL操作 HashMap线程不安全,在并发情况下,或者多个线程同时操作时,肯定要使用ConcurrentHashMap 无论是HashM…...

【Python】文件 读取 写 os模块 shutil模块 pickle模块

目录 1.文件 1.1 读取操作 1.2 写操作 1.3 os:文件管理 1.4 os.path:获取文件属性 1.5 shutil:文件的拷贝删除移动解压缩 1.6 pickle:数据永久存储 1.文件 文件编码 编码是一种规则集合,记录内容和二进制间相互…...

PAT A1087 All Roads Lead to Rome

1087 All Roads Lead to Rome 分数 30 作者 CHEN, Yue 单位 浙江大学 Indeed there are many different tourist routes from our city to Rome. You are supposed to find your clients the route with the least cost while gaining the most happiness. Input Specific…...

浅谈HttpURLConnection所有方法详解

HttpURLConnection 类是 Java 中用于实现 HTTP 协议的基础类,它提供了一系列方法来建立与 HTTP 服务器的连接、发送请求并读取响应信息。下面是 HttpURLConnection 类中常用的方法以及其详细解释: ---------------------------------------------------…...

前端快速创建web3应用模版分享

一、起因 一直以来都有一个创建前端Dapp模版的愿望,一来是工作中也有这样的需要,避免每次都要抽离重复的代码。二来是这样的模版也能帮助其他前端快速了解到web3应用的脚手架以及框架结构。于是决定整理一个模版并开源,希望我的代码能帮助到大…...

越权漏洞讲解

越权漏洞是指系统或应用程序中存在的安全漏洞,允许攻击者以超越其授权范围的方式访问系统资源或执行特权操作。这种漏洞可能会导致严重的安全风险,因为攻击者可以利用它来获取敏感信息、修改系统设置或执行恶意操作。 下面是一些常见的越权漏洞类型和它…...

短视频矩阵源码系统打包.源码

Masayl是一款基于区块链技术的去中心化应用程序开发平台,可帮助开发者快速、便捷地创建去中心化应用程序。Masayl拥有丰富的API和SDK,为开发者们提供了支持。此外,Masayl还采用了高效的智能合约技术,确保应用程序的稳定、安全和高…...

云南LED、LCD显示屏系统建设,户外、室内广告大屏建设方案

LED大屏幕显示系统是LED高清晰数字显示技术、显示单元无缝拼接技术、多屏图像处理技术、信号切换技术、网络技术等科技手段的应用综合为一体,形成一个拥有高亮度、高清晰度、技术先进、功能强大、使用方便的大屏幕投影显示系统。通过大屏幕显示系统,可以…...