当前位置: 首页 > news >正文

深度学习:迁移学习

目录

一、迁移学习

1.什么是迁移学习

2.迁移学习的步骤

1、选择预训练的模型和适当的层

2、冻结预训练模型的参数

3、在新数据集上训练新增加的层

4、微调预训练模型的层

5、评估和测试

二、迁移学习实例

1.导入模型

2.冻结模型参数

3.修改参数

4.创建类,数据增强,导入数据

5.定义训练集和测试集函数

6.将模型传入GPU,并有序调整学习率

7.进行训练和测试


一、迁移学习

1.什么是迁移学习

        迁移学习是指利用已经训练好的模型,在新的任务上进行微调。迁移学习可以加快模型训练速度,提高模型性能,并且在数据稀缺的情况下也能很好地工作。

 

2.迁移学习的步骤

1、选择预训练的模型和适当的层

        通常,我们会选择在大规模图像数据集(如ImageNet)上预训练的模型,如VGG、ResNet等。然后,根据新数据集的特点,选择需要微调的模型层。对于低级特征的任务(如边缘检测),最好使用浅层模型的层,而对于高级特征的任务(如分类),则应选择更深层次的模型。

 

2、冻结预训练模型的参数

        保持预训练模型的权重不变,只训练新增加的层或者微调一些层,避免因为在数据集中过拟合导致预训练模型过度拟合。

 

3、在新数据集上训练新增加的层

        在冻结预训练模型的参数情况下,训练新增加的层。这样,可以使新模型适应新的任务,从而获得更高的性能。

 

4、微调预训练模型的层

        在新层上进行训练后,可以解冻一些已经训练过的层,并且将它们作为微调的目标。这样做可以提高模型在新数据集上的性能。

 

5、评估和测试

        在训练完成之后,使用测试集对模型进行评估。如果模型的性能仍然不够好,可以尝试调整超参数或者更改微调层。

 

二、迁移学习实例

  • 该实例使用的模型是ResNet-18残差神经网络模型

 

1.导入模型

  • 导入所要用的库,加载ResNet18模型
import torch
import torchvision.models as models
from torch import nn
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
import numpy as np"""将resnet18模型迁移到食物分类项目中"""
resent_model = models.resnet18(weights=models.ResNet18_Weights.DEFAULT)  # 既调用了resnet18网络,又使用了训练好的模型 在这里下载了模型

 

2.冻结模型参数

  • 将导入的模型参数冻结
for param in resent_model.parameters():param.requires_grad = False  # 设置每个参数的requires_grad属性为False,表示在训练过程中这些参数不需要计算梯度,也就是说它们不会在反向传播中更新。# print(param)
# 模型所有参数(即权重和偏差)的requires_grad属性设置为False,从而冻结所有模型参数
# 使得在反向传播过程中不会计算它们的梯度,以此减少模型的计算量,提高理速度。

 

3.修改参数

  • 因为我们所用的数据分类是20个,原模型分类是1000个,所以需要修改全连接层
  • 获取原模型输入层的特征个数
  • 将原模型的全连接层替换成原输入,输出为20的全连接层
  • 保存需要训练的参数,后面优化器进行优化时就可以只训练该层参数
in_features = resent_model.fc.in_features  # 获取模型原输入的特征个数
resent_model.fc = nn.Linear(in_features, 20)  # 创建一个全连接层,输入特征为in_features,输出为20param_to_update = []  # 保存需要训练的参数,仅仅包含全连接层的参数
for param in resent_model.parameters():if param.requires_grad == True:param_to_update.append(param)

 

4.创建类,数据增强,导入数据

  • 将图片从本地导入,并进行数据增强,最后进行打包
class food_dataset(Dataset):def __init__(self, file_path, transform=None):  # 类的初始化,解析数据文件txtself.file_path = file_pathself.imgs = []self.labels = []self.transform = transformwith open(self.file_path) as f:  # 是把train.txt文件中图片的路径保存在 self.imgs,train.txt文件中标签保存在self.label里samples = [x.strip().split(' ') for x in f.readlines()]  # 去掉首尾空格 再按空格分成两个元素for img_path, label in samples:self.imgs.append(img_path)  # 图像的路径self.labels.append(label)  # 标签,还不是tensor# 初始化:把图片目录加载到selfdef __len__(self):  # 类实例化对象后,可以使用len函数测量对象的个数return len(self.imgs)def __getitem__(self, idx):  # 关键,可通过索引的形式获取每一个图片数据及标签image = Image.open(self.imgs[idx])  # 读取到图片数据,还不是tensorif self.transform:# 将pil图像数据转换为tensorimage = self.transform(image)  # 图像处理为256x256,转换为tenorlabel = self.labels[idx]  # label还不是tensorlabel = torch.from_numpy(np.array(label, dtype=np.int64))  # label也转换为tensorreturn image, labeldata_transforms = {'train':transforms.Compose([transforms.Resize([300, 300]),transforms.RandomRotation(45),transforms.CenterCrop(224),transforms.RandomHorizontalFlip(p=0.5),transforms.RandomVerticalFlip(p=0.5),# transforms.ColorJitter(brightness=0.2, contrast=0.1, saturation=0.1, hue=0.1),transforms.RandomGrayscale(p=0.1),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 为 ImageNet 数据集计算的标准化参数]),'test':transforms.Compose([transforms.Resize([224, 224]),transforms.ToTensor(),transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])  # 为 ImageNet 数据集计算的标准化参数])
}train_data = food_dataset(file_path=r'trainda.txt',transform=data_transforms['train'])  # 64张图片为一个包  训练集60000张图片 打包成了938个包
test_data = food_dataset(file_path=r'testda.txt', transform=data_transforms['test'])train_dataloader = DataLoader(train_data, batch_size=64, shuffle=True)
test_dataloader = DataLoader(test_data, batch_size=64, shuffle=True)

 

5.定义训练集和测试集函数

def train(dataloader, model, loss_fn, optimizer):model.train()  # 告诉模型,我要开始训练,模型中w进行随机化操作,已经更新w.在训练过程中,w会被修改的batch_size_num = 1for x, y in dataloader:x, y = x.to(device), y.to(device)  # 把训练数据集和标签传入CPU或GPUpred = model.forward(x)  # 向前传播loss = loss_fn(pred, y)  # 通过交叉熵损失函数计算损失值lossoptimizer.zero_grad()  # 梯度值清零loss.backward()  # 反向传播计算得到每个参数的梯度值woptimizer.step()  # 根据梯度更新网络w参数loss_value = loss.item()  # 从tensor数据中提取数据出来,tensor获取损失值if batch_size_num % 40 == 0:print(f"loss:{loss_value:>7f} [number:{batch_size_num}]")batch_size_num += 1best_acc = 0def test(dataloader, model, loss_fn):global best_accsize = len(dataloader.dataset)num_batches = len(dataloader)model.eval()  # 测试,w就不能再更新。test_loss, correct = 0, 0with torch.no_grad():  # 一个上下文管理器,关闭梯度计算。当你确认不会调用Tensor.backward()的时候。这可以减少计算所占用的消耗for x, y in dataloader:x, y = x.to(device), y.to(device)pred = model.forward(x)test_loss += loss_fn(pred, y).item()  # test loss是会自动累加每一个批次的损失值correct += (pred.argmax(1) == y).type(torch.float).sum().item()test_loss /= num_batches  # 能来衡量模型测试的好坏。correct /= size  # 平均的正确率print(f"Test result: \n Accuracy: {(100 * correct)}%, Avg loss: {test_loss}\n")acc_s.append(correct)loss_s.append(test_loss)if correct > best_acc:  # 保存正确率最大的那一次的模型best_acc = correct

 

6.将模型传入GPU,并有序调整学习率

from torch import nndevice = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_avaibale() else 'cpu'
model = resent_model.to(device)  # 为什么不需要加括号,之前是model = CNN().to(device) 因为 resnet_model 是对象不是类"""有序调整学习率"""
loss_fn = nn.CrossEntropyLoss()  # 处理多分类
optimizer = torch.optim.Adam(param_to_update, lr=0.001)  # 仅训练最后一层的参数
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)  # 调整学习率

 

7.进行训练和测试

  • 选择训练100轮,每训练一轮,输出测试结果
epchos = 100
acc_s = []
loss_s = []
for t in range(epchos):print(f"Epoch {t + 1}\n--------------------------")train(train_dataloader, model, loss_fn, optimizer)scheduler.step()test(test_dataloader, model, loss_fn)
print('最优测试结果为:', best_acc)

输出:

相关文章:

深度学习:迁移学习

目录 一、迁移学习 1.什么是迁移学习 2.迁移学习的步骤 1、选择预训练的模型和适当的层 2、冻结预训练模型的参数 3、在新数据集上训练新增加的层 4、微调预训练模型的层 5、评估和测试 二、迁移学习实例 1.导入模型 2.冻结模型参数 3.修改参数 4.创建类&#xff…...

Footprint Growthly Quest 工具:赋能 Telegram 社区实现 Web3 飞速增长

作者:Stella L (stellafootprint.network) 在 Web3 的快节奏世界里,社区互动是关键。而众多 Web3 社区之所以能够蓬勃发展,很大程度上得益于 Telegram 平台。正因如此,Footprint Analytics 精心打造了 Growthly —— 一款专为 Tel…...

进入xwindows后挂起键盘鼠标没有响应@FreeBSD

问题: 在升级pkg包后,系统无法进入xfce等xwindows,表现为黑屏和看见鼠标,左上角有一个白字符块,键盘鼠标没有反应,整个系统卡住。但是可以ssh登录,内部的服务一切正常。 表现 处理过程&#xf…...

CentOS7.9 snmptrapd更改162端口

端口更改前: 命令: netstat -an |grep 162 [root@kibana snmp]# netstat -an | grep 162 udp 0 0 0.0.0.0:162 0.0.0.0:* unix 3 [ ] STREAM CONNECTED 45162 /run/systemd/journal/stdout u…...

模糊测试SFuzz亮相第32届中国国际信息通信展览会

9月25日,被誉为“中国ICT市场的创新基地和风向标”的第32届中国国际信息通信展在北京盛大开幕,本次展会将在为期三天的时间内,为信息通信领域创新成果、尖端技术和产品提供国家级交流平台。开源网安携模糊测试产品及相关解决方案精彩亮相&…...

CMake学习

向大佬lyf学习,先把其8服务器中所授fine 文章目录 前言一、CMakeList.txt 命令1. 最外层CMakeLists1.1 cmake_minimum_required()1.2 project()1.3 set()1.4 add_subdirectory(&…...

书生·浦语大模型全链路开源开放体系

书生浦语大模型全链路开源开放体系 大模型应用生态的发展和繁荣是建立在模型基座强大的通用基础能力之上的。上海AI实验室联合团队研究认为,大模型各项性能提升的基础在于语言建模能力的增强,对于大模型的研究应回归语言建模本质,通过更高质量…...

PHP安装swoole扩展无效,如何将文件上传至Docker容器

目录 过程 操作方式 过程 在没有使用过云服务器以前,Docker这个平台一直都很神秘。在我申请了华为云服务器,并使用WordPress镜像去搭建自己的网站以后,我不得不去把Docker平台弄清楚,原因是我使用的一个主题需要安装swoole扩展,才能够正常启用。而要将swoole.so这个扩展…...

Web3.0 应用项目

Web3.0 是下一代互联网的概念,旨在去中心化、用户拥有数据控制权和通过区块链技术实现信任的网络。Web3.0的应用项目主要集中在区块链、加密货币、去中心化应用 (DApps)、去中心化金融 (DeFi)、NFT(非同质化代币)等领域。以下是一些典型的 We…...

Linux 学习笔记(十六)—— 重定向与缓冲区

一、文件重定向 矩阵的下标,也就是文件描述符的分配规则,是从0开始空的最小的文件描述符分配给进程新打开的文件;文件输出重定向的原理是,关掉1(输出),然后打开文件,这个新打开的文…...

828华为云征文|WordPress部署

目录 前言 一、环境准备 二、远程连接 三、WordPress简介 四、WordPress安装 1. 基础环境安装 ​编辑 2. WordPress下载与解压 3. 创建站点 4. 数据库配置 总结 前言 WordPress 是一个非常流行的开源内容管理系统(Content Management System, CMS&#xf…...

华为开源自研AI框架昇思MindSpore应用案例:计算高效的卷积模型ShuffleNet

如果你对MindSpore感兴趣,可以关注昇思MindSpore社区 ShuffleNet ShuffleNet网络介绍 ShuffleNetV1是旷视科技提出的一种计算高效的CNN模型,和MobileNet, SqueezeNet等一样主要应用在移动端,所以模型的设计目标就是利用有限的计算资源来达到…...

《C++ 小游戏:简易飞机大战游戏的实现》

文章目录 《C 游戏代码解析:简易飞机大战游戏的实现》一、游戏整体结构与功能概述二、各个类和函数的功能分析(一)BK类 - 背景类(二)hero_plane类 - 玩家飞机类(三)plane_bullet类 - 玩家飞机发…...

SpringCloud源码:服务端分析(二)- EurekaServer分析

背景 从昨日的两篇文章:SpringCloud源码:客户端分析(一)- SpringBootApplication注解类加载流程、SpringCloud源码:客户端分析(二)- 客户端源码分析。 我们理解了客户端的初始化,其实…...

插槽slot在vue中的使用

介绍 在 Vue.js 中,插槽(slot)是一种用于实现组件内容分发的功能。通过插槽,可以让父组件在使用子组件时自定义子组件内部的内容。插槽提供了一种灵活的方式来组合和复用组件。 项目中有很多地方需要调用一个组件,比…...

针对考研的C语言学习(定制化快速掌握重点2)

1.C语言中字符与字符串的比较方法 在C语言中&#xff0c;单字符可以用进行比较也可以用 > , < ,但是字符串却不能用直接比较&#xff0c;需要用strcmp函数。 strcmp 函数的原型定义在 <string.h> 头文件中&#xff0c;其定义如下&#xff1a; int strcmp(const …...

[C++][IO流][流输入输出][截断理解]详细讲解

目录 1.流输入输出说明1.<<执行顺序2.>>执行顺序 2.截断(trunc)理解 1.流输入输出说明 1.<<执行顺序 链式操作的顺序&#xff1a;当使用多个<<操作符进行链式插入时&#xff0c;执行顺序是从左到右的 每个<<操作都将数据插入到前一个流的输出中…...

阿里云部署1Panel(失败版)

官网脚本部署不成功 这个不怪1panel,这个是阿里Linux 拉不到docker的下载源,懒得思考 正常部署直接打开官网 https://1panel.cn/docs/installation/online_installation/ 但是我使用的阿里云os(Alibaba Cloud Linux 3.2104 LTS 64位) 我执行不管用啊装不上docker 很烦 curl -s…...

九、设备的分配与回收

1.设备分配时应考虑的因素 ①设备的固有属性 设备的固有属性可分为三种:独占设备、共享设备、虚拟设备。 独占设备 一个时段只能分配给一个进程(如打印机) 共享设备 可同时分配给多个进程使用(如磁盘)&#xff0c;各进程往往是宏观上同时共享使用设备而微观上交替使用。 …...

单片机的原理及应用

单片机的原理及应用 1. 单片机的基本原理 1.1. 组成部分 单片机主要由以下几个部分组成&#xff1a; 中央处理器&#xff08;CPU&#xff09;&#xff1a;执行指令并控制整个系统的操作。 存储器&#xff1a; 程序存储器&#xff08;Flash&#xff09;&#xff1a;存储用户…...

Python数据分析篇--NumPy--入门

我什么也没忘&#xff0c;但是有些事只适合收藏。不能说&#xff0c;也不能想&#xff0c;却又不能忘。 -- 史铁生 《我与地坛》 NumPy相关知识 1. NumPy&#xff0c;全称是 Numerical Python&#xff0c;它是目前 Python 数值计算中最重要的基础模块。 2. NumPy 是针对多…...

OJ在线评测系统 后端 判题机模块预开发 架构分析 使用工厂模式搭建

判题机模块预开发(架构师)(工厂模式) 判题机模块 是为了把代码交个代码沙箱去处理 得到结果返回 代码沙箱 梳理判题模块和代码沙箱的关系 判题模块&#xff1a;调用代码沙箱 把代码和输入交给代码沙箱去执行 代码沙箱&#xff1a;只负责接受代码和输入 返回编译的结果 不负…...

linux 目录文件夹操作

目录 查看文件夹大小&#xff1a; Linux统计文件个数 2.统计文件夹中文件个数ls -l ./|grep "^-"|wc -l 4.统计文件夹下文件个数&#xff0c;包括子文件ls -lR | grep "^-"| wc -l 统计文件个数 移动绝对目录&#xff1a; 移动相对目录 test.py报错…...

(Linux驱动学习 - 4).Linux 下 DHT11 温湿度传感器驱动编写

DHT11的通信协议是单总线协议&#xff0c;可以用之前学习的pinctl和gpio子系统完成某IO引脚上数据的读与写。 一.在设备树下添加dht11的设备结点 1.流程图 2.设备树代码 &#xff08;1&#xff09;.在设备树的 iomuxc结点下添加 pinctl_dht11 &#xff08;2&#xff09;.在根…...

前端登录页面验证码

首先&#xff0c;在el-form-item里有两个div&#xff0c;各占一半&#xff0c;左边填验证码&#xff0c;右边生成验证码 <el-form-item prop"code"><div style"display: flex " prop"code"><el-input placeholder"请输入验证…...

【鸿蒙】HarmonyOS NEXT应用开发快速入门教程之布局篇(上)

系列文章目录 【鸿蒙】HarmonyOS NEXT开发快速入门教程之ArkTS语法装饰器&#xff08;上&#xff09; 【鸿蒙】HarmonyOS NEXT开发快速入门教程之ArkTS语法装饰器&#xff08;下&#xff09; 【鸿蒙】HarmonyOS NEXT应用开发快速入门教程之布局篇&#xff08;上&#xff09; 文…...

使用 Nginx 和 Gunicorn 部署 Flask 项目详细教程

使用 Nginx 和 Gunicorn 部署 Flask 项目详细教程 在这篇文章中&#xff0c;我们将介绍如何使用 Nginx 和 Gunicorn 来部署一个 Flask 项目。这种部署方式非常适合在生产环境中使用&#xff0c;因为它能够提供更好的性能和更高的稳定性。 目录 Flask 项目简介环境准备Gunico…...

linux中bashrc和profile环境变量在Shell编程变量的传递作用

在 Linux 系统中&#xff0c;.bashrc文件和.profile文件都是用于配置用户环境的重要文件&#xff0c;它们之间有以下关联&#xff1a; 一、作用相似性 环境设置&#xff1a;两者都用于设置用户的环境变量和启动应用程序的配置。例如&#xff0c;它们可以定义路径变量&#xf…...

数据结构-4.2.串的定义和基本操作

一.串的定义&#xff1a; 1.单/双引号不是字符串里的内容&#xff0c;他只是一个边界符&#xff0c;用来表示字符串的头和尾&#xff1b; 2.空串也是字符串的子串&#xff0c;空串长度为0&#xff1b; 3.字符的编号是从1开始&#xff0c;不是0&#xff1b; 4.空格也是字符&a…...

fastzdp_redis第一次开发, 2024年9月26日, Python操作Redis零基础快速入门

提供完整录播课 安装 pip install fastzdp_redisPython连接Redis import redis# 建立链接 r redis.Redis(hostlocalhost, port6379, db0)# 设置key r.set(foo, bar)# 获取key的值 print(r.get(foo))RESP3 支持 简单的理解: 支持更丰富的数据类型 参考文档: https://blog.c…...

外贸自建站如何收款/网站优化要做哪些

保存照片和视频到系统相册显示- http://blog.csdn.net/chendong_/article/details/52290329 Android 7.0 之拍照与图片裁剪适配-http://blog.csdn.net/yyh352091626/article/details/54908624 拍照、相册及裁剪的终极实现&#xff08;一&#xff09;——拍照及裁剪功能实现- ht…...

做网站卖东西赚钱吗/上海百度关键词搜索推广服务

删除文件时排除特定文件删除当前目录下所有 *.txt文件&#xff0c;除了test.txtrm ls *.txt|egrep -v test.txt或者rm ls *.txt|awk {if($0 ! "test.txt") print $0}排除多个文件rm ls *.txt|egrep -v (test.txt|fff.txt|ppp.txt)---------------rm -f ls *.log.1|eg…...

武汉做网站哪家专业/优化设计答案六年级

本帖最后由 uytyrt 于 2017-2-15 09:16 编辑一、mysql同步原理1、mysql主库在事务提交时将数据变更作为时间记录到二进制日志(binary log)中&#xff1b;2、slave IO线程将master的binary log events读写到它的中继(Relay log)&#xff1b;3、slave SQL进程读取Relay log&#…...

php+ajax网站开发典型实例pdf/网络营销是什么专业

http://blog.csdn.net/fakine/article/details/7517417/ 1. 使用库函数 string.h strstr函数 函数名: strstr 功 能: 在串中查找指定字符串的第一次出现 用 法: char *strstr(char *str1, char *str2); 说明&#xff1a;返回指向第一次出现str2位置的指针&#xff0c;如果…...

中国建设银行网站上不去/嘉定区整站seo十大排名

自我补充&#xff1a;ubuntu在线安装sqlite3数据库的方法&#xff1a; 系统平台&#xff1a;ubuntu12.04在ubuntu里面直接使用命令:sudo apt-get install sqlite3 &#xff0c;出现&#xff1a;……………………libsqlite3-0 ( 3.7.9-2ubuntu1) but 3.7.9-2ubuntu1.1 is to be…...

什么样的网站适合优化/武汉网络seo公司

2019独角兽企业重金招聘Python工程师标准>>> netstat -apn|grep 7021 查看端口使用情况 lsof -i :7020 查看端口进程 转载于:https://my.oschina.net/shootercn/blog/16899...