当前位置: 首页 > news >正文

30分钟吃掉 Pytorch 转 onnx

节前,我们星球组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、参加社招和校招面试的同学.

针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。

汇总合集:

《大模型面试宝典》(2024版) 发布!

圈粉无数!《PyTorch 实战宝典》火了!!!


PyTorch 是一个用于机器学习的开源深度学习框架,而ONNX(Open Neural Network Exchange)是一个用于表示深度学习模型的开放式格式。

将 PyTorch 模型转换为ONNX格式有几个原因和优势:

  1. 跨平台部署: ONNX是一个跨平台的格式,支持多种深度学习框架,包括PyTorch、TensorFlow等。将模型转换为ONNX格式可以使模型在不同框架和设备上进行部署和运行。

  2. 性能优化: ONNX格式可以在不同框架之间实现性能优化。例如,可以在PyTorch中训练模型,然后转换为ONNX格式,并在性能更高的框架(如TensorRT)中进行推理。

  3. 模型压缩: ONNX格式可以实现模型的压缩和优化,从而减小模型的体积并提高推理速度。这对于在资源受限的设备上部署模型尤为重要。

pytorch 模型线上部署最常见的方式是转换成onnx,然后再转成tensorRT 在cuda上进行部署推理。

技术交流群

前沿技术资讯、算法交流、求职内推、算法竞赛、面试交流(校招、社招、实习)等、与 10000+来自港科大、北大、清华、中科院、CMU、腾讯、百度等名校名企开发者互动交流~

我们建了Pytorch 技术与面试交流群, 想要获取最新面试题、了解最新面试动态的、需要源码&资料、提升技术的同学,可以直接加微信号:mlc2040。加的时候备注一下:研究方向 +学校/公司+CSDN,即可。然后就可以拉你进群了。

方式①、微信搜索公众号:机器学习社区,后台回复:加群
方式②、添加微信号:mlc2040,备注:技术交流

本文介绍将pytorch模型转换成onnx模型并进行推理的方法。

#!pip install onnx 
#!pip install onnxruntime
#!pip install torchvision

一,准备pytorch模型

我们先导入torchvision中的resnet18模型,演示它的推理效果。

以便和onnx的结果进行对比。

import torch
import torchvision.models as models
import numpy as np
import torchvision
import torchvision.transforms as Tfrom PIL import Imagedef create_net():net = models.resnet18(weights=torchvision.models.ResNet18_Weights.IMAGENET1K_V1)return net net = create_net()torch.save(net.state_dict(),'resnet18.pt')
net.eval();
def get_test_transform():return T.Compose([T.Resize([320, 320]),T.ToTensor(),T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])image = Image.open("dog.png") # 289
img = get_test_transform()(image)
img = img.unsqueeze_(0) 
output = net(img)
score, indice = torch.max(torch.softmax(output,axis=-1),1)
info = {'score':score.tolist()[0],'indice':indice.tolist()[0]}def show_image(image, title):import matplotlib.pyplot as plt ax=plt.subplot()ax.imshow(image)ax.set_title(title)ax.set_xticks([])ax.set_yticks([]) plt.show()show_image(image, title = info)

图片

二,pytorch模型转换成onnx模型

1, 简化版本

import onnxruntime
import onnxbatch_size = 1  
input_shape = (3, 320, 320)   x = torch.randn(batch_size, *input_shape)
onnx_file = "resnet18.onnx"
torch.onnx.export(net,x,onnx_file,opset_version=10,do_constant_folding=True,  # 是否执行常量折叠优化input_names=["input"],output_names=["output"],dynamic_axes={"input":{0:"batch_size"},  "output":{0:"batch_size"}})
!du -s -h resnet18.pt
 45M	resnet18.pt
!du -s -h resnet18.onnx 
 45M	resnet18.onnx

可以在 https://netron.app/ 中拖入 resnet18.onnx 文件查看模型结构

2,全面版本

下面的代码包括了设置输入输出尺寸,以及动态可以变batch等等。

import argparse
from argparse import Namespace
import time
import sys
import os
import torch
import torch.nn as nn
import torchvision.models as models
import onnx
import onnxruntimefrom io import BytesIOROOT = os.getcwd()
if str(ROOT) not in sys.path:sys.path.append(str(ROOT))params = Namespace(weights='resnet18.pt',img_size=[320,320],batch_size=1,half=False,dynamic_batch=True)parser = argparse.ArgumentParser()
parser.add_argument('--weights', type=str, default='checkpoint.pt', help='weights path')
parser.add_argument('--img-size', nargs='+', type=int, default=[320, 320], help='image size')  # height, width
parser.add_argument('--batch-size', type=int, default=1, help='batch size')
parser.add_argument('--half', action='store_true', help='FP16 half-precision export')
parser.add_argument('--inplace', action='store_true', help='set Detect() inplace=True')
parser.add_argument('--simplify', action='store_true', help='simplify onnx model')
parser.add_argument('--dynamic-batch', action='store_true', help='export dynamic batch onnx model')
parser.add_argument('--trt-version', type=int, default=8, help='tensorrt version')
parser.add_argument('--device', default='cpu', help='cuda device, i.e. 0 or 0,1,2,3 or cpu')args = parser.parse_args(args='',namespace=params)args.img_size *= 2 if len(args.img_size) == 1 else 1  # expand
print(args)t = time.time()# Check device
cuda = args.device != 'cpu' and torch.cuda.is_available()
device = torch.device(f'cuda:{args.device}' if cuda else 'cpu')
assert not (device.type == 'cpu' and args.half), '--half only compatible with GPU export, i.e. use --device 0'# Load PyTorch model
model = create_net()
model.to(device)
model.load_state_dict(torch.load(args.weights)) # pytorch模型加载# Input
img = torch.zeros(args.batch_size, 3, *args.img_size).to(device)  # image size(1,3,320,192) iDetection# Update model
if args.half:img, model = img.half(), model.half()  # to FP16
model.eval()prediction = model(img)  # dry run# ONNX export
print('\nStarting to export ONNX...')
export_file = args.weights.replace('.pt', '.onnx')  # filename
with BytesIO() as f:dynamic_axes = {"input":{0:"batch_size"}, "output":{0:"batch_size"} } if args.dynamic_batch else Nonetorch.onnx.export(model, img, f, verbose=False, opset_version=13,training=torch.onnx.TrainingMode.EVAL,do_constant_folding=True,input_names=['input'],output_names=['output'],dynamic_axes=dynamic_axes)f.seek(0)# Checksonnx_model = onnx.load(f)  # load onnx modelonnx.checker.check_model(onnx_model)  # check onnx modelif args.simplify:try:import onnxsimprint('\nStarting to simplify ONNX...')onnx_model, check = onnxsim.simplify(onnx_model)assert check, 'assert check failed'except Exception as e:print(f'Simplifier failure: {e}')onnx.save(onnx_model, export_file)print(f'ONNX export success, saved as {export_file}')# Finish
print('\nExport complete (%.2fs)' % (time.time() - t))
Namespace(weights='resnet18.pt', img_size=[320, 320], batch_size=1, half=False, dynamic_batch=True, inplace=False, simplify=False, trt_version=8, device='cpu')Starting to export ONNX...
ONNX export success, saved as resnet18.onnxExport complete (0.57s)

三,使用onnx模型进行推理

1,函数风格

onnx_sesstion = onnxruntime.InferenceSession(export_file)
def pipe(img_path,onnx_sesstion = onnx_sesstion):image = Image.open(img_path) img = get_test_transform()(image)img = img.unsqueeze_(0) to_numpy = lambda tensor: tensor.data.cpu().numpy()inputs = {onnx_sesstion.get_inputs()[0].name: to_numpy(img)}outs = onnx_sesstion.run(None, inputs)[0]score, indice = torch.max(torch.softmax(torch.as_tensor(outs),axis=-1),1)info = {'score':score.tolist()[0],'indice':indice.tolist()[0]}return info
img_path = 'dog.png'image = Image.open(img_path)info = pipe(img_path)show_image(image,info)

图片

2,对象风格

import os, sysimport onnxruntime
import onnxclass ONNXModel():def __init__(self, onnx_path):self.onnx_session = onnxruntime.InferenceSession(onnx_path)self.input_names = [node.name for node in self.onnx_session.get_inputs()]self.output_names = [node.name for node in self.onnx_session.get_outputs()]print("input_name:{}".format(self.input_names))print("output_name:{}".format(self.output_names))def forward(self, x):if isinstance(x,np.ndarray):assert len(self.input_names)==1input_feed = {self.input_names[0]:x}elif isinstance(x,(tuple,list)):assert len(self.input_names)==len(x)input_feed = {k:v for k,v in zip(self.input_names,x)}else:assert isinstance(x,dict)input_feed = xouts = self.onnx_session.run(self.output_names, input_feed=input_feed)return outsdef predict(self,img_path):image = Image.open(img_path) img = get_test_transform()(image)img = img.unsqueeze_(0) to_numpy = lambda tensor: tensor.data.cpu().numpy()outs = self.forward(to_numpy(img))[0]score, indice = torch.max(torch.softmax(torch.as_tensor(outs),axis=-1),1)return {'score':score[0].data.numpy().tolist(),'indice':indice[0].data.numpy().tolist()}
onnx_model = ONNXModel(export_file)
info = onnx_model.predict(img_path)
show_image(image, title = info)
input_name:['input']
output_name:['output']

图片

相关文章:

30分钟吃掉 Pytorch 转 onnx

节前,我们星球组织了一场算法岗技术&面试讨论会,邀请了一些互联网大厂朋友、参加社招和校招面试的同学. 针对算法岗技术趋势、大模型落地项目经验分享、新手如何入门算法岗、该如何准备、面试常考点分享等热门话题进行了深入的讨论。 汇总合集&…...

KEIL5如何打开KEIL4的GD工程

GD官方提供的很多KEIL例程为KIEL4的版本,读者使用的时候可能会碰到使用KEIL5打开KEIL4的工程会报错以及无法找到芯片选型的问题,具体表现如下图所示。 我们该怎么办呢? 下面为大家介绍两种方法: 第一种方法是在keil4的工程后缀u…...

大前端技术分类

1 基础 2 语言 3 类库 4 框架 5 跨栈 6 架构 7 领域 7.1 中后台 7.2 跨平台 7.3 可视化 7.4 智能化 7.5 工程化 7.5.1 规范化 7.5.2 流程化 —— 前端工程化工具系列 7.5.3 模板化 7.5.4 自动化 7.5.5 平台化 7.6 其他 7.6.1 音视频 7.6.2 Web3 7.6.3 区块…...

Android AAudio——C API控制音频流(四)

上一篇文章我们介绍了 C API 中音频流的创建流程,以及打开音频流操作,这里我们再来看一下音频流的其他操作流程 一、音频流操作介绍 1、操作流程图 下图是状态变化流程图,虚线框表示瞬时状态,实线框表示稳定状态。 2、操作函数 上图中主要包含下面几个操作函数: aaudio…...

万能嗅探:视频号下载神器

万能嗅探是一款比较好用资源嗅探软件,界面干净,可以抓取浏览器的网页,不过想必各位主要用来抓取视频号,下面是使用方法。 使用方法 打开万能嗅探客户端,然后打开浏览器,产生网络请求即可,看看…...

python数据分析-ZET财务数据分析

一、公司背景 中兴通讯股份有限公司是一家总部位于中国深圳的跨国公司,致力于为全球客户提供通信设备和解决方案。公司成立于1985年,自成立以来一直致力于为客户提供创新的通信技术和服务。中兴通讯的业务涵盖多个领域,包括但不限于高端路由…...

Leetcode学习

回文数 反转一半数字 第一个想法是将数字转换为字符串,并检查字符串是否为回文。 但是,这需要额外的非常量空间来创建问题描述中所不允许的字符串。 第二个想法是将数字本身反转,然后将反转的数字与原始数字比较,如果它们是相同…...

python 列出面板数据所有变量名

在Python中,处理面板数据(Panel Data)通常使用pandas库,特别是当数据以DataFrame或Panel(尽管Panel在较新版本的pandas中已被弃用)的形式存在时。然而,由于Panel的弃用,现代做法通常…...

知乎网站只让知乎用户看文章,普通人看不了

知乎默认不显示全部文章,需要点击展开阅读全文 然而点击后却要登录,这意味着普通人看不了博主写的文章,只有成为知乎用户才有权力查看文章。我想这不是知乎创作者希望的情况,他们写文章肯定是希望所有人都能看到。 这个网站篡改…...

web前端的实习记录:探索、挑战与成长

web前端的实习记录:探索、挑战与成长 踏入web前端实习的旅程,我怀揣着对未知的好奇与对技术的渴望,开始了一段全新的学习与实践。在这个过程中,我经历了四个方面的技术探索,五个方面的挑战应对,六个方面的…...

正则表达式的详解带你认识正则表达式的意义

前言 ​ 我们都知道协议通常通过添加固定的字符、报头、特定的数字等来定义数据的结构和格式。将正确的信息提取出来是十分重要的,而正则表达式可以用来描述和匹配这些固定的结构,从而提取出所需的信息。并且正则表达式还可以处理大量复杂的字符串。这篇…...

中国现在最厉害的书法家颜廷利:东方伟大思想家哲学家教育家

中国书法界名人颜廷利教授,一位在21世纪东方哲学、科学界及当代中国教育领域内具有深远影响力的泰斗级人物,不仅以其深厚的国学修为和对易经姓名学的独到见解著称,还因其选择在济南市历城区的龙泉大街以及天桥区的凤凰山庄与泉星小区等地设立…...

OS常用操作

目录 1 文件和目录操作 1. 1 创建目录 1.2 删除目录 1.3 列出目录内容 1.4 删除文件 1.5 打开和关闭文件描述符 1.6 修改文件权限 1.7 获取和设置文件属性 2 路径操作 2.1 获取当前工作目录 2.2 改变工作目录 2.3 路径操作 2.4 添加 Python 的模块搜索路径列表 3 …...

【IC验证】03 UVM

...

Jira的原理及应用详解(六)

本系列文章简介: 在当今快速发展的软件开发和项目管理领域,有效的团队协作和精确的项目进度追踪是确保项目成功的关键。Jira作为一款广受欢迎的项目和问题追踪工具,以其强大的功能、灵活的定制性以及卓越的用户体验,赢得了全球众多…...

Linux进程间通信之System V

目录 认识system V: system V共享内存: 共享内存的基本原理: 共享内存的数据结构: 共享内存的建立与释放: 共享内存的建立: 共享内存的释放: 共享内存的关联: 共享内存的去关联…...

力扣hot100:394. 字符串解码(递归/括号匹配,字符串之间相对顺序)

LeetCode:394. 字符串解码 本题容易想到用递归处理,在写递归时主要是需要明确自己的递归函数的定义。 不过我们也可以利用括号匹配的方式使用栈进行处理。 1、递归 定义递归函数string GetString(string & s,int & i); 表示处理处理整个numbe…...

【C++11】多线程常用知识

知识体系 thread C++ thread中最常用的两个函数是join和detach,怎么选择呢,简单来说,如果希望等待线程结束,用join,如果希望异步执行,且不等待执行结果,那么就用detach;thread_local可以简单理解为一个线程级别的全局变量;线程id在调试多线程程序时是非常有用的东西;…...

详解linux设备下的/dev/null

/dev/zero是一个特殊的设备文件,它在Linux系统中通常被用来生成无限数量的零数据流。 这个设备文件位于/dev目录下,它不代表任何实际的硬件设备,而是一个虚拟设备。 当从/dev/zero设备中读取数据时,会得到无限数量的零字节&…...

GPT-4 Turbo 和 GPT-4 的区别

引言 人工智能(AI)领域的发展日新月异,OpenAI 的 GPT 系列模型一直是这一领域的佼佼者。GPT-4 和 GPT-4 Turbo 是目前市场上最先进的语言模型之一。本文将详细探讨 GPT-4 和 GPT-4 Turbo 之间的区别,以帮助用户更好地理解和选择适…...

基于小波多分辨分析的一维时间序列信号趋势检测与去除(MATLAB R2018a)

小波最开始是数学上提出的概念,并且在纯数学的王国里存在了一个世纪之久。最开始是为了弥补傅里叶分析的缺陷,即傅里叶级数发散的问题,并寻找出能够代替傅里叶分析的方法。从最早的一些艰难的探索开始直到慢慢发展成为一套完整系统的小波分析…...

Linux RedHat7.6操作系统的xfs格式化后,mount不生效

Linux RedHat7.6操作系统的xfs格式化后,mount不生效 问题现象 最近在准备测试环境的过程中,当对xfs文件系统格式化后,mount磁盘,通过df -h命令查看,未显示挂载磁盘信息 [rootZHZXLxjspo0db003 ~]# mount /dev/datavg/datavg-lv_data /data…...

高并发ping多台主机IP

简介 社区或者是大型公司往往有成千上万或者几百台设备,保持设备始终在线对网络运维人员来说至关重要,然而一个一个登录检查,或者一个一个ping并不明智,累人且效率极低,并出错率高。花钱买检测服务当我没说。 shell编…...

03 Linux 内核数据结构

Linux kernel 有四种重要的数据结构:链表、队列、映射、二叉树。普通驱动开发者只需要掌握链表和队列即可。 链表和队列 Linux 内核都有完整的实现,我们不需要深究其实现原理,只需要会使用 API 接口即可。 1、链表 链表是 Linux 内核中最简单、最普通的数据结构。链表是一…...

关于软件调用独显配置指引【笔记】

关于笔记本电脑不支持独显直连的,bios下也是没有切换独显直连的选项的,处理方法 简单的来说按照图片指引可配置让软件调用独显: 1、进入系统→屏幕→显示卡界面; 2、【添加应用】浏览需要调用独显的软件安装目录,并打开…...

正大国际期货:什么是主力合约?

一个期货品种,在同一时间段,会上市多个月份的合约, 由于主力合约交易量大,流动性高,一般建议新手交易主力合约。 主力合约通常指交易集中,流动性好的合约 ,即在一段时间内交易量和持仓量最大的…...

codeforces round 949 div2

A Turtle and Piggy Are Playing a Game 题目&#xff1a; 思路&#xff1a;输出2的幂次b使得2^b为最大的不超过x的数 代码&#xff1a; #include <iostream>using namespace std;const int N 2e5 10;void solve() {int l, r;cin >> l >> r;if(r % 2) …...

分享美好,高清无阻 - 直播极简联网解决方案

1、需求背景 随着移动互联网、UGC模式和直播平台的发展&#xff0c;网络直播的门槛日益降低&#xff0c;越来越多的人希望成为直播的主角。基于物联网的户外直播无线联网解决方案应运而生&#xff0c;满足直播者的需求。 户外直播无线联网解决方案提供了无处不在的直播体验&a…...

贪心算法-加油站

一、题目描述 二、解题思路 1.运动过程分析 这里需要一个油箱剩余油量的变量resGas&#xff0c;初始化resGas0&#xff1b;还需要一个标记从什么位置当做初始位置的startIdx&#xff0c;初始化startIdx0。 我们从数组下标idx0处开始向后遍历&#xff0c;初始时startIdx0&#…...

【ArcGIS微课1000例】0116:将度-分-秒值转换为十进制度值(字段计算器VBA)

相关阅读:【ArcGIS微课1000例】0087:经纬度格式转换(度分秒转度、度转度分秒) 文章目录 一、计算方法二、计算案例一、计算方法 将度分秒转换为十进制度的简单等式: DD = (Seconds/3600) + (Minutes/60) + Degrees如果角度值是负数,则转换方法不同。其中一种方法是: …...

wordpress 添加代码/软文营销的写作技巧有哪些

pageContext对象的用法 利用EL函数库在jsp页面中对字符串处理&#xff1a; 由于在JSP页面中显示数据时&#xff0c;经常需要对显示的字符串进行处理&#xff0c;SUN公司针对于一些常见处理定义了一套EL函数库供开发者使用。 准备工作&#xff1a; 1&#xff09;导入jar包&…...

20种推广方式/厦门seo网站推广

Linux下使用最广泛的C/C编译器是GCC&#xff0c;大多数的Linux发行版本都默认安装&#xff0c;不管是开发人员还是初学者&#xff0c;一般都将GCC作为Linux下首选的编译工具。本教程毫不犹豫地使用GCC来编译C程序。 首先检查系统是否安装GCC 以及make gcc -v make -v运行结果…...

天津网站建设zmad/最近最新的新闻

现在微信已经关闭了QQ号和邮箱注册的功能&#xff0c;那如何才能用邮箱登陆微信呢&#xff1f;下面小编以TOM邮箱为例&#xff0c;分享一下邮箱登录微信的具体操作~ 1、打开微信&#xff0c;点击注册&#xff0c;默认为手机号注册。输入手机号注册一个账号&#xff08;只用手机…...

什么网站可以做图赚钱/百度快照有什么用

配置文件配置属性如下&#xff1a; 启动时报&#xff1a;**java.lang.IllegalStateException: Failed to load property source from location ‘classpath:/application.yml’**异常 接着查看具体报错信息&#xff1a; 这是格式配置出错了 我的处理方式是出错的配置直接删…...

做网站选什么系统/域名注册流程

本篇教程通过PHPstudy安装Mysql数据库。什么是phpstudy&#xff1f;phpStudy是一个PHP调试环境的程序集成包。该程序包集成最新ApachePHPMySQLphpMyAdminZendOptimizer&#xff0c;一次性安装&#xff0c;无须配置即可使用&#xff0c;是非常方便、好用的PHP调试环境。该程序不…...

软件开发工具包sdk/太原seo网站优化

首先看一下epoll的几个函数的介绍。1、epoll_create函数/*** brief 该函数生成一个epoll专用的文件描述符。它其实是在内核申请一空间&#xff0c;用来存* 放你想关注的socket fd上是否发生以及发生了什么事件。* param size: size就是你在这个epoll fd上能关注的最大socket fd…...