【CV学习笔记】tensorrtx-yolov5 逐行代码解析
1、前言
TensorRTx(下文简称为trtx)是一个十分流行的利用API来搭建网络结构实现trt加速的开源库,作者提到为什么不用ONNX parser的方式来进行trt加速,而用最底层的API来搭建trt加速的方式有如下原因:
- Flexible 很容易修改模型的任意一层,删除、增加、替换等操作。
- Debuggable 可以容易获得模型中间某一层的结果
- Chance to learn 可以对模型结构有进一步的了解
尽管onnx2trt的方式目前已经在绝大部分情况下都不会出现问题,但在trtx下,我们能够掌握更底层的原理和代码,有利于我们对模型的部署以及优化。下文将会以yolov5s在trtx框架下的例子,来逐行解析是trtx是如何工作的。
TensorRTx项目链接:https://github.com/wang-xinyu/tensorrtx。
2、步骤解析
在trtx中,对一个模型加速的过程可以分为两个步骤
- 提取pytorch模型参数 wts
- 利用trt底层API搭建网络结构,并将wts中的参数填充到网络中
2.1、get_wts.py
首先需要将pytorch中的模型参数提取出来,pytorch中的模型参数是以caffe中blob的格式存在的,每个操作都有对应的名字、数据长度、数据.
for k, v in model.state_dict().items():# k-> blob的名字vr = v.reshape(-1).cpu().numpy() # vr -> 数据长度f.write('{} {} '.format(k, len(vr)))for vv in vr:f.write(' ')f.write(struct.pack('>f', float(vv)).hex()) # 将数据转化到16进制f.write('\n')
通过上get_wts.py,就可以得到包含yolov5s.pth的模型参数,打开yolov5s.wts如下图所示:
其中第一行的351为总的blob数量,第二行的model.0.conv.weight为第一个blob的名字,3456表示为该blob的数据长度,3a198000 3ca58000…为实际参数。
得到了上述的参数之后,就可以以trtx的方式进行加速了。
2.2、构造engine
在利用wts转engine的之前,需要十分清楚模型的网络结构,不太清楚的同学可以参考太阳花的小绿豆关于yolov5的网络结构图。了解完yolov5的网络结构后,就可以着手利用trt的api来搭建网络模型了。搭建模型的代码在 model.cpp中的build_det_engine函数,本文将其中的代码过程直接画到yolov5的网络结构图中了,可以直接对照代码和图来进行查看。
//yolov5_det.cpp
viod serialize_engine(...){if (is_p6) {...} else {// 以yolov5s为例engine = build_det_engine(max_batchsize, builder, config, DataType::kFLOAT, gd, gw, wts_name);}// 序列化IHostMemory* serialized_engine = engine->serialize();std::ofstream p(engine_name, std::ios::binary);// 写到文件中p.write(reinterpret_cast<const char*>(serialized_engine->data()), serialized_engine->size());}
model.cpp
// 解析get_wts.py
static std::map<std::string, Weights> loadWeights(const std::string file) {int32_t count; // wts文件第一行,共有351个blobinput >> count;//每一行是一个blob,模型名称 + 数据长度 + 参数while (count--) {// 一个blob的参数Weights wt{ DataType::kFLOAT, nullptr, 0 };uint32_t size; //blob 数据长度std::string name; // blob 数据名字for (uint32_t x = 0, y = size; x < y; ++x) {input >> std::hex >> val[x]; // 将数据转化成十进制,并放到val中}// 每个blob名字对应一个wtweightMap[name] = wt;}
}ICudaEngine* build_det_engine(){// 初始化网络结构INetworkDefinition* network = builder->createNetworkV2(0U);// 定义模型输入ITensor* data = network->addInput(kInputTensorName, dt, Dims3{ 3, kInputH, kInputW });// 加载pytorch模型中的参数std::map<std::string, Weights> weightMap = loadWeights(wts_name);// 逐步添加网络结构,已将代码与网络结构一一对应 ,具体过程见上图// 增加yolo后处理decode模块,使用了pluginauto yolo = addYoLoLayer(network, weightMap, "model.24", std::vector<IConvolutionLayer*>{det0, det1, det2});network->markOutput(*yolo->getOutput(0)); //将plugin的输出设置为模型的最后输出(decode)#if defined(USE_FP16)// FP16config->setFlag(BuilderFlag::kFP16);#elif defined(USE_INT8)// INT8 量化std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl;assert(builder->platformHasFastInt8());config->setFlag(BuilderFlag::kINT8);Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, kInputW, kInputH, "./coco_calib/", "int8calib.table", kInputTensorName);config->setInt8Calibrator(calibrator);#endif// 根据网络结构来生成engineICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);return engine;
}
3、plugin
本人对plugin也在学习当中,下面是我在学习trtx-yolo5代码中对plugin浅显的认知。原作者在模型后面增加了一个模型解码的plugin,用于获得每个特征层上的bbox,调用代码在model.cpp中
auto yolo = addYoLoLayer(network, weightMap, "model.24", std::vector<IConvolutionLayer*>{det0, det1, det2});static IPluginV2Layer* addYoLoLayer(...){// 注册一个名为 "YoloLayer_TRT"的插件,如果找不到插件,就会报错auto creator = getPluginRegistry()->getPluginCreator("YoloLayer_TRT", "1");// plugin的数据PluginField plugin_fields[2];int netinfo[5] = {kNumClass, kInputW, kInputH, kMaxNumOutputBbox, (int)is_segmentation}; //维度数据plugin_fields[0].data = netinfo; plugin_fields[0].length = 5; plugin_fields[0].name = "netinfo";plugin_fields[0].type = PluginFieldType::kFLOAT32;// 所有plugin的参数PluginFieldCollection plugin_data;plugin_data.nbFields = 2;plugin_data.fields = plugin_fields;// 创建plugin的对象 IPluginV2 *plugin_obj = creator->createPlugin("yololayer", &plugin_data);
}
实现代码在yololayer.h/cu中
class API YoloLayerPlugin : public IPluginV2IOExt {// 设置插件名称,在注册插件时会寻找对应的插件const char* getPluginType() const TRT_NOEXCEPT override{return "YoloLayer_TRT";}//插件构造函数YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, bool is_segmentation, const std::vector<YoloKernel>& vYoloKernel){/*classCount:类别数量netWidth:输入宽netHeight:输入高maxOut:最大检测数量is_segmentation:是否含有实例分割vYoloKernel:anchors参数*/}}// 插件运行时调用的代码
void YoloLayerPlugin::forwardGpu(...){// 输出结果 1+ 是在第一个位置记录解码的数量int outputElem = 1 + mMaxOutObject * sizeof(Detection) / sizeof(float);// 将存放结果的内存置为0for (int idx = 0; idx < batchSize; ++idx) {CUDA_CHECK(cudaMemsetAsync(output + idx * outputElem, 0, sizeof(float), stream));// 遍历三种不同尺度的anchorfor (unsigned int i = 0; i < mYoloKernel.size(); ++i) {// 调用核函数进行解码CalDetection << < (numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream >> >(...)}}__global__ void CalDetection(...){// input:模型输出结果// output:decode存放地址// 当前线程的的全局索引IDint idx = threadIdx.x + blockDim.x * blockIdx.x;// yoloWidth * yoloHeightint total_grid = yoloWidth * yoloHeight; // 在当前特征层上要处理的总框数int bnIdx = idx / total_grid; // 第n个batch // x,y,w,h,score + 80int info_len_i = 5 + classes;// 如果带有实例分割分析,需要再加上32个分割系数if (is_segmentation) info_len_i += 32;// 第n个batch的推理结果开始地址const float* curInput = input + bnIdx * (info_len_i * total_grid * kNumAnchor);// 遍历三种不同尺寸的anchorfor (int k = 0; k < kNumAnchor; ++k) {//每个框的置信度float box_prob = Logist(curInput[idx + k * info_len_i * total_grid + 4 * total_grid]);if (box_prob < kIgnoreThresh) continue;for (int i = 5; i < 5 + classes; ++i) {// 每个类别的概率float p = Logist(curInput[idx + k * info_len_i * total_grid + i * total_grid]);// 提取最大概率以及类别IDif (p > max_cls_prob) {max_cls_prob = p;class_id = i - 5;}}// float *res_count = output + bnIdx * outputElem;// 统计decode框的数量 int count = (int)atomicAdd(res_count, 1);// 下面是按照论文的公式将预测的宽和高恢复到原图大小...}
}
4、总结
通过本次对trtx开源代码的深入学习,知道了如何利用trt的api对模型进行加速,同时还了解到plugin的实现,后续还会继续学习trtx里面的知识点。
相关文章:
【CV学习笔记】tensorrtx-yolov5 逐行代码解析
1、前言 TensorRTx(下文简称为trtx)是一个十分流行的利用API来搭建网络结构实现trt加速的开源库,作者提到为什么不用ONNX parser的方式来进行trt加速,而用最底层的API来搭建trt加速的方式有如下原因: Flexible 很容易修改模型的任意一层,删…...
微信管理系统可以解决什么问题?
微信作为一款社交通讯软件,已经成为人们日常生活中不可缺少的工具。不仅个人,很多企业都用微信来联系客户、维护客户和营销,这自然而然就会有很多微信账号、手机也多,那管理起来就会带来很多的不便,而微信管理系统正好…...
mysql事务测试
mysql的事务处理主要有两种方法1、用begin,rollback,commit来实现 begin; -- 开始一个事务 rollback; -- 事务回滚 commit; -- 事务提交 2、直接用set来改变mysql的自动提交模式 mysql默认是自动提交的,也就是你提交一个sql,它就直接执行!我…...
Spring面试题14:Spring中什么是Spring Beans? 包含哪些?Spring容器提供几种方式配置元数据?Spring中怎样定义类的作用域?
该文章专注于面试,面试只要回答关键点即可,不需要对框架有非常深入的回答,如果你想应付面试,是足够了,抓住关键点 面试官:Spring中什么是Spring Beans? 包含哪些? 在Spring中,Spring Beans是指由Spring容器管理的对象。Spring Beans包含以下内容: 类定义(Class De…...
Tomcat部署、优化、以及操作练习
一.Tomcat的基本介绍 1.1.Tomcat是什么? Tomcat服务器是一个免费的开放源代码的Web应用服务器,属于轻量级应用服务器,在中小型系统和并发访问用户不是很多的场合下被普遍使用,是开发和调试JSP程序的首选。一般来说,T…...
服务器假死日志按时间统计排查
文章目录 场景解决方案排查过程根据cost时间来筛选 场景 服务器假死,进程还在,但是已经接不到请求了。因此有客户报事,发现服务假死了。 解决方案 这种假死问题一般不太好排查,常规来说有几种可能。 1、慢sql导致卡死。 2、大数…...
CSS——grid网格布局的基本使用
网格布局在实现页面自适应,大屏可视化中常常使用,在这篇博客里,记录一下网格布局的基本使用。 参考文档:网格布局_菜鸟教程 文章目录 1. 体会grid的自适应性2. grid-template-arr配置网格行列3. 网格单位fr与repeat()简写属性值4…...
【python】使用Nuitka打包python项目-demo示例
文章目录 写在前面参考准备工作Quick Start参数说明使用打包程序输出目录结构日志2023.09.20 写在前面 本文的demo示例的代码/数据可从笔者的GitCode获取: HelloWorld 参考 Nuitka官网: https://github.com/Nuitka/NuitkaNuitka使用: https://daobook.github.io/nuitka-doc/…...
Java多线程篇(5)——cas和atomic原子类
文章目录 CASAtomic 原子类一般原子类针对aba问题 —— AtomicStampedReference针对大量自旋问题 —— LongAdder CAS 原理大致如下: 在java的 Unsafe 类里封装了一些 cas 的api。以 compareAndSetInt 为例,来看看其底层实现。 可以发现,最…...
数据结构---栈和队列
栈(Stack) 栈:一种特殊的线性表,其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端称为栈 顶,另一端称为栈底。栈中的数据元素遵守后进先出LIFO(Last In First Out)的原则。 压栈࿱…...
2023-9-23 合并果子
题目链接:合并果子 #include <iostream> #include <algorithm> #include <queue>using namespace std;int main() {int n;cin >> n;priority_queue<int, vector<int>, greater<int>> heap;for(int i 0; i < n; i){in…...
基于QT和UDP实现一个实时RTP数据包的接收,并将数据包转化成文件
简单介绍:代码写的比较详细,需要留意的地方看结尾介绍 头文件 #ifndef RTPRECEIVER_H #define RTPRECEIVER_H#include <QDialog> #include <QUdpSocket> #include <QFile> #include <QTextStream> #include <httpclient.h&g…...
云原生安全性:保护现代应用免受威胁
文章目录 引言云原生安全性的挑战云原生安全性的关键实践1. 安全的镜像构建2. 网络策略3. 漏洞扫描和漏洞管理4. 认证和授权5. 日志和监控 云原生安全工具结论 🎉欢迎来到云计算技术应用专栏~云原生安全性:保护现代应用免受威胁 ☆* o(≧▽≦)o *☆嗨~我…...
R语言绘图-3-Circular-barplot图
0. 参考: https://r-graph-gallery.com/web-circular-barplot-with-R-and-ggplot2.html 1. 说明: 利用 ggplot 绘制 环状的条形图 (circular barplot),并且每个条带按照数值大小进行排列。 2 绘图代码: 注意:绘图代码中的字体…...
解决Keil5下载没有对应芯片Flash的问题
问题描述 例如芯片是STM32F103ZET6,但是选项中并没有对应型号的芯片导致下载失败。 解决方法 1、寻找芯片安装包的具体位置,芯片安装包路径在软件安装过程中会有(如图1所示)。如果没有记录可以双击一下芯片安装包会直接提示。…...
深拷贝与浅拷贝(对象的引用)
可以用赋值 1.对象的引用 代码: <!-- 1.对象的引用 --><script>const info{name:"lucy",age:20}const objinfo;info.name"sam"console.log(obj.name) //sam</script>图解: 等于号的赋值,对象info…...
重新认识架构—不只是软件设计
前言 什么是架构? 通常情况下,人们对架构的认知仅限于在软件工程中的定义:架构主要指软件系统的结构设计,比如常见的SOLID准则、DDD架构。一个良好的软件架构可以帮助团队更有效地进行软件开发,降低维护成本࿰…...
我的创业笔记:困境与思索
现在是2023年9月22日傍晚,我一个人走在广州的珠江边,静静地思索着当前个人创业面临的困境,不由自主地想将这些想法记录下来。 故事需要从两个月前说起。2023年7月31号,我从金山办公离职后,就满心欢喜地开启了自己的个…...
minio文件上传
1.代码 大佬仓库:https://gitee.com/Gary2016/minio-upload?_fromgitee_search 关于这个代码的讲解:来自b站 2.准备minio 参考:[1]、[2] 2.1 下载 官网:https://min.io/download#/windows 2.2 启动 ①准备一个data文件夹…...
IDEA .iml文件及.idea文件夹详解
.iml文件 idea 对module 配置信息之意, infomation of module。每个模块都有一个iml文件。 IDEA中的.iml文件是项目标识文件,缺少了这个文件,IDEA就无法识别项目。跟Eclipse的.project文件性质是一样的。并且这些文件不同的设备上的内容也会…...
使用Python做一个微信机器人
介绍 简介 该程序将微信的内部功能提取出来,然后在程序里加载Python,接着将这些功能导出成库函数,就可以在Python里使用这些函数 程序启动的时候会执行py_code目录下的main.py,类似于你在命令行使用python main.py。 现在会以…...
云计算战略:选择适合你业务的云平台
文章目录 云计算的概述选择云平台的关键因素1. 业务需求2. 预算3. 性能要求4. 数据隐私和合规性 示例:选择适合的云平台业务需求预算性能要求数据隐私和合规性 代码示例:使用云平台服务结论 🎉欢迎来到云计算技术应用专栏~云计算战略…...
Python:打印目录下每层的文件总数
代码如下: import osclass FileCount(object):def __init__(self,root_path: str):self.root_path root_pathself._count Noneself._file_count Noneself.children []def get_count(self):if self._count is None:self._count 0self._file_count 0for child_…...
LVS-NAT模式
LVS负载均衡群集 群集的定义 Cluster,集群(也称群集)由多台主机构成,但对外只表现为一一个整体,只提供一-个访问入口(域名或IP地址), 相当于一台大型计算机。 群集的作用 对于企业服务的的性能提升一般…...
【神印王座】龙皓晨竟然上了头版头条!内容违背,新闻真实性原则
Hello,小伙伴们,我是小郑继续为大家深度解析神印王座国漫。 大家有没有发现,当龙皓晨他们从驱魔关回到圣城时,有这么一幕,一个卖报小孩边走边说:驱魔关大捷,少年英雄龙皓晨操控守护与怜悯之神印王座&#x…...
C++之类和函数权限访问总结(二百二十七)
简介: CSDN博客专家,专注Android/Linux系统,分享多mic语音方案、音视频、编解码等技术,与大家一起成长! 优质专栏:Audio工程师进阶系列【原创干货持续更新中……】🚀 人生格言: 人生…...
手动部署 OceanBase 集群
手动部署一个 OB 单副本集群,包括一个 OBProxy 节点 部署环境 服务器信息 IP地址 192.168.0.26 网卡名 ifcfg-enp1s0 OS Kylin Linux Advanced Server release V10 CPU 8C 内存 32G 磁盘1 本地盘 /data/1 磁盘2 本地盘 /data/log1 机器和角色划分 …...
【操作系统笔记十二】Linux常用基础命令
Linux 常用快捷键 Tab 命令或路径等的补全键,特别常用的快捷键Ctrl insert 复制命令行内容(常用可提高效率)Shift insert 粘贴命令行内容(常用可提高效率)Ctrl C 中断当前任务(退出)Ctrl Z…...
Compose LazyColumn 对比 RecyclerView ,谁的性能更好?
LazyColumn 是 compose 中用来实现类似 RecyclerView 效果的控件 ,但是大家都说LazyColumn性能比RecyclerView差太多,毕竟 RecyclerView google优化了十多年了,比RecyclerView差一点也正常,今天我们就用实际数据来对比LazyColumn和…...
[python 刷题] 49 Group Anagrams
[python 刷题] 49 Group Anagrams 题目: Given an array of strings strs, group the anagrams together. You can return the answer in any order. An Anagram is a word or phrase formed by rearranging the letters of a different word or phrase, typically…...
网络有哪些广告推广方式/360手机优化大师安卓版
题库来源:安全生产模拟考试一点通公众号小程序 2020年熔化焊接与热切割考试及熔化焊接与热切割考试软件,包含熔化焊接与热切割考试答案和解析及熔化焊接与热切割考试软件练习。由安全生产模拟考试一点通公众号结合国家熔化焊接与热切割考试最新大纲及熔…...
贵港网站建设公司/九个关键词感悟中国理念
https://www.cnblogs.com/xyhuangjinfu/p/5429644.html 转载于:https://www.cnblogs.com/wangc04/p/9580796.html...
做渔具最大的外贸网站/南宁网站建设公司
ubuntu环境下安装pyinstaller。 pyinstaller的官网:https://pythonhosted.org/PyInstaller/installation.html 一、安装 直接使用pip安装,终端输入指令:pip install pyinstaller 二、验证 输入指令pyinstaller --version,如果输出…...
沧州网站建设价格/北京网站制作
人生苦短,我用 Python 前文传送门: 小白学 Python 数据分析(1):数据分析基础 小白学 Python 数据分析(2):Pandas (一)概述 小白学 Python 数据分析&#x…...
学做网站推广要多久时间/网络营销策划是什么
数据结构系列是我学习做的笔记,会持续更新,源码分享在github:数据结构,当然你也可以从下面的代码片中获取 注:github代码更新会有延迟,关注不迷路😄 本篇博文简单介绍邻接表与其存储图的特点,并…...
第一媒体app最新版本/优化落实疫情防控
1.Android 获取、移除 View 的 OnClickListener https://blog.csdn.net/lv_fq/article/details/82314241 https://github.com/lvfaqiang/AndroidTestCode/blob/master/app/src/main/java/com/lvfq/code/view/click/ViewClickActivity.kt...