当前位置: 首页 > news >正文

PyTorch深度学习实战——基于ResNet模型实现猫狗分类

PyTorch深度学习实战——基于ResNet模型实现猫狗分类

    • 0. 前言
    • 1. ResNet 架构
    • 2. 基于预训练 ResNet 模型实现猫狗分类
    • 相关链接

0. 前言

VGG11VGG19,不同之处仅在于网络层数,一般来说,神经网络越深,它的准确率就越高。但并非仅增加网络层数,就可以获得更准确的结果,随着网络层数的增加可能会出现以下问题:

  • 梯度消失和爆炸:在网络层次过深的情况下,反向传播可能会面临梯度消失和爆炸的问题,导致训练网络时无法收敛
  • 过拟合:增加网络深度会带来更多的参数,如果数据样本过少或网络过于复杂,会导致网络过拟合,降低模型的泛化能力

总之,在构建的神经网络过深时,有两个问题:前向传播中,网络的最后几层几乎没有学习到有关原始图像的任何信息;在反向传播中,由于梯度消失(梯度值几乎为零),靠近输入的前几层几乎没有任何梯度更新。
深度残差网络 (ResNet) 的提出就是为了解决上述问题。在 ResNet 中,如果模型没有什么要学习的,那么卷积层可以什么也不做,只是将上一层的输出传递给下一层。但是,如果模型需要学习其他一些特征,则卷积层将前一层的输出作为输入,并学习完成目标任务所需的其它特征。

1. ResNet 架构

ResNet 通过残差结构解决网络过深时出现的问题,让模型能够训练得更深。经典的 ResNet 架构如下所示:

ResNet架构
残差结构的基本思想是:每一个残差块都不是直接映射输入信号到输出信号,而是通过学习残差映射来实现:
F ( x ) = H ( x ) − x F(x)=H(x)−x F(x)=H(x)x

其中, x x x 是输入, H ( x ) H(x) H(x) 是一个表示所需映射的基本块,而 F ( x ) F(x) F(x) 是残差块学习到的映射。换句话说,输入 x x x 通过卷积层,得到特征变换后的输出 F ( x ) F(x) F(x),与输入 x x x 进行逐元素的相加运算,得到最终输出 H ( x ) H(x) H(x)

H ( x ) = x + F ( x ) H(x) = x + F(x) H(x)=x+F(x)

如果某个基本块为恒等映射,则残差块的学习目标就变为学习 F ( x ) = 0 F(x)=0 F(x)=0,也就是让输入信号直接到达残差块的输出层。这样就可以解决梯度消失的问题,可以训练更深的神经网络。
实现过程中 ResNet 中使用 Shortcut Connection (也称跳跃连接, Skip Connection )在残差块中实现跨层连接,从而实现信息的直接传递,跨层连接可以绕过一个或多个卷积层,直接将网络中的浅层信息传递到深层中。
ResNet 的残差块中,Shortcut Connection 经常与卷积层或批归一化 (Batch Normalization) 相结合。通过该连接,残差块的激活张量可以直接和下一层的输出相加,理论上,即使是最后一层可能拥有原始图像的全部信息,并且反向传播过程中梯度将可以在几乎没有修改的情况下自由地流向浅层。典型的残差块如下所示:

残差模块

在传统中顺序堆叠的神经网络中,神经网络通常直接学习 F ( x ) F(x) F(x),其中 x 是来自前一层的输出值,而在残差网络中,利用跳跃连接,将残差信号 F ( x ) F(x) F(x) 加上恒等映射 x x x 得到最终的输出 H ( x ) = F ( x ) + x H(x)=F(x)+x H(x)=F(x)+x。接下来,我们通过在 PyTorch 中构建残差块来深入了解残差网络。

2. 基于预训练 ResNet 模型实现猫狗分类

(1)__init__ 方法中定义一个带有卷积操作的类:

from torch import nnclass ResLayer(nn.Module):def __init__(self,ni,no,kernel_size,stride=1):super(ResLayer, self).__init__()padding = kernel_size - 2self.conv = nn.Sequential(nn.Conv2d(ni, no, kernel_size, stride, padding=padding),nn.ReLU())

在以上代码中,为了确保通过卷积后输出的尺寸保持不变,以便于将输入与卷结果相加,我们通过 padding 控制卷积时输出的尺寸。

(2) 定义 forward 方法:

    def forward(self, x):return self.conv(x) + x

在以上代码中,得到的输出是通过卷积操作的输入和原始输入之和。

PyTorch 中预训练的基于残差块的 ResNet18 架构如下:

请添加图片描述
该架构有 18 个可训练网络层,因此被称为 ResNet18 架构。此外,需要注意的是,ResNet18 并不是每个卷积层都会添加跳跃连接,而是在每两层之后使用跳跃连接。
了解了 ResNet 架构之后,构建一个基于预训练 ResNet18 架构的模型来执行狗猫分类任务。构建分类器的流程可以参考在迁移学习中使用预训练 VGG16 模型构建的猫狗分类器。

(3) 加载预训练 ResNet18 模型并检查模型中的模块:

model = models.resnet18(pretrained=True)

ResNet18 模型架构包含以下组件:

  • 卷积层
  • 批归一化
  • ReLU 激活
  • 最大池化层
  • 4ResNet
  • 平均池化 (avgpool) 层
  • 全连接层 (fc) 层

冻结特征提取模块的网络权重,仅替换 avgpoolfc 层并更新其中的参数。

(4) 定义模型架构、损失函数和优化器:

def get_model():model = models.resnet18(pretrained=True)for param in model.parameters():param.requires_grad = Falsemodel.avgpool = nn.AdaptiveAvgPool2d(output_size=(1,1))model.fc = nn.Sequential(nn.Flatten(),nn.Linear(512, 128),nn.ReLU(),nn.Dropout(0.2),nn.Linear(128, 1),nn.Sigmoid())loss_fn = nn.BCELoss()optimizer = torch.optim.Adam(model.parameters(), lr= 1e-3)return model.to(device), loss_fn, optimizer

在模型中,fc 模块的输入形状为 512,因为 avgpool 的输出形状为 batch size x 512 x 1 x 1。定义了模型后,训练模型,随着 epoch 的增加,模型训练和验证准确率的变化(对应模型分别为 ResNet18ResNet34ResNet50ResNet101ResNet152) 如下:

模型训练和验证准确率
仅对 1000 张图像进行训练时,模型的准确率就可以达到 98% 左右,且准确率随着 ResNet 层数的增加而增加。

相关链接

PyTorch深度学习实战(1)——神经网络与模型训练过程详解
PyTorch深度学习实战(2)——PyTorch基础
PyTorch深度学习实战(3)——使用PyTorch构建神经网络
PyTorch深度学习实战(4)——常用激活函数和损失函数详解
PyTorch深度学习实战(5)——计算机视觉基础
PyTorch深度学习实战(6)——神经网络性能优化技术
PyTorch深度学习实战(7)——批大小对神经网络训练的影响
PyTorch深度学习实战(8)——批归一化
PyTorch深度学习实战(9)——学习率优化
PyTorch深度学习实战(10)——过拟合及其解决方法
PyTorch深度学习实战(11)——卷积神经网络
PyTorch深度学习实战(12)——数据增强
PyTorch深度学习实战(13)——可视化神经网络中间层输出
PyTorch深度学习实战(14)——类激活图
PyTorch深度学习实战(15)——迁移学习

相关文章:

PyTorch深度学习实战——基于ResNet模型实现猫狗分类

PyTorch深度学习实战——基于ResNet模型实现猫狗分类 0. 前言1. ResNet 架构2. 基于预训练 ResNet 模型实现猫狗分类相关链接 0. 前言 从 VGG11 到 VGG19,不同之处仅在于网络层数,一般来说,神经网络越深,它的准确率就越高。但并非…...

机器学习第六课--朴素贝叶斯

朴素贝叶斯广泛地应用在文本分类任务中,其中最为经典的场景为垃圾文本分类(如垃圾邮件分类:给定一个邮件,把它自动分类为垃圾或者正常邮件)。这个任务本身是属于文本分析任务,因为对应的数据均为文本类型,所以对于此类任务我们首先…...

基于Java+SpringBoot+Vue的图书借还小程序的设计与实现(亮点:多角色、点赞评论、借书还书、在线支付)

图书借还管理小程序 一、前言二、我的优势2.1 自己的网站2.2 自己的小程序(小蔡coding)2.3 有保障的售后2.4 福利 三、开发环境与技术3.1 MySQL数据库3.2 Vue前端技术3.3 Spring Boot框架3.4 微信小程序 四、功能设计4.1 主要功能描述 五、系统实现5.1 小…...

【校招VIP】前端计算机网络之UDP相关

考点介绍 UDP是一个简单的面向消息的传输层协议,尽管UDP提供标头和有效负载的完整性验证(通过校验和),但它不保证向上层协议提供消息传递,并且UDP层在发送后不会保留UDP 消息的状态。因此,UDP有时被称为不可…...

前缀和实例4(和可被k整除的子数组)

题目: 给定一个整数数组 nums 和一个整数 k ,返回其中元素之和可被 k 整除的(连续、非空) 子数组 的数目。 子数组 是数组的 连续 部分。 示例 1: 输入:nums [4,5,0,-2,-3,1], k 5 输出:7 …...

Android获取系统读取权限

第一步在Androidifest.xml文件中加上授权语句 <uses-permission android:name"android.permission.WRITE_EXTERNAL_STORAGE"/><uses-permission android:name"android.permission.READ_EXTERNAL_STORAGE"/>并且在Application标签下添加 androi…...

输入学生成绩(最多不超过40),输入为负值时表示输入结束,统计成绩高于平均成绩的学生人数

#include<stdio.h> #define N 40 int scanfscore(int score[N]) {int i -1;do {i;printf("输入学生成绩:");scanf("%d", &score[i]);} while (score[i] > 0);return i; } int average(int score[N], int n) {int j 0;int k 0;double sum …...

【力扣周赛】第 363 场周赛(完全平方数和质因数分解)

文章目录 竞赛链接Q1&#xff1a;100031. 计算 K 置位下标对应元素的和竞赛时代码写法2——手写二进制中1的数量 Q2&#xff1a;100040. 让所有学生保持开心的分组方法数&#xff08;排序后枚举分界&#xff09;竞赛时代码 Q3&#xff1a;100033. 最大合金数&#xff08;二分答…...

RocketMQ的介绍和环境搭建

一、介绍 我也不知道是啥&#xff0c;知道有什么用、怎么用就行了&#xff0c;说到mq&#xff08;MessageQueue&#xff09;就是消息队列&#xff0c;队列是先进先出的一种数据结构&#xff0c;但是RocketMQ不一定是这样&#xff0c;简单的理解一下&#xff0c;就是临时存储的…...

【web开发】7、Django(2)

提示&#xff1a;文章写完后&#xff0c;目录可以自动生成&#xff0c;如何生成可参考右边的帮助文档 文章目录 一、部门列表二、部门管理&#xff08;增删改&#xff09;三、用户管理过渡到modelform组件四、modelform实例&#xff1a;靓号操作五、自定义分页组件六、datepick…...

Prometheus+Grafana可视化监控【Nginx状态】

文章目录 一、安装Docker二、安装Nginx(Docker容器方式)三、安装Prometheus四、安装Grafana五、Pronetheus和Grafana相关联六、安装nginx_exporter七、Grafana添加Nginx监控模板 一、安装Docker 注意&#xff1a;我这里使用之前写好脚本进行安装Docker&#xff0c;如果已经有D…...

R 语言的安装教程

一、下载相关软件 1、R 下载 官网&#xff1a;R: The R Project for Statistical Computing 找到中国镜像&#xff0c;下载快 历史版本点击这里 2、Rtools 下载 进入镜像后&#xff0c;点击这里 然后选择与上面下载的R版本相对应的版本即可 3、Rstudio 下载 官网&#xff1…...

uniapp-提现功能(demo)

页面布局 提现页面 有一个输入框 一个提现按钮 一段提现全部的文字 首先用v-model 和data内的数据双向绑定 输入框逻辑分析 输入框的逻辑 为了符合日常输出 所以要对输入框加一些条件限制 因为是提现 所以对输入的字符做筛选,只允许出现小数点和数字 这里用正则实现的小数点…...

Spring 篇

1、什么是 Spring&#xff1f; Spring是一个轻量级的IOC和AOP容器框架。是为Java应用程序提供基础性服务的一套框架&#xff0c;目的是用于简化企业应用程序的开发&#xff0c;它使得开发者只需要关心业务需求。常见的配置方式有三种&#xff1a;基于XML的配置、基于注解的配置…...

three.js简单3D图形的使用

npm init vitelatest //创建一个vite的脚手架 选择 Vanilla 之后自己处理一下 在main.js中写入 // 导入three.js import * as THREE from three// 创建场景 const scene new THREE.Scene();// 创建相机 const camera new THREE.PerspectiveCamera(45, //视角window.inner…...

spark withColumn的使用(笔记)

目录 前言&#xff1a; spark withColumn的语法及使用&#xff1a; 准备源数据演示&#xff1a; 完整实例代码&#xff1a; 前言&#xff1a; withColumn()&#xff1a;是Apache Spark中用于DataFrame操作的函数之一&#xff0c;它的作用是在DataFrame中添加或替换列&#xff…...

PTA:7-1 线性表的合并

线性表的合并 题目输入样例输出样例 代码解析 题目 输入样例 4 7 5 3 11 3 2 6 3输出样例 7 5 3 11 2 6 代码 #include<iostream> #include<vector> using namespace std;bool checkrep(const vector<int>& arr, int x) {for (int element : arr) {i…...

Spring 的创建和日志框架的整合

目录 一、第一个 Spring 项目 1、配置环境 2、Spring 的 jar 包 Maven 项目导入 jar 包和设置国内源的方法&#xff1a; 3、Spring 的配置文件 4、Spring 的核心 API ApplicationContext 4、程序开发 5、细节分析 &#xff08;1&#xff09;名词解释 &#xff08;2&…...

11-集合和学生管理系统

1.ArrayList 集合和数组的优势对比&#xff1a; 长度可变添加数据的时候不需要考虑索引&#xff0c;默认将数据添加到末尾 1.1 ArrayList类概述 什么是集合 ​ 提供一种存储空间可变的存储模型&#xff0c;存储的数据容量可以发生改变 ArrayList集合的特点 ​ 长度可以变化…...

C语言进阶指针(3) ——qsort的实现

大家好&#xff0c;我们今天来学习回调函数qsort的实现。 首先让我们打开cplusplus.com找到qsort函数。 我们看到这个函数就可以看到它的头文件和参数信息。 #include<stdlib.h> void qsort (void* base, size_t num, size_t size, int (*compar)(const void*,const voi…...

Rust源码分析——Rc 和 Weak 源码详解

Rc 和 Weak 源码详解 一个值需要被多个所有者拥有 rust中所有权机制在图这种数据结构中&#xff0c;一个节点可能被多个其它节点所指向。那么如何表示图这种数据结构&#xff1f;在多线程中&#xff0c;多个线程可能会持有同一个数据&#xff1f;如何解决这个问题。 Rc rus…...

【网络编程】深入理解TCP协议二(连接管理机制、WAIT_TIME、滑动窗口、流量控制、拥塞控制)

TCP协议 1.连接管理机制2.再谈WAIT_TIME状态2.1理解WAIT_TIME状态2.2解决TIME_WAIT状态引起的bind失败的方法2.3监听套接字listen第二个参数介绍 3.滑动窗口3.1介绍3.2丢包情况分析 4.流量控制5.拥塞控制5.1介绍5.2慢启动 6.捎带应答、延时应答 1.连接管理机制 正常情况下&…...

社区团购商城小程序v18.1开源独立版+前端

新增后台清理缓存功能 修复定位权限 修复无法删除手机端管理员 11月新登录接口修复&#xff01; 修复商家付款到零钱&#xff0c; 修复会员登陆不显示头像&#xff0c; 修复无法修改会员开添加绑定...

MATLAB入门-字符串操作

MATLAB入门-字符串操作 注&#xff1a;本篇文章是学习笔记&#xff0c;课程链接是&#xff1a;link MATLAB中的字符串特性&#xff1a; 无论是字符还是字符串&#xff0c;都要使用单引号来‘’表示&#xff1b;在MATLAB中&#xff0c;字符都是在矩阵中存储的&#xff0c;无论…...

Kong Learning

一、Kong Kong是由Mashape公司开源的可扩展的Api GateWay项目。它运行在调用Api之前&#xff0c;以插件的扩展方式为Api提供了管理。比如&#xff0c;鉴权、限流、监控、健康检查等&#xff0c;Kong是基于lua语言、nginx以及openResty开发的&#xff0c;所有拥有动态路由、负载…...

Python怎样写桌面程序

要编写Python桌面应用程序&#xff0c;可以使用以下几种方法&#xff1a; 1.使用Tkinter模块&#xff1a;Tkinter是Python自带的GUI工具包之一&#xff0c;可以使用它来创建基本的GUI界面。例如&#xff0c;可以创建一个简单的窗口&#xff0c;添加按钮、文本框等控件&#xf…...

蓝桥杯2023年第十四届省赛真题-平方差--题解

蓝桥杯2023年第十四届省赛真题-平方差 时间限制: 3s 内存限制: 320MB 提交: 2379 解决: 469 题目描述 给定 L, R&#xff0c;问 L ≤ x ≤ R 中有多少个数 x 满足存在整数 y,z 使得 x y2 − z2。 输入格式 输入一行包含两个整数 L, R&#xff0c;用一个空格分隔。 输出格…...

iText实战--根据绝对位置添加内容

3.1 direct content 概念简介 pdf内容的4个层级 层级1&#xff1a;在text和graphics底下&#xff0c;PdfWriter.getDirectContentUnder() 层级2&#xff1a;graphics层&#xff0c;Chunk, Images背景&#xff0c;PdfPCell的边界等 层级3&#xff1a;text层&#xff0c;Chun…...

使用navicat for mongodb连接mongodb

使用navicat for mongodb连接mongodb 安装navicat for mongodb连接mongodb 安装navicat for mongodb 上文mongodb7.0安装全过程详解我们说过&#xff0c;在安装的时候并没有勾选install mongodb compass 我们使用navicat去进行可视化的数据库管理 navicat for mongodb下载地址…...

Qt ffmpeg音视频转换工具

Qt ffmpeg音视频转换工具&#xff0c;QProcess方式调用ffmpeg&#xff0c;对音视频文件进行格式转换&#xff0c;支持常见的音视频格式&#xff0c;主要在于QProcess的输出处理以及转换的文件名和后缀的处理&#xff0c;可以进一步加上音视频剪切合并和音视频文件属性查询修改的…...

外贸网站建设维护/补习班

http://symphony.b3log.org/article/1381403388981 #正向、反向代理解释一、什么是squid&#xff1f;squid可以做代理也可以做缓存。squid缓存不仅可以节省宝贵的带宽资源, 也可以大大降低服务器的I/O。squid不仅可以做正向代理, 又可以做反向代理。正向代理, squid后面是客…...

wordpress英文网赚站/北京seo公司wyhseo

文章目录写在开头的话ZStack产品与解决方案介绍00. Zstack 愿景0.1 云计算在企业应用中的价值与意义0.2 企业建设云平台过程中的典型问题0.3 Zstack愿景01. 解决方案1.1 ZStack 产品图谱1.2 ZStack云平台1.3 ZStack 云平台应用场景1.4 ZStack Mini一体机1.5 ZStack Mini 适用场…...

网站套利怎么做/中国十大关键词

题库来源&#xff1a;安全生产模拟考试一点通公众号小程序 安全生产模拟考试一点通&#xff1a;P气瓶充装考试题考前必练&#xff01;安全生产模拟考试一点通每个月更新P气瓶充装模拟考试题目及答案&#xff01;多做几遍&#xff0c;其实通过P气瓶充装证考试很简单。 1、【多选…...

青岛网站建设找/网站优化外包推荐

首先确保机器上安装了openssl和openssl-devel #yum install openssl #yum install openssl-devel 然后就是自己颁发证书给自己 #cd /usr/local/nginx/conf #openssl genrsa -des3 -out server.key 1024 #生成私钥&#xff0c;需要输入密码&#xff0c;记住就…...

注册网站云空间/建站系统推荐

1、数据字典怎么理解&#xff1f;数据字典是指对数据的数据项、数据结构、数据流、数据存储、处理逻辑、外部实体等进行定义和描述&#xff0c;其目的是对数据流程图中的各个元素做出详细的说明。数据字典(Data dictionary)是一种用户可以访问的记录数据库和应用程序源数据的目…...

普陀区网站开发/外包公司为什么没人去

今天做图像旋转练习的时候&#xff0c;要根据鼠标的移动轨迹来确定转过的角度&#xff0c;于是就很自然的想到通过三个点来确定这个转过的角度&#xff1a;图像的中心&#xff0c;鼠标按下的点&#xff0c;鼠标拖到的点。想到使用斜率来计算角度&#xff0c;于是联想到数学公式…...