YOLO即插即用模块---AgentAttention
Agent Attention: On the Integration of Softmax and Linear Attention
论文地址:https://arxiv.org/pdf/2312.08874
问题: 普遍使用的 Softmax 注意力机制在视觉 Transformer 模型中计算复杂度过高,限制了其在各种场景中的应用。
方法: 提出了一个新的注意力机制,名为 Agent Attention,通过引入一组代理 token (A) 来解决计算复杂度过高的问题。
具体步骤:
-
代理聚合 (Agent Aggregation): 将代理 token (A) 作为查询 token (Q) 的代理,从键 (K) 和值 (V) 中聚合信息,形成代理特征 (VA)。
-
代理广播 (Agent Broadcast): 将代理 token (A) 作为键,将全局信息从代理特征 (VA) 广播到每个查询 token (Q),形成最终的输出。
代理 token (A) 的获取方式:
-
可学习的参数
-
从输入特征中提取 (例如,通过池化或卷积)
Agent Attention 模块:
-
包含纯 Agent Attention、代理偏置 (Agent Bias) 和深度可分离卷积 (DWC) 模块。
-
代理偏置用于添加位置信息,帮助不同的代理 token 关注不同的区域。
-
DWC 模块用于保持特征多样性,弥补线性注意力的不足。
-
Agent Attention 的优势:
-
高效计算和高表达能力: 结合了 Softmax 注意力和线性注意力的优点,既降低了计算复杂度,又保持了高表达能力。
-
大感受野: 可以采用更大的感受野,甚至全局感受野,同时保持相同的计算量。P8
实验结果:
-
在图像分类、目标检测、语义分割和图像生成等任务上,Agent Attention 都取得了显著的性能提升。
-
在高分辨率场景中,Agent Attention 表现出优异的性能。
-
将 Agent Attention 应用于 Stable Diffusion,可以加速图像生成过程,并显著提高图像生成质量,无需任何额外的训练。
总结: Agent Attention 是一种高效且高表达的注意力机制,可以有效地解决 Softmax 注意力计算复杂度过高的问题,在各种视觉任务中取得了显著的性能提升,特别是在高分辨率场景中。
即插即用代码:
import torch
import torch.nn as nn
from timm.models.layers import trunc_normal_class AgentAttention(nn.Module):def __init__(self, dim, num_patches, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0.,sr_ratio=1, agent_num=49, **kwargs):super().__init__()assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."self.dim = dimself.num_patches = num_patcheswindow_size = (int(num_patches ** 0.5), int(num_patches ** 0.5))self.window_size = window_sizeself.num_heads = num_headshead_dim = dim // num_headsself.scale = head_dim ** -0.5self.q = nn.Linear(dim, dim, bias=qkv_bias)self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)self.sr_ratio = sr_ratioif sr_ratio > 1:self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)self.norm = nn.LayerNorm(dim)self.agent_num = agent_numself.dwc = nn.Conv2d(in_channels=dim, out_channels=dim, kernel_size=(3, 3), padding=1, groups=dim)self.an_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))self.na_bias = nn.Parameter(torch.zeros(num_heads, agent_num, 7, 7))self.ah_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, window_size[0] // sr_ratio, 1))self.aw_bias = nn.Parameter(torch.zeros(1, num_heads, agent_num, 1, window_size[1] // sr_ratio))self.ha_bias = nn.Parameter(torch.zeros(1, num_heads, window_size[0], 1, agent_num))self.wa_bias = nn.Parameter(torch.zeros(1, num_heads, 1, window_size[1], agent_num))trunc_normal_(self.an_bias, std=.02)trunc_normal_(self.na_bias, std=.02)trunc_normal_(self.ah_bias, std=.02)trunc_normal_(self.aw_bias, std=.02)trunc_normal_(self.ha_bias, std=.02)trunc_normal_(self.wa_bias, std=.02)pool_size = int(agent_num ** 0.5)self.pool = nn.AdaptiveAvgPool2d(output_size=(pool_size, pool_size))self.softmax = nn.Softmax(dim=-1)def forward(self, x, H, W):b, n, c = x.shapenum_heads = self.num_headshead_dim = c // num_headsq = self.q(x)if self.sr_ratio > 1:x_ = x.permute(0, 2, 1).reshape(b, c, H, W)x_ = self.sr(x_).reshape(b, c, -1).permute(0, 2, 1)x_ = self.norm(x_)kv = self.kv(x_).reshape(b, -1, 2, c).permute(2, 0, 1, 3)else:kv = self.kv(x).reshape(b, -1, 2, c).permute(2, 0, 1, 3)k, v = kv[0], kv[1]agent_tokens = self.pool(q.reshape(b, H, W, c).permute(0, 3, 1, 2)).reshape(b, c, -1).permute(0, 2, 1)q = q.reshape(b, n, num_heads, head_dim).permute(0, 2, 1, 3)k = k.reshape(b, n // self.sr_ratio ** 2, num_heads, head_dim).permute(0, 2, 1, 3)v = v.reshape(b, n // self.sr_ratio ** 2, num_heads, head_dim).permute(0, 2, 1, 3)agent_tokens = agent_tokens.reshape(b, self.agent_num, num_heads, head_dim).permute(0, 2, 1, 3)kv_size = (self.window_size[0] // self.sr_ratio, self.window_size[1] // self.sr_ratio)position_bias1 = nn.functional.interpolate(self.an_bias, size=kv_size, mode='bilinear')position_bias1 = position_bias1.reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)position_bias2 = (self.ah_bias + self.aw_bias).reshape(1, num_heads, self.agent_num, -1).repeat(b, 1, 1, 1)position_bias = position_bias1 + position_bias2agent_attn = self.softmax((agent_tokens * self.scale) @ k.transpose(-2, -1) + position_bias)agent_attn = self.attn_drop(agent_attn)agent_v = agent_attn @ vagent_bias1 = nn.functional.interpolate(self.na_bias, size=self.window_size, mode='bilinear')agent_bias1 = agent_bias1.reshape(1, num_heads, self.agent_num, -1).permute(0, 1, 3, 2).repeat(b, 1, 1, 1)agent_bias2 = (self.ha_bias + self.wa_bias).reshape(1, num_heads, -1, self.agent_num).repeat(b, 1, 1, 1)agent_bias = agent_bias1 + agent_bias2q_attn = self.softmax((q * self.scale) @ agent_tokens.transpose(-2, -1) + agent_bias)q_attn = self.attn_drop(q_attn)x = q_attn @ agent_vx = x.transpose(1, 2).reshape(b, n, c)v = v.transpose(1, 2).reshape(b, H // self.sr_ratio, W // self.sr_ratio, c).permute(0, 3, 1, 2)if self.sr_ratio > 1:v = nn.functional.interpolate(v, size=(H, W), mode='bilinear')x = x + self.dwc(v).permute(0, 2, 3, 1).reshape(b, n, c)x = self.proj(x)x = self.proj_drop(x)return xif __name__ == '__main__':dim = 4num_patches = 64block = AgentAttention(dim=dim, num_patches=num_patches)H, W = 8,8x = torch.rand(1, num_patches, dim)output = block(x, H, W)print(f"Input size: {x.size()}")print(f"Output size: {output.size()}")
YOLO小伙伴可进群交流:
相关文章:

YOLO即插即用模块---AgentAttention
Agent Attention: On the Integration of Softmax and Linear Attention 论文地址:https://arxiv.org/pdf/2312.08874 问题: 普遍使用的 Softmax 注意力机制在视觉 Transformer 模型中计算复杂度过高,限制了其在各种场景中的应用。 方法&a…...
探索开源语音识别的未来:高效利用先进的自动语音识别技术20241030
🚀 探索开源语音识别的未来:高效利用自动语音识别技术 🌟 引言 在数字化时代,语音识别技术正在引领人机交互的新潮流,为各行业带来了颠覆性的改变。开源的自动语音识别(ASR)系统,如…...

学习路之TP6--workman安装
一、安装 首先通过 composer 安装 composer require topthink/think-worker 报错: 分析:最新版本需要TP8,或装低版本的 composer require topthink/think-worker:^3.*安装后, 增加目录 vendor\workerman vendor\topthink\think-w…...

.NET内网实战:通过白名单文件反序列化漏洞绕过UAC
01阅读须知 此文所节选自小报童《.NET 内网实战攻防》专栏,主要内容有.NET在各个内网渗透阶段与Windows系统交互的方式和技巧,对内网和后渗透感兴趣的朋友们可以订阅该电子报刊,解锁更多的报刊内容。 02基本介绍 03原理分析 在渗透测试和红…...
AI Agents - 自动化项目:计划、评估和分配
Agents: Role 角色Goal 目标Backstory 背景故事 Tasks: Description 描述Expected Output 期望输出Agent 代理 Automated Project: Planning, Estimation, and Allocation Initial Imports 1.本地文件helper.py # Add your utilities or helper functions to…...
Git的.gitignore文件
一、各语言对应的.gitignore模板文件 项目地址:https://github.com/github/gitignore 二、.gitignore文件不生效 .gitignore文件只是ignore没有被追踪的文件,已被追踪的文件,要先删除缓存文件。 # 单个文件 git rm --cached file/path/to…...
网站安全,WAF网站保护暴力破解
雷池的核心功能 通过过滤和监控 Web 应用与互联网之间的 HTTP 流量,功能包括: SQL 注入保护:防止恶意 SQL 代码的注入,保护网站数据安全。跨站脚本攻击 (XSS):阻止攻击者在用户浏览器中执行恶意脚本。暴力破解防护&a…...

深度学习:梯度下降算法简介
梯度下降算法简介 梯度下降算法 我们思考这样一个问题,现在需要用一条直线来回归拟合这三个点,直线的方程是 y w ^ x b y \hat{w}x b yw^xb,我们假设斜率 w ^ \hat{w} w^是已知的,现在想要找到一个最好的截距 b b b。 一条…...

SparkSQL整合Hive后,如何启动hiveserver2服务
当spark sql与hive整合后,我们就无法启动hiveserver2的服务了,每次都要先启动hive的元数据服务(nohup hive --service metastore)才能启动hive,之前的beeline命令也用不了,hiveserver2的无法启动,这也导致我…...

前端路由如何从0开始配置?vue-router 的使用
在 Web 开发中,路由是指根据 URL 的不同部分将请求分发到不同的处理函数或页面的过程。路由是单页应用(SPA, Single Page Application)和服务器端渲染(SSR, Server-Side Rendering)应用中的一个重要概念。 在开发中如何…...

Java中的运算符【与C语言的区别】
目录 1. 算术运算符 1.0 赋值运算符: 1.1 四则运算符: - * / % 【取余与C有点不同】 1.2 增量运算符: - * / % * 【右侧运算结果会自动转换类型】 1.3 自增、自减:、-- 2. 关系运算符 3. 逻辑运算符 3.1 短路求值 3.2 【…...

二、基础语法
入门了解 注释 **作用:**在代码中加一些注释和说明,方便自己或者其他程序员阅读代码 两种格式: 单行注释:// 描述信息 通常放在一行代码的上方,或者一条语句的末尾,对该行代码进行说明 多行注释&#x…...

DB-GPT系列(一):DB-GPT能帮你做什么?
DB-GPT是一个开源的AI原生数据应用开发框架(AI Native Data App Development framework with AWEL and Agents),围绕大模型提供灵活、可拓展的AI原生数据应用管理与开发能力,可以帮助企业快速构建、部署智能AI数据应用,通过智能数据分析、洞察…...

【Python各个击破】numpy
简介 NumPy是一个开源的Python库,它提供了一个强大的N维数组对象和许多用于操作这些数组的函数。它是大多数Python科学计算的基础,包括Pandas、SciPy和scikit-learn等库都建立在NumPy之上。 安装 !pip install numpy导入 import numpy as np用法 # …...

【STM32 Blue Pill编程实例】-4位7段数码管使用
4位7段数码管使用 文章目录 4位7段数码管使用1、7段数码介绍2、硬件准备与接线3、模块配置4、代码实现在本文中,我们将介绍如何将 STM32 Blue Pill开发板与 4 位 7 段数码管连接,并在 STM32CubeIDE 中对其进行编程。 在文章中首先将介绍 4 位 7 段数码管及其与 STM32 Blue Pi…...
[进阶]java基础之集合(三)数据结构
文章目录 数据结构概述常见的数据结构数据结构(栈)数据结构(队列)数据结构(数组)数据结构(链表) 数据结构 概述 数据结构是计算机底层存储、组织数据的方式。是指数据相互之间是以什么方式排列在一起的。数据结构是为了更加方便的管理和使用数据,需要结合具体的业…...
《Apache Cordova/PhoneGap 使用技巧分享》
一、引言 在移动应用开发的领域中,Apache Cordova(也被称为 PhoneGap)是一个强大的工具,它允许开发者使用 HTML、CSS 和 JavaScript 等 Web 技术来构建跨平台的移动应用。这种方式不仅能够提高开发效率,还能降低开发成…...
SCP(Secure Copy
SCP(Secure Copy)是Linux系统下基于SSH协议的安全文件传输工具,用于在本地和远程主机间安全、快速地传输文件和目录。SCP命令通过加密传输确保数据的安全性,并且不占用过多系统资源。 SCP的基本用法 基本语法:…...
uniApp 省市区自定义数据
关于自定义省市区选择 其实也是用了 uniApp的内置组件 picker <picker mode"multiSelector" change"bindRegionChange" columnchange"bindMultiPickerColumnChange" :value"valueRegion" :range"multiArray"><v…...

图解Redis 06 | Hash数据类型的原理及应用场景
介绍 Hash 类型特别适合存储对象,例如用户信息等。 String类型也可以用于存储用户信息,Hash与String存储用户信息的区别如下图所示: 内部实现 Hash 类型 的底层数据结构是通过压缩列表(Ziplist)或哈希表ÿ…...

使用docker在3台服务器上搭建基于redis 6.x的一主两从三台均是哨兵模式
一、环境及版本说明 如果服务器已经安装了docker,则忽略此步骤,如果没有安装,则可以按照一下方式安装: 1. 在线安装(有互联网环境): 请看我这篇文章 传送阵>> 点我查看 2. 离线安装(内网环境):请看我这篇文章 传送阵>> 点我查看 说明:假设每台服务器已…...
在软件开发中正确使用MySQL日期时间类型的深度解析
在日常软件开发场景中,时间信息的存储是底层且核心的需求。从金融交易的精确记账时间、用户操作的行为日志,到供应链系统的物流节点时间戳,时间数据的准确性直接决定业务逻辑的可靠性。MySQL作为主流关系型数据库,其日期时间类型的…...

《通信之道——从微积分到 5G》读书总结
第1章 绪 论 1.1 这是一本什么样的书 通信技术,说到底就是数学。 那些最基础、最本质的部分。 1.2 什么是通信 通信 发送方 接收方 承载信息的信号 解调出其中承载的信息 信息在发送方那里被加工成信号(调制) 把信息从信号中抽取出来&am…...

k8s业务程序联调工具-KtConnect
概述 原理 工具作用是建立了一个从本地到集群的单向VPN,根据VPN原理,打通两个内网必然需要借助一个公共中继节点,ktconnect工具巧妙的利用k8s原生的portforward能力,简化了建立连接的过程,apiserver间接起到了中继节…...

成都鼎讯硬核科技!雷达目标与干扰模拟器,以卓越性能制胜电磁频谱战
在现代战争中,电磁频谱已成为继陆、海、空、天之后的 “第五维战场”,雷达作为电磁频谱领域的关键装备,其干扰与抗干扰能力的较量,直接影响着战争的胜负走向。由成都鼎讯科技匠心打造的雷达目标与干扰模拟器,凭借数字射…...

【论文阅读28】-CNN-BiLSTM-Attention-(2024)
本文把滑坡位移序列拆开、筛优质因子,再用 CNN-BiLSTM-Attention 来动态预测每个子序列,最后重构出总位移,预测效果超越传统模型。 文章目录 1 引言2 方法2.1 位移时间序列加性模型2.2 变分模态分解 (VMD) 具体步骤2.3.1 样本熵(S…...

Redis数据倾斜问题解决
Redis 数据倾斜问题解析与解决方案 什么是 Redis 数据倾斜 Redis 数据倾斜指的是在 Redis 集群中,部分节点存储的数据量或访问量远高于其他节点,导致这些节点负载过高,影响整体性能。 数据倾斜的主要表现 部分节点内存使用率远高于其他节…...
Java多线程实现之Thread类深度解析
Java多线程实现之Thread类深度解析 一、多线程基础概念1.1 什么是线程1.2 多线程的优势1.3 Java多线程模型 二、Thread类的基本结构与构造函数2.1 Thread类的继承关系2.2 构造函数 三、创建和启动线程3.1 继承Thread类创建线程3.2 实现Runnable接口创建线程 四、Thread类的核心…...
#Uniapp篇:chrome调试unapp适配
chrome调试设备----使用Android模拟机开发调试移动端页面 Chrome://inspect/#devices MuMu模拟器Edge浏览器:Android原生APP嵌入的H5页面元素定位 chrome://inspect/#devices uniapp单位适配 根路径下 postcss.config.js 需要装这些插件 “postcss”: “^8.5.…...
LangChain知识库管理后端接口:数据库操作详解—— 构建本地知识库系统的基础《二》
这段 Python 代码是一个完整的 知识库数据库操作模块,用于对本地知识库系统中的知识库进行增删改查(CRUD)操作。它基于 SQLAlchemy ORM 框架 和一个自定义的装饰器 with_session 实现数据库会话管理。 📘 一、整体功能概述 该模块…...