PyTorch深度学习实战——基于ResNet模型实现猫狗分类
PyTorch深度学习实战——基于ResNet模型实现猫狗分类
- 0. 前言
- 1. ResNet 架构
- 2. 基于预训练 ResNet 模型实现猫狗分类
- 相关链接
0. 前言
从 VGG11 到 VGG19,不同之处仅在于网络层数,一般来说,神经网络越深,它的准确率就越高。但并非仅增加网络层数,就可以获得更准确的结果,随着网络层数的增加可能会出现以下问题:
- 梯度消失和爆炸:在网络层次过深的情况下,反向传播可能会面临梯度消失和爆炸的问题,导致训练网络时无法收敛
- 过拟合:增加网络深度会带来更多的参数,如果数据样本过少或网络过于复杂,会导致网络过拟合,降低模型的泛化能力
总之,在构建的神经网络过深时,有两个问题:前向传播中,网络的最后几层几乎没有学习到有关原始图像的任何信息;在反向传播中,由于梯度消失(梯度值几乎为零),靠近输入的前几层几乎没有任何梯度更新。
深度残差网络 (ResNet) 的提出就是为了解决上述问题。在 ResNet 中,如果模型没有什么要学习的,那么卷积层可以什么也不做,只是将上一层的输出传递给下一层。但是,如果模型需要学习其他一些特征,则卷积层将前一层的输出作为输入,并学习完成目标任务所需的其它特征。
1. ResNet 架构
ResNet 通过残差结构解决网络过深时出现的问题,让模型能够训练得更深。经典的 ResNet 架构如下所示:

残差结构的基本思想是:每一个残差块都不是直接映射输入信号到输出信号,而是通过学习残差映射来实现:
F ( x ) = H ( x ) − x F(x)=H(x)−x F(x)=H(x)−x
其中, x x x 是输入, H ( x ) H(x) H(x) 是一个表示所需映射的基本块,而 F ( x ) F(x) F(x) 是残差块学习到的映射。换句话说,输入 x x x 通过卷积层,得到特征变换后的输出 F ( x ) F(x) F(x),与输入 x x x 进行逐元素的相加运算,得到最终输出 H ( x ) H(x) H(x):
H ( x ) = x + F ( x ) H(x) = x + F(x) H(x)=x+F(x)
如果某个基本块为恒等映射,则残差块的学习目标就变为学习 F ( x ) = 0 F(x)=0 F(x)=0,也就是让输入信号直接到达残差块的输出层。这样就可以解决梯度消失的问题,可以训练更深的神经网络。
实现过程中 ResNet 中使用 Shortcut Connection (也称跳跃连接, Skip Connection )在残差块中实现跨层连接,从而实现信息的直接传递,跨层连接可以绕过一个或多个卷积层,直接将网络中的浅层信息传递到深层中。
在 ResNet 的残差块中,Shortcut Connection 经常与卷积层或批归一化 (Batch Normalization) 相结合。通过该连接,残差块的激活张量可以直接和下一层的输出相加,理论上,即使是最后一层可能拥有原始图像的全部信息,并且反向传播过程中梯度将可以在几乎没有修改的情况下自由地流向浅层。典型的残差块如下所示:

在传统中顺序堆叠的神经网络中,神经网络通常直接学习 F ( x ) F(x) F(x),其中 x 是来自前一层的输出值,而在残差网络中,利用跳跃连接,将残差信号 F ( x ) F(x) F(x) 加上恒等映射 x x x 得到最终的输出 H ( x ) = F ( x ) + x H(x)=F(x)+x H(x)=F(x)+x。接下来,我们通过在 PyTorch 中构建残差块来深入了解残差网络。
2. 基于预训练 ResNet 模型实现猫狗分类
(1) 在 __init__ 方法中定义一个带有卷积操作的类:
from torch import nnclass ResLayer(nn.Module):def __init__(self,ni,no,kernel_size,stride=1):super(ResLayer, self).__init__()padding = kernel_size - 2self.conv = nn.Sequential(nn.Conv2d(ni, no, kernel_size, stride, padding=padding),nn.ReLU())
在以上代码中,为了确保通过卷积后输出的尺寸保持不变,以便于将输入与卷结果相加,我们通过 padding 控制卷积时输出的尺寸。
(2) 定义 forward 方法:
def forward(self, x):return self.conv(x) + x
在以上代码中,得到的输出是通过卷积操作的输入和原始输入之和。
在 PyTorch 中预训练的基于残差块的 ResNet18 架构如下:

该架构有 18 个可训练网络层,因此被称为 ResNet18 架构。此外,需要注意的是,ResNet18 并不是每个卷积层都会添加跳跃连接,而是在每两层之后使用跳跃连接。
了解了 ResNet 架构之后,构建一个基于预训练 ResNet18 架构的模型来执行狗猫分类任务。构建分类器的流程可以参考在迁移学习中使用预训练 VGG16 模型构建的猫狗分类器。
(3) 加载预训练 ResNet18 模型并检查模型中的模块:
model = models.resnet18(pretrained=True)
ResNet18 模型架构包含以下组件:
- 卷积层
- 批归一化
ReLU激活- 最大池化层
4个ResNet块- 平均池化 (
avgpool) 层 - 全连接层 (
fc) 层
冻结特征提取模块的网络权重,仅替换 avgpool 和 fc 层并更新其中的参数。
(4) 定义模型架构、损失函数和优化器:
def get_model():model = models.resnet18(pretrained=True)for param in model.parameters():param.requires_grad = Falsemodel.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))model.fc = nn.Sequential(nn.Flatten(),nn.Linear(512, 128),nn.ReLU(),nn.Dropout(0.2),nn.Linear(128, 1),nn.Sigmoid())loss_fn = nn.BCELoss()optimizer = torch.optim.Adam(model.parameters(), lr= 1e-3)return model.to(device), loss_fn, optimizer
在模型中,fc 模块的输入形状为 512,因为 avgpool 的输出形状为 batch size x 512 x 1 x 1。定义了模型后,训练模型,随着 epoch 的增加,模型训练和验证准确率的变化(对应模型分别为 ResNet18、ResNet34、ResNet50、ResNet101 和 ResNet152) 如下:

仅对 1000 张图像进行训练时,模型的准确率就可以达到 98% 左右,且准确率随着 ResNet 层数的增加而增加。
相关链接
PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习
相关文章:
PyTorch深度学习实战——基于ResNet模型实现猫狗分类
PyTorch深度学习实战——基于ResNet模型实现猫狗分类 0. 前言1. ResNet 架构2. 基于预训练 ResNet 模型实现猫狗分类相关链接 0. 前言 从 VGG11 到 VGG19,不同之处仅在于网络层数,一般来说,神经网络越深,它的准确率就越高。但并非…...
机器学习第六课--朴素贝叶斯
朴素贝叶斯广泛地应用在文本分类任务中,其中最为经典的场景为垃圾文本分类(如垃圾邮件分类:给定一个邮件,把它自动分类为垃圾或者正常邮件)。这个任务本身是属于文本分析任务,因为对应的数据均为文本类型,所以对于此类任务我们首先…...
基于Java+SpringBoot+Vue的图书借还小程序的设计与实现(亮点:多角色、点赞评论、借书还书、在线支付)
图书借还管理小程序 一、前言二、我的优势2.1 自己的网站2.2 自己的小程序(小蔡coding)2.3 有保障的售后2.4 福利 三、开发环境与技术3.1 MySQL数据库3.2 Vue前端技术3.3 Spring Boot框架3.4 微信小程序 四、功能设计4.1 主要功能描述 五、系统实现5.1 小…...
【校招VIP】前端计算机网络之UDP相关
考点介绍 UDP是一个简单的面向消息的传输层协议,尽管UDP提供标头和有效负载的完整性验证(通过校验和),但它不保证向上层协议提供消息传递,并且UDP层在发送后不会保留UDP 消息的状态。因此,UDP有时被称为不可…...
前缀和实例4(和可被k整除的子数组)
题目: 给定一个整数数组 nums 和一个整数 k ,返回其中元素之和可被 k 整除的(连续、非空) 子数组 的数目。 子数组 是数组的 连续 部分。 示例 1: 输入:nums [4,5,0,-2,-3,1], k 5 输出:7 …...
Android获取系统读取权限
第一步在Androidifest.xml文件中加上授权语句 <uses-permission android:name"android.permission.WRITE_EXTERNAL_STORAGE"/><uses-permission android:name"android.permission.READ_EXTERNAL_STORAGE"/>并且在Application标签下添加 androi…...
输入学生成绩(最多不超过40),输入为负值时表示输入结束,统计成绩高于平均成绩的学生人数
#include<stdio.h> #define N 40 int scanfscore(int score[N]) {int i -1;do {i;printf("输入学生成绩:");scanf("%d", &score[i]);} while (score[i] > 0);return i; } int average(int score[N], int n) {int j 0;int k 0;double sum …...
【力扣周赛】第 363 场周赛(完全平方数和质因数分解)
文章目录 竞赛链接Q1:100031. 计算 K 置位下标对应元素的和竞赛时代码写法2——手写二进制中1的数量 Q2:100040. 让所有学生保持开心的分组方法数(排序后枚举分界)竞赛时代码 Q3:100033. 最大合金数(二分答…...
RocketMQ的介绍和环境搭建
一、介绍 我也不知道是啥,知道有什么用、怎么用就行了,说到mq(MessageQueue)就是消息队列,队列是先进先出的一种数据结构,但是RocketMQ不一定是这样,简单的理解一下,就是临时存储的…...
【web开发】7、Django(2)
提示:文章写完后,目录可以自动生成,如何生成可参考右边的帮助文档 文章目录 一、部门列表二、部门管理(增删改)三、用户管理过渡到modelform组件四、modelform实例:靓号操作五、自定义分页组件六、datepick…...
Prometheus+Grafana可视化监控【Nginx状态】
文章目录 一、安装Docker二、安装Nginx(Docker容器方式)三、安装Prometheus四、安装Grafana五、Pronetheus和Grafana相关联六、安装nginx_exporter七、Grafana添加Nginx监控模板 一、安装Docker 注意:我这里使用之前写好脚本进行安装Docker,如果已经有D…...
R 语言的安装教程
一、下载相关软件 1、R 下载 官网:R: The R Project for Statistical Computing 找到中国镜像,下载快 历史版本点击这里 2、Rtools 下载 进入镜像后,点击这里 然后选择与上面下载的R版本相对应的版本即可 3、Rstudio 下载 官网࿱…...
uniapp-提现功能(demo)
页面布局 提现页面 有一个输入框 一个提现按钮 一段提现全部的文字 首先用v-model 和data内的数据双向绑定 输入框逻辑分析 输入框的逻辑 为了符合日常输出 所以要对输入框加一些条件限制 因为是提现 所以对输入的字符做筛选,只允许出现小数点和数字 这里用正则实现的小数点…...
Spring 篇
1、什么是 Spring? Spring是一个轻量级的IOC和AOP容器框架。是为Java应用程序提供基础性服务的一套框架,目的是用于简化企业应用程序的开发,它使得开发者只需要关心业务需求。常见的配置方式有三种:基于XML的配置、基于注解的配置…...
three.js简单3D图形的使用
npm init vitelatest //创建一个vite的脚手架 选择 Vanilla 之后自己处理一下 在main.js中写入 // 导入three.js import * as THREE from three// 创建场景 const scene new THREE.Scene();// 创建相机 const camera new THREE.PerspectiveCamera(45, //视角window.inner…...
spark withColumn的使用(笔记)
目录 前言: spark withColumn的语法及使用: 准备源数据演示: 完整实例代码: 前言: withColumn():是Apache Spark中用于DataFrame操作的函数之一,它的作用是在DataFrame中添加或替换列ÿ…...
PTA:7-1 线性表的合并
线性表的合并 题目输入样例输出样例 代码解析 题目 输入样例 4 7 5 3 11 3 2 6 3输出样例 7 5 3 11 2 6 代码 #include<iostream> #include<vector> using namespace std;bool checkrep(const vector<int>& arr, int x) {for (int element : arr) {i…...
Spring 的创建和日志框架的整合
目录 一、第一个 Spring 项目 1、配置环境 2、Spring 的 jar 包 Maven 项目导入 jar 包和设置国内源的方法: 3、Spring 的配置文件 4、Spring 的核心 API ApplicationContext 4、程序开发 5、细节分析 (1)名词解释 (2&…...
11-集合和学生管理系统
1.ArrayList 集合和数组的优势对比: 长度可变添加数据的时候不需要考虑索引,默认将数据添加到末尾 1.1 ArrayList类概述 什么是集合 提供一种存储空间可变的存储模型,存储的数据容量可以发生改变 ArrayList集合的特点 长度可以变化…...
C语言进阶指针(3) ——qsort的实现
大家好,我们今天来学习回调函数qsort的实现。 首先让我们打开cplusplus.com找到qsort函数。 我们看到这个函数就可以看到它的头文件和参数信息。 #include<stdlib.h> void qsort (void* base, size_t num, size_t size, int (*compar)(const void*,const voi…...
XCTF-web-easyupload
试了试php,php7,pht,phtml等,都没有用 尝试.user.ini 抓包修改将.user.ini修改为jpg图片 在上传一个123.jpg 用蚁剑连接,得到flag...
Spark 之 入门讲解详细版(1)
1、简介 1.1 Spark简介 Spark是加州大学伯克利分校AMP实验室(Algorithms, Machines, and People Lab)开发通用内存并行计算框架。Spark在2013年6月进入Apache成为孵化项目,8个月后成为Apache顶级项目,速度之快足见过人之处&…...
【Linux】C语言执行shell指令
在C语言中执行Shell指令 在C语言中,有几种方法可以执行Shell指令: 1. 使用system()函数 这是最简单的方法,包含在stdlib.h头文件中: #include <stdlib.h>int main() {system("ls -l"); // 执行ls -l命令retu…...
java 实现excel文件转pdf | 无水印 | 无限制
文章目录 目录 文章目录 前言 1.项目远程仓库配置 2.pom文件引入相关依赖 3.代码破解 二、Excel转PDF 1.代码实现 2.Aspose.License.xml 授权文件 总结 前言 java处理excel转pdf一直没找到什么好用的免费jar包工具,自己手写的难度,恐怕高级程序员花费一年的事件,也…...
LeetCode - 394. 字符串解码
题目 394. 字符串解码 - 力扣(LeetCode) 思路 使用两个栈:一个存储重复次数,一个存储字符串 遍历输入字符串: 数字处理:遇到数字时,累积计算重复次数左括号处理:保存当前状态&a…...
抖音增长新引擎:品融电商,一站式全案代运营领跑者
抖音增长新引擎:品融电商,一站式全案代运营领跑者 在抖音这个日活超7亿的流量汪洋中,品牌如何破浪前行?自建团队成本高、效果难控;碎片化运营又难成合力——这正是许多企业面临的增长困局。品融电商以「抖音全案代运营…...
质量体系的重要
质量体系是为确保产品、服务或过程质量满足规定要求,由相互关联的要素构成的有机整体。其核心内容可归纳为以下五个方面: 🏛️ 一、组织架构与职责 质量体系明确组织内各部门、岗位的职责与权限,形成层级清晰的管理网络…...
sqlserver 根据指定字符 解析拼接字符串
DECLARE LotNo NVARCHAR(50)A,B,C DECLARE xml XML ( SELECT <x> REPLACE(LotNo, ,, </x><x>) </x> ) DECLARE ErrorCode NVARCHAR(50) -- 提取 XML 中的值 SELECT value x.value(., VARCHAR(MAX))…...
新能源汽车智慧充电桩管理方案:新能源充电桩散热问题及消防安全监管方案
随着新能源汽车的快速普及,充电桩作为核心配套设施,其安全性与可靠性备受关注。然而,在高温、高负荷运行环境下,充电桩的散热问题与消防安全隐患日益凸显,成为制约行业发展的关键瓶颈。 如何通过智慧化管理手段优化散…...
多种风格导航菜单 HTML 实现(附源码)
下面我将为您展示 6 种不同风格的导航菜单实现,每种都包含完整 HTML、CSS 和 JavaScript 代码。 1. 简约水平导航栏 <!DOCTYPE html> <html lang"zh-CN"> <head><meta charset"UTF-8"><meta name"viewport&qu…...
