当前位置: 首页 > news >正文

【CV学习笔记】tensorrtx-yolov5 逐行代码解析

1、前言

TensorRTx(下文简称为trtx)是一个十分流行的利用API来搭建网络结构实现trt加速的开源库,作者提到为什么不用ONNX parser的方式来进行trt加速,而用最底层的API来搭建trt加速的方式有如下原因:

  • Flexible 很容易修改模型的任意一层,删除、增加、替换等操作。
  • Debuggable 可以容易获得模型中间某一层的结果
  • Chance to learn 可以对模型结构有进一步的了解

尽管onnx2trt的方式目前已经在绝大部分情况下都不会出现问题,但在trtx下,我们能够掌握更底层的原理和代码,有利于我们对模型的部署以及优化。下文将会以yolov5s在trtx框架下的例子,来逐行解析是trtx是如何工作的。

TensorRTx项目链接:https://github.com/wang-xinyu/tensorrtx。

2、步骤解析

在trtx中,对一个模型加速的过程可以分为两个步骤

  • 提取pytorch模型参数 wts
  • 利用trt底层API搭建网络结构,并将wts中的参数填充到网络中
2.1、get_wts.py

首先需要将pytorch中的模型参数提取出来,pytorch中的模型参数是以caffe中blob的格式存在的,每个操作都有对应的名字、数据长度、数据.

for k, v in model.state_dict().items():# k-> blob的名字vr = v.reshape(-1).cpu().numpy() # vr -> 数据长度f.write('{} {} '.format(k, len(vr)))for vv in vr:f.write(' ')f.write(struct.pack('>f', float(vv)).hex()) # 将数据转化到16进制f.write('\n')

通过上get_wts.py,就可以得到包含yolov5s.pth的模型参数,打开yolov5s.wts如下图所示:

在这里插入图片描述

其中第一行的351为总的blob数量,第二行的model.0.conv.weight为第一个blob的名字,3456表示为该blob的数据长度,3a198000 3ca58000…为实际参数。

得到了上述的参数之后,就可以以trtx的方式进行加速了。

2.2、构造engine

在利用wts转engine的之前,需要十分清楚模型的网络结构,不太清楚的同学可以参考太阳花的小绿豆关于yolov5的网络结构图。了解完yolov5的网络结构后,就可以着手利用trt的api来搭建网络模型了。搭建模型的代码在 model.cpp中的build_det_engine函数,本文将其中的代码过程直接画到yolov5的网络结构图中了,可以直接对照代码和图来进行查看。
在这里插入图片描述

//yolov5_det.cpp
viod serialize_engine(...){if (is_p6) {...} else {// 以yolov5s为例engine = build_det_engine(max_batchsize, builder, config, DataType::kFLOAT, gd, gw, wts_name);}// 序列化IHostMemory* serialized_engine = engine->serialize();std::ofstream p(engine_name, std::ios::binary);// 写到文件中p.write(reinterpret_cast<const char*>(serialized_engine->data()), serialized_engine->size());}

model.cpp

// 解析get_wts.py
static std::map<std::string, Weights> loadWeights(const std::string file) {int32_t count;  // wts文件第一行,共有351个blobinput >> count;//每一行是一个blob,模型名称 + 数据长度 + 参数while (count--) {// 一个blob的参数Weights wt{ DataType::kFLOAT, nullptr, 0 };uint32_t size;  //blob 数据长度std::string name; // blob 数据名字for (uint32_t x = 0, y = size; x < y; ++x) {input >> std::hex >> val[x];  // 将数据转化成十进制,并放到val中}// 每个blob名字对应一个wtweightMap[name] = wt;}
}ICudaEngine* build_det_engine(){// 初始化网络结构INetworkDefinition* network = builder->createNetworkV2(0U);// 定义模型输入ITensor* data = network->addInput(kInputTensorName, dt, Dims3{ 3, kInputH, kInputW });// 加载pytorch模型中的参数std::map<std::string, Weights> weightMap = loadWeights(wts_name);// 逐步添加网络结构,已将代码与网络结构一一对应 ,具体过程见上图// 增加yolo后处理decode模块,使用了pluginauto yolo = addYoLoLayer(network, weightMap, "model.24", std::vector<IConvolutionLayer*>{det0, det1, det2});network->markOutput(*yolo->getOutput(0));  //将plugin的输出设置为模型的最后输出(decode)#if defined(USE_FP16)// FP16config->setFlag(BuilderFlag::kFP16);#elif defined(USE_INT8)// INT8 量化std::cout << "Your platform support int8: " << (builder->platformHasFastInt8() ? "true" : "false") << std::endl;assert(builder->platformHasFastInt8());config->setFlag(BuilderFlag::kINT8);Int8EntropyCalibrator2* calibrator = new Int8EntropyCalibrator2(1, kInputW, kInputH, "./coco_calib/", "int8calib.table", kInputTensorName);config->setInt8Calibrator(calibrator);#endif// 根据网络结构来生成engineICudaEngine* engine = builder->buildEngineWithConfig(*network, *config);return engine;
}
3、plugin

本人对plugin也在学习当中,下面是我在学习trtx-yolo5代码中对plugin浅显的认知。原作者在模型后面增加了一个模型解码的plugin,用于获得每个特征层上的bbox,调用代码在model.cpp中

auto yolo = addYoLoLayer(network, weightMap, "model.24", std::vector<IConvolutionLayer*>{det0, det1, det2});static IPluginV2Layer* addYoLoLayer(...){// 注册一个名为 "YoloLayer_TRT"的插件,如果找不到插件,就会报错auto creator = getPluginRegistry()->getPluginCreator("YoloLayer_TRT", "1");// plugin的数据PluginField plugin_fields[2];int netinfo[5] = {kNumClass, kInputW, kInputH, kMaxNumOutputBbox, (int)is_segmentation};  //维度数据plugin_fields[0].data = netinfo;  plugin_fields[0].length = 5; plugin_fields[0].name = "netinfo";plugin_fields[0].type = PluginFieldType::kFLOAT32;// 所有plugin的参数PluginFieldCollection plugin_data;plugin_data.nbFields = 2;plugin_data.fields = plugin_fields;// 创建plugin的对象 IPluginV2 *plugin_obj = creator->createPlugin("yololayer", &plugin_data);
}

实现代码在yololayer.h/cu中

class API YoloLayerPlugin : public IPluginV2IOExt {// 设置插件名称,在注册插件时会寻找对应的插件const char* getPluginType() const TRT_NOEXCEPT override{return "YoloLayer_TRT";}//插件构造函数YoloLayerPlugin(int classCount, int netWidth, int netHeight, int maxOut, bool is_segmentation, const std::vector<YoloKernel>& vYoloKernel){/*classCount:类别数量netWidth:输入宽netHeight:输入高maxOut:最大检测数量is_segmentation:是否含有实例分割vYoloKernel:anchors参数*/}}// 插件运行时调用的代码
void YoloLayerPlugin::forwardGpu(...){// 输出结果 1+ 是在第一个位置记录解码的数量int outputElem = 1 + mMaxOutObject * sizeof(Detection) / sizeof(float);// 将存放结果的内存置为0for (int idx = 0; idx < batchSize; ++idx) {CUDA_CHECK(cudaMemsetAsync(output + idx * outputElem, 0, sizeof(float), stream));// 遍历三种不同尺度的anchorfor (unsigned int i = 0; i < mYoloKernel.size(); ++i) {// 调用核函数进行解码CalDetection << < (numElem + mThreadCount - 1) / mThreadCount, mThreadCount, 0, stream >> >(...)}}__global__ void CalDetection(...){// input:模型输出结果// output:decode存放地址// 当前线程的的全局索引IDint idx = threadIdx.x + blockDim.x * blockIdx.x;// yoloWidth * yoloHeightint total_grid = yoloWidth * yoloHeight; // 在当前特征层上要处理的总框数int bnIdx = idx / total_grid;    // 第n个batch    // x,y,w,h,score + 80int info_len_i = 5 + classes;// 如果带有实例分割分析,需要再加上32个分割系数if (is_segmentation) info_len_i += 32;// 第n个batch的推理结果开始地址const float* curInput = input + bnIdx * (info_len_i * total_grid * kNumAnchor);// 遍历三种不同尺寸的anchorfor (int k = 0; k < kNumAnchor; ++k) {//每个框的置信度float box_prob = Logist(curInput[idx + k * info_len_i * total_grid + 4 * total_grid]);if (box_prob < kIgnoreThresh) continue;for (int i = 5; i < 5 + classes; ++i) {// 每个类别的概率float p = Logist(curInput[idx + k * info_len_i * total_grid + i * total_grid]);// 提取最大概率以及类别IDif (p > max_cls_prob) {max_cls_prob = p;class_id = i - 5;}}// float *res_count = output + bnIdx * outputElem;// 统计decode框的数量	int count = (int)atomicAdd(res_count, 1);// 下面是按照论文的公式将预测的宽和高恢复到原图大小...}
}
4、总结

通过本次对trtx开源代码的深入学习,知道了如何利用trt的api对模型进行加速,同时还了解到plugin的实现,后续还会继续学习trtx里面的知识点。

相关文章:

【CV学习笔记】tensorrtx-yolov5 逐行代码解析

1、前言 TensorRTx(下文简称为trtx)是一个十分流行的利用API来搭建网络结构实现trt加速的开源库&#xff0c;作者提到为什么不用ONNX parser的方式来进行trt加速&#xff0c;而用最底层的API来搭建trt加速的方式有如下原因: Flexible 很容易修改模型的任意一层&#xff0c;删…...

微信管理系统可以解决什么问题?

微信作为一款社交通讯软件&#xff0c;已经成为人们日常生活中不可缺少的工具。不仅个人&#xff0c;很多企业都用微信来联系客户、维护客户和营销&#xff0c;这自然而然就会有很多微信账号、手机也多&#xff0c;那管理起来就会带来很多的不便&#xff0c;而微信管理系统正好…...

mysql事务测试

mysql的事务处理主要有两种方法1、用begin,rollback,commit来实现 begin; -- 开始一个事务 rollback; -- 事务回滚 commit; -- 事务提交 2、直接用set来改变mysql的自动提交模式 mysql默认是自动提交的&#xff0c;也就是你提交一个sql&#xff0c;它就直接执行&#xff01;我…...

Spring面试题14:Spring中什么是Spring Beans? 包含哪些?Spring容器提供几种方式配置元数据?Spring中怎样定义类的作用域?

该文章专注于面试,面试只要回答关键点即可,不需要对框架有非常深入的回答,如果你想应付面试,是足够了,抓住关键点 面试官:Spring中什么是Spring Beans? 包含哪些? 在Spring中,Spring Beans是指由Spring容器管理的对象。Spring Beans包含以下内容: 类定义(Class De…...

Tomcat部署、优化、以及操作练习

一.Tomcat的基本介绍 1.1.Tomcat是什么&#xff1f; Tomcat服务器是一个免费的开放源代码的Web应用服务器&#xff0c;属于轻量级应用服务器&#xff0c;在中小型系统和并发访问用户不是很多的场合下被普遍使用&#xff0c;是开发和调试JSP程序的首选。一般来说&#xff0c;T…...

服务器假死日志按时间统计排查

文章目录 场景解决方案排查过程根据cost时间来筛选 场景 服务器假死&#xff0c;进程还在&#xff0c;但是已经接不到请求了。因此有客户报事&#xff0c;发现服务假死了。 解决方案 这种假死问题一般不太好排查&#xff0c;常规来说有几种可能。 1、慢sql导致卡死。 2、大数…...

CSS——grid网格布局的基本使用

网格布局在实现页面自适应&#xff0c;大屏可视化中常常使用&#xff0c;在这篇博客里&#xff0c;记录一下网格布局的基本使用。 参考文档&#xff1a;网格布局_菜鸟教程 文章目录 1. 体会grid的自适应性2. grid-template-arr配置网格行列3. 网格单位fr与repeat()简写属性值4…...

【python】使用Nuitka打包python项目-demo示例

文章目录 写在前面参考准备工作Quick Start参数说明使用打包程序输出目录结构日志2023.09.20 写在前面 本文的demo示例的代码/数据可从笔者的GitCode获取: HelloWorld 参考 Nuitka官网: https://github.com/Nuitka/NuitkaNuitka使用: https://daobook.github.io/nuitka-doc/…...

Java多线程篇(5)——cas和atomic原子类

文章目录 CASAtomic 原子类一般原子类针对aba问题 —— AtomicStampedReference针对大量自旋问题 —— LongAdder CAS 原理大致如下&#xff1a; 在java的 Unsafe 类里封装了一些 cas 的api。以 compareAndSetInt 为例&#xff0c;来看看其底层实现。 可以发现&#xff0c;最…...

数据结构---栈和队列

栈(Stack) 栈&#xff1a;一种特殊的线性表&#xff0c;其只允许在固定的一端进行插入和删除元素操作。进行数据插入和删除操作的一端称为栈 顶&#xff0c;另一端称为栈底。栈中的数据元素遵守后进先出LIFO&#xff08;Last In First Out&#xff09;的原则。 压栈&#xff1…...

2023-9-23 合并果子

题目链接&#xff1a;合并果子 #include <iostream> #include <algorithm> #include <queue>using namespace std;int main() {int n;cin >> n;priority_queue<int, vector<int>, greater<int>> heap;for(int i 0; i < n; i){in…...

基于QT和UDP实现一个实时RTP数据包的接收,并将数据包转化成文件

简单介绍&#xff1a;代码写的比较详细&#xff0c;需要留意的地方看结尾介绍 头文件 #ifndef RTPRECEIVER_H #define RTPRECEIVER_H#include <QDialog> #include <QUdpSocket> #include <QFile> #include <QTextStream> #include <httpclient.h&g…...

云原生安全性:保护现代应用免受威胁

文章目录 引言云原生安全性的挑战云原生安全性的关键实践1. 安全的镜像构建2. 网络策略3. 漏洞扫描和漏洞管理4. 认证和授权5. 日志和监控 云原生安全工具结论 &#x1f389;欢迎来到云计算技术应用专栏~云原生安全性&#xff1a;保护现代应用免受威胁 ☆* o(≧▽≦)o *☆嗨~我…...

R语言绘图-3-Circular-barplot图

0. 参考&#xff1a; https://r-graph-gallery.com/web-circular-barplot-with-R-and-ggplot2.html 1. 说明&#xff1a; 利用 ggplot 绘制 环状的条形图 (circular barplot)&#xff0c;并且每个条带按照数值大小进行排列。 2 绘图代码: 注意&#xff1a;绘图代码中的字体…...

解决Keil5下载没有对应芯片Flash的问题

问题描述 例如芯片是STM32F103ZET6&#xff0c;但是选项中并没有对应型号的芯片导致下载失败。 解决方法 1、寻找芯片安装包的具体位置&#xff0c;芯片安装包路径在软件安装过程中会有&#xff08;如图1所示&#xff09;。如果没有记录可以双击一下芯片安装包会直接提示。…...

深拷贝与浅拷贝(对象的引用)

可以用赋值 1.对象的引用 代码&#xff1a; <!-- 1.对象的引用 --><script>const info{name:"lucy",age:20}const objinfo;info.name"sam"console.log(obj.name) //sam</script>图解&#xff1a; 等于号的赋值&#xff0c;对象info…...

重新认识架构—不只是软件设计

前言 什么是架构&#xff1f; 通常情况下&#xff0c;人们对架构的认知仅限于在软件工程中的定义&#xff1a;架构主要指软件系统的结构设计&#xff0c;比如常见的SOLID准则、DDD架构。一个良好的软件架构可以帮助团队更有效地进行软件开发&#xff0c;降低维护成本&#xff0…...

我的创业笔记:困境与思索

现在是2023年9月22日傍晚&#xff0c;我一个人走在广州的珠江边&#xff0c;静静地思索着当前个人创业面临的困境&#xff0c;不由自主地想将这些想法记录下来。 故事需要从两个月前说起。2023年7月31号&#xff0c;我从金山办公离职后&#xff0c;就满心欢喜地开启了自己的个…...

minio文件上传

1.代码 大佬仓库&#xff1a;https://gitee.com/Gary2016/minio-upload?_fromgitee_search 关于这个代码的讲解&#xff1a;来自b站 2.准备minio 参考&#xff1a;[1]、[2] 2.1 下载 官网&#xff1a;https://min.io/download#/windows 2.2 启动 ①准备一个data文件夹…...

IDEA .iml文件及.idea文件夹详解

.iml文件 idea 对module 配置信息之意&#xff0c; infomation of module。每个模块都有一个iml文件。 IDEA中的.iml文件是项目标识文件&#xff0c;缺少了这个文件&#xff0c;IDEA就无法识别项目。跟Eclipse的.project文件性质是一样的。并且这些文件不同的设备上的内容也会…...

十五五具身智能规划纲要解读:政策领航打造中国具身未来

摘要&#xff1a;本报告解读“十五五”规划对具身智能的战略布局&#xff0c;其首次被系统写入国家未来产业&#xff0c;明确实训场、核心技术攻关等落地抓手。我国在政策支持、工业供应链、市场需求上具备领先优势&#xff0c;2025年人形机器人出货量占全球84.7%&#xff0c;宇…...

OpCore Simplify:自动化配置黑苹果系统部署的创新方法——从配置困境到高效部署的转变

OpCore Simplify&#xff1a;自动化配置黑苹果系统部署的创新方法——从配置困境到高效部署的转变 【免费下载链接】OpCore-Simplify A tool designed to simplify the creation of OpenCore EFI 项目地址: https://gitcode.com/GitHub_Trending/op/OpCore-Simplify 作为…...

第十五届蓝桥杯c++B组:宝石组合

蓝桥杯真题&#xff1a;宝石组合#include<bits/stdc.h> // 万能头文件&#xff0c;包含了C所有标准库 using namespace std; // 自定义函数&#xff1a;求三个数的最小公倍数&#xff08;LCM&#xff09; int LCM(int x, int y, int z) {int maxx max(…...

MySql自用

一、语法 1.左连接 left join ...on... left左边的表的行全保留 2.子嵌套需要给别名 3.基础函数框架 Create Function 函数名(N INT) Returns Int 函数返回值类型 BeginReturn(--函数体); End N INT&#xff1a;入参&#xff0c;参数名为 N&#xff0c;类型为整数 INT&a…...

Qwen2.5-VL-7B-Instruct新手入门:从安装到第一个图文对话

Qwen2.5-VL-7B-Instruct新手入门&#xff1a;从安装到第一个图文对话 1. 环境准备与快速部署 1.1 硬件要求 Qwen2.5-VL-7B-Instruct是专为RTX 4090显卡优化的多模态大模型&#xff0c;需要满足以下硬件条件&#xff1a; 显卡&#xff1a;NVIDIA RTX 4090&#xff08;24GB显…...

Flutter环境搭建全攻略:从安装到解决常见问题

1. Flutter开发环境搭建前的准备 在开始Flutter开发之前&#xff0c;我们需要做好一些基础准备工作。首先确保你的电脑满足以下最低配置要求&#xff1a; 操作系统&#xff1a;Windows 10或更高版本&#xff08;64位&#xff09;磁盘空间&#xff1a;至少5GB可用空间内存&#…...

strace命令实战指南:从基础到高级的系统调用跟踪技巧

1. strace命令基础入门&#xff1a;你的第一个系统调用跟踪 第一次接触strace时&#xff0c;我盯着屏幕上飞速滚动的系统调用记录完全摸不着头脑。直到有次服务器上的Python脚本莫名其妙卡死&#xff0c;老工程师用三行strace命令就定位到是文件权限问题&#xff0c;我才真正理…...

手把手教你用Qwen2.5-0.5B-Instruct快速搭建多语言聊天机器人

手把手教你用Qwen2.5-0.5B-Instruct快速搭建多语言聊天机器人 1. 为什么选择这个模型&#xff1f; 在当今全球化环境中&#xff0c;能够支持多种语言的智能助手变得越来越重要。Qwen2.5-0.5B-Instruct作为阿里云开源的最新轻量级大语言模型&#xff0c;特别适合需要快速部署多…...

Locale-Emulator完全指南:突破区域限制的7个实战技巧

Locale-Emulator完全指南&#xff1a;突破区域限制的7个实战技巧 【免费下载链接】Locale-Emulator Yet Another System Region and Language Simulator 项目地址: https://gitcode.com/gh_mirrors/lo/Locale-Emulator 副标题&#xff1a;如何让你的软件不再受系统区域设…...

FreeRTOS vs 裸机开发:何时该用RTOS?项目实战对比分析

FreeRTOS vs 裸机开发&#xff1a;何时该用RTOS&#xff1f;项目实战对比分析 在嵌入式开发的世界里&#xff0c;开发者常常面临一个关键选择&#xff1a;是采用裸机开发&#xff08;Bare Metal&#xff09;还是引入实时操作系统&#xff08;RTOS&#xff09;&#xff1f;这个问…...