当前位置: 首页 > news >正文

注意力机制(SE,ECA,CBAM) Pytorch代码

注意力机制

    • 1 SENet
    • 2 ECANet
    • 3 CBAM
      • 3.1 通道注意力
      • 3.2 空间注意力
      • 3.3 CBAM
    • 4 展示网络层具体信息

1 SENet

SE注意力机制(Squeeze-and-Excitation Networks):是一种通道类型的注意力机制,就是在通道维度上增加注意力机制,主要内容是是squeezeexcitation.

在这里插入图片描述
就是使用另外一个新的神经网络(两个Linear层),针对通道维度的数据进行学习,获取到特征图每个通道的重要程度,然后再和原始通道数据相乘即可。
具体参考Blog:
CNN中的注意力机制

小结:

  1. SENet的核心思想是通过全连接网络根据loss损失来自动学习特征权重,而不是直接根据特征通道的数值分配来判断,使有效的特征通道的权重大。

  2. 论文认为excitation操作中使用两个全连接层相比直接使用一个全连接层,它的好处在于,具有更多的非线性,可以更好地拟合通道间的复杂关联。

代码:
拆解步骤,forward代码写的比较细节


import torch
from torch import nn
from torchstat import stat  # 查看网络参数# 定义SE注意力机制的类
class se_block(nn.Module):# 初始化, in_channel代表输入特征图的通道数, ratio代表第一个全连接下降通道的倍数def __init__(self, in_channel, ratio=4):# 继承父类初始化方法super(se_block, self).__init__()# 属性分配# 全局平均池化,输出的特征图的宽高=1self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)# 第一个全连接层将特征图的通道数下降4倍self.fc1 = nn.Linear(in_features=in_channel, out_features=in_channel//ratio, bias=False)# relu激活self.relu = nn.ReLU()# 第二个全连接层恢复通道数self.fc2 = nn.Linear(in_features=in_channel//ratio, out_features=in_channel, bias=False)# sigmoid激活函数,将权值归一化到0-1self.sigmoid = nn.Sigmoid()# 前向传播def forward(self, inputs):  # inputs 代表输入特征图# 获取输入特征图的shapeb, c, h, w = inputs.shape# 全局平均池化 [b,c,h,w]==>[b,c,1,1]x = self.avg_pool(inputs)# 维度调整 [b,c,1,1]==>[b,c]x = x.view([b,c])# 第一个全连接下降通道 [b,c]==>[b,c//4]  # 这里也是使用Linear层的原因,只是对Channel进行线性变换x = self.fc1(x)x = self.relu(x)# 第二个全连接上升通道 [b,c//4]==>[b,c]  # 再通过Linear层恢复Channel数目x = self.fc2(x)# 对通道权重归一化处理  # 将数值转化为(0,1)之间,体现不同通道之间重要程度x = self.sigmoid(x)# 调整维度 [b,c]==>[b,c,1,1]  x = x.view([b,c,1,1])# 将输入特征图和通道权重相乘outputs = x * inputsreturn outputs

结果展示:
在这里插入图片描述
提示: in_channel/ratio需要大于0,否则线性层输入是0维度,没有意义,可以根据自己需求调整ratio的大小

2 ECANet

作者表明 SENet 中的降维会给通道注意力机制带来副作用,并且捕获所有通道之间的依存关系是效率不高的,而且是不必要的。
参考Blog:
CNN中的注意力机制

代码:
详细版本:在forward中,介绍了每一步的作用

import torch
from torch import nn
import math
from torchstat import stat  # 查看网络参数# 定义ECANet的类
class eca_block(nn.Module):# 初始化, in_channel代表特征图的输入通道数, b和gama代表公式中的两个系数def __init__(self, in_channel, b=1, gama=2):# 继承父类初始化super(eca_block, self).__init__()# 根据输入通道数自适应调整卷积核大小kernel_size = int(abs((math.log(in_channel, 2)+b)/gama))# 如果卷积核大小是奇数,就使用它if kernel_size % 2:kernel_size = kernel_size# 如果卷积核大小是偶数,就把它变成奇数else:kernel_size = kernel_size + 1# 卷积时,为例保证卷积前后的size不变,需要0填充的数量padding = kernel_size // 2# 全局平均池化,输出的特征图的宽高=1self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)# 1D卷积,输入和输出通道数都=1,卷积核大小是自适应的# 这个1维卷积需要好好了解一下机制,这是改进SENet的重要不同点self.conv = nn.Conv1d(in_channels=1, out_channels=1, kernel_size=kernel_size,bias=False, padding=padding)# sigmoid激活函数,权值归一化self.sigmoid = nn.Sigmoid()# 前向传播def forward(self, inputs):# 获得输入图像的shapeb, c, h, w = inputs.shape# 全局平均池化 [b,c,h,w]==>[b,c,1,1]x = self.avg_pool(inputs)# 维度调整,变成序列形式 [b,c,1,1]==>[b,1,c]x = x.view([b,1,c])   # 这是为了给一维卷积# 1D卷积 [b,1,c]==>[b,1,c]x = self.conv(x)# 权值归一化x = self.sigmoid(x)# 维度调整 [b,1,c]==>[b,c,1,1]x = x.view([b,c,1,1])# 将输入特征图和通道权重相乘[b,c,h,w]*[b,c,1,1]==>[b,c,h,w]outputs = x * inputsreturn outputs

精简版:

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchinfo import summary
import mathclass EfficientChannelAttention(nn.Module):           # Efficient Channel Attention moduledef __init__(self, c, b=1, gamma=2):super(EfficientChannelAttention, self).__init__()t = int(abs((math.log(c, 2) + b) / gamma))k = t if t % 2 else t + 1self.avg_pool = nn.AdaptiveAvgPool2d(1)self.conv1 = nn.Conv1d(1, 1, kernel_size=k, padding=int(k/2), bias=False)self.sigmoid = nn.Sigmoid()def forward(self, x):x = self.avg_pool(x)# 这里可以对照上一版代码,理解每一个函数的作用x = self.conv1(x.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)out = self.sigmoid(x)return out

效果展示:
在这里插入图片描述
总结:
ECANet参数更少!

3 CBAM

CBAM注意力机制是由**通道注意力机制(channel)空间注意力机制(spatial)**组成。
先通道注意力,后空间注意力的顺序注意力模块!
在这里插入图片描述

3.1 通道注意力

在这里插入图片描述
输入数据,对数据分别做最大池化操作和平均池化操作(输出都是batchchannel11),然后使用SENet的方法,针对channel进行先降维后升维操作,之后将输出的两个结果相加,再使用Sigmoid得到通道权重,再之后使用View函数恢复**(batchchannel11)**维度,和原始数据相乘得到通道注意力结果!
通道注意力代码:

#(1)通道注意力机制
class channel_attention(nn.Module):# 初始化, in_channel代表输入特征图的通道数, ratio代表第一个全连接的通道下降倍数def __init__(self, in_channel, ratio=4):# 继承父类初始化方法super(channel_attention, self).__init__()# 全局最大池化 [b,c,h,w]==>[b,c,1,1]self.max_pool = nn.AdaptiveMaxPool2d(output_size=1)# 全局平均池化 [b,c,h,w]==>[b,c,1,1]self.avg_pool = nn.AdaptiveAvgPool2d(output_size=1)# 第一个全连接层, 通道数下降4倍self.fc1 = nn.Linear(in_features=in_channel, out_features=in_channel//ratio, bias=False)# 第二个全连接层, 恢复通道数self.fc2 = nn.Linear(in_features=in_channel//ratio, out_features=in_channel, bias=False)# relu激活函数self.relu = nn.ReLU()# sigmoid激活函数self.sigmoid = nn.Sigmoid()# 前向传播def forward(self, inputs):# 获取输入特征图的shapeb, c, h, w = inputs.shape# 输入图像做全局最大池化 [b,c,h,w]==>[b,c,1,1]max_pool = self.max_pool(inputs)# 输入图像的全局平均池化 [b,c,h,w]==>[b,c,1,1]avg_pool = self.avg_pool(inputs)# 调整池化结果的维度 [b,c,1,1]==>[b,c]max_pool = max_pool.view([b,c])avg_pool = avg_pool.view([b,c])# 第一个全连接层下降通道数 [b,c]==>[b,c//4]x_maxpool = self.fc1(max_pool)x_avgpool = self.fc1(avg_pool)# 激活函数x_maxpool = self.relu(x_maxpool)x_avgpool = self.relu(x_avgpool)# 第二个全连接层恢复通道数 [b,c//4]==>[b,c]x_maxpool = self.fc2(x_maxpool)x_avgpool = self.fc2(x_avgpool)# 将这两种池化结果相加 [b,c]==>[b,c]x = x_maxpool + x_avgpool# sigmoid函数权值归一化x = self.sigmoid(x)# 调整维度 [b,c]==>[b,c,1,1]x = x.view([b,c,1,1])# 输入特征图和通道权重相乘 [b,c,h,w]outputs = inputs * xreturn outputs

3.2 空间注意力

在这里插入图片描述
针对输入数据,分别选取数据中最大值所在的维度(batch1h*w),和按照维度进行数据平均操作(batch1hw),然后将两个数据做通道连接(batch2hw),使用卷积操作,将channel维度降为1,之后对结果取sigmoid,得到空间注意力权重,和原始数据相乘得到空间注意力结果。

代码:

#(2)空间注意力机制
class spatial_attention(nn.Module):# 初始化,卷积核大小为7*7def __init__(self, kernel_size=7):# 继承父类初始化方法super(spatial_attention, self).__init__()# 为了保持卷积前后的特征图shape相同,卷积时需要paddingpadding = kernel_size // 2# 7*7卷积融合通道信息 [b,2,h,w]==>[b,1,h,w]self.conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=kernel_size,padding=padding, bias=False)# sigmoid函数self.sigmoid = nn.Sigmoid()# 前向传播def forward(self, inputs):# 在通道维度上最大池化 [b,1,h,w]  keepdim保留原有深度# 返回值是在某维度的最大值和对应的索引x_maxpool, _ = torch.max(inputs, dim=1, keepdim=True)# 在通道维度上平均池化 [b,1,h,w]x_avgpool = torch.mean(inputs, dim=1, keepdim=True)# 池化后的结果在通道维度上堆叠 [b,2,h,w]x = torch.cat([x_maxpool, x_avgpool], dim=1)# 卷积融合通道信息 [b,2,h,w]==>[b,1,h,w]x = self.conv(x)# 空间权重归一化x = self.sigmoid(x)# 输入特征图和空间权重相乘outputs = inputs * xreturn outputs

3.3 CBAM

将通道注意力模块和空间注意力模块顺序串联得到CBAM模块!
代码:

class cbam(nn.Module):# 初始化,in_channel和ratio=4代表通道注意力机制的输入通道数和第一个全连接下降的通道数# kernel_size代表空间注意力机制的卷积核大小def __init__(self, in_channel, ratio=4, kernel_size=7):# 继承父类初始化方法super(cbam, self).__init__()# 实例化通道注意力机制self.channel_attention = channel_attention(in_channel=in_channel, ratio=ratio)# 实例化空间注意力机制self.spatial_attention = spatial_attention(kernel_size=kernel_size)# 前向传播def forward(self, inputs):# 先将输入图像经过通道注意力机制x = self.channel_attention(inputs)# 然后经过空间注意力机制x = self.spatial_attention(x)return x

结果:
在这里插入图片描述

4 展示网络层具体信息

安装包

pip install torchstat

使用

from torchstat import stat net = cbam(16)
stat(net, (16, 256, 256))  # 不需要Batch维度

注意力机制后期学习到再持续更新!!

参考博客:
CNN注意力机制
ECANet

相关文章:

注意力机制(SE,ECA,CBAM) Pytorch代码

注意力机制1 SENet2 ECANet3 CBAM3.1 通道注意力3.2 空间注意力3.3 CBAM4 展示网络层具体信息1 SENet SE注意力机制(Squeeze-and-Excitation Networks):是一种通道类型的注意力机制,就是在通道维度上增加注意力机制,主要内容是是…...

Vue2笔记03 脚手架(项目结构),常用属性配置,ToDoList(本地存储,组件通信)

Vue脚手架 vue-cli 向下兼容可以选择较高版本 初始化 全局安装脚手架 npm install -g vue/cli 创建项目:切换到项目所在目录 vue create xxx 按照指引选择vue版本 创建成功 根据指引依次输入上面指令即可运行项目 也可使用vue ui在界面上完成创建&…...

Java程序的执行顺序、简述对线程池的理解

点个关注,必回关 文章目录一、Java程序是如何执行的二、合理利用线程池能够带来三个好处一、Java程序是如何执行的 我们日常的工作中都使用开发工具(IntelliJ IDEA 或 Eclipse 等)可以很方便的调试程序,或者是通 过打包工具把项目…...

【前言】嵌入式系统简介

随手拍拍💁‍♂️📷 日期: 2022.12.01 地点: 杭州 介绍: 2022.11.30下午两点时,杭州下了一场特别大的雪。隔天的12月路过食堂时,边上的井盖上发现了这个小雪人。此时边上的雪已经融化殆尽,只有这个雪人依旧维持着原状⛄…...

React设计原理—1框架原理

阅读前须知 本文是笔者学习卡颂的《React设计原理》的读书笔记,对书中有价值内容以Q&A方式进行呈现,同时结合了自己的理解🤔阅读时推荐先看问题,想想自己的答案,再和答案比对一下本文属于前端框架科普,…...

(C00034)基于Springboot+html前后端分离技术的宿舍管理系统-有文档

基于Springboothtml技术的宿舍管理系统-有文档项目简介项目获取开发环境项目技术运行截图项目简介 基于Springboothtml的前后端分离技术的宿舍管理系统项目为了方便对学生宿舍进行管理而设计,分为后勤、宿管、学生三种用户,后勤对整体宿舍进行管理、宿管…...

Flink面试题

一 基础篇Flink的执行图有哪几种?分别有什么作用Flink中的执行图一般是可以分为四类,按照生成顺序分别为:StreamGraph-> JobGraph-> ExecutionGraph->物理执行图。1)StreamGraph顾名思义,这里代表的是我们编写…...

Python学习笔记

前言:又从仓库翻出来了一些以前总结的文档,以下内容是我初学Python时网上找的或是图书馆借书抄写的笔记,现在再看有点零散不成体系,但是也还是纪念一下子吧。 Python学习笔记 对于初学编程的人来说,Python可以缩短编…...

最适合入门的100个深度学习实战项目

🚨注意🚨:最近经粉丝反馈,发现有些订阅者将此专栏内容进行二次售卖,特在此声明,本专栏内容仅供学习,不得以任何方式进行售卖,未经作者许可不得对本专栏内容行使发表权、署名权、修改…...

AssertionError: 618 columns passed, passed data had 508 columns【已解决】

问题描述 程序中断,报错如下AssertionError: 618 columns passed, passed data had 508 columns Exception has occurred: ValueError 618 columns passed, passed data had 508 columns AssertionError: 618 columns passed, passed data had 508 columnsThe abo…...

166_技巧_Power BI 窗口函数处理连续发生业务问题

166_技巧_Power BI 窗口函数处理连续发生业务问题 一、背景 在生产经营的数据监控中,会有一类指标需要监控是否连续发生,从而根据其在设定区间中的连续频次来评价业务。 例如: 员工连续迟到天数。销售金额连续上升或者下降。用户连续登陆…...

电子科技大学人工智能期末复习笔记(五):机器学习

目录 前言 监督学习 vs 无监督学习 回归 vs 分类 Regression vs Classification 训练集 vs 测试集 vs 验证集 泛化和过拟合 Generalization & Overfitting 线性分类器 Linear Classifiers 激活函数 - 概率决策 ⚠线性回归 决策树 Decision Trees 决策树构建递归…...

使用DDD指导业务设计的总结思考

领域驱动设计(DDD) 是 Eric Evans 提出的一种软件设计方法和思想,主要解决业务系统的设计和建模。DDD 有大量难以理解的概念,尤其是翻译的原因,某些词汇非常生涩,例如:模型、限界上下文、聚合、…...

面试官问:如何确保缓存和数据库的一致性?

如果你对这个问题有过研究,应该可以发现这个问题其实很好回答,如果第一次听到或者第一次遇到这个问题,估计会有点懵,今天我们来聊聊这个话题。 1、问题分析 首先我们来看看为什么会有这个问题! 我们在日常开发中&am…...

16.数据库Redis

一、基本概念 Redis(Remote Dictionary Server)译为“远程字典服务”,它是一款基于内存实现的键值型 NoSQL 数据库, 通常也被称为数据结构服务器,这是因为它可以存储多种数据类型,比如 string(字…...

【Redis高级-集群分片】

单机安装Redis首先需要安装Redis所需要的依赖:yum install -y gcc tclRedis安装包上传到虚拟机的任意目录:我放到了/tmp目录:解压缩:tar -zxvf /tmp/redis-6.2.4.tar.gz -C /tmp解压后:进入redis目录:cd /t…...

CSDN - CSDN27题解

文章目录幸运数字题目描述解题思路AC代码投篮题目描述解题思路AC代码通货膨胀-x国货币题目描述解题思路AC代码最后一位题目描述解题思路AC代码CSDN编程竞赛报名地址:https://edu.csdn.net/contest/detail/41 这次题目描述刚开始好像有些问题,之后被修正了…...

docker拉取mysql

搜索mysql版本docker search mysql搜索获赞数(星星数量) 大于 1000 的镜像docker search --filterstars1000 mysql搜索官方发布的版本docker search --filter is-officialtrue mysql搜索版本号docker search mysql57拉取docker pull devbeta/mysql57查看下载镜像docker images启…...

在Linux上安装Python3

记录:373场景:在CentOS 7.9操作系统上,安装Python-3.8.9环境。版本:JDK 1.8 Python-3.8.9官网地址:https://www.python.org下载地址:https://www.python.org/ftp/python/1.安装基础依赖1.1安装gcc(1)安装命…...

23 种设计模式的通俗解释,看完秒懂

01 工厂方法 追 MM 少不了请吃饭了,麦当劳的鸡翅和肯德基的鸡翅都是 MM 爱吃的东西,虽然口味有所不同,但不管你带 MM 去麦当劳或肯德基,只管向服务员说「来四个鸡翅」就行了。麦当劳和肯德基就是生产鸡翅的 Factory 工厂模式&…...

如何做好需求管理?经验方法、模型、工具

需求管理能力是衡量产品经理能力的一个重要指标。因为需求是产品的基石,只有选取恰当的方法进行需求分析及管理,才能更好的构建产品方案,从而输出精准的产品定义。结合本人学习和自身经验,打算将需求管理分”需求挖掘”、”需求分…...

怎么用期货做风险对冲(如何利用期货对冲风险)

不同期货市场的同一期货品种的对冲交易怎么做 不同 期货市场 的同一期货品种的 对冲交易 。 因为地域和 制度环境 不同,同一种期货品种在不同市场的同一时间的价格很可能是不一样的,并且也是在不断变化的。 这样在一个市场做多头买进&#xff0…...

C++标准模板库type_traits源码剖析

一、type_traits源码介绍 1、type_traits是C11提供的模板元基础库。 2、type_traits可实现在编译期计算。包括添加修饰、萃取、判断查询、类型推导等等功能。 3、type_traits提供了编译期的true和false。 二、type_traits的作用 1、根据不同类型,模板匹配不同版本…...

Python获取公众号(pc客户端)数据,使用Fiddler抓包工具

前言 嗨喽~大家好呀,这里是魔王呐 ❤ ~! 今天来教大家如何使用Fiddler抓包工具,获取公众号(PC客户端)的数据。 Fiddler是一个http协议调试代理工具,它能够记录并检查所有你的电脑和互联网之间的http通讯,…...

Maven进阶

这里写目录标题1.分模块开发1.1 模块更新后,会造成的影响2.依赖管理2.1 依赖传递2.2 可选依赖(隐藏自己的依赖,不让别人用)2.3 排除依赖(用别人的资源,把不用的去了)3.聚合与继承3.1 为什么要使用聚合工程?3.2 聚合工程开发2.1 聚合工程三级目录1.分模块开发 我们之前做的项目…...

AXI实战(一)-为AXI总线搭建简单的仿真测试环境

AXI实战(一)-搭建简单仿真环境 看完在本文后,你将可能拥有: 一个可以仿真AXI/AXI_Lite总线的完美主端(Master)或从端(Slave)一个使用SystemVerilog仿真模块的船信体验小何的AXI实战系列开更了,以下是初定的大纲安排: 欢迎感兴趣的朋友关注并支持,以下为正文部分 文章目录…...

数据库管理-第五十六期 监控(20230210)

数据库管理 2023-02-10第五十六期 监控1 怎么监控2 直观3 历史分析4 另一个BUG总结第五十六期 监控 春节后的7天班过后就来到了2月份,本周对之前发现X8M上的那个bug进行补丁修复和协助从12.2迁移了一套PDB到这个一体机上面,2次割接。这周还和原厂老大哥…...

测试开发,测试架构师为什么能拿50 60k呢需要掌握哪些技能呢

这篇文章是软件工程系列知识总结的第五篇,同样我会以自己的理解来阐述软件工程中关于架构设计相关的知识。相比于我们常见的研发架构师,测试架构师是近几年才出现的一个岗位,当然岗位title其实没有特殊的含义,在我看来测试架构师其…...

Miniblink 入门

miniblink官网:入门之前强烈建议将Miniblink介绍仔细看一遍。 MB内核组件标准版接口文档:这里列举了所有的api以及简单的说明,但是本人建议还是看wke.h更方便,里面都是宏实现的,直接搜相关函数即可。 mb demo下载和参…...

[python入门㊷] - python存储数据

目录 ❤ json.dump()存储数据 ❤ json.laod()读取数据 ❤ 保存和读取用户生成的数据 ❤ 重构 JSON(JavaScript Object Notation)格式最初是为JavaScript开发的,但随后成了一种常见格式,被包括Python在内的众多语言采用 ❤ json.dump()存储数据…...

广州贸易网站/b站推广网站入口202

原标题:如何计算MySQL中的QPS及TPS指标指标介绍•QPS :Queries Per Second查询量/秒,是一台服务器每秒能够相应的查询次数,是对一个特定的查询服务器在规定时间内所处理查询量多少的衡量标准。 •TPS : Transactions Per Second 事…...

东莞科技网站建设/网站seo推广

借鉴文章: 用vector实现普通平衡树!_致上-CSDN博客_vector平衡树 您需要写一种数据结构(可参考题目标题),来维护一些数,其中需要提供以下操作: 1.插入数值 x 2.删除数值 x (若有多个相同的数…...

网站开发所得税/爱战网官网

题目链接 链接&#xff1a;https://leetcode.com/problems/max-sum-of-rectangle-no-larger-than-k/description/ 题解&代码 1、暴力枚举所有的情况&#xff0c;时间复杂度O(n^2*m^2)&#xff0c;实际耗时759 ms class Solution { public:int maxSumSubmatrix(vector<ve…...

商城建站流程/设计网站logo

啦啦外卖41.7独立版最新外卖源码全开源 依旧是开源安装版&#xff0c;不是市面上的网站打包的垃圾版本。 包含完整vue源码 安装演示 版本验证 插件列表&#xff1a; vue源码 小程序新接口可以正确获取用户信息&#xff1a;...

网站建设技术支持有什么/一手项目对接app平台

击上方蓝色字体&#xff0c;选择“标星公众号”优质文章&#xff0c;第一时间送达 大家好&#xff0c;我是燕子&#xff01;Sentinel是阿里巴巴开源的限流器熔断器&#xff0c;并且带有可视化操作界面。在日常开发中&#xff0c;限流功能时常被使用&#xff0c;用于对某些接口…...

做品牌网站哪个好用/深圳百度seo培训

1、Windows平台 在windows命令行窗口下执行&#xff1a; C:/>netstat -ano 我们可以知道某一端口被那个进程&#xff08;对应PID&#xff09;占用&#xff1b; 然后我们可以打开任务管理器&#xff1b;查看某一PID对应的进程名&#xff1b; 如果PID没有显示&#xff0c;菜单…...