当前位置: 首页 > news >正文

【图像分类】【深度学习】【Pytorch版本】 ResNeXt模型算法详解

【图像分类】【深度学习】【Pytorch版本】 ResNeXt模型算法详解

文章目录

  • 【图像分类】【深度学习】【Pytorch版本】 ResNeXt模型算法详解
  • 前言
  • ResNeXt讲解
    • 分组卷积(Group Converlution)
    • 分割-变换-合并策略(split-transform-merge)
    • ResNeXt模型结构
  • ResNeXt Pytorch代码
  • 完整代码
  • 总结


前言

ResNeXt是加利福尼亚大学圣迭戈分校的Xie, Saining等人在《Aggregated Residual Transformations for Deep Neural Networks【CVPR-2017】》【论文地址】一文中提出的模型,结合ResNet【参考】的卷积块堆叠的思想以及Inception【参考】的分割-变换-合并的策略,在不明显增加参数量级的情况下提升了模型的准确率。


ResNeXt讲解

Inception系列模型则证明精心设计的拓扑结构(采用分割-转换-合并策略),在拥有不错的表示能力同时计算复杂度大大降低:首先通过1×1的卷积将输入分割成多个低维度的嵌入,然后通过一组专门的过滤器(3×3,5×5等)分别进行转换,最后通过串联进行合并。
但是Inception系列的实现一直伴随着一系列复杂的因素:卷积核的数量和大小是为每个变换单独定制的,网络中的Inception模块也是逐个定制的。随着网络深度的增加,网络的超参数(卷积核个数、大小和步长等)也在增加,设计更好的网络架构以学习表征变得越来越困难。ResNets继承了VGGNet简单而有效的方法,采用相同拓扑结构的模块堆叠构建深度网络,不需要每层都单独设置超参数,减少了超参数的自由选择。
因此在论文中,ResNeXt提出了一个简单的架构,它以一种简单、可扩展的方式采用了ResNets的重复层策略,同时利用了Inception的分割-变换-合并策略。

分组卷积(Group Converlution)

在分组卷积中,将输入特征图的通道分成多个组,每个组内的通道只与相应组内的卷积核进行卷积运算,最后将各个组的输出特征图连接在一起,形成最终的输出特征图。
以下是博主绘制的普通卷积和分组卷积的示意图:

实际上无论普通卷积还是分组卷积,卷积核的数量没有发生改变,只不过分组卷积的卷积核的通道数变小了。

分组卷积的主要目的是减少卷积操作的计算量,特别适用于在计算资源有限的情况下进行模型设计。

分割-变换-合并策略(split-transform-merge)

注意:这个小节比较考验读者的对卷积过程的认知功底,建议大家好好理解下,有助于大家夯实基本功。
先说结论,下图是原论文中给出的结构示意图,a图结构是分割-变换-合并策略的体现,c图结构则是使用分组卷积后的对a图结构的等价替换。

接下来博主就将详细讲解分割-变换-合并策略中每一个步骤的过程和作用,为了方便大家理解,博主采用了a图的结构进行讲解。
ResNeXt通过将输入数据分割成多个子集,每个子集进行独立的变换操作,网络可以学习到更多不同的特征表示。而通过合并操作,网络可以将这些不同的特征表示进行组合,从而得到更丰富的特征表达能力。

  1. split:分割输入数据。

    分割可以理解为将多个卷积核划分到不同组,每个组的卷积核个数一致。如示意图所示,将一层大卷积层拆分成多个小卷积层后处理同一个输入,假设将多个小卷积层的输出(子集)拼接成一起就等价于大卷积层的输出,因此俩者是等效的。

    个人理解:其实可以只用一个卷积层进行卷积,将输出的特征图按照组进行拆分即可,不需要对多个小卷积层单独分组。

  2. transform:子集独立变换。

    每个小卷积层的输出(子集)再经过一层各自的卷积层进行卷积变换。如示意图所示,等价于分组卷积。

    个人理解:早期深度学习框架不支持分组卷积,因此分组卷积的实现,需要在分组卷积事先将输入按照分组进行拆分,也是就split过程,然后对分组后的输入子集再进行小组内卷积。

  3. merge:合并特征图。
    合并可以理解为将一个大卷积核划分成多个小卷积核,每个小卷积核拥有大卷积核的一部分通道,每个小卷积核的通道数量一致。如示意图所示,大卷积核通道数和拆分后的小卷积核的总通道数是一致的。回顾以下,传统的卷积运算(大卷积核)的输出特征是由每个通道的权重与对应输入特征进行运算和相加而来,即1到12一次性相加,那么小卷积就是将这个过程进行了拆分,即先是1到4、5到8和9到12分别相加,然后再对三个相加结果再进行相加。

    个人理解:其实先将多组输入的特征图进行拼接,只用一个大卷积核组成的卷积层进行卷积即可,不需要用多个小卷积核组成的卷积层。

ResNeXt模型结构

ResNeXt对ResNet进行了改进,采用了多分支的策略,在论文中作者提出了三种等价的模型结构,最后的ResNeXt采用了图c的结构来构建ResNeXt,因为c结构比较简洁而且速度更快。

ResNeXt通过增加cardinality(group)参数,可以灵活地控制子集的数量,增加基数可以提高模型的性能,提高特征提取的能力,且要比增加宽度和深度更有效。
下图是原论文给出的关于ResNeXt模型结构的详细示意图:

ResNeXt与ResNet一样也是构建基于两个准则:1.同阶段中的残差块使用相同的卷积核个数和卷积核尺寸;2.特征图减小时增加卷积核个数。基于上述准则,在ResNet-50模型的基础上,提出了ResNeXt-50模型。
ResNeXt在图像分类中分为两部分:backbone部分: 主要由残差结构、卷积层和池化层(汇聚层)组成,分类器部分:由全局平均池化层和全连接层组成 。

ResNeXt只能在残差块的深度超过2层时使用,所以ResNeXt不在ResNet18和34进行修改的原因。


ResNeXt Pytorch代码

分组卷积层:

# 3×3分组卷积
nn.Conv2d(in_channels=width, out_channels=width, groups=groups,kernel_size=3, stride=stride, bias=False, padding=1)

残差结构Bottleneck: 卷积层(或分组卷积层)+BN层+激活函数

class Bottleneck(nn.Module):expansion = 4# 残差结构参考了resnet的残差结构def __init__(self, in_channel, out_channel, stride=1, downsample=None,groups=1, width_per_group=64):super(Bottleneck, self).__init__()# 是为了保证卷积核个数能被组数整除,每组的卷积核个数不出现小数width = int(out_channel * (width_per_group / 64.)) * groups# 第一层(降维)self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,kernel_size=1, stride=1, bias=False)  # squeeze channelsself.bn1 = nn.BatchNorm2d(width)# 第二层(分组卷积)self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(width)# 第三层(升维)self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,kernel_size=1, stride=1, bias=False)  # unsqueeze channelsself.bn3 = nn.BatchNorm2d(out_channel*self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return out

完整代码

import torch.nn as nn
import torch
from torchsummary import summaryclass Bottleneck(nn.Module):expansion = 4# 残差结构参考了resnet的残差结构def __init__(self, in_channel, out_channel, stride=1, downsample=None,groups=1, width_per_group=64):super(Bottleneck, self).__init__()# 是为了保证卷积核个数能被组数整除,每组的卷积核个数不出现小数width = int(out_channel * (width_per_group / 64.)) * groups# 第一层(降维)self.conv1 = nn.Conv2d(in_channels=in_channel, out_channels=width,kernel_size=1, stride=1, bias=False)  # squeeze channelsself.bn1 = nn.BatchNorm2d(width)# 第二层(分组卷积)self.conv2 = nn.Conv2d(in_channels=width, out_channels=width, groups=groups,kernel_size=3, stride=stride, bias=False, padding=1)self.bn2 = nn.BatchNorm2d(width)# 第三层(升维)self.conv3 = nn.Conv2d(in_channels=width, out_channels=out_channel*self.expansion,kernel_size=1, stride=1, bias=False)  # unsqueeze channelsself.bn3 = nn.BatchNorm2d(out_channel*self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampledef forward(self, x):identity = xif self.downsample is not None:identity = self.downsample(x)out = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)out += identityout = self.relu(out)return outclass ResNeXt(nn.Module):def __init__(self,blocks_num,num_classes=1000,groups=1,width_per_group=64):super(ResNeXt, self).__init__()self.in_channel = 64# 组数self.groups = groups# 每组包含的卷积个数self.width_per_group = width_per_groupself.conv1 = nn.Conv2d(3, self.in_channel, kernel_size=7, stride=2,padding=3, bias=False)self.bn1 = nn.BatchNorm2d(self.in_channel)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)# 第一组残差块组self.layer1 = self._make_layer(Bottleneck, 64, blocks_num[0])# 第二组残差块组self.layer2 = self._make_layer(Bottleneck, 128, blocks_num[1], stride=2)# 第三组残差块组self.layer3 = self._make_layer(Bottleneck, 256, blocks_num[2], stride=2)# 第四组残差块组self.layer4 = self._make_layer(Bottleneck, 512, blocks_num[3], stride=2)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))  # output size = (1, 1)self.fc = nn.Linear(512 * Bottleneck.expansion, num_classes)# 权重初始化for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')def _make_layer(self, block, channel, block_num, stride=1):downsample = Noneif stride != 1 or self.in_channel != channel * block.expansion:downsample = nn.Sequential(nn.Conv2d(self.in_channel, channel * block.expansion, kernel_size=1, stride=stride, bias=False),nn.BatchNorm2d(channel * block.expansion))layers = []layers.append(block(self.in_channel,channel,downsample=downsample,stride=stride,groups=self.groups,width_per_group=self.width_per_group))self.in_channel = channel * block.expansionfor _ in range(1, block_num):layers.append(block(self.in_channel,channel,groups=self.groups,width_per_group=self.width_per_group))return nn.Sequential(*layers)def forward(self, x):# backbone主干网络部分# resnext50为例# N x 3 x 224 x 224x = self.conv1(x)# N x 64 x 112 x 112x = self.bn1(x)# N x 64 x 112 x 112x = self.relu(x)# N x 64 x 112 x 112x = self.maxpool(x)# N x 64 x 56 x 56x = self.layer1(x)# N x 256 x 56 x 56x = self.layer2(x)# N x 512 x 28 x 28x = self.layer3(x)# N x 1024 x 14 x 14x = self.layer4(x)# N x 2048 x 7 x 7x = self.avgpool(x)# N x 2048 x 1 x 1x = torch.flatten(x, 1)# N x 2048x = self.fc(x)# N x 1000return xdef resnext50_32x4d(num_classes=1000):# https://download.pytorch.org/models/resnext50_32x4d-7cdf4587.pthgroups = 32width_per_group = 4return ResNeXt([3, 4, 6, 3],num_classes=num_classes,groups=groups,width_per_group=width_per_group)def resnext101_32x8d(num_classes=1000):# https://download.pytorch.org/models/resnext101_32x8d-8ba56ff5.pthgroups = 32width_per_group = 8return ResNeXt([3, 4, 23, 3],num_classes=num_classes,groups=groups,width_per_group=width_per_group)if __name__ == '__main__':device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")model = resnext50_32x4d().to(device)summary(model, input_size=(3, 224, 224))

summary可以打印网络结构和参数,方便查看搭建好的网络结构。


总结

尽可能简单、详细的介绍了分组卷积的原理和在卷积神经网络中的作用,讲解了ResNeXt模型的结构和pytorch代码。

相关文章:

【图像分类】【深度学习】【Pytorch版本】 ResNeXt模型算法详解

【图像分类】【深度学习】【Pytorch版本】 ResNeXt模型算法详解 文章目录 【图像分类】【深度学习】【Pytorch版本】 ResNeXt模型算法详解前言ResNeXt讲解分组卷积(Group Converlution)分割-变换-合并策略(split-transform-merge)ResNeXt模型结构 ResNeXt Pytorch代码完整代码总…...

Android 14 应用适配指南

Android 14 应用适配指南:https://dev.mi.com/distribute/doc/details?pId1718 Android 14 功能和变更列表 | Android 开发者 | Android Developers 1.获取Android 14 1.1 谷歌发布时间表 https://developer.android.com/about/versions/14/overview#timeli…...

【AI美图提示词】第07期效果图,AI人工智能自动绘画,精选绝美版美图欣赏

AI诗配画 山水画中景如画,云雾缭绕峰峦间。桥畔流水潺潺响,诗意盎然山水间。上面的诗句和图片全部来自AI自动化完成,这就是技术的力量,接下来我们进行模型生成学习: 先上原始底图: 下面是模型生成效果图&a…...

前端知识(十三)——JavaScript监听按键,禁止F12,禁止右键,禁止保存网页【Ctrl+s】等操作

禁止右键 document.oncontextmenu new Function("event.returnValuefalse;") //禁用右键禁止按键 // 监听按键 document.onkeydown function () {// f12if (window.event && window.event.keyCode 123) {alert("F12被禁用");event.keyCode 0…...

面向对象设计与分析(28)单例模式的奇异递归模板CRTP实现

前面我们介绍了单例模式的两种实现:懒汉模式和饿汉模式,今天我们以新的方式来实现可复用的单例模式。 奇异递归模板是指父类是个模板类,模板类型是子类类型,即父类通过模板参数可以知道子类的类型。 // brief: a singleton base…...

微信小程序 - 龙骨图集拆分

微信小程序 - 龙骨图集拆分 注意目录结构演示动画废话一下业务逻辑注意点龙骨JSON图集结构 源码分享dragonbones-split.jsdragonbones-split.jsondragonbones-split.wxmldragonbones-split.wxssimgUtil.js 参考资料 注意 只支持了JSON版本 目录结构 演示动画 Spine播放器1.5.…...

使用React 18和WebSocket构建实时通信功能

1. 引言 WebSocket是一种在Web应用中实现双向通信的协议。它允许服务器主动向客户端推送数据,而不需要客户端发起请求。在现代的实时应用中,WebSocket经常用于实时数据传输、聊天功能、实时通知和多人协作等场景。在本篇博客中,我们将探索如…...

vue3使用vue-router嵌套路由(多级路由)

文章目录 1、Vue3 嵌套路由2、项目结构3、编写相关页面代码3.1、编写route文件下 index.ts文件3.2、main.ts文件代码:3.3、App.vue文件代码:3.4、views文件夹下的Home文件夹下的index.vue文件代码:3.5、views文件夹下的Home文件夹下的Tigerhh…...

openGauss学习笔记-164 openGauss 数据库运维-备份与恢复-导入数据-使用COPY FROM STDIN导入数据-处理错误表

文章目录 openGauss学习笔记-164 openGauss 数据库运维-备份与恢复-导入数据-使用COPY FROM STDIN导入数据-处理错误表164.1 操作场景164.2 查询错误信息164.3 处理数据导入错误 openGauss学习笔记-164 openGauss 数据库运维-备份与恢复-导入数据-使用COPY FROM STDIN导入数据-…...

QT Widget - 随便画个圆

简介 实现在界面中画一个圆, 其实目的是想画一个LED效果的圆。代码 #include <QApplication> #include <QWidget> #include <QPainter> #include <QColor> #include <QPen>class LEDWidget : public QWidget { public:LEDWidget(QWidget *pare…...

js输入框部分内容不可编辑,其余正常输入,el-input和el-select输入框和多个下拉框联动后的内容不可修改

<tr>//格式// required自定义指令<e-td :required"!read" label><span>地区&#xff1a;</span></e-td><td>//v-if"!read && this.data.nationCode 148"显示逻辑<divclass"table-cell-flex"sty…...

分布式文件存储系统minio了解下

什么是minio minio 是一个基于 Apache License v2.0 开源协议的对象存储服务。非常适合于存储大容量非结构化的数据&#xff0c;例如图片、视频、日志文件、备份数据和容器/虚拟机镜像等&#xff0c;而一个对象文件可以是任意大小。 是一种海量、安全、低成本、高可靠的云存储…...

迅为RK3568开发板使用OpenCV处理图像-ROI区域-位置提取ROI

在图像处理过程中&#xff0c;我们可能会对图像的某一个特定区域感兴趣&#xff0c;该区域被称为感兴趣区域&#xff08;Region of Interest, ROI&#xff09;。在设定感兴趣区域 ROI 后&#xff0c;就可以对该区域进行整体操作。 位置提取 ROI 本小节代码在配套资料“iTOP-3…...

重新认识Word——尾注

重新认识Word——尾注 参考文献格式文献自动生成器插入尾注将数字带上方括号将参考文献中的标号改为非上标 多处引用一篇文献多篇文献被一处引用插入尾注有横线怎么删除&#xff1f;删除尾注 前面我们学习了如何给图片&#xff0c;公式自动添加编号&#xff0c;今天我们来看看毕…...

所有学前教育专业,一定要刷到这篇啊

我是真的希望所有学前教育的宝子都能刷到这篇啊啊&#xff0c;只要输入需求&#xff0c;几秒它就给你写出来了&#xff0c;而且不满意还可以重新写多&#xff0c;每次都是不一样的内容。重复率真的不高&#xff0c;需求越多&#xff0c;生成的文字内容越精准&#xff01;&#…...

colmap三维重建核心逻辑梳理

colmap三维重建核心逻辑梳理 1. 算法流程束流2. 初始化3. 重建主流程 1. 算法流程束流 重建核心逻辑见 incremental_mapper.cc 中 IncrementMapperController 中 Reconstruct 初始化变量和对象判断是否有初始重建模型&#xff0c;若有&#xff0c;则获取初始重建模型数量&am…...

查询某个类是在哪个JAR的什么版本开始出现的方法

背景 我们在依赖第三方JAR时&#xff0c;同时也会间接的依赖第三方JAR引用的依赖&#xff0c;而当我们项目中某个依赖的版本与第三方JAR依赖的版本不一致时&#xff0c;可能会导致第三方JAR的在运行时无法找到某些方法或类&#xff0c;从而无法正常使用。 如我正在开发的一个…...

Linux本地搭建StackEdit Markdown编辑器结合内网穿透实现远程访问

文章目录 1. docker部署Stackedit2. 本地访问3. Linux 安装cpolar4. 配置Stackedit公网访问地址5. 公网远程访问Stackedit6. 固定Stackedit公网地址 StackEdit是一个受欢迎的Markdown编辑器&#xff0c;在GitHub上拥有20.7k Star&#xff01;&#xff0c;它支持将Markdown笔记保…...

k8s中ConfigMap、Secret创建使用演示、配置文件存储介绍

目录 一.ConfigMap&#xff08;cm&#xff09; 1.适用场景 2.创建并验证configmap &#xff08;1&#xff09;以yaml配置文件创建configmap&#xff0c;验证变化是是否同步 &#xff08;2&#xff09;--from-file以目录或文件 3.如何使用configmap &#xff08;1&#x…...

Linux服务器性能优化小结

文章目录 生产环境监测常见专业名词扫盲服务器平均负载服务器平均负载的定义如何判断平均负载值以及好坏情况如果依据平均负载来判断服务器当前状况系统平均负载和CPU使用率的区别 CPU上下文切换基本概念3种上下文切换进程上下文切换线程上下文切换中断上下文切换 查看上下文切…...

在软件开发中正确使用MySQL日期时间类型的深度解析

在日常软件开发场景中&#xff0c;时间信息的存储是底层且核心的需求。从金融交易的精确记账时间、用户操作的行为日志&#xff0c;到供应链系统的物流节点时间戳&#xff0c;时间数据的准确性直接决定业务逻辑的可靠性。MySQL作为主流关系型数据库&#xff0c;其日期时间类型的…...

JavaSec-RCE

简介 RCE(Remote Code Execution)&#xff0c;可以分为:命令注入(Command Injection)、代码注入(Code Injection) 代码注入 1.漏洞场景&#xff1a;Groovy代码注入 Groovy是一种基于JVM的动态语言&#xff0c;语法简洁&#xff0c;支持闭包、动态类型和Java互操作性&#xff0c…...

什么是库存周转?如何用进销存系统提高库存周转率?

你可能听说过这样一句话&#xff1a; “利润不是赚出来的&#xff0c;是管出来的。” 尤其是在制造业、批发零售、电商这类“货堆成山”的行业&#xff0c;很多企业看着销售不错&#xff0c;账上却没钱、利润也不见了&#xff0c;一翻库存才发现&#xff1a; 一堆卖不动的旧货…...

工业自动化时代的精准装配革新:迁移科技3D视觉系统如何重塑机器人定位装配

AI3D视觉的工业赋能者 迁移科技成立于2017年&#xff0c;作为行业领先的3D工业相机及视觉系统供应商&#xff0c;累计完成数亿元融资。其核心技术覆盖硬件设计、算法优化及软件集成&#xff0c;通过稳定、易用、高回报的AI3D视觉系统&#xff0c;为汽车、新能源、金属制造等行…...

JVM虚拟机:内存结构、垃圾回收、性能优化

1、JVM虚拟机的简介 Java 虚拟机(Java Virtual Machine 简称:JVM)是运行所有 Java 程序的抽象计算机,是 Java 语言的运行环境,实现了 Java 程序的跨平台特性。JVM 屏蔽了与具体操作系统平台相关的信息,使得 Java 程序只需生成在 JVM 上运行的目标代码(字节码),就可以…...

【VLNs篇】07:NavRL—在动态环境中学习安全飞行

项目内容论文标题NavRL: 在动态环境中学习安全飞行 (NavRL: Learning Safe Flight in Dynamic Environments)核心问题解决无人机在包含静态和动态障碍物的复杂环境中进行安全、高效自主导航的挑战&#xff0c;克服传统方法和现有强化学习方法的局限性。核心算法基于近端策略优化…...

A2A JS SDK 完整教程:快速入门指南

目录 什么是 A2A JS SDK?A2A JS 安装与设置A2A JS 核心概念创建你的第一个 A2A JS 代理A2A JS 服务端开发A2A JS 客户端使用A2A JS 高级特性A2A JS 最佳实践A2A JS 故障排除 什么是 A2A JS SDK? A2A JS SDK 是一个专为 JavaScript/TypeScript 开发者设计的强大库&#xff…...

pycharm 设置环境出错

pycharm 设置环境出错 pycharm 新建项目&#xff0c;设置虚拟环境&#xff0c;出错 pycharm 出错 Cannot open Local Failed to start [powershell.exe, -NoExit, -ExecutionPolicy, Bypass, -File, C:\Program Files\JetBrains\PyCharm 2024.1.3\plugins\terminal\shell-int…...

基于单片机的宠物屋智能系统设计与实现(论文+源码)

本设计基于单片机的宠物屋智能系统核心是实现对宠物生活环境及状态的智能管理。系统以单片机为中枢&#xff0c;连接红外测温传感器&#xff0c;可实时精准捕捉宠物体温变化&#xff0c;以便及时发现健康异常&#xff1b;水位检测传感器时刻监测饮用水余量&#xff0c;防止宠物…...

el-amap-bezier-curve运用及线弧度设置

文章目录 简介示例线弧度属性主要弧度相关属性其他相关样式属性完整示例链接简介 ‌el-amap-bezier-curve 是 Vue-Amap 组件库中的一个组件,用于在 高德地图 上绘制贝塞尔曲线。‌ 基本用法属性path定义曲线的路径,可以是多个弧线段的组合。stroke-weight线条的宽度。stroke…...