Pytorch-day07-模型保存与读取
PyTorch 模型保存&读取
- 模型存储
- 模型单卡存储&多卡存储
- 模型单卡读取&多卡读取
1、模型存储
- PyTorch存储模型主要采用pkl,pt,pth三种格式,就使用层面来说没有区别
- PyTorch模型主要包含两个部分:模型结构和权重。其中模型是继承nn.Module的类,权重的数据结构是一个字典(key是层名,value是权重向量)
- 存储也由此分为两种形式:存储整个模型(包括结构和权重)和只存储模型权重(推荐)。
import torch
from torchvision import models
model = models.resnet50(pretrained=True)
save_dir = './resnet50.pth'# 保存整个 模型结构+权重
torch.save(model, save_dir)
# 保存 模型权重
torch.save(model.state_dict, save_dir)# pt, pth和pkl三种数据格式均支持模型权重和整个模型的存储
2、模型单卡存储&多卡存储
- PyTorch中将模型和数据放到GPU上有两种方式——.cuda()和.to(device)
- 注:如果要使用多卡训练的话,需要对模型使用torch.nn.DataParallel
2.1、nn.DataParrallel
<CLASS torch.nn.DataParallel(module, device_ids=None, output_device=None, dim=0)>
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-puyISgkD-1692613764220)(attachment:image.png)]
- module即表示你定义的模型
- device_ids表示你训练的device
- output_device这个参数表示输出结果的device,而这最后一个参数output_device一般情况下是省略不写的,那么默认就是在device_ids[0]
注:因此一般情况下第一张显卡的内存使用占比会更多
import os
import torch
from torchvision import models
#单卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 如果是多卡改成类似0,1,2
model = model.cuda() # 单卡
#print(model)
---------------------------------------------------------------------------RuntimeError Traceback (most recent call last)~\AppData\Local\Temp/ipykernel_7460/77570021.py in <module>1 import os2 os.environ['CUDA_VISIBLE_DEVICES'] = '0' # 如果是多卡改成类似0,1,2
----> 3 model = model.cuda() # 单卡D:\Users\xulele\Anaconda3\lib\site-packages\torch\nn\modules\module.py in cuda(self, device)903 Module: self904 """
--> 905 return self._apply(lambda t: t.cuda(device))906 907 def ipu(self: T, device: Optional[Union[int, device]] = None) -> T:D:\Users\xulele\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _apply(self, fn)795 def _apply(self, fn):796 for module in self.children():
--> 797 module._apply(fn)798 799 def compute_should_use_set_data(tensor, tensor_applied):D:\Users\xulele\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _apply(self, fn)818 # `with torch.no_grad():`819 with torch.no_grad():
--> 820 param_applied = fn(param)821 should_use_set_data = compute_should_use_set_data(param, param_applied)822 if should_use_set_data:D:\Users\xulele\Anaconda3\lib\site-packages\torch\nn\modules\module.py in <lambda>(t)903 Module: self904 """
--> 905 return self._apply(lambda t: t.cuda(device))906 907 def ipu(self: T, device: Optional[Union[int, device]] = None) -> T:D:\Users\xulele\Anaconda3\lib\site-packages\torch\cuda\__init__.py in _lazy_init()245 if 'CUDA_MODULE_LOADING' not in os.environ:246 os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
--> 247 torch._C._cuda_init()248 # Some of the queued calls may reentrantly call _lazy_init();249 # we need to just return without initializing in that case.RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-0G4NTv1z-1692613764220)(attachment:ed8eb711294e4c6e3e43690ddb2bf66.png)]
#多卡
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1'
model = torch.nn.DataParallel(model).cuda() # 多卡
#print(model)
[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-eHt1Dn8t-1692613764221)(attachment:image.png)]
2.3、单卡保存+单卡加载
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #这里替换成希望使用的GPU编号
model = models.resnet50(pretrained=True)
model.cuda()save_dir = 'resnet50.pt' #保存路径# 保存+读取整个模型
torch.save(model, save_dir)
loaded_model = torch.load(save_dir)
loaded_model.cuda()# 保存+读取模型权重
torch.save(model.state_dict(), save_dir)
# 先加载模型结构
loaded_model = models.resnet50()
# 在加载模型权重
loaded_model.load_state_dict(torch.load(save_dir))
loaded_model.cuda()
---------------------------------------------------------------------------RuntimeError Traceback (most recent call last)~\AppData\Local\Temp/ipykernel_7460/585340704.py in <module>5 os.environ['CUDA_VISIBLE_DEVICES'] = '0' #这里替换成希望使用的GPU编号6 model = models.resnet50(pretrained=True)
----> 7 model.cuda()8 9 save_dir = 'resnet50.pt' #保存路径D:\Users\xulele\Anaconda3\lib\site-packages\torch\nn\modules\module.py in cuda(self, device)903 Module: self904 """
--> 905 return self._apply(lambda t: t.cuda(device))906 907 def ipu(self: T, device: Optional[Union[int, device]] = None) -> T:D:\Users\xulele\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _apply(self, fn)795 def _apply(self, fn):796 for module in self.children():
--> 797 module._apply(fn)798 799 def compute_should_use_set_data(tensor, tensor_applied):D:\Users\xulele\Anaconda3\lib\site-packages\torch\nn\modules\module.py in _apply(self, fn)818 # `with torch.no_grad():`819 with torch.no_grad():
--> 820 param_applied = fn(param)821 should_use_set_data = compute_should_use_set_data(param, param_applied)822 if should_use_set_data:D:\Users\xulele\Anaconda3\lib\site-packages\torch\nn\modules\module.py in <lambda>(t)903 Module: self904 """
--> 905 return self._apply(lambda t: t.cuda(device))906 907 def ipu(self: T, device: Optional[Union[int, device]] = None) -> T:D:\Users\xulele\Anaconda3\lib\site-packages\torch\cuda\__init__.py in _lazy_init()245 if 'CUDA_MODULE_LOADING' not in os.environ:246 os.environ['CUDA_MODULE_LOADING'] = 'LAZY'
--> 247 torch._C._cuda_init()248 # Some of the queued calls may reentrantly call _lazy_init();249 # we need to just return without initializing in that case.RuntimeError: Found no NVIDIA driver on your system. Please check that you have an NVIDIA GPU and installed a driver from http://www.nvidia.com/Download/index.aspx
2.4、单卡保存+多卡加载
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #这里替换成希望使用的GPU编号
model = models.resnet50(pretrained=True)
model.cuda()# 保存+读取整个模型
torch.save(model, save_dir)os.environ['CUDA_VISIBLE_DEVICES'] = '1,2' #这里替换成希望使用的GPU编号
loaded_model = torch.load(save_dir)
loaded_model = nn.DataParallel(loaded_model).cuda()# 保存+读取模型权重
torch.save(model.state_dict(), save_dir)os.environ['CUDA_VISIBLE_DEVICES'] = '1,2' #这里替换成希望使用的GPU编号
loaded_model = models.resnet50() #注意这里需要对模型结构有定义
loaded_model.load_state_dict(torch.load(save_dir))
loaded_model = nn.DataParallel(loaded_model).cuda()
2.5、多卡保存+单卡加载
核心问题:如何去掉权重字典键名中的"module",以保证模型的统一性
- 对于加载整个模型,直接提取模型的module属性即可
- 对于加载模型权重,保存模型时保存模型的module属性对应的权重
os.environ['CUDA_VISIBLE_DEVICES'] = '1,2' #这里替换成希望使用的GPU编号model = models.resnet50(pretrained=True)
model = nn.DataParallel(model).cuda()# 保存+读取整个模型
torch.save(model, save_dir)os.environ['CUDA_VISIBLE_DEVICES'] = '0' #这里替换成希望使用的GPU编号
loaded_model = torch.load(save_dir).module
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2' #这里替换成希望使用的GPU编号model = models.resnet50(pretrained=True)
model = nn.DataParallel(model).cuda()# 保存权重
torch.save(model.module.state_dict(), save_dir)#加载模型权重
os.environ['CUDA_VISIBLE_DEVICES'] = '0' #这里替换成希望使用的GPU编号
loaded_model = models.resnet50() #注意这里需要对模型结构有定义
loaded_model.load_state_dict(torch.load(save_dir))
loaded_model.cuda()
2.6、多卡保存+多卡加载
保存整个模型时会同时保存所使用的GPU id等信息,读取时若这些信息和当前使用的GPU信息不符则可能会报错或者程序不按预定状态运行。可能出现以下2个问题:
- 1、读取整个模型再使用nn.DataParallel进行分布式训练设置,这种情况很可能会造成保存的整个模型中GPU id和读取环境下设置的GPU id不符,训练时数据所在device和模型所在device不一致而报错
- 2、读取整个模型而不使用nn.DataParallel进行分布式训练设置,发现程序会自动使用设备的前n个GPU进行训练(n是保存的模型使用的GPU个数)。此时如果指定的GPU个数少于n,则会报错
建议方案:
- 只模型权重,之后再使用nn.DataParallel进行分布式训练设置则没有问题
- 因此多卡模式下建议使用权重的方式存储和读取模型
os.environ['CUDA_VISIBLE_DEVICES'] = '0,1,2' #这里替换成希望使用的GPU编号model = models.resnet50(pretrained=True)
model = nn.DataParallel(model).cuda()# 保存+读取模型权重,强烈建议!!
torch.save(model.state_dict(), save_dir)
#加载模型 权重
loaded_model = models.resnet50() #注意这里需要对模型结构有定义
loaded_model.load_state_dict(torch.load(save_dir)))
loaded_model = nn.DataParallel(loaded_model).cuda()
建议
- 不管是单卡保存还是多卡保存,建议以保存模型权重为主
- 不管是单卡还是多卡,先load模型权重,再指定是多卡加载(nn.DataParallel)或单卡(cuda)
# 使用案例(截取片段代码)My_model.eval()
test_total_loss = 0
test_total_correct = 0
test_total_num = 0past_test_loss = 0 #上一轮的loss
save_model_step = 10 # 每10步保存一次modelfor iter,(images,labels) in enumerate(test_loader):images = images.to(device)labels = labels.to(device)outputs = My_model(images)loss = criterion(outputs,labels)test_total_correct += (outputs.argmax(1) == labels).sum().item()test_total_loss += loss.item()test_total_num += labels.shape[0]test_loss = test_total_loss / test_total_numprint("Epoch [{}/{}], train_loss:{:.4f}, train_acc:{:.4f}%, test_loss:{:.4f}, test_acc:{:.4f}%".format(i+1, epoch, train_total_loss / train_total_num, train_total_correct / train_total_num * 100, test_total_loss / test_total_num, test_total_correct / test_total_num * 100))# model saveif test_loss<past_test_loss:#保存模型权重torch.save(model.state_dict(), save_dir)#保存 模型权重+模型结构#torch.save(model, save_dir)if iter % save_model_step == 0:#保存模型权重torch.save(model.state_dict(), save_dir)#保存 模型权重+模型结构#torch.save(model, save_dir)past_test_loss = test_loss
单卡保存&单卡读取 案例
Google Colab:https://colab.research.google.com/drive/1hEOeqXYm4BfulY6d30QCI4HrFmCmmTQu?usp=sharing
相关文章:
Pytorch-day07-模型保存与读取
PyTorch 模型保存&读取 模型存储模型单卡存储&多卡存储模型单卡读取&多卡读取 1、模型存储 PyTorch存储模型主要采用pkl,pt,pth三种格式,就使用层面来说没有区别PyTorch模型主要包含两个部分:模型结构和权重。其中模型是继承n…...
【C语言每日一题】01. Hello, World!
题目来源:http://noi.openjudge.cn/ch0101/01/ 01. Hello, World! 总时间限制: 1000ms 内存限制: 65536kB 问题描述 对于大部分编程语言来说,编写一个能够输出“Hello, World!”的程序往往是最基本、最简单的。因此,这个程序常常作为一个初…...
arm: day8
1.中断实验:按键控制led灯 流程: key.h /*************************************************************************> File Name: include/key.h> Created Time: 2023年08月21日 星期一 17时03分20秒***************************************…...
k8s容器加入host解析字段
一、通过edit或path来修改 kubectl edit deploy /xxxxx. x-n cattle-system xxxxx为你的资源对象名称 二、添加字段 三、code hostAliases:- hostnames:- www.rancher.localip: 10.10.2.180...
浅谈开发过程中完善的注释的重要性
第一部分:引言 1.1 简述编程注释的定义和功能 编程注释是一种在源代码中添加的辅助性文字,它不参与编译或执行,但对于理解源代码起着至关重要的作用。注释可以简单地描述代码的功能,也可以详细地解释算法的工作原理、设计决策的…...
Docker 微服务实战
1. 通过IDEA新建一个普通微服务模块 1.1 建Module docker_boot 1.2 改写pom <?xml version"1.0" encoding"UTF-8"?><project xmlns"http://maven.apache.org/POM/4.0.0" xmlns:xsi"http://www.w3.org/2001/XMLSchema-instance&…...
JupyterHub实战应用
一、JupyerHub jupyter notebook 是一个非常有用的工具,我们可以在浏览器中任意编辑调试我们的python代码,并且支持markdown 语法,可以说是科研利器。但是这种情况适合个人使用,也就是jupyter notebook以我们自己的主机作为服务器…...
【MySQL】视图
目录 一、什么是视图 二、视图的操作 2.1 创建视图 2.2 删除视图 三、视图规则和限制 一、什么是视图 视图是一个虚拟表,其内容由查询定义。同真实的表一样,视图包含一系列带有名称的列和行数据。视图的数据变化会影响到基表(创建视图所…...
基于 Android 剧院购票APP的开发与设计
摘要:近年来,随着社会的发展和科技方面的创新,越来越多的人选择使用手机应用程序来购买剧场票。本文将探讨基于 Android 平台的剧院购票应用程序的开发和设计。该应用程序将为用户提供浏览剧场列表、查看剧场详情、选择座位并购买剧场票的功能…...
反转链表II
江湖一笑浪滔滔,红尘尽忘了 题目 示例 思路 链表这部分的题,不少都离不开单链表的反转,参考:反转一个单链表 这道题加上哨兵位的话会简单很多,如果不加的话,还需要分情况一下,像是从头节点开始…...
HTML 和 CSS 来实现毛玻璃效果(Glassmorphism)
毛玻璃效果简介 它的主要特征就是半透明的背景,以及阴影和边框。 同时还要为背景加上模糊效果,使得背景之后的元素根据自身内容产生漂亮的“变形”效果,示例: 代码实现 首先,创建一个 HTML 文件,写入如下…...
【技术】国标GB28181视频平台EasyGBS通过对应密钥上传到其他平台展示的详细步骤
国标GB28181协议视频平台EasyGBS是基于国标GB28181协议的视频云服务平台,支持多路设备同时接入,并对多平台、多终端分发出RTSP、RTMP、FLV、HLS、WebRTC等格式的视频流。平台可提供视频监控直播、云端录像、云存储、检索回放、智能告警、语音对讲、平台级…...
SpeedBI数据可视化工具:浏览器上做分析
SpeedBI数据分析云是一种在浏览器上进行数据可视化分析的工具,它能够将数据以可视化的形式呈现出来,并支持多种数据源和图表类型。 所有操作,均在浏览器上进行 在浏览器中打开SpeedBI数据分析云官网,点击【免费使用】进入&#…...
8.21笔记
Deeplab-MSc-LargrFOC 此图除了主输出之外,还有五个支线输出,他们池化层与VGG网络不同,其中卷积核大小是3,而VGG中卷积核大小为2(这个网络一开始是基于VGG网络提出的,因为那时候提出比较早,没有…...
MyBatis-Plus中公共字段的统一处理
数据库中一些表的公共字段,例如修改时间、修改人、创建时间、创建人,我们一般都是这样来处理的: employee.setCreateTime(LocalDateTime.now()); employee.setUpdateTime(LocalDateTime.now()); employee.setCreateUser(UserHolder.get()); …...
SQL的导出与导入
1、导入 使用命令行导入 1.登录sql界面; 2.create database Demo新建一个库; 3.选中数据库use Demo;选中导入路径source D:Demo.sql; 4.查看表show tables; 2、导出 整个sql mysqldump -u username -ppassword dbname > dbname.sq…...
记录一次wordpress项目的发布过程
背景:发布一套已完成的代码到线上,有完整的代码包,sql文件,环境是linux 宝塔。无wordpress相关经验。 过程:正常的发布代码 问题1:访问自己的域名后跳转到别的域名。 解决: 修改数据表wp_optio…...
HTML详解连载(8)
HTML详解连载(8) 专栏链接 [link](http://t.csdn.cn/xF0H3)下面进行专栏介绍 开始喽浮动-产品区域布局场景 解决方法清除浮动方法一:额外标签发方法二:单伪元素法方法三:双伪元素法方法四:overflow浮动-总结…...
Linux系统之安装OneNav个人书签管理器
Linux系统之安装OneNav个人书签管理器 一、OneNav介绍1.OneNav简介2.OneNav特点 二、本地环境介绍2.1 本地环境规划2.2 本次实践介绍 三、检查本地环境3.1 检查本地操作系统版本3.2 检查系统内核版本3.3 检查本地yum仓库状态 四、安装httpd服务4.1 安装httpd4.2 启动httpd服务4…...
主程技术分享: 游戏项目帧同步,状态同步如何选
网络游戏开发项目中帧同步,状态同步如何选? 网络游戏的核心技术之一就是玩家的网络同步,主流的网络同步有”帧同步”与”状态同步”。今天我们来分析一下这两种同步模式。同时教大家如何在自己的项目中采用最合适的同步方式。接下来从以下3个方面来阐述: 对啦&…...
ChatGPT-4: 半年的深度使用思考
几个月的时间一直在使用 ChatGpt-4,以口述语音转文字的形式说一下自己的体会。 1、选择版本 大前提:我使用的都是 GPT4 的版本。也就是说至少每个月要付费20$。 因为 3.5 的版本,实际上使用体验是非常差的,主要体现在答非所问上。…...
【健康医疗】Axure用药提醒小程序原型图,健康管理用药助手原型模板
作品概况 页面数量:共 20 页 兼容软件:Axure RP 9/10,不支持低版本 应用领域:健康管理,用药助手 作品申明:页面内容仅用于功能演示,无实际功能 作品特色 本作品为「用药提醒」小程序原型图…...
ERROR in static/js/xxx.js from UglifyJs
老项目用的webpack3,打包的时候遇到**ERROR in static/js/xxx.js from UglifyJs**这个报错, UglifyJS是个包含JS解释器、代码最小化、压缩、美化的工具集,是前端开发打包的最常用工具之一,只支持ES5,不支持ES6&#x…...
阿里云ECS服务器安装PostgreSQL
1. 概述 PostgreSQL是一个功能强大的开源数据库,它支持丰富的数据类型和自定义类型,其提供了丰富的接口,可以自行扩展其功能,支持使用流行的编程语言编写自定义函数 PostgreSQL数据库有如下优势: PostgreSQL数据库时…...
【核磁共振成像】傅里叶重建
目录 一、傅里叶重建二、填零三、移相四、数据窗函数五、矩形视野六、多线圈数据重建七、图像变形校正八、缩放比例九、基线校准 长TR,长TE,是T2加权像; 短TR,短TE,是T1加权像; 长TR,短TE&#…...
Camunda 工作流节点跳转 - 多实例节点判断和跳转
在多种工作流引擎中,Camunda框架对流程的处理控制更为强大、灵活。 在应对流程节点按业务需要进行自由跨节点跳转的需求时,通过代码自由控制节点的跳转在Camunda中是支持的,并且提供了编码方法,其中多实例的处理上有一些区别要特…...
MySQL不停重启问题
MySQL不停的自动杀掉自动重启 看一下log日志 my.cnf 里配置的 log_error /var/log/mysqld.log vim /var/log/mysqld.log 报的错误只是 [ERROR] Cant start server: Bind on TCP/IP port: Address already in use [ERROR] Do you already have another mysqld server …...
ol-cesium 暴露 Cesium viewer 对象以及二三维切换、viewer 添加点功能示例
ol-cesium 暴露 Cesium viewer 对象以及二三维切换、viewer 添加点功能示例 核心代码完整代码在线示例 二三维一体化的概念一直都比较火热,虽然大多数都是狭义的概念,但是很多需求方也想要这样的功能。 Openlayers 官方出了一个二三维一体化的工具&…...
国产化-达梦数据库安装2
目录 DM8数据库下载地址 安装一路狂飙next 启动服务 随着国家政府的推广、越来越多的政府项目、在系统部署需要采购国产服务器、数据库等 DM8数据库下载地址 https://eco.dameng.com/download/ 安装一路狂飙next windos安装比较简单直接next即可 仅仅记录几个关键疑问地方k…...
延长OLED透明屏的使用寿命:关键因素与有效方法分享
OLED透明屏作为一项创新的显示技术,具备透明度和高清晰度的特点,在各个领域得到了广泛应用。 然而,为了确保OLED透明屏的持久性和稳定性,延长其使用寿命是至关重要的。根据最新的研究和数据报告, 在这篇文章中&#…...
网站管理工作是具体应该怎么做/微信推广引流方法
2019独角兽企业重金招聘Python工程师标准>>> NDoc 可以将 C#.NET 编译生成的程序集和对应的 /doc XML 文档,自动转换成如 .NET Framework SDK 类库文档或者 MSDN Library 在线 .NET 类库文档形式的代码文档,让您快速拥有专业级的类库API 文档…...
在线建设房屋设计网站/360搜索引擎首页
2019独角兽企业重金招聘Python工程师标准>>> 如果说用“永存、曲折、已死、重生”来形容Java,笔者以为一点也不为过。 1991年,James Gosling带领着名为“Green Team”的团队着手研发一种新的语言以及专为下一代数字设备和计算机使用的网络系统…...
中国机械加工网站官网/星巴克seo网络推广
国际经济与贸易就业 转载于:https://blog.51cto.com/439927/88599...
浙江省两学一做网站/线下宣传渠道和宣传方式
默认的select标签比较难看,UI比较漂亮,如果想要实现UI上的下拉样式,好像必须用js写select,从网上拷贝而且修改了一个下拉框,为了方便以后引用所以记录下来。 /* diy_select */ .diy_select{height:30px;width:90px; po…...
南京做网站费用/济南网络推广网络营销
处理移动端click事件300毫秒延迟。FastClick 是一个简单,易于使用的js库用于消除在移动浏览器上触发click事件与一个物理Tap(敲击)之间的300延迟。 1、为什么会延迟? 从点击屏幕上的元素到触发元素的 click 事件,移动浏览器会有大约 300 毫秒…...
深圳汇网网站建设/搜索引擎哪个最好用
转载“共享博客” 原文 http://www.sharedblog.cn/?post120 容器溢出 语法: overflow: visible | hidden | scroll | auto | inherit; visible: 默认值,容器溢出不裁剪,正常显示; hidden: 溢出部分隐藏不可见; scroll: 当容器没有溢…...