神经网络实战--使用迁移学习完成猫狗分类

前言:Hello大家好,我是Dream。 今天来学习一下如何使用基于tensorflow和keras的迁移学习完成猫狗分类,欢迎大家一起前来探讨学习~
本文目录:
- 一、加载数据集
- 1.调用库函数
- 2.加载数据集
- 3.数据集管理
- 二、猫狗数据集介绍
- 1.猫狗数据集介绍:
- 2.图片展示
- 三、MobileNetV2网络介绍
- 1.加载tensorflow提供的预训练模型
- 2.轻量级网络——MobileNetV2
- 3.MobileNetV2的网络模块
- 四、搭建迁移学习
- 1.训练
- 2.训练结果可视化
- 3.输出训练的准确率
- 4.用cnn工具可视化一批数据的预测结果
- 5.数据输出
- 6.用cnn工具可视化一个数据样本的各层输出
- 7.输出结果图像
- 五、源码获取
说明:在此试验下,我们使用的是使用tf2.x版本,在jupyter环境下完成
在本文中,我们将主要完成以下任务:
-
实现基于tensorflow和keras的迁移学习
-
加载tensorflow提供的数据集(不得使用cifar10)
-
需要使用markdown单元格对数据集进行说明
-
加载tensorflow提供的预训练模型(不得使用vgg16)
-
需要使用markdown单元格对原始模型进行说明
-
网络末端连接任意结构的输出端网络
-
用图表显示准确率和损失函数
-
用cnn工具可视化一批数据的预测结果
-
用cnn工具可视化一个数据样本的各层输出
一、加载数据集
1.调用库函数
import matplotlib.pyplot as plt
import numpy as np
import os
import tensorflow as tf
import cnn_utils
from tensorflow.keras.preprocessing import image_dataset_from_directory
from tensorflow.keras.layers import GlobalAveragePooling2D,Dense,Input,Dropout
2.加载数据集
数据集加载,数据是通过这个网站下载的猫狗数据集:http://aimaksen.bslience.cn/cats_and_dogs_filtered.zip,实验中为了训练方便,我们取了一个较小的数据集。
path_to_zip = tf.keras.utils.get_file('data.zip',origin='http://aimaksen.bslience.cn/cats_and_dogs_filtered.zip',extract=True,
)
PATH = os.path.join(os.path.dirname(path_to_zip), 'cats_and_dogs_filtered')train_dir = os.path.join(PATH, 'train')
validation_dir = os.path.join(PATH, 'validation')BATCH_SIZE = 32
IMG_SIZE = (160, 160)
3.数据集管理
使用image_dataset_from_director进行数据集管理,使用ImageDataGenerator训练过程中会出现错误,不知道是什么原因,就使用了原始的image_dataset_from_director方法进行数据集管理。
train_dataset = image_dataset_from_directory(train_dir,shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE)validation_dataset = image_dataset_from_directory(validation_dir,shuffle=True,batch_size=BATCH_SIZE,image_size=IMG_SIZE)
二、猫狗数据集介绍
1.猫狗数据集介绍:
猫狗数据集包括25000张训练图片,12500张测试图片,包括猫和狗两种图片。在此次实验中为了训练方便,我们取了一个较小的数据集。 数据解压之后会有两个文件夹,一个是 “train”,一个是 “test”,顾名思义一个是用来训练的,另一个是作为检验正确性的数据。

在train文件夹里边是一些已经命名好的图像,有猫也有狗。而在test文件夹中是只有编号名的图像。

2.图片展示
下面是数据集中的图片展示:
class_names = ['cats', 'dogs']plt.figure(figsize=(10, 10))
for images, labels in train_dataset.take(1):for i in range(9):ax = plt.subplot(3, 3, i + 1)plt.imshow(images[i].numpy().astype("uint8"))plt.title(class_names[labels[i]])plt.axis("off")
🌟🌟🌟 这里是输出的结果:✨✨✨

三、MobileNetV2网络介绍
1.加载tensorflow提供的预训练模型
val_batches = tf.data.experimental.cardinality(validation_dataset)
test_dataset = validation_dataset.take(val_batches // 5)
validation_dataset = validation_dataset.skip(val_batches // 5)
2.轻量级网络——MobileNetV2
使用轻量级网络——MobileNetV2进行数据预处理 说明: MobileNetV2是基于倒置的残差结构,普通的残差结构是先经过 1x1 的卷积核把 feature map的通道数压下来,然后经过 3x3 的卷积核,最后再用 1x1 的卷积核将通道数扩张回去,即先压缩后扩张,而MobileNetV2的倒置残差结构是先扩张后压缩。

3.MobileNetV2的网络模块
MobileNetV2的网络模块样子是这样的:

MobileNetV2是基于深度级可分离卷积构建的网络,它是将标准卷积拆分为了两个操作:深度卷积 和 逐点卷积,深度卷积和标准卷积不同,对于标准卷积其卷积核是用在所有的输入通道上,而深度卷积针对每个输入通道采用不同的卷积核,就是说一个卷积核对应一个输入通道,所以说深度卷积是depth级别的操作。而逐点卷积其实就是普通的卷积,只不过其采用1x1的卷积核。
MobileNetV2的模型如下图所示,其中t为Bottleneck内部升维的倍数,c为通道数,n为该bottleneck重复的次数,s为sride:

其中,当stride=1时,才会使用elementwise 的sum将输入和输出特征连接(如下图左侧);stride=2时,无short cut连接输入和输出特征(下图右侧):

四、搭建迁移学习
1.训练
inital_input = tf.keras.applications.mobilenet_v2.preprocess_input
IMG_SHAPE = IMG_SIZE + (3,)
base_model = tf.keras.applications.MobileNetV2(input_shape=IMG_SHAPE,include_top=False,weights='imagenet')
base_model.trainable = False
base_model.summary()
🌟🌟🌟 这里是输出的结果:✨✨✨

2.训练结果可视化
用图表显示准确率和损失函数
# 训练结果可视化,用图表显示准确率和损失函数
acc = history.history['accuracy']
val_acc = history.history['val_accuracy']loss = history.history['loss']
val_loss = history.history['val_loss']epochs_range=range(initial_epochs)
plt.figure(figsize=(8,8))
plt.subplot(2,1,1)
plt.plot(epochs_range, acc, label="Training Accuracy")
plt.plot(epochs_range, val_acc,label="Validation Accuracy")
plt.legend()
plt.title("Training and Validation Accuracy")plt.subplot(2,1,2)
plt.plot(epochs_range, loss, label="Training Loss")
plt.plot(epochs_range, val_loss,label="Validation Loss")
plt.legend()
plt.title("Training and Validation Loss")
plt.show()
🌟🌟🌟 这里是输出的结果:✨✨✨

3.输出训练的准确率
# 输出训练的准确率
test_loss, test_accuracy = model.evaluate(test_dataset)
print('test accuracy: {:.2f}'.format(test_accuracy))
🌟🌟🌟 这里是输出的结果:✨✨✨

4.用cnn工具可视化一批数据的预测结果
label_dict = {0: 'cat',1: 'dog'
}test_image_batch, test_label_batch = test_dataset.as_numpy_iterator().next()
# 编码成uint8 以图片形式输出
test_image_batch = test_image_batch.astype('uint8')cnn_utils.plot_predictions(model, test_image_batch, test_label_batch, label_dict, 32, 5, 5)
🌟🌟🌟 这里是输出的结果:✨✨✨

5.数据输出
# 数据输出,数字化特征图
test_image_batch, test_label_batch = train_dataset.as_numpy_iterator().next()img_idx = 0
random_batch = np.random.permutation(np.arange(0,len(test_image_batch)))[:BATCH_SIZE]
image_activation = test_image_batch[random_batch[img_idx]:random_batch[img_idx]+1]cnn_utils.get_activations(base_model, image_activation[0])
🌟🌟🌟 这里是输出的结果:✨✨✨

6.用cnn工具可视化一个数据样本的各层输出
cnn_utils.display_activations(cnn_utils.get_activations(base_model, image_activation[0]))
🌟🌟🌟 这里是输出的结果:✨✨✨

7.输出结果图像
🌟🌟🌟 这里是输出的结果:✨✨✨



五、源码获取
关注此公众号:人生苦短我用Pythons,回复 神经网络源码获取源码,快点击我吧
🌲🌲🌲 好啦,这就是今天要分享给大家的全部内容了,我们下期再见!
❤️❤️❤️如果你喜欢的话,就不要吝惜你的一键三连了~


最后,有任何问题,欢迎关注下面的公众号,获取第一时间消息、作者联系方式及每周抽奖等多重好礼! ↓↓↓
相关文章:
神经网络实战--使用迁移学习完成猫狗分类
前言: Hello大家好,我是Dream。 今天来学习一下如何使用基于tensorflow和keras的迁移学习完成猫狗分类,欢迎大家一起前来探讨学习~ 本文目录:一、加载数据集1.调用库函数2.加载数据集3.数据集管理二、猫狗数据集介绍1.猫狗数据集介…...
Attention机制 学习笔记
学习自https://easyai.tech/ai-definition/attention/ Attention本质 Attention(注意力)机制如果浅层的理解,跟他的名字非常匹配。他的核心逻辑就是“从关注全部到关注重点”。 比如我们人在看图片时,对图片的不同地方的注意力…...
数据类型与运算符
1.字符型作用: 字符型变量用于显示单个字符语法: char cc a ;注意1: 在显示字符型变量时,用单引号将字符括起来,不要用双引号注意2: 单引号内只能有一个字符,不可以是字符串C和C中字符型变量只占用1个字节。字符型变是并不是把字符本身放到内存中存储&am…...
算法刷题-二叉树的锯齿形层序遍历、用栈实现队列 栈设计、买卖股票的最佳时机 IV
文章目录二叉树的锯齿形层序遍历(树、广度优先搜索)用栈实现队列(栈、设计)买卖股票的最佳时机 IV(数组、动态规划)二叉树的锯齿形层序遍历(树、广度优先搜索) 给定一个二叉树&…...
华为OD机试 - 最小传递延迟(Python)| 代码编写思路+核心知识点
最小传递延迟 题目 通讯网络中有 N 个网络节点 用 1 ~ N 进行标识 网络通过一个有向无环图进行表示 其中图的边的值,表示节点之间的消息传递延迟 现给定相连节点之间的延时列表 times[i]={u,v,w} 其中 u 表示源节点,v 表示目的节点,w 表示 u 和 v 之间的消息传递延时 请计…...
集中供热调度系统天然气仪表内网仪表图像识别案例
一、项目需求 出于能耗采集与冬季集中供暖工作的节能和能耗分析需要,要采集现场的6块天然气表计,并存储进入客户的mySQL数据库中,现场采集的表计不允许接线,且网络环境为内网环境,需要采集表计数据并存入数据库&#…...
笔试题-2023-复旦微-数字IC设计【纯净题目版】
回到首页:2023 数字IC设计秋招复盘——数十家公司笔试题、面试实录 推荐内容:数字IC设计学习比较实用的资料推荐 题目背景 笔试时间:2022.07.26应聘岗位:数字前端工程师笔试时长:120min笔试平台:赛码题目类型:基础题(10道)、选做题(10道)、验证题(5道)主观评价 难…...
【Linux】冯诺依曼体系结构和操作系统概念
文章目录🎪 冯诺依曼体系结构🚀1.体系概述🚀2.CPU和内存的数据交换🚀3.体系结构中数据的流动🎪 操作系统概念理解🚀1.简述🚀2.设计目的🚀3.定位🚀4.理解🚀5.管…...
HTML5之HTML基础学习笔记
列表标签 列表的应用场景 场景:在网页中按照行展示关联性的内容,如:新闻列表、排行榜、账单等特点:按照行的方式,整齐显示内容种类:无序列表、有序列表、自定义列表 这是老师PPT上的内容, 列表…...
FreeRTOS信号量 | FreeRTOS十
目录 说明: 一、信号量 1.1、信号量简介 1.2、信号量特点 二、二值信号量 2.1、二值信号量简介 2.2、获取与释放二值信号量函数 2.3、二值信号量使用过程与相关API函数 2.4、创建二值信号量函数了解 2.5、释放二值信号量了解 2.6、获取二值信号量了解 三…...
【SpringBoot】SpringBoot常用注解
一、前言首先这里说的SpringBoot常用注解是指在我们开发项目过程中,我们经常使用的注解,包含Spring、SpringBoot、SpringCloud、SpringMVC等这些框架中的注解,而不仅仅是SpringBoot中的注解。这里只是作一个注解列举,每个注解具体…...
数据一致性
目录一、AOP 动态代理切入方法(1) Aspect Oriented Programming(2) 切入点表达式二、SpringBoot 项目扫描类(1) ResourceLoader 扫描类(2) Map 的 computeIfAbsent 方法(3) 反射几个常用 api① 创建一个测试注解② 创建测试 PO 类③ 反射 api 获取指定类的指定注解信息(4) 返回…...
Docker不做虚拟化内核,对.NET有什么影响?
引子前两天刷抖音,看见了这样一个问题。问题:容器化不做虚拟内核,会有什么弊端?Java很多方法会跟CPU的核数有关,这个时候调用系统函数,读到的是宿主机信息,而不是我们限制资源的大小。思考&…...
HTML总结
CSS代码风格 空格规范: 1. 属性值前面,冒号后面,保留一个空格; 2. 选择器(标签)和大括号中间保留空格。 基本语法概述: 1.HTML标签是由尖括号包围的关键词,如<html> 2.HTM…...
ByteHouse:基于ClickHouse的实时数仓能力升级解读
更多技术交流、求职机会,欢迎关注字节跳动数据平台微信公众号,回复【1】进入官方交流群 ByteHouse是火山引擎上的一款云原生数据仓库,为用户带来极速分析体验,能够支撑实时数据分析和海量数据离线分析。便捷的弹性扩缩容能力&…...
[SSD固态硬盘技术 15] FTL映射表的神秘面纱
为什么需要映射表?固态硬盘的存储器件采用的是闪存[5],具有以下几个特点: (1)读写基本单位是以页(Page)为单位,擦除是以块(Block)为单位。...
浅析依赖注入框架的生命周期(以 InversifyJS 为例)
在上一篇介绍了 VSCode 的依赖注入设计,并且实现了一个简单的 IOC 框架。但是距离成为一个生产环境可用的框架还差的很远。 行业内已经有许多非常优秀的开源 IOC 框架,它们划分了更为清晰地模块来应对复杂情况下依赖注入运行的正确性。 这里我将以 Inv…...
HER2靶向药物研发进展-销售数据-上市药品前景分析
HER2长期作为肿瘤领域的热门靶点之一,其原因是它在多部位、多种形式的癌症中均有异常的表达,据研究表明HER2除了在胃癌、胆道癌、胆管癌、乳腺癌、卵巢癌、结肠癌、膀胱癌、肺癌、子宫颈癌、子宫浆液性子宫内膜癌、头颈癌、食道癌中的异常表达还存在于多…...
【第38天】不同路径数问题 | 网格 dp 入门
本文已收录于专栏🌸《Java入门一百例》🌸学习指引序、专栏前言一、网格模型二、【例题1】1、题目描述2、解题思路3、模板代码4、代码解析5.原题链接三、【例题2】1、题目描述2、解题思路3、模板代码4、代码解析5.原题链接三、推荐专栏四、课后习题序、专…...
LINUX之链接命令
链接命令学习目标能够说出软链接的创建方式能够说出硬链接的创建方式1. 链接命令的介绍链接命令是创建链接文件,链接文件分为:软链接硬链接命令说明ln -s创建软链接ln创建硬链接2. 软链接类似于Windows下的快捷方式,当一个源文件的目录层级比较深&#x…...
利用ngx_stream_return_module构建简易 TCP/UDP 响应网关
一、模块概述 ngx_stream_return_module 提供了一个极简的指令: return <value>;在收到客户端连接后,立即将 <value> 写回并关闭连接。<value> 支持内嵌文本和内置变量(如 $time_iso8601、$remote_addr 等)&a…...
在 Nginx Stream 层“改写”MQTT ngx_stream_mqtt_filter_module
1、为什么要修改 CONNECT 报文? 多租户隔离:自动为接入设备追加租户前缀,后端按 ClientID 拆分队列。零代码鉴权:将入站用户名替换为 OAuth Access-Token,后端 Broker 统一校验。灰度发布:根据 IP/地理位写…...
Python实现prophet 理论及参数优化
文章目录 Prophet理论及模型参数介绍Python代码完整实现prophet 添加外部数据进行模型优化 之前初步学习prophet的时候,写过一篇简单实现,后期随着对该模型的深入研究,本次记录涉及到prophet 的公式以及参数调优,从公式可以更直观…...
SpringBoot+uniapp 的 Champion 俱乐部微信小程序设计与实现,论文初版实现
摘要 本论文旨在设计并实现基于 SpringBoot 和 uniapp 的 Champion 俱乐部微信小程序,以满足俱乐部线上活动推广、会员管理、社交互动等需求。通过 SpringBoot 搭建后端服务,提供稳定高效的数据处理与业务逻辑支持;利用 uniapp 实现跨平台前…...
零基础设计模式——行为型模式 - 责任链模式
第四部分:行为型模式 - 责任链模式 (Chain of Responsibility Pattern) 欢迎来到行为型模式的学习!行为型模式关注对象之间的职责分配、算法封装和对象间的交互。我们将学习的第一个行为型模式是责任链模式。 核心思想:使多个对象都有机会处…...
css3笔记 (1) 自用
outline: none 用于移除元素获得焦点时默认的轮廓线 broder:0 用于移除边框 font-size:0 用于设置字体不显示 list-style: none 消除<li> 标签默认样式 margin: xx auto 版心居中 width:100% 通栏 vertical-align 作用于行内元素 / 表格单元格ÿ…...
项目部署到Linux上时遇到的错误(Redis,MySQL,无法正确连接,地址占用问题)
Redis无法正确连接 在运行jar包时出现了这样的错误 查询得知问题核心在于Redis连接失败,具体原因是客户端发送了密码认证请求,但Redis服务器未设置密码 1.为Redis设置密码(匹配客户端配置) 步骤: 1).修…...
推荐 github 项目:GeminiImageApp(图片生成方向,可以做一定的素材)
推荐 github 项目:GeminiImageApp(图片生成方向,可以做一定的素材) 这个项目能干嘛? 使用 gemini 2.0 的 api 和 google 其他的 api 来做衍生处理 简化和优化了文生图和图生图的行为(我的最主要) 并且有一些目标检测和切割(我用不到) 视频和 imagefx 因为没 a…...
华为OD机考-机房布局
import java.util.*;public class DemoTest5 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseSystem.out.println(solve(in.nextLine()));}}priv…...
R 语言科研绘图第 55 期 --- 网络图-聚类
在发表科研论文的过程中,科研绘图是必不可少的,一张好看的图形会是文章很大的加分项。 为了便于使用,本系列文章介绍的所有绘图都已收录到了 sciRplot 项目中,获取方式: R 语言科研绘图模板 --- sciRplothttps://mp.…...
