使用PyTorch导出JIT模型:C++ API与libtorch实战
PyTorch导出JIT模型并用C++ API libtorch调用
本文将介绍如何将一个 PyTorch 模型导出为 JIT 模型并用 PyTorch 的 C++API libtorch运行这个模型。
Step1:导出模型
首先我们进行第一步,用 Python API 来导出模型,由于本文的重点是在后面的部署阶段,因此,模型的训练就不进行了,直接对 torchvision 中自带的 ResNet50 进行导出。在实际应用中,大家可以对自己训练好的模型进行导出。
# export_jit_model.py
import torch
import torchvision.models as modelsmodel = models.resnet50(pretrained=True)
model.eval()example_input = torch.rand(1, 3, 224, 224)jit_model = torch.jit.trace(model, example_input)
torch.jit.save(jit_model, 'resnet50_jit.pth')
导出 JIT 模型的方式有两种:trace 和 script。
我们采用
torch.jit.trace
的方式来导出 JIT 模型,这种方式会根据一个输入将模型跑一遍,然后记录下执行过程。这种方式的问题在于对于有分支判断的模型不能很好的应对,因为一个输入不能覆盖到所有的分支。但是在我们 ResNet50 模型中不会遇到分支判断,因此这里是合适的。关于两种导出 JIT 模型的方式各自优劣不是本文的中断,以后会再写一篇来分析。
在我们的工程目录
demo
下运行上面的
export_jit_model.py
,会得到一个 JIT 模型件:
resnet50_jit.pth
。
Step 2:安装libtorch
接下来我们要安装 PyTorch 的 C++ API:libtorch。这一步很简单,直接下载官方预编译的文件并解压即可:
wget https://download.pytorch.org/libtorch/nightly/cpu/libtorch-shared-with-deps-latest.zip
unzip libtorch-shared-with-deps-latest.zip
也解压在我们的工程目录
demo
下即可。
Step 3:安装OpenCV
用 Python 或 C++ 做图像任务,OpenCV 是经常用到的。如果还没有安装的读者可以参考如下在工程目录
demo
下进行安装,构建的过程可能会比较久。已经安装的读者可跳过此步骤,一会儿在
CMakeLists.txt
文件中正确地指定本机的 OpenCV 地址即可。
git clone --branch 3.4 --depth 1 https://github.com/opencv/opencv.git
mkdir demo/build && cd demo/build
cmake ..
make -j 6
Step 4:准备测试图像并用Python测试
我们先准备一张小猫的图像,并用 PyTorch ResNet50 模型正常跑一下,一会儿与我们 C++ 模型运行的结果对比来验证 C++ 模型是否被正确的部署。
kitten.jpg
:

写一个脚本用 PyTorch 运行一下模型:
# pytorch_test.pyimport torchvision.models as models
from torchvision.transforms import transforms
import torch
from PIL import Image# normalize = transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])
all_transforms = transforms.Compose([transforms.Resize(224),transforms.ToTensor()])# normalize])model = models.resnet50(pretrained=True)
model.eval()img = Image.open('kitten.jpg').convert('RGB')
img_tensor = all_transforms(img).unsqueeze(dim=0)
pred = model(img_tensor).squeeze(dim=0)
print(torch.argmax(pred).item())
输出结果是:282。通过查看
ImageNet 1K 类别名与索引的对应关系
,可以看到,结果为 tiger cat,模型预测正确。一会儿我们看一下部署后的 C++ 模型是否能正确输出结果 282。
Step 5:准备cpp源文件
我们下面准备一会要执行的 cpp 源文件,第一次使用 libtorch 的读者可以先借鉴下面的文件。
这里有几个点要说一下,不注意可能会犯错:
cv::imread()
默认读取为三通道BGR,需要进行B/R通道交换,这里采用
cv::cvtColor()
实现。- 图像尺寸需要调整到
224
×
224
224\times 224
2
2
4
×
2
2
4
,通过
cv::resize()
实现。
3. opencv读取的图像矩阵存储形式:H x W x C, 但是pytorch中 Tensor的存储为:N x C x H x W, 因此需要进行变换,就是
np.transpose()
操作,这里使用
tensor.permut()
实现,效果是一样的。
4. 数据归一化,采用
tensor.div(255)
实现。
// test_model.cpp
#include <vector>#include <torch/torch.h>
#include <torch/script.h>#include <opencv2/core.hpp>
#include <opencv2/imgproc/imgproc.hpp>
#include <opencv2/highgui/highgui.hpp>int main(int argc, char* argv[]) {// 加载JIT模型auto module = torch::jit::load(argv[1]);// 加载图像auto image = cv::imread(argv[2], cv::ImreadModes::IMREAD_COLOR);cv::Mat image_transfomed;cv::resize(image, image_transfomed, cv::Size(224, 224));cv::cvtColor(image_transfomed, image_transfomed, cv::COLOR_BGR2RGB);// 图像转换为Tensortorch::Tensor tensor_image = torch::from_blob(image_transfomed.data, {image_transfomed.rows, image_transfomed.cols, 3},torch::kByte);tensor_image = tensor_image.permute({2, 0, 1});// tensor_image = tensor_image.toType(torch::kFloat);tensor_image = tensor_image.div(255.);// tensor_image = tensor_image.sub(0.5);// tensor_image = tensor_image.div(0.5);tensor_image = tensor_image.unsqueeze(0);// 运行模型torch::Tensor output = module.forward({tensor_image}).toTensor();// 结果处理int result = output.argmax().item<int>();std::cout << "The classifiction index is: " << result << std::endl;return 0;
}
Step 6:构建运行验证
我们先来写一下
CMakeLists.txt
:
cmake_minimum_required(VERSION 3.0 FATAL_ERROR)
project(resnet50)find_package(Torch REQUIRED PATHS ./libtorch)
find_package(OpenCV REQUIRED)add_executable(resnet50 test_model.cpp)
target_link_libraries(resnet50 "${TORCH_LIBRARIES}" "${OpenCV_LIBS}")set_property(TARGET resnet50 PROPERTY CXX_STANDARD 14)
set(CMAKE_CXX_STANDARD_REQUIRED TRUE)set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${TORCH_CXX_FLAGS}")
现在我们的工程目录
demo
下有以下文件:
CMakeLists.txt export_jit_model.py kitten.jpg libtorch pytorch_test.py resnet50_jit.pth test_model.cpp
然后开始用 CMake 构建工程:
mkdir build && cd build
OpenCV_DIR=[YOUR_PATH_TO_OPENCV]/opencv/build cmake ..
make
整个过程没有报错的话我们就已经构建完成了,会得到一个可执行文件
resnet50
在工程目录
demo
下。
接下来我们执行,并验证运行结果是否与 PyTorch 的结果一致:
./build/resnet50 resnet50_jit.pth kitten.jpg
输出:
The classifiction index is: 282
运行成功并且结果正确。
Ref:
https://www.jianshu.com/p/7cddc09ca7a4
https://blog.csdn.net/cxx654/article/details/115916275
https://zhuanlan.zhihu.com/p/370455320
相关文章:
使用PyTorch导出JIT模型:C++ API与libtorch实战
PyTorch导出JIT模型并用C API libtorch调用 本文将介绍如何将一个 PyTorch 模型导出为 JIT 模型并用 PyTorch 的 CAPI libtorch运行这个模型。 Step1:导出模型 首先我们进行第一步,用 Python API 来导出模型,由于本文的重点是在后面的部署…...
Python——异常捕获,传递及其抛出操作
01. 异常的概念 1. 程序在运行时,如果 python解释器遇到一个错误,会停止程序的执行,并且提示一些错误信息,这就是异常。 2. 程序停止执行并且提示错误信息这个动作,我们通常称之为:抛出(raise…...
【Maven】 的继承机制
Maven是一个强大的项目管理工具,主要用于Java项目的构建和管理。它以其项目对象模型(POM)为基础,允许开发者定义项目的依赖、构建过程和插件。Maven的继承机制是其核心特性之一,它允许子项目继承和复用父项目的配置&am…...
微信小程序结合后端php发送模版消息
前端: <view class"container"><button bindtap"requestSubscribeMessage">订阅消息</button> </view> // index.js Page({data: {tmplIds: [UTgCUfsjHVESf5FjOzls0I9i_FVS1N620G2VQCg1LZ0] // 使用你的模板ID},requ…...
sqlalchemy报错sqlalchemy.orm.exc.DetachedInstanceError
解决方案: 在初始化数据库的代码中,将 maker sessionmaker(bindeng)修改为 maker sessionmaker(bindeng, expire_on_commitFalse)为什么要添加 expire_on_commitFalse 参数? expire_on_commit 可以用来更改 SQLAlchemy 的对象刷新机制&…...
华为网络模拟器eNSP安装部署教程
eNSP是图形化网络仿真平台,该平台通过对真实网络设备的仿真模拟,帮助广大ICT从业者和客户快速熟悉华为数通系列产品,了解并掌握相关产品的操作和配置、提升对企业ICT网络的规划、建设、运维能力,从而帮助企业构建更高效࿰…...
【React】详解样式控制:从基础到进阶应用的全面指南
文章目录 一、内联样式1. 什么是内联样式?2. 内联样式的定义3. 基本示例4. 动态内联样式 二、CSS模块1. 什么是CSS模块?2. CSS模块的定义3. 基本示例4. 动态应用样式 三、CSS-in-JS1. 什么是CSS-in-JS?2. styled-components的定义3. 基本示例…...
【ROS2】高级:安全-理解安全密钥库
目标:探索位于 ROS 2 安全密钥库中的文件。 教程级别:高级 时间:15 分钟 内容 背景安全工件位置 公钥材料 私钥材料域治理政策 安全飞地 参加测验! 背景 在继续之前,请确保您已完成设置安全教程。 sros2 包可以用来创…...
C语言 ——— 数组指针的定义 数组指针的使用
目录 前言 数组指针的定义 数组指针的使用 前言 之前有编写过关于 指针数组 的相关知识 C语言 ——— 指针数组 & 指针数组模拟二维整型数组-CSDN博客 指针数组 顾名思义就是 存放指针的数组 那什么是数组指针呢? 数组指针的定义 何为数组指针…...
opencascade AIS_ManipulatorOwner AIS_MediaPlayer源码学习
前言 AIS_ManipulatorOwner是OpenCascade中的一个类,主要用于操纵对象的交互控制。AIS_ManipulatorOwner结合AIS_Manipulator类,允许用户通过可视化工具(如旋转、平移、缩放等)来操纵几何对象。 以下是AIS_ManipulatorOwner的基…...
如何防止用户通过打印功能复制页面文字
简单防白嫖,要让打印出来的页面是空白,通常的做法是在打印时隐藏页面上的所有内容。这可以通过CSS的媒体查询(Media Queries)来实现,特别是针对media print的查询。 在JavaScript中,你通常不会直接控制打印…...
Python3网络爬虫开发实战(3)网页数据的解析提取
文章目录 一、XPath1. 选取节点2. 查找某个特定的节点或者包含某个指定的值的节点3. XPath 运算符4. 节点轴5. 利用 lxml 使用 XPath 二、CSS三、Beautiful Soup1. 信息提取2. 嵌套选择3. 关联选择4. 方法选择器5. css 选择器 四、PyQuery1. 初始化2. css 选择器3. 信息提取4. …...
基于 HTML+ECharts 实现监控平台数据可视化大屏(含源码)
构建监控平台数据可视化大屏:基于 HTML 和 ECharts 的实现 监控平台的数据可视化对于实时掌握系统状态、快速响应问题至关重要。通过直观的数据展示,运维团队可以迅速发现异常,优化资源配置。本文将详细介绍如何利用 HTML 和 ECharts 实现一个…...
立创梁山派--移植开源的SFUD和FATFS实现SPI-FLASH文件系统
本文主要是在sfud的基础上进行fatfs文件系统的移植,并不对sfud的移植再进行过多的讲解了哦,所以如果想了解sfud的移植过程,请参考我的另外一篇文章:传送门 正文开始咯 首先我们需要先准备资料准备好,这里对于fatfs的…...
MySQL之视图和索引实战
1.新建数据库 mysql> create database myudb5_indexstu; Query OK, 1 row affected (0.01 sec) mysql> use myudb5_indexstu; Database changed 2.新建表 1.学生表student,定义主键,姓名不能重名,性别只能输入男或女,所在…...
快速参考:用C# Selenium实现浏览器窗口缩放的步骤
背景介绍 在现代网络环境中,浏览器自动化已成为数据抓取和测试的重要工具。Selenium作为一个强大的浏览器自动化工具,能够与多种编程语言结合使用,其中C#是非常受欢迎的选择之一。在实际应用中,我们常常需要调整浏览器窗口的缩放…...
MyBatis 插件机制、分页插件如何实现的
MyBatis 插件机制允许开发者在 SQL 执行的各个阶段(如预处理、执行、结果处理等)中插入自定义逻辑,从而实现对 MyBatis 行为的扩展和增强。以下是 MyBatis 插件运行原理的详细介绍: 插件接口 MyBatis 插件通过实现 org.apache.i…...
CentOS6.0安装telnet-server启用telnet服务
CentOS6.0安装telnet-server启用telnet服务 一步到位 fp"/etc/yum.repos.d" ; cp -a ${fp} ${fp}.$(date %0y%0m%0d%0H%0M%0S).bkup echo [base] nameCentOS-$releasever - Base baseurlhttp://mirrors.163.com/centos-vault/6.0/os/$basearch/http://mirrors.a…...
H5+CSS+JS工作性价比计算器
工作性价比=平均日新x综合环境系数/35 x(工作时长+通勤时长—0.5 x摸鱼时长) x学历系数 如果代码中的公式不对,请指正 效果图 源代码 <!DOCTYPE html> <html> <head> <style> .calculator { width: 300px; padd…...
Linux:基础命令学习
目录 一、ls命令 实例:-l以长格式显示文件和目录信息 实例:-F根据文件类型在列出的文件名称后加一符号 实例: -R 递归显示目录中的所有文件和子目录。 实例: 组合使用 Home目录和工作目录 二、目录修改和查看命令 三、mkd…...
面向无人机海岸带生态系统监测的语义分割基准数据集
描述:海岸带生态系统的监测是维护生态平衡和可持续发展的重要任务。语义分割技术在遥感影像中的应用为海岸带生态系统的精准监测提供了有效手段。然而,目前该领域仍面临一个挑战,即缺乏公开的专门面向海岸带生态系统的语义分割基准数据集。受…...
MySQL 知识小结(一)
一、my.cnf配置详解 我们知道安装MySQL有两种方式来安装咱们的MySQL数据库,分别是二进制安装编译数据库或者使用三方yum来进行安装,第三方yum的安装相对于二进制压缩包的安装更快捷,但是文件存放起来数据比较冗余,用二进制能够更好管理咱们M…...
【无标题】路径问题的革命性重构:基于二维拓扑收缩色动力学模型的零点隧穿理论
路径问题的革命性重构:基于二维拓扑收缩色动力学模型的零点隧穿理论 一、传统路径模型的根本缺陷 在经典正方形路径问题中(图1): mermaid graph LR A((A)) --- B((B)) B --- C((C)) C --- D((D)) D --- A A -.- C[无直接路径] B -…...
Elastic 获得 AWS 教育 ISV 合作伙伴资质,进一步增强教育解决方案产品组合
作者:来自 Elastic Udayasimha Theepireddy (Uday), Brian Bergholm, Marianna Jonsdottir 通过搜索 AI 和云创新推动教育领域的数字化转型。 我们非常高兴地宣布,Elastic 已获得 AWS 教育 ISV 合作伙伴资质。这一重要认证表明,Elastic 作为 …...
Linux-进程间的通信
1、IPC: Inter Process Communication(进程间通信): 由于每个进程在操作系统中有独立的地址空间,它们不能像线程那样直接访问彼此的内存,所以必须通过某种方式进行通信。 常见的 IPC 方式包括&#…...
IP选择注意事项
IP选择注意事项 MTP、FTP、EFUSE、EMEMORY选择时,需要考虑以下参数,然后确定后选择IP。 容量工作电压范围温度范围擦除、烧写速度/耗时读取所有bit的时间待机功耗擦写、烧写功耗面积所需要的mask layer...
【字节拥抱开源】字节团队开源视频模型 ContentV: 有限算力下的视频生成模型高效训练
本项目提出了ContentV框架,通过三项关键创新高效加速基于DiT的视频生成模型训练: 极简架构设计,最大化复用预训练图像生成模型进行视频合成系统化的多阶段训练策略,利用流匹配技术提升效率经济高效的人类反馈强化学习框架&#x…...
Heygem50系显卡合成的视频声音杂音模糊解决方案
如果你在使用50系显卡有杂音的情况,可能还是官方适配问题,可以使用以下方案进行解决: 方案一:剪映替换音色(简单适合普通玩家) 使用剪映换音色即可,口型还是对上的,没有剪映vip的&…...
Vuex:Vue.js 应用程序的状态管理模式
什么是Vuex? Vuex 是专门为 Vue.js 应用程序开发的状态管理模式 库。它采用集中式存储管理应用的所有组件的状态,并以相应的规则保证状态以一种可预测的方式发生变化。 在大型单页应用中,当多个组件共享状态时,简单的单向数据流…...
c++算法学习3——深度优先搜索
一、深度优先搜索的核心概念 DFS算法是一种通过递归或栈实现的"一条路走到底"的搜索策略,其核心思想是: 深度优先:从起点出发,选择一个方向探索到底,直到无路可走 回溯机制:遇到死路时返回最近…...
