当前位置: 首页 > news >正文

常用组件详解(十):保存与加载模型、检查点机制的使用

文章目录

  • 1.保存、加载模型
  • 2.torch.nn.Module.state_dict()
    • 2.1基本使用
    • 2.2保存和加载状态字典
  • 3.创建Checkpoint
    • 3.1基本使用
    • 3.2完整案例


1.保存、加载模型

  torch.save()用于保存一个序列化对象到磁盘上,该序列化对象可以是任何类型的对象,包括模型、张量和字典等(内部使用pickle模块实现对象的序列化)。数据会被保存为.pt.pth格式,可通过torch.load()从磁盘加载被保存的序列化对象,加载时会重新构造出原来的对象。
  torch.save()有两种保存模型的方式:

  • 1.保存整个模型(继承了torch.nn.Module的类),不推荐使用。
    • torch.load():利用pickle将保存的序列化对象反序列化,得到原始数据。可用于加载完整模型或状态字典。
#保存整个模型
torch.save(model, PATH)
#加载模型
model = torch.load(PATH)
  • 2.仅保存模型的参数(状态字典state_dict),推荐使用。
    • torch.nn.Module.load_state_dict():通过反序列化得到模型的state_dict()(状态字典)来加载模型,传入的参数是状态字典,而非.pt.pth文件。
#只保存模型参数
torch.save(model.state_dict(), PATH)
#加载模型
model=Model()
model.load_state_dict(torch.load(PATH))

  在实际使用中推荐第二种方式,第一种方式往往容易产生各种错误:

  • 设备错误。若在cuda:0上训练好一个模型并保存,则读取出来的模型也是默认在cuda:0上,如果训练过程的其他数据被放到了cuda:1上,那么就会发生错误:
RuntimeError: arguments are located on different GPUs at /opt/conda/conda-bld/pytorch_1503966894950/work/torch/lib/THC/generated/../generic/THCTensorMathPointwise.cu:215

此时需要将其他其他数据都保存在cuda:0上,或加载模型时指定使用cuda:1

device = torch.device("cuda:1")
model = torch.load(PATH, map_location=device)
  • 版本错误:比如使用pytorch1.0训练并保存CNN模型,再用pytorch1.1读取模型,则会出现错误:
AttributeError: 'Conv2d' object has no attribute 'padding_mode'

此时只能通过获取该模型的参数来加载新的模型:

#加载模型参数
model_state = torch.load(model_path).state_dict()
#初始化新模型并加载参数
model = Model()
model.load_state_dict(model_state)

2.torch.nn.Module.state_dict()

2.1基本使用

  torch.nn.Module.state_dict()用于返回模型的状态字典,其中保存了模型的可学习参数。其中,只有可学习参数的层(卷积层、全连接层等)和注册缓冲区(batchnorm’s running_mean)才会作为模型参数保存(优化器也有状态字典,也可进行保存)。
【例子】

import torch
import torch.nn as nn
import torch.optim as optim# 定义模型
class TheModelClass(nn.Module):def __init__(self):super(TheModelClass, self).__init__()self.conv1 = nn.Conv2d(3, 6, 5)self.pool = nn.MaxPool2d(2, 2)self.conv2 = nn.Conv2d(6, 16, 5)self.fc1 = nn.Linear(16 * 5 * 5, 120)self.fc2 = nn.Linear(120, 84)self.fc3 = nn.Linear(84, 10)def forward(self, x):x = self.pool(F.relu(self.conv1(x)))x = self.pool(F.relu(self.conv2(x)))x = x.view(-1, 16 * 5 * 5)x = F.relu(self.fc1(x))x = F.relu(self.fc2(x))x = self.fc3(x)return x# 初始化模型
model = TheModelClass()# 初始化优化器
optimizer = optim.SGD(model.parameters(), lr=0.001, momentum=0.9)# 打印模型的状态字典
print("Model's state_dict:")
for param_tensor in model.state_dict():print(param_tensor, "\t", model.state_dict()[param_tensor].size())# 打印优化器的状态字典
print("Optimizer's state_dict:")
for var_name in optimizer.state_dict():print(var_name, "\t", optimizer.state_dict()[var_name])

  查看模型与优化器的状态字典:
在这里插入图片描述

2.2保存和加载状态字典

  通过torch.save()来保存模型的状态字典(state_dict),即只保存学习到的模型参数,并通过torch.nn.Module.load_state_dict()来加载并恢复模型参数。PyTorch中最常见的模型保存扩展名为.pt.pth

#保存模型状态字典
PATH = './test_state_dict.pth'
torch.save(model.state_dict(), PATH)
#根据状态字典加载模型
model = TheModelClass()
model.load_state_dict(torch.load(PATH))
model.eval()
#打印新模型的状态字典
print("Model's state_dict:")
for param_tensor in model.state_dict():print(param_tensor, "\t", model.state_dict()[param_tensor].size())

  注意,模型推理之前,需要调用model.eval()函数将dropoutbatch normalization层设置为评估模式,否则会导致模型推理结果不一致。
在这里插入图片描述

3.创建Checkpoint

3.1基本使用

  模型检查点(checkpoint)是指模型训练过程中保存的模型状态,包括模型参数(权重与偏置)、优化器状态等其他相关的训练信息。通过保存检查点,可以实现在训练过程中定期保存模型的当前状态,以便在需要时恢复训练或用于模型评估和推理。模型检查点常见的保存信息如下:

  • 1.模型权重:模型的状态字典。
  • 2.优化器状态:优化器的状态字典。
  • 3.训练状态:当前的训练轮数(epoch)、批次(batch)等。
  • 4.其他数据:如学习率调度器的状态、自定义指标等。

例如:
【保存检查点】

#将模型参数和优化器状态的状态字典保存到检查点中
checkpoint = {'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'loss': loss.item(),'epoch':epoch
}#保存检查点
torch.save(checkpoint, 'checkpoint.pth')

【加载检查点】

# 加载检查点
checkpoint = torch.load('checkpoint.pth')# 恢复模型和优化器状态
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])# 恢复训练状态
epoch = checkpoint['epoch']
loss = checkpoint['loss']# 如果是恢复训练,可以从保存的epoch继续
for epoch in range(epoch, num_epochs):# 继续训练

3.2完整案例

import torch
import torch.nn as nn
import torch.optim as optim# 假设有一个简单的模型
class SimpleModel(nn.Module):def __init__(self):super(SimpleModel, self).__init__()self.fc = nn.Linear(10, 1)def forward(self, x):return self.fc(x)model = SimpleModel()
optimizer = optim.SGD(model.parameters(), lr=0.01)
loss_fn = nn.MSELoss()# 训练循环
num_epochs = 100
for epoch in range(num_epochs):# 假设有输入x和目标yx = torch.randn(64, 10)y = torch.randn(64, 1)optimizer.zero_grad()output = model(x)loss = loss_fn(output, y)loss.backward()optimizer.step()# 每10个epoch保存一次检查点if epoch % 10 == 0:checkpoint = {'model_state_dict': model.state_dict(),'optimizer_state_dict': optimizer.state_dict(),'epoch': epoch,'loss': loss.item()}torch.save(checkpoint, f'checkpoint_epoch_{epoch}.pth')# 加载检查点并继续训练
checkpoint = torch.load('checkpoint_epoch_10.pth')
model.load_state_dict(checkpoint['model_state_dict'])
optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
start_epoch = checkpoint['epoch']
loss = checkpoint['loss']# 从第11个epoch开始继续训练
for epoch in range(start_epoch + 1, num_epochs):# 继续训练pass

相关文章:

常用组件详解(十):保存与加载模型、检查点机制的使用

文章目录 1.保存、加载模型2.torch.nn.Module.state_dict()2.1基本使用2.2保存和加载状态字典 3.创建Checkpoint3.1基本使用3.2完整案例 1.保存、加载模型 torch.save()用于保存一个序列化对象到磁盘上,该序列化对象可以是任何类型的对象,包括模型、张量…...

基于SpringBoot+Vue+MySQL的在线学习交流平台

系统展示 用户前台界面 管理员后台界面 系统背景 随着互联网技术的飞速发展,在线学习已成为现代教育的重要组成部分。传统的面对面教学方式已无法满足广大学习者的需求,特别是在时间、地点上受限的学习者。因此,构建一个基于SpringBoot、Vue.…...

前端开发在AI时代如何保持核心竞争力

随着人工智能(AI)技术的迅猛发展,前端开发领域正经历着前所未有的变革。AI辅助开发工具、自动化测试框架、智能代码补全等技术的出现,极大地提高了开发效率,同时也对前端开发人员的技能和角色提出了新的要求。在这个背…...

ffmpeg面向对象——拉流协议匹配机制探索

目录 1.URLProtocol类2.协议匹配的核心接口3. URLContext类4. 综合调用流程图5.rtsp拉流协议匹配流程图及对象图5.1 rtsp拉流协议调用流程图5.2 rtsp拉流协议对象图 6.本地文件调用流程图及对象图6.1 本地文件调用流程图6.2 本地文件对象图 7.内存数据调用流程图及对象图7.1 内…...

R语言绘制柱状图

柱状图是一种数据可视化工具。由 x 轴和 y 轴构成,x 轴表示类别,y 轴为数据数值。以矩形柱子展示数据大小,便于直观比较不同类别数据差异及了解分布。广泛应用于销售分析、统计、项目管理、科学研究等领域。可定制颜色、宽度等属性&#xff0…...

GNU/Linux - tarball文件介绍介绍

Linux 中的 tarball 文件是将多个文件和目录归档到一个文件中的常用方法,通常用于备份、分发或打包目的。术语 “tarball ”来源于 “tar”(磁带归档的缩写)命令的使用,该命令最初设计用于将数据写入磁带等顺序存储设备。如今&…...

AppointmentController

目录 1、 AppointmentController 1.1、 删除预约单据信息 1.2、 反审核预约单 1.3、 SelectToMainten AppointmentController using QXQPS.Models; using QXQPS.Vo; using System; using System.Collections; using System.Collections.Generic; using System.L…...

网站建设完成后,切勿让公司官网成为摆设

在当今这个数字化时代,公司官网已经成为企业展示形象、传递信息、吸引客户的重要平台。然而,许多企业在网站建设完成后,往往忽视了对官网的持续运营和维护,导致官网逐渐沦为摆设,无法发挥其应有的作用。为了确保公司官…...

独孤思维:闲得蛋疼才去做副业

独孤现实中玩的要好的朋友。 他们都只在自己的社交圈,工作圈链接。 没有人知道,副业可以这么玩。 所以他们很好奇,问我,独孤,你最开始是怎么知道这些副业的? 其实,独孤最开始接触副业&#…...

vulnhub靶场之hackablell

一.环境搭建 1.靶场描述 difficulty: easy This works better with VirtualBox rather than VMware 2.靶场下载 https://download.vulnhub.com/hackable/hackableII.ova 3.靶场启动 二.信息收集 1.寻找靶场的真实ip nmap -SP 192.168.246.0/24 arp-scan -l 根据上面两个…...

《浔川社团官方通报 —— 为何明确 10 月 2 日上线的浔川 AI 翻译 v3.0 再次被告知延迟上线》

《浔川社团官方通报 —— 为何明确 10 月 2 日上线的浔川 AI 翻译 v3.0 再次被告知延迟上线》 各位关注浔川社团的朋友们: 大家好!首先,我们要向一直期待浔川 AI 翻译 v3.0 上线的朋友们致以最诚挚的歉意。原定于 10 月 2 日上线的浔川 AI 翻…...

加密与安全_HOTP一次性密码生成算法

文章目录 HOTP 的基础原理HOTP 的工作流程HOTP 的应用场景HOTP 的安全性安全性增强措施Code生成HOTP可配置项校验HOTP可拓展功能计数器(counter)计数器在客户端和服务端的作用计数器的同步机制客户端和服务端中的计数器表现服务端如何处理计数器不同步计…...

ResNet18果蔬图像识别分类

关于深度实战社区 我们是一个深度学习领域的独立工作室。团队成员有:中科大硕士、纽约大学硕士、浙江大学硕士、华东理工博士等,曾在腾讯、百度、德勤等担任算法工程师/产品经理。全网20多万粉丝,拥有2篇国家级人工智能发明专利。 社区特色…...

深度强化学习中收敛图的横坐标是steps还是episode?

在深度强化学习(Deep Reinforcement Learning, DRL)的收敛图中,横坐标选择 steps 或者 episodes 主要取决于算法的设计和实验的需求,两者的差异和使用场景如下: Steps(步数): 定义&a…...

一个真实可用的登录界面!

需要工具: MySQL数据库、vscode上的php插件PHP Server等 项目结构: login | --backend | --database.sql |--login.php |--welcome.php |--index.html |--script.js |--style.css 项目开展 index.html: 首先需要一个静态网页&#x…...

Vue中watch监听属性的一些应用总结

【1】vue2中watch的应用 ① 简单监视 在 Vue 2 中,如果你不需要深度监视,即只需监听顶层属性的变化,可以使用简写形式来定义 watch。这种方式更加简洁,适用于大多数基本场景。 示例代码 假设你有一个 Vue 组件,其中…...

MongoDB-aggregate流式计算:带条件的关联查询使用案例分析

在数据库的查询中,是一定会遇到表关联查询的。当两张大表关联时,时常会遇到性能和资源问题。这篇文章就是用一个例子来分享MongoDB带条件的关联查询发挥的作用。 假设工作环境中有两张MongoDB集合:SC_DATA(学生基本信息集合&…...

Redis数据库与GO(一):安装,string,hash

安装包地址:https://github.com/tporadowski/redis/releases 建议下载zip版本,解压即可使用。解压后,依次打开目录下的redis-server.exe和redis-cli.exe,redis-cli.exe用于输入指令。 一、基本结构 如图,redis对外有个…...

expressjs,实现上传图片,返回图片链接

在 Express.js 中实现图片上传并返回图片链接,你通常需要使用一个中间件来处理文件上传,比如 multer。multer 是一个 node.js 的中间件,用于处理 multipart/form-data 类型的表单数据,主要用于上传文件。 以下是一个简单的示例&a…...

爬虫——XPath基本用法

第一章XML 一、xml简介 1.什么是XML? 1,XML指可扩展标记语言 2,XML是一种标记语言,类似于HTML 3,XML的设计宗旨是传输数据,而非显示数据 4,XML标签需要我们自己自定义 5,XML被…...

变量 varablie 声明- Rust 变量 let mut 声明与 C/C++ 变量声明对比分析

一、变量声明设计:let 与 mut 的哲学解析 Rust 采用 let 声明变量并通过 mut 显式标记可变性,这种设计体现了语言的核心哲学。以下是深度解析: 1.1 设计理念剖析 安全优先原则:默认不可变强制开发者明确声明意图 let x 5; …...

第19节 Node.js Express 框架

Express 是一个为Node.js设计的web开发框架,它基于nodejs平台。 Express 简介 Express是一个简洁而灵活的node.js Web应用框架, 提供了一系列强大特性帮助你创建各种Web应用,和丰富的HTTP工具。 使用Express可以快速地搭建一个完整功能的网站。 Expre…...

Qt/C++开发监控GB28181系统/取流协议/同时支持udp/tcp被动/tcp主动

一、前言说明 在2011版本的gb28181协议中,拉取视频流只要求udp方式,从2016开始要求新增支持tcp被动和tcp主动两种方式,udp理论上会丢包的,所以实际使用过程可能会出现画面花屏的情况,而tcp肯定不丢包,起码…...

鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院查看报告小程序

一、开发环境准备 ​​工具安装​​: 下载安装DevEco Studio 4.0(支持HarmonyOS 5)配置HarmonyOS SDK 5.0确保Node.js版本≥14 ​​项目初始化​​: ohpm init harmony/hospital-report-app 二、核心功能模块实现 1. 报告列表…...

Java入门学习详细版(一)

大家好,Java 学习是一个系统学习的过程,核心原则就是“理论 实践 坚持”,并且需循序渐进,不可过于着急,本篇文章推出的这份详细入门学习资料将带大家从零基础开始,逐步掌握 Java 的核心概念和编程技能。 …...

如何理解 IP 数据报中的 TTL?

目录 前言理解 前言 面试灵魂一问:说说对 IP 数据报中 TTL 的理解?我们都知道,IP 数据报由首部和数据两部分组成,首部又分为两部分:固定部分和可变部分,共占 20 字节,而即将讨论的 TTL 就位于首…...

【Android】Android 开发 ADB 常用指令

查看当前连接的设备 adb devices 连接设备 adb connect 设备IP 断开已连接的设备 adb disconnect 设备IP 安装应用 adb install 安装包的路径 卸载应用 adb uninstall 应用包名 查看已安装的应用包名 adb shell pm list packages 查看已安装的第三方应用包名 adb shell pm list…...

安卓基础(Java 和 Gradle 版本)

1. 设置项目的 JDK 版本 方法1:通过 Project Structure File → Project Structure... (或按 CtrlAltShiftS) 左侧选择 SDK Location 在 Gradle Settings 部分,设置 Gradle JDK 方法2:通过 Settings File → Settings... (或 CtrlAltS)…...

Python 高效图像帧提取与视频编码:实战指南

Python 高效图像帧提取与视频编码:实战指南 在音视频处理领域,图像帧提取与视频编码是基础但极具挑战性的任务。Python 结合强大的第三方库(如 OpenCV、FFmpeg、PyAV),可以高效处理视频流,实现快速帧提取、压缩编码等关键功能。本文将深入介绍如何优化这些流程,提高处理…...

小木的算法日记-多叉树的递归/层序遍历

🌲 从二叉树到森林:一文彻底搞懂多叉树遍历的艺术 🚀 引言 你好,未来的算法大神! 在数据结构的世界里,“树”无疑是最核心、最迷人的概念之一。我们中的大多数人都是从 二叉树 开始入门的,它…...