当前位置: 首页 > news >正文

Transformer实战-系列教程11:SwinTransformer 源码解读4(WindowAttention类)

🚩🚩🚩Transformer实战-系列教程总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Pycharm中进行
本篇文章配套的代码资源已经上传
点我下载源码

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)

6、WindowAttention类

6.1 构造函数

class WindowAttention(nn.Module):def __init__(self, dim, window_size, num_heads, qkv_bias=True, qk_scale=None, attn_drop=0., proj_drop=0.):super().__init__()self.dim = dimself.window_size = window_sizeself.num_heads = num_headshead_dim = dim // num_headsself.scale = qk_scale or head_dim ** -0.5self.relative_position_bias_table = nn.Parameter(torch.zeros((2 * window_size[0] - 1) * (2 * window_size[1] - 1), num_heads))coords_h = torch.arange(self.window_size[0])coords_w = torch.arange(self.window_size[1])coords = torch.stack(torch.meshgrid([coords_h, coords_w]))coords_flatten = torch.flatten(coords, 1)relative_coords = coords_flatten[:, :, None] - coords_flatten[:, None, :]relative_coords = relative_coords.permute(1, 2, 0).contiguous()relative_coords[:, :, 0] += self.window_size[0] - 1relative_coords[:, :, 1] += self.window_size[1] - 1relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1relative_position_index = relative_coords.sum(-1)self.register_buffer("relative_position_index", relative_position_index)self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)trunc_normal_(self.relative_position_bias_table, std=.02)self.softmax = nn.Softmax(dim=-1)
  1. dim:输入特征维度
  2. window_size:窗口大小
  3. num_heads:多头注意力头数
  4. head_dim:每头注意力的头数
  5. scale :缩放因子
  6. relative_position_bias_table:相对位置偏置表,它对每个头存储不同窗口位置之间的偏置,以模拟位置信息
  7. coords_h 、coords_w、coords:窗口内每个位置的坐标
  8. coords_flatten :将坐标展平,为计算相对位置做准备
  9. 第1个relative_coords:计算窗口内每个位置相对于其他位置的坐标差
  10. 第2个relative_coords:重排坐标差的维度以符合预期的格式
  11. relative_coords[:, :, 0]、relative_coords[:, :, 1]、relative_coords[:, :, 0]:调整坐标差,使其能够映射到相对位置偏置表中的索引
  12. relative_position_index :计算每对位置之间的相对位置索引
  13. register_buffer:将相对位置索引注册为模型的缓冲区,这样它就不会在训练过程中被更新
  14. qkv :创建一个线性层,用于生成QKV
  15. attn_drop、proj、proj_drop:初始化注意力dropout、输出投影层及其dropout
  16. trunc_normal_:使用截断正态分布初始化相对位置偏置表
  17. softmax :初始化softmax层,用于计算注意力权重

6.2 前向传播

    def forward(self, x, mask=None):B_, N, C = x.shapeqkv = self.qkv(x).reshape(B_, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]  # make torchscript happy (cannot use tensor as tuple)q = q * self.scaleattn = (q @ k.transpose(-2, -1))relative_position_bias = self.relative_position_bias_table[self.relative_position_index.view(-1)].view(self.window_size[0] * self.window_size[1], self.window_size[0] * self.window_size[1], -1)  # Wh*Ww,Wh*Ww,nHrelative_position_bias = relative_position_bias.permute(2, 0, 1).contiguous()  # nH, Wh*Ww, Wh*Wwattn = attn + relative_position_bias.unsqueeze(0)if mask is not None:nW = mask.shape[0]attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(1).unsqueeze(0)attn = attn.view(-1, self.num_heads, N, N)attn = self.softmax(attn)else:attn = self.softmax(attn)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B_, N, C)x = self.proj(x)x = self.proj_drop(x)return x
  1. B_, N, C = x.shape原始输入: torch.Size([256, 49, 96]),B_, N, C即原始输入的维度
  2. qkv = self.qkv(x).reshape...qkv: torch.Size([3, 256, 3, 49, 32]),被重塑的一个五维张量,分别代表qkv三个维度、256个窗口、3个注意力头数但是不会一直是3越往后会越多、49是一个窗口有7*7=49元素、每个头的特征维度。在之前的Transformer以及Vision Transformer中,都是用x接上各自的全连接后分别生成QKV,这这里直接一起生成了。
  3. q: torch.Size([256, 3, 49, 32]),k: torch.Size([256, 3, 49, 32]),v: torch.Size([256, 3, 49, 32]),从qkv中分解出q、k、v,而且已经包含了多头注意力机制
  4. attn: torch.Size([256, 3, 49, 49]),attn是q和k的点积
  5. relative_position_bias: torch.Size([49, 49, 3]),从相对位置偏置表中索引出每对位置之间的偏置,并重塑以匹配注意力分数的形状
  6. relative_position_bias: torch.Size([3, 49, 49]),重新排列,位置编码在Transformer中一直当成偏置加进去的,而这个位置编码是对一个窗口的,所以每一个窗口的都对应了相同的位置编码
  7. attn: torch.Size([256, 3, 49, 49]),将位置编码加到注意力分数上,到这里就算完了全部的注意力机制了
  8. attn: torch.Size([256, 3, 49, 49]),掩码加到注意力分数上,使用softmax函数归一化注意力分数,得到注意力权重,应用注意力dropout
  9. x: torch.Size([256, 49, 96]),使用注意力权重对v向量进行重构,然后对结果进行转置和重塑
  10. x: torch.Size([256, 49, 96]),将加权的注意力输出通过一个线性投影层,应用输出dropout,这就是最后WindowAttention的输出,一共256个窗口,每个窗口有49个特征,每个特征对应96维的向量

SwinTransformer 算法原理
SwinTransformer 源码解读1(项目配置/SwinTransformer类)
SwinTransformer 源码解读2(PatchEmbed类/BasicLayer类)
SwinTransformer 源码解读3(SwinTransformerBlock类)
SwinTransformer 源码解读4(WindowAttention类)
SwinTransformer 源码解读5(Mlp类/PatchMerging类)

相关文章:

Transformer实战-系列教程11:SwinTransformer 源码解读4(WindowAttention类)

🚩🚩🚩Transformer实战-系列教程总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在Pycharm中进行 本篇文章配套的代码资源已经上传 点我下载源码 SwinTransformer 算法原理 SwinTransformer 源码解读1(项目配置/SwinTr…...

Jenkins(本地Windows上搭建)上传 Pipeline构建前端项目并将生成dist文件夹上传至指定服务器

下载安装jdk https://www.oracle.com/cn/java/technologies/downloads/#jdk21-windows 下载jenkins window版 双击安装 https://www.jenkins.io/download/thank-you-downloading-windows-installer-stable/ 网页输入 http://localhost:8088/ 输入密码、设置账号、安装推…...

Elasticsearch 安装和配置脚本文档

Elasticsearch 安装和配置脚本文档 目录 **Elasticsearch 安装和配置脚本文档**0.**概述**1.**使用方法:**2.**脚本步骤:**3. **完整代码如下:** 0.概述 此Bash脚本用于自动化在CentOS 7系统上安装和配置Elasticsearch(ES&#x…...

【Android辟邪】之:gradle——在项目间共享依赖关系版本

翻译和简单修改自:https://docs.gradle.org/current/userguide/platforms.html#sec:sharing-catalogs 建议看原文(有能力的话) 现在 Gradle 脚本可以使用两种语法编写:Kotlin 和 Groovy 本文只使用kotlin脚本语法,更…...

Qt 项目树工程,拷贝子项目dll到子项目exe运行路径

1、项目树工程 2、项目树列表 ---- BuildAll -------- App (exe) -------- Database (dll) 注:使用 子项目–>添加库–>内部库 的方式 3、qmake 内置的变量 $$OUT_PWD 表示输出文件(如可执行文件…...

进程间通信方式

1>内核提供的原始通信方式有三种 1)无名管道 2)有名管道 3)信号 2>System V提供了三种通信方式 4)消息队列 5)共享内存 6)信号量(信号灯集) 3>套接字通信 7)socke…...

[linux]:匿名管道和命名管道(什么是管道,怎么创建管道(函数),匿名管道和命名管道的区别,代码例子)

目录 一、匿名管道 1.什么是管道?什么是匿名管道? 2.怎么创建匿名管道(函数) 3.匿名管道的4种情况 4.匿名管道有5种特性 5.怎么使用匿名管道?匿名管道有什么用?(例子) 二、命名…...

Python调用matlab程序

matlab官网:https://ww2.mathworks.cn/?s_tidgn_logo matlab外部语言和库接口,包括 Python、Java、C、C、.NET 和 Web 服务。 matlab和python的版本 安装依赖配置 安装matlab的engine 找到matlab的安装目录:“xxx\ extern\engines\python…...

FlinkSql 窗口函数

Windowing TVF 以前用的是Grouped Window Functions(分组窗口函数),但是分组窗口函数只支持窗口聚合 现在FlinkSql统一都是用的是Windowing TVFs(窗口表值函数),Windowing TVFs更符合 SQL 标准且更加强大…...

十分钟GIS——geoserver+postgis+udig从零开始发布地图服务

1数据库部署 1.1PostgreSql安装 下载到安装文件后(postgresql-9.2.19-1-windows-x64.exe),双击安装。 指定安装目录,如下图所示 指定数据库文件存放目录位置,如下图所示 指定数据库访问管理员密码,如下图所…...

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之Span组件

鸿蒙(HarmonyOS)项目方舟框架(ArkUI)之Span组件 一、操作环境 操作系统: Windows 10 专业版、IDE:DevEco Studio 3.1、SDK:HarmonyOS 3.1 二、Span组件 鸿蒙(HarmonyOS)作为Text组件的子组件&#xff0…...

Leetcode—42. 接雨水【困难】

2024每日刷题&#xff08;112&#xff09; Leetcode—42. 接雨水 空间复杂度为O(n)的算法思想 实现代码 class Solution { public:int trap(vector<int>& height) {int ans 0;int n height.size();vector<int> l(n);vector<int> r(n);for(int i 0; …...

[Python] opencv - 什么是直方图?如何绘制图像的直方图?如何对直方图进行均匀化处理?

什么是直方图&#xff1f; 直方图是一种统计图&#xff0c;用于展示数据的分布情况。它将数据按照一定的区间或者组进行划分&#xff0c;然后计算在每个区间或组内的数据频数或频率&#xff08;即数据出现的次数或占比&#xff09;&#xff0c;然后用矩形或者柱形图的形式将这…...

ppi rust开发 python调用

创建python的一个测试工程 python -m venv venv .\venv\Scripts\activatepip install cffi创建一个rust的lib项目 cargo new --lib pyrustlib.rs #[no_mangle] pub extern "C" fn rust_add(x: i32, y: i32) -> i32 {x y }Cargo.toml [package] name "p…...

网站后端开发 thinkphp6 入门教程合集(更新中)

thinkphp6 入门&#xff08;1&#xff09;--安装、路由规则、多应用模式 thinkphp6 入门&#xff08;1&#xff09;--安装、路由规则、多应用模式_软件工程小施同学的博客-CSDN博客 thinkphp6 入门&#xff08;2&#xff09;--视图、渲染html页面、赋值 thinkphp6 入门&#x…...

Web前端框架-Vue(初识)

文章目录 web前端三大主流框架**1.Angular****2.React****3.Vue**什么是Vue.js 为什么要学习流行框架框架和库和插件的区别一.简介指令v-cloakv-textv-htmlv-pre**v-once**v-onv-on事件函数中传入参数事件修饰符双向数据绑定v-model 按键修饰符自定义按键修饰符别名v-bind(属性…...

配置dns服务的正反向解析

服务端IP客户端IP网址192.168.153.137192.168.153.www.openlab.com 1&#xff1a;正向解析 1.1关闭客户端和服务端的安全软件&#xff0c;安装bind软件 [rootserver ~]# setenforce 0 [rootserver ~]# systemctl stop firewalld [rootserver ~]# yum install bind -y [rootnod…...

小白水平理解面试经典题目LeetCode 71. Simplify Path【Stack类】

71. 简化路径 小白渣翻译 给定一个字符串 path &#xff0c;它是 Unix 风格文件系统中文件或目录的绝对路径&#xff08;以斜杠 ‘/’ 开头&#xff09;&#xff0c;将其转换为简化的规范路径。 在 Unix 风格的文件系统中&#xff0c;句点 ‘.’ 指的是当前目录&#xff0c;…...

电力负荷预测 | 电力系统负荷预测模型(Python线性回归、随机森林、支持向量机、BP神经网络、GRU、LSTM)

文章目录 效果一览文章概述源码设计参考资料效果一览 文章概述 电力系统负荷预测模型(Python线性回归、随机森林、支持向量机、BP神经网络、GRU、LSTM) 所谓预测,就是指通过对事物进行分析及研究,并运用合理的方法探索事物的发展变化规律,对其未来发展做出预先估计和判断。…...

YY调音台:音频后期处理

我从事影视后期处理的工作&#xff0c;主要负责音频、音效合成这块工作内容。在影视作品中&#xff0c;声音不仅仅是背景元素&#xff0c;它在叙事和创造情感氛围上发挥着至关重要的作用。我们的工作不仅要让听众听到声音&#xff0c;更要让他们通过声音感受到情感的波动和故事…...

基于FPGA的PID算法学习———实现PID比例控制算法

基于FPGA的PID算法学习 前言一、PID算法分析二、PID仿真分析1. PID代码2.PI代码3.P代码4.顶层5.测试文件6.仿真波形 总结 前言 学习内容&#xff1a;参考网站&#xff1a; PID算法控制 PID即&#xff1a;Proportional&#xff08;比例&#xff09;、Integral&#xff08;积分&…...

JavaScript 中的 ES|QL:利用 Apache Arrow 工具

作者&#xff1a;来自 Elastic Jeffrey Rengifo 学习如何将 ES|QL 与 JavaScript 的 Apache Arrow 客户端工具一起使用。 想获得 Elastic 认证吗&#xff1f;了解下一期 Elasticsearch Engineer 培训的时间吧&#xff01; Elasticsearch 拥有众多新功能&#xff0c;助你为自己…...

如何在看板中有效管理突发紧急任务

在看板中有效管理突发紧急任务需要&#xff1a;设立专门的紧急任务通道、重新调整任务优先级、保持适度的WIP&#xff08;Work-in-Progress&#xff09;弹性、优化任务处理流程、提高团队应对突发情况的敏捷性。其中&#xff0c;设立专门的紧急任务通道尤为重要&#xff0c;这能…...

Cinnamon修改面板小工具图标

Cinnamon开始菜单-CSDN博客 设置模块都是做好的&#xff0c;比GNOME简单得多&#xff01; 在 applet.js 里增加 const Settings imports.ui.settings;this.settings new Settings.AppletSettings(this, HTYMenusonichy, instance_id); this.settings.bind(menu-icon, menu…...

鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院查看报告小程序

一、开发环境准备 ​​工具安装​​&#xff1a; 下载安装DevEco Studio 4.0&#xff08;支持HarmonyOS 5&#xff09;配置HarmonyOS SDK 5.0确保Node.js版本≥14 ​​项目初始化​​&#xff1a; ohpm init harmony/hospital-report-app 二、核心功能模块实现 1. 报告列表…...

新能源汽车智慧充电桩管理方案:新能源充电桩散热问题及消防安全监管方案

随着新能源汽车的快速普及&#xff0c;充电桩作为核心配套设施&#xff0c;其安全性与可靠性备受关注。然而&#xff0c;在高温、高负荷运行环境下&#xff0c;充电桩的散热问题与消防安全隐患日益凸显&#xff0c;成为制约行业发展的关键瓶颈。 如何通过智慧化管理手段优化散…...

Matlab | matlab常用命令总结

常用命令 一、 基础操作与环境二、 矩阵与数组操作(核心)三、 绘图与可视化四、 编程与控制流五、 符号计算 (Symbolic Math Toolbox)六、 文件与数据 I/O七、 常用函数类别重要提示这是一份 MATLAB 常用命令和功能的总结,涵盖了基础操作、矩阵运算、绘图、编程和文件处理等…...

回溯算法学习

一、电话号码的字母组合 import java.util.ArrayList; import java.util.List;import javax.management.loading.PrivateClassLoader;public class letterCombinations {private static final String[] KEYPAD {"", //0"", //1"abc", //2"…...

Caliper 负载(Workload)详细解析

Caliper 负载(Workload)详细解析 负载(Workload)是 Caliper 性能测试的核心部分,它定义了测试期间要执行的具体合约调用行为和交易模式。下面我将全面深入地讲解负载的各个方面。 一、负载模块基本结构 一个典型的负载模块(如 workload.js)包含以下基本结构: use strict;/…...

从物理机到云原生:全面解析计算虚拟化技术的演进与应用

前言&#xff1a;我的虚拟化技术探索之旅 我最早接触"虚拟机"的概念是从Java开始的——JVM&#xff08;Java Virtual Machine&#xff09;让"一次编写&#xff0c;到处运行"成为可能。这个软件层面的虚拟化让我着迷&#xff0c;但直到后来接触VMware和Doc…...