通过解读yolov5_gpu_optimization学习如何使用onnx_surgon
onnx实战一: 解析yolov5 gpu的onnx优化案例:
这是一个英伟达的仓库, 这个仓库的做法就是通过用gs对onnx进行修改减少算子然后最后使用TensorRT插件实现算子, 左边是优化过的, 右边是原版的。 通过这个案例理解原版的onnx的导出流程然后我们看英伟达是怎么拿gs来优化这个onnx
原版的export_onnx函数
先看torch.onnx.export
函数的参数解释:
-
model: 要导出的PyTorch模型, 在工程中这里输入的是训练好的pt文件
-
im: 这里对应torch.onnx.export的args, 这个是用作模型输入的示例张量。这帮助ONNX确定输入的形状和类型。
-
f: 输出ONNX模型的文件名或文件对象, 用来指定导出模型的路径和文件名。
-
verbose (默认为
False
): 如果设置为True
,则会打印出模型导出时的详细日志。 -
opset_version: 导出的ONNX模型的操作集版本。不同的版本可能支持不同的操作。
-
training:
torch.onnx.TrainingMode.TRAINING
: 表示模型处于训练模式。torch.onnx.TrainingMode.EVAL
: 表示模型处于评估模式。
-
do_constant_folding (默认为
True
): 当设置为True
,导出过程中会尝试简化模型,将常量子图折叠为一个常量节点。 -
input_names: 为模型的输入提供名称, 参数规定是数组
-
output_names: 为模型的输出提供名称, 参数规定是数组
-
dynamic_axes: 为模型的输入/输出定义动态轴。对于那些维度在推理时可能会发生变化的情况(例如,批处理大小),此参数允许指定哪些轴是动态的。这里images是输入, 本来是1x3x640x640, 这里通过指定把0, 2, 3维度变成了动态轴的输入, 第二个维度是3这个还是固定的。如果使用动态, 可以输入任意数量和任意大小的图片而不是规定的单张640x640
'images'
: 对应的张量名称。0: 'batch'
: 表示第0个维度(即批处理维度)是动态的,并命名为’batch’。2: 'height'
: 表示第2个维度(即图像的高度)是动态的。3: 'width'
: 表示第3个维度(即图像的宽度)是动态的。
'output'
: 对应的张量名称。0: 'batch'
: 表示第0个维度(即批处理维度)是动态的。1: 'anchors'
: 表示第1个维度是动态的。
- dynamic (没有在给定的函数调用中明确给出,但可以从上下文推断):
True
: 如果你想让某些轴动态,你可以设置此参数为True
。False
: 表示不使用动态轴。
导出了onnx之后开始做onnxsim
-
model_onnx, check = onnxsim.simplify(...):
使用onnxsim的simplify方法简化模型。它返回简化后的onnx模型和一个布尔值check,表示简化是否成功。 -
在对动态输入的onnx导出的时候,
dynamic_input_shape=dynamic
是不够的,还要把输入给他,让onnxsim更加谨慎的优化onnx, 确保满足我们给他的输出,所以这里多了一个input_shapes={'images': list(im.shape)} if dynamic else None
def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):# YOLOv5 ONNX exporttry:check_requirements(('onnx',))import onnxLOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')f = file.with_suffix('.onnx')torch.onnx.export(model,im,f,verbose=False,opset_version=opset,training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,do_constant_folding=not train,input_names=['images'],output_names=['output'],dynamic_axes={'images': {0: 'batch',2: 'height',3: 'width'}, # shape(1,3,640,640)'output': {0: 'batch',1: 'anchors'} # shape(1,25200,85)} if dynamic else None)# Checksmodel_onnx = onnx.load(f) # load onnx modelonnx.checker.check_model(model_onnx) # check onnx model# Metadatad = {'stride': int(max(model.stride)), 'names': model.names}for k, v in d.items():meta = model_onnx.metadata_props.add()meta.key, meta.value = k, str(v)onnx.save(model_onnx, f)# Simplifyif simplify:try:check_requirements(('onnx-simplifier',))import onnxsimLOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')model_onnx, check = onnxsim.simplify(model_onnx,dynamic_input_shape=dynamic,input_shapes={'images': list(im.shape)} if dynamic else None)assert check, 'assert check failed'onnx.save(model_onnx, f)except Exception as e:LOGGER.info(f'{prefix} simplifier failure: {e}')LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')return fexcept Exception as e:LOGGER.info(f'{prefix} export failure: {e}')
更改过的export_onnx函数
- 首先是把onnx的输出由一个改成了3个, 然后指定动态输出, 因为有多个输出,全部都把他们的batch, width, height指定为动态的,满足不同的输入输出。 不过这边的问题是看起来是只改了最后的输出,但是前面在yolo.py的地方已经把sigmoid后面的计算都干掉了, 因为后面的计算映射了一堆的算子导致了在计算图太冗余
这一坨全部不要了就保留sigmoid就可以了,然后就是直接硬编码t就是int32
diff --git a/models/yolo.py b/models/yolo.py
index 02660e6..c810745 100644
--- a/models/yolo.py
+++ b/models/yolo.py
@@ -55,29 +55,15 @@ class Detect(nn.Module):z = [] # inference outputfor i in range(self.nl):x[i] = self.m[i](x[i]) # conv
- bs, _, ny, nx = x[i].shape # x(bs,255,20,20) to x(bs,3,20,20,85)
- x[i] = x[i].view(bs, self.na, self.no, ny, nx).permute(0, 1, 3, 4, 2).contiguous()
-
- if not self.training: # inference
- if self.onnx_dynamic or self.grid[i].shape[2:4] != x[i].shape[2:4]:
- self.grid[i], self.anchor_grid[i] = self._make_grid(nx, ny, i)
-
- y = x[i].sigmoid()
- if self.inplace:
- y[..., 0:2] = (y[..., 0:2] * 2 + self.grid[i]) * self.stride[i] # xy
- y[..., 2:4] = (y[..., 2:4] * 2) ** 2 * self.anchor_grid[i] # wh
- else: # for YOLOv5 on AWS Inferentia https://github.com/ultralytics/yolov5/pull/2953
- xy, wh, conf = y.split((2, 2, self.nc + 1), 4) # y.tensor_split((2, 4, 5), 4) # torch 1.8.0
- xy = (xy * 2 + self.grid[i]) * self.stride[i] # xy
- wh = (wh * 2) ** 2 * self.anchor_grid[i] # wh
- y = torch.cat((xy, wh, conf), 4)
- z.append(y.view(bs, -1, self.no))
-
- return x if self.training else (torch.cat(z, 1),) if self.export else (torch.cat(z, 1), x)
+ y = x[i].sigmoid()
+ z.append(y)
+ return zdef _make_grid(self, nx=20, ny=20, i=0):d = self.anchors[i].device
- t = self.anchors[i].dtype
+ # t = self.anchors[i].dtype
+ # TODO(tylerz) hard-code data type to int
+ t = torch.int32shape = 1, self.na, ny, nx, 2 # grid shapey, x = torch.arange(ny, device=d, dtype=t), torch.arange(nx, device=d, dtype=t)if check_version(torch.__version__, '1.10.0'): # torch>=1.10.0 meshgrid workaround for torch>=0.7 compatibility
--
2.36.0
-
onnxsim这里跟之前是一样的, 也是直接onnxsim, 如果动态的要给输出给onnxsim然后让它更加的谨慎,满足需求
-
这后面的重点是增加了用onnx-surgon来更改onnx, 先把整个onnx导入进来,然后使用然后用Variable做模型的输出, 这里做四个模型输出, 分别是
DecodeNumDetection
,DecodeDetectionBoxes
,DecodeDetectionScores
,DecodeDetectionClasses
-
然后设置一个attrs, gs设置的attrs使用字典的格式弄的。这里设置max_stride, num_classes, anchors, prenms_score_threshold四个属性,这些属性的操作会在TensorRT中实现的
-
decode_plugin是中间的节点,这个节点上面是inputs, 下面是四个不同的decodes, 这里就是把这个nodes做出来了
-
然后在整体的网络上添加这个节点,然后再把输出改成这个节点的输出保持一致,在计算图中把其他的节点claenup()清洁掉, 最后导出
def export_onnx(model, im, file, opset, train, dynamic, simplify, prefix=colorstr('ONNX:')):# YOLOv5 ONNX export# try:check_requirements(('onnx',))import onnxLOGGER.info(f'\n{prefix} starting export with onnx {onnx.__version__}...')f = file.with_suffix('.onnx')print(train)torch.onnx.export(model,im,f,verbose=False,opset_version=opset,training=torch.onnx.TrainingMode.TRAINING if train else torch.onnx.TrainingMode.EVAL,do_constant_folding=not train,input_names=['images'],output_names=['p3', 'p4', 'p5'],dynamic_axes={'images': {0: 'batch',2: 'height',3: 'width'}, # shape(1,3,640,640)'p3': {0: 'batch',2: 'height',3: 'width'}, # shape(1,25200,4)'p4': {0: 'batch',2: 'height',3: 'width'},'p5': {0: 'batch',2: 'height',3: 'width'}} if dynamic else None)# Checksmodel_onnx = onnx.load(f) # load onnx modelonnx.checker.check_model(model_onnx) # check onnx model# Simplifyif simplify:# try:check_requirements(('onnx-simplifier',))import onnxsimLOGGER.info(f'{prefix} simplifying with onnx-simplifier {onnxsim.__version__}...')model_onnx, check = onnxsim.simplify(model_onnx,dynamic_input_shape=dynamic,input_shapes={'images': list(im.shape)} if dynamic else None)assert check, 'assert check failed'onnx.save(model_onnx, f)# except Exception as e:# LOGGER.info(f'{prefix} simplifier failure: {e}')# add yolov5_decoding:import onnx_graphsurgeon as onnx_gsimport numpy as npyolo_graph = onnx_gs.import_onnx(model_onnx)p3 = yolo_graph.outputs[0]p4 = yolo_graph.outputs[1]p5 = yolo_graph.outputs[2]decode_out_0 = onnx_gs.Variable("DecodeNumDetection",dtype=np.int32)decode_out_1 = onnx_gs.Variable("DecodeDetectionBoxes",dtype=np.float32)decode_out_2 = onnx_gs.Variable("DecodeDetectionScores",dtype=np.float32)decode_out_3 = onnx_gs.Variable("DecodeDetectionClasses",dtype=np.int32)decode_attrs = dict()decode_attrs["max_stride"] = int(max(model.stride))decode_attrs["num_classes"] = model.model[-1].ncdecode_attrs["anchors"] = [float(v) for v in [10,13, 16,30, 33,23, 30,61, 62,45, 59,119, 116,90, 156,198, 373,326]]decode_attrs["prenms_score_threshold"] = 0.25decode_plugin = onnx_gs.Node(op="YoloLayer_TRT",name="YoloLayer",inputs=[p3, p4, p5],outputs=[decode_out_0, decode_out_1, decode_out_2, decode_out_3],attrs=decode_attrs)yolo_graph.nodes.append(decode_plugin)yolo_graph.outputs = decode_plugin.outputsyolo_graph.cleanup().toposort()model_onnx = onnx_gs.export_onnx(yolo_graph)d = {'stride': int(max(model.stride)), 'names': model.names}for k, v in d.items():meta = model_onnx.metadata_props.add()meta.key, meta.value = k, str(v)onnx.save(model_onnx, f)LOGGER.info(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')return f# except Exception as e:# LOGGER.info(f'{prefix} export failure: {e}')
相关文章:
通过解读yolov5_gpu_optimization学习如何使用onnx_surgon
onnx实战一: 解析yolov5 gpu的onnx优化案例: 这是一个英伟达的仓库, 这个仓库的做法就是通过用gs对onnx进行修改减少算子然后最后使用TensorRT插件实现算子, 左边是优化过的, 右边是原版的。 通过这个案例理解原版的onnx的导出流程然后我们看英伟达是怎么拿gs来优化…...
图像复原与重建,解决噪声的几种空间域复原方法(数字图像处理概念 P4)
文章目录 图像复原模型噪声模型只存在噪声的空间域复原 图像复原模型 噪声模型 只存在噪声的空间域复原...
Android 启动优化案例:WebView非预期初始化排查
去年年底做启动优化时,有个比较好玩的 case 给大家分享下,希望大家能从我的分享里 get 到我在做一些问题排查修复时是怎么看上去又low又土又高效的。 1. 现象 在我们使用 Perfetto 进行app 启动过程性能观测时,在 UI 线程发现了一段 几十毫…...
20230919后台面经整理
1.你认为什么是操作系统,操作系统有哪些功能 os是:管理资源、向用户提供服务、硬件机器的扩展 1.进程线程管理:状态、控制、通信等 2.存储管理:分配回收、地址转换 3.文件管理:目录、操作、磁盘、存取 4.设备管理&…...
画一个时钟(html+css+js)
这是一个很简约的时钟。。。。。。。 效果: 代码: <template><div class"demo-box"><div class"clock"><ul class"mark"><liv-for"(rotate, index) in rotatedAngles":key"i…...
红 黑 树
文章目录 一、红黑树的概念二、红黑树的实现1. 红黑树的存储结构2. 红黑树的插入 一、红黑树的概念 在 AVL 树中删除一个结点,旋转可能要持续到根结点,此时效率较低 红黑树也是一种二叉搜索树,通过在每个结点中增加一个位置来存储红色或黑色…...
掷骰子的多线程应用程序1(复现《Qt C++6.0》)
说明:复现的代码来自《Qt C6.0》P496-P500。在复现时完全按照代码,出现了两处报错: (1)ui指针(2)按钮的响应函数。下面程序对以上问题进行了修改。除了图片、清空、关闭功能外,其他…...
【vue2第十八章】VueRouter 路由嵌套 与 keep-alive缓存组件(activated,deactivated)
VueRouter 路由嵌套 在使用vue开发中,可能会碰到使用多层级别的路由。比如: 其中就包含了两个主要页面,首页,详情,但是首页的下面又包含了列表,喜欢,收藏,我的四个子路由。 此时就…...
如何确保亚马逊、速卖通等平台测评补单的环境稳定性和安全性?
做亚马逊、速卖通等平台测评补单时,确保环境的安全性和稳定性是非常重要的。稳定的环境是测评的基础,如果无法解决安全性问题,那么测评就不值得进行。为了确保稳定的环境系统,需要考虑物理环境和IP环境两个方面。 首先࿰…...
echarts图表 实现高度按照 内容撑起来或者超出部分滚动展示效果
背景:因为数据不固定 高度写死导致数据显示不全,所以图表高度要根据内容计算 实现代码如下: <divv-if"showCharts"id"business-bars"class"chart":style"{ height: chartHeight px }"></d…...
【论文阅读】检索增强发展历程及相关文章总结
文章目录 前言Knn-LMInsightMethodResultsDomain AdaptionTuning Nearest Neighbor Search Analysis REALMInsightsMethodKnowledge RetrieverKnowledge-Augmented Encoder ExpResultAblation StudyCase Study DPRInsightMethodExperimentsResults RAGInsightRAG-Sequence Mode…...
【漏洞复现系列】二、weblogic-cve_2020_2883(RCE/反序列化)
Key words:T3协议,weblogic Server,反序列化 2.1、漏洞原理 cve_2020_2883 远程代码执行漏洞存在于 WebLogic Server 核心组件中,允许未经身份验证的攻击者通过 T3 协议网络访问并破坏易受攻击的 WebLogic Server,成功的漏洞利…...
算法通关村-----LRU的设计与实现
LRU 缓存 问题描述 请你设计并实现一个满足 LRU (最近最少使用) 缓存 约束的数据结构。 实现 LRUCache 类: LRUCache(int capacity) 以 正整数 作为容量 capacity 初始化 LRU 缓存。int get(int key) 如果关键字 key 存在于缓存中,则返回关键字的值&…...
王江涛十天搞定考研词汇
学习目标: 考研词汇 学习内容: 2023-9-17 第一天考研词汇 学习时间: 开始:2023-9-17 结束:进行中 学习产出: 2023-9-17intellect智力;知识分子intellectual智力的;聪明的intellectualize使...理智化&a…...
算法(二)——数组章节和链表章节
数组章节 (1)二分查找 给定一个 n 个元素有序的(升序)整型数组 nums 和一个目标值 target ,写一个函数搜索 nums 中的 target,如果目标值存在返回下标,否则返回 -1。 class Solution {public i…...
Android:ListView在Fragment中的使用
一、前言: 因为工作一直在用mvvm框架,因此这篇文章是基于mvvm框架写的。在Fragment复制之前一定要谨记项目可以跑起来。确保能跑起来之后直接复制就行。 二、代码展示: 页面布局 ?xml version"1.0" encoding"utf-8"…...
BIGEMAP在土地规划中的应用
工具 Bigemap gis office地图软件 BIGEMAP GIS Office-全能版 Bigemap APP_卫星地图APP_高清卫星地图APP 1.使用软件一般都用于套坐标,比如我们常见的(kml shp CAD等土建规划图纸)以及一些项目厂区红线,方便于项目选址和居民建…...
软件测试常见术语和名词解释
1. Unit testing (单元测试):指一段代码的基本测试,其实际大小是未定的,通常是一个函数或子程序,一般由开发者执行。 2. Integration testing (集成测试):被测试系统的所有组件都集成在一起,找出被测试系统…...
prometheus+process_exporter进程监控
一、需要监控进程的服务器上配置 1、进入到临时工作目录,传入process_exporter包 [root Nginx1 ~]# cd work/ [root Nginx1 work]# rz 2、解压,并移动至/usr/local/目录下 [root Nginx1 work]# tar xzf process-exporter-0.7.5.linux-amd64.tar.gz [root…...
四川玖璨电子商务有限公司专注抖音电商运营
四川玖璨电商是一个靠谱的抖音培训公司,在电商行业内有着广泛的知名度和良好的口碑。该公司通过多年的发展,形成了独特的运营理念和有效的运营策略,为商家提供了一站式的抖音电商运营服务。 首先,四川玖璨电子商务有限公司注重与…...
python LeetCode 刷题记录 83
题目 给定一个已排序的链表的头 head , 删除所有重复的元素,使每个元素只出现一次 。返回 已排序的链表 。 代码 class Solution:def deleteDuplicates(self, head: Optional[ListNode]) -> Optional[ListNode]:if head:# 判断非空链表current he…...
Grom 如何解决 SQL 注入问题
什么是 SQL 注入 SQL 注入是一种常见的数据库攻击手段, SQL 注入漏洞也是网络世界中最普遍的漏洞之一。 SQL 注入就是恶意用户通过在表单中填写包含 SQL 关键字的数据来使数据库执行非常规代码的过程。 这个问题的来源就是, SQL 数据库的操作是通过 SQ…...
腾讯mini项目-【指标监控服务重构】2023-07-19
今日已办 OpenTelemetry Logs 通过日志记录 API 支持日志收集 集成现有的日志记录库和日志收集工具 Overview 日志记录 API - Logging API,允许您检测应用程序并生成结构化日志旨在与其他 telemerty data(例如metric和trace)配合使用&am…...
抖音矩阵系统源代码开发部署--SaaS开源技术开发文档
一、概述 抖音SEO矩阵系统源代码是一套针对抖音平台的搜索引擎优化工具,它可以帮助用户提高抖音视频在搜索结果中的排名,增加曝光率和流量。本开发文档旨在提供系统的功能框架、技术要求和开发示例,以便开发者进行二次开发和优化。 二、功能…...
CLIP模型资料学习
clip资料 links https://blog.csdn.net/wzk4869/article/details/129680734?ops_request_misc&request_id&biz_id102&utm_termCLIP&utm_mediumdistribute.pc_search_result.none-task-blog-2allsobaiduweb~default-4-129680734.142v94insert_down1&spm10…...
【c语言】贪吃蛇
当我们不想学习新知识的时候,并且特别无聊,就会突然先看看别人怎么写游戏的,今天给大家分享的是贪吃蛇,所需要的知识有结构体,枚举,以及easy-x图形库的一些基本函数就完全够用了,本来我想插入游…...
【Node.js】定时任务cron:
文章目录 一、文档:【Nodejs 插件】 二、安装与使用【1】安装【2】使用 三、cron表达式:{秒数} {分钟} {小时} {日期} {月份} {星期} {年份(可为空)}四、案例: 一、文档: 【说明文档】https://www.npmjs.com/package/cron 【Cron表…...
vue3 引入element-plus
1.首先安装element-plus npm install element-plus 2.main.js配置 ... import ElementPlus from element-plus import element-plus/theme-chalk/index.css; //导入图标 import * as ELementPlusIconsVue from "element-plus/icons-vue" ... app.use(ElementPlus) /…...
数据通信——传输层TCP(超时时间选择)
引言 TCP每一次发送报文段,就会对这个报文段设置一次计时器。如果时间到了却没有收到确认报文,那么就要重传该报文。 这个之前在TCP传输的机制中提到过,这个章节就来研究一下超时时间问题。 关于加权的概念 有必要提及一下加权的概念&#x…...
【数据库索引优化】
文章目录 数据库索引优化1. 选择合适的字段创建索引2. 限值每张表上的索引数量3. 被频繁更新的字段应该慎重建立索引4. 尽可能考虑简历联合索引而不是单列索引5. 避免冗余索引6. 字符串类型的字段使用前缀索引代替普通索引7. 避免索引失效8. 删除长期未使用的索引 数据库索引优…...
幼儿园网站建设工作总结/搜索app下载安装
介绍一款新的绘图神器:sviewgui。sviewgui介绍sviewgui是一个基于 PyQt 的 GUI,用于 csv 文件或 Pandas 的 DataFrame 的数据可视化。此 GUI 基于 matplotlib,您可以通过多种方式可视化您的 csv 文件。主要特点:Ⅰ 散点图、线图、…...
做网站需要代码么/网络小说排行榜
歌曲:无间道(电影无间道国语主题曲)歌手:梁朝伟&刘德华 专辑:风沙 我不愿意结束我还没有结束无止境的旅途看着我没停下的脚步已经忘了身在何处谁能改变人生的长度谁知道永恒有多么恐怖谁了解生存往往比命运还残酷只是没有人愿意认输我们都…...
做网站需要ftp/网络营销模式
一扫二拖,早上扫完下午脏,来回拖完地玩还得清洁拖把,心理暗悄悄想着怎么逃避这一现实,然鹅,第二天继续死循环……扫地机器人的出现可以说是一大救星,不用动手,不用弯腰,还可以坐在沙发吃着零食看…...
石家庄市市政建设总公司网站/百度百度一下官网
DML 添加数据 insert into 表名 where 条件 修改数据 update 表名 where 条件 删除数据 delete from where 条件 DQL 查询语句 select 【distinct去重】 字段列表 as ‘别名’ from 表名列表 where 条件列表 group by 分组字段列表 having 分组后条件列表 order by 排序字段…...
网站建设 短信群发模板/郑州网站建设哪家好
system.mss ----> Peripheral Drivers -----> Import Examples 也可以通过边上的文档 查看函数信息...
网站整体风格设计/一个关键词要刷多久
工作之余抽点时间出来写写博文,希望对新接触的朋友有帮助。今天在这里和大家一同学习一下引用指针 函数是C/C程序的基本功能单元,其重要性不言而喻。函数设计的纤细缺点很容易致使该函数被错用,所以光使函数的功能正确是不敷的。本章重点论述…...