Resnet模型详解
1、Resnet是什么?
Resnet是一种深度神经网络架构,被广泛用于计算机视觉任务,特别是图像分类。它是由微软研究院的研究员于2015年提出的,是深度学习领域的重要里程碑之一。
2、网络退化问题
理论上来讲,随着网络的层数的增加,网络能够进行更加复杂的特征提取,可以取得更好的结果。但是实验发现深度网络出现了退化问题,如下图所示。网络深度增加时,网络准确度出现饱和,之后甚至还快速下降。而且这种下降不是因为过拟合引起的,而是因为在适当的深度模型上添加更多的层会导致了更高的训练误差,从而使其下降。
图1 网络深度对比(来源:Resnet的论文)
当你使用深度神经网络进行训练时,网络层可以被看作是一系列的函数堆叠,每个函数代表一个网络层的操作,这里我们就记作。在反向传播过程中,梯度是通过链式法则逐层计算得出的。假设每个操作的梯度都小于1,因为多个小于1的数相乘可能会导致结果变得更小。在神经网络中,随着反向传播的逐层传递,梯度可能会逐渐变得非常小,甚至接近于零,这就是梯度消失问题。
而如果经过网络层操作后的输出值大于1,那么反向传播时梯度可能会相应地增大。这种情况下,梯度爆炸问题可能会出现。梯度爆炸问题指的是在深度神经网络中,梯度逐渐放大,导致底层网络的参数更新过大,甚至可能导致数值溢出。
3、残差结构
在ResNet提出之前,所有的神经网络都是通过卷积层和池化层的叠加组成的。所以,Resnet对后面计算机视觉的发展影响是巨大的。
图2 残差结构(来源:Resnet的论文)
它这里完成的一个很简单的过程,我先举一个例子:
想象一张经过神经网络处理后的低分辨率图像。为了提高图像的质量,我们引入了一个创新的思想:将原始高分辨率图像与低分辨率图像之间的差异提取出来,形成了一个残差图像。这个残差图像代表了低分辨率图像与目标高分辨率图像之间的差异或缺失的细节。
图3 残差图像
然后,我们将这个残差图像与低分辨率图像相加,得到一个结合了低分辨率信息和残差细节的新图像。这个新图像作为下一个神经网络层的输入,使网络能够同时利用原始低分辨率信息和残差细节信息进行更精确的学习。
图4 残差+低分辨率图像
通过这种方式,我们的神经网络能够逐步地从低分辨率图像中提取信息,并通过残差图像的相加操作将遗漏的细节加回来。这使得网络能够更有效地进行图像恢复或其他任务,提高了模型的性能和准确性。
我相信我已经成功表达了残差结构的思想和操作过程。其实这个思想也并非是resnet创新的,在我们过去的其他领域中早已有这种思想,ResNet将这一思想引入了计算机视觉领域,并在深度神经网络中的训练中取得了重要突破。这种创新在一定程度上解决了深层神经网络训练中的梯度消失和梯度爆炸问题,使得网络能够更深更准确地学习特征和表示。
4、Resnet网络结构
(1)对于相同的输出特征图尺寸,层具有相同数量的滤波器
(2)当feature map大小降低一半时,feature map的数量增加一倍【过滤器(可以看作是卷积核的集合)的数量增加一倍】,这保持了网络层的复杂度。然后通过步长为2的卷积层直接执行下采样。
网络结构具体如下图所示:
图5 左为VGG-19,中为34个参数层的简单网络,右为34个参数层的残差网络
ResNet相比普通网络每两层间增加了短路机制,这就形成了残差学习,其中实线表示快捷连接,虚线表示feature map数量发生了改变。
有两种情况需要考虑:
(1)输入和输出具有相同的维度时(对应实线部分):
直接使用恒等快捷连接
(2)维度增加(当快捷连接跨越两种尺寸的特征图时,它们执行时步长为2):
①快捷连接仍然执行恒等映射,额外填充零输入以增加维度。这样就不会引入额外的参数。
②用下面公式的投影快捷连接用于匹配维度(由1×1卷积完成)
论文中也提供了更详细的结构,如下图所示:
5、使用Pytorch实现Resnet
本来是按照论文手写的代码,但用的时候发现维度不匹配,用不了预训练权重,所以这里就照着pytorch源码进行了修改。
import torch
import torchvision
import torch.nn as nn
import torchsummary
from torch.hub import load_state_dict_from_urlmodel_urls = {"resnet18" : "https://download.pytorch.org/models/resnet18-f37072fd.pth","resnet34" : "https://download.pytorch.org/models/resnet34-b627a593.pth","resnet50" : "https://download.pytorch.org/models/resnet50-0676ba61.pth","resnet101" : "https://download.pytorch.org/models/resnet101-63fe2227.pth","resnet152" : "https://download.pytorch.org/models/resnet152-394f9c45.pth",
}cfgs = {"resnet18": [2, 2, 2, 2],"resnet34": [3, 4, 6, 3],"resnet50": [3, 4, 6, 3],"resnet101": [3, 4, 23, 3],"resnet152": [3, 8, 36, 3],
}def conv1x1(in_planes, out_planes, stride = 1):"""1x1 convolution"""return nn.Conv2d(in_planes, out_planes, kernel_size=(1,1), stride=(stride,stride), bias=False)
def conv3x3(in_planes, out_planes, stride= 1, groups=1, dilation=1):"""3x3 convolution with padding"""return nn.Conv2d(in_planes,out_planes,kernel_size=(3,3),stride=(stride,stride),padding=dilation,groups=groups,bias=False,dilation=(dilation,dilation))
class BasicBlock(nn.Module):expansion = 1def __init__(self,inplanes: int,planes,stride = 1,downsample = None,groups =1,base_width = 64,dilation= 1,norm_layer= None,):super(BasicBlock,self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dif groups != 1 or base_width != 64:raise ValueError("BasicBlock only supports groups=1 and base_width=64")if dilation > 1:raise NotImplementedError("Dilation > 1 not supported in BasicBlock")self.conv1 = conv3x3(inplanes, planes, stride)self.bn1 = norm_layer(planes)self.relu = nn.ReLU(inplace=True)self.conv2 = conv3x3(planes, planes)self.bn2 = norm_layer(planes)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass Bottleneck(nn.Module):expansion = 4def __init__(self,inplanes,planes,stride = 1,downsample = None,groups = 1,base_width = 64,dilation = 1,norm_layer = None,):super(Bottleneck,self).__init__()if norm_layer is None:norm_layer = nn.BatchNorm2dwidth = int(planes * (base_width / 64.0)) * groupsself.conv1 = conv1x1(inplanes, width)self.bn1 = norm_layer(width)self.conv2 = conv3x3(width, width, stride, groups, dilation)self.bn2 = norm_layer(width)self.conv3 = conv1x1(width, planes * self.expansion)self.bn3 = norm_layer(planes * self.expansion)self.relu = nn.ReLU(inplace=True)self.downsample = downsampleself.stride = stridedef forward(self, x):identity = xout = self.conv1(x)out = self.bn1(out)out = self.relu(out)out = self.conv2(out)out = self.bn2(out)out = self.relu(out)out = self.conv3(out)out = self.bn3(out)if self.downsample is not None:identity = self.downsample(x)out += identityout = self.relu(out)return outclass ResNet(nn.Module):def __init__(self,block,layers,in_channels=3,num_classes = 1000,zero_init_residual = False,groups = 1,width_per_group = 64,replace_stride_with_dilation = None,):super(ResNet,self).__init__()norm_layer = nn.BatchNorm2dself._norm_layer = norm_layerself.num=num_classesself.inplanes = 64self.dilation = 1self.groups = groupsself.base_width = width_per_groupself.conv1 = nn.Conv2d(in_channels, self.inplanes, kernel_size=(7,7), stride=(2,2), padding=3, bias=False)self.bn1 = norm_layer(self.inplanes)self.relu = nn.ReLU(inplace=True)self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)self.layer1 = self._make_layer(block, 64, layers[0])self.layer2 = self._make_layer(block, 128, layers[1], stride=2, dilate=False)self.layer3 = self._make_layer(block, 256, layers[2], stride=2, dilate=False)self.layer4 = self._make_layer(block, 512, layers[3], stride=2, dilate=False)self.avgpool = nn.AdaptiveAvgPool2d((1, 1))self.fc = nn.Linear(512 * block.expansion, self.num)for m in self.modules():if isinstance(m, nn.Conv2d):nn.init.kaiming_normal_(m.weight, mode="fan_out", nonlinearity="relu")elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):nn.init.constant_(m.weight, 1)nn.init.constant_(m.bias, 0)if zero_init_residual:for m in self.modules():if isinstance(m, Bottleneck) and m.bn3.weight is not None:nn.init.constant_(m.bn3.weight, 0) # type: ignore[arg-type]elif isinstance(m, BasicBlock) and m.bn2.weight is not None:nn.init.constant_(m.bn2.weight, 0) # type: ignore[arg-type]def _make_layer(self,block, planes, blocks, stride = 1,dilate = False,):norm_layer = self._norm_layerdownsample = Noneprevious_dilation = self.dilationif dilate:self.dilation *= stridestride = 1if stride != 1 or self.inplanes != planes * block.expansion:downsample = nn.Sequential(conv1x1(self.inplanes, planes * block.expansion, stride),norm_layer(planes * block.expansion),)layers = []layers.append(block(self.inplanes, planes, stride, downsample, self.groups, self.base_width, previous_dilation))self.inplanes = planes * block.expansionfor _ in range(1, blocks):layers.append(block(self.inplanes,planes,groups=self.groups,base_width=self.base_width,dilation=self.dilation,))return nn.Sequential(*layers)def forward(self, x):x = self.conv1(x)x = self.bn1(x)x = self.relu(x)x = self.maxpool(x)x = self.layer1(x)x = self.layer2(x)x = self.layer3(x)x = self.layer4(x)x = self.avgpool(x)x = torch.flatten(x, 1)x = self.fc(x)return xdef resnet(in_channels, num_classes, mode='resnet50', pretrained=False):if mode == "resnet18" or mode == "resnet34":block = BasicBlockelse:block = Bottleneckmodel = ResNet(block, cfgs[mode], in_channels=in_channels, num_classes=num_classes)if pretrained:state_dict = load_state_dict_from_url(model_urls[mode], model_dir='./model', progress=True) # 预训练模型地址if num_classes != 1000:num_new_classes = num_classesfc_weight = state_dict['fc.weight']fc_bias = state_dict['fc.bias']fc_weight_new = fc_weight[:num_new_classes, :]fc_bias_new = fc_bias[:num_new_classes]state_dict['fc.weight'] = fc_weight_newstate_dict['fc.bias'] = fc_bias_newmodel.load_state_dict(state_dict)return model
这种写法是按照先前VGG那样写的,这样有助于使用同一个model_urls和cfgs。
BasicBlock类中的init()函数是先定义网络架构,forward()的函数是前向传播,实现的功能就是残差块:
Bottleneck类是另一种blcok类型,同上,init()函数是预定义网络架构,forward函数是进行前向传播。该block中有三个卷积,分别是1x1,3x3,1x1,分别完成的功能就是维度压缩,卷积,恢复维度,所以bottleneck实现的功能就是对通道数进行压缩,再放大。
注意:这里的plane不再是输出的通道数,输出通道数应该就是plane*expansion,即4*plane。
- resnet18: BasicBlock, [2, 2, 2, 2]
- resnet34: BasicBlock, [3, 4, 6, 3]
- resnet50: Bottleneck, [3, 4, 6, 3]
- resnet101: Bottleneck, [3, 4, 23, 3]
- resnet152: Bottleneck, [3, 8, 36, 3]
这个后面的结构是作者自己挑的一个参数,所以不用管它为什么。BasicBlock主要用于resnet18和34,Bottleneck用于resnet50,101和152。
相关文章:

Resnet模型详解
1、Resnet是什么? Resnet是一种深度神经网络架构,被广泛用于计算机视觉任务,特别是图像分类。它是由微软研究院的研究员于2015年提出的,是深度学习领域的重要里程碑之一。 2、网络退化问题 理论上来讲,随着网络的层…...

AI 绘画Stable Diffusion 研究(十六)SD Hypernetwork详解
大家好,我是风雨无阻。 本期内容: 什么是 Hypernetwork?Hypernetwork 与其他模型的区别?Hypernetwork 原理Hypernetwork 如何下载安装?Hypernetwork 如何使用? 在上一篇文章中,我们详细介绍了 …...

2023.8 -java - 继承
继承就是子类继承父类的特征和行为,使得子类对象(实例)具有父类的实例域和方法,或子类从父类继承方法,使得子类具有父类相同的行为。 继承的特性 子类拥有父类非 private 的属性、方法。 子类可以拥有自己的属性和方法…...
前端面试:【移动端开发】PWA、Hybrid App和Native App的比较
在移动端开发中,开发者有多种选择,包括渐进式Web应用(PWA),混合应用(Hybrid App)和原生应用(Native App)。每种方法都有其独特的优势和适用场景。本文将对它们进行比较&a…...

picGo+gitee+typora设置图床
picGogiteetypora设置图床 picGogitee设置图床下载picGo软件安装picGo软件gitee操作在gitee中创建仓库在gitee中配置私人令牌 配置picGo在插件设置中搜索gitee插件并进行下载 TyporapicGo设置Typora 下载Typora进行图像设置 picGogitee设置图床 当我了解picGogitee可以设置图床…...

[JavaWeb]【十三】web后端开发-原理篇
目录 一、SpringBoot配置优先级 1.1 配置优先级比较 1.2 java系统属性和命令行参数 1.3 打包运行jar 1.4 综合优先级编辑 二、Bean管理 2.1 获取bean 2.2 bean作用域 2.2.1 五种作用域 2.2.2 配置作用域 2.3 第三方bean 2.3.1 编写公共配置类 三、SpringBoot原理 …...

服务注册中心 Eureka
服务注册中心 Eureka Spring Cloud Eureka 是 Netflix 公司开发的注册发现组件,本身是一个基于 REST 的服务。提供注册与发现,同时还提供了负载均衡、故障转移等能力。 Eureka 有 3 个角色 服务中心(Eureka Server):…...

SpringIoC组件的高级特性
目录 一、Bean组件的周期与作用域 二、FactoryBean接口 一、Bean组件的周期与作用域 1.1 Bean组件的生命周期 什么是Bean的周期方法 我们可以在组件类中定义方法,然后当IoC容器实例化和销毁组件对象的时候进行调用!这两个方法我们成为生命周期方法&a…...

Linux--进程地址空间
1.线程地址空间 所谓进程地址空间(process address space),就是从进程的视角看到的地址空间,是进程运行时所用到的虚拟地址的集合。 简单地说,进程就是内核数据结构和代码和本身的代码和数据,进程本身不能…...

ISIS路由协议
骨干区域与非骨干区域 凡是由级别2组建起来的邻居形成骨干区域;级别1就在非骨干区域,骨干区域有且只有一个,并且需要连续,ISIS在IP环境下目前不支持虚链路。 路由器级别 L1路由器只能建立L1的邻居;L2路由器只能建立L…...
论文解读:Bert原理深入浅出
摘取于https://www.jianshu.com/p/810ca25c4502 任务1:Masked Language Model Maked LM 是为了解决单向信息问题,现有的语言模型的问题在于,没有同时利用双向信息,如 ELMO 号称是双向LM,但实际上是两个单向 RNN 构成的…...

共享内存 windows和linux
服务端,即写入端 #include <iostream> #include <string.h> #define BUF_SIZE 1024 #ifdef _WIN32 #include <windows.h> #define SHARENAME L"shareMemory" HANDLE g_MapFIle; LPVOID g_baseBuffer; #else #define SHARENAME "sh…...
一个mongodb问题分析
mongodb问题分析 现状 表的个数: 生产上常用的表就10来个。 sharding cluster replica set方式部署: 9个shard server, 每个shard server 1主2从, 大量数据写入时或对大表创建索引时,可能有主从复制延迟问题。实…...

Vue3.0极速入门- 目录和文件说明
目录结构 以下文件均为npm create helloworld自动生成的文件目录结构 目录截图 目录说明 目录/文件说明node_modulesnpm 加载的项目依赖模块src这里是我们要开发的目录,基本上要做的事情都在这个目录里assets放置一些图片,如logo等。componentsvue组件…...

RabbitMQ---订阅模型-Direct
1、 订阅模型-Direct • 有选择性的接收消息 • 在订阅模式中,生产者发布消息,所有消费者都可以获取所有消息。 • 在路由模式中,我们将添加一个功能 - 我们将只能订阅一部分消息。 例如,我们只能将重要的错误消息引导到日志文件…...

Django REST framework实现api接口
drf 是Django REST framework的简称,drf 是基于django的一个api 接口实现框架,REST是接口设计的一种风格。 一、 安装drf pip install djangorestframework pip install markdown # Markdown support for the browsable API. pip install …...

4.19 20
服务端没有 listen,客户端发起连接建立,会发生什么? 服务端如果只 bind 了 IP 地址和端口,而没有调用 listen 的话,然后客户端对服务端发起了连接建立,服务端会回 RST 报文。 没有 listen&#x…...

(动态规划) 剑指 Offer 10- II. 青蛙跳台阶问题 ——【Leetcode每日一题】
❓剑指 Offer 10- II. 青蛙跳台阶问题 难度:简单 一只青蛙一次可以跳上1级台阶,也可以跳上2级台阶。求该青蛙跳上一个 n 级的台阶总共有多少种跳法。 答案需要取模 1e97(1000000007),如计算初始结果为:1…...
物联网WIFI 模块AT指令版本七大元凶
前言 目前我们讨论的这个问题,并不是说WIFI方案不具备以应的功能。而是指在同一个AT固件下可能存在的问题。由于各厂商AT指令的开发深度不同,导致各厂商之间的AT指令差异很大。我总结了一些问题,可能是导致目前AT指令不好用元凶。 底层库问题…...
Qt 正则(数据格式校验、替换指定格式数据、获取匹配数据)
头文件引用 #include <QRegExp>初始化QRegExp实列 QRegExp re("^\\d{1,3},\\d{1,3}$");数据格式验证 QRegExp re("^\\d{1,3},\\d{1,3}$"); QString msg "12,33"; if(re.exactMatch()){// 验证通过 }else{//验证不通过 }替换数…...

智慧医疗能源事业线深度画像分析(上)
引言 医疗行业作为现代社会的关键基础设施,其能源消耗与环境影响正日益受到关注。随着全球"双碳"目标的推进和可持续发展理念的深入,智慧医疗能源事业线应运而生,致力于通过创新技术与管理方案,重构医疗领域的能源使用模式。这一事业线融合了能源管理、可持续发…...
电脑插入多块移动硬盘后经常出现卡顿和蓝屏
当电脑在插入多块移动硬盘后频繁出现卡顿和蓝屏问题时,可能涉及硬件资源冲突、驱动兼容性、供电不足或系统设置等多方面原因。以下是逐步排查和解决方案: 1. 检查电源供电问题 问题原因:多块移动硬盘同时运行可能导致USB接口供电不足&#x…...

基于当前项目通过npm包形式暴露公共组件
1.package.sjon文件配置 其中xh-flowable就是暴露出去的npm包名 2.创建tpyes文件夹,并新增内容 3.创建package文件夹...
【AI学习】三、AI算法中的向量
在人工智能(AI)算法中,向量(Vector)是一种将现实世界中的数据(如图像、文本、音频等)转化为计算机可处理的数值型特征表示的工具。它是连接人类认知(如语义、视觉特征)与…...

让AI看见世界:MCP协议与服务器的工作原理
让AI看见世界:MCP协议与服务器的工作原理 MCP(Model Context Protocol)是一种创新的通信协议,旨在让大型语言模型能够安全、高效地与外部资源进行交互。在AI技术快速发展的今天,MCP正成为连接AI与现实世界的重要桥梁。…...

IoT/HCIP实验-3/LiteOS操作系统内核实验(任务、内存、信号量、CMSIS..)
文章目录 概述HelloWorld 工程C/C配置编译器主配置Makefile脚本烧录器主配置运行结果程序调用栈 任务管理实验实验结果osal 系统适配层osal_task_create 其他实验实验源码内存管理实验互斥锁实验信号量实验 CMISIS接口实验还是得JlINKCMSIS 简介LiteOS->CMSIS任务间消息交互…...

短视频矩阵系统文案创作功能开发实践,定制化开发
在短视频行业迅猛发展的当下,企业和个人创作者为了扩大影响力、提升传播效果,纷纷采用短视频矩阵运营策略,同时管理多个平台、多个账号的内容发布。然而,频繁的文案创作需求让运营者疲于应对,如何高效产出高质量文案成…...
代码随想录刷题day30
1、零钱兑换II 给你一个整数数组 coins 表示不同面额的硬币,另给一个整数 amount 表示总金额。 请你计算并返回可以凑成总金额的硬币组合数。如果任何硬币组合都无法凑出总金额,返回 0 。 假设每一种面额的硬币有无限个。 题目数据保证结果符合 32 位带…...

20个超级好用的 CSS 动画库
分享 20 个最佳 CSS 动画库。 它们中的大多数将生成纯 CSS 代码,而不需要任何外部库。 1.Animate.css 一个开箱即用型的跨浏览器动画库,可供你在项目中使用。 2.Magic Animations CSS3 一组简单的动画,可以包含在你的网页或应用项目中。 3.An…...
DAY 26 函数专题1
函数定义与参数知识点回顾:1. 函数的定义2. 变量作用域:局部变量和全局变量3. 函数的参数类型:位置参数、默认参数、不定参数4. 传递参数的手段:关键词参数5 题目1:计算圆的面积 任务: 编写一…...