当前位置: 首页 > news >正文

Pytorch微调深度学习模型

在公开数据训练了模型,有时候需要拿到自己的数据上微调。今天正好做了一下微调,在此记录一下微调的方法。用Pytorch还是比较容易实现的。

网上找了很多方法,以及Chatgpt也给了很多方法,但是不够简洁和容易理解。

大体步骤是:

1、加载训练好的模型。

2、冻结不想微调的层,设置想训练的层。(这里可以新建一个层替换原有层,也可以不新建层,直接微调原有层)

3、训练即可。

1、先加载一个模型

我这里是训练好的一个SqueezeNet模型,所有模型都适用。

## 加载要微调的模型
# 环境里必须有模型的框架,才能torch.load
from Model.main_SqueezeNet import SqueezeNet,Firemodel = torch.load("Model/SqueezeNet.pth").to(device)
print(model)
# 输出结果
SqueezeNet((stem): Sequential((0): Conv2d(1, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)(3): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False))(fire2): Fire((squeeze): Sequential((0): Conv2d(8, 4, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(4, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(expand_1x1): Sequential((0): Conv2d(4, 8, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(expand_3x3): Sequential((0): Conv2d(4, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)))(fire3): Fire((squeeze): Sequential((0): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(expand_1x1): Sequential((0): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(expand_3x3): Sequential((0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)))(fire4): Fire((squeeze): Sequential((0): Conv2d(16, 8, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(expand_1x1): Sequential((0): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True))(expand_3x3): Sequential((0): Conv2d(8, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))(1): BatchNorm2d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(2): ReLU(inplace=True)))(conv10): Conv2d(16, 2, kernel_size=(1, 1), stride=(1, 1))(avg): AdaptiveAvgPool2d(output_size=1)(maxpool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
)

print(model)时会显示模型每个层的名字。这里我想对conv10层进行微调,因为它是最后一个具有参数可以微调的层了。当然,如果最后一层是全连接的话,也建议微调最后全连接层。 

2、冻结不想训练的层。

这里就有两种不同的方法了:一是新建一个conv10层,替换掉原来的层。二是不新建,直接微调原来的层。

新建:

model.conv10 = nn.Conv2d(model.conv10.in_channels, model.conv10.out_channels, model.conv10.kernel_size, model.conv10.stride)
print(model)

可以直接用model.conv10.in_channels等加载原来层的各种参数。这样就定义好了一个新的conv10层,并且已经替换进了模型中。

然后先冻结所有层(requires_grad = False),再放开conv10层(requires_grad = True)。

# 先冻结所有层
for param in model.parameters():param.requires_grad = False# 仅对conv10层进行微调,如果在冻结后新定义了conv10层,这两行可以不写,默认有梯度
for param in model.conv10.parameters():param.requires_grad = True

如果不新建层,则不需要运行model.conv10 = nn.Conv2d那一行即可。直接开始冻结就可以。

 3、训练

这里一定要注意,optimizer里要设置参数 model.conv10.parameters(),而不是model.parameters()。这是让模型知道它将要训练哪些参数。

optimizer = optim.SGD(model.conv10.parameters(), lr=1e-2)

虽然上面已经冻结了不想训练的参数,但是这里最好还是写上model.conv10.parameters()。大家也可以试试不写行不行。

# 使用交叉熵损失函数
criterion = nn.CrossEntropyLoss()
# 只优化conv10层的参数
optimizer = optim.SGD(model.conv10.parameters(), lr=1e-2)
# 将模型移到GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 设置模型为训练模式
model.train()num_epochs = 10
for epoch in range(num_epochs):# model.train()running_loss = 0.0correct = 0for x_train, y_train in data_loader:x_train, y_train = x_train.to(device), y_train.to(device)print(x_train.shape, y_train.shape)# 前向传播outputs = model(x_train)loss = criterion(outputs, y_train)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item() * x_train.size(0)# 统计训练集的准确率_, predicted = torch.max(outputs, 1)correct += (predicted == y_train).sum().item()# 计算每个 epoch 的训练损失和准确率epoch_loss = running_loss / len(dataset)epoch_accuracy = 100 * correct / len(dataset)# if epoch % 5 == 0 or epoch == num_epochs-1 :print(f'Epoch [{epoch+1}/{num_epochs}]')print(f'Train Loss: {epoch_loss:.4f}, Train Accuracy: {epoch_accuracy:.2f}%')

输出显示Loss下降说明模型有在学习。 模型准确率从0变成100,还是非常有成就感的!当然我这里就用了一个样本来微调hhhh。

Epoch [1/10]
Train Loss: 0.8185, Train Accuracy: 0.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [2/10]
Train Loss: 0.7063, Train Accuracy: 0.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [3/10]
Train Loss: 0.6141, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [4/10]
Train Loss: 0.5385, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [5/10]
Train Loss: 0.4761, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [6/10]
Train Loss: 0.4244, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [7/10]
Train Loss: 0.3812, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [8/10]
Train Loss: 0.3449, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [9/10]
Train Loss: 0.3140, Train Accuracy: 100.00%
torch.Size([1, 1, 32, 16]) torch.Size([1])
Epoch [10/10]
Train Loss: 0.2876, Train Accuracy: 100.00%

4、验证一下确实是只有这个层参数变化了,而其他层参数没变。

在训练模型之前,看一下这个层的参数:

raw_parm = model.conv10.weight
print(raw_parm)
# 部分输出为
Parameter containing:
tensor([[[[-0.1621]],[[ 0.0288]],[[ 0.1275]],[[ 0.1584]],[[ 0.0248]],[[-0.2013]],[[-0.2086]],[[ 0.1460]],[[ 0.0566]],[[ 0.2897]],[[ 0.2898]],[[ 0.0610]],[[ 0.2172]],[[ 0.0860]],[[ 0.2730]],[[-0.1053]]],

训练后,也输出一下这个层的参数:

## 查看微调后模型的参数
tuned_parm = model.conv10.weight
print(tuned_parm)
# 部分输出为:
Parameter containing:
tensor([[[[-0.1446]],[[ 0.0365]],[[ 0.1490]],[[ 0.1783]],[[ 0.0424]],[[-0.1826]],[[-0.1903]],[[ 0.1636]],[[ 0.0755]],[[ 0.3092]],[[ 0.3093]],[[ 0.0833]],[[ 0.2405]],[[ 0.1049]],[[ 0.2925]],[[-0.0866]]],

可见这个层的参数确实是变了。

然后检查一下别的随便一个层:

训练前:

# 训练前
raw_parm = model.stem[0].weight
print(raw_parm)
# 部分输出为:
Parameter containing:
tensor([[[[-0.0723, -0.2151,  0.1123],[-0.2114,  0.0173, -0.1322],[-0.0819,  0.0748, -0.2790]]],[[[-0.0918, -0.2783, -0.3193],[ 0.0359,  0.2993, -0.3422],[ 0.1979,  0.2499, -0.0528]]],

训练后:

## 查看微调后模型的参数
tuned_parm = model.stem[0].weight
print(tuned_parm)
# 部分输出为:
Parameter containing:
tensor([[[[-0.0723, -0.2151,  0.1123],[-0.2114,  0.0173, -0.1322],[-0.0819,  0.0748, -0.2790]]],[[[-0.0918, -0.2783, -0.3193],[ 0.0359,  0.2993, -0.3422],[ 0.1979,  0.2499, -0.0528]]],

可见参数没有变化。说明这层没有进行学习。

5、为了让大家更容易全面理解,完整代码如下。

import torch
import numpy as np
import torch.optim as optim
import torch.nn as nn
from torchinfo import summary
from torch.utils.data import DataLoader, Dataset,TensorDataset
from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, precision_score, recall_score, f1_score
import matplotlib.pyplot as plt
from imblearn.under_sampling import RandomUnderSampler # 多数样本下采样device = torch.device("cuda" if torch.cuda.is_available() else "cpu")## 加载微调数据
feats = np.load("feats_jn105.npy")
labels = np.array([0])
print(feats.shape)
print(labels.shape)# 将data和labels转换为 PyTorch 张量
data_tensor = torch.tensor(feats, dtype = torch.float32, requires_grad=True)
labels_tensor = torch.tensor(labels, dtype = torch.long)# 添加通道维度
# data_tensor = data_tensor.unsqueeze(1)  # 变为(num, 1, 32, 16)
batch_size = 15# 创建 TensorDataset
dataset = TensorDataset(data_tensor, labels_tensor)
data_loader = DataLoader(dataset, batch_size = batch_size, shuffle = False)
input, label = next(iter(data_loader))
print(input.shape,label.shape)
# upyter nbconvert --to script ./Model/main_SqueezeNet.ipynb # 终端运行,ipynb转py## 加载要微调的模型
# 环境里必须有模型的框架,才能torch.load
from Model.main_SqueezeNet import SqueezeNet,Firemodel = torch.load("Model/SqueezeNet.pth").to(device)
print(model)# 为模型写一个新的层
# model.fc = nn.Linear(in_features = model.fc.in_features, out_features = model.fc.out_features)
model.conv10 = nn.Conv2d(model.conv10.in_channels, model.conv10.out_channels, model.conv10.kernel_size, model.conv10.stride)
print(model)# 先冻结所有层
for param in model.parameters():param.requires_grad = False# 仅对conv10层进行微调,如果在冻结后新定义了conv10层,这两行可以不写,默认有梯度
for param in model.conv10.parameters():param.requires_grad = Trueraw_parm = model.stem[0].weight
print(raw_parm)
for name, param in model.named_parameters():print(name, param.requires_grad)# 使用交叉熵损失函数
criterion = nn.CrossEntropyLoss()# 只优化c10层的参数
optimizer = optim.SGD(model.conv10.parameters(), lr=1e-2)# 将模型移到GPU(如果可用)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)# 设置模型为训练模式
model.train()num_epochs = 10
for epoch in range(num_epochs):# model.train()running_loss = 0.0correct = 0for x_train, y_train in data_loader:x_train, y_train = x_train.to(device), y_train.to(device)print(x_train.shape, y_train.shape)# 前向传播outputs = model(x_train)loss = criterion(outputs, y_train)# 反向传播和优化optimizer.zero_grad()loss.backward()optimizer.step()running_loss += loss.item() * x_train.size(0)# 统计训练集的准确率_, predicted = torch.max(outputs, 1)correct += (predicted == y_train).sum().item()# 计算每个 epoch 的训练损失和准确率epoch_loss = running_loss / len(dataset)epoch_accuracy = 100 * correct / len(dataset)# if epoch % 5 == 0 or epoch == num_epochs-1 :print(f'Epoch [{epoch+1}/{num_epochs}]')print(f'Train Loss: {epoch_loss:.4f}, Train Accuracy: {epoch_accuracy:.2f}%')## 查看微调后模型的参数
tuned_parm = model.stem[0].weight
print(tuned_parm)

如有更好的方法,欢迎大家分享~

相关文章:

Pytorch微调深度学习模型

在公开数据训练了模型,有时候需要拿到自己的数据上微调。今天正好做了一下微调,在此记录一下微调的方法。用Pytorch还是比较容易实现的。 网上找了很多方法,以及Chatgpt也给了很多方法,但是不够简洁和容易理解。 大体步骤是&…...

springboot 使用笔记

1.springboot 快速启动项目 注意:该启动只是临时启动,不能关闭终端面板 cd /www/wwwroot java -jar admin.jar2.脚本启动 linux shell脚本启动springboot服务 3.java一键部署springboot 第5条 https://blog.csdn.net/qq_30272167/article/details/1…...

网络安全基础——网络安全法

填空题 1.根据**《中华人民共和国网络安全法》**第二十条(第二款),任何组织和个人试用网路应当遵守宪法法律,遵守公共秩序,遵守社会公德,不危害网络安全,不得利用网络从事危害国家安全、荣誉和利益,煽动颠…...

SCAU软件体系结构实验四 组合模式

目录 一、题目 二、源码 一、题目 个人(Person)与团队(Team)可以形成一个组织(Organization):组织有两种:个人组织和团队组织,多个个人可以组合成一个团队,不同的个人与团队可以组合成一个更大的团队。 使用控制台或者JavaFx界面…...

Amazon商品详情API接口:电商创新与用户体验的驱动力

在电子商务蓬勃发展的今天,作为全球最大的电商平台之一,亚马逊(Amazon)凭借其强大的技术实力和丰富的商品资源,为全球用户提供了优质的购物体验。其中,Amazon商品详情API接口在电商创新与用户体验提升方面扮…...

手机无法连接服务器1302什么意思?

你有没有遇到过手机无法连接服务器,屏幕上显示“1302”这样的错误代码?尤其是在急需使用手机进行工作或联系朋友时,突然出现的连接问题无疑会带来不少麻烦。那么,什么是1302错误,它又意味着什么呢? 1302错…...

Android adb shell dumpsys audio 信息查看分析详解

Android adb shell dumpsys audio 信息查看分析详解 一、前言 Android 如果要分析当前设备的声音通道相关日志, 仅仅看AudioService的日志是看不到啥日志的,但是看整个audio关键字的日志又太多太乱了, 所以可以看一下系统提供的一个调试指令…...

Python 网络爬虫操作指南

网络爬虫是自动化获取互联网上信息的一种工具。它广泛应用于数据采集、分析以及实现信息聚合等众多领域。本文将为你提供一个完整的Python网络爬虫操作指南,帮助你从零开始学习并实现简单的网络爬虫。我们将涵盖基本的爬虫概念、Python环境配置、常用库介绍。 上传…...

基于FPGA的2FSK调制-串口收发-带tb仿真文件-实际上板验证成功

基于FPGA的2FSK调制 前言一、2FSK储备知识二、代码分析1.模块分析2.波形分析 总结 前言 设计实现连续相位 2FSK 调制器,2FSK 的两个频率为:fI15KHz,f23KHz,波特率为 1500 bps,比特0映射为f 载波,比特1映射为 载波。 1&#xff09…...

JavaScript的基础数据类型

一、JavaScript中的数组 定义 数组是一种特殊的对象,用于存储多个值。在JavaScript中,数组可以包含不同的数据类型,如数字、字符串、对象、甚至其他数组。数组的创建有两种常见方式: 字面量表示法:let fruits [apple…...

第三讲 架构详解:“隐语”可信隐私计算开源框架

目录 隐语架构 隐语架构拆解 产品层 算法层 计算层 资源层 互联互通 跨域管控 本文主要是记录参加隐语开源社区推出的第四期隐私计算实训营学习到的相关内容。 隐语架构 隐语架构拆解 产品层 产品定位: 通过可视化产品,降低终端用户的体验和演…...

JDBC编程---Java

目录 一、数据库编程的前置 二、Java的数据库编程----JDBC 1.概念 2.JDBC编程的优点 三.导入MySQL驱动包 四、JDBC编程的实战 1.创造数据源,并设置数据库所在的位置,三条固定写法 2.建立和数据库服务器之间的连接,连接好了后&#xff…...

Python绘制太极八卦

文章目录 系列目录写在前面技术需求1. 图形绘制库的支持2. 图形绘制功能3. 参数化设计4. 绘制控制5. 数据处理6. 用户界面 完整代码代码分析1. rset() 函数2. offset() 函数3. taiji() 函数4. bagua() 函数5. 绘制过程6. 技术亮点 写在后面 系列目录 序号直达链接爱心系列1Pyth…...

Spring框架特性及包下载(Java EE 学习笔记04)

1 Spring 5的新特性 Spring 5是Spring当前最新的版本,与历史版本对比,Spring 5对Spring核心框架进行了修订和更新,增加了很多新特性,如支持响应式编程等。 更新JDK基线 因为Spring 5代码库运行于JDK 8之上,所以Spri…...

Linux关于vim的笔记

Linux关于vim的笔记:(vimtutor打开vim 教程) --------------------------------------------------------------------------------------------------------------------------------- 1. 光标在屏幕文本中的移动既可以用箭头键,也可以使用 hjkl 字母键…...

linux mount nfs开机自动挂载远程目录

要在Linux系统中实现开机自动挂载NFS共享目录,你需要编辑/etc/fstab文件。以下是具体步骤和示例: 确保你的系统已经安装了NFS客户端。如果没有安装,可以使用以下命令安装: sudo apt-install nfs-common 编辑/etc/fstab文件&#…...

【vue】导航守卫

什么是导航守卫 在vue路由切换过程中对行为做个限制 全局前置守卫 route.beforeEach((to, from, next)) > {// to是切换到的路由// from是正要离开的路由// next控制是否允许进入目标路由next(false); //不允许 }路由级别的导航守卫 const routes [{path: /User,name: U…...

基于Matlab实现LDPC编码

在无线通信和数据存储领域,LDPC(低密度奇偶校验码)编码是一种高效、纠错能力强大的错误校正技术。本MATLAB仿真程序全面地展示了如何在AWGN(加性高斯白噪声)信道下应用LDPC编码与BPSK(二进制相移键控&#…...

PostgreSQL 中约束Constraints

在 PostgreSQL 中,约束(Constraints)是用于限制进入数据库表中数据的规则。它们确保数据的准确性和可靠性,通过定义规则来防止无效数据的插入或更新。PostgreSQL 支持多种类型的约束,每种约束都有特定的用途和语法。以…...

✨系统设计时应时刻考虑设计模式基础原则

目录 💫单一职责原则 (Single Responsibility Principle, SRP)💫开放-封闭原则 (Open-Closed Principle, OCP)💫依赖倒转原则 (Dependency Inversion Principle, DIP)💫里氏代换原则 (Liskov Substitution Principle, LSP)&#x…...

UE5 学习系列(二)用户操作界面及介绍

这篇博客是 UE5 学习系列博客的第二篇,在第一篇的基础上展开这篇内容。博客参考的 B 站视频资料和第一篇的链接如下: 【Note】:如果你已经完成安装等操作,可以只执行第一篇博客中 2. 新建一个空白游戏项目 章节操作,重…...

第一篇:Agent2Agent (A2A) 协议——协作式人工智能的黎明

AI 领域的快速发展正在催生一个新时代,智能代理(agents)不再是孤立的个体,而是能够像一个数字团队一样协作。然而,当前 AI 生态系统的碎片化阻碍了这一愿景的实现,导致了“AI 巴别塔问题”——不同代理之间…...

Python 包管理器 uv 介绍

Python 包管理器 uv 全面介绍 uv 是由 Astral(热门工具 Ruff 的开发者)推出的下一代高性能 Python 包管理器和构建工具,用 Rust 编写。它旨在解决传统工具(如 pip、virtualenv、pip-tools)的性能瓶颈,同时…...

NXP S32K146 T-Box 携手 SD NAND(贴片式TF卡):驱动汽车智能革新的黄金组合

在汽车智能化的汹涌浪潮中,车辆不再仅仅是传统的交通工具,而是逐步演变为高度智能的移动终端。这一转变的核心支撑,来自于车内关键技术的深度融合与协同创新。车载远程信息处理盒(T-Box)方案:NXP S32K146 与…...

关于uniapp展示PDF的解决方案

在 UniApp 的 H5 环境中使用 pdf-vue3 组件可以实现完整的 PDF 预览功能。以下是详细实现步骤和注意事项&#xff1a; 一、安装依赖 安装 pdf-vue3 和 PDF.js 核心库&#xff1a; npm install pdf-vue3 pdfjs-dist二、基本使用示例 <template><view class"con…...

WPF八大法则:告别模态窗口卡顿

⚙️ 核心问题&#xff1a;阻塞式模态窗口的缺陷 原始代码中ShowDialog()会阻塞UI线程&#xff0c;导致后续逻辑无法执行&#xff1a; var result modalWindow.ShowDialog(); // 线程阻塞 ProcessResult(result); // 必须等待窗口关闭根本问题&#xff1a…...

消防一体化安全管控平台:构建消防“一张图”和APP统一管理

在城市的某个角落&#xff0c;一场突如其来的火灾打破了平静。熊熊烈火迅速蔓延&#xff0c;滚滚浓烟弥漫开来&#xff0c;周围群众的生命财产安全受到严重威胁。就在这千钧一发之际&#xff0c;消防救援队伍迅速行动&#xff0c;而豪越科技消防一体化安全管控平台构建的消防“…...

【Kafka】Kafka从入门到实战:构建高吞吐量分布式消息系统

Kafka从入门到实战:构建高吞吐量分布式消息系统 一、Kafka概述 Apache Kafka是一个分布式流处理平台,最初由LinkedIn开发,后成为Apache顶级项目。它被设计用于高吞吐量、低延迟的消息处理,能够处理来自多个生产者的海量数据,并将这些数据实时传递给消费者。 Kafka核心特…...

jdbc查询mysql数据库时,出现id顺序错误的情况

我在repository中的查询语句如下所示&#xff0c;即传入一个List<intager>的数据&#xff0c;返回这些id的问题列表。但是由于数据库查询时ID列表的顺序与预期不一致&#xff0c;会导致返回的id是从小到大排列的&#xff0c;但我不希望这样。 Query("SELECT NEW com…...

土建施工员考试:建筑施工技术重点知识有哪些?

《管理实务》是土建施工员考试中侧重实操应用与管理能力的科目&#xff0c;核心考查施工组织、质量安全、进度成本等现场管理要点。以下是结合考试大纲与高频考点整理的重点内容&#xff0c;附学习方向和应试技巧&#xff1a; 一、施工组织与进度管理 核心目标&#xff1a; 规…...