YOLOv5改进系列(1)——添加CBAM注意力机制
一、如何理解注意力机制
假设你正在阅读一本书,同时有人在你旁边说话。当你听到某些关键字时,比如“你的名字”或者“你感兴趣的话题”,你会自动把注意力从书上转移到他们的谈话上,尽管你并没有完全忽略书本的内容。这就是注意力机制的核心思想:动态地根据重要性来分配注意力,而不是对输入的信息一视同仁地处理。
在计算机视觉或自然语言处理(NLP)中,注意力机制让模型能够灵活地聚焦于输入数据的某些重要部分。例如,在图像分类任务中,模型不需要对整张图像的每一个像素都一视同仁,而是可以专注于那些关键区域,如目标物体的边缘或特征。在句子处理时,模型可以根据句子的上下文,专注于那些对语义理解最为关键的单词或短语,而忽略不太重要的部分。
注意力机制就像在复杂信息处理中自动筛选和重点关注的过程,帮助模型更智能地选择和处理最有用的信息。
作用
- 提升准确性:注意力机制聚焦关键信息,提升预测精度。
- 增强可解释性:能更清晰展示模型决策过程。
- 处理变长数据:适用于文本、语音等不定长序列数据。
不足
- 计算开销大:需要计算每个位置的权重,耗时长。
- 易过拟合:复杂权重可能导致模型在训练集上表现过好,泛化能力弱。
- 数据需求高:需要大量数据训练,否则效果不佳。
二、CBAM注意力机制
论文名称:《CBAM: Convolutional Block Attention Module》
论文地址:https://arxiv.org/pdf/1807.06521
论文代码:GitHub - Jongchan/attention-module: Official PyTorch code for "BAM: Bottleneck Attention Module (BMVC2018)" and "CBAM: Convolutional Block Attention Module (ECCV2018)"
CBAM从通道channel和空间spatial两个作用域出发,实现从通道到空间的顺序注意力结构。空间注意力可使神经网络更加关注在图像分类中决定性作用的像素区域而忽略无关紧要的区域,通道注意力则用于处理特征图通道的分配关系,同时对两个维度进行注意力分配加强了注意力机制对模型性能的提升效果。
2.1 CAM通道注意力模块
shared MLP
-
输入特征图(H×W×C):
- 先将输入的特征图(H×W×C)分别经过基于宽度和高度的最大池化和平均池化,对特征图按两个维度压缩,得到两个1×1×C的特征图。
-
池化后的特征图进行处理:
- 将最大池化和平均池化的结果利用共享的全连接层(Shared MLP)进行处理:
- 先通过一个全连接层下降通道数(C -> C/4)。
- 然后再通过另一个全连接层恢复通道数(C/4 -> C)。
- 将最大池化和平均池化的结果利用共享的全连接层(Shared MLP)进行处理:
-
生成权重:
- 将共享的全连接层所得到的结果相加后再使用Sigmoid激活函数,生成最终的channel attention feature,得到每个通道的权重值(0~1之间)。
-
特征调整:
- 将权重通过逐通道相加到输入特征图上,生成最终调整后的特征图。
代码如下所示:
import torch import torch.nn as nnclass ChannelAttentionModule(nn.Module):def __init__(self, in_channels, reduction=4):super(ChannelAttentionModule, self).__init__()# 使用最大池化和平均池化self.avg_pool = nn.AdaptiveAvgPool2d(1) # 输出形状 1x1self.max_pool = nn.AdaptiveMaxPool2d(1) # 输出形状 1x1# 全连接层用于减少通道数再增加回来self.fc1 = nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False) # 1x1卷积,降维self.relu = nn.ReLU() # 激活函数ReLUself.fc2 = nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False) # 1x1卷积,升维# 使用 Sigmoid 激活函数self.sigmoid = nn.Sigmoid()def forward(self, x):# 输入特征图的形状为 (B, C, H, W)avg_out = self.avg_pool(x) # 平均池化max_out = self.max_pool(x) # 最大池化# 使用共享的全连接层(MLP)进行处理avg_out = self.fc2(self.relu(self.fc1(avg_out)))max_out = self.fc2(self.relu(self.fc1(max_out)))# 将池化结果加起来out = avg_out + max_out# 使用Sigmoid激活out = self.sigmoid(out)# 通过逐通道相乘的方式调整输入特征图return x * out# 测试模块 if __name__ == "__main__":# 输入张量 (batch_size, channels, height, width)input_tensor = torch.randn(1, 64, 32, 32) # 示例输入 (B=1, C=64, H=32, W=32)# 实例化通道注意力模块cam = ChannelAttentionModule(in_channels=64)# 前向传播output = cam(input_tensor)print(output.shape) # 输出特征图的形状
2.2 SAM空间注意力模块
具体流程如下:
将上面CAM模块输出的特征图F'作为本模块的输入特征图。
首先,对输入特征图在通道维度下做最大池化和平均池化,将池化后的两张特征图在通道维度堆叠(concat)。
然后,经过一个7×7卷积(7×7比3×3效果更好)操作,降维为1个channel,即叠积核融合通道信息,特征图的shape从b,2,h,w - >b,1,h,w。
最后,将卷积后的结果经过sigmoid函数对特征图的空间权重归一化,再将输入特征图和权重相乘。
class spatial_attention(nn.Module):def __init__(self, kernel_size=7):super(spatial_attention, self).__init__()# 为了保持卷积前后的特征图shape相同,卷积时需要paddingpadding = kernel_size // 2 # 确保7x7卷积后输出形状与输入一致# 7×7卷积融合通道信息,输入[b, 2, h, w],输出[b, 1, h, w]self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, padding=padding, bias=False)# sigmoid激活函数self.sigmoid = nn.Sigmoid()def forward(self, inputs):# 在通道维度上最大池化 [b, 1, h, w],keepdim保留原有深度x_maxpool, _ = torch.max(inputs, dim=1, keepdim=True)# 在通道维度上平均池化 [b, 1, h, w]x_avgpool = torch.mean(inputs, dim=1, keepdim=True)# 池化后的结果在通道维度上堆叠 [b, 2, h, w]x = torch.cat([x_maxpool, x_avgpool], dim=1)# 卷积融合通道信息, [b, 2, h, w] ==> [b, 1, h, w]x = self.conv(x)# 空间权重归一化x = self.sigmoid(x)# 输入特征图和空间权重相乘outputs = inputs * xreturn outputs
-
输入特征图:假设输入一个中间特征图 F∈RC×H×W,其中 C、H、W 分别表示通道数、高度和宽度。
-
通道注意力模块:
- 对输入特征图 F 进行全局最大池化(Max Pooling)和全局平均池化(Average Pooling)操作,生成两个不同的空间信息描述。即通过最大池化和平均池化将特征图 F 沿空间维度压缩为1维向量。
- 将池化后的结果通过一个共享的多层感知机(MLP),以融合不同空间信息并输出通道维度的权重 Mc(F)。
- 最后通过公式 (3) 计算通道注意力:
- 空间注意力模块:
- 将通道加权后的特征图 F′ 进一步经过空间注意力机制。首先,再次对特征图 F′ 进行全局最大池化和全局平均池化操作,但这次是在通道维度进行操作。
- 将池化结果拼接后,经过一个卷积操作生成空间注意力图 Ms(F′),公式 (4) 表示为:
- 最终将空间注意力图 Ms(F′) 与特征图 F′ 按元素相乘,得到最终的输出特征图 F′′。
- CBAM模块的完整过程:
- 通道注意力模块和空间注意力模块可以串联使用,先通过通道注意力调整特征,再通过空间注意力调整,完成特征图的两次加权。
三、CBAM注意力机制添加过程
1.在common.py中添加网络结构
将下面代码复制到common.py文件最下面
class ChannelAttentionModule(nn.Module):def __init__(self, in_channels, reduction=4):super(ChannelAttentionModule, self).__init__()# 使用最大池化和平均池化self.avg_pool = nn.AdaptiveAvgPool2d(1) # 输出形状 1x1self.max_pool = nn.AdaptiveMaxPool2d(1) # 输出形状 1x1# 全连接层用于减少通道数再增加回来self.fc1 = nn.Conv2d(in_channels, in_channels // reduction, 1, bias=False) # 1x1卷积,降维self.relu = nn.ReLU() # 激活函数ReLUself.fc2 = nn.Conv2d(in_channels // reduction, in_channels, 1, bias=False) # 1x1卷积,升维# 使用 Sigmoid 激活函数self.sigmoid = nn.Sigmoid()def forward(self, x):b, c, h, w = x.size() # 获取输入的形状 (batch_size, channels, height, width)# 最大池化和平均池化,输出维度 [b, c, 1, 1]max_pool = self.max_pool(x).view(b, c) # 调整池化结果维度为 [b, c]avg_pool = self.avg_pool(x).view(b, c) # 调整池化结果维度为 [b, c]# 第一个全连接层降通道数 [b, c] => [b, c/4]x_maxpool = self.fc1(max_pool)x_avgpool = self.fc1(avg_pool)# 激活函数x_maxpool = self.relu(x_maxpool)x_avgpool = self.relu(x_avgpool)# 第二个全连接层恢复通道数 [b, c/4] => [b, c]x_maxpool = self.fc2(x_maxpool)x_avgpool = self.fc2(x_avgpool)# 将最大池化和平均池化的结果相加 [b, c]x = x_maxpool + x_avgpool# Sigmoid函数权重归一化x = self.sigmoid(x)# 调整维度 [b, c] => [b, c, 1, 1]x = x.view(b, c, 1, 1)# 输入特征图和通道权重相乘 [b, c, h, w]outputs = x * xreturn outputs# def forward(self, x):# # 输入特征图的形状为 (B, C, H, W)# b,c,h,w = x.shape# max_out = self.max_pool(x) # 最大池化# avg_out = self.avg_pool(x) # 平均池化## # 使用共享的全连接层(MLP)进行处理# avg_out = self.fc2(self.relu(self.fc1(avg_out)))# max_out = self.fc2(self.relu(self.fc1(max_out)))## # 将池化结果加起来# out = avg_out + max_out## # 使用Sigmoid激活# out = self.sigmoid(out)## # 通过逐通道相乘的方式调整输入特征图# return x * outclass spatial_attention(nn.Module):def __init__(self, kernel_size=7):super(spatial_attention, self).__init__()# 为了保持卷积前后的特征图shape相同,卷积时需要paddingpadding = kernel_size // 2 # 确保7x7卷积后输出形状与输入一致# 7×7卷积融合通道信息,输入[b, 2, h, w],输出[b, 1, h, w]self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size, padding=padding, bias=False)# sigmoid激活函数self.sigmoid = nn.Sigmoid()def forward(self, inputs):# 在通道维度上最大池化 [b, 1, h, w],keepdim保留原有深度x_maxpool, _ = torch.max(inputs, dim=1, keepdim=True)# 在通道维度上平均池化 [b, 1, h, w]x_avgpool = torch.mean(inputs, dim=1, keepdim=True)# 池化后的结果在通道维度上堆叠 [b, 2, h, w]x = torch.cat([x_maxpool, x_avgpool], dim=1)# 卷积融合通道信息, [b, 2, h, w] ==> [b, 1, h, w]x = self.conv(x)# 空间权重归一化x = self.sigmoid(x)# 输入特征图和空间权重相乘outputs = inputs * xreturn outputsclass CBAM(nn.Module):def __init__(self, c1, c2, ratio=16, kernel_size=7): # ch_insuper(CBAM, self).__init__()self.channel_attention = ChannelAttentionModule(c1, ratio)self.spatial_attention = spatial_attention(kernel_size)def forward(self, x):out = self.channel_attention(x) * x # 通道注意力加权# c*h*w (通道数、高度、宽度)out = self.spatial_attention(out) * out # 空间注意力加权# c*h*w * 1*h*w(空间维度权重)return out
2.在yolo.py中添加CBAM结构
在下面找到 def parse_model(d, ch): 函数,往下找到if m in 这一行 添加CBAM
3.在yolov5s_CBAM.yaml中添加CBAM结构
在 yolov5s_CBAM.yaml
文件中添加 CBAM 模块。具体来说,有两种常见的添加方式:
- 在主干(backbone)的 SPPF(Spatial Pyramid Pooling-Fast)层之前添加一层 CBAM 模块。
- 将 Backbone 中所有的 C3 层全部替换为 CBAM 模块。
4.在yolo.py中修改yolov5s.yaml为yolov5s_CBAM.yaml
5.添加成功结果展示
6.修改train下的parse_opt开始训练
相关文章:

YOLOv5改进系列(1)——添加CBAM注意力机制
一、如何理解注意力机制 假设你正在阅读一本书,同时有人在你旁边说话。当你听到某些关键字时,比如“你的名字”或者“你感兴趣的话题”,你会自动把注意力从书上转移到他们的谈话上,尽管你并没有完全忽略书本的内容。这就是注意力机…...

无头单向非循环java版的模拟实现
【本节目标】 1.ArrayList的缺陷 2.链表 1. ArrayList的缺陷 上节课已经熟悉了 ArrayList 的使用,并且进行了简单模拟实现。通过源码知道, ArrayList 底层使用数组来存储元素: public class ArrayList<E> extends AbstractList<…...
Bert Score-文本相似性评估
Bert Score Bert Score 是基于BERT模型的一种方法。它通过计算两个句子在BERT模型中的嵌入编码之间的余弦相似度来评估它们的相似度。BERTScore考虑了上下文信息和语义信息,因此能够更准确地衡量句子之间的相似度。 安装 pip install bert-score 使用例子 一个…...
Pyenv管理Python版本,conda之外的另一套python版本管理解决方案
简介 Pyenv 是一个 python 解释器管理工具,可以对计算机中的多个 python 版本进行管理和切换。为什么要用 pyenv 管理python呢,用过的 python 人都知道,python 虽然是易用而强大的编程语言,但是 python 解释器却有多个版本&#…...

快速实现AI搜索!Fivetran 支持 Milvus 作为数据迁移目标
Fivetran 现已支持 Milvus 向量数据库作为数据迁移的目标,能够有效简化 RAG 应用和 AI 搜索中数据源接入的流程。 数据是 AI 应用的支柱,无缝连接数据是充分释放数据潜力的关键。非结构化数据对于企业搜索和检索增强生成(RAG)聊天…...
css的页面布局属性
CSS Flexbox(Flexible Box Layout)是一种用于页面布局的CSS3规范,它提供了一种更加高效的方式来布置、对齐和分配容器内元素的空间,即使它们的大小是未知或者动态变化的。Flexbox很容易处理一维布局,即在一个方向上&am…...

RTE 大会报名丨AI 时代新基建:云边端架构和 AI Infra ,RTE2024 技术专场第二弹!
所有 AI Infra 都在探寻规格和性能的最佳平衡,如何构建高可用的云边端协同架构? 语音 AI 实现 human-like 的最后一步是什么? AI 视频的爆炸增长,给新一代编解码技术提出了什么新挑战? 当大模型进化到实时多模态&am…...

【React】入门Day01 —— 从基础概念到实战应用
目录 一、React 概述 二、开发环境创建 三、JSX 基础 四、React 的事件绑定 五、React 组件基础使用 六、组件状态管理 - useState 七、组件的基础样式处理 快速入门 – React 中文文档 一、React 概述 React 是什么 由 Meta 公司开发,是用于构建 Web 和原生…...

<<机器学习实战>>10-11节笔记:生成器与线性回归手动实现
10生成器与python实现 如果是曲线规律的数据集,则需要把模型变复杂。如果是噪音较大,则需要做特征工程。 随机种子的知识点补充: 根据不同库中的随机过程,需要用对应的随机种子: 比如 llist(range(5)) random.shuf…...

链表OJ经典题目及思路总结(一)
目录 前言1.移除元素1.1 链表1.2 数组 2.双指针2.1 找链表的中间结点2.2 找倒数第k个结点 总结 前言 解代码题 先整体:首先数据结构链表的题一定要多画图,捋清问题的解决思路; 后局部:接着考虑每一步具体如何实现,框架…...

初识chatgpt
GPT到底是什么 首先,我们需要了解GPT的全称:Generative Pre-trained Transformer,即三个关键词:生成式 预训练 变换模型。 (1)什么是生成式? 即能够生成新的文本序列。 (2&#…...
【60天备战2024年11月软考高级系统架构设计师——第33天:云计算与大数据架构——大数据处理框架的应用场景】
随着大数据技术的发展,越来越多的企业开始采用大数据处理框架来解决实际问题。理解这些框架的应用场景对于架构师来说至关重要。 大数据处理框架的应用场景 实时数据分析:使用Apache Kafka与Apache Spark结合,可以实现对实时数据流的处理与…...

如何设计具体项目的数据库管理
### 例三:足协的数据库管理算法 #### 角色: - **ESFP学生**:小明 - **ENTP老师**:张老师 #### 主题:足协的数据库管理算法 --- **张老师**:小明,今天我们来讨论一下足协的数据库管理算法。你…...

对于 Vue CLI 项目如何引入Echarts以及动态获取数据
🚀个人主页:一颗小谷粒 🚀所属专栏:Web前端开发 很荣幸您能阅读我的文章,诚请评论指点,欢迎欢迎 ~ 目录 1、数据画卷—Echarts介绍 1.1 什么是Echarts? 1.2 Echarts官网地址 2、Vue CLI 项目…...

【Linux笔记】在VMware中,为基于NAT模式运行的CentOS虚拟机设置固定的网络IP地址
一、配置VMware虚拟网络 1、打开VMware虚拟网络编辑器: 点击VMware主界面上方的“编辑”菜单,选择“虚拟网络编辑器”。 2、选择NAT模式网络: 在虚拟网络编辑器中,选择VMnet8(或其他NAT模式的网络)。 取消勾…...

一文上手Kafka【中】
一、发送消息细节 在发送消息的特别注意: 在版本 3.0 中,以前返回 ListenableFuture 的方法已更改为返回 CompletableFuture。为了便于迁移,2.9 版本添加了一个方法 usingCompletableFuture(),该方法为 CompletableFu…...
Ubuntu如何如何安装tcpdump
在Ubuntu上安装tcpdump非常简单,可以通过以下步骤完成: 打开终端。 更新包列表: 首先,更新你的包管理器的包列表: sudo apt update 安装tcpdump: 使用以下命令安装tcpdump: sudo apt install …...

3-3 AUTOSAR RTE 对SR Port的作用
返回总目录->返回总目录<- 一、前言 RTE作为SWC和BSW之间的通信机构,支持Sender-Receiver方式实现ECU内及ECU间的通信。 对于Sender-Receiver Port支持三种模式: 显式访问:若运行实体采用显示模式的S/R通信方式,数据读写是即时的;隐式访问:当多个运行实体需要读取…...
hive/impala/mysql几种数据库的sql常用写法和函数说明
做大数据开发的时候,会在几种库中来回跳,同一个需求,不同库函数和写法会有出入,在此做汇总沉淀。 1. hive 1. 日期差 DATEDIFF(CURRENT_DATE(),wdjv.creation_date) < 30 30天内的数据 2.impala 3. spark 4. mysql 1.时间差…...
论文阅读:LM-Cocktail: Resilient Tuning of Language Models via Model Merging
论文链接 代码链接 Abstract 预训练的语言模型不断进行微调,以更好地支持下游应用。然而,此操作可能会导致目标领域之外的通用任务的性能显著下降。为了克服这个问题,我们提出了LM Cocktail,它使微调后的模型在总体上保持弹性。我们的方法以模型合并(Model Merging)的形…...

业务系统对接大模型的基础方案:架构设计与关键步骤
业务系统对接大模型:架构设计与关键步骤 在当今数字化转型的浪潮中,大语言模型(LLM)已成为企业提升业务效率和创新能力的关键技术之一。将大模型集成到业务系统中,不仅可以优化用户体验,还能为业务决策提供…...
DockerHub与私有镜像仓库在容器化中的应用与管理
哈喽,大家好,我是左手python! Docker Hub的应用与管理 Docker Hub的基本概念与使用方法 Docker Hub是Docker官方提供的一个公共镜像仓库,用户可以在其中找到各种操作系统、软件和应用的镜像。开发者可以通过Docker Hub轻松获取所…...

.Net框架,除了EF还有很多很多......
文章目录 1. 引言2. Dapper2.1 概述与设计原理2.2 核心功能与代码示例基本查询多映射查询存储过程调用 2.3 性能优化原理2.4 适用场景 3. NHibernate3.1 概述与架构设计3.2 映射配置示例Fluent映射XML映射 3.3 查询示例HQL查询Criteria APILINQ提供程序 3.4 高级特性3.5 适用场…...
MySQL中【正则表达式】用法
MySQL 中正则表达式通过 REGEXP 或 RLIKE 操作符实现(两者等价),用于在 WHERE 子句中进行复杂的字符串模式匹配。以下是核心用法和示例: 一、基础语法 SELECT column_name FROM table_name WHERE column_name REGEXP pattern; …...
大数据学习(132)-HIve数据分析
🍋🍋大数据学习🍋🍋 🔥系列专栏: 👑哲学语录: 用力所能及,改变世界。 💖如果觉得博主的文章还不错的话,请点赞👍收藏⭐️留言Ǵ…...
Angular微前端架构:Module Federation + ngx-build-plus (Webpack)
以下是一个完整的 Angular 微前端示例,其中使用的是 Module Federation 和 npx-build-plus 实现了主应用(Shell)与子应用(Remote)的集成。 🛠️ 项目结构 angular-mf/ ├── shell-app/ # 主应用&…...
MFE(微前端) Module Federation:Webpack.config.js文件中每个属性的含义解释
以Module Federation 插件详为例,Webpack.config.js它可能的配置和含义如下: 前言 Module Federation 的Webpack.config.js核心配置包括: name filename(定义应用标识) remotes(引用远程模块࿰…...

高保真组件库:开关
一:制作关状态 拖入一个矩形作为关闭的底色:44 x 22,填充灰色CCCCCC,圆角23,边框宽度0,文本为”关“,右对齐,边距2,2,6,2,文本颜色白色FFFFFF。 拖拽一个椭圆,尺寸18 x 18,边框为0。3. 全选转为动态面板状态1命名为”关“。 二:制作开状态 复制关状态并命名为”开…...

解决MybatisPlus使用Druid1.2.11连接池查询PG数据库报Merge sql error的一种办法
目录 前言 一、问题重现 1、环境说明 2、重现步骤 3、错误信息 二、关于LATERAL 1、Lateral作用场景 2、在四至场景中使用 三、问题解决之道 1、源码追踪 2、关闭sql合并 3、改写处理SQL 四、总结 前言 在博客:【写在创作纪念日】基于SpringBoot和PostG…...
PostgreSQL 对 IPv6 的支持情况
PostgreSQL 对 IPv6 的支持情况 PostgreSQL 全面支持 IPv6 网络协议,包括连接、存储和操作 IPv6 地址。以下是详细说明: 一、网络连接支持 1. 监听 IPv6 连接 在 postgresql.conf 中配置: listen_addresses 0.0.0.0,:: # 监听所有IPv4…...