当前位置: 首页 > news >正文

加载预训练模型,模型微调,在自己的数据集上快速出效果

  • 针对于某个任务,自己的训练数据不多,先找到一个同类的别人训练好的模型,把别人现成的训练好了的模型拿过来,换成自己的数据,调整一下参数,再训练一遍,这就是微调(fine-tune)。 PyTorch里面提供的经典的网络模型都是官方通过Imagenet的数据集与训练好的数据,如果我们的数据训练数据不够,这些数据是可以作为基础模型来使用的。(Fine tuning 模型微调)

  • Fine tuning 模型微调的好处

    • 对于数据集本身很小(几千张图片)的情况,从头开始训练具有几千万参数的大型神经网络是不现实的,因为越大的模型对数据量的要求越大,过拟合无法避免。这时候如果还想用上大型神经网络的超强特征提取能力,只能靠微调已经训练好的模型。

    • 可以降低训练成本:如果使用导出特征向量的方法进行迁移学习,后期的训练成本非常低,用 CPU 都完全无压力,没有深度学习机器也可以做。

    • 前人花很大精力训练出来的模型在大概率上会比你自己从零开始搭的模型要强悍,没有必要重复造轮子

  • 迁移学习初衷是节省人工标注样本的时间,让模型可以通过一个已有的标记数据的领域向未标记数据领域进行迁移从而训练出适用于该领域的模型,直接对目标域从头开始学习成本太高,我们故而转向运用已有的相关知识来辅助尽快地学习新知识。把统一的概念抽象出来,只学习不同的内容。迁移学习按照学习方式可以分为基于样本的迁移,基于特征的迁移,基于模型的迁移,以及基于关系的迁移。

  • 微调应该是迁移学习中的一部分。微调只能说是一个trick,一种技术;迁移学习是一个更宏大的概念

  • Pytorch模型保存、加载与预训练

  • 保存和加载整个模型和参数:这种方式会保存整个模型的结构以及参数,会占用较大的磁盘空间, 通常不采用这种方式

  • torch.save(model, 'model.pkl')  #保存
    model = torch.load('model.pkl') # 加载
    
  • 保存和加载模型的参数, 优点是速度快,占用的磁盘空间少, 是最常用的模型保存方法。load_state_dict有一个strict参数,该参数默认是True, 表示预训练模型的网络结构与自定义的网络结构严格相同(包括名字和维度)。 如果自定义网络和预训练网络不严格相同时, 需要将不属于自定义网络的key去掉

  • torch.save(model.state_dict(), 'model_state_dict.pkl')
    model = model.load_state_dict(torch.load(model_state_dict.pkl))
    
  • 在实际场景中, 我们往往需要保存更多的信息,如优化器的参数, 那么可以通过字典的方式进行存储

  • # 保存
    torch.save({'epoch': epochId,'state_dict': model.state_dict,'best_acc': best_acc,'optimizer': optimizer.state_dict()}, checkpoint_path + "/m-" + timestamp + str("%.4f" % best_acc) + ".pth.tar")
    # 加载
    def load_model(model, checkpoint, optimizer):model_CKPT = torch.load(checkpoint)model.load_state_dict(model_CKPT['state_dict'])optimizer.load_state_dict(model_CKPT['optimizer'])return model, optimizer
    
  • 加载部分预训练模型: 如果我们修改了网络, 那么就需要将这部分参数过滤掉:(值得注意的是,当两个网络的结构相同, 但是结构的命名不同时, 直接加载会报错。因此需要修改结构的key值)

  • def load_model(model, chinkpoint, optimizer):model_CKPT = torch.load(checkpoint)model_dict = model.state_dict()pretrained_dict = model_CKPT['state_dict']# 将不在model中的参数过滤掉new_dict = {k, v for k, v in pretrained_dict.items() if k in model_dict.keys()}model_dict.update(new_dict)model.load_state_dict(model_dict)# 加载优化器参数optimizer.load_state_dict(model_CKPT['optimizer'])return model, optimizer
    
  • 冻结网络的部分参数, 训练另一部分参数(注意,必须同时在优化器中将这些参数过滤掉, 否则会报错。因为optimizer里面的参数要求required_grad为Ture)

    • 当输入给模型的数据集形式相似或者相同时,常见的是利用现有的经典模型(如Residual Network、 GoogleNet等)作为backbone来提取特征,那么这些经典模型已经训练好的模型参数可以直接拿过来使用。通常情况下, 我们希望将这些经典网络模型的参数固定下来, 不进行训练,只训练后面我们添加的和具体任务相关的网络参数。

      • 新数据集和原始数据集合类似,那么直接可以微调一个最后的FC层或者重新指定一个新的分类器

      • 新数据集比较小和原始数据集合差异性比较大,那么可以使用从模型的中部开始训练,只对最后几层进行fine-tuning

      • 新数据集比较小和原始数据集合差异性比较大,如果上面方法还是不行的化那么最好是重新训练,只将预训练的模型作为一个新模型初始化的数据

      • 新数据集的大小一定要与原始数据集相同,比如CNN中输入的图片大小一定要相同,才不会报错

      • 对于不同的层可以设置不同的学习率,一般情况下建议,对于使用的原始数据做初始化的层设置的学习率要小于(一般可设置小于10倍)初始化的学习率,这样保证对于已经初始化的数据不会扭曲的过快,而使用初始化学习率的新层可以快速的收敛。

  • # 以ResNet网络为例
    # 当我们加载ResNet预训练模型之后,在ResNet的基础上连接了新的网络模块, ResNet那部分网络参数先冻结不更新
    # 只更新新引入网络结构的参数
    class Net(torch.nn.Module):def __init__(self, model, pretrained):super(Net, self).__init__()self.resnet = model(pretained)for p in self.parameters():p.requires_grad = Falseself.conv1 = torch.nn.Conv2d(2048, 1024, 1)self.conv2 = torch.nn.Conv2d(1024, 1024, 1)
    
  • 参数修改: resnet网络的最后一层对应1000个类别, 如果我们自己的数据只有10个类别, 那么可以进行如下修改

  • import torch
    import torchvision.models as models
    model = models.resnet50(pretrained=True)
    fc_inDim = model.fc.in_features
    # 修改为10个类别
    model.fc = torch.nn.Linear(fc_inDim, 10)
    
  • Pytorch有很多方便易用的包,今天要谈的是torchvision包,它包括3个子包,分别是: torchvison.datasets ,torchvision.models ,torchvision.transforms ,分别是预定义好的数据集(比如MNIST、CIFAR10等)、预定义好的经典网络结构(比如AlexNet、VGG、ResNet等)和预定义好的数据增强方法(比如Resize、ToTensor等)。这些方法可以直接调用,简化我们建模的过程,也可以作为我们学习或构建新的模型的参考。

相关文章:

加载预训练模型,模型微调,在自己的数据集上快速出效果

针对于某个任务,自己的训练数据不多,先找到一个同类的别人训练好的模型,把别人现成的训练好了的模型拿过来,换成自己的数据,调整一下参数,再训练一遍,这就是微调(fine-tune&#xff…...

VScode远程连接服务器-过程试图写入的管道不存在-could not establist connection to【已解决】

问题描述 使用服务器的过程中突然与服务器断连,报错如下:could not establist connection to [20:23:39.487] > ssh: connect to host 10.201.0.131 port 22: Connection timed out > [20:23:39.495] > 过程试图写入的管道不存在。 > [20…...

电子技术——B类输出阶

电子技术——B类输出阶 下图展示了一个B类输出阶的原理图,B类输出阶由两个互补的BJT组成,不同时导通。 原理 当输入电压 vI0v_I 0vI​0 的时候,两个晶体管都截止输出电压为零。当 vIv_IvI​ 上升至超过0.5V的时候,此时 QNQ_NQN…...

【老卫搬砖】034期:HarmonyOS 3.1 Beta 1初体验,我在本地模拟器里面刷短视频

今天啊打开这个DevEco Studio的话,已经提示有3.1Beta1版本的一个更新啊。然后看一下它的一些特性。本文也演示了如何在本地模拟器里面运行HarmonyOS版短视频。 主要特性 新特性包括: Added support for Windows 11 64-bit and macOS 13.x OSs, as well…...

Day901.内部临时表 -MySQL实战

内部临时表 Hi,我是阿昌,今天学习记录的是关于内部临时表的内容。 sort buffer、内存临时表和 join buffer。这三个数据结构都是用来存放语句执行过程中的中间数据,以辅助 SQL 语句的执行的。 其中,在排序的时候用到了 sort bu…...

jstatd的启动方式与关闭方式

启动方式与注意事项: 启动方式: 前台启动不打印日志: jstatd -J-Djava.security.policyjstatd.all.policy -J-Djava.rmi.server.hostname服务器IP 前台启动并打印日志: ./jstatd -J-Djava.security.policyjstatd.all.policy -…...

_improve-3

createElement过程 React.createElement(): 根据指定的第一个参数创建一个React元素 React.createElement(type,[props],[...children] )第一个参数是必填,传入的是似HTML标签名称,eg: ul, li第二个参数是选填,表示的是属性&#…...

C++——异常

目录 C语言传统的处理错误的方式 C异常概念 异常的使用 异常的抛出和匹配原则 在函数调用链中异常栈展开匹配原则 自定义异常体系 异常的重新抛出 ​编辑 异常安全 异常规范 C标准库的异常体系 异常的优缺点 C语言传统的处理错误的方式 传统的错误处理机制: …...

MVVM 架构进阶:MVI 架构详解

前言Android开发发展到今天已经相当成熟了,各种架构大家也都耳熟能详,如MVC,MVP,MVVM等,其中MVVM更是被官方推荐,成为Android开发中的显学。不过软件开发中没有银弹,MVVM架构也不是尽善尽美的,在使用过程中…...

有没有必要考PMP证书?

其实针对有没有必要考试吗,这个可以根本不同行业的人来决定的。 1.高等教育项目管理专业科班出身的人员。 在我国本科学历和硕士研究生学历中,项目管理也有开设。不管以后从事的工作是否为项目管理或其他管理,作为本专业的同学,…...

1 机器学习基础

1 机器学习概述 1.1 数据驱动的问题求解 大数据-Big Data 大数据的多面性 1.2 数据分析 机器学习:海量的数据,获取有用的信息 专门研究计算机怎样模拟或实现人类的学习行为,以获取新的知识或技能,重新组织已有的知识结构使之…...

java基础系列(六) sleep()和wait() 区别

一.前言 关于并发编程这块, 线程的一些基础知识我们得搞明白, 本篇文章来说一下这两个方法的区别,对Android中的HandlerThread机制原理可以有更深的理解, HandlerThread源码理解,请查看笔者的这篇博客: HandlerThread源码理解_handlerthread 源码_broadview_java的博客-CSDN博…...

Urho3D序列化

从Serializable派生的类可以通过定义属性将其自动序列化为二进制或XML格式。属性存储到每个类的上下文中。场景加载/保存和网络复制都是通过从Serializable派生Node和Component类来实现的。 支持的属性类型是Variant支持的所有属性类型,不包括指针和自定义值。 属性…...

企业级信息系统开发学习1.3——利用注解配置取代Spring配置文件

文章目录一、利用注解配置类取代Spring配置文件(一)打开项目(二)创建新包(三)拷贝类与接口(四)创建注解配置类(五)创建测试类(六)运行…...

VUE DIFF算法之快速DIFF

VUE DIFF算法系列讲解 VUE 简单DIFF算法 VUE 双端DIFF算法 文章目录VUE DIFF算法系列讲解前言一、快速DIFF的代码实现二、实践练习1练习2总结前言 本节我们来写一下VUE3中新的DIFF算法-快速DIFF,顾名思义,也就是目前最快的DIFF算法(在VUE中&…...

一文掌握如何轻松稿定项目风险管理【静说】

风险管理对于每个项目经理和PMO都非常重要,如果管理不当会出现很多问题,咱们以前分享过很多风险管理的内容: 风险无处不在,一旦发生,会对一个或多个项目目标产生积极或消极影响的确定事件或条件。那么接下来介绍下五大…...

操作系统权限提升(十四)之绕过UAC提权-基于白名单AutoElevate绕过UAC提权

系列文章 操作系统权限提升(十二)之绕过UAC提权-Windows UAC概述 操作系统权限提升(十三)之绕过UAC提权-MSF和CS绕过UAC提权 注:阅读本编文章前,请先阅读系列文章,以免造成看不懂的情况!! 基于白名单AutoElevate绕过…...

ecology9-谷歌浏览器下-pdf.js在渲染时部分发票丢失文字 问题定位及解决

问题 问题描述 : 在谷歌浏览器下,pdf.js在渲染时部分发票丢失文字;360浏览器兼容模式不存在此问题 排查思路:1、对比谷歌浏览器的css样式和360浏览器兼容模式下的样式,没有发现关键差别 2、✔使用Fiddler修改网页js D…...

JavaScript Window Navigator

文章目录JavaScript Window NavigatorWindow Navigator警告!!!浏览器检测JavaScript Window Navigator window.navigator 对象包含有关访问者浏览器的信息。 Window Navigator window.navigator 对象在编写时可不使用 window 这个前缀。 实例 <div id"example"…...

Linux基础命令-du查看文件的大小

文章目录 du 命令介绍 语法格式 基本参数 参考实例 1&#xff09;以人类可读形式显示指定的文件大小 2&#xff09;显示当前目录下所有文件大小 3&#xff09;只显示目录的大小 4&#xff09;显示根下哪个目录文件最大 5&#xff09;显示所有文件的大小 6&#xff0…...

日语AI面试高效通关秘籍:专业解读与青柚面试智能助攻

在如今就业市场竞争日益激烈的背景下&#xff0c;越来越多的求职者将目光投向了日本及中日双语岗位。但是&#xff0c;一场日语面试往往让许多人感到步履维艰。你是否也曾因为面试官抛出的“刁钻问题”而心生畏惧&#xff1f;面对生疏的日语交流环境&#xff0c;即便提前恶补了…...

零门槛NAS搭建:WinNAS如何让普通电脑秒变私有云?

一、核心优势&#xff1a;专为Windows用户设计的极简NAS WinNAS由深圳耘想存储科技开发&#xff0c;是一款收费低廉但功能全面的Windows NAS工具&#xff0c;主打“无学习成本部署” 。与其他NAS软件相比&#xff0c;其优势在于&#xff1a; 无需硬件改造&#xff1a;将任意W…...

以下是对华为 HarmonyOS NETX 5属性动画(ArkTS)文档的结构化整理,通过层级标题、表格和代码块提升可读性:

一、属性动画概述NETX 作用&#xff1a;实现组件通用属性的渐变过渡效果&#xff0c;提升用户体验。支持属性&#xff1a;width、height、backgroundColor、opacity、scale、rotate、translate等。注意事项&#xff1a; 布局类属性&#xff08;如宽高&#xff09;变化时&#…...

Java如何权衡是使用无序的数组还是有序的数组

在 Java 中,选择有序数组还是无序数组取决于具体场景的性能需求与操作特点。以下是关键权衡因素及决策指南: ⚖️ 核心权衡维度 维度有序数组无序数组查询性能二分查找 O(log n) ✅线性扫描 O(n) ❌插入/删除需移位维护顺序 O(n) ❌直接操作尾部 O(1) ✅内存开销与无序数组相…...

江苏艾立泰跨国资源接力:废料变黄金的绿色供应链革命

在华东塑料包装行业面临限塑令深度调整的背景下&#xff0c;江苏艾立泰以一场跨国资源接力的创新实践&#xff0c;重新定义了绿色供应链的边界。 跨国回收网络&#xff1a;废料变黄金的全球棋局 艾立泰在欧洲、东南亚建立再生塑料回收点&#xff0c;将海外废弃包装箱通过标准…...

成都鼎讯硬核科技!雷达目标与干扰模拟器,以卓越性能制胜电磁频谱战

在现代战争中&#xff0c;电磁频谱已成为继陆、海、空、天之后的 “第五维战场”&#xff0c;雷达作为电磁频谱领域的关键装备&#xff0c;其干扰与抗干扰能力的较量&#xff0c;直接影响着战争的胜负走向。由成都鼎讯科技匠心打造的雷达目标与干扰模拟器&#xff0c;凭借数字射…...

Caliper 配置文件解析:config.yaml

Caliper 是一个区块链性能基准测试工具,用于评估不同区块链平台的性能。下面我将详细解释你提供的 fisco-bcos.json 文件结构,并说明它与 config.yaml 文件的关系。 fisco-bcos.json 文件解析 这个文件是针对 FISCO-BCOS 区块链网络的 Caliper 配置文件,主要包含以下几个部…...

项目部署到Linux上时遇到的错误(Redis,MySQL,无法正确连接,地址占用问题)

Redis无法正确连接 在运行jar包时出现了这样的错误 查询得知问题核心在于Redis连接失败&#xff0c;具体原因是客户端发送了密码认证请求&#xff0c;但Redis服务器未设置密码 1.为Redis设置密码&#xff08;匹配客户端配置&#xff09; 步骤&#xff1a; 1&#xff09;.修…...

【Java学习笔记】BigInteger 和 BigDecimal 类

BigInteger 和 BigDecimal 类 二者共有的常见方法 方法功能add加subtract减multiply乘divide除 注意点&#xff1a;传参类型必须是类对象 一、BigInteger 1. 作用&#xff1a;适合保存比较大的整型数 2. 使用说明 创建BigInteger对象 传入字符串 3. 代码示例 import j…...

使用LangGraph和LangSmith构建多智能体人工智能系统

现在&#xff0c;通过组合几个较小的子智能体来创建一个强大的人工智能智能体正成为一种趋势。但这也带来了一些挑战&#xff0c;比如减少幻觉、管理对话流程、在测试期间留意智能体的工作方式、允许人工介入以及评估其性能。你需要进行大量的反复试验。 在这篇博客〔原作者&a…...