当前位置: 首页 > news >正文

Vision Transformer(vit)的主干

图解:

代码:

class VisionTransformer(nn.Module):def __init__(self, img_size=224, patch_size=16, in_c=3, num_classes=1000,embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.0, qkv_bias=True,qk_scale=None, representation_size=None, distilled=False, drop_ratio=0.,attn_drop_ratio=0., drop_path_ratio=0., embed_layer=PatchEmbed, norm_layer=None,act_layer=None):"""Args:img_size (int, tuple): input image size
#输入图像的大小,通常是 224 或其他标准尺寸patch_size (int, tuple): patch size
#每个块(patch)的大小,例如 16x16in_c (int): number of input channels
#输入图像的通道数,RGB 图像是 3num_classes (int): number of classes for classification head
#最终分类的类别数,默认 1000 类embed_dim (int): embedding dimension
#嵌入维度,即每个 patch 被映射到的向量的维度,默认是 768depth (int): depth of transformer
#Transformer 的深度,即堆叠的块(Block)数量。num_heads (int): number of attention heads
#注意力头的数量,默认设为 12mlp_ratio (int): ratio of mlp hidden dim to embedding dim
# MLP 隐藏层的维度与嵌入维度的比例。qkv_bias (bool): enable bias for qkv if True
#是否为 QKV(查询、键、值)矩阵添加偏置qk_scale (float): override default qk scale of head_dim ** -0.5 if set
#如果设定,将会覆盖默认的 qk 缩放因子representation_size (Optional[int]): enable and set representation layer (pre-logits) to this value if set
#如果设置了这个值,将会有一个表示层(pre-logits)distilled (bool): model includes a distillation token and head as in DeiT models
#vit中可以不管这个参数drop_ratio (float): dropout rate
# Dropout 的比例attn_drop_ratio (float): attention dropout rate
#注意力层的 Dropout 比例drop_path_ratio (float): stochastic depth rate
#droppath比例embed_layer (nn.Module): patch embedding layer
#用于嵌入图像的层,默认使用 PatchEmbednorm_layer: (nn.Module): normalization layer
#正则化层,通常是 LayerNorm"""super(VisionTransformer, self).__init__()self.num_classes = num_classesself.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
# 与 embed_dim 保持一致,表示嵌入的维度。self.num_tokens = 2 if distilled else 1
#不管distilled所以distilled=1norm_layer = norm_layer or partial(nn.LayerNorm, eps=1e-6)
#使用 LayerNorm作为默认的规范化层act_layer = act_layer or nn.GELU
#默认使用 GELU 作为激活函数self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_c=in_c, embed_dim=embed_dim)
#Embedding层结构num_patches = self.patch_embed.num_patches
#patches的个数self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
#这是用于分类的分类标记(Class Token),它是一个可学习的参数,初始值为零self.dist_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) if distilled else None
#不管distilled所以self.dist_token=Noneself.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
#位置编码(Position Embedding)self.pos_drop = nn.Dropout(p=drop_ratio)
#位置编码后的 Dropout 操作dpr = [x.item() for x in torch.linspace(0, drop_path_ratio, depth)]  # stochastic depth decay rule
#用于控制每个 Block 的 DropPath 比例self.blocks = nn.Sequential(*[Block(dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, qk_scale=qk_scale,drop_ratio=drop_ratio, attn_drop_ratio=attn_drop_ratio, drop_path_ratio=dpr[i],norm_layer=norm_layer, act_layer=act_layer)for i in range(depth)])
#使用 Block 类构建了Transformer的主体部分,包括注意力和MLP层,并使用残差连接和 DropPath self.norm = norm_layer(embed_dim)
#最后的归一化层,用于 Transformer 输出的处理# Representation layerif representation_size and not distilled:
#设置了 representation_size则会增加一个表示层 pre_logits,not distilled=trueself.has_logits = Trueself.num_features = representation_sizeself.pre_logits = nn.Sequential(OrderedDict([("fc", nn.Linear(embed_dim, representation_size)),("act", nn.Tanh())]))
#pre_logits层结构一个全连接和tanh激活函数else:self.has_logits = Falseself.pre_logits = nn.Identity()# Classifier head(s)self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()self.head_dist = Noneif distilled:self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
#distilled为none不用管# Weight initnn.init.trunc_normal_(self.pos_embed, std=0.02)if self.dist_token is not None:nn.init.trunc_normal_(self.dist_token, std=0.02)nn.init.trunc_normal_(self.cls_token, std=0.02)self.apply(_init_vit_weights)
#权重初始化def forward_features(self, x):# [B, C, H, W] -> [B, num_patches, embed_dim]x = self.patch_embed(x)  # [B, 196, 768]
#将输入的图像 x 切分为多个 patch 并嵌入,通过Embedding层# [1, 1, 768] -> [B, 1, 768]cls_token = self.cls_token.expand(x.shape[0], -1, -1)if self.dist_token is None:x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]else:x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
#分类标记如果有将cls_token加入,因为dist_token为none,所以在维度1上拼接x = self.pos_drop(x + self.pos_embed)
#添加位置编码并应用 Dropoutx = self.blocks(x)
#通过 Transformer 的 Block 堆叠进行处理x = self.norm(x)
#进行归一化
#vit中self.dist_token is None所以模型只有分类标记 (class token)。if self.dist_token is None:return self.pre_logits(x[:, 0])
#x[:, 0]表示提取分类标记(class token) 的输出向量。这个向量是用于分类任务的主要特征表示。else:return x[:, 0], x[:, 1]def forward(self, x):x = self.forward_features(x)
#首先获取 Transformer 的特征输出if self.head_dist is not None:x, x_dist = self.head(x[0]), self.head_dist(x[1])if self.training and not torch.jit.is_scripting():# during inference, return the average of both classifier predictionsreturn x, x_distelse:return (x + x_dist) / 2else:
#self.head_dist为none只看head层就是最后的全连接层输出为num_classesx = self.head(x)return x

 操作:

代码:

        # [B, C, H, W] -> [B, num_patches, embed_dim]
        x = self.patch_embed(x)  # [B, 196, 768]
#将输入的图像 x 切分为多个 patch 并嵌入,通过Embedding层

操作:

代码:

    # [1, 1, 768] -> [B, 1, 768]
        cls_token = self.cls_token.expand(x.shape[0], -1, -1)
        if self.dist_token is None:
            x = torch.cat((cls_token, x), dim=1)  # [B, 197, 768]
        else:
            x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
#分类标记如果有将cls_token加入,因为dist_token为none,所以在维度1上拼接

操作:

代码:

x = self.pos_drop(x + self.pos_embed)
#添加位置编码并应用 Dropout

操作:

代码:

        x = self.blocks(x)
#通过 Transformer 的 Block 堆叠进行处理
        x = self.norm(x)
#进行归一化

操作:

代码:

#vit中self.dist_token is None所以模型只有分类标记 (class token)。
        if self.dist_token is None:
            return self.pre_logits(x[:, 0])
#x[:, 0]表示提取分类标记(class token) 的输出向量。这个向量是用于分类任务的主要特征表示。
        else:
            return x[:, 0], x[:, 1]

操作:

代码:

#self.head_dist为none只看head层就是最后的全连接层输出为num_classes
            x = self.head(x)

分类标记 (Class Token):

是一种特殊的 输入 token,在 Transformer 模型中被用来聚合全局特征。

它在模型中起到了类似于 CNN 中全局池化 (Global Pooling) 的作用,负责从所有 patch 的信息中提取一个全局表示。

这个 token 的输出向量被用作分类任务的特征输入,之后会被送入分类头 (classifier head) 进行最终的类别预测。

embedding层:

Vision Transformer(vit)的Embedding层结构-CSDN博客

Multi-Head Self-Attention:

Vision Transformer(vit)的Multi-Head Self-Attention(多头注意力机制)结构-CSDN博客

MLP模块:

Vision Transformer(vit)的MLP模块-CSDN博客

Encoder block:

Vision Transformer(vit)的Encoder层结构-CSDN博客

详解:Vision Transformer详解-CSDN博客

相关文章:

Vision Transformer(vit)的主干

图解: 代码: class VisionTransformer(nn.Module):def __init__(self, img_size224, patch_size16, in_c3, num_classes1000,embed_dim768, depth12, num_heads12, mlp_ratio4.0, qkv_biasTrue,qk_scaleNone, representation_sizeNone, distilledFalse,…...

手撸了一个文件传输工具

在日常的开发与运维中,文件传输工具是不可或缺的利器。无论是跨服务器传递配置文件,还是快速从一台机器下载日志文件,一个高效、可靠且简单的文件传输工具能够显著提高工作效率。今天,我想分享我自己手撸一个文件传输工具的全过程…...

Java程序调kubernetes(k8s1.30.7)core API简单示例,并解决403权限验证问题,即何进行进行权限授权以及验证

简单记录问题 一、问题描述 希望通过Java程序使用Kubernetes提供的工具包实现对Kubernetes集群core API的调用&#xff0c;但是在高版本上遇见权限验证问题4xx。 <dependency><groupId>io.kubernetes</groupId><artifactId>client-java</artifact…...

java八股-Redis Stream和RocketMQ实现的解决方案

文章目录 Redis Stream方案&#xff1a;ShortLinkStatsSaveProducer.javaShortLinkStatsSaveConsumer.java RocketMQ方案ShortLinkStatsSaveProducer.javaShortLinkStatsSaveConsumer.java Redis Stream方案&#xff1a; ShortLinkStatsSaveProducer.java package com.nageoff…...

第29天 MCU入门

目录 MCU介绍 MCU的组成与作用 电子产品项目开发流程 硬件开发流程 常用元器件初步了解 硬件原理图与PCB板 常见电源符号和名称 电阻 电阻的分类 贴片电阻的封装说明&#xff1a; 色环电阻的计算 贴片电阻阻值计算 上拉电阻与下拉电阻 电容 电容的读数 二极管 LED 灯电路 钳位作…...

【Python网络爬虫笔记】6- 网络爬虫中的Requests库

一、概述 Requests 是一个用 Python 语言编写的、简洁且功能强大的 HTTP 库。它允许开发者方便地发送各种 HTTP 请求&#xff0c;如 GET、POST、PUT、DELETE 等&#xff0c;并且可以轻松地处理请求的响应。这个库在 Python 生态系统中被广泛使用&#xff0c;无论是简单的网页数…...

Linux网络_网络协议_网络传输_网络字节序

一.协议 1.概念 协议&#xff08;Protocol&#xff09; 是一组规则和约定&#xff0c;用于定义计算机网络中不同设备之间如何进行通信和数据交换。协议规定了数据的格式、传输方式、传输顺序等详细规则&#xff0c;确保不同设备和系统能够有效地互联互通。 在网络通信中&#…...

浅谈网络 | 应用层之流媒体与P2P协议

目录 流媒体名词系列视频的本质视频压缩编码过程如何在直播中看到帅哥美女&#xff1f;RTMP 协议 P2PP2P 文件下载种子文件 (.torrent)去中心化网络&#xff08;DHT&#xff09;哈希值与 DHT 网络DHT 网络是如何查找 流媒体 直播系统组成与协议 近几年直播比较火&#xff0c;…...

css vue vxe-text-ellipsis table 实现多行文本超出隐藏省略

分享 vxe-text-ellipsis table grid 多行文本溢出省略的用法 正常情况下如果需要使用文本超出隐藏&#xff0c;通过 css 就可以完成 overflow: hidden; text-overflow: ellipsis; white-space: nowrap;但是如果需要实现多行文本溢出&#xff0c;就很难实现里&#xff0c;谷歌…...

基于hexo框架的博客搭建流程

这篇博文讲一讲hexo博客的搭建及文章管理&#xff0c;也算是我对于暑假的一个交代 &#xff01;&#xff01;&#xff01;注意&#xff1a;下面的操作是基于你已经安装了node.js和git的前提下进行的&#xff0c;并且拥有github账号 创建一个blog目录 在磁盘任意位置创建一个…...

数据结构-简单排序

一.前提 二.冒泡排序 三.插入排序 #include<iostream> using namespace std; typedef int ElemengType; void Bubble_Sort(ElemengType A[], int N) {for (int p N - 1; p > 0; p--) {int flag 0;for (int i 0; i < p; i) {if (A[i] > A[i 1]) {swap(A[i], …...

三十一:HTTP多种重定向跳转方式的差异

在现代网站开发中&#xff0c;HTTP 重定向是一种常见的技术&#xff0c;用于将用户的请求从一个 URL 跳转到另一个 URL。重定向机制广泛应用于网站迁移、SEO 优化、以及内容管理系统中。不同的 HTTP 状态码代表不同的重定向方式&#xff0c;每种方式的行为和适用场景各有不同。…...

利用Python爬虫精准获取淘宝商品详情的深度解析

在数字化时代&#xff0c;数据的价值日益凸显&#xff0c;尤其是在电子商务领域。淘宝作为中国最大的电商平台之一&#xff0c;拥有海量的商品数据&#xff0c;对于研究市场趋势、分析消费者行为等具有重要意义。本文将详细介绍如何使用Python编写爬虫程序&#xff0c;精准获取…...

架构师的英文:Architect

中文版 软件架构师 的英文是 “Software Architect”。 Software: 软件Architect: 架构师&#xff0c;通常指的是设计和规划某种系统或结构的人。 Software Architect 通常负责软件系统的整体设计、技术选型、架构规划&#xff0c;确保系统的可扩展性、可维护性和高效性等。…...

数据结构 ——— 计数排序算法的实现

目录 计数排序算法的思想 计数排序算法的实现 计数排序算法的思想 遍历数组&#xff0c;找出数组中的最大值 max 和 最小值 min 最大值 max 减去最小值 min 再加 1 得出数组元素的范围 range 利用 range 的大小 malloc 一个 count 数组用来计数 再对 count 数组进行初始化…...

k8s搭建Istio环境,案例pod一直处在Init:CrashLoopBackOff

1 部署calico网络环境&#xff0c;网上去找k8s版本对应的calico的配置文件&#xff0c;k8s2.8.0我用的3.28 2 安装istio环境 curl -L https://istio.io/downloadIstio | sh - # 省略istioctl生效的步骤 source <(istioctl completion zsh) istioctl install --set profile…...

Jenkins升级到最新版本后无法启动

1. 场景还原 最近在web界面将jenkins升级到最新版本后&#xff0c;后台无法启动jenkins服务&#xff0c;服务状态如下&#xff1a; 运行jenkins命令提示invalid Java version jenkins --version jenkins: invalid Java version: java version "1.8.0_202" Java(TM)…...

用户界面创建一个新的运动类型

● 现在我们需要根据我们之前规划的架构步骤来实现在用户界面创建一个运动类型 ● 首先我们在要获取用户在表单中输入的数据 //从表单中获取数据const type inputType.value;const distance inputDistance.value;const duration inputDuration.value;● 然后针对与不同的运动…...

ubuntu防火墙入门(一)——设置服务、关闭端口

本机想通过git clone gitgithub.com:skumra/robotic-grasping.git下载代码&#xff0c;firewall-config中需要为当前区域的防火墙开启SSH服务吗 是的&#xff0c;如果你想通过 git clone gitgithub.com:skumra/robotic-grasping.git 使用 SSH 协议从 GitHub 下载代码&#xff0…...

分治算法——二分查找(c++)(详解)

大家好&#xff0c;今天进入一个实用算法&#xff1a;分治算法。 1.分治算法介绍 分治算法&#xff0c;大概就是将一个大问题拆解成若干个小问题&#xff0c;将小问题一一解决&#xff0c;大问题也就迎刃而解。它包含了多种算法&#xff0c;比如递归、递推等。这里就讲解一下其…...

Binder架构

一、架构 如上图&#xff0c;binder 分为用户层和驱动层两部分&#xff0c;用户层有客户端&#xff08;Client&#xff09;、服务端&#xff08;Server&#xff09;、服务管理&#xff08;ServiceManager&#xff09;。 从用户空间的角度&#xff0c;使用步骤如下&#xff08;…...

大数据治理:解锁数据价值,引领未来创新

目录 引言 一、大数据治理的定义 二、大数据治理的重要性 三、大数据治理的核心组件 四、大数据治理的实践案例 1. 数据标准化 2. 数据质量管理 案例一&#xff1a;医疗行业的大数据治理——智能医疗助手守护健康 引言 在数字化时代&#xff0c;数据已成为企业最宝贵的…...

解决windows下php8.x及以上版本,在Apache2.4中无法加载CURL扩展的问题

本文已首发于&#xff1a;秋码记录 若你也想搭建一个个人博客&#xff0c;可参考&#xff1a;国内 gitee.com Pages 下线了&#xff0c;致使众多站长纷纷改用 github、gitlab Pages 托管平台 在日新月异的信息化下&#xff0c;软件也在跟随着互联网的脚步&#xff0c;逐步推进…...

【韩顺平老师Java反射笔记】

反射 文章目录 基本使用反射机制java程序在计算机有三个阶段反射相关的主要类 反射调用优化Class类的常用方法获取Class对象的6种方式哪些类型有Class对象类加载类加载时机类加载过程图 通过反射获取类的结构信息第一组&#xff1a;java.lang.Class类第二组&#xff1a;java.la…...

Arrays.asList()新增报错,该怎么解决

一、前言 在 Java 开发中&#xff0c;Arrays.asList() 是一个常用的工具方法&#xff0c;它允许开发者快速将数组转换为列表。尽管这个方法非常方便&#xff0c;但许多开发者在使用时可能会遭遇一个常见的错误&#xff1a;尝试向由 Arrays.asList() 返回的列表中添加元素时抛出…...

【热门主题】000072 分布式数据库:开启数据管理新纪元

前言&#xff1a;哈喽&#xff0c;大家好&#xff0c;今天给大家分享一篇文章&#xff01;并提供具体代码帮助大家深入理解&#xff0c;彻底掌握&#xff01;创作不易&#xff0c;如果能帮助到大家或者给大家一些灵感和启发&#xff0c;欢迎收藏关注哦 &#x1f495; 目录 【热…...

基于Springboot开发的云野旅游平台

一、功能介绍 云野旅游平台包含管理员、用户两个角色以及前后台系统。 前台系统功能 用户登录成功后&#xff0c;可以进行查看旅游路线、最新线路、旅游资讯、个人中心、后台管理、购物车、客服等功能模块。进行相对应操作。 后台系统功能 管理员或用户登录成功后&#xf…...

2024金盾信安杯线上赛 MISC ezpng[wp]

下载题目发现给了个password和png 图片发现损坏的 password丢随波逐流一键解 base64 给出解码的结果是 cimbar搜索发现在Github有工具 然后对附件中的图片进行小厨房xor 得到一张新图片 利用工具进行跑出答案...

搭建业务的性能优化指南

这是一篇搭建业务优化的心路历程&#xff0c;也是写给搭建业务的性能优化指南。 前言 直到今天&#xff0c;淘内的页面大多都迁移到了 SSR&#xff0c;从我们终端平台 - 搭建研发团队的视角看&#xff0c;业务大致可以分为两类 —— 搭建派 和 源码派。 这两者互不冲突&#xf…...

电脑提示报错“Directx error”怎么解决?是什么原因导致的?游戏软件提示“Directx error”错误的解决方案

DirectX Error&#xff08;DX错误&#xff09;通常指的是在使用基于DirectX技术的应用程序&#xff08;尤其是游戏&#xff09;时遇到的问题。这个问题可能由多种因素导致&#xff0c;以下是一些可能的原因及相应的解决方案&#xff1a; 可能的原因 DirectX版本不匹配&#x…...

怎样做网站模板/湖人排名最新

今天主要学习了列表&#xff0c;python的列表真的事太强大了&#xff0c;由于内容比较多&#xff0c;今天就先简单的介绍一下新学的几个成员函数吧。 首先我们要了解list是一种序列类型&#xff0c;其构造方式有四种形式&#xff1a; &#xff08;1&#xff09;空列表 [] &…...

做网站简单还是做app简单/网络销售怎么找客源

海信POS机可编程键值定义&#xff1a;Q&#xff1a;档键设置&#xff0c;即右仙K0\K1\K2\K3\K4\K5键的设置&#xff0c;其对应键上的L\R\S\X\Z键&#xff0c;如何将其第一层键值放到R档上&#xff0c;第二层键值放到S档上?A&#xff1a;将对应R档的K1键设为空&#xff0c;对应…...

做网站主图多少钱/微帮推广平台怎么加入

[TOC]1.JDBC入门1.1.什么是JDBCJDBC从物理结构上来说就是java语言访问数据库的一套接口集合&#xff0c;本质上是java语言根数据库之间的协议。JDBC提供一组类和接口&#xff0c;通过使用JDBC&#xff0c;开发人员可以使用java代码发送sql语句&#xff0c;来操作数据库1.2.使用…...

建设网站是普通办公吗/天津做网站的网络公司

1&#xff09;实验平台&#xff1a;正点原子MPSoC开发板 2&#xff09;平台购买地址&#xff1a;https://detail.tmall.com/item.htm?id692450874670 3&#xff09;全套实验源码手册视频下载地址&#xff1a; http://www.openedv.com/thread-340252-1-1.html 第四章开发环境搭…...

文章做模板 wordpress/网站规划

1.编译行为带来的缺陷 预处理器将头文件中的代码直接插入源文件 编译器只通过预处理后的源文件产生目标文件 因此&#xff0c;规则中以源文件为依赖&#xff0c;命令可能无法执行 实现想法&#xff1a; 通过命令自动生成对文件的依赖 将生成的依赖自动包含进makefile中 当头文…...

网站上的销售怎么做/网络营销案例分析

通常遇到的大文本文件是log日志文件&#xff0c;GB级别的log文件很常见通常在打开log文件时头痛&#xff0c;因为常用的一些文本文件工具都不好用了&#xff0c;比如UE&#xff0c;notepad等&#xff0c;记事本就不用提了今天&#xff0c;我需要在1.5G的log文件中查找标签&…...