TensorFlow2实战-系列教程6:迁移学习实战
🧡💛💚TensorFlow2实战-系列教程 总目录
有任何问题欢迎在下面留言
本篇文章的代码运行界面均在Jupyter Notebook中进行
本篇文章配套的代码资源已经上传
1、迁移学习
- 用已经训练好模型的权重参数当做自己任务的模型权重初始化
- 一般全连接层需要自己训练,可以选择是否训练已经训练好的特征提取层
一般情况下根据自己的任务,选择对那些网络进行微调和重新训练:
如果预训练模型的任务和自己任务非常接近,那可能只需要把最后的全连接层重新训练即可
如果自己任务的数据量比较小,那么应该选择重新训练少数层
如果自己任务的数据量比较大,可以适当多选择几层进行训练
2、猫狗识别
import os
import warnings
warnings.filterwarnings("ignore")
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras import layers
from tensorflow.keras import Model
base_dir = './data/cats_and_dogs'
train_dir = os.path.join(base_dir, 'train')
validation_dir = os.path.join(base_dir, 'validation')train_cats_dir = os.path.join(train_dir, 'cats')
train_dogs_dir = os.path.join(train_dir, 'dogs')validation_cats_dir = os.path.join(validation_dir, 'cats')
validation_dogs_dir = os.path.join(validation_dir, 'dogs')
前面的内容和TensorFlow2实战-系列教程3:猫狗识别1完全一样
3、加载预训练模型
from tf.keras.applications.resnet import ResNet50
from tensorflow.keras.applications.resnet import ResNet101
from tensorflow.keras.applications.inception_v3 import InceptionV3
从keras中导入预训练模型,在TensorFlow的keras模块,有很多可以直接导入的预训练权重。
pre_trained_model = ResNet101(input_shape = (75, 75, 3), include_top = False, weights = 'imagenet')
- 加载导入的模型
- input_shape 为输入大小
- include_top为False就是表示不要最后的全连接层
- 这段代码执行后,会自动进行下载
downloading data from
https://storage.googleapis.com/tensorflow/kerasapplications/resnet/resnet101_weights_tf_dim_ordering_tf_kernels_notop.h5
171446536/171446536 [==============================] - 15s 0us/step
for layer in pre_trained_model.layers:layer.trainable = False
选择要进行重新训练的层
4、callback模块
在 TensorFlow 中,回调(Callbacks)是一个强大的工具,用于在训练的不同阶段(例如在每个时代的开始和结束、在每个批次的处理前后)自定义和控制模型的行为,相当于一个监视器:
4.1 callback示例
callbacks = [
# 如果连续两个epoch还没降低就停止:tf.keras.callbacks.EarlyStopping(patience=2, monitor='val_loss'),
# 可以动态改变学习率:tf.keras.callbacks.LearningRateScheduler
# 保存模型:tf.keras.callbacks.ModelCheckpoint
# 自定义方法:tf.keras.callbacks.Callback
]
上面是一个模板,继续我们的猫狗识别的迁移学习项目:
4.2 定义callback
class myCallback(tf.keras.callbacks.Callback):def on_epoch_end(self, epoch, logs={}):if(logs.get('acc')>0.95):print("\nReached 95% accuracy so cancelling training!")self.model.stop_training = True
- 定义一个类,继承Callback
- 定义一个函数,传入epoch值和日志
- 从当前epoch的日志中取出准确率,如果准确率大于95%
- 打印信息
- 停止训练
from tensorflow.keras.optimizers import Adam
x = layers.Flatten()(pre_trained_model.output)
x = layers.Dense(1024, activation='relu')(x)
x = layers.Dropout(0.2)(x)
x = layers.Dense(1, activation='sigmoid')(x)
model = Model(pre_trained_model.input, x)
model.compile(optimizer = Adam(lr=0.001), loss = 'binary_crossentropy', metrics = ['acc'])
- 导入优化器
- 将预训练模型的输出展平为一维
- 定义一个1024的全连接层
- 在这层加入dropout
- 输出全连接层
- 构建模型
- 指定优化器、损失函数、验证方法等配置训练器
5、模型训练
定义需要重新训练的层
train_datagen = ImageDataGenerator(rescale = 1./255.,rotation_range = 40,width_shift_range = 0.2,height_shift_range = 0.2,shear_range = 0.2,zoom_range = 0.2,horizontal_flip = True)test_datagen = ImageDataGenerator( rescale = 1.0/255. )train_generator = train_datagen.flow_from_directory(train_dir,batch_size = 20,class_mode = 'binary', target_size = (75, 75)) validation_generator = test_datagen.flow_from_directory( validation_dir,batch_size = 20,class_mode = 'binary', target_size = (75, 75))
前面的内容和TensorFlow2实战-系列教程3:猫狗识别1一样,制作数据
callbacks = myCallback()
history = model.fit_generator(train_generator,validation_data = validation_generator,steps_per_epoch = 100,epochs = 100,validation_steps = 50,verbose = 2,callbacks=[callbacks])
指定训练参数、数据、加入callback模块到模型中,执行训练,verbose = 2表示每次epoch记录一次日志
打印结果:
Epoch 99/100 100/100 - 76s - loss: 0.6138 - acc: 0.6655 - val_loss: 0.6570 - val_acc: 0.6900
Epoch 100/100 100/100 - 76s - loss: 0.5993 - acc: 0.6735 - val_loss: 0.7176 - val_acc: 0.6910
6、预测效果展示
import matplotlib.pyplot as plt
acc = history.history['acc']
val_acc = history.history['val_acc']
loss = history.history['loss']
val_loss = history.history['val_loss']epochs = range(len(acc))plt.plot(epochs, acc, 'b', label='Training accuracy')
plt.plot(epochs, val_acc, 'r', label='Validation accuracy')
plt.title('Training and validation accuracy')
plt.legend()plt.figure()plt.plot(epochs, loss, 'b', label='Training Loss')
plt.plot(epochs, val_loss, 'r', label='Validation Loss')
plt.title('Training and validation loss')
plt.legend()
plt.show()
相关文章:

TensorFlow2实战-系列教程6:迁移学习实战
🧡💛💚TensorFlow2实战-系列教程 总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在Jupyter Notebook中进行 本篇文章配套的代码资源已经上传 1、迁移学习 用已经训练好模型的权重参数当做自己任务的模型权重初始化一般全连接层需…...

怎样开发adobe indesign插件,具体流程?
文章目录 第一.流程步骤第二.如何调试indesign插件第三.相关资源第四.总结 第一.流程步骤 开发Adobe InDesign插件通常涉及以下步骤: 获取SDK和工具: 从Adobe官方网站下载最新的Adobe InDesign SDK(Software Development Kit)&am…...

Docker 安装与基本操作
目录 一、Docker 概述 1、Docker 简述 2、Docker 的优势 3、Docker与虚拟机的区别 4、Docker 的核心概念 1)镜像 2)容器 3)仓库 二、Docker 安装 1、命令: 2、实操: 三、Docker 镜像操作 1、命令࿱…...

译文带你理解Python的dataclass装饰器
dataclass 是 Python dataclasses 模块中的一个 decorator。当使用 dataclass 装饰器时,它会自动生成一些特殊方法,包括: _ _ init _ _:用于初始化字段的构造函数_ _ repr _ _:对象的字符串表示_ _ eq _ _:…...

【C语言】实现程序的暂停
编写程序时,有时候需要让程序在某些地方暂停执行,等待用户输入或者观察程序执行结果。在 C 语言中,有多种方法可以实现程序的暂停,包括 system("pause")、getchar() 和 while ((c getchar()) ! \n && c ! EOF)…...

Hana SQL+正则表达式
目录 一、Pre 前言 二、知识点拆解 1)case when…then…else 2)json_value 函数 拓展资料 3)CAST 函数 拓展资料 4) ROUND 函数 5)occurences_regexpr 函数 拓展资料 6)正则表达式 拓展资料 三、整合分析…...

【笔记】顺利通过EMC试验(16-41)-视频笔记
目录 视频链接 P1:电子设备中有哪些主要骚扰源 P2:怎样减小DC模块的骚扰 P3:PCB上的辐射源究竟在哪里 P4:怎样控制PCB板的电磁辐射 P5:多层线路板是解决电磁兼容问题的简单方法 P6:怎样处理地线上的裂缝 P7:怎样降低时钟信号的辐射 P8:为什么IO接口的处理特别重要 P9…...

Qlik Sense 调用NPrinting生成On-Demand报表
安装 Qlik Sense On-Demand 报表控件 On-Demand 报表控件添加按钮,该按钮按需生成 Qlik NPrinting 报表。它包括在 Dashboard bundle 中。 当您希望用户能够使用应用程序中的选择作为过滤器在 Qlik Sense 中打印预定义 Qlik NPrinting 报表时,On-Deman…...

ElasticSearch重建/创建/删除索引操作 - 第501篇
历史文章(文章累计500) 《国内最全的Spring Boot系列之一》 《国内最全的Spring Boot系列之二》 《国内最全的Spring Boot系列之三》 《国内最全的Spring Boot系列之四》 《国内最全的Spring Boot系列之五》 《国内最全的Spring Boot系列之六》 E…...

数据写入HBase(scala)
package sourceimport org.apache.hadoop.hbase.{HBaseConfiguration, TableName} import org.apache.hadoop.hbase.client.{ConnectionFactory, Put} import org.apache.hadoop.hbase.util.Bytesobject ffff {def main(args: Array[String]): Unit {//hbase连接配置val conf …...
Codeforces Round 799 (Div. 4)
目录 A. Marathon B. All Distinct C. Where’s the Bishop? D. The Clock E. Binary Deque F. 3SUM G. 2^Sort H. Gambling A. Marathon 直接模拟 void solve() {int ans0;for(int i1;i<4;i) {cin>>a[i];if(i>1&&a[i]>a[1]) ans;}cout<&l…...

为什么要用云手机养tiktok账号
在拓展海外电商市场的过程中,许多用户选择采用tiktok短视频平台引流的策略,以提升在电商平台上的流量,吸引更多消费者。而要进行tiktok引流,养号是必不可少的一个环节。tiktok云手机成为实现国内跨境养号的一种有效方式࿰…...

vue pc端网页实现自适应
一、基本原理 pc端做自适应可以用rem来实现,啥是rem,自己百度 二、新建rem.ts文件 // rem等比适配配置文件 // 基准大小 const baseSize 14 // 设置 rem 函数 function setRem () {// 当前页面宽度相对于 1920宽的缩放比例,可根据自己需要…...

Android 13以上版本读写SD卡权限适配
如题,最近工作上处理的问题,把解决方案简单逻列出来,供有需要的朋友参考之 解决方案: 1、配置权限 <uses-permission android:name"android.permission.READ_MEDIA_IMAGES" /><uses-permission android:name&q…...

并查集模板:食物链详解
import java.io.BufferedReader; import java.io.IOException; import java.io.InputStreamReader;public class Main {static int N 50010;static int n,m; //n个动物,m局判断static int[] p new int[N]; //p[i]是i的根节点static int[] d new int[N]; //d[i]表示i到…...

使用WAF防御网络上的隐蔽威胁之反序列化攻击
什么是反序列化 反序列化是将数据结构或对象状态从某种格式转换回对象的过程。这种格式通常是二进制流或者字符串(如JSON、XML),它是对象序列化(即对象转换为可存储或可传输格式)的逆过程。 反序列化的安全风险 反…...

05. 交换机的基本配置
文章目录 一. 初识交换机1.1. 交换机的概述1.2. Ethernet_ll格式1.3. MAC分类1.4. 冲突域1.5. 广播域1.6. 交换机的原理1.7. 交换机的3种转发行为 二. 初识ARP2.1. ARP概述2.2. ARP报文格式2.3. ARP的分类2.4. 免费ARP的作用 三. 实验专题3.1. 实验1:交换机的基本原…...

yolo将标签数据打到原图上形成目标框
第一章 目标:为了查看自己在标注标签时是否准确,写了这段代码来将标注的框打到原图上 第二章 步骤:进行反归一化得到坐标画出矩形框 第二行是目标图片对应的txt,第三行是目标图片 第三章 全部代码如下: import cv2 import …...

002-00-02【大红ai源码】dolphinscheduler3.2.0 源码环境搭建------by孤山村头王大爷家女儿大红
【ai阅读源码-dolphinscheduler】 DolphinScheduler 开发手册1、软件要求2、克隆代码库3、编译打包4、代码风格5、新建数据库,导入元数据。6, 启动后端6.1 启动api-server 6.2 启动master-server6.3 启动worker-server 7 启动前端 DolphinScheduler 开发…...

python-自动化篇-运维-监控-如何使⽤Python处理和解析⽇志⽂件?-实操记录
文章目录 1. 选择日志文件格式: 确定要处理的日志文件的格式。不同的日志文件可能具有不同的格式,如文本日志、CSV、JSON、XML等。了解日志文件的格式对解析⾮常重要。2. 打开日志文件: 使⽤Python的文件操作功能打开日志文件,以便…...

代码随想录算法训练营DAY6 | 哈希表(1)
DAY5休息一天,今天重启~ 哈希表理论基础:代码随想录 Java hash实现 :java 哈希表-CSDN博客 一、LeetCode 242 有效的字母异位词 题目链接:242.有效的字母异位词 思路:设置字典 class Solution {public boolean isAnag…...

【嵌入式学习】C++QT-Day3-C++基础
笔记 见我的博客:https://lingjun.life/wiki/EmbeddedNote/19Cpp 作业 设计一个Per类,类中包含私有成员:姓名、年龄、指针成员身高、体重,再设计一个Stu类,类中包含私有成员:成绩、Per类对象p1,设计这两个类的构造函…...

表贴式PMSM的直接转矩控制(DTC)MATLAB仿真模型
微❤关注“电气仔推送”获得资料(专享优惠) 模型简介 表贴式PMSM的直接转矩控制(DTC),直接使用滞环控制对转矩和磁链进行控制,相对于传统的FOC控制而言,其不需要进行解耦变换,在此次的有以下几点需要注意:…...

详解OpenHarmony各部分文件在XR806上的编译顺序
大家好,今天我们来谈一谈编程时一个很有趣的话题——编译顺序。我知道,一提到编译可能大家会感到有点儿头疼,但请放心,我不会让大家头疼的。我们要明白,在开始写代码之前,了解整个程序的编译路径是十分有必…...

【美团】无人机-大数据开发工程师
更新时间:2024/01/29 工作地点:北京市 事业群:到家事业群 工作经验:3年 部门介绍 为了更好地提升城市即时配送的效率与体验,美团于2017年启动了无人机配送服务的探索,通过科技创新推动履约工具变革&#x…...

微服务系统设计:横向扩展和纵向扩展的对比
微服务扩展性:水平扩展 vs 垂直扩展 特点水平扩展垂直扩展扩展单位增加微服务实例增加单个实例的资源 (CPU,内存)方向向外,增加节点向上,增加单个节点的资源复杂性随着实例数量的增加,管理难度更大管理更简单…...

Java基于SpringBoot+Vue的网上超市管理系统
博主介绍:✌程序员徐师兄、7年大厂程序员经历。全网粉丝30W、csdn博客专家、掘金/华为云/阿里云/InfoQ等平台优质作者、专注于Java技术领域和毕业项目实战✌ 🍅文末获取源码联系🍅 👇🏻 精彩专栏推荐订阅👇…...

HTTP中POST、GET、PUT、DELETE方式的区别
GET请求会向数据库发索取数据的请求,从而来获取信息,该请求就像数据库的select操作一样,只是用来查询一下数据,不会修改、增加数据,不会影响资源的内容,即该请求不会产生副作用。无论进行多少次操作&#x…...

77.Go中interface{}判nil的正确姿势
文章目录 一:interface{}简介二、interface{}判空三:注意点四:实际案例 一:interface{}简介 在go中的nil只能赋值给指针、channel、func、interface、map或slice类型的变量 interface 是否根据是否包含有 method,底层…...

ES实战回顾
1、你用的集群节点情况? 一个ES集群,18个节点,其中3个主节点,15个数据节点,500G左右的索引数据量,没有单独的协调节点,它的每个节点都可以充当协调功能; 2、你们常用的索引有哪些&a…...