当前位置: 首页 > news >正文

【CV学习笔记】tensorrtx-yolov5 逐行代码解析

1、前言

TensorRTx(下文简称为trtx)是一个十分流行的利用API来搭建网络结构实现trt加速的开源库,作者提到为什么不用ONNX parser的方式来进行trt加速,而用最底层的API来搭建trt加速的方式有如下原因:

  • Flexible 很容易修改模型的任意一层,删除、增加、替换等操作。
  • Debuggable 可以容易获得模型中间某一层的结果
  • Chance to learn 可以对模型结构有进一步的了解

尽管onnx2trt的方式目前已经在绝大部分情况下都不会出现问题,但在trtx下,我们能够掌握更底层的原理和代码,有利于我们对模型的部署以及优化。下文将会以yolov5s在trtx框架下的例子,来逐行解析是trtx是如何工作的。

TensorRTx项目链接:https://github.com/wang-xinyu/tensorrtx。

2、步骤解析

在trtx中,对一个模型加速的过程可以分为两个步骤

  • 提取pytorch模型参数 wts
  • 利用trt底层API搭建网络结构,并将wts中的参数填充到网络中
2.1、get_wts.py

首先需要将pytorch中的模型参数提取出来,pytorch中的模型参数是以caffe中blob的格式存在的,每个操作都有对应的名字、数据长度、数据.

for k, v in model.state_dict().items():# k-> blob的名字vr = v.reshape(-1).cpu().numpy() # vr -> 数据长度f.write('{} {} '.format(k, len(vr)))for vv in vr:f.write(' ')f.write(struct.pack('>f', float(vv)).hex()) # 将数据转化到16进制f.write('\n')

通过上get_wts.py,就可以得到包含yolov5s.pth的模型参数,打开yolov5s.wts如下图所示:

在这里插入图片描述

其中第一行的351为总的blob数量,第二行的model.0.conv.weight为第一个blob的名字,3456表示为该blob的数据长度,3a198000 3ca58000…为实际参数。

得到了上述的参数之后,就可以以trtx的方式进行加速了。

2.2、构造engine

在利用wts转engine的之前,需要十分清楚模型的网络结构,不太清楚的同学可以参考太阳花的小绿豆关于yolov5的网络结构图。了解完yolov5的网络结构后,就可以着手利用trt的api来搭建网络模型了。搭建模型的代码在 model.cpp中的build_det_engine函数,本文将其中的代码过程直接画到yolov5的网络结构图中了,可以直接对照代码和图来进行查看。
在这里插入图片描述

//yolov5_det.cpp
viod serialize_engine(...){if (is_p6) {...} else {// 以yolov5s为例engine = build_det_engine(max_batchsize, builder, config, DataType::kFLOAT, gd, gw, wts_name);}// 序列化IHostMemory* serialized_engine = engine->serialize();std::ofstream p(engine_name, std::ios::binary);// 写到文件中p.write(reinterpret_cast<const char*>(serialized_engine->data()), serialized_engine->size());}

model.cpp

// 解析get_wts.py
static std::map<std::string, Weights> loadWeights(const std::string file) {int32_t count;  // wts文件第一行,共有351个blobinput >> count;//每一行是一个blob,模型名称 + 数据长度 + 参数while (count--) {// 一个blob的参数Weights wt{ DataType::kFLOAT, nullptr, 0 };uint32_t size;  //blob 数据长度std::string name; // blob 数据名字for (uint32_t x = 0, y = size; x < y; ++x) {input >> std::hex >> val[x];  // 将数据转化成十进制,并放到val中}// 每个blob名字对应一个wtweightMap[name] = wt;}
}ICudaEngine* build_det_engine(){// 初始化网络结构INetworkDefinition* network = builder->createNetworkV2(0U);// 定义模型输入ITensor* data = network->addInput(kInputTensorName, dt, Dims3{ 3, kInputH, kInputW });// 加载pytorch模型中的参数std::map<std::string, Weights> weightMap = loadWeights(wts_name);// 逐步添加网络结构,已将代码与网络结构一一对应 ,具体过程见上图// 增加yolo后处理decode模块,使用了pluginauto yolo = addYoLoLayer(network, weightMap, "model.24", std::vector<IConvolutionLayer*>{det0, det1, det2});network->markOutput(*yolo->getOutput(0));  //将plugin的输出设置为模型的最后输出(decode)#if defined(USE_FP16)// FP16config->setFlag(BuilderFlag::kFP16);#elif defined(USE_INT8)// INT8 量化std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl;assert(builder->platformHasFastInt8());config->setFlag(BuilderFlag::kINT8);Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, kInputW, kInputH, "./coco_calib/", "int8calib.table", kInputTensorName);config->setInt8Calibrator(calibrator);#endif// 根据网络结构来生成engineICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);return engine;
}
3、plugin

本人对plugin也在学习当中,下面是我在学习trtx-yolo5代码中对plugin浅显的认知。原作者在模型后面增加了一个模型解码的plugin,用于获得每个特征层上的bbox,调用代码在model.cpp中

auto yolo = addYoLoLayer(network, weightMap, "model.24", std::vector<IConvolutionLayer*>{det0, det1, det2});static IPluginV2Layer* addYoLoLayer(...){// 注册一个名为 "YoloLayer_TRT"的插件,如果找不到插件,就会报错auto creator = getPluginRegistry()->getPluginCreator("YoloLayer_TRT", "1");// plugin的数据PluginField plugin_fields[2];int netinfo[5] = {kNumClass, kInputW, kInputH, kMaxNumOutputBbox, (int)is_segmentation};  //维度数据plugin_fields[0].data = netinfo;  plugin_fields[0].length = 5; plugin_fields[0].name = "netinfo";plugin_fields[0].type = PluginFieldType::kFLOAT32;// 所有plugin的参数PluginFieldCollection plugin_data;plugin_data.nbFields = 2;plugin_data.fields = plugin_fields;// 创建plugin的对象 IPluginV2 *plugin_obj = creator->createPlugin("yololayer", &plugin_data);
}

实现代码在yololayer.h/cu中

class API YoloLayerPlugin : public IPluginV2IOExt {// 设置插件名称,在注册插件时会寻找对应的插件const char* getPluginType() const TRT_NOEXCEPT override{return "YoloLayer_TRT";}//插件构造函数YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, bool is_segmentation, const std::vector<YoloKernel>& vYoloKernel){/*classCount:类别数量netWidth:输入宽netHeight:输入高maxOut:最大检测数量is_segmentation:是否含有实例分割vYoloKernel:anchors参数*/}}// 插件运行时调用的代码
void YoloLayerPlugin::forwardGpu(...){// 输出结果 1+ 是在第一个位置记录解码的数量int outputElem = 1 + mMaxOutObject * sizeof(Detection) / sizeof(float);// 将存放结果的内存置为0for (int idx = 0; idx < batchSize; ++idx) {CUDA_CHECK(cudaMemsetAsync(output + idx * outputElem, 0, sizeof(float), stream));// 遍历三种不同尺度的anchorfor (unsigned int i = 0; i < mYoloKernel.size(); ++i) {// 调用核函数进行解码CalDetection << < (numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream >> >(...)}}__global__ void CalDetection(...){// input:模型输出结果// output:decode存放地址// 当前线程的的全局索引IDint idx = threadIdx.x + blockDim.x * blockIdx.x;// yoloWidth * yoloHeightint total_grid = yoloWidth * yoloHeight; // 在当前特征层上要处理的总框数int bnIdx = idx / total_grid;    // 第n个batch    // x,y,w,h,score + 80int info_len_i = 5 + classes;// 如果带有实例分割分析,需要再加上32个分割系数if (is_segmentation) info_len_i += 32;// 第n个batch的推理结果开始地址const float* curInput = input + bnIdx * (info_len_i * total_grid * kNumAnchor);// 遍历三种不同尺寸的anchorfor (int k = 0; k < kNumAnchor; ++k) {//每个框的置信度float box_prob = Logist(curInput[idx + k * info_len_i * total_grid + 4 * total_grid]);if (box_prob < kIgnoreThresh) continue;for (int i = 5; i < 5 + classes; ++i) {// 每个类别的概率float p = Logist(curInput[idx + k * info_len_i * total_grid + i * total_grid]);// 提取最大概率以及类别IDif (p > max_cls_prob) {max_cls_prob = p;class_id = i - 5;}}// float *res_count = output + bnIdx * outputElem;// 统计decode框的数量	int count = (int)atomicAdd(res_count, 1);// 下面是按照论文的公式将预测的宽和高恢复到原图大小...}
}
4、总结

通过本次对trtx开源代码的深入学习,知道了如何利用trt的api对模型进行加速,同时还了解到plugin的实现,后续还会继续学习trtx里面的知识点。

相关文章:

【CV学习笔记】tensorrtx-yolov5 逐行代码解析

1、前言 TensorRTx(下文简称为trtx)是一个十分流行的利用API来搭建网络结构实现trt加速的开源库&#xff0c;作者提到为什么不用ONNX parser的方式来进行trt加速&#xff0c;而用最底层的API来搭建trt加速的方式有如下原因: Flexible 很容易修改模型的任意一层&#xff0c;删…...

微信管理系统可以解决什么问题?

微信作为一款社交通讯软件&#xff0c;已经成为人们日常生活中不可缺少的工具。不仅个人&#xff0c;很多企业都用微信来联系客户、维护客户和营销&#xff0c;这自然而然就会有很多微信账号、手机也多&#xff0c;那管理起来就会带来很多的不便&#xff0c;而微信管理系统正好…...

mysql事务测试

mysql的事务处理主要有两种方法1、用begin,rollback,commit来实现 begin; -- 开始一个事务 rollback; -- 事务回滚 commit; -- 事务提交 2、直接用set来改变mysql的自动提交模式 mysql默认是自动提交的&#xff0c;也就是你提交一个sql&#xff0c;它就直接执行&#xff01;我…...

Spring面试题14:Spring中什么是Spring Beans? 包含哪些?Spring容器提供几种方式配置元数据?Spring中怎样定义类的作用域?

该文章专注于面试,面试只要回答关键点即可,不需要对框架有非常深入的回答,如果你想应付面试,是足够了,抓住关键点 面试官:Spring中什么是Spring Beans? 包含哪些? 在Spring中,Spring Beans是指由Spring容器管理的对象。Spring Beans包含以下内容: 类定义(Class De…...

Tomcat部署、优化、以及操作练习

一.Tomcat的基本介绍 1.1.Tomcat是什么&#xff1f; Tomcat服务器是一个免费的开放源代码的Web应用服务器&#xff0c;属于轻量级应用服务器&#xff0c;在中小型系统和并发访问用户不是很多的场合下被普遍使用&#xff0c;是开发和调试JSP程序的首选。一般来说&#xff0c;T…...

服务器假死日志按时间统计排查

文章目录 场景解决方案排查过程根据cost时间来筛选 场景 服务器假死&#xff0c;进程还在&#xff0c;但是已经接不到请求了。因此有客户报事&#xff0c;发现服务假死了。 解决方案 这种假死问题一般不太好排查&#xff0c;常规来说有几种可能。 1、慢sql导致卡死。 2、大数…...

CSS——grid网格布局的基本使用

网格布局在实现页面自适应&#xff0c;大屏可视化中常常使用&#xff0c;在这篇博客里&#xff0c;记录一下网格布局的基本使用。 参考文档&#xff1a;网格布局_菜鸟教程 文章目录 1. 体会grid的自适应性2. grid-template-arr配置网格行列3. 网格单位fr与repeat()简写属性值4…...

【python】使用Nuitka打包python项目-demo示例

文章目录 写在前面参考准备工作Quick Start参数说明使用打包程序输出目录结构日志2023.09.20 写在前面 本文的demo示例的代码/数据可从笔者的GitCode获取: HelloWorld 参考 Nuitka官网: https://github.com/Nuitka/NuitkaNuitka使用: https://daobook.github.io/nuitka-doc/…...

Java多线程篇(5)——cas和atomic原子类

文章目录 CASAtomic 原子类一般原子类针对aba问题 —— AtomicStampedReference针对大量自旋问题 —— LongAdder CAS 原理大致如下&#xff1a; 在java的 Unsafe 类里封装了一些 cas 的api。以 compareAndSetInt 为例&#xff0c;来看看其底层实现。 可以发现&#xff0c;最…...

数据结构---栈和队列

栈(Stack) 栈&#xff1a;一种特殊的线性表&#xff0c;其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端称为栈 顶&#xff0c;另一端称为栈底。栈中的数据元素遵守后进先出LIFO&#xff08;Last In First Out&#xff09;的原则。 压栈&#xff1…...

2023-9-23 合并果子

题目链接&#xff1a;合并果子 #include <iostream> #include <algorithm> #include <queue>using namespace std;int main() {int n;cin >> n;priority_queue<int, vector<int>, greater<int>> heap;for(int i 0; i < n; i){in…...

基于QT和UDP实现一个实时RTP数据包的接收,并将数据包转化成文件

简单介绍&#xff1a;代码写的比较详细&#xff0c;需要留意的地方看结尾介绍 头文件 #ifndef RTPRECEIVER_H #define RTPRECEIVER_H#include <QDialog> #include <QUdpSocket> #include <QFile> #include <QTextStream> #include <httpclient.h&g…...

云原生安全性:保护现代应用免受威胁

文章目录 引言云原生安全性的挑战云原生安全性的关键实践1. 安全的镜像构建2. 网络策略3. 漏洞扫描和漏洞管理4. 认证和授权5. 日志和监控 云原生安全工具结论 &#x1f389;欢迎来到云计算技术应用专栏~云原生安全性&#xff1a;保护现代应用免受威胁 ☆* o(≧▽≦)o *☆嗨~我…...

R语言绘图-3-Circular-barplot图

0. 参考&#xff1a; https://r-graph-gallery.com/web-circular-barplot-with-R-and-ggplot2.html 1. 说明&#xff1a; 利用 ggplot 绘制 环状的条形图 (circular barplot)&#xff0c;并且每个条带按照数值大小进行排列。 2 绘图代码: 注意&#xff1a;绘图代码中的字体…...

解决Keil5下载没有对应芯片Flash的问题

问题描述 例如芯片是STM32F103ZET6&#xff0c;但是选项中并没有对应型号的芯片导致下载失败。 解决方法 1、寻找芯片安装包的具体位置&#xff0c;芯片安装包路径在软件安装过程中会有&#xff08;如图1所示&#xff09;。如果没有记录可以双击一下芯片安装包会直接提示。…...

深拷贝与浅拷贝(对象的引用)

可以用赋值 1.对象的引用 代码&#xff1a; <!-- 1.对象的引用 --><script>const info{name:"lucy",age:20}const objinfo;info.name"sam"console.log(obj.name) //sam</script>图解&#xff1a; 等于号的赋值&#xff0c;对象info…...

重新认识架构—不只是软件设计

前言 什么是架构&#xff1f; 通常情况下&#xff0c;人们对架构的认知仅限于在软件工程中的定义&#xff1a;架构主要指软件系统的结构设计&#xff0c;比如常见的SOLID准则、DDD架构。一个良好的软件架构可以帮助团队更有效地进行软件开发&#xff0c;降低维护成本&#xff0…...

我的创业笔记:困境与思索

现在是2023年9月22日傍晚&#xff0c;我一个人走在广州的珠江边&#xff0c;静静地思索着当前个人创业面临的困境&#xff0c;不由自主地想将这些想法记录下来。 故事需要从两个月前说起。2023年7月31号&#xff0c;我从金山办公离职后&#xff0c;就满心欢喜地开启了自己的个…...

minio文件上传

1.代码 大佬仓库&#xff1a;https://gitee.com/Gary2016/minio-upload?_fromgitee_search 关于这个代码的讲解&#xff1a;来自b站 2.准备minio 参考&#xff1a;[1]、[2] 2.1 下载 官网&#xff1a;https://min.io/download#/windows 2.2 启动 ①准备一个data文件夹…...

IDEA .iml文件及.idea文件夹详解

.iml文件 idea 对module 配置信息之意&#xff0c; infomation of module。每个模块都有一个iml文件。 IDEA中的.iml文件是项目标识文件&#xff0c;缺少了这个文件&#xff0c;IDEA就无法识别项目。跟Eclipse的.project文件性质是一样的。并且这些文件不同的设备上的内容也会…...

多云管理“拦路虎”:深入解析网络互联、身份同步与成本可视化的技术复杂度​

一、引言&#xff1a;多云环境的技术复杂性本质​​ 企业采用多云策略已从技术选型升维至生存刚需。当业务系统分散部署在多个云平台时&#xff0c;​​基础设施的技术债呈现指数级积累​​。网络连接、身份认证、成本管理这三大核心挑战相互嵌套&#xff1a;跨云网络构建数据…...

进程地址空间(比特课总结)

一、进程地址空间 1. 环境变量 1 &#xff09;⽤户级环境变量与系统级环境变量 全局属性&#xff1a;环境变量具有全局属性&#xff0c;会被⼦进程继承。例如当bash启动⼦进程时&#xff0c;环 境变量会⾃动传递给⼦进程。 本地变量限制&#xff1a;本地变量只在当前进程(ba…...

Docker 运行 Kafka 带 SASL 认证教程

Docker 运行 Kafka 带 SASL 认证教程 Docker 运行 Kafka 带 SASL 认证教程一、说明二、环境准备三、编写 Docker Compose 和 jaas文件docker-compose.yml代码说明&#xff1a;server_jaas.conf 四、启动服务五、验证服务六、连接kafka服务七、总结 Docker 运行 Kafka 带 SASL 认…...

dedecms 织梦自定义表单留言增加ajax验证码功能

增加ajax功能模块&#xff0c;用户不点击提交按钮&#xff0c;只要输入框失去焦点&#xff0c;就会提前提示验证码是否正确。 一&#xff0c;模板上增加验证码 <input name"vdcode"id"vdcode" placeholder"请输入验证码" type"text&quo…...

鸿蒙中用HarmonyOS SDK应用服务 HarmonyOS5开发一个医院查看报告小程序

一、开发环境准备 ​​工具安装​​&#xff1a; 下载安装DevEco Studio 4.0&#xff08;支持HarmonyOS 5&#xff09;配置HarmonyOS SDK 5.0确保Node.js版本≥14 ​​项目初始化​​&#xff1a; ohpm init harmony/hospital-report-app 二、核心功能模块实现 1. 报告列表…...

C# SqlSugar:依赖注入与仓储模式实践

C# SqlSugar&#xff1a;依赖注入与仓储模式实践 在 C# 的应用开发中&#xff0c;数据库操作是必不可少的环节。为了让数据访问层更加简洁、高效且易于维护&#xff0c;许多开发者会选择成熟的 ORM&#xff08;对象关系映射&#xff09;框架&#xff0c;SqlSugar 就是其中备受…...

高防服务器能够抵御哪些网络攻击呢?

高防服务器作为一种有着高度防御能力的服务器&#xff0c;可以帮助网站应对分布式拒绝服务攻击&#xff0c;有效识别和清理一些恶意的网络流量&#xff0c;为用户提供安全且稳定的网络环境&#xff0c;那么&#xff0c;高防服务器一般都可以抵御哪些网络攻击呢&#xff1f;下面…...

Map相关知识

数据结构 二叉树 二叉树&#xff0c;顾名思义&#xff0c;每个节点最多有两个“叉”&#xff0c;也就是两个子节点&#xff0c;分别是左子 节点和右子节点。不过&#xff0c;二叉树并不要求每个节点都有两个子节点&#xff0c;有的节点只 有左子节点&#xff0c;有的节点只有…...

重启Eureka集群中的节点,对已经注册的服务有什么影响

先看答案&#xff0c;如果正确地操作&#xff0c;重启Eureka集群中的节点&#xff0c;对已经注册的服务影响非常小&#xff0c;甚至可以做到无感知。 但如果操作不当&#xff0c;可能会引发短暂的服务发现问题。 下面我们从Eureka的核心工作原理来详细分析这个问题。 Eureka的…...

JS设计模式(4):观察者模式

JS设计模式(4):观察者模式 一、引入 在开发中&#xff0c;我们经常会遇到这样的场景&#xff1a;一个对象的状态变化需要自动通知其他对象&#xff0c;比如&#xff1a; 电商平台中&#xff0c;商品库存变化时需要通知所有订阅该商品的用户&#xff1b;新闻网站中&#xff0…...