Resnet模型详解
1、Resnet是什么?
Resnet是一种深度神经网络架构,被广泛用于计算机视觉任务,特别是图像分类。它是由微软研究院的研究员于2015年提出的,是深度学习领域的重要里程碑之一。
2、网络退化问题
理论上来讲,随着网络的层数的增加,网络能够进行更加复杂的特征提取,可以取得更好的结果。但是实验发现深度网络出现了退化问题,如下图所示。网络深度增加时,网络准确度出现饱和,之后甚至还快速下降。而且这种下降不是因为过拟合引起的,而是因为在适当的深度模型上添加更多的层会导致了更高的训练误差,从而使其下降。
图1 网络深度对比(来源:Resnet的论文)
当你使用深度神经网络进行训练时,网络层可以被看作是一系列的函数堆叠,每个函数代表一个网络层的操作,这里我们就记作。在反向传播过程中,梯度是通过链式法则逐层计算得出的。假设每个操作的梯度都小于1,因为多个小于1的数相乘可能会导致结果变得更小。在神经网络中,随着反向传播的逐层传递,梯度可能会逐渐变得非常小,甚至接近于零,这就是梯度消失问题。
而如果经过网络层操作后的输出值大于1,那么反向传播时梯度可能会相应地增大。这种情况下,梯度爆炸问题可能会出现。梯度爆炸问题指的是在深度神经网络中,梯度逐渐放大,导致底层网络的参数更新过大,甚至可能导致数值溢出。
3、残差结构
在ResNet提出之前,所有的神经网络都是通过卷积层和池化层的叠加组成的。所以,Resnet对后面计算机视觉的发展影响是巨大的。
图2 残差结构(来源:Resnet的论文)
它这里完成的一个很简单的过程,我先举一个例子:
想象一张经过神经网络处理后的低分辨率图像。为了提高图像的质量,我们引入了一个创新的思想:将原始高分辨率图像与低分辨率图像之间的差异提取出来,形成了一个残差图像。这个残差图像代表了低分辨率图像与目标高分辨率图像之间的差异或缺失的细节。
图3 残差图像
然后,我们将这个残差图像与低分辨率图像相加,得到一个结合了低分辨率信息和残差细节的新图像。这个新图像作为下一个神经网络层的输入,使网络能够同时利用原始低分辨率信息和残差细节信息进行更精确的学习。
图4 残差+低分辨率图像
通过这种方式,我们的神经网络能够逐步地从低分辨率图像中提取信息,并通过残差图像的相加操作将遗漏的细节加回来。这使得网络能够更有效地进行图像恢复或其他任务,提高了模型的性能和准确性。
我相信我已经成功表达了残差结构的思想和操作过程。其实这个思想也并非是resnet创新的,在我们过去的其他领域中早已有这种思想,ResNet将这一思想引入了计算机视觉领域,并在深度神经网络中的训练中取得了重要突破。这种创新在一定程度上解决了深层神经网络训练中的梯度消失和梯度爆炸问题,使得网络能够更深更准确地学习特征和表示。
4、Resnet网络结构
(1)对于相同的输出特征图尺寸,层具有相同数量的滤波器
(2)当feature map大小降低一半时,feature map的数量增加一倍【过滤器(可以看作是卷积核的集合)的数量增加一倍】,这保持了网络层的复杂度。然后通过步长为2的卷积层直接执行下采样。
网络结构具体如下图所示:
图5 左为VGG-19,中为34个参数层的简单网络,右为34个参数层的残差网络
ResNet相比普通网络每两层间增加了短路机制,这就形成了残差学习,其中实线表示快捷连接,虚线表示feature map数量发生了改变。
有两种情况需要考虑:
(1)输入和输出具有相同的维度时(对应实线部分):
直接使用恒等快捷连接
(2)维度增加(当快捷连接跨越两种尺寸的特征图时,它们执行时步长为2):
①快捷连接仍然执行恒等映射,额外填充零输入以增加维度。这样就不会引入额外的参数。
②用下面公式的投影快捷连接用于匹配维度(由1×1卷积完成)
论文中也提供了更详细的结构,如下图所示:
5、使用Pytorch实现Resnet
本来是按照论文手写的代码,但用的时候发现维度不匹配,用不了预训练权重,所以这里就照着pytorch源码进行了修改。
import torch
import torchvision
import torch.nn as nn
import torchsummary
from torch.hub import load_state_dict_from_urlmodel_urls = {"resnet18" : "https://download.pytorch.org/models/resnet18-f37072fd.pth","resnet34" : "https://download.pytorch.org/models/resnet34-b627a593.pth","resnet50" : "https://download.pytorch.org/models/resnet50-0676ba61.pth","resnet101" : "https://download.pytorch.org/models/resnet101-63fe2227.pth","resnet152" : "https://download.pytorch.org/models/resnet152-394f9c45.pth",
}cfgs = {"resnet18": [2, 2, 2, 2],"resnet34": [3, 4, 6, 3],"resnet50": [3, 4, 6, 3],"resnet101": [3, 4, 23, 3],"resnet152": [3, 8, 36, 3],
}def conv1x1(in_planes, out_planes, stride = 1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=(1,1), stride=(stride,stride), bias=False)
def conv3x3(in_planes, out_planes, stride= 1, groups=1, dilation=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes,out_planes,kernel_size=(3,3),stride=(stride,stride),padding=dilation,groups=groups,bias=False,dilation=(dilation,dilation))
class BasicBlock(nn.Module):expansion = 1def __init__(self,inplanes: int,planes,stride = 1,downsample = None,groups =1,base_width = 64,dilation= 1,norm_layer= None,):super(BasicBlock,self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dif groups != 1 or base_width != 64:raise ValueError("BasicBlock only supports groups=1 and base_width=64")if dilation > 1:raise NotImplementedError("Dilation > 1 not supported in BasicBlock")self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = norm_layer(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = norm_layer(planes)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass Bottleneck(nn.Module):expansion = 4def __init__(self,inplanes,planes,stride = 1,downsample = None,groups = 1,base_width = 64,dilation = 1,norm_layer = None,):super(Bottleneck,self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dwidth = int(planes * (base_width / 64.0)) * groupsself.conv1 = conv1x1(inplanes, width)self.bn1 = norm_layer(width)self.conv2 = conv3x3(width, width, stride, groups, dilation)self.bn2 = norm_layer(width)self.conv3 = conv1x1(width, planes * self.expansion)self.bn3 = norm_layer(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self,block,layers,in_channels=3,num_classes = 1000,zero_init_residual = False,groups = 1,width_per_group = 64,replace_stride_with_dilation = None,):super(ResNet,self).__init__()norm_layer = nn.BatchNorm2dself._norm_layer = norm_layerself.num=num_classesself.inplanes = 64self.dilation = 1self.groups = groupsself.base_width = width_per_groupself.conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=(7,7), stride=(2,2), padding=3, bias=False)self.bn1 = norm_layer(self.inplanes)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=False)self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=False)self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=False)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, self.num)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck) and m.bn3.weight is not None:nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]elif isinstance(m, BasicBlock) and m.bn2.weight is not None:nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]def _make_layer(self,block, planes, blocks, stride = 1,dilate = False,):norm_layer = self._norm_layerdownsample = Noneprevious_dilation = self.dilationif dilate:self.dilation *= stridestride = 1if stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride),norm_layer(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes,planes,groups=self.groups,base_width=self.base_width,dilation=self.dilation,))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef resnet(in_channels, num_classes, mode='resnet50', pretrained=False):if mode == "resnet18" or mode == "resnet34":block = BasicBlockelse:block = Bottleneckmodel = ResNet(block, cfgs[mode], in_channels=in_channels, num_classes=num_classes)if pretrained:state_dict = load_state_dict_from_url(model_urls[mode], model_dir='./model', progress=True) # 预训练模型地址if num_classes != 1000:num_new_classes = num_classesfc_weight = state_dict['fc.weight']fc_bias = state_dict['fc.bias']fc_weight_new = fc_weight[:num_new_classes, :]fc_bias_new = fc_bias[:num_new_classes]state_dict['fc.weight'] = fc_weight_newstate_dict['fc.bias'] = fc_bias_newmodel.load_state_dict(state_dict)return model
这种写法是按照先前VGG那样写的,这样有助于使用同一个model_urls和cfgs。
BasicBlock类中的init()函数是先定义网络架构,forward()的函数是前向传播,实现的功能就是残差块:
Bottleneck类是另一种blcok类型,同上,init()函数是预定义网络架构,forward函数是进行前向传播。该block中有三个卷积,分别是1x1,3x3,1x1,分别完成的功能就是维度压缩,卷积,恢复维度,所以bottleneck实现的功能就是对通道数进行压缩,再放大。
注意:这里的plane不再是输出的通道数,输出通道数应该就是plane*expansion,即4*plane。
- resnet18: BasicBlock, [2, 2, 2, 2]
- resnet34: BasicBlock, [3, 4, 6, 3]
- resnet50: Bottleneck, [3, 4, 6, 3]
- resnet101: Bottleneck, [3, 4, 23, 3]
- resnet152: Bottleneck, [3, 8, 36, 3]
这个后面的结构是作者自己挑的一个参数,所以不用管它为什么。BasicBlock主要用于resnet18和34,Bottleneck用于resnet50,101和152。
相关文章:
Resnet模型详解
1、Resnet是什么? Resnet是一种深度神经网络架构,被广泛用于计算机视觉任务,特别是图像分类。它是由微软研究院的研究员于2015年提出的,是深度学习领域的重要里程碑之一。 2、网络退化问题 理论上来讲,随着网络的层…...
AI 绘画Stable Diffusion 研究(十六)SD Hypernetwork详解
大家好,我是风雨无阻。 本期内容: 什么是 Hypernetwork?Hypernetwork 与其他模型的区别?Hypernetwork 原理Hypernetwork 如何下载安装?Hypernetwork 如何使用? 在上一篇文章中,我们详细介绍了 …...
2023.8 -java - 继承
继承就是子类继承父类的特征和行为,使得子类对象(实例)具有父类的实例域和方法,或子类从父类继承方法,使得子类具有父类相同的行为。 继承的特性 子类拥有父类非 private 的属性、方法。 子类可以拥有自己的属性和方法…...
前端面试:【移动端开发】PWA、Hybrid App和Native App的比较
在移动端开发中,开发者有多种选择,包括渐进式Web应用(PWA),混合应用(Hybrid App)和原生应用(Native App)。每种方法都有其独特的优势和适用场景。本文将对它们进行比较&a…...
picGo+gitee+typora设置图床
picGogiteetypora设置图床 picGogitee设置图床下载picGo软件安装picGo软件gitee操作在gitee中创建仓库在gitee中配置私人令牌 配置picGo在插件设置中搜索gitee插件并进行下载 TyporapicGo设置Typora 下载Typora进行图像设置 picGogitee设置图床 当我了解picGogitee可以设置图床…...
[JavaWeb]【十三】web后端开发-原理篇
目录 一、SpringBoot配置优先级 1.1 配置优先级比较 1.2 java系统属性和命令行参数 1.3 打包运行jar 1.4 综合优先级编辑 二、Bean管理 2.1 获取bean 2.2 bean作用域 2.2.1 五种作用域 2.2.2 配置作用域 2.3 第三方bean 2.3.1 编写公共配置类 三、SpringBoot原理 …...
服务注册中心 Eureka
服务注册中心 Eureka Spring Cloud Eureka 是 Netflix 公司开发的注册发现组件,本身是一个基于 REST 的服务。提供注册与发现,同时还提供了负载均衡、故障转移等能力。 Eureka 有 3 个角色 服务中心(Eureka Server):…...
SpringIoC组件的高级特性
目录 一、Bean组件的周期与作用域 二、FactoryBean接口 一、Bean组件的周期与作用域 1.1 Bean组件的生命周期 什么是Bean的周期方法 我们可以在组件类中定义方法,然后当IoC容器实例化和销毁组件对象的时候进行调用!这两个方法我们成为生命周期方法&a…...
Linux--进程地址空间
1.线程地址空间 所谓进程地址空间(process address space),就是从进程的视角看到的地址空间,是进程运行时所用到的虚拟地址的集合。 简单地说,进程就是内核数据结构和代码和本身的代码和数据,进程本身不能…...
ISIS路由协议
骨干区域与非骨干区域 凡是由级别2组建起来的邻居形成骨干区域;级别1就在非骨干区域,骨干区域有且只有一个,并且需要连续,ISIS在IP环境下目前不支持虚链路。 路由器级别 L1路由器只能建立L1的邻居;L2路由器只能建立L…...
论文解读:Bert原理深入浅出
摘取于https://www.jianshu.com/p/810ca25c4502 任务1:Masked Language Model Maked LM 是为了解决单向信息问题,现有的语言模型的问题在于,没有同时利用双向信息,如 ELMO 号称是双向LM,但实际上是两个单向 RNN 构成的…...
共享内存 windows和linux
服务端,即写入端 #include <iostream> #include <string.h> #define BUF_SIZE 1024 #ifdef _WIN32 #include <windows.h> #define SHARENAME L"shareMemory" HANDLE g_MapFIle; LPVOID g_baseBuffer; #else #define SHARENAME "sh…...
一个mongodb问题分析
mongodb问题分析 现状 表的个数: 生产上常用的表就10来个。 sharding cluster replica set方式部署: 9个shard server, 每个shard server 1主2从, 大量数据写入时或对大表创建索引时,可能有主从复制延迟问题。实…...
Vue3.0极速入门- 目录和文件说明
目录结构 以下文件均为npm create helloworld自动生成的文件目录结构 目录截图 目录说明 目录/文件说明node_modulesnpm 加载的项目依赖模块src这里是我们要开发的目录,基本上要做的事情都在这个目录里assets放置一些图片,如logo等。componentsvue组件…...
RabbitMQ---订阅模型-Direct
1、 订阅模型-Direct • 有选择性的接收消息 • 在订阅模式中,生产者发布消息,所有消费者都可以获取所有消息。 • 在路由模式中,我们将添加一个功能 - 我们将只能订阅一部分消息。 例如,我们只能将重要的错误消息引导到日志文件…...
Django REST framework实现api接口
drf 是Django REST framework的简称,drf 是基于django的一个api 接口实现框架,REST是接口设计的一种风格。 一、 安装drf pip install djangorestframework pip install markdown # Markdown support for the browsable API. pip install …...
4.19 20
服务端没有 listen,客户端发起连接建立,会发生什么? 服务端如果只 bind 了 IP 地址和端口,而没有调用 listen 的话,然后客户端对服务端发起了连接建立,服务端会回 RST 报文。 没有 listen&#x…...
(动态规划) 剑指 Offer 10- II. 青蛙跳台阶问题 ——【Leetcode每日一题】
❓剑指 Offer 10- II. 青蛙跳台阶问题 难度:简单 一只青蛙一次可以跳上1级台阶,也可以跳上2级台阶。求该青蛙跳上一个 n 级的台阶总共有多少种跳法。 答案需要取模 1e97(1000000007),如计算初始结果为:1…...
物联网WIFI 模块AT指令版本七大元凶
前言 目前我们讨论的这个问题,并不是说WIFI方案不具备以应的功能。而是指在同一个AT固件下可能存在的问题。由于各厂商AT指令的开发深度不同,导致各厂商之间的AT指令差异很大。我总结了一些问题,可能是导致目前AT指令不好用元凶。 底层库问题…...
Qt 正则(数据格式校验、替换指定格式数据、获取匹配数据)
头文件引用 #include <QRegExp>初始化QRegExp实列 QRegExp re("^\\d{1,3},\\d{1,3}$");数据格式验证 QRegExp re("^\\d{1,3},\\d{1,3}$"); QString msg "12,33"; if(re.exactMatch()){// 验证通过 }else{//验证不通过 }替换数…...
网络层协议——ip
文章目录 1. 网络层2. IP协议2.1 协议头格式 3. 网段划分3.1 特殊的IP地址3.2 IP地址的数量限制 4. 私有IP地址和公网IP地址 1. 网络层 在应用层解决了如何读取完整报文、序列化反序列化、协议处理问题。在传输层解决了可靠性问题。那么网络层IP的作用是在复杂的网络环境中确定…...
Qt6和Rust结合构建桌面应用
桌面应用程序是原生的、快速的、安全的,并提供Web应用程序无法比拟的体验。 Rust 是一种低级静态类型多范式编程语言,专注于安全性和性能,解决了 C/C 长期以来一直在努力解决的问题,例如内存错误和构建并发程序。 在桌面应用程序开…...
Kubernetes(K8S)简介
Kubernetes (K8S) 是什么 它是一个为 容器化 应用提供集群部署和管理的开源工具,由 Google 开发。Kubernetes 这个名字源于希腊语,意为“舵手”或“飞行员”。k8s 这个缩写是因为 k 和 s 之间有八个字符的关系。 Google 在 2014 年开源了 Kubernetes 项…...
面试中问:React中函数组件和class组件的区别,hooks模拟生命周期
React中函数组件和class组件的区别,hooks模拟生命周期 React中函数组件和class组件的区别hooks模拟生命周期 React中函数组件和class组件的区别 函数组件: 定义:函数组件是使用纯函数定义的组件,它接受 props 作为参数并返回 JSX。简洁性&am…...
Python高光谱遥感数据处理与高光谱遥感机器学习方法应用
本文提供一套基于Python编程工具的高光谱数据处理方法和应用案例。 本文涵盖高光谱遥感的基础、方法和实践。基础篇以学员为中心,用通俗易懂的语言解释高光谱的基本概念和理论,旨在帮助学员深入理解科学原理。方法篇结合Python编程工具,专注…...
Java实现接收xml格式数据并解析,返回xml格式数据
需求描述:后端接受xml格式数据,解析出相应数据,并返回xml格式数据。 <!--XML解析--><dependency><groupId>com.fasterxml.jackson.dataformat</groupId><artifactId>jackson-dataformat-xml</artifactId>…...
【C++】初步认识模板
🏖️作者:malloc不出对象 ⛺专栏:C的学习之路 👦个人简介:一名双非本科院校大二在读的科班编程菜鸟,努力编程只为赶上各位大佬的步伐🙈🙈 目录 前言一、泛型编程二、函数模板2.1 函…...
Ansible 临时命令搭建安装仓库
创建一个名为/ansible/yum.sh 的 shell 脚本,该脚本将使用 Ansible 临时命令在各个受管节点上安装 yum 存储库. 存储库1: 存储库的名称为 EX294_BASE 描述为 EX294 base software 基础 URL 为 http://content/rhel8.0/x86_64/dvd/BaseOS GPG 签名检查为…...
phpstorm动态调试
首先在phpstudy搭建好网站,在管理拓展开启xdebug拓展 查看php.ini配置已经更改 需要增添修改一下设置 [Xdebug] zend_extensionD:/phpstudy_pro/Extensions/php/php5.6.9nts/ext/php_xdebug.dll xdebug.collect_params1 xdebug.collect_return1 xdebug.auto_trace…...
二叉树的层序遍历及完全二叉树的判断
文章目录 1.二叉树层序遍历 2.完全二叉树的判断 文章内容 1.二叉树层序遍历 二叉树的层序遍历需要一个队列来帮助实现。 我们在队列中存储的是节点的地址,所以我们要对队列结构体的数据域重定义, 以上代码 从逻辑上来讲就是1入队,1出队&am…...
网站模板下载模板下载安装/什么是网络营销渠道
在myeclipse中deploy:选择了一个工程,添加一个新的deploy工程时,不能正常出现deploy Location,可能的原因是没有在mymatadata中添加context-root"/",另外webrootdir属性也要设置正确。一个常见的配置如下&am…...
买做指甲的材料在哪个网站/最新足球消息
说多了都是泪。。。 最后,我们还是要用一首“500的歌”来共勉:来来来,加完今天,还有三天! 文章转载自 开源中国社区 [http://www.oschina.net]...
乐云seo商城网站建设/男生最喜欢的浏览器
声明: 以下内容为阅读由周志明编著的《深入理解Java虚拟机JVM高级特性与最佳实战》(第二版)自行总结记录,算不上完全解读了大神的意思,但也没有瞎写。如果写的不清楚的地方,还望能够自行阅读原著。这里写记…...
wordpress 提示要安装/soe搜索优化
作为2017世界物联网博览会的重要活动之一,由中国经济信息社江苏中心研撰的《2016-2017中国物联网发展年度报告》(下称《年报》)近日在无锡发布。《年报》认为,我国智慧城市步入实质发展阶段,企业广泛参与、营收能力增强。 2016年以来,我国智慧城市建设开放合作特征进一步凸显,B…...
wordpress主题检测/搜索引擎优化的核心及内容
参考文章1:Ubuntu20.04安装Mysql 参考文章2:ubuntu在安装MySQL常见问题和mysql_secure_installation向导记录详解 会遇到这个插件的问题:mysql1193 HY000,MySQL ERROR 1193 (HY000): Unknown system variable ‘validate_password_policy’ …...
做电影网站算侵权吗/aso优化什么意思
测试环境:ubuntu18.04driver450cuda11.0cudnn8.0.5opencv4.4.0 1、ubuntu显卡驱动下载安装 2、cuda及cudnn安装 3、opencv4编译配置 4、darknet源码编译测试...