VIT用于图像分类 学习笔记(附代码)
论文地址:https://arxiv.org/abs/2010.11929
代码地址:https://github.com/bubbliiiing/classification-pytorch
1.是什么?
Vision Transformer(VIT)是一种基于Transformer架构的图像分类模型。它将图像分割成一系列的图像块,并将每个图像块作为输入序列传递给Transformer模型。VIT通过自注意力机制来捕捉图像中的全局上下文信息,并使用多层感知机(MLP)来进行特征提取和分类。
VIT的核心思想是将图像转换为序列数据,这使得模型能够利用Transformer的强大表达能力来处理图像。通过将图像分割成图像块,并将它们展平为序列,VIT能够在不依赖传统卷积神经网络的情况下实现图像分类任务。
2.为什么?
从2020年,transformer开始在CV领域大放异彩:图像分类(ViT, DeiT),目标检测(DETR,Deformable DETR),语义分割(SETR,MedT),图像生成(GANsformer)等。而从深度学习暴发以来,CNN一直是CV领域的主流模型,而且取得了很好的效果,相比之下transformer却独霸NLP领域,transformer在CV领域的探索正是研究界想把transformer在NLP领域的成功借鉴到CV领域。对于图像问题,卷积具有天然的先天优势(inductive bias):平移等价性(translation equivariance)和局部性(locality)。而transformer虽然不并具备这些优势,但是transformer的核心self-attention的优势不像卷积那样有固定且有限的感受野,self-attention操作可以获得long-range信息(相比之下CNN要通过不断堆积Conv layers来获取更大的感受野),但训练的难度就比CNN要稍大一些。
ViT(vision transformer)是Google在2020年提出的直接将transformer应用在图像分类的模型,后面很多的工作都是基于ViT进行改进的。这篇论文也是受到其启发,尝试将Transformer应用到CV领域通过这篇文章的实验,给出的最佳模型在ImageNet1K上能够达到88.55%的准确率(先在Google自家的JFT数据集上进行了预训练),说明Transformer在CV领域确实是有效的,而且效果还挺惊人。
3.怎么样?
3.1网络结构
与寻常的分类网络类似,整个Vision Transformer可以分为两部分,一部分是特征提取部分,另一部分是分类部分。
在特征提取部分,VIT所做的工作是特征提取。特征提取部分在图片中的对应区域是Patch+Position Embedding和Transformer Encoder。Patch+Position Embedding的作用主要是对输入进来的图片进行分块处理,每隔一定的区域大小划分图片块。然后将划分后的图片块组合成序列。在获得序列信息后,传入Transformer Encoder进行特征提取,这是Transformer特有的Multi-head Self-attention结构,通过自注意力机制,关注每个图片块的重要程度。
在分类部分,VIT所做的工作是利用提取到的特征进行分类。在进行特征提取的时候,我们会在图片序列中添加上Cls Token,该Token会作为一个单位的序列信息一起进行特征提取,提取的过程中,该Cls Token会与其它的特征进行特征交互,融合其它图片序列的特征。最终,我们利用Multi-head Self-attention结构提取特征后的Cls Token进行全连接分类。
3.2特征提取部分介绍
3.2.1Patch
Patch的作用主要是对输入进来的图片进行分块处理,每隔一定的区域大小划分图片块。然后将划分后的图片块组合成序列。
该部分首先对输入进来的图片进行分块处理,处理方式其实很简单,使用的是现成的卷积。也就是说,不是把图片分割,是做了一次简单的卷积,可以理解为初步特征提取,或者说是映射。
由于卷积使用的是滑动窗口的思想,我们只需要设定特定的步长,就可以输入进来的图片进行分块处理了。在VIT中,我们常设置这个卷积的卷积核大小为16x16,步长也为16x16,此时卷积就会每隔16个像素点进行一次特征提取,由于卷积核大小为16x16,两个图片区域的特征提取过程就不会有重叠。当我们输入的图片是224, 224, 3的时候,我们可以获得一个14, 14, 768的特征层。
在代码实现中,直接通过一个卷积层来实现。 以ViT-B/16为例,直接使用一个卷积核大小为16x16,步距为16,卷积核个数为768的卷积来实现。通过卷积[224, 224, 3] -> [14, 14, 768],然后把H以及W两个维度展平即可[14, 14, 768] -> [196, 768],此时正好变成了一个二维矩阵,正是Transformer想要的。
3.2.2Position Embedding
Position Embedding的作用主要是对组合序列加上[class]token以及Position Embedding。
在原论文中,作者说参考BERT,在刚刚得到的一堆tokens中插入一个专门用于分类的[class]token,这个[class]token是一个可训练的参数,数据格式和其他token一样都是一个向量,以ViT-B/16为例,就是一个长度为768的向量,与之前从图片中生成的tokens拼接在一起,Cat([1, 768], [196, 768]) -> [197, 768]。然后关于Position Embedding就是之前Transformer中讲到的Positional Encoding,这里的Position Embedding采用的是一个可训练的参数(1D Pos. Emb.),是直接叠加在tokens上的(add),所以shape要一样。以ViT-B/16为例,刚刚拼接[class]token后shape是[197, 768],那么这里的Position Embedding的shape也是[197, 768]。
对于Position Embedding作者也有做一系列对比试验,在源码中默认使用的是1D Pos. Emb
.
,对比不使用Position Embedding准确率提升了大概3个点,和2D Pos. Emb.
比起来没太大差别。
3.2.3Transformer Encoder
Transformer Encoder其实就是重复堆叠Encoder Block L次,下图是太阳花的小绿豆绘制的Encoder Block,主要由以下几部分组成:
- Layer Norm,这种Normalization方法主要是针对NLP领域提出的,这里是对每个token进行Norm处理,之前也有讲过Layer Norm不懂的可以参考链接
- Multi-Head Attention,看懂Self-attention结构,其实看懂下面这个动图就可以了,动图中存在一个序列的三个单位输入,每一个序列单位的输入都可以通过三个处理(比如全连接)获得Query、Key、Value,Query是查询向量、Key是键向量、Value值向量。
- Dropout/DropPath,在原论文的代码中是直接使用的Dropout层,在但rwightman实现的代码中使用的是DropPath(stochastic depth),可能后者会更好一点。
- MLP Block,如图右侧所示,就是全连接+GELU激活函数+Dropout组成也非常简单,需要注意的是第一个全连接层会把输入节点个数翻4倍[197, 768] -> [197, 3072],第二个全连接层会还原回原节点个数[197, 3072] -> [197, 768]
3.3 分类部分
上面通过Transformer Encoder后输出的shape和输入的shape是保持不变的,以ViT-B/16为例,输入的是[197, 768]输出的还是[197, 768]。注意,在Transformer Encoder后其实还有一个Layer Norm没有画出来,后面有我自己画的ViT的模型可以看到详细结构。这里我们只是需要分类的信息,所以我们只需要提取出[class]token生成的对应结果就行,即[197, 768]中抽取出[class]token对应的[1, 768]。接着我们通过MLP Head得到我们最终的分类结果。MLP Head原论文中说在训练ImageNet21K时是由Linear+tanh激活函数+Linear组成。但是迁移到ImageNet1K上或者你自己的数据上时,只用一个Linear即可。
3.4别人画的网络结构图
3.5代码实现
Patch+Position Embedding
class PatchEmbed(nn.Module):def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_features=768, norm_layer=None, flatten=True):super().__init__()self.num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)self.flatten = flattenself.proj = nn.Conv2d(in_chans, num_features, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(num_features) if norm_layer else nn.Identity()def forward(self, x):x = self.proj(x)if self.flatten:x = x.flatten(2).transpose(1, 2) # BCHW -> BNCx = self.norm(x)return xclass VisionTransformer(nn.Module):def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU):super().__init__()#-----------------------------------------------## 224, 224, 3 -> 196, 768#-----------------------------------------------#self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)num_patches = (224 // patch_size) * (224 // patch_size)self.num_features = num_featuresself.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]#--------------------------------------------------------------------------------------------------------------------## classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。## 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。# 此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。# 在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。#--------------------------------------------------------------------------------------------------------------------## 196, 768 -> 197, 768self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))#--------------------------------------------------------------------------------------------------------------------## 为网络提取到的特征添加上位置信息。# 以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768# 此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。#--------------------------------------------------------------------------------------------------------------------## 197, 768 -> 197, 768self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))def forward_features(self, x):x = self.patch_embed(x)cls_token = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_token, x), dim=1)cls_token_pe = self.pos_embed[:, 0:1, :]img_token_pe = self.pos_embed[:, 1: , :]img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)x = self.pos_drop(x + pos_embed)
TransformerBlock
class Mlp(nn.Module):""" MLP as used in Vision Transformer, MLP-Mixer and related networks"""def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresdrop_probs = (drop, drop)self.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.drop1 = nn.Dropout(drop_probs[0])self.fc2 = nn.Linear(hidden_features, out_features)self.drop2 = nn.Dropout(drop_probs[1])def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop1(x)x = self.fc2(x)x = self.drop2(x)return xclass Block(nn.Module):def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):super().__init__()self.norm1 = norm_layer(dim)self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)self.norm2 = norm_layer(dim)self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()def forward(self, x):x = x + self.drop_path(self.attn(self.norm1(x)))x = x + self.drop_path(self.mlp(self.norm2(x)))return x
VIT
整个VIT模型由一个Patch+Position Embedding加上多个TransformerBlock组成。典型的TransforerBlock的数量为12个。
class VisionTransformer(nn.Module):def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU):super().__init__()#-----------------------------------------------## 224, 224, 3 -> 196, 768#-----------------------------------------------#self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)num_patches = (224 // patch_size) * (224 // patch_size)self.num_features = num_featuresself.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]#--------------------------------------------------------------------------------------------------------------------## classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。## 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。# 此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。# 在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。#--------------------------------------------------------------------------------------------------------------------## 196, 768 -> 197, 768self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))#--------------------------------------------------------------------------------------------------------------------## 为网络提取到的特征添加上位置信息。# 以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768# 此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。#--------------------------------------------------------------------------------------------------------------------## 197, 768 -> 197, 768self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))self.pos_drop = nn.Dropout(p=drop_rate)#-----------------------------------------------## 197, 768 -> 197, 768 12次#-----------------------------------------------#dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]self.blocks = nn.Sequential(*[Block(dim = num_features, num_heads = num_heads, mlp_ratio = mlp_ratio, qkv_bias = qkv_bias, drop = drop_rate,attn_drop = attn_drop_rate, drop_path = dpr[i], norm_layer = norm_layer, act_layer = act_layer)for i in range(depth)])self.norm = norm_layer(num_features)self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()def forward_features(self, x):x = self.patch_embed(x)cls_token = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_token, x), dim=1)cls_token_pe = self.pos_embed[:, 0:1, :]img_token_pe = self.pos_embed[:, 1: , :]img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)x = self.pos_drop(x + pos_embed)x = self.blocks(x)x = self.norm(x)return x[:, 0]def forward(self, x):x = self.forward_features(x)x = self.head(x)return xdef freeze_backbone(self):backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]for module in backbone:try:for param in module.parameters():param.requires_grad = Falseexcept:module.requires_grad = Falsedef Unfreeze_backbone(self):backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]for module in backbone:try:for param in module.parameters():param.requires_grad = Trueexcept:module.requires_grad = True
Vision Transforme的构建代码
import math
from collections import OrderedDict
from functools import partialimport numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F#--------------------------------------#
# Gelu激活函数的实现
# 利用近似的数学公式
#--------------------------------------#
class GELU(nn.Module):def __init__(self):super(GELU, self).__init__()def forward(self, x):return 0.5 * x * (1 + F.tanh(np.sqrt(2 / np.pi) * (x + 0.044715 * torch.pow(x,3))))def drop_path(x, drop_prob: float = 0., training: bool = False):if drop_prob == 0. or not training:return xkeep_prob = 1 - drop_probshape = (x.shape[0],) + (1,) * (x.ndim - 1)random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)random_tensor.floor_() output = x.div(keep_prob) * random_tensorreturn outputclass DropPath(nn.Module):def __init__(self, drop_prob=None):super(DropPath, self).__init__()self.drop_prob = drop_probdef forward(self, x):return drop_path(x, self.drop_prob, self.training)class PatchEmbed(nn.Module):def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_features=768, norm_layer=None, flatten=True):super().__init__()self.num_patches = (input_shape[0] // patch_size) * (input_shape[1] // patch_size)self.flatten = flattenself.proj = nn.Conv2d(in_chans, num_features, kernel_size=patch_size, stride=patch_size)self.norm = norm_layer(num_features) if norm_layer else nn.Identity()def forward(self, x):x = self.proj(x)if self.flatten:x = x.flatten(2).transpose(1, 2) # BCHW -> BNCx = self.norm(x)return x#--------------------------------------------------------------------------------------------------------------------#
# Attention机制
# 将输入的特征qkv特征进行划分,首先生成query, key, value。query是查询向量、key是键向量、v是值向量。
# 然后利用 查询向量query 叉乘 转置后的键向量key,这一步可以通俗的理解为,利用查询向量去查询序列的特征,获得序列每个部分的重要程度score。
# 然后利用 score 叉乘 value,这一步可以通俗的理解为,将序列每个部分的重要程度重新施加到序列的值上去。
#--------------------------------------------------------------------------------------------------------------------#
class Attention(nn.Module):def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.):super().__init__()self.num_heads = num_headsself.scale = (dim // num_heads) ** -0.5self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)self.attn_drop = nn.Dropout(attn_drop)self.proj = nn.Linear(dim, dim)self.proj_drop = nn.Dropout(proj_drop)def forward(self, x):B, N, C = x.shapeqkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)q, k, v = qkv[0], qkv[1], qkv[2]attn = (q @ k.transpose(-2, -1)) * self.scaleattn = attn.softmax(dim=-1)attn = self.attn_drop(attn)x = (attn @ v).transpose(1, 2).reshape(B, N, C)x = self.proj(x)x = self.proj_drop(x)return xclass Mlp(nn.Module):""" MLP as used in Vision Transformer, MLP-Mixer and related networks"""def __init__(self, in_features, hidden_features=None, out_features=None, act_layer=GELU, drop=0.):super().__init__()out_features = out_features or in_featureshidden_features = hidden_features or in_featuresdrop_probs = (drop, drop)self.fc1 = nn.Linear(in_features, hidden_features)self.act = act_layer()self.drop1 = nn.Dropout(drop_probs[0])self.fc2 = nn.Linear(hidden_features, out_features)self.drop2 = nn.Dropout(drop_probs[1])def forward(self, x):x = self.fc1(x)x = self.act(x)x = self.drop1(x)x = self.fc2(x)x = self.drop2(x)return xclass Block(nn.Module):def __init__(self, dim, num_heads, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,drop_path=0., act_layer=GELU, norm_layer=nn.LayerNorm):super().__init__()self.norm1 = norm_layer(dim)self.attn = Attention(dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)self.norm2 = norm_layer(dim)self.mlp = Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), act_layer=act_layer, drop=drop)self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()def forward(self, x):x = x + self.drop_path(self.attn(self.norm1(x)))x = x + self.drop_path(self.mlp(self.norm2(x)))return xclass VisionTransformer(nn.Module):def __init__(self, input_shape=[224, 224], patch_size=16, in_chans=3, num_classes=1000, num_features=768,depth=12, num_heads=12, mlp_ratio=4., qkv_bias=True, drop_rate=0.1, attn_drop_rate=0.1, drop_path_rate=0.1,norm_layer=partial(nn.LayerNorm, eps=1e-6), act_layer=GELU):super().__init__()#-----------------------------------------------## 224, 224, 3 -> 196, 768#-----------------------------------------------#self.patch_embed = PatchEmbed(input_shape=input_shape, patch_size=patch_size, in_chans=in_chans, num_features=num_features)num_patches = (224 // patch_size) * (224 // patch_size)self.num_features = num_featuresself.new_feature_shape = [int(input_shape[0] // patch_size), int(input_shape[1] // patch_size)]self.old_feature_shape = [int(224 // patch_size), int(224 // patch_size)]#--------------------------------------------------------------------------------------------------------------------## classtoken部分是transformer的分类特征。用于堆叠到序列化后的图片特征中,作为一个单位的序列特征进行特征提取。## 在利用步长为16x16的卷积将输入图片划分成14x14的部分后,将14x14部分的特征平铺,一幅图片会存在序列长度为196的特征。# 此时生成一个classtoken,将classtoken堆叠到序列长度为196的特征上,获得一个序列长度为197的特征。# 在特征提取的过程中,classtoken会与图片特征进行特征的交互。最终分类时,我们取出classtoken的特征,利用全连接分类。#--------------------------------------------------------------------------------------------------------------------## 196, 768 -> 197, 768self.cls_token = nn.Parameter(torch.zeros(1, 1, num_features))#--------------------------------------------------------------------------------------------------------------------## 为网络提取到的特征添加上位置信息。# 以输入图片为224, 224, 3为例,我们获得的序列化后的图片特征为196, 768。加上classtoken后就是197, 768# 此时生成的pos_Embedding的shape也为197, 768,代表每一个特征的位置信息。#--------------------------------------------------------------------------------------------------------------------## 197, 768 -> 197, 768self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + 1, num_features))self.pos_drop = nn.Dropout(p=drop_rate)#-----------------------------------------------## 197, 768 -> 197, 768 12次#-----------------------------------------------#dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]self.blocks = nn.Sequential(*[Block(dim = num_features, num_heads = num_heads, mlp_ratio = mlp_ratio, qkv_bias = qkv_bias, drop = drop_rate,attn_drop = attn_drop_rate, drop_path = dpr[i], norm_layer = norm_layer, act_layer = act_layer)for i in range(depth)])self.norm = norm_layer(num_features)self.head = nn.Linear(num_features, num_classes) if num_classes > 0 else nn.Identity()def forward_features(self, x):x = self.patch_embed(x)cls_token = self.cls_token.expand(x.shape[0], -1, -1) x = torch.cat((cls_token, x), dim=1)cls_token_pe = self.pos_embed[:, 0:1, :]img_token_pe = self.pos_embed[:, 1: , :]img_token_pe = img_token_pe.view(1, *self.old_feature_shape, -1).permute(0, 3, 1, 2)img_token_pe = F.interpolate(img_token_pe, size=self.new_feature_shape, mode='bicubic', align_corners=False)img_token_pe = img_token_pe.permute(0, 2, 3, 1).flatten(1, 2)pos_embed = torch.cat([cls_token_pe, img_token_pe], dim=1)x = self.pos_drop(x + pos_embed)x = self.blocks(x)x = self.norm(x)return x[:, 0]def forward(self, x):x = self.forward_features(x)x = self.head(x)return xdef freeze_backbone(self):backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]for module in backbone:try:for param in module.parameters():param.requires_grad = Falseexcept:module.requires_grad = Falsedef Unfreeze_backbone(self):backbone = [self.patch_embed, self.cls_token, self.pos_embed, self.pos_drop, self.blocks[:8]]for module in backbone:try:for param in module.parameters():param.requires_grad = Trueexcept:module.requires_grad = Truedef vit(input_shape=[224, 224], pretrained=False, num_classes=1000):model = VisionTransformer(input_shape)if pretrained:model.load_state_dict(torch.load("model_data/vit-patch_16.pth"))if num_classes!=1000:model.head = nn.Linear(model.num_features, num_classes)return model
参考:Vision Transformer详解
神经网络学习小记录67——Pytorch版 Vision Transformer(VIT)模型的复现详解
相关文章:
VIT用于图像分类 学习笔记(附代码)
论文地址:https://arxiv.org/abs/2010.11929 代码地址:https://github.com/bubbliiiing/classification-pytorch 1.是什么? Vision Transformer(VIT)是一种基于Transformer架构的图像分类模型。它将图像分割成一系列…...
MongoDB Certified Associate Developer 认证考试心得
介绍 前段时间通过了 MongoDB Associate Developer 考试,也记下了一些心得,结果忘记发出来了,现在重新整理下。通过考试后证书是这样的: MongoDB 目前有两个认证证书 1. MongoDB Associate Developer 认证掌握使用MongoDB 来构建现代应用…...
基于Java车间工时管理系统(源码+部署文档)
博主介绍: ✌至今服务客户已经1000、专注于Java技术领域、项目定制、技术答疑、开发工具、毕业项目实战 ✌ 🍅 文末获取源码联系 🍅 👇🏻 精彩专栏 推荐订阅 👇🏻 不然下次找不到 Java项目精品实…...
2024.1.5
今天真是狂学了一天的C,什么期末考试,滚tmd(就一门政治,不能影响我c的脚步),今天还是指针,主要是函数指针和函数指针数组,将简单的两位数计算器程序用此方式更加简单的实现了&#x…...
水库大坝安全监测设计与施工经验
随着我国的科技水平不断上升,带动了我国的水电建设向更高层次发展。目前,我国的水电站大坝已有上百座,并且大坝安全检测仪器质量与先进技术不断更新发展,如今水电站大坝数据信息采集与观测资料分析,能够有效提高水库大…...
媒体捕捉-拍照
引言 在项目开发中,从媒体库中选择图片或使用相机拍摄图片是一个极为普遍的需求。通常,我们使用UIImagePickerController来实现单张图片选择或启动相机拍照。整个拍照过程由UIImagePickerController内部实现,无需我们关心细节,只…...
Typora+PicGo+Gitee构建云存储图片
创建Gitee仓库 首先,打开工作台 - Gitee.com,自行注册一个账户 注册完后,新建一个仓库(记得仓库要开源) 然后创建完仓库后,鼠标移动到右上角头像位置,选择设置,并点击ÿ…...
【话题】ChatGPT等大语言模型为什么没有智能2
我们接着上一次的讨论,继续探索大模型的存在的问题。正巧CSDN最近在搞文章活动,我们来看看大模型“幻觉”。当然,本文可能有很多我自己的“幻觉”,欢迎批评指正。如果这么说的话,其实很容易得出一个小结论——大模型如…...
通过大量生物、地球、农业、气象、生态、环境科学领域中案例,一起探索如何优雅地使用大模型吧!
以ChatGPT、LLaMA、Gemini、DALLE、Midjourney、Stable Diffusion、星火大模型、文心一言、千问为代表AI大语言模型带来了新一波人工智能浪潮,可以面向科研选题、思维导图、数据清洗、统计分析、高级编程、代码调试、算法学习、论文检索、写作、翻译、润色、文献辅助…...
slf4j+logback源码加载流程解析
slf4j绑定logback源码解析 Logger log LoggerFactory.getLogger(LogbackDemo.class);如上述代码所示,在项目中通常会这样创建一个Logger对象去打印日志。 然后点进去,会走到LoggerFactory的getILoggerFactory()方法,如下代码所示。 public …...
KVM虚拟机部署K8S重启后/etc/hosts内容丢失
前言 使用KVM开了虚拟机部署K8S,部署完成后重启,节点的pod等信息无法获取到,查看报错初步推测为域名解析失效,查看/etc/hosts后发现安装k8s时添加的内容全部消失 网上搜索一番之后发现了 如果直接修改 /etc/hosts 文件࿰…...
Redis使用场景(五)
Redis实战精讲-13小时彻底学会Redis 1.计数器 可以对 String 进行自增自减运算,从而实现计数器功能。 Redis 这种内存型数据库的读写性能非常高,很适合存储频繁读写的计数量。 2.缓存 将热点数据放到内存中,设置内存的最大使用量以及淘汰策略…...
【UnityShader入门精要学习笔记】(2)GPU流水线
本系列为作者学习UnityShader入门精要而作的笔记,内容将包括: 书本中句子照抄 个人批注项目源码一堆新手会犯的错误潜在的太监断更,有始无终 总之适用于同样开始学习Shader的同学们进行有取舍的参考。 文章目录 上节复习GPU流水线顶点着色…...
CSS免费在线字体格式转换器 CSS @font-face 生成器
今天竟意外发现的一款免费的“网页字体生成器”,功能强大又好用~ 工具地址:https://transfonter.org/ 根据你设置生成后的文件预览: 支持TTF、OTF、WOFF、WOFF2 或 SVG字体格式转换生成,每个文件最大15MB。转换完成以后还会生成一…...
Codeium在IDEA里的3个坑
转载自Codeium在IDEA里的3个坑:无法log in,downloading language server和中文乱码_downloading codeium language server...-CSDN博客文章浏览阅读1.7w次,点赞26次,收藏47次。Codeium安装IDEA插件的3个常见坑_downloading codeiu…...
C-C++ 项目构建指南:如何使用 Makefile 提高开发效率
Makefile是一个常用的自动化构建工具,它可以为开发人员提供方便的项目构建方式。在C/C项目中,Makefile可以用来编译、链接和生成可执行文件。使用Makefile的好处是可以自动执行一系列命令,从而减少手动操作的复杂性和出错的可能性。此外&…...
基于SpringBoot的图书管理系统
文章目录 项目介绍主要功能截图:部分代码展示设计总结项目获取方式🍅 作者主页:超级无敌暴龙战士塔塔开 🍅 简介:Java领域优质创作者🏆、 简历模板、学习资料、面试题库【关注我,都给你】 🍅文末获取源码联系🍅 项目介绍 🚀🚀🚀SpringBoot 阿博图书管理系…...
矩阵对角线遍历
Diagonal 2614. 对角线上的质数 class Solution {public int diagonalPrime(int[][] nums) {int n = nums....
【教程】Typecho Joe主题开启并修复壁纸相册不显示问题
转载请注明出处:小锋学长生活大爆炸[xfxuezhang.cn] 背景说明 Joe主题本身支持“壁纸”功能,其实就是相册。当时还在网上找了好久相册部署的开源项目,太傻了。 但是网上教程很少,一没说如何开启壁纸功能,二没说开启后为…...
MR混合现实情景实训教学系统在法律专业课堂上的应用
MR混合现实情景实训教学系统是一种将虚拟现实(VR)、增强现实(AR)相结合的先进技术。在法律教学课堂上,MR教学系统为学生模拟模拟法庭、案例分析等多种形式,让学生在实践中掌握法律知识,提高法律…...
车载 Android之 核心服务 - CarPropertyService 的VehicleHAL
前言: 本文是车载Android之核心服务-CarPropertyService的第二篇,了解一下CarPropertyService的VehicleHAL, 第一篇在车载 Android之 核心服务 - CarPropertyService 解析-CSDN博客,有兴趣的 朋友可以去看下。 本节介绍 AndroidAutomotiveOS中对于 Veh…...
年底了,准备跳槽的可以看看...
前两天跟朋友感慨,今年的铜九铁十、裁员、疫情导致好多人都没拿到offer!现在已经1月了,具体明年的金三银四只剩下两个月。 对于想跳槽的职场人来说,绝对要从现在开始做准备了。这时候,很多高薪技术岗、管理岗的缺口和市场需求也出…...
Bagging算法_随机森林Random_Forest
Bagging B a g g i n g Bagging Bagging是并行式集成学习方法最著名的代表,这个名字是由 B o o t s t r a p A G G r e g a t I N G Bootstrap AGGregatING BootstrapAGGregatING而来,顾名思义,该算法由 B o o s t s t r a p Booststrap Boos…...
物理与网络安全
物流环境安全 场地选择考虑抗震、承重、防火、防水、供电、空气调节、电磁防护、雷击及静电 场地因素: 自然灾害,社会因素(加油站、化工厂),配套条件(消防,交通,电力,…...
torch.meshgrid和np.meshgrid的区别
numpy中meshgrid: 把数组a当作一行,再根据数组b的长度扩充行。 把数组b当作一列,再根据数组a的长度扩充列。 torch中meshgrid: 把数组a当作一列,再根据数组b的长度扩充列。 把数组b当作一行,再根据数组a的…...
【PostgreSQL】约束-唯一约束
【PostgreSQL】约束链接 检查 唯一 主键 外键 排他 唯一约束 唯一约束是数据库中的一种约束,用于确保某个列或字段的值在该列或字段中是唯一的。唯一约束可用于确保数据库表中的某个列中的值是唯一的,也可用于确保多个列的组合值是唯一的。 在创建表…...
学习使用js/jquery获取指定class名称的三种方式
学习使用js/jquery获取指定class名称的三种方式 简介一、获取元素的class名称1、通过原生JS获取元素的class名称2、通过Jquery获取元素的class名称 二、应用1、样式修改2、动画效果实现 简介 在开发网页时,我们经常需要通过JS获取元素的class名称进行一些操作&…...
latex数学公式
写于:2024年1月5日 晚 修改: 摘要:数学公式根据其位置可以分为行内公式和行间公式。行内公式更加紧凑,而行间公式富于变化,可以为其编号、引用、换行等操作。本文对数学公式的 LaTex 做简单记录和整理。 行内公式 行内…...
frp配置内网穿透访问家里的nas
frp配置内网穿透访问家里的nas 需求 家里局域网内有台nas,在去公司的路上想访问它 其内网地址为: http://192.168.50.8:6002 工具 1.frp版本v0.53.2 下载地址: https://github.com/fatedier/frp/releases/download/v0.53.2/frp_0.53.2_li…...
C语言-蓝桥杯2023年第十四届省赛真题-砍树
题目描述 给定一棵由 n 个结点组成的树以及 m 个不重复的无序数对 (a1, b1), (a2, b2), . . . , (am, bm),其中 ai 互不相同,bi 互不相同,ai ≠ bj(1 ≤ i, j ≤ m)。 小明想知道是否能够选择一条树上的边砍断,使得对于每个 (a…...
网站建设 网站/厦门seo屈兴东
[TOC]**会员相关函数全部位于 framework/model/mc.mod.php 文件内。****注意:该文件内所有函数使用前必须加载文件: load()→model(mc);**## mc_check 检测会员信息是否存在(邮箱和手机号)> 如果会员不存在,返回 true,否则返回注册信息。~…...
效果好的网站建设公司/青岛网站制作公司
剑指offer题目描述: 给定一个数组A[0,1,...,n-1],请构建一个数组B[0,1,...,n-1],其中B中的元素B[i]A[0]*A[1]*...*A[i-1]*A[i1]*...*A[n-1]。不能使用除法。(注意:规定B[0] A[1] * A[2] * ... * A[n-1],B[n-1] A[0] * A[1] * .…...
wordpress 登录没反应/宁波seo专员
单链表(single-linked list)链表结构应用实例分析数据结构算法类方法对象代码实现插入向尾部直接插入节点思路分析算法实现按照顺序插入指定位置思路分析算法实现修改思路分析代码实现删除思路分析代码实现查找思路分析代码实现面试题有效元素的个数代码…...
建筑中级职称查询网站/网络广告营销的特点
本教程演示如何在 torchtext 中使用文本分类数据集,包括- AG_NEWS,- SogouNews,- DBpedia,- YelpReviewPolarity,- YelpReviewFull,- YahooAnswers,- AmazonReviewPolarity,- AmazonReviewFull此示例演示如何使用 TextClassification 数据集中的一个训练用于分类文本…...
wordpress升级后乱码/新闻稿在线
昨日#Pandownload#登上微博热搜榜,因其开发者已被抓。扬州网警巡查执法发布净网战报:宝应网安破获一起黑客攻击计算机信息系统案件今年2月,受害人刘某报案称其下载的“Pandownload"软件会在未授权的情况下,将自己百度网盘的数…...
php mysql做网站登录/营销软件网站
网上看到好多朋友都通过了软考,我也好想参加软考,听说这个证含金量挺高的,加油,打算下半年参加考试,现在赶紧的看书哦,下面是希赛的考试资料网站,有时间上来多学习了,也希望参加过和准备参加的朋…...