当前位置: 首页 > news >正文

加载预训练模型,模型微调,在自己的数据集上快速出效果

  • 针对于某个任务,自己的训练数据不多,先找到一个同类的别人训练好的模型,把别人现成的训练好了的模型拿过来,换成自己的数据,调整一下参数,再训练一遍,这就是微调(fine-tune)。 PyTorch里面提供的经典的网络模型都是官方通过Imagenet的数据集与训练好的数据,如果我们的数据训练数据不够,这些数据是可以作为基础模型来使用的。(Fine tuning 模型微调)

  • Fine tuning 模型微调的好处

    • 对于数据集本身很小(几千张图片)的情况,从头开始训练具有几千万参数的大型神经网络是不现实的,因为越大的模型对数据量的要求越大,过拟合无法避免。这时候如果还想用上大型神经网络的超强特征提取能力,只能靠微调已经训练好的模型。

    • 可以降低训练成本:如果使用导出特征向量的方法进行迁移学习,后期的训练成本非常低,用 CPU 都完全无压力,没有深度学习机器也可以做。

    • 前人花很大精力训练出来的模型在大概率上会比你自己从零开始搭的模型要强悍,没有必要重复造轮子

  • 迁移学习初衷是节省人工标注样本的时间,让模型可以通过一个已有的标记数据的领域向未标记数据领域进行迁移从而训练出适用于该领域的模型,直接对目标域从头开始学习成本太高,我们故而转向运用已有的相关知识来辅助尽快地学习新知识。把统一的概念抽象出来,只学习不同的内容。迁移学习按照学习方式可以分为基于样本的迁移,基于特征的迁移,基于模型的迁移,以及基于关系的迁移。

  • 微调应该是迁移学习中的一部分。微调只能说是一个trick,一种技术;迁移学习是一个更宏大的概念

  • Pytorch模型保存、加载与预训练

  • 保存和加载整个模型和参数:这种方式会保存整个模型的结构以及参数,会占用较大的磁盘空间, 通常不采用这种方式

  • torch.save(model, 'model.pkl')  #保存
    model = torch.load('model.pkl') # 加载
    
  • 保存和加载模型的参数, 优点是速度快,占用的磁盘空间少, 是最常用的模型保存方法。load_state_dict有一个strict参数,该参数默认是True, 表示预训练模型的网络结构与自定义的网络结构严格相同(包括名字和维度)。 如果自定义网络和预训练网络不严格相同时, 需要将不属于自定义网络的key去掉

  • torch.save(model.state_dict(), 'model_state_dict.pkl')
    model = model.load_state_dict(torch.load(model_state_dict.pkl))
    
  • 在实际场景中, 我们往往需要保存更多的信息,如优化器的参数, 那么可以通过字典的方式进行存储

  • # 保存
    torch.save({'epoch': epochId,'state_dict': model.state_dict,'best_acc': best_acc,'optimizer': optimizer.state_dict()}, checkpoint_path + "/m-" + timestamp + str("%.4f" % best_acc) + ".pth.tar")
    # 加载
    def load_model(model, checkpoint, optimizer):model_CKPT = torch.load(checkpoint)model.load_state_dict(model_CKPT['state_dict'])optimizer.load_state_dict(model_CKPT['optimizer'])return model, optimizer
    
  • 加载部分预训练模型: 如果我们修改了网络, 那么就需要将这部分参数过滤掉:(值得注意的是,当两个网络的结构相同, 但是结构的命名不同时, 直接加载会报错。因此需要修改结构的key值)

  • def load_model(model, chinkpoint, optimizer):model_CKPT = torch.load(checkpoint)model_dict = model.state_dict()pretrained_dict = model_CKPT['state_dict']# 将不在model中的参数过滤掉new_dict = {k, v for k, v in pretrained_dict.items() if k in model_dict.keys()}model_dict.update(new_dict)model.load_state_dict(model_dict)# 加载优化器参数optimizer.load_state_dict(model_CKPT['optimizer'])return model, optimizer
    
  • 冻结网络的部分参数, 训练另一部分参数(注意,必须同时在优化器中将这些参数过滤掉, 否则会报错。因为optimizer里面的参数要求required_grad为Ture)

    • 当输入给模型的数据集形式相似或者相同时,常见的是利用现有的经典模型(如Residual Network、 GoogleNet等)作为backbone来提取特征,那么这些经典模型已经训练好的模型参数可以直接拿过来使用。通常情况下, 我们希望将这些经典网络模型的参数固定下来, 不进行训练,只训练后面我们添加的和具体任务相关的网络参数。

      • 新数据集和原始数据集合类似,那么直接可以微调一个最后的FC层或者重新指定一个新的分类器

      • 新数据集比较小和原始数据集合差异性比较大,那么可以使用从模型的中部开始训练,只对最后几层进行fine-tuning

      • 新数据集比较小和原始数据集合差异性比较大,如果上面方法还是不行的化那么最好是重新训练,只将预训练的模型作为一个新模型初始化的数据

      • 新数据集的大小一定要与原始数据集相同,比如CNN中输入的图片大小一定要相同,才不会报错

      • 对于不同的层可以设置不同的学习率,一般情况下建议,对于使用的原始数据做初始化的层设置的学习率要小于(一般可设置小于10倍)初始化的学习率,这样保证对于已经初始化的数据不会扭曲的过快,而使用初始化学习率的新层可以快速的收敛。

  • # 以ResNet网络为例
    # 当我们加载ResNet预训练模型之后,在ResNet的基础上连接了新的网络模块, ResNet那部分网络参数先冻结不更新
    # 只更新新引入网络结构的参数
    class Net(torch.nn.Module):def __init__(self, model, pretrained):super(Net, self).__init__()self.resnet = model(pretained)for p in self.parameters():p.requires_grad = Falseself.conv1 = torch.nn.Conv2d(2048, 1024, 1)self.conv2 = torch.nn.Conv2d(1024, 1024, 1)
    
  • 参数修改: resnet网络的最后一层对应1000个类别, 如果我们自己的数据只有10个类别, 那么可以进行如下修改

  • import torch
    import torchvision.models as models
    model = models.resnet50(pretrained=True)
    fc_inDim = model.fc.in_features
    # 修改为10个类别
    model.fc = torch.nn.Linear(fc_inDim, 10)
    
  • Pytorch有很多方便易用的包,今天要谈的是torchvision包,它包括3个子包,分别是: torchvison.datasets ,torchvision.models ,torchvision.transforms ,分别是预定义好的数据集(比如MNIST、CIFAR10等)、预定义好的经典网络结构(比如AlexNet、VGG、ResNet等)和预定义好的数据增强方法(比如Resize、ToTensor等)。这些方法可以直接调用,简化我们建模的过程,也可以作为我们学习或构建新的模型的参考。

相关文章:

加载预训练模型,模型微调,在自己的数据集上快速出效果

针对于某个任务,自己的训练数据不多,先找到一个同类的别人训练好的模型,把别人现成的训练好了的模型拿过来,换成自己的数据,调整一下参数,再训练一遍,这就是微调(fine-tune&#xff…...

VScode远程连接服务器-过程试图写入的管道不存在-could not establist connection to【已解决】

问题描述 使用服务器的过程中突然与服务器断连,报错如下:could not establist connection to [20:23:39.487] > ssh: connect to host 10.201.0.131 port 22: Connection timed out > [20:23:39.495] > 过程试图写入的管道不存在。 > [20…...

电子技术——B类输出阶

电子技术——B类输出阶 下图展示了一个B类输出阶的原理图,B类输出阶由两个互补的BJT组成,不同时导通。 原理 当输入电压 vI0v_I 0vI​0 的时候,两个晶体管都截止输出电压为零。当 vIv_IvI​ 上升至超过0.5V的时候,此时 QNQ_NQN…...

【老卫搬砖】034期:HarmonyOS 3.1 Beta 1初体验,我在本地模拟器里面刷短视频

今天啊打开这个DevEco Studio的话,已经提示有3.1Beta1版本的一个更新啊。然后看一下它的一些特性。本文也演示了如何在本地模拟器里面运行HarmonyOS版短视频。 主要特性 新特性包括: Added support for Windows 11 64-bit and macOS 13.x OSs, as well…...

Day901.内部临时表 -MySQL实战

内部临时表 Hi,我是阿昌,今天学习记录的是关于内部临时表的内容。 sort buffer、内存临时表和 join buffer。这三个数据结构都是用来存放语句执行过程中的中间数据,以辅助 SQL 语句的执行的。 其中,在排序的时候用到了 sort bu…...

jstatd的启动方式与关闭方式

启动方式与注意事项: 启动方式: 前台启动不打印日志: jstatd -J-Djava.security.policyjstatd.all.policy -J-Djava.rmi.server.hostname服务器IP 前台启动并打印日志: ./jstatd -J-Djava.security.policyjstatd.all.policy -…...

_improve-3

createElement过程 React.createElement(): 根据指定的第一个参数创建一个React元素 React.createElement(type,[props],[...children] )第一个参数是必填,传入的是似HTML标签名称,eg: ul, li第二个参数是选填,表示的是属性&#…...

C++——异常

目录 C语言传统的处理错误的方式 C异常概念 异常的使用 异常的抛出和匹配原则 在函数调用链中异常栈展开匹配原则 自定义异常体系 异常的重新抛出 ​编辑 异常安全 异常规范 C标准库的异常体系 异常的优缺点 C语言传统的处理错误的方式 传统的错误处理机制: …...

MVVM 架构进阶:MVI 架构详解

前言Android开发发展到今天已经相当成熟了,各种架构大家也都耳熟能详,如MVC,MVP,MVVM等,其中MVVM更是被官方推荐,成为Android开发中的显学。不过软件开发中没有银弹,MVVM架构也不是尽善尽美的,在使用过程中…...

有没有必要考PMP证书?

其实针对有没有必要考试吗,这个可以根本不同行业的人来决定的。 1.高等教育项目管理专业科班出身的人员。 在我国本科学历和硕士研究生学历中,项目管理也有开设。不管以后从事的工作是否为项目管理或其他管理,作为本专业的同学,…...

1 机器学习基础

1 机器学习概述 1.1 数据驱动的问题求解 大数据-Big Data 大数据的多面性 1.2 数据分析 机器学习:海量的数据,获取有用的信息 专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之…...

java基础系列(六) sleep()和wait() 区别

一.前言 关于并发编程这块, 线程的一些基础知识我们得搞明白, 本篇文章来说一下这两个方法的区别,对Android中的HandlerThread机制原理可以有更深的理解, HandlerThread源码理解,请查看笔者的这篇博客: HandlerThread源码理解_handlerthread 源码_broadview_java的博客-CSDN博…...

Urho3D序列化

从Serializable派生的类可以通过定义属性将其自动序列化为二进制或XML格式。属性存储到每个类的上下文中。场景加载/保存和网络复制都是通过从Serializable派生Node和Component类来实现的。 支持的属性类型是Variant支持的所有属性类型,不包括指针和自定义值。 属性…...

企业级信息系统开发学习1.3——利用注解配置取代Spring配置文件

文章目录一、利用注解配置类取代Spring配置文件(一)打开项目(二)创建新包(三)拷贝类与接口(四)创建注解配置类(五)创建测试类(六)运行…...

VUE DIFF算法之快速DIFF

VUE DIFF算法系列讲解 VUE 简单DIFF算法 VUE 双端DIFF算法 文章目录VUE DIFF算法系列讲解前言一、快速DIFF的代码实现二、实践练习1练习2总结前言 本节我们来写一下VUE3中新的DIFF算法-快速DIFF,顾名思义,也就是目前最快的DIFF算法(在VUE中&…...

一文掌握如何轻松稿定项目风险管理【静说】

风险管理对于每个项目经理和PMO都非常重要,如果管理不当会出现很多问题,咱们以前分享过很多风险管理的内容: 风险无处不在,一旦发生,会对一个或多个项目目标产生积极或消极影响的确定事件或条件。那么接下来介绍下五大…...

操作系统权限提升(十四)之绕过UAC提权-基于白名单AutoElevate绕过UAC提权

系列文章 操作系统权限提升(十二)之绕过UAC提权-Windows UAC概述 操作系统权限提升(十三)之绕过UAC提权-MSF和CS绕过UAC提权 注:阅读本编文章前,请先阅读系列文章,以免造成看不懂的情况!! 基于白名单AutoElevate绕过…...

ecology9-谷歌浏览器下-pdf.js在渲染时部分发票丢失文字 问题定位及解决

问题 问题描述 : 在谷歌浏览器下,pdf.js在渲染时部分发票丢失文字;360浏览器兼容模式不存在此问题 排查思路:1、对比谷歌浏览器的css样式和360浏览器兼容模式下的样式,没有发现关键差别 2、✔使用Fiddler修改网页js D…...

JavaScript Window Navigator

文章目录JavaScript Window NavigatorWindow Navigator警告!!!浏览器检测JavaScript Window Navigator window.navigator 对象包含有关访问者浏览器的信息。 Window Navigator window.navigator 对象在编写时可不使用 window 这个前缀。 实例 <div id"example"…...

Linux基础命令-du查看文件的大小

文章目录 du 命令介绍 语法格式 基本参数 参考实例 1&#xff09;以人类可读形式显示指定的文件大小 2&#xff09;显示当前目录下所有文件大小 3&#xff09;只显示目录的大小 4&#xff09;显示根下哪个目录文件最大 5&#xff09;显示所有文件的大小 6&#xff0…...

文献计量分析方法:Citespace安装教程

Citespace是一款由陈超美教授开发的可用于海量文献可视化分析的软件&#xff0c;可对Web of Science&#xff0c;Scopus&#xff0c;Pubmed&#xff0c;CNKI等数据库的海量文献进行主题、关键词&#xff0c;作者单位、合作网络&#xff0c;期刊、发表时间&#xff0c;文献被引等…...

MVI 架构更佳实践:支持 LiveData 属性监听

前言MVI架构为了解决MVVM在逻辑复杂时需要写多个LiveData(可变不可变)的问题,使用ViewState对State集中管理&#xff0c;只需要订阅一个 ViewState 便可获取页面的所有状态通过集中管理ViewState&#xff0c;只需对外暴露一个LiveData&#xff0c;解决了MVVM模式下LiveData膨胀…...

LeetCode438 找到字符串中所有字母异位词 带输入和输出

题目&#xff1a; 给定两个字符串 s 和 p&#xff0c;找到 s 中所有 p 的 异位词 的子串&#xff0c;返回这些子串的起始索引。不考虑答案输出的顺序。 异位词 指由相同字母重排列形成的字符串&#xff08;包括相同的字符串&#xff09;。 示例 1: 输入: s “cbaebabacd”, …...

ACSC 2023 比赛复现

Admin Dashboard 在 index.php 中可以看到需要访问者是 admin 权限&#xff0c;才可以看到 flag。 report.php 中可以让 admin bot 访问我们输入的 url&#xff0c;那么也就是说可以访问 addadmin.php 添加用户。 在 addadmin.php 中可以添加 admin 用户&#xff0c;但是需…...

【Linux驱动开发100问】什么是模块?如何编写和使用模块?

&#x1f947;今日学习目标&#xff1a;什么是Linux内核&#xff1f; &#x1f935;‍♂️ 创作者&#xff1a;JamesBin ⏰预计时间&#xff1a;10分钟 &#x1f389;个人主页&#xff1a;嵌入式悦翔园个人主页 &#x1f341;专栏介绍&#xff1a;Linux驱动开发100问 什么是模块…...

Android 9.0 Recent列表不显示某个app

1.概述 在9.0的系统产品rom定制化开发中,在一些产品定制化需求中,也是有很多重要的功能实现的,比如在某些app的开发中 由于不想被杀掉,所以就不想出现在recent的列表中,因此就需要从recent的列表中,去掉这个app的显示,然后这里有 两种方法实现这个功能,一种是在app中就…...

深度学习之卷积神经网络学习笔记一

1. 引言深度学习是一系列算法的统称&#xff0c;包括卷积神经网络&#xff08;CNN&#xff09;&#xff0c;循环神经网络&#xff08;RNN&#xff09;&#xff0c;自编码器&#xff08;AE&#xff09;&#xff0c;深度置信网络&#xff08;DBN&#xff09;&#xff0c;生成对抗…...

黑盒测试的常用方法

这里我们先设置一个示例,后面的文章中会根据示例来进行讲解 假设有一个程序是判断一个整形数字是否属于1-100 目录 1.等价类法 2.边界值法 3.判定表法 4.场景设计法 5.错误猜测法 6.正交法 1.等价类法 概念:系统性的确定要输入的测试条件的方法可以看出概念非常抽象,那…...

操作系统笔记-第一章

文章目录操作系统概述1. 操作系统的概念1.1 操作系统的地位1.2 操作系统的作用1.3 操作系统的定义2. 操作系统的历史2.1 操作系统的产生2.1.1 手动操作阶段&#xff08;20世纪40年代&#xff09;2.1.2 批处理阶段&#xff08;20世纪50年代&#xff09;2.1.3 执行系统阶段&#…...

daillist

daillist #重要说明&#xff1a; #[1]任意两个配置参数之间必须以空格隔开&#xff0c;否则&#xff0c;拨号脚本无法识别。 #[2]Info格式说明:厂商名简称_制式_频段 #VID #PID #PORT_M #PORT_A #PORT_G #script_*99# #script_#777 #Info 05c6 9025 /dev/ttyUSB1 /dev/ttyUSB2 …...

wordpress去除相册样式/关键词搜索热度查询

Linux软RAID的实现方式前提软RAID说明软RAID实现mdadm配置示例软RAID测试和修复软RAID管理练习示例创建一个10G可用空间的RAID5。第1步&#xff1a;准备3块 5G 的硬盘。并对其进行分区创建&#xff0c;分区格式为 fd &#xff0c;大小 5G第2步&#xff1a;再增加一块空闲盘第3步…...

一个互联网公司可以做几个网站/百度提交网址多久才会收录

1. 首先明白什么是类加载器&#xff1f; 顾名思义&#xff0c;类加载器&#xff08;class loader&#xff09;用来加载 Java 类到 Java 虚拟机中。一般来说&#xff0c;Java 虚拟机使用 Java 类的方式如下&#xff1a;Java 源程序&#xff08;.java 文件&#xff09;在经过 Jav…...

做网站表格/seo优化关键词排名优化

文章目录项目整体结构依赖openfeign的一些配置order-service-apiorder-servicepay-service测试源码分析源码下载项目整体结构 说明&#xff1a; 所有公共依赖都放在了父pom中&#xff0c;API接口抽离放在单独模块 依赖 <properties><java.version>1.8</java.…...

微网站建设开发/上海seo搜索优化

这是一个iPhone Menu JSON文件的示例&#xff0c;您可能会看到该文件用于存储菜单配置设置以在移动设备上设置网站。 使用简单的JSON格式&#xff0c;可以轻松地在移动和Web组件之间共享它。 另外&#xff1a; 请参阅更多JSON示例。 {"menu": {"header": &…...

珠海建设网站首页/百度搜索指数入口

ASP代码审计学习笔记 -5.文件下载漏洞 文件下载漏洞 漏洞代码&#xff1a; <% function download(f,n) on error resume next Set SCreateObject("Adodb.Stream") S.Mode3 S.Type1 S.Open S.LoadFromFile(f) if Err.Number>0 then …...

四川建设网有限责任公司招聘/宁波网站seo诊断工具

说明&#xff1a;这里提供了简单的压测与高可用集群思路&#xff0c;因为时间问题&#xff0c;笔者并没有详细测试并搭建高可用集群。 rabbitMq压测方案 rabbitmq压测性能代码 public class Send2 {//消息队列名称private final static String QUEUE_NAME "helloword2…...