当前位置: 首页 > news >正文

深度学习中的注意力模块的添加

在深度学习中,骨干网络通常指的是网络的主要结构或主干部分,它负责从原始输入中提取高级特征。骨干网络通常由卷积神经网络(CNN)或者类似的架构组成,用于对图像、文本或其他类型的数据进行特征提取和表示学习。

注意力模块则是一种用于处理序列数据的重要组件,例如在自然语言处理领域中常用的 Transformer 模型中就包含了注意力机制。注意力模块可以让模型更好地关注输入序列中的不同部分,并学习它们之间的相关性,从而提高模型的性能和泛化能力。

骨干网络和注意力模块通常是结合在一起来构建端到端的深度学习模型。这种结合可以通过多种方式实现:

  1. 注意力机制作为模块插入:在骨干网络的某个特定层或者多个层之间插入注意力模块。这样可以让模型在处理输入数据时更加灵活,可以根据任务的需要更加关注特定的信息或特征。

  2. 注意力机制与骨干网络并行:将注意力模块与骨干网络的不同部分并行处理输入数据,然后将它们的输出进行合并或者融合。这种方式可以提供更丰富的特征表征,同时保留了骨干网络和注意力模块各自的特点。

  3. 注意力机制作为整个模型的一部分:有些模型设计中,注意力机制被整合到模型的整个结构中,例如在 Transformer 模型中,注意力机制是模型的核心组件之一,与编码器、解码器等其他模块相互作用,共同完成任务。

总的来说,骨干网络和注意力模块的结合方式取决于具体的任务和模型设计需求。它们相互协作可以提高模型的表现,并且在不同的应用场景中可能会有不同的结合方式和调整方法。

举例:以 ResNet 骨干网络为例,并在其中的一个特定层插入自注意力机制。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50class SelfAttention(nn.Module):def __init__(self, in_channels, out_channels):super(SelfAttention, self).__init__()self.query_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)self.key_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)self.value_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x):batch_size, channels, height, width = x.size()proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)proj_key = self.key_conv(x).view(batch_size, -1, width * height)energy = torch.bmm(proj_query, proj_key)attention = F.softmax(energy, dim=-1)proj_value = self.value_conv(x).view(batch_size, -1, width * height)out = torch.bmm(proj_value, attention.permute(0, 2, 1))out = out.view(batch_size, channels, height, width)out = self.gamma * out + xreturn outclass ResNetWithAttention(nn.Module):def __init__(self, num_classes):super(ResNetWithAttention, self).__init__()self.resnet = resnet50(pretrained=True)# Insert attention module after the second convolutional layerself.resnet.layer1.add_module("self_attention", SelfAttention(256, 256))self.fc = nn.Linear(2048, num_classes)def forward(self, x):x = self.resnet(x)x = F.avg_pool2d(x, x.size()[2:]).view(x.size(0), -1)x = self.fc(x)return x# Example usage:
model = ResNetWithAttention(num_classes=1000)
input_tensor = torch.randn(1, 3, 224, 224)  # Example input tensor
output = model(input_tensor)
print(output.shape)  # Should print: torch.Size([1, 1000])

在这个示例中,我们定义了一个自注意力模块 SelfAttention,并将其插入到了 ResNet 的第一个残差块 layer1 中的第二个卷积层之后。然后我们定义了一个新的模型 ResNetWithAttention,其中包含了 ResNet 的主干部分和我们插入的注意力模块。最后,我们在模型的最后添加了一个全连接层用于分类。

这个示例展示了如何在 PyTorch 中实现将注意力模块插入到现有骨干网络中的过程。通过这种方式,我们可以灵活地设计深度学习模型,以更好地适应不同的任务和数据特点。

举例:在 PyTorch 中实现将注意力机制与骨干网络并行处理输入数据,我们可以在骨干网络的输出上应用注意力机制,然后将其与骨干网络的输出进行合并或融合。下面是一个示例,我们将在 ResNet50 骨干网络的输出上应用自注意力机制,并将其与原始输出进行融合。

import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.models import resnet50class SelfAttention(nn.Module):def __init__(self, in_channels, out_channels):super(SelfAttention, self).__init__()self.query_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)self.key_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)self.value_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)self.gamma = nn.Parameter(torch.zeros(1))def forward(self, x):batch_size, channels, height, width = x.size()proj_query = self.query_conv(x).view(batch_size, -1, width * height).permute(0, 2, 1)proj_key = self.key_conv(x).view(batch_size, -1, width * height)energy = torch.bmm(proj_query, proj_key)attention = F.softmax(energy, dim=-1)proj_value = self.value_conv(x).view(batch_size, -1, width * height)out = torch.bmm(proj_value, attention.permute(0, 2, 1))out = out.view(batch_size, channels, height, width)out = self.gamma * out + xreturn outclass ResNetWithAttentionParallel(nn.Module):def __init__(self, num_classes):super(ResNetWithAttentionParallel, self).__init__()self.resnet = resnet50(pretrained=True)self.attention = SelfAttention(2048, 2048)self.fc = nn.Linear(2048 * 2, num_classes)  # Concatenating original and attention-enhanced featuresdef forward(self, x):features = self.resnet(x)attention_out = self.attention(features)combined_features = torch.cat((features, attention_out), dim=1)  # Concatenate original and attention-enhanced featuresoutput = self.fc(combined_features.view(features.size(0), -1))return output# Example usage:
model = ResNetWithAttentionParallel(num_classes=1000)
input_tensor = torch.randn(1, 3, 224, 224)  # Example input tensor
output = model(input_tensor)
print(output.shape)  # Should print: torch.Size([1, 1000])

在这个示例中,我们定义了一个自注意力模块 SelfAttention,并在 ResNet50 的输出上应用了这个注意力机制。然后,我们将注意力机制的输出与原始的骨干网络输出进行了融合,通过将它们连接在一起。最后,我们在融合后的特征上添加了一个全连接层用于分类。

这个示例展示了如何在 PyTorch 中实现将注意力机制与骨干网络并行处理输入数据的方法。通过这种方式,我们可以利用注意力机制来增强骨干网络提取的特征,从而提高模型的性能和泛化能力。

举例:一个自注意力(self-attention)机制作为整个模型一部分的例子,这个例子基于 Transformer 模型的结构。在 Transformer 中,自注意力机制被整合到编码器和解码器中,用于处理序列数据。

下面是一个简化版本的 Transformer 编码器,其中包含自注意力层作为整个模型的一部分:

import torch
import torch.nn as nn
import torch.nn.functional as Fclass SelfAttention(nn.Module):def __init__(self, embed_size, heads):super(SelfAttention, self).__init__()self.embed_size = embed_sizeself.heads = headsself.head_dim = embed_size // headsassert (self.head_dim * heads == embed_size), "Embedding size needs to be divisible by heads"self.values = nn.Linear(self.head_dim, self.head_dim, bias=False)self.keys = nn.Linear(self.head_dim, self.head_dim, bias=False)self.queries = nn.Linear(self.head_dim, self.head_dim, bias=False)self.fc_out = nn.Linear(heads * self.head_dim, embed_size)def forward(self, values, keys, query, mask):N = query.shape[0]value_len, key_len, query_len = values.shape[1], keys.shape[1], query.shape[1]# Split the embedding into self.heads different piecesvalues = values.reshape(N, value_len, self.heads, self.head_dim)keys = keys.reshape(N, key_len, self.heads, self.head_dim)queries = query.reshape(N, query_len, self.heads, self.head_dim)values = self.values(values)keys = self.keys(keys)queries = self.queries(queries)# Scaled dot-product attentionenergy = torch.einsum("nqhd,nkhd->nhqk", [queries, keys])if mask is not None:energy = energy.masked_fill(mask == 0, float("-1e20"))attention = torch.softmax(energy / (self.embed_size ** (1 / 2)), dim=3)out = torch.einsum("nhql,nlhd->nqhd", [attention, values]).reshape(N, query_len, self.heads * self.head_dim)out = self.fc_out(out)return outclass TransformerEncoderLayer(nn.Module):def __init__(self, embed_size, heads, dropout, forward_expansion):super(TransformerEncoderLayer, self).__init__()self.attention = SelfAttention(embed_size, heads)self.norm1 = nn.LayerNorm(embed_size)self.norm2 = nn.LayerNorm(embed_size)self.feed_forward = nn.Sequential(nn.Linear(embed_size, forward_expansion * embed_size),nn.ReLU(),nn.Linear(forward_expansion * embed_size, embed_size),)self.dropout = nn.Dropout(dropout)def forward(self, value, key, query, mask):attention = self.attention(value, key, query, mask)# Add skip connection, run through normalization and finally dropoutx = self.dropout(self.norm1(attention + query))forward = self.feed_forward(x)out = self.dropout(self.norm2(forward + x))return outclass TransformerEncoder(nn.Module):def __init__(self,src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length,):super(TransformerEncoder, self).__init__()self.embed_size = embed_sizeself.device = deviceself.word_embedding = nn.Embedding(src_vocab_size, embed_size)self.position_embedding = nn.Embedding(max_length, embed_size)self.layers = nn.ModuleList([TransformerEncoderLayer(embed_size,heads,dropout=dropout,forward_expansion=forward_expansion,)for _ in range(num_layers)])self.dropout = nn.Dropout(dropout)def forward(self, x, mask):N, seq_length = x.shapepositions = torch.arange(0, seq_length).expand(N, seq_length).to(self.device)out = self.dropout(self.word_embedding(x) + self.position_embedding(positions))for layer in self.layers:out = layer(out, out, out, mask)return out# Example usage:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
src_vocab_size = 1000  # Example vocabulary size
max_length = 100  # Example maximum sequence length
embed_size = 256
heads = 8
num_layers = 6
forward_expansion = 4
dropout = 0.2encoder = TransformerEncoder(src_vocab_size,embed_size,num_layers,heads,device,forward_expansion,dropout,max_length,
)# Example input tensor
input_tensor = torch.randint(0, src_vocab_size, (32, 10))  # Batch size: 32, Sequence length: 10
mask = torch.ones(32, 10)  # Example maskoutput = encoder(input_tensor, mask)
print(output.shape)  # Should print: torch.Size([32, 10, 256])

在这个例子中,我们定义了一个简化版本的 Transformer 编码器,其中包含自注意力层作为整个模型的一部分。自注意力层用于处理输入序列,并学习序列中不同位置之间的关系。整个模型接受输入序列并输出相应的表示。

相关文章:

深度学习中的注意力模块的添加

在深度学习中,骨干网络通常指的是网络的主要结构或主干部分,它负责从原始输入中提取高级特征。骨干网络通常由卷积神经网络(CNN)或者类似的架构组成,用于对图像、文本或其他类型的数据进行特征提取和表示学习。 注意力…...

Docker 部署开源远程桌面工具 RustDesk

RustDesk是一款远程控制,远程协助的开源软件。完美替代TeamViewer ,ToDesk,向日葵等平台。关键支持自建服务器,更安全私密远程控制电脑!官网地址:https://rustdesk.com/ 环境准备 1、阿里云服务器一 台&a…...

intellij idea 使用git ,快速合并冲突

可以选择左边的远程分支上的代码,也可以选择右边的代码,而中间是合并的结果。 一个快速合并冲突的小技巧: 如果冲突比较多,想要快速合并冲突。也可以直接点击上图中 Apply non-conflicting changes 旁边的 All 。 这样 Idea 就会…...

AcWing26. 二进制中1的个数。三种解法Java

输入一个 3232 位整数,输出该数二进制表示中 11 的个数。 注意: 负数在计算机中用其绝对值的补码来表示。 数据范围 −100≤ 输入整数 ≤100 样例1 输入:9 输出:2 解释:9的二进制表示是1001,一共有2个…...

【ADB】常见命令汇总(持续更新)

▒ 目录 ▒ 🛫 导读开发环境 1️⃣ 设备连接和识别2️⃣ 应用程序管理3️⃣ 文件传输和管理4️⃣ 设备信息和日志5️⃣ 设备操作和控制6️⃣ 截图相关🛬 文章小结📖 参考资料 🛫 导读 Android调试桥(ADB)是…...

【递归与递推】数的计算|数的划分|耐摔指数

1.数的计算 - 蓝桥云课 (lanqiao.cn) 思路: 1.dfs的变量>每一次递归什么在变? (1)当前数的大小一直在变:sum (2)最高位的数:k 2.递归出口:最高位数字为1 3.注意&#…...

企业案例:金蝶云星空集成钉钉,帆软BI

正文:在数字化转型的大潮中,众多企业开始探索并实践高效的数据流转与集成,以提升内部管理效率和决策质量。本文将以某企业为例,详细介绍如何通过将钉钉审批流程的数据实时同步至金蝶云星空,并进一步在帆软报表平台上实…...

简单设计模式讲解

设计模式是在软件开发中经常使用的最佳实践,用于解决在软件设计中经常遇到的问题。它们提供了可重用的设计,使得代码更加灵活、可维护和可扩展。下面我将为你讲解几种常见的设计模式,并提供相应的C#代码示例。 1. 单例模式(Single…...

基于springboot的社区医疗服务系统

文章目录 项目介绍主要功能截图:部分代码展示设计总结项目获取方式 🍅 作者主页:超级无敌暴龙战士塔塔开 🍅 简介:Java领域优质创作者🏆、 简历模板、学习资料、面试题库【关注我,都给你】 &…...

影院座位选择简易实现(uniapp)

界面展示 主要使用到uniap中的movable-area&#xff0c;和movable-view组件实现。 代码逻辑分析 1、使用movable-area和movea-view组件&#xff0c;用于座位展示 <div class"ui-seat__box"><movable-area class"ui-movableArea"><movab…...

调用飞书获取用户Id接口成功,但是没有返回相应数据

原因&#xff1a; 该自建应用没有开放相应的数据权限。 解决办法&#xff1a; 在此处配置即可。...

STM32 GPIO输入检测——按键

前言 在嵌入式系统开发中&#xff0c;对GPIO输入进行检测是一项常见且关键的任务。STM32微控制器作为一款功能强大的处理器&#xff0c;具有丰富的GPIO功能&#xff0c;可以轻松实现对外部信号的检测和处理。在本文中&#xff0c;我们将深入探讨如何在STM32微控制器上进行GPIO…...

Rustdesk二次编译,新集成AI功能开源Gpt小程序为远程协助助力,全网首发

环境&#xff1a; Rustdesk1.1.9 sciter版 问题描述&#xff1a; Rustdesk二次编译&#xff0c;新集成AI功能开源Gpt小程序为远程协助助力,全网首发 解决方案&#xff1a; Rustdesk二次编译&#xff0c;新集成开源AI功能Gpt小程序&#xff0c;为远程协助助力&#xff0c…...

面试(03)————多线程和线程池

一、多线程 1、什么是线程?线程和进程的区别? 2、创建线程有几种方式 &#xff1f; 3、Runnable 和 Callable 的区别&#xff1f; 4、如何启动一个新线程、调用 start 和 run 方法的区别&#xff1f; 5、线程有哪几种状态以及各种状态之间的转换&#xff1f; 6、线程…...

纯CSS实现未读消息显示99+

在大佬那看到这个小技巧&#xff0c;我觉得这个功能点还挺常用&#xff0c;所以给大家分享下具体的实现。当未读消息数小于100的时候显示准确数值&#xff0c;大于99的时候显示99。 1. 实现效果 2. 组件封装 <template><span class"col"><sup :styl…...

【C++】C++ primer plus 第十二章--类和动态内存分配

动态内存和类 关于静态数据成员 类之作声明&#xff0c;不分配内存&#xff0c;因此静态成员变量在类中不能进行初始化&#xff0c;需要在类外进行。特殊情况&#xff1a; 存在可以在类中声明静态成员并初始化的情况&#xff0c;成员类型为const整型或者const枚举类型。 特殊…...

分类预测 | Matlab实现GWO-LSSVM灰狼算法优化最小二乘支持向量机数据分类预测

分类预测 | Matlab实现GWO-LSSVM灰狼算法优化最小二乘支持向量机数据分类预测 目录 分类预测 | Matlab实现GWO-LSSVM灰狼算法优化最小二乘支持向量机数据分类预测分类效果基本介绍程序设计参考资料 分类效果 基本介绍 1.Matlab实现GWO-LSSVM灰狼算法优化最小二乘支持向量机数据…...

使用PHP进行极验验证码动态参数提取与逆向分析

在网络安全领域&#xff0c;逆向工程和验证码破解是常见的技术挑战之一。极验验证码作为一种常见的人机验证工具&#xff0c;其动态参数的提取和逆向分析对于验证码的破解至关重要。本文将介绍如何使用PHP语言进行极验验证码动态参数的提取与逆向分析。 1. 准备工作 在开始之前…...

43.1k star, 免费开源的 markdown 编辑器 MarkText

43.1k star, 免费开源的 markdown 编辑器 MarkText 分类 开源分享 项目名: MarkText -- 简单而优雅的开源 Markdown 编辑器 Github 开源地址&#xff1a; https://github.com/marktext/marktext 官网地址&#xff1a; MarkText 支持平台&#xff1a; Linux, macOS 以及 Win…...

ArcGIS Pro怎么进行挖填方计算

在工程实施之前&#xff0c;我们需要充分利用地形&#xff0c;结合实际因素&#xff0c;通过挖填方计算项目的标高&#xff0c;以达到合理控制成本的目的&#xff0c;这里为大家介绍一下ArcGIS Pro中挖填方计算的方法&#xff0c;希望能对你有所帮助。 数据来源 教程所使用的…...

19c补丁后oracle属主变化,导致不能识别磁盘组

补丁后服务器重启&#xff0c;数据库再次无法启动 ORA01017: invalid username/password; logon denied Oracle 19c 在打上 19.23 或以上补丁版本后&#xff0c;存在与用户组权限相关的问题。具体表现为&#xff0c;Oracle 实例的运行用户&#xff08;oracle&#xff09;和集…...

【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)

服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...

【ROS】Nav2源码之nav2_behavior_tree-行为树节点列表

1、行为树节点分类 在 Nav2(Navigation2)的行为树框架中,行为树节点插件按照功能分为 Action(动作节点)、Condition(条件节点)、Control(控制节点) 和 Decorator(装饰节点) 四类。 1.1 动作节点 Action 执行具体的机器人操作或任务,直接与硬件、传感器或外部系统…...

P3 QT项目----记事本(3.8)

3.8 记事本项目总结 项目源码 1.main.cpp #include "widget.h" #include <QApplication> int main(int argc, char *argv[]) {QApplication a(argc, argv);Widget w;w.show();return a.exec(); } 2.widget.cpp #include "widget.h" #include &q…...

【2025年】解决Burpsuite抓不到https包的问题

环境&#xff1a;windows11 burpsuite:2025.5 在抓取https网站时&#xff0c;burpsuite抓取不到https数据包&#xff0c;只显示&#xff1a; 解决该问题只需如下三个步骤&#xff1a; 1、浏览器中访问 http://burp 2、下载 CA certificate 证书 3、在设置--隐私与安全--…...

MySQL 8.0 OCP 英文题库解析(十三)

Oracle 为庆祝 MySQL 30 周年&#xff0c;截止到 2025.07.31 之前。所有人均可以免费考取原价245美元的MySQL OCP 认证。 从今天开始&#xff0c;将英文题库免费公布出来&#xff0c;并进行解析&#xff0c;帮助大家在一个月之内轻松通过OCP认证。 本期公布试题111~120 试题1…...

【JavaSE】绘图与事件入门学习笔记

-Java绘图坐标体系 坐标体系-介绍 坐标原点位于左上角&#xff0c;以像素为单位。 在Java坐标系中,第一个是x坐标,表示当前位置为水平方向&#xff0c;距离坐标原点x个像素;第二个是y坐标&#xff0c;表示当前位置为垂直方向&#xff0c;距离坐标原点y个像素。 坐标体系-像素 …...

Android Bitmap治理全解析:从加载优化到泄漏防控的全生命周期管理

引言 Bitmap&#xff08;位图&#xff09;是Android应用内存占用的“头号杀手”。一张1080P&#xff08;1920x1080&#xff09;的图片以ARGB_8888格式加载时&#xff0c;内存占用高达8MB&#xff08;192010804字节&#xff09;。据统计&#xff0c;超过60%的应用OOM崩溃与Bitm…...

Linux 中如何提取压缩文件 ?

Linux 是一种流行的开源操作系统&#xff0c;它提供了许多工具来管理、压缩和解压缩文件。压缩文件有助于节省存储空间&#xff0c;使数据传输更快。本指南将向您展示如何在 Linux 中提取不同类型的压缩文件。 1. Unpacking ZIP Files ZIP 文件是非常常见的&#xff0c;要在 …...

【JVM面试篇】高频八股汇总——类加载和类加载器

目录 1. 讲一下类加载过程&#xff1f; 2. Java创建对象的过程&#xff1f; 3. 对象的生命周期&#xff1f; 4. 类加载器有哪些&#xff1f; 5. 双亲委派模型的作用&#xff08;好处&#xff09;&#xff1f; 6. 讲一下类的加载和双亲委派原则&#xff1f; 7. 双亲委派模…...