当前位置: 首页 > news >正文

YOLO即插即用模块---AgentAttention

Agent Attention: On the Integration of Softmax and Linear Attention

论文地址:https://arxiv.org/pdf/2312.08874

问题: 普遍使用的 Softmax 注意力机制在视觉 Transformer 模型中计算复杂度过高,限制了其在各种场景中的应用。

方法: 提出了一个新的注意力机制,名为 Agent Attention,通过引入一组代理 token (A) 来解决计算复杂度过高的问题。

具体步骤

  1. 代理聚合 (Agent Aggregation): 将代理 token (A) 作为查询 token (Q) 的代理,从键 (K) 和值 (V) 中聚合信息,形成代理特征 (VA)。

  2. 代理广播 (Agent Broadcast): 将代理 token (A) 作为键,将全局信息从代理特征 (VA) 广播到每个查询 token (Q),形成最终的输出。

代理 token (A) 的获取方式

  • 可学习的参数

  • 从输入特征中提取 (例如,通过池化或卷积)

Agent Attention 模块

  • 包含纯 Agent Attention、代理偏置 (Agent Bias) 和深度可分离卷积 (DWC) 模块。

  • 代理偏置用于添加位置信息,帮助不同的代理 token 关注不同的区域。

  • DWC 模块用于保持特征多样性,弥补线性注意力的不足。

Agent Attention 的优势

  • 高效计算和高表达能力: 结合了 Softmax 注意力和线性注意力的优点,既降低了计算复杂度,又保持了高表达能力。

  • 大感受野: 可以采用更大的感受野,甚至全局感受野,同时保持相同的计算量。P8

实验结果

  • 在图像分类、目标检测、语义分割和图像生成等任务上,Agent Attention 都取得了显著的性能提升。

  • 在高分辨率场景中,Agent Attention 表现出优异的性能。

  • 将 Agent Attention 应用于 Stable Diffusion,可以加速图像生成过程,并显著提高图像生成质量,无需任何额外的训练。

总结: Agent Attention 是一种高效且高表达的注意力机制,可以有效地解决 Softmax 注意力计算复杂度过高的问题,在各种视觉任务中取得了显著的性能提升,特别是在高分辨率场景中。

即插即用代码:

import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_class AgentAttention(nn.Module):def __init__(self, dim, num_patches, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,sr_ratio=1, agent_num=49, **kwargs):super().__init__()assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."self.dim = dimself.num_patches = num_patcheswindow_size = (int(num_patches ** 0.5), int(num_patches ** 0.5))self.window_size = window_sizeself.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5self.q = nn.Linear(dim, dim, bias=qkv_bias)self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)self.sr_ratio = sr_ratioif sr_ratio > 1:self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)self.norm = nn.LayerNorm(dim)self.agent_num = agent_numself.dwc = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=(3, 3), padding=1, groups=dim)self.an_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))self.na_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))self.ah_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, window_size[0] // sr_ratio, 1))self.aw_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, 1, window_size[1] // sr_ratio))self.ha_bias = nn.Parameter(torch.zeros(1, num_heads, window_size[0], 1, agent_num))self.wa_bias = nn.Parameter(torch.zeros(1, num_heads, 1, window_size[1], agent_num))trunc_normal_(self.an_bias, std=.02)trunc_normal_(self.na_bias, std=.02)trunc_normal_(self.ah_bias, std=.02)trunc_normal_(self.aw_bias, std=.02)trunc_normal_(self.ha_bias, std=.02)trunc_normal_(self.wa_bias, std=.02)pool_size = int(agent_num ** 0.5)self.pool = nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size))self.softmax = nn.Softmax(dim=-1)def forward(self, x, H, W):b, n, c = x.shapenum_heads = self.num_headshead_dim = c // num_headsq = self.q(x)if self.sr_ratio > 1:x_ = x.permute(0, 2, 1).reshape(b, c, H, W)x_ = self.sr(x_).reshape(b, c, -1).permute(0, 2, 1)x_ = self.norm(x_)kv = self.kv(x_).reshape(b, -1, 2, c).permute(2, 0, 1, 3)else:kv = self.kv(x).reshape(b, -1, 2, c).permute(2, 0, 1, 3)k, v = kv[0], kv[1]agent_tokens = self.pool(q.reshape(b, H, W, c).permute(0, 3, 1, 2)).reshape(b, c, -1).permute(0, 2, 1)q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)k = k.reshape(b, n // self.sr_ratio ** 2, num_heads, head_dim).permute(0, 2, 1, 3)v = v.reshape(b, n // self.sr_ratio ** 2, num_heads, head_dim).permute(0, 2, 1, 3)agent_tokens = agent_tokens.reshape(b, self.agent_num, num_heads, head_dim).permute(0, 2, 1, 3)kv_size = (self.window_size[0] // self.sr_ratio, self.window_size[1] // self.sr_ratio)position_bias1 = nn.functional.interpolate(self.an_bias, size=kv_size, mode='bilinear')position_bias1 = position_bias1.reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)position_bias2 = (self.ah_bias + self.aw_bias).reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)position_bias = position_bias1 + position_bias2agent_attn = self.softmax((agent_tokens * self.scale) @ k.transpose(-2, -1) + position_bias)agent_attn = self.attn_drop(agent_attn)agent_v = agent_attn @ vagent_bias1 = nn.functional.interpolate(self.na_bias, size=self.window_size, mode='bilinear')agent_bias1 = agent_bias1.reshape(1, num_heads, self.agent_num, -1).permute(0, 1, 3, 2).repeat(b, 1, 1, 1)agent_bias2 = (self.ha_bias + self.wa_bias).reshape(1, num_heads, -1, self.agent_num).repeat(b, 1, 1, 1)agent_bias = agent_bias1 + agent_bias2q_attn = self.softmax((q * self.scale) @ agent_tokens.transpose(-2, -1) + agent_bias)q_attn = self.attn_drop(q_attn)x = q_attn @ agent_vx = x.transpose(1, 2).reshape(b, n, c)v = v.transpose(1, 2).reshape(b, H // self.sr_ratio, W // self.sr_ratio, c).permute(0, 3, 1, 2)if self.sr_ratio > 1:v = nn.functional.interpolate(v, size=(H, W), mode='bilinear')x = x + self.dwc(v).permute(0, 2, 3, 1).reshape(b, n, c)x = self.proj(x)x = self.proj_drop(x)return xif __name__ == '__main__':dim = 4num_patches = 64block = AgentAttention(dim=dim, num_patches=num_patches)H, W = 8,8x = torch.rand(1, num_patches, dim)output = block(x, H, W)print(f"Input size: {x.size()}")print(f"Output size: {output.size()}")

YOLO小伙伴可进群交流:

相关文章:

YOLO即插即用模块---AgentAttention

Agent Attention: On the Integration of Softmax and Linear Attention 论文地址:https://arxiv.org/pdf/2312.08874 问题: 普遍使用的 Softmax 注意力机制在视觉 Transformer 模型中计算复杂度过高,限制了其在各种场景中的应用。 方法&a…...

探索开源语音识别的未来:高效利用先进的自动语音识别技术20241030

🚀 探索开源语音识别的未来:高效利用自动语音识别技术 🌟 引言 在数字化时代,语音识别技术正在引领人机交互的新潮流,为各行业带来了颠覆性的改变。开源的自动语音识别(ASR)系统,如…...

学习路之TP6--workman安装

一、安装 首先通过 composer 安装 composer require topthink/think-worker 报错: 分析:最新版本需要TP8,或装低版本的 composer require topthink/think-worker:^3.*安装后, 增加目录 vendor\workerman vendor\topthink\think-w…...

.NET内网实战:通过白名单文件反序列化漏洞绕过UAC

01阅读须知 此文所节选自小报童《.NET 内网实战攻防》专栏,主要内容有.NET在各个内网渗透阶段与Windows系统交互的方式和技巧,对内网和后渗透感兴趣的朋友们可以订阅该电子报刊,解锁更多的报刊内容。 02基本介绍 03原理分析 在渗透测试和红…...

AI Agents - 自动化项目:计划、评估和分配

Agents: Role 角色Goal 目标Backstory 背景故事 Tasks: Description 描述Expected Output 期望输出Agent 代理 Automated Project: Planning, Estimation, and Allocation Initial Imports 1.本地文件helper.py # Add your utilities or helper functions to…...

Git的.gitignore文件

一、各语言对应的.gitignore模板文件 项目地址:https://github.com/github/gitignore 二、.gitignore文件不生效 .gitignore文件只是ignore没有被追踪的文件,已被追踪的文件,要先删除缓存文件。 # 单个文件 git rm --cached file/path/to…...

网站安全,WAF网站保护暴力破解

雷池的核心功能 通过过滤和监控 Web 应用与互联网之间的 HTTP 流量,功能包括: SQL 注入保护:防止恶意 SQL 代码的注入,保护网站数据安全。跨站脚本攻击 (XSS):阻止攻击者在用户浏览器中执行恶意脚本。暴力破解防护&a…...

深度学习:梯度下降算法简介

梯度下降算法简介 梯度下降算法 我们思考这样一个问题,现在需要用一条直线来回归拟合这三个点,直线的方程是 y w ^ x b y \hat{w}x b yw^xb,我们假设斜率 w ^ \hat{w} w^是已知的,现在想要找到一个最好的截距 b b b。 一条…...

SparkSQL整合Hive后,如何启动hiveserver2服务

当spark sql与hive整合后,我们就无法启动hiveserver2的服务了,每次都要先启动hive的元数据服务(nohup hive --service metastore)才能启动hive,之前的beeline命令也用不了,hiveserver2的无法启动,这也导致我…...

前端路由如何从0开始配置?vue-router 的使用

在 Web 开发中,路由是指根据 URL 的不同部分将请求分发到不同的处理函数或页面的过程。路由是单页应用(SPA, Single Page Application)和服务器端渲染(SSR, Server-Side Rendering)应用中的一个重要概念。 在开发中如何…...

Java中的运算符【与C语言的区别】

目录 1. 算术运算符 1.0 赋值运算符: 1.1 四则运算符: - * / % 【取余与C有点不同】 1.2 增量运算符: - * / % * 【右侧运算结果会自动转换类型】 1.3 自增、自减:、-- 2. 关系运算符 3. 逻辑运算符 3.1 短路求值 3.2 【…...

二、基础语法

入门了解 注释 **作用:**在代码中加一些注释和说明,方便自己或者其他程序员阅读代码 两种格式: 单行注释:// 描述信息 通常放在一行代码的上方,或者一条语句的末尾,对该行代码进行说明 多行注释&#x…...

DB-GPT系列(一):DB-GPT能帮你做什么?

DB-GPT是一个开源的AI原生数据应用开发框架(AI Native Data App Development framework with AWEL and Agents),围绕大模型提供灵活、可拓展的AI原生数据应用管理与开发能力,可以帮助企业快速构建、部署智能AI数据应用,通过智能数据分析、洞察…...

【Python各个击破】numpy

简介 NumPy是一个开源的Python库,它提供了一个强大的N维数组对象和许多用于操作这些数组的函数。它是大多数Python科学计算的基础,包括Pandas、SciPy和scikit-learn等库都建立在NumPy之上。 安装 !pip install numpy导入 import numpy as np用法 # …...

【STM32 Blue Pill编程实例】-4位7段数码管使用

4位7段数码管使用 文章目录 4位7段数码管使用1、7段数码介绍2、硬件准备与接线3、模块配置4、代码实现在本文中,我们将介绍如何将 STM32 Blue Pill开发板与 4 位 7 段数码管连接,并在 STM32CubeIDE 中对其进行编程。 在文章中首先将介绍 4 位 7 段数码管及其与 STM32 Blue Pi…...

[进阶]java基础之集合(三)数据结构

文章目录 数据结构概述常见的数据结构数据结构(栈)数据结构(队列)数据结构(数组)数据结构(链表) 数据结构 概述 数据结构是计算机底层存储、组织数据的方式。是指数据相互之间是以什么方式排列在一起的。数据结构是为了更加方便的管理和使用数据,需要结合具体的业…...

《Apache Cordova/PhoneGap 使用技巧分享》

一、引言 在移动应用开发的领域中,Apache Cordova(也被称为 PhoneGap)是一个强大的工具,它允许开发者使用 HTML、CSS 和 JavaScript 等 Web 技术来构建跨平台的移动应用。这种方式不仅能够提高开发效率,还能降低开发成…...

SCP(Secure Copy

SCP(Secure Copy)‌是Linux系统下基于SSH协议的安全文件传输工具,用于在本地和远程主机间安全、快速地传输文件和目录。SCP命令通过加密传输确保数据的安全性,并且不占用过多系统资源‌。 SCP的基本用法 ‌基本语法‌&#xff1a…...

uniApp 省市区自定义数据

关于自定义省市区选择 其实也是用了 uniApp的内置组件 picker <picker mode"multiSelector" change"bindRegionChange" columnchange"bindMultiPickerColumnChange" :value"valueRegion" :range"multiArray"><v…...

图解Redis 06 | Hash数据类型的原理及应用场景

介绍 Hash 类型特别适合存储对象&#xff0c;例如用户信息等。 String类型也可以用于存储用户信息&#xff0c;Hash与String存储用户信息的区别如下图所示&#xff1a; 内部实现 Hash 类型 的底层数据结构是通过压缩列表&#xff08;Ziplist&#xff09;或哈希表&#xff…...

在 Windows 系统上设置 MySQL8.0以支持远程连接

在 Windows 系统上设置 MySQL8.0以支持远程连接的步骤如下&#xff1a; 步骤1: 修改 MySQL 配置文件1. 找到配置文件&#xff1a; MySQL 的配置文件通常为 my.ini&#xff0c;通常位于 C:\ProgramData\MySQL\MySQL Server8.0\&#xff08;确保查看隐藏文件和文件夹&#xff09…...

四种基本的编程命名规范

目前&#xff0c;共有四种基本的编程命名规范&#xff0c;分别是匈牙利命名法、驼峰式命名法、帕斯卡命名法和下划线命名法&#xff0c;其中前三种命名法较为流行。 例如&#xff1a;iMyData是一个匈牙利命名法&#xff1b;myData是一个驼峰式命名法&#xff1b;MyData是一个帕…...

【前端】在 TypeScript 中使用 AbortController 取消异步请求

在 TypeScript 中使用 AbortController 来取消异步请求&#xff0c;尤其是像 fetch 这样的请求&#xff0c;可以提供一种优雅的方式来中止长时间运行的操作。下面是一个详细的步骤说明&#xff0c;展示如何在 TypeScript 中使用 AbortController 取消 fetch 请求。 步骤 1&…...

k8s知识点总结

docker 名称空间 分类 Docker中的名称空间用于提供进程隔离&#xff0c;确保容器之间的资源相互独立。主要分类包括&#xff1a; PID Namespace&#xff1a;进程ID隔离&#xff0c;使每个容器有自己的进程树&#xff0c;容器内的进程不会干扰其他容器或主机上的进程。 NET Nam…...

论文阅读:三星-TinyClick

《Single-Turn Agent for Empowering GUI Automation》 赋能GUI自动化的单轮代理 摘要 我们介绍了一个用于图形用户界面&#xff08;GUI&#xff09;交互任务的单轮代理&#xff0c;使用了视觉语言模型Florence-2-Base。该代理的主要任务是识别与用户指令相对应的UI元素的屏幕…...

Windows on ARM上使用sherpa-onnx实现语音识别

Windows on ARM上使用sherpa-onnx实现语音识别 下载模型准备声音文件测试下载模型 模型所在的地址在这里(),通过git命令将模型下载下来 模型:hfd地址 git clone https://hf-mirror.com/csukuangfj/sherpa-onnx-streaming-paraformer-bilingual-zh-en将如下的代码保存成一个…...

Unity 打包AB Timeline 引用丢失,错误问题

1、裁剪 在 link.xml 添加 <assembly fullname"Unity.Timeline" preserve"all"/> 上面这一步我其实做了&#xff0c;但还是不行&#xff0c;各种搜索&#xff0c;不得解&#xff0c;还有创建一个空的Timeline 放到 Resources目录下的&#xff0c;也…...

【Kettle的安装与使用】使用Kettle实现mysql和hive的数据传输(使用Kettle将mysql数据导入hive、将hive数据导入mysql)

文章目录 一、安装1、解压2、修改字符集3、启动 二、实战1、将hive数据导入mysql2、将mysql数据导入到hive 一、安装 Kettle的安装包在文章结尾 1、解压 在windows中解压到一个非中文路径下 2、修改字符集 修改 spoon.bat 文件 "-Dfile.encodingUTF-8"3、启动…...

STM32的hal库在实现延时函数(例如:Delay_ms 等)为什么用滴答定时(Systick)而不是定时器定时中断,也不是RTC?

STM32的HAL库在实现延时函数&#xff08;如Delay_ms等&#xff09;时选择使用滴答定时器&#xff08;Systick&#xff09;而非定时器定时中断或RTC&#xff08;实时时钟&#xff09;&#xff0c;主要基于以下几个原因&#xff1a; Systick定时器的优势 集成在NVIC中&#xff…...

刚刚买的域名被DNS劫持了怎么处理

在当今数字化的时代&#xff0c;域名作为网络世界的重要标识&#xff0c;对于个人和企业的在线业务都至关重要。然而&#xff0c;有时会遭遇令人头疼的问题&#xff0c;比如新买的域名被DNS劫持。这不仅会影响网站的正常访问&#xff0c;还可能导致用户信息泄露、业务受损等严重…...

罗岗网站建设公司/网站收录服务

首先,要分别在两个文件中实现以下两个类 class Object { public: NewType ToType(); }; class NewType : public Object { } -------------------------------------------------------------------------------- 做法1 -------------------------------------------------…...

做信息网站怎么赚钱/代写文章价格表

自从2008年以来&#xff0c;太多的同学、朋友&#xff0c;QQ等网络帐号被盗。 然后&#xff0c;盗号者来骗钱。比如 借用账号、帮忙支付费用等。 盗号者固然可恶&#xff0c;传统骗子的网络版。 可是&#xff0c;这些帐号的主人就仅仅是可怜么&#xff1f; 自己的号被盗&#x…...

如何把网站主关键词做到百度首页/网上做广告推广

MySQL 有一个和优秀的语法 create table ... like &#xff0c; 可以快速复制一张表&#xff0c;创建其副本。 PostgreSQL 也有类似的语法&#xff0c;而且更加灵活&#xff0c;不过要注意些细节。先来看看MySQL 语法&#xff1a; create table ... like原始表T1&#xff0c;结…...

做养生网站需要资质吗/成都百度推广优化创意

下载mongodb压缩包。官网下载即可。 安装还是比较简单&#xff0c;需要把压缩包解压然后配置环境变量启动即可&#xff0c;其中还有些小问题&#xff0c;后续会提到&#xff0c;比如外网访问等。 连接服务器&#xff0c;将mongo压缩包传到服务器上&#xff0c;在服务器上新建…...

合肥网站建设技术/营销策划咨询

2018-11-20 回答就楼主的问题以及这机子的情况说几句&#xff1a;1.你这是第三代ivb系列i3&#xff0c;属于双核四线程&#xff0c;并非四核u。但综合性能是非常不错的我这里要展开篇幅&#xff0c;说说关于核心工作原理的问题&#xff0c;让楼主从核心数量决定一切的误区中走出…...

免费博客 wordpress/seo平台优化服务

1. markdown-index 最近做了一个Jetbrains的插件&#xff0c;叫markdown-index&#xff0c;它的作用是为Markdown文档的标题自动添加序号&#xff0c;效果如下&#xff1a; 目前已经可以在Jetbrains全家桶的插件市场中搜索到。 2. 为什么我要做这个插件 我习惯用Markdown写完…...