当前位置: 首页 > news >正文

模型转换 PyTorch转ONNX 入门

前言

本文主要介绍如何将PyTorch模型转换为ONNX模型,为后面的模型部署做准备。转换后的xxx.onnx模型,进行加载和测试。最后介绍使用Netron,可视化ONNX模型,看一下网络结构;查看使用了那些算子,以便开发部署。

目录

前言

一、PyTorch模型转ONNX模型

1.1 转换为ONNX模型且加载权重

1.2 转换为ONNX模型但不加载权重

1.3 torch.onnx.export() 函数

二、加载ONNX模型

三、可视化ONNX模型


一、PyTorch模型转ONNX模型

将PyTorch模型转换为ONNX模型,通常是使用torch.onnx.export( )函数来转换的,基本的思路是:

  1. 加载PyTorch模型,可以选择只加载模型结构;也可以选择加载模型结构和权重。
  2. 然后定义PyTorch模型的输入维度,比如(1, 3, 224, 224),这是一个三通道的彩色图,分辨率为224x224。
  3. 最后使用torch.onnx.export( )函数来转换,生产xxx.onnx模型。

下面有一个简单的例子:

import torch
import torch.onnx# 加载 PyTorch 模型
model = ...# 设置模型输入,包括:通道数,分辨率等
dummy_input = torch.randn(1, 3, 224, 224, device='cpu')# 转换为ONNX模型
torch.onnx.export(model, dummy_input, "model.onnx", export_params=True)

1.1 转换为ONNX模型且加载权重

这里举一个resnet18的例子,基本思路是:

  1. 首先加载了一个预训练的 ResNet18 模型;
  2. 然后将其设置为评估模式。接下来定义一个与模型输入张量形状相同的输入张量,并使用 torch.randn() 函数生成了一个随机张量。
  3. 最后,使用 onnx.export() 函数将 PyTorch 模型转换为 ONNX 格式,并将其保存到指定的输出文件中。

程序如下:

import torch
import torchvision.models as models# 加载预训练的 ResNet18 模型
model = models.resnet18(pretrained=True)# 将模型设置为评估模式
model.eval()# 定义输入张量,需要与模型的输入张量形状相同
input_shape = (1, 3, 224, 224)
x = torch.randn(input_shape)# 需要指定输入张量,输出文件路径和运行设备
# 默认情况下,输出张量的名称将基于模型中的名称自动分配
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 将 PyTorch 模型转换为 ONNX 格式
output_file = "resnet18.onnx"
torch.onnx.export(model, x.to(device), output_file, export_params=True)

1.2 转换为ONNX模型但不加载权重

举一个resnet18的例子:基本思路是:

  1. 首先加载了一个预训练的 ResNet18 模型;
  2. 然后使用 onnx.export() 函数将 PyTorch 模型转换为 ONNX 格式;指定参数do_constant_folding=False,不加载模型的权重。
import torch
import torchvision.models as models# 加载 PyTorch 模型
model = models.resnet18()# 将模型转换为 ONNX 格式但不加载权重
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "resnet18.onnx", do_constant_folding=False)

下面构建一个简单网络结构,并转换为ONNX

import torch
import torchvision
import numpy as np# 定义一个简单的PyTorch 模型
class MyModel(torch.nn.Module):def __init__(self):super(MyModel, self).__init__()self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)self.relu = torch.nn.ReLU()self.maxpool = torch.nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.flatten = torch.nn.Flatten()self.fc1 = torch.nn.Linear(64 * 8 * 8, 10)def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.maxpool(x)x = self.conv2(x)x = self.relu(x)x = self.maxpool(x)x = self.flatten(x)x = self.fc1(x)return x# 创建模型实例
model = MyModel()# 指定模型输入尺寸
dummy_input = torch.randn(1, 3, 32, 32)# 将PyTorch模型转为ONNX模型
torch.onnx.export(model, dummy_input, 'mymodel.onnx',  do_constant_folding=False)

1.3 torch.onnx.export() 函数

看一下这个函数的参数

torch.onnx.export(model, args, f, export_params=True, opset_version=10, do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes=None, verbose=False, example_outputs=None, keep_initializers_as_inputs=None)
  • model:需要导出的 PyTorch 模型
  • args:PyTorch模型输入数据的尺寸,指定通道数、长和宽。可以是单个 Tensor 或元组,也可以是元组列表。
  • f:导出的 ONNX 文件路径和名称,mymodel.onnx。
  • export_params:是否导出模型参数。如果设置为 False,则不导出模型参数。
  • opset_version:导出的 ONNX 版本。默认值为 10。
  • do_constant_folding:是否对模型进行常量折叠。如果设置为 True,不加载模型的权重。
  • input_names:模型输入数据的名称。默认为 'input'。
  • output_names:模型输出数据的名称。默认为 'output'。
  • dynamic_axes:动态轴的列表,允许在导出的 ONNX 模型中创建变化的维度。
  • verbose:是否输出详细的导出信息。
  • example_outputs:用于确定导出 ONNX 模型输出形状的样本输出。
  • keep_initializers_as_inputs:是否将模型的初始化器作为输入导出。如果设置为 True,则模型初始化器将被作为输入的一部分导出。

下面是只是一个常用的模板

import torch.onnx # 转为ONNX
def Convert_ONNX(model): # 设置模型为推理模式model.eval() # 设置模型输入的尺寸dummy_input = torch.randn(1, input_size, requires_grad=True)  # 导出ONNX模型  torch.onnx.export(model,         # model being run dummy_input,       # model input (or a tuple for multiple inputs) "xxx.onnx",       # where to save the model  export_params=True,  # store the trained parameter weights inside the model file opset_version=10,    # the ONNX version to export the model to do_constant_folding=True,  # whether to execute constant folding for optimization input_names = ['modelInput'],   # the model's input names output_names = ['modelOutput'], # the model's output names dynamic_axes={'modelInput' : {0 : 'batch_size'},    # variable length axes 'modelOutput' : {0 : 'batch_size'}}) print(" ") print('Model has been converted to ONNX')if __name__ == "__main__": # 构建模型并训练# xxxxxxxxxxxx# 测试模型精度#testAccuracy() # 加载模型结构与权重model = Network() path = "myFirstModel.pth" model.load_state_dict(torch.load(path)) # 转换为ONNX Convert_ONNX(model)

二、加载ONNX模型

加载ONNX模型,通常需要用到ONNX、ONNX Runtime,所以需要先安装。

pip install onnx
pip install onnxruntime

加载ONNX模型可以使用ONNX Runtime库,以下是一个加载ONNX模型的示例代码:

import onnxruntime as ort# 加载 ONNX 模型
ort_session = ort.InferenceSession("model.onnx")# 准备输入信息
input_info = ort_session.get_inputs()[0]
input_name = input_info.name
input_shape = input_info.shape
input_type = input_info.type# 运行ONNX模型
outputs = ort_session.run(input_name, input_data)# 获取输出信息
output_info = ort_session.get_outputs()[0]
output_name = output_info.name
output_shape = output_info.shape
output_data = outputs[0]print("outputs:", outputs)
print("output_info :", output_info )
print("output_name :", output_name )
print("output_shape :", output_shape )
print("output_data :", output_data )

以下是一个示例程序,将 resnet18 模型从 PyTorch 转换为 ONNX 格式,然后加载和测试 ONNX 模型的过程:

import torch
import torchvision.models as models
import onnx
import onnxruntime# 加载 PyTorch 模型
model = models.resnet18(pretrained=True)
model.eval()# 定义输入和输出张量的名称和形状
input_names = ["input"]
output_names = ["output"]
batch_size = 1
input_shape = (batch_size, 3, 224, 224)
output_shape = (batch_size, 1000)# 将 PyTorch 模型转换为 ONNX 格式
torch.onnx.export(model,  # 要转换的 PyTorch 模型torch.randn(input_shape),  # 模型输入的随机张量"resnet18.onnx",  # 保存的 ONNX 模型的文件名input_names=input_names,  # 输入张量的名称output_names=output_names,  # 输出张量的名称dynamic_axes={input_names[0]: {0: "batch_size"}, output_names[0]: {0: "batch_size"}}  # 动态轴,即输入和输出张量可以具有不同的批次大小
)# 加载 ONNX 模型
onnx_model = onnx.load("resnet18.onnx")
onnx_model_graph = onnx_model.graph
onnx_session = onnxruntime.InferenceSession(onnx_model.SerializeToString())# 使用随机张量测试 ONNX 模型
x = torch.randn(input_shape).numpy()
onnx_output = onnx_session.run(output_names, {input_names[0]: x})[0]print(f"PyTorch output: {model(torch.from_numpy(x)).detach().numpy()[0, :5]}")
print(f"ONNX output: {onnx_output[0, :5]}")

上述代码中,首先加载预训练的 resnet18 模型,并定义了输入和输出张量的名称和形状。

然后,使用 torch.onnx.export() 函数将模型转换为 ONNX 格式,并保存为 resnet18.onnx 文件。

接着,使用 onnxruntime.InferenceSession() 函数加载 ONNX 模型,并使用随机张量进行测试。

最后,将 PyTorch 模型和 ONNX 模型的输出进行比较,以确保它们具有相似的输出。

三、可视化ONNX模型

使用Netron,可视化ONNX模型,看一下网络结构;查看使用了那些算子,以便开发部署。

这里简单介绍一下

Netron是一个轻量级、跨平台的模型可视化工具,支持多种深度学习框架的模型可视化,包括TensorFlow、PyTorch、ONNX、Keras、Caffe等等。它提供了可视化网络结构、层次关系、输出尺寸、权重等信息,并且可以通过鼠标移动和缩放来浏览模型。Netron还支持模型的导出和导入,方便模型的分享和交流。

 Netron的网页在线版本,直接在网页中打开和查看ONNX模型Netron

开源地址:GitHub - lutzroeder/netron: Visualizer for neural network, deep learning, and machine learning models

支持多种操作系统:

macOS: Download 

Linux: Download 

Windows: Download 

Browser: Start 

Python Server: Run pip install netron and netron [FILE] or netron.start('[FILE]').

 

下面是可视化模型截图:

还能查看某个节点(运算操作)的信息,比如下面MaxPool,点击一下,能看到使用的3x3的池化核,是否有填充pads,步长strides等参数。

 

分享完毕~

 

相关文章:

模型转换 PyTorch转ONNX 入门

前言 本文主要介绍如何将PyTorch模型转换为ONNX模型,为后面的模型部署做准备。转换后的xxx.onnx模型,进行加载和测试。最后介绍使用Netron,可视化ONNX模型,看一下网络结构;查看使用了那些算子,以便开发部署…...

【深度学习】激活函数

上一章——认识神经网络 新课P54介绍了强人工智能概念,P55到P58解读了矩阵乘法在代码中的应用,P59,P60介绍了在Tensflow中实现神经网络的代码及细节,详细的内容可以自行观看2022吴恩达机器学习Deeplearning.ai课程,专…...

【新2023】华为OD机试 - 数字的排列(Python)

华为 OD 清单查看地址:blog.csdn.net/hihell/category_12199275.html 数字的排列 题目 小华是个很有对数字很敏感的小朋友, 他觉得数字的不同排列方式有特殊的美感。 某天,小华突发奇想,如果数字多行排列, 第一行1个数, 第二行2个, 第三行3个, 即第n行n个数字,并且…...

[oeasy]python0085_ASCII之父_Bemer_COBOL_数据交换网络

编码进化 回忆上次内容 上次 回顾了 字符编码的 进化过程 IBM 在数字化过程中 作用 非常大IBM 的 BCDIC 有 黑历史 😄 6-bit的 BCDIC 直接进化成 8-bit的 EBCDIC补全了 小写字母 和 控制字符 在ibm就是信息产业的年代 ibm的标准 怎么最终 没有成为 行业的标准 呢…...

volatile,内存屏障

volatile的特性可见性: 对于其他线程是可见,假设线程1修改了volatile修饰的变量,那么线程2是可见的,并且是线程安全的重排序: 由于CPU执行的时候,指令在后面的会先执行,在指令层级的时候我们晓得volatile的特性后,我们就要去volatile是如何实现的,这个很重要!&#…...

【ESP 保姆级教程】玩转emqx MQTT篇① —— 系统主题、延迟发布、服务器配置预算、常见问题

忘记过去,超越自己 ❤️ 博客主页 单片机菜鸟哥,一个野生非专业硬件IOT爱好者 ❤️❤️ 本篇创建记录 2023-02-18 ❤️❤️ 本篇更新记录 2023-02-18 ❤️🎉 欢迎关注 🔎点赞 👍收藏 ⭐️留言📝🙏 此博客均由博主单独编写,不存在任何商业团队运营,如发现错误,请…...

第48讲:SQL优化之ORDER BY排序查询的优化

文章目录1.ORDEY BY排序查询优化方面的概念2.ORDER BY排序的优化原则3.ORDER BY排序优化的案例3.1.准备排序优化的表以及索引3.2.同时对nl和lxfs字段使用升序排序3.3.同时对nl和lxfs字段使用降序排序3.4.排序时调整联合索引中字段的位置顺序3.5.排序时一个字段使用升序一个字段…...

[Datawhale][CS224W]图机器学习(三)

目录一、简介与准备二、教程2.1 下载安装2.2 创建图2.2.1 常用图创建(自定义图创建)1.创建图对象2.添加图节点3.创建连接2.2.2 经典图结构1.全连接无向图2.全连接有向图3.环状图4.梯状图5.线性串珠图6.星状图7.轮辐图8.二项树2.2.3 栅格图1.二维矩形栅格…...

2023版最新最强大数据面试宝典

此套面试题来自于各大厂的真实面试题及常问的知识点,如果能理解吃透这些问题,你的大数据能力将会大大提升,进入大厂指日可待!目前已经更新到第4版,广受好评!复习大数据面试题,看这一套就够了&am…...

CSS 中的 BFC 是什么,有什么作用?

BFC,即“块级格式化上下文”(Block Formatting Context),是 CSS 中一个重要的概念,它指的是一个独立的渲染区域,让块级盒子在布局时遵循一些特定的规则。BFC 的存在使得我们可以更好地控制文档流&#xff0…...

总结在使用 Git 踩过的坑

问题一: 原因 git 有两种拉代码的方式,一个是 HTTP,另一个是 ssh。git 的 HTTP 底层是通过 curl 的。HTTP 底层基于 TCP,而 TCP 协议的实现是有缓冲区的。 所以这个报错大致意思就是说,连接已经关闭,但是此时有未处理…...

从 HTTP 到 gRPC:APISIX 中 etcd 操作的迁移之路

罗泽轩,API7.ai 技术专家/技术工程师,Apache APISIX PMC 成员。 原文链接 Apache APISIX 现有基于 HTTP 的 etcd 操作的局限性 etcd 在 2.x 版本的时候,对外暴露的是 HTTP 1 (以下简称 HTTP)的接口。etcd 升级到 3.x…...

【C语言每日一题】——倒置字符串

【C语言每日一题】——倒置字符串😎前言🙌倒置字符串🙌总结撒花💞😎博客昵称:博客小梦 😊最喜欢的座右铭:全神贯注的上吧!!! 😊作者简…...

Native扩展开发的一般流程(类似开发一个插件)

文章目录大致开发流程1、编写对应的java类服务2、将jar包放到对应位置3、配置文件中进行服务配置4、在代码中调用5、如何查看服务调用成功大致开发流程 1、编写服务,打包为jar包2、将jar包放到指定的位置3、在配置文件中进行配置,调用对应的服务 1、编…...

【新解法】华为OD机试 - 任务调度 | 备考思路,刷题要点,答疑,od Base 提供

华为 OD 清单查看地址:blog.csdn.net/hihell/category_12199275.html 任务调度 题目 现有一个 CPU 和一些任务需要处理,已提前获知每个任务的任务 ID、优先级、所需执行时间和到达时间。 CPU 同时只能运行一个任务,请编写一个任务调度程序,采用“可抢占优先权调度”调度…...

Spring3定时任务

简介 Spring 内部有一个 task 是 Spring 自带的一个设定时间自动任务调度,提供了两种方式进行配置,一种是注解的方式,而另外一种就是 XML 配置方式了;注解方式比较简洁,XML 配置方式相对而言有些繁琐,但是应用场景的不…...

数据库版本管理工具Flyway应用研究

目录1 为什么使用数据库版本控制2 数据库版本管理工具选型:Flyway、Liquibase、Bytebase、阿里 DMSFlywayLiquibaseBytebase阿里 DMS3 Flyway数据库版本管理研究3.1 参考资料3.2 Flyway概述3.3 Flyway原理3.4 Flyway版本和功能3.5 Flyway概念3.5.1 版本迁移&#xf…...

更换 Ubuntu 系统 apt 命令安装软件源

更换 Ubuntu 系统 apt 命令安装软件源清华大学开源软件镜像站 https://mirrors.tuna.tsinghua.edu.cn/ 1. Ubuntu 的软件源配置文件 /etc/apt/sources.list MIRRORS -> 使用帮助 -> ubuntu https://mirrors.tuna.tsinghua.edu.cn/help/ubuntu/ Ubuntu 系统 apt 命令安…...

2023年可见光通信(LiFi)研究新进展

可见光无线通信Light Fidelity(LiFi)又称“光保真技术”,是一种利用可见光进行数据传输的全新无线传输技术。LiFi是一种以半导体光源作为信号发射源,利用无需授权的自由光谱实现无线连接的新型无线通信技术,支持高密度…...

Greenplum的两阶段提交

注:本文章引自终于把分布式事务讲明白了! 在前面的文章中,我们了解了单机库中的事务一致性实现以及分布式事务中的两阶段提交协议。大多数分布式系统都是采用了两阶段提交塄来保证事务的原子性,Greenplum也是采用了两阶段提交&am…...

多元回归分析 | CNN-BiLSTM卷积双向长短期记忆神经网络多输入单输出预测(Matlab完整程序)

多元回归分析 | CNN-BiLSTM卷积双向长短期记忆神经网络多输入单输出预测(Matlab完整程序) 目录 多元回归分析 | CNN-BiLSTM卷积双向长短期记忆神经网络多输入单输出预测(Matlab完整程序)预测结果评价指标基本介绍程序设计参考资料预测结果 评价指标 训练结束: 已完成最大轮…...

git命令行推送本地分支到远程仓库

之前说过Git与IDEA强强联合(HTTPS协议连接)那么如何使用命令行来推送代码呢? 如下图所示为一个基于layui的前端代码: 目录工作区文件: 本地内容就是将这些内容推送到远程仓库 首先使用git命令初始化git本地仓库&…...

在vscode中使用Typescript并运行

首先呢,我们在学习ts之前,需要先安装ts 1、安装 typescript npm install -g typescript //检查是否安装tsc -v ​ 2、生成配置文件,cd进入该文件夹,在控制台输 tsc --init ​ 此时我们就可以看到在ts文件夹下面出现了 一个tsco…...

【C++提高编程】C++全栈体系(十九)

C提高编程 第三章 STL - 常用容器 一、string容器 1. string基本概念 本质: string是C风格的字符串,而string本质上是一个类 string和char * 区别: char * 是一个指针string是一个类,类内部封装了char*,管理这个…...

Java版电能表协议解析源码(DL/T645-2007)、Modbus串口虚拟工具、网络串口调试工具分享

什么是Modbus通信协议Modbus串口调试工具Java版协议解析源码 网络与串口二合一调试助手TCPCOM: https://download.csdn.net/download/liuyuan_java/87454762 Modbus调试工具,模拟串口调试工具 https://download.csdn.net/download/liuyuan_java/874274…...

2023美赛选题建议 美国大学生数学建模竞赛ABCDEF题

选题建议和粗略思路已更新完毕 对于没有基础的同学来说CD两题上手难度较高,大家可以根据自己的实际情况选择最适合自己的题目,团队将持续更新各题后续内容,Q群322297051 A题主要难度就是建立第一问的模型,综合来看难度不大&…...

2023,想跳槽的可以再等等

今天讲讲跳槽。 金三银四即将开启,一些不满现状,被外界的“高薪”“好福利”吸引的人,一般就在这时候毅然决然地跳槽了。 跳槽是为了寻求更好的发展,但在跳槽前我们也不能确定下家就是更好的归宿,这就更加需要我们审…...

Java之滑动窗口详解

目录 一.滑动窗口 1.什么滑动窗口 2.滑动窗口的三要素 二.找到字符串中所有字母异位词 1.题目描述 2.问题分析 3.代码实现 三.字符串的排列 1.题目描述 2.问题分析 3.代码实现 四.考试的最大困扰度 1.题目描述 2.问题分析 3.代码实现 五.替换后的最长重复字符 …...

Webpack(应用一:基本使用,只需六步骤)

前言 上一篇文章已经说明了webpack的定义以及需求 本偏文章主要讲解webpack的基本使用 tips:现在以vscode编辑器来展示,只需要几个步骤就可以实现webpack的基本使用。 一、首先要安装node.js 1、不会安装node.js的,可以在网上自己找教程来…...

【Python小游戏】智商爆棚,推荐一款益智类亲子娱乐首选—某程序员老爸:成语编成填空“游戏”,贪玩女儿1天牢记500词(厉害了我的Python)

前言 成语填空想必大家都是十分熟悉的了,特别是有在上小学的家长肯定都有十分深刻的印象。 在我们的认知里看图猜成语不就是一些小儿科的东西吗? 当然了你也别小看了成语调控小游戏,有的时候知识储备不够,你还真的不一定猜得出…...

网站推广有哪些优势/百度关键词推广公司

得流量者得天下,有了流量,客户和生意才会跟着来,所以流量的获取一直都是市场部门的核心目标之一。这几年,SEO的重要性一点都没衰减,甚至有很多小型的创业公司,也会给专门的HC来提升SEO效果。 在基础的SEO工…...

做愛视频网站/设计网站排行榜前十名

参考 文章目录前言小程序应用生命周期App(Object object)小程序页面生命周期Page(Object object)页面页面配置onLoadonHideonPullDownRefresh:onReachBottomonShareAppMessageonPageScrollonResize在手机启用屏幕旋转页面跳转Navigator方法一 可返回方法二 原地销毁…...

有源码怎么在本地网站搭建/bt磁力搜索器

求职经验分享者Crydi:猴子数据分析学员成功入职金融行业数据分析师跳槽薪资提升30%你将收获:1、如何学习准备才能跳槽成功提升薪资?2、如何做一份面试官喜欢的简历?3、如何准备面试,才能让面试官对你刮目相看&#xff…...

网站后台做完文章不显示/四种营销策略

有scp命令,传输文件却显示报错无此命令今天下午在一台服务器上使用scp命令向另外一台服务器传文件的时候,报此错误 bash: scp: command not found ,lost connection,以为是该服务器没有安装此命令,w ...学习mongo系列(五) AND,$or,&…...

满城网站建设/如何制作网站教程

三子棋代码实现思路: 一.菜单(是否开始游戏)—->打印棋盘—–>玩家下(电脑下)—->判断输赢*** 二.我们需要创建头文件game.h, 源文件game.c 和test.h*** 1.game.h用于函数的定义 #define _CRT_SECURE_NO_…...

潍坊网站建设怎样/网络营销专业是干嘛的

做奢侈品回收这么久,遇见最多的问题就是大多数的顾客都觉得自己包包贬值太快,舍不得卖!那究竟是什么原因让你的包包贬值这么快呢?这就要说说包包的保值率了~影响包包保值率的原因有很多,主要有:品牌、款式、…...