当前位置: 首页 > news >正文

模型转换 PyTorch转ONNX 入门

前言

本文主要介绍如何将PyTorch模型转换为ONNX模型,为后面的模型部署做准备。转换后的xxx.onnx模型,进行加载和测试。最后介绍使用Netron,可视化ONNX模型,看一下网络结构;查看使用了那些算子,以便开发部署。

目录

前言

一、PyTorch模型转ONNX模型

1.1 转换为ONNX模型且加载权重

1.2 转换为ONNX模型但不加载权重

1.3 torch.onnx.export() 函数

二、加载ONNX模型

三、可视化ONNX模型


一、PyTorch模型转ONNX模型

将PyTorch模型转换为ONNX模型,通常是使用torch.onnx.export( )函数来转换的,基本的思路是:

  1. 加载PyTorch模型,可以选择只加载模型结构;也可以选择加载模型结构和权重。
  2. 然后定义PyTorch模型的输入维度,比如(1, 3, 224, 224),这是一个三通道的彩色图,分辨率为224x224。
  3. 最后使用torch.onnx.export( )函数来转换,生产xxx.onnx模型。

下面有一个简单的例子:

import torch
import torch.onnx# 加载 PyTorch 模型
model = ...# 设置模型输入,包括:通道数,分辨率等
dummy_input = torch.randn(1, 3, 224, 224, device='cpu')# 转换为ONNX模型
torch.onnx.export(model, dummy_input, "model.onnx", export_params=True)

1.1 转换为ONNX模型且加载权重

这里举一个resnet18的例子,基本思路是:

  1. 首先加载了一个预训练的 ResNet18 模型;
  2. 然后将其设置为评估模式。接下来定义一个与模型输入张量形状相同的输入张量,并使用 torch.randn() 函数生成了一个随机张量。
  3. 最后,使用 onnx.export() 函数将 PyTorch 模型转换为 ONNX 格式,并将其保存到指定的输出文件中。

程序如下:

import torch
import torchvision.models as models# 加载预训练的 ResNet18 模型
model = models.resnet18(pretrained=True)# 将模型设置为评估模式
model.eval()# 定义输入张量,需要与模型的输入张量形状相同
input_shape = (1, 3, 224, 224)
x = torch.randn(input_shape)# 需要指定输入张量,输出文件路径和运行设备
# 默认情况下,输出张量的名称将基于模型中的名称自动分配
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")# 将 PyTorch 模型转换为 ONNX 格式
output_file = "resnet18.onnx"
torch.onnx.export(model, x.to(device), output_file, export_params=True)

1.2 转换为ONNX模型但不加载权重

举一个resnet18的例子:基本思路是:

  1. 首先加载了一个预训练的 ResNet18 模型;
  2. 然后使用 onnx.export() 函数将 PyTorch 模型转换为 ONNX 格式;指定参数do_constant_folding=False,不加载模型的权重。
import torch
import torchvision.models as models# 加载 PyTorch 模型
model = models.resnet18()# 将模型转换为 ONNX 格式但不加载权重
dummy_input = torch.randn(1, 3, 224, 224)
torch.onnx.export(model, dummy_input, "resnet18.onnx", do_constant_folding=False)

下面构建一个简单网络结构,并转换为ONNX

import torch
import torchvision
import numpy as np# 定义一个简单的PyTorch 模型
class MyModel(torch.nn.Module):def __init__(self):super(MyModel, self).__init__()self.conv1 = torch.nn.Conv2d(3, 32, kernel_size=3, stride=1, padding=1)self.relu = torch.nn.ReLU()self.maxpool = torch.nn.MaxPool2d(kernel_size=2, stride=2)self.conv2 = torch.nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)self.flatten = torch.nn.Flatten()self.fc1 = torch.nn.Linear(64 * 8 * 8, 10)def forward(self, x):x = self.conv1(x)x = self.relu(x)x = self.maxpool(x)x = self.conv2(x)x = self.relu(x)x = self.maxpool(x)x = self.flatten(x)x = self.fc1(x)return x# 创建模型实例
model = MyModel()# 指定模型输入尺寸
dummy_input = torch.randn(1, 3, 32, 32)# 将PyTorch模型转为ONNX模型
torch.onnx.export(model, dummy_input, 'mymodel.onnx',  do_constant_folding=False)

1.3 torch.onnx.export() 函数

看一下这个函数的参数

torch.onnx.export(model, args, f, export_params=True, opset_version=10, do_constant_folding=True, input_names=['input'], output_names=['output'], dynamic_axes=None, verbose=False, example_outputs=None, keep_initializers_as_inputs=None)
  • model:需要导出的 PyTorch 模型
  • args:PyTorch模型输入数据的尺寸,指定通道数、长和宽。可以是单个 Tensor 或元组,也可以是元组列表。
  • f:导出的 ONNX 文件路径和名称,mymodel.onnx。
  • export_params:是否导出模型参数。如果设置为 False,则不导出模型参数。
  • opset_version:导出的 ONNX 版本。默认值为 10。
  • do_constant_folding:是否对模型进行常量折叠。如果设置为 True,不加载模型的权重。
  • input_names:模型输入数据的名称。默认为 'input'。
  • output_names:模型输出数据的名称。默认为 'output'。
  • dynamic_axes:动态轴的列表,允许在导出的 ONNX 模型中创建变化的维度。
  • verbose:是否输出详细的导出信息。
  • example_outputs:用于确定导出 ONNX 模型输出形状的样本输出。
  • keep_initializers_as_inputs:是否将模型的初始化器作为输入导出。如果设置为 True,则模型初始化器将被作为输入的一部分导出。

下面是只是一个常用的模板

import torch.onnx # 转为ONNX
def Convert_ONNX(model): # 设置模型为推理模式model.eval() # 设置模型输入的尺寸dummy_input = torch.randn(1, input_size, requires_grad=True)  # 导出ONNX模型  torch.onnx.export(model,         # model being run dummy_input,       # model input (or a tuple for multiple inputs) "xxx.onnx",       # where to save the model  export_params=True,  # store the trained parameter weights inside the model file opset_version=10,    # the ONNX version to export the model to do_constant_folding=True,  # whether to execute constant folding for optimization input_names = ['modelInput'],   # the model's input names output_names = ['modelOutput'], # the model's output names dynamic_axes={'modelInput' : {0 : 'batch_size'},    # variable length axes 'modelOutput' : {0 : 'batch_size'}}) print(" ") print('Model has been converted to ONNX')if __name__ == "__main__": # 构建模型并训练# xxxxxxxxxxxx# 测试模型精度#testAccuracy() # 加载模型结构与权重model = Network() path = "myFirstModel.pth" model.load_state_dict(torch.load(path)) # 转换为ONNX Convert_ONNX(model)

二、加载ONNX模型

加载ONNX模型,通常需要用到ONNX、ONNX Runtime,所以需要先安装。

pip install onnx
pip install onnxruntime

加载ONNX模型可以使用ONNX Runtime库,以下是一个加载ONNX模型的示例代码:

import onnxruntime as ort# 加载 ONNX 模型
ort_session = ort.InferenceSession("model.onnx")# 准备输入信息
input_info = ort_session.get_inputs()[0]
input_name = input_info.name
input_shape = input_info.shape
input_type = input_info.type# 运行ONNX模型
outputs = ort_session.run(input_name, input_data)# 获取输出信息
output_info = ort_session.get_outputs()[0]
output_name = output_info.name
output_shape = output_info.shape
output_data = outputs[0]print("outputs:", outputs)
print("output_info :", output_info )
print("output_name :", output_name )
print("output_shape :", output_shape )
print("output_data :", output_data )

以下是一个示例程序,将 resnet18 模型从 PyTorch 转换为 ONNX 格式,然后加载和测试 ONNX 模型的过程:

import torch
import torchvision.models as models
import onnx
import onnxruntime# 加载 PyTorch 模型
model = models.resnet18(pretrained=True)
model.eval()# 定义输入和输出张量的名称和形状
input_names = ["input"]
output_names = ["output"]
batch_size = 1
input_shape = (batch_size, 3, 224, 224)
output_shape = (batch_size, 1000)# 将 PyTorch 模型转换为 ONNX 格式
torch.onnx.export(model,  # 要转换的 PyTorch 模型torch.randn(input_shape),  # 模型输入的随机张量"resnet18.onnx",  # 保存的 ONNX 模型的文件名input_names=input_names,  # 输入张量的名称output_names=output_names,  # 输出张量的名称dynamic_axes={input_names[0]: {0: "batch_size"}, output_names[0]: {0: "batch_size"}}  # 动态轴,即输入和输出张量可以具有不同的批次大小
)# 加载 ONNX 模型
onnx_model = onnx.load("resnet18.onnx")
onnx_model_graph = onnx_model.graph
onnx_session = onnxruntime.InferenceSession(onnx_model.SerializeToString())# 使用随机张量测试 ONNX 模型
x = torch.randn(input_shape).numpy()
onnx_output = onnx_session.run(output_names, {input_names[0]: x})[0]print(f"PyTorch output: {model(torch.from_numpy(x)).detach().numpy()[0, :5]}")
print(f"ONNX output: {onnx_output[0, :5]}")

上述代码中,首先加载预训练的 resnet18 模型,并定义了输入和输出张量的名称和形状。

然后,使用 torch.onnx.export() 函数将模型转换为 ONNX 格式,并保存为 resnet18.onnx 文件。

接着,使用 onnxruntime.InferenceSession() 函数加载 ONNX 模型,并使用随机张量进行测试。

最后,将 PyTorch 模型和 ONNX 模型的输出进行比较,以确保它们具有相似的输出。

三、可视化ONNX模型

使用Netron,可视化ONNX模型,看一下网络结构;查看使用了那些算子,以便开发部署。

这里简单介绍一下

Netron是一个轻量级、跨平台的模型可视化工具,支持多种深度学习框架的模型可视化,包括TensorFlow、PyTorch、ONNX、Keras、Caffe等等。它提供了可视化网络结构、层次关系、输出尺寸、权重等信息,并且可以通过鼠标移动和缩放来浏览模型。Netron还支持模型的导出和导入,方便模型的分享和交流。

 Netron的网页在线版本,直接在网页中打开和查看ONNX模型Netron

开源地址:GitHub - lutzroeder/netron: Visualizer for neural network, deep learning, and machine learning models

支持多种操作系统:

macOS: Download 

Linux: Download 

Windows: Download 

Browser: Start 

Python Server: Run pip install netron and netron [FILE] or netron.start('[FILE]').

 

下面是可视化模型截图:

还能查看某个节点(运算操作)的信息,比如下面MaxPool,点击一下,能看到使用的3x3的池化核,是否有填充pads,步长strides等参数。

 

分享完毕~

 

相关文章:

模型转换 PyTorch转ONNX 入门

前言 本文主要介绍如何将PyTorch模型转换为ONNX模型,为后面的模型部署做准备。转换后的xxx.onnx模型,进行加载和测试。最后介绍使用Netron,可视化ONNX模型,看一下网络结构;查看使用了那些算子,以便开发部署…...

【深度学习】激活函数

上一章——认识神经网络 新课P54介绍了强人工智能概念,P55到P58解读了矩阵乘法在代码中的应用,P59,P60介绍了在Tensflow中实现神经网络的代码及细节,详细的内容可以自行观看2022吴恩达机器学习Deeplearning.ai课程,专…...

【新2023】华为OD机试 - 数字的排列(Python)

华为 OD 清单查看地址:blog.csdn.net/hihell/category_12199275.html 数字的排列 题目 小华是个很有对数字很敏感的小朋友, 他觉得数字的不同排列方式有特殊的美感。 某天,小华突发奇想,如果数字多行排列, 第一行1个数, 第二行2个, 第三行3个, 即第n行n个数字,并且…...

[oeasy]python0085_ASCII之父_Bemer_COBOL_数据交换网络

编码进化 回忆上次内容 上次 回顾了 字符编码的 进化过程 IBM 在数字化过程中 作用 非常大IBM 的 BCDIC 有 黑历史 😄 6-bit的 BCDIC 直接进化成 8-bit的 EBCDIC补全了 小写字母 和 控制字符 在ibm就是信息产业的年代 ibm的标准 怎么最终 没有成为 行业的标准 呢…...

volatile,内存屏障

volatile的特性可见性: 对于其他线程是可见,假设线程1修改了volatile修饰的变量,那么线程2是可见的,并且是线程安全的重排序: 由于CPU执行的时候,指令在后面的会先执行,在指令层级的时候我们晓得volatile的特性后,我们就要去volatile是如何实现的,这个很重要!&#…...

【ESP 保姆级教程】玩转emqx MQTT篇① —— 系统主题、延迟发布、服务器配置预算、常见问题

忘记过去,超越自己 ❤️ 博客主页 单片机菜鸟哥,一个野生非专业硬件IOT爱好者 ❤️❤️ 本篇创建记录 2023-02-18 ❤️❤️ 本篇更新记录 2023-02-18 ❤️🎉 欢迎关注 🔎点赞 👍收藏 ⭐️留言📝🙏 此博客均由博主单独编写,不存在任何商业团队运营,如发现错误,请…...

第48讲:SQL优化之ORDER BY排序查询的优化

文章目录1.ORDEY BY排序查询优化方面的概念2.ORDER BY排序的优化原则3.ORDER BY排序优化的案例3.1.准备排序优化的表以及索引3.2.同时对nl和lxfs字段使用升序排序3.3.同时对nl和lxfs字段使用降序排序3.4.排序时调整联合索引中字段的位置顺序3.5.排序时一个字段使用升序一个字段…...

[Datawhale][CS224W]图机器学习(三)

目录一、简介与准备二、教程2.1 下载安装2.2 创建图2.2.1 常用图创建(自定义图创建)1.创建图对象2.添加图节点3.创建连接2.2.2 经典图结构1.全连接无向图2.全连接有向图3.环状图4.梯状图5.线性串珠图6.星状图7.轮辐图8.二项树2.2.3 栅格图1.二维矩形栅格…...

2023版最新最强大数据面试宝典

此套面试题来自于各大厂的真实面试题及常问的知识点,如果能理解吃透这些问题,你的大数据能力将会大大提升,进入大厂指日可待!目前已经更新到第4版,广受好评!复习大数据面试题,看这一套就够了&am…...

CSS 中的 BFC 是什么,有什么作用?

BFC,即“块级格式化上下文”(Block Formatting Context),是 CSS 中一个重要的概念,它指的是一个独立的渲染区域,让块级盒子在布局时遵循一些特定的规则。BFC 的存在使得我们可以更好地控制文档流&#xff0…...

总结在使用 Git 踩过的坑

问题一: 原因 git 有两种拉代码的方式,一个是 HTTP,另一个是 ssh。git 的 HTTP 底层是通过 curl 的。HTTP 底层基于 TCP,而 TCP 协议的实现是有缓冲区的。 所以这个报错大致意思就是说,连接已经关闭,但是此时有未处理…...

从 HTTP 到 gRPC:APISIX 中 etcd 操作的迁移之路

罗泽轩,API7.ai 技术专家/技术工程师,Apache APISIX PMC 成员。 原文链接 Apache APISIX 现有基于 HTTP 的 etcd 操作的局限性 etcd 在 2.x 版本的时候,对外暴露的是 HTTP 1 (以下简称 HTTP)的接口。etcd 升级到 3.x…...

【C语言每日一题】——倒置字符串

【C语言每日一题】——倒置字符串😎前言🙌倒置字符串🙌总结撒花💞😎博客昵称:博客小梦 😊最喜欢的座右铭:全神贯注的上吧!!! 😊作者简…...

Native扩展开发的一般流程(类似开发一个插件)

文章目录大致开发流程1、编写对应的java类服务2、将jar包放到对应位置3、配置文件中进行服务配置4、在代码中调用5、如何查看服务调用成功大致开发流程 1、编写服务,打包为jar包2、将jar包放到指定的位置3、在配置文件中进行配置,调用对应的服务 1、编…...

【新解法】华为OD机试 - 任务调度 | 备考思路,刷题要点,答疑,od Base 提供

华为 OD 清单查看地址:blog.csdn.net/hihell/category_12199275.html 任务调度 题目 现有一个 CPU 和一些任务需要处理,已提前获知每个任务的任务 ID、优先级、所需执行时间和到达时间。 CPU 同时只能运行一个任务,请编写一个任务调度程序,采用“可抢占优先权调度”调度…...

Spring3定时任务

简介 Spring 内部有一个 task 是 Spring 自带的一个设定时间自动任务调度,提供了两种方式进行配置,一种是注解的方式,而另外一种就是 XML 配置方式了;注解方式比较简洁,XML 配置方式相对而言有些繁琐,但是应用场景的不…...

数据库版本管理工具Flyway应用研究

目录1 为什么使用数据库版本控制2 数据库版本管理工具选型:Flyway、Liquibase、Bytebase、阿里 DMSFlywayLiquibaseBytebase阿里 DMS3 Flyway数据库版本管理研究3.1 参考资料3.2 Flyway概述3.3 Flyway原理3.4 Flyway版本和功能3.5 Flyway概念3.5.1 版本迁移&#xf…...

更换 Ubuntu 系统 apt 命令安装软件源

更换 Ubuntu 系统 apt 命令安装软件源清华大学开源软件镜像站 https://mirrors.tuna.tsinghua.edu.cn/ 1. Ubuntu 的软件源配置文件 /etc/apt/sources.list MIRRORS -> 使用帮助 -> ubuntu https://mirrors.tuna.tsinghua.edu.cn/help/ubuntu/ Ubuntu 系统 apt 命令安…...

2023年可见光通信(LiFi)研究新进展

可见光无线通信Light Fidelity(LiFi)又称“光保真技术”,是一种利用可见光进行数据传输的全新无线传输技术。LiFi是一种以半导体光源作为信号发射源,利用无需授权的自由光谱实现无线连接的新型无线通信技术,支持高密度…...

Greenplum的两阶段提交

注:本文章引自终于把分布式事务讲明白了! 在前面的文章中,我们了解了单机库中的事务一致性实现以及分布式事务中的两阶段提交协议。大多数分布式系统都是采用了两阶段提交塄来保证事务的原子性,Greenplum也是采用了两阶段提交&am…...

第19节 Node.js Express 框架

Express 是一个为Node.js设计的web开发框架,它基于nodejs平台。 Express 简介 Express是一个简洁而灵活的node.js Web应用框架, 提供了一系列强大特性帮助你创建各种Web应用,和丰富的HTTP工具。 使用Express可以快速地搭建一个完整功能的网站。 Expre…...

TDengine 快速体验(Docker 镜像方式)

简介 TDengine 可以通过安装包、Docker 镜像 及云服务快速体验 TDengine 的功能,本节首先介绍如何通过 Docker 快速体验 TDengine,然后介绍如何在 Docker 环境下体验 TDengine 的写入和查询功能。如果你不熟悉 Docker,请使用 安装包的方式快…...

进程地址空间(比特课总结)

一、进程地址空间 1. 环境变量 1 )⽤户级环境变量与系统级环境变量 全局属性:环境变量具有全局属性,会被⼦进程继承。例如当bash启动⼦进程时,环 境变量会⾃动传递给⼦进程。 本地变量限制:本地变量只在当前进程(ba…...

DockerHub与私有镜像仓库在容器化中的应用与管理

哈喽,大家好,我是左手python! Docker Hub的应用与管理 Docker Hub的基本概念与使用方法 Docker Hub是Docker官方提供的一个公共镜像仓库,用户可以在其中找到各种操作系统、软件和应用的镜像。开发者可以通过Docker Hub轻松获取所…...

如何在最短时间内提升打ctf(web)的水平?

刚刚刷完2遍 bugku 的 web 题,前来答题。 每个人对刷题理解是不同,有的人是看了writeup就等于刷了,有的人是收藏了writeup就等于刷了,有的人是跟着writeup做了一遍就等于刷了,还有的人是独立思考做了一遍就等于刷了。…...

push [特殊字符] present

push 🆚 present 前言present和dismiss特点代码演示 push和pop特点代码演示 前言 在 iOS 开发中,push 和 present 是两种不同的视图控制器切换方式,它们有着显著的区别。 present和dismiss 特点 在当前控制器上方新建视图层级需要手动调用…...

免费数学几何作图web平台

光锐软件免费数学工具,maths,数学制图,数学作图,几何作图,几何,AR开发,AR教育,增强现实,软件公司,XR,MR,VR,虚拟仿真,虚拟现实,混合现实,教育科技产品,职业模拟培训,高保真VR场景,结构互动课件,元宇宙http://xaglare.c…...

FFmpeg:Windows系统小白安装及其使用

一、安装 1.访问官网 Download FFmpeg 2.点击版本目录 3.选择版本点击安装 注意这里选择的是【release buids】,注意左上角标题 例如我安装在目录 F:\FFmpeg 4.解压 5.添加环境变量 把你解压后的bin目录(即exe所在文件夹)加入系统变量…...

论文阅读笔记——Muffin: Testing Deep Learning Libraries via Neural Architecture Fuzzing

Muffin 论文 现有方法 CRADLE 和 LEMON,依赖模型推理阶段输出进行差分测试,但在训练阶段是不可行的,因为训练阶段直到最后才有固定输出,中间过程是不断变化的。API 库覆盖低,因为各个 API 都是在各种具体场景下使用。…...

密码学基础——SM4算法

博客主页:christine-rr-CSDN博客 ​​​​专栏主页:密码学 📌 【今日更新】📌 对称密码算法——SM4 目录 一、国密SM系列算法概述 二、SM4算法 2.1算法背景 2.2算法特点 2.3 基本部件 2.3.1 S盒 2.3.2 非线性变换 ​编辑…...