注意力机制讲解与代码解析
一、SEBlock(通道注意力机制)
先在H*W维度进行压缩,全局平均池化将每个通道平均为一个值。
(B, C, H, W)---- (B, C, 1, 1)
利用各channel维度的相关性计算权重
(B, C, 1, 1) --- (B, C//K, 1, 1) --- (B, C, 1, 1) --- sigmoid
与原特征相乘得到加权后的。
import torch
import torch.nn as nnclass SELayer(nn.Module):def __init__(self, channel, reduction = 4):super(SELayer, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1) //自适应全局池化,只需要给出池化后特征图大小self.fc1 = nn.Sequential(nn.Conv2d(channel, channel//reduction, 1, bias = False),nn.ReLu(implace = True),nn.Conv2d(channel//reduction, channel, 1, bias = False),nn.sigmoid())def forward(self, x):y = self.avg_pool(x)y_out = self.fc1(y)return x * y
二、CBAM(通道注意力+空间注意力机制)
CBAM里面既有通道注意力机制,也有空间注意力机制。
通道注意力同SE的大致相同,但额外加入了全局最大池化与全局平均池化并行。
空间注意力机制:先在channel维度进行最大池化和均值池化,然后在channel维度合并,MLP进行特征交融。最终和原始特征相乘。
import torch
import torch.nn as nnclass ChannelAttention(nn.Module):def __init__(self, channel, rate = 4):super(ChannelAttention, self).__init__()self.avg_pool = nn.AdaptiveAvgPool2d(1)self.max_pool = nn.AdaptiveMaxPool2d(1)self.fc1 = nn.Sequential(nn.Conv2d(channel, channel//rate, 1, bias = False)nn.ReLu(implace = True)nn.Conv2d(channel//rate, channel, 1, bias = False) )self.sig = nn.sigmoid()def forward(self, x):avg = sefl.avg_pool(x)avg_feature = self.fc1(avg)max = self.max_pool(x)max_feature = self.fc1(max)out = max_feature + avg_featureout = self.sig(out)return x * out
import torch
import torch.nn as nnclass SpatialAttention(nn.Module):def __init__(self):super(SpatialAttention, self).__init__()//(B,C,H,W)---(B,1,H,W)---(B,2,H,W)---(B,1,H,W)self.conv1 = nn.Conv2d(2, 1, kernel_size = 3, padding = 1, bias = False)self.sigmoid = nn.sigmoid()def forward(self, x):mean_f = torch.mean(x, dim = 1, keepdim = True)max_f = torch.max(x, dim = 1, keepdim = True)cat = torch.cat([mean_f, max_f], dim = 1)out = self.conv1(cat)return x*self.sigmod(out)
三、transformer里的注意力机制
Scaled Dot-Product Attention
该注意力机制的输入是QKV。
1.先Q,K相乘。
2.scale
3.softmax
4.求output
import torch
import torch.nn as nnclass ScaledDotProductAttention(nn.Module):def __init__(self, scale):super(ScaledDotProductAttention, self)self.scale = scaleself.softmax = nn.softmax(dim = 2)def forward(self, q, k, v):u = torch.bmm(q, k.transpose(1, 2))u = u / scaleattn = self.softmax(u)output = torch.bmm(attn, v)return outputscale = np.power(d_k, 0.5) //缩放系数为K维度的根号。
//Q (B, n_q, d_q) , K (B, n_k, d_k) V (B, n_v, d_v),Q与K的特征维度一定要一样。KV的个数一定要一样。
MultiHeadAttention
将QKVchannel维度转换为n*C的形式,相当于分成n份,分别做注意力机制。
1.QKV单头变多头 channel ----- n * new_channel通过linear变换,然后把head和batch先合并
2.求单头注意力机制输出
3.维度拆分 将最终的head和channel合并。
4.linear得到最终输出维度
import torch
import torch.nn as nnclass MultiHeadAttention(nn.Module):def __init__(self, n_head, d_k, d_k_, d_v, d_v_, d_o):super(MultiHeadAttention, self)self.n_head = n_headself.d_k = d_kself.d_v = d_vself.fc_k = nn.Linear(d_k_, n_head * d_k)self.fc_v = nn.Linear(d_v_, n_head * d_v)self.fc_q = nn.Linear(d_k_, n_head * d_k)self.attention = ScaledDotProductAttention(scale=np.power(d_k, 0.5))self.fc_o = nn.Linear(n_head * d_v, d_0)def forward(self, q, k, v):batch, n_q, d_q_ = q.size()batch, n_k, d_k_ = k.size()batch, n_v, d_v_ = v.size()q = self.fc_q(q)k = self.fc_k(k)v = self.fc_v(v)q = q.view(batch, n_q, n_head, d_q).permute(2, 0, 1, 3).contiguous().view(-1, n_q, d_q)k = k.view(batch, n_k, n_head, d_k).permute(2, 0, 1, 3).contiguous().view(-1, n_k, d_k)v = v.view(batch, n_v, n_head, d_v).permute(2, 0, 1, 3).contiguous().view(-1. n_v, d_v) output = self.attention(q, k, v)output = output.view(n_head, batch, n_q, d_v).permute(1, 2, 0, 3).contiguous().view(batch, n_q, -1)output = self.fc_0(output)return output
相关文章:
注意力机制讲解与代码解析
一、SEBlock(通道注意力机制) 先在H*W维度进行压缩,全局平均池化将每个通道平均为一个值。 (B, C, H, W)---- (B, C, 1, 1) 利用各channel维度的相关性计算权重 (B, C, 1, 1) --- (B, C//K, 1, 1) --- (B, C, 1, 1) --- sigmoid 与原特征相…...
微调 TrOCR – 训练 TrOCR 识别弯曲文本
TrOCR(基于 Transformer 的光学字符识别)模型是性能最佳的 OCR 模型之一。在我们之前的文章中,我们分析了它们在单行打印和手写文本上的表现。然而,与任何其他深度学习模型一样,它们也有其局限性。TrOCR 在处理开箱即用的弯曲文本时表现不佳。本文将通过在弯曲文本数据集上…...
Jetsonnano B01 笔记7:Mediapipe与人脸手势识别
今日继续我的Jetsonnano学习之路,今日学习安装使用的是:MediaPipe 一款开源的多媒体机器学习模型应用框架。可在移动设备、工作站和服务 器上跨平台运行,并支持移动 GPU 加速。 介绍与程序搬运官方,只是自己的学习记录笔记&am…...
vue学习之v-if/v-else/v-else-if
v-else/v-else-if 创建 demo7.html,内容如下 <!DOCTYPE html> <html lang"en"><head><meta charset"UTF-8"><meta name"viewport" content"widthdevice-width, initial-scale1.0"><title>Docum…...
ansible的安装和简单的块使用
目录 一、概述 二、安装 1、选择源 2、安装ansible 3、模块查看 三、实验 1、拓扑编辑 2、设置组、ping模块 3、hostname模块 4、file模块 编辑 5、stat模块 6、copy模块(本地拷贝到远程) 7、fetch模块与copy模块类似,但作用…...
Android 状态栏显示运营商名称
Android 原生设计中在锁屏界面会显示运营商名称,用户界面中,大概是基于 icon 数量长度显示考虑,对运营商名称不作显示。但是国内基本都加上运营商名称。对图标显示长度优化基本都是:缩小运营商字体、限制字数长度、信号图标压缩上…...
10.Xaml ListBox控件
1.运行界面 2.运行源码 a.Xaml 源码 <Grid Name="Grid1"><!--IsSelected="True" 表示选中--><ListBox x:Name="listBo...
基于vue3和element-plus的省市区级联组件
git地址:https://github.com/ht-sauce/elui-china-area-dht 使用:npm i elui-china-area-dht 默认使用 使用方法 <template><div class"app"><!--默认使用--><elui-china-area-dht change"onChange"></elui-china…...
Paper: 利用RNN来提取恶意软件家族的API调用模式
论文 摘要 恶意软件家族分类是预测恶意软件特征的好方法,因为属于同一家族的恶意软件往往有相似的行为特征恶意软件检测或分类方法分静态分析和动态分析两种: 静态分析基于恶意软件中包含的特定签名进行分析,优点是分析的范围覆盖了整个代码…...
sdkman 安装以及 graalvm安装
sdkman安装以及graalvm安装全过程, (可能需要梯子) tiamTiam-Lenovo:~$ curl -s "https://get.sdkman.io" | bash-syyyyyyys:/yho: -yd./yh/ m..oho. hy ..sh/ :N -/…...
如何正确使用 WEB 接口的 HTTP 状态码和业务状态码?
当设计和开发 Web 接口时,必然会和 HTTP 状态码与业务状态码这两个概念打交道。很多同学可能没有注意过这两个概念或者两者的区别,做得稀里糊涂,接下来详细讲解下二者的定义、区别和使用方法。 HTTP 状态码 HTTP 状态码是由 HTTP 协议定义的…...
Spark【Spark SQL(三)DataSet】
DataSet DataFrame 的出现,让 Spark 可以更好地处理结构化数据的计算,但存在一个问题:编译时的类型安全问题,为了解决它,Spark 引入了 DataSet API(DataFrame API 的扩展)。DataSet 是分布式的数…...
制作立体图像实用软件:3DMasterKit 10.7 Crack
3DMasterKit 软件专为创建具有逼真 3D 和运动效果的光栅图片而设计:翻转、动画、变形和缩放。 打印机、广告工作室、摄影工作室和摄影师将发现 3DMasterKit 是一种有用且经济高效的解决方案,可将其业务扩展到新的维度,提高生成的 3D 图像和光…...
高校 Web 站点网络安全面临的主要的威胁
校园网 Web 站点的主要安全威胁来源于计算机病毒、内部用户恶意攻击和 破坏、内部用户非恶意的错误操作和网络黑客入侵等。 2.1 计算机病毒 计算机病毒是指编制者在计算机程序中插入的破坏计算机功能或者数据, 影响计算机使用并且能够自我复制的一组计算机指令或…...
vue前端解决跨域
1,首先 axios请求,看后端接口路径,http://122.226.146.110:25002/api/xx/ResxxList,所以baseURL地址改成 ‘/api’ let setAxios originAxios.create({baseURL: /api, //这里要改掉timeout: 20000 // request timeout}); export default s…...
【Cicadaplayer】解码线程及队列实现
4.4分支https://github.com/alibaba/CicadaPlayer/blob/release/0.4.4/framework/codec/ActiveDecoder.h对外:送入多个包,获取一个帧 int send_packet(std::unique_ptr<IAFPacket> &packet, uint64_t timeOut) override;int getFrame(std::u...
把文件上传到Gitee的详细步骤
目录 第一步:创建一个空仓库 第二步:找到你想上传的文件所在的地址,打开命令窗口,git init 第三步:git add 想上传的文件 ,git commit -m "给这次提交取个名字" 第四步:和咱们在第…...
基于keras中Lenet对于mnist的处理
文章目录 MNIST导入必要的包加载数据可视化数据集查看数据集的分布开始训练画出loss图画出accuracy图 使用数据外的图来测试图片可视化转化灰度图的可视化可视化卷积层的特征图第一层卷积 conv1 和 pool1第二层卷积 conv2 和 pool2 MNIST MNIST(Modified National …...
Python爬虫 教程:IP池的使用
前言 嗨喽~大家好呀,这里是魔王呐 ❤ ~! python更多源码/资料/解答/教程等 点击此处跳转文末名片免费获取 一、简介 爬虫中为什么需要使用代理 一些网站会有相应的反爬虫措施,例如很多网站会检测某一段时间某个IP的访问次数,如果访问频率…...
Ansible之playbook剧本
一、playbook概述1.1 playbook 介绍1.2 playbook 组成部分 二、playbook 示例2.1 playbook 启动及检测2.2 实例一2.3 vars 定义、引用变量2.4 指定远程主机sudo切换用户2.5 when条件判断2.6 迭代2.7 Templates 模块1.先准备一个以 .j2 为后缀的 template 模板文件,设…...
unique_ptr的大小探讨
unique_ptr大小和删除器有很大关系,具体区别看如下代码的分析。不要让unique_ptr占用的空间太大,否则不会达到裸指针同样的效果。 #include <iostream> #include <memory> using namespace std;class Widget {int m_x;int m_y;int m_z;publ…...
人工智能TensorFlow PyTorch物体分类和目标检测合集【持续更新】
1. 基于TensorFlow2.3.0的花卉识别 基于TensorFlow2.3.0的花卉识别Android APP设计_基于安卓的花卉识别_lilihewo的博客-CSDN博客 2. 基于TensorFlow2.3.0的垃圾分类 基于TensorFlow2.3.0的垃圾分类Android APP设计_def model_load(img_shape(224, 224, 3)_lilihewo的博客-CS…...
ElementPlus·面包屑导航实现
面包屑导航 使用vue3中的UI框架elementPlus的 <el-breadcrumb> 实现面包屑导航 <template><!-- 面包屑 --><div class"bread-container" ><el-breadcrumb separator">"><el-breadcrumb-item :to"{ path:/ }&quo…...
【项目管理】PM vs PMO 18点区别
导读:项目经理跟PMO主要有哪些区别?首先从定义上了解,然后根据其他维度进行对比分析,基本可以了解这二者的区别,文中罗列18点区别供各位参考。 目录 1、定义 1.1 PMO 1.2 PM 2、两者区别 2.1 ROI 2.2 项目成功率…...
13 Python使用Json
概述 在上一节,我们介绍了如何在Python中使用xml,包括:SAX、DOM、ElementTree等内容。在这一节,我们将介绍如何在Python中使用Json。Json的英文全称为JavaScript Object Notation,中文为JavaScript对象表示法ÿ…...
PDFBOX和ASPOSE.PDF
一、aspose.pdf 文档 https://docs.aspose.com/pdf/java/ 1、按段落分段 /*** docx文本按段分段*/ public static void main(String[] args) {int i 1;try {// 打开文件流FileInputStream file new FileInputStream("I:\\范文.docx");// 创建 Word 文档对象XWPFDo…...
第51节:cesium 范围查询(含源码+视频)
结果示例: 完整源码: <template><div class="viewer"><el-button-group class="top_item"><el-button type=...
YOLOv5改进算法之添加CA注意力机制模块
目录 1.CA注意力机制 2.YOLOv5添加注意力机制 送书活动 1.CA注意力机制 CA(Coordinate Attention)注意力机制是一种用于加强深度学习模型对输入数据的空间结构理解的注意力机制。CA 注意力机制的核心思想是引入坐标信息,以便模型可以更好地…...
Jmeter系列-阶梯加压线程组Stepping Thread Group详解(6)
前言 tepping Thread Group是第一个自定义线程组但,随着版本的迭代,已经有更好的线程组代替Stepping Thread Group了【Concurrency Thread Group】,所以说Stepping Thread Group已经是过去式了,但还是介绍一下 Stepping Thread …...
图像的几何变换(缩放、平移、旋转)
图像的几何变换 学习目标 掌握图像的缩放、平移、旋转等了解数字图像的仿射变换和透射变换 1 图像的缩放 缩放是对图像的大小进行调整,即 使图像放大或缩小 cv2.resize(src,dsize,fx0,fy0,interpolationcv2.INTER_LINEAR) 参数: src :输入图像dsize…...
日产一卡二卡四/重庆seo公司
背景:今天公司终端上有一个功能打开异常,报500错误,我用Fiddler找到链接,然后在IE里打开,报500.23错误:检测到在集成的托管管道模式下不适用的ASP.NET设置。后台是一个IIS7和tomcat7集成的环境,…...
wordpress is admin/长春做网站公司长春seo公司
概述 任何一个Java程序都有一个或多个class,程序运行时,需要将class文件加载到JVM中才可以使用,负责加载class文件的就是ClassLoader。每一个Class对象内部都有一个classLoader字段来标记该对象应该由哪一个ClassLoader加载。 public final…...
甲蛙网站建设/长春网站建设开发
西雅图IT圈:seattleit【今日作者】Powerball选号机身体和灵魂总有一个要走在买PowerBall的路上01本周二,华盛顿3位参议员公布了民主党税改计划的新细节,计划对大公司申报的收入征收最低15%的税,税收收入将用于支持拜登政府的大规模…...
备案后怎么建设网站/重庆网络推广公司
一、nn.Sequential torch.nn.Sequential是一个Sequential容器,模块将按照构造函数中传递的顺序添加到模块中。另外,也可以传入一个有序模块。 为了更容易理解,官方给出了一些案例: # Sequential使用实例model nn.Sequential(nn…...
制作企业网站的方法/北京网络seo经理
面对对象: 1.面向对象是一种编程思想,把一切东西看成是一个个对象,比如人、耳机、鼠标、水杯等,他们各自都有属性,比如:耳机是白色的,鼠标是黑色的,水杯是圆柱形的等等,把…...
网站发展趋势和前景/百度引流推广怎么做
参考书籍:《视觉SLAM 14讲》 代码: https://github.com/gaoxiang12/slambook 第1讲:前言; 学习需具备的知识: >高等数学、线性代数、概率论 >C语言基础 >Linux基础 SLAM 概念: Simultaneous…...