当前位置: 首页 > news >正文

Pytorch-day10-模型部署推理-checkpoint

模型部署&推理

  • 模型部署
  • 模型推理

我们会将PyTorch训练好的模型转换为ONNX 格式,然后使用ONNX Runtime运行它进行推理

1、ONNX

ONNX( Open Neural Network Exchange) 是 Facebook (现Meta) 和微软在2017年共同发布的,用于标准描述计算图的一种格式。ONNX通过定义一组与环境和平台无关的标准格式,使AI模型可以在不同框架和环境下交互使用,ONNX可以看作深度学习框架和部署端的桥梁,就像编译器的中间语言一样

由于各框架兼容性不一,我们通常只用 ONNX 表示更容易部署的静态图。硬件和软件厂商只需要基于ONNX标准优化模型性能,让所有兼容ONNX标准的框架受益

ONNX主要关注在模型预测方面,使用不同框架训练的模型,转化为ONNX格式后,可以很容易的部署在兼容ONNX的运行环境中

  • ONNX官网:https://onnx.ai/
  • ONNX GitHub:https://github.com/onnx/onnx

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-4hoUBZ88-1692614464568)(attachment:image-2.png)]

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-PlCTmLyk-1692614464569)(attachment:image.png)]

2、ONNX Runtime

  • ONNX Runtime官网:https://www.onnxruntime.ai/
  • ONNX Runtime GitHub:https://github.com/microsoft/onnxruntime

ONNX Runtime 是由微软维护的一个跨平台机器学习推理加速器,它直接对接ONNX,可以直接读取.onnx文件并实现推理,不需要再把 .onnx 格式的文件转换成其他格式的文件

PyTorch借助ONNX Runtime也完成了部署的最后一公里,构建了 PyTorch --> ONNX --> ONNX Runtime 部署流水线

安装onnx

pip install onnx

安装onnx runtime

pip install onnxruntime # 使用CPU进行推理

pip install onnxruntime-gpu # 使用GPU进行推理

注意:ONNX和ONNX Runtime之间的适配关系。我们可以访问ONNX Runtime的Github进行查看

网址:https://github.com/microsoft/onnxruntime/blob/master/docs/Versioning.md

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-NVBVlhGG-1692614464569)(attachment:image.png)]

ONNX Runtime和CUDA之间的适配关系

网址:https://onnxruntime.ai/docs/execution-providers/CUDA-ExecutionProvider.html

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-6x0xvNMn-1692614464569)(attachment:image-2.png)]

ONNX Runtime、TensorRT和CUDA的匹配关系:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-G7NPCXmY-1692614464569)(attachment:image-3.png)]

3、模型转换为ONNX格式

  • 用torch.onnx.export()把模型转换成 ONNX 格式的函数
  • 模型导成onnx格式前,我们必须调用model.eval()或者model.train(False)以确保我们的模型处在推理模式下
import torch.onnx 
# 转换的onnx格式的名称,文件后缀需为.onnx
onnx_file_name = "resnet50.onnx"
# 我们需要转换的模型,将torch_model设置为自己的模型
model = torchvision.models.resnet50(pretrained=True)
# 加载权重,将model.pth转换为自己的模型权重
model = model.load_state_dict(torch.load("resnet50.pt"))
# 导出模型前,必须调用model.eval()或者model.train(False)
model.eval()
# dummy_input就是一个输入的实例,仅提供输入shape、type等信息 
batch_size = 1 # 随机的取值,当设置dynamic_axes后影响不大
dummy_input = torch.randn(batch_size, 3, 224, 224, requires_grad=True) 
# 这组输入对应的模型输出
output = model(dummy_input)
# 导出模型
torch.onnx.export(model,        # 模型的名称dummy_input,   # 一组实例化输入onnx_file_name,   # 文件保存路径/名称export_params=True,        #  如果指定为True或默认, 参数也会被导出. 如果你要导出一个没训练过的就设为 False.opset_version=10,          # ONNX 算子集的版本,当前已更新到15do_constant_folding=True,  # 是否执行常量折叠优化input_names = ['conv1'],   # 输入模型的张量的名称output_names = ['fc'], # 输出模型的张量的名称# dynamic_axes将batch_size的维度指定为动态,# 后续进行推理的数据可以与导出的dummy_input的batch_size不同dynamic_axes={'conv1' : {0 : 'batch_size'},    'fc' : {0 : 'batch_size'}})

注:
算子版本对照文档:https://github.com/onnx/onnx/blob/main/docs/Operators.md

ONNX模型的检验

我们需要检测下我们的模型文件是否可用,我们将通过onnx.checker.check_model()进行检验

import onnx
# 我们可以使用异常处理的方法进行检验
try:# 当我们的模型不可用时,将会报出异常onnx.checker.check_model(self.onnx_model)
except onnx.checker.ValidationError as e:print("The model is invalid: %s"%e)
else:# 模型可用时,将不会报出异常,并会输出“The model is valid!”print("The model is valid!")

ONNX模型可视化

使用netron做可视化。下载地址:https://netron.app/

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-iEgN86DI-1692614464569)(attachment:image.png)]

模型的输入&输出信息:

[外链图片转存失败,源站可能有防盗链机制,建议将图片保存下来直接上传(img-qzyKV8ba-1692614464570)(attachment:image-2.png)]

使用ONNX Runtime进行推理


import onnxruntime
# 需要进行推理的onnx模型文件名称
onnx_file_name = "xxxxxx.onnx"# onnxruntime.InferenceSession用于获取一个 ONNX Runtime 推理器
ort_session = onnxruntime.InferenceSession(onnx_file_name, providers=['CPUExecutionProvider'])
# session_fp32 = onnxruntime.InferenceSession("resnet50.onnx", providers=['CUDAExecutionProvider'])
# session_fp32 = onnxruntime.InferenceSession("resnet50.onnx", providers=['OpenVINOExecutionProvider'])# 构建字典的输入数据,字典的key需要与我们构建onnx模型时的input_names相同
# 输入的input_img 也需要改变为ndarray格式
# ort_inputs = {'conv_1': input_img} 
#建议使用下面这种方法,因为避免了手动输入key
ort_inputs = {ort_session.get_inputs()[0].name:input_img}# run是进行模型的推理,第一个参数为输出张量名的列表,一般情况可以设置为None
# 第二个参数为构建的输入值的字典
# 由于返回的结果被列表嵌套,因此我们需要进行[0]的索引
ort_output = ort_session.run(None,ort_inputs)[0]
# output = {ort_session.get_outputs()[0].name}
# ort_output = ort_session.run([output], ort_inputs)[0]

注意:

  • PyTorch模型的输入为tensor,而ONNX的输入为array,因此我们需要对张量进行变换或者直接将数据读取为array格式
  • 输入的array的shape应该和我们导出模型的dummy_input的shape相同,如果图片大小不一样,我们应该先进行resize操作
  • run的结果是一个列表,我们需要进行索引操作才能获得array格式的结果
  • 在构建输入的字典时,我们需要注意字典的key应与导出ONNX格式设置的input_name相同

完整代码

1. 安装&下载

#!pip install onnx -i https://pypi.tuna.tsinghua.edu.cn/simple
#!pip install onnxruntime -i https://pypi.tuna.tsinghua.edu.cn/simple
#!pip install torch -i https://pypi.tuna.tsinghua.edu.cn/simple
# Download ImageNet labels
#!wget https://raw.githubusercontent.com/pytorch/hub/master/imagenet_classes.txt

2、定义模型

import torch
import io
import time
from PIL import Image
import torchvision.transforms as transforms
from torchvision import datasets
import onnx
import onnxruntime
import torchvision
import numpy as np
from torch import nn
import torch.nn.init as init
onnx_file = 'resnet50.onnx'
save_dir = './resnet50.pt'

# 下载预训练模型
Resnet50 = torchvision.models.resnet50(pretrained=True)# 保存 模型权重
torch.save(Resnet50.state_dict(), save_dir)print(Resnet50)
D:\Users\xulele\Anaconda3\lib\site-packages\torchvision\models\_utils.py:208: UserWarning: The parameter 'pretrained' is deprecated since 0.13 and may be removed in the future, please use 'weights' instead.warnings.warn(
D:\Users\xulele\Anaconda3\lib\site-packages\torchvision\models\_utils.py:223: UserWarning: Arguments other than a weight enum or `None` for 'weights' are deprecated since 0.13 and may be removed in the future. The current behavior is equivalent to passing `weights=ResNet50_Weights.IMAGENET1K_V1`. You can also use `weights=ResNet50_Weights.DEFAULT` to get the most up-to-date weights.warnings.warn(msg)ResNet((conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)(layer1): Sequential((0): Bottleneck((conv1): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(256, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(64, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer2): Sequential((0): Bottleneck((conv1): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(256, 512, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(3): Bottleneck((conv1): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(128, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(128, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer3): Sequential((0): Bottleneck((conv1): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(512, 1024, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(3): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(4): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(5): Bottleneck((conv1): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(256, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(256, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(1024, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(layer4): Sequential((0): Bottleneck((conv1): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)(downsample): Sequential((0): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(2, 2), bias=False)(1): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)))(1): Bottleneck((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True))(2): Bottleneck((conv1): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn1): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv2): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)(bn2): BatchNorm2d(512, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(conv3): Conv2d(512, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False)(bn3): BatchNorm2d(2048, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)(relu): ReLU(inplace=True)))(avgpool): AdaptiveAvgPool2d(output_size=(1, 1))(fc): Linear(in_features=2048, out_features=1000, bias=True)
)

3. 模型导出为ONNX格式


batch_size = 1    # just a random number
# 先加载模型结构
loaded_model = torchvision.models.resnet50()   
# 在加载模型权重
loaded_model.load_state_dict(torch.load(save_dir))
#单卡GPU
# loaded_model.cuda()# 将模型设置为推理模式
loaded_model.eval()
# Input to the model
x = torch.randn(batch_size, 3, 224, 224, requires_grad=True)
torch_out = loaded_model(x)
torch_out
tensor([[-5.8050e-01,  7.5065e-02,  1.9404e-01, -9.1107e-01,  9.9716e-01,-1.2941e+00, -1.3402e-01, -6.4496e-01,  6.0434e-01, -1.6355e+00,-1.5187e-01,  1.0285e+00, -9.0719e-02, -2.6877e-01, -1.2656e+00,-7.9748e-01, -1.3802e+00, -9.6179e-01,  5.3512e-01,  8.3388e-02,-6.2868e-01,  1.5385e-01, -2.5405e-01,  4.3549e-01, -3.2834e-02,-8.9873e-01, -1.7059e+00, -8.5661e-01, -1.4386e+00, -2.0589e+00,-2.3464e+00, -3.6227e-01, -3.5712e+00, -1.6644e+00, -3.0064e-01,-1.8671e+00,  7.5745e-01, -2.3606e+00,  1.2460e-01,  2.7504e-01,-2.1071e-01, -2.6051e+00,  4.9932e-02, -3.0857e-01, -1.5757e-02,5.6365e-02,  1.0149e-01, -2.4776e+00,  1.7863e+00, -2.1650e+00,1.8615e+00, -2.8109e+00, -2.0084e+00, -5.4413e-01,  8.8444e-01,-8.8331e-01,  7.3980e-02, -2.0061e+00,  5.5653e-01,  7.1335e-01,4.6456e-01,  1.0112e+00,  4.2683e-01, -1.8685e-01, -1.1910e+00,1.6901e-01, -7.3501e-01, -2.4989e-01, -2.7711e-01,  1.8286e+00,-1.1317e+00,  1.9985e+00,  4.0941e-01,  2.7733e-01, -5.1216e-02,3.1703e-01, -2.1450e-01,  1.5035e+00,  1.2469e+00,  3.6729e+00,-1.2205e+00, -2.9484e-01, -3.2170e-01, -2.1006e+00, -1.2326e-01,3.9842e-01, -3.5075e-01,  1.5957e-01, -4.8100e-01,  1.2830e+00,-1.1557e+00,  2.9266e-01,  6.7955e-01,  1.2951e+00, -1.7461e-01,-3.4974e+00,  9.8954e-01, -1.1453e+00, -1.5246e+00,  7.6012e-01,-2.7971e-01, -1.0384e-01, -1.3282e+00,  3.7075e-01, -1.0879e+00,-2.2167e+00, -1.6805e+00,  1.5793e-01, -1.2778e+00, -3.4896e-01,6.2826e-01,  1.7638e+00, -8.2627e-01,  6.5328e-01,  5.1948e-01,-1.5375e+00, -2.7378e+00, -6.8703e-02, -1.5729e+00, -2.1919e+00,-1.0581e+00, -2.9345e+00, -3.2737e+00, -2.5095e+00, -2.5462e+00,-3.4298e+00,  1.0801e+00, -4.6679e-02, -7.1422e-01, -1.1388e+00,-2.2512e+00, -9.3222e-01,  2.7792e-01, -2.4730e-01, -1.3677e+00,-1.1018e+00, -2.3430e+00,  1.1828e+00,  1.5632e+00, -2.6486e+00,-2.2285e+00, -8.2680e-01, -1.9754e+00, -1.5034e+00, -2.1048e+00,1.0566e+00, -6.0091e-01, -2.2394e+00, -1.0461e+00, -1.4851e+00,9.9063e-02,  4.5648e-01, -3.0590e+00, -5.1038e-02, -2.2756e+00,-1.5584e+00, -2.6344e+00, -1.3177e+00, -2.4749e+00,  1.3347e-01,-1.8447e+00, -1.9380e+00, -1.1397e+00, -9.6618e-01, -4.7473e-01,-8.1531e-01, -2.0591e+00, -2.2707e+00, -2.1579e+00, -8.4820e-01,-1.8621e+00, -1.0359e+00, -1.7589e+00, -5.1326e-01, -1.9336e+00,-2.4361e+00, -3.0598e+00, -1.5690e+00,  7.9418e-01, -2.0329e+00,-1.4686e+00, -1.3989e+00, -1.2050e+00, -4.6212e-01, -2.1246e+00,3.9028e-02, -1.3888e+00, -8.1794e-01, -3.2460e+00, -2.9345e-01,-1.5963e+00, -1.4708e+00, -1.7513e+00, -1.0326e+00, -2.5880e+00,-3.5845e-02, -1.8802e+00, -2.0279e+00, -2.2119e+00, -5.6981e-01,-1.4423e+00, -5.3841e-01, -2.4736e-01,  1.4031e-01, -1.1382e+00,-1.3424e+00, -1.5412e-01, -1.5119e+00, -8.1195e-01, -2.3688e+00,-3.1494e+00, -1.2997e+00, -2.0867e+00, -1.5811e+00, -1.1873e+00,-1.4610e+00,  4.6883e-01, -1.3841e+00, -2.3627e+00, -5.0272e-01,-2.2311e+00,  2.8236e-01, -1.4063e+00, -6.1543e-01,  2.2254e-01,-1.8209e+00, -2.2796e+00, -1.4799e+00, -9.3366e-01, -4.5269e-01,-1.5885e+00, -3.5685e-01, -7.9922e-01, -1.7434e+00, -1.3543e+00,-5.9424e-01, -7.4004e-02, -4.8574e-01, -9.4252e-01, -1.1784e+00,-1.0762e+00, -7.0929e-01, -2.3507e+00, -1.5668e+00, -2.8629e+00,-9.7854e-01, -7.7075e-01, -2.1660e+00, -2.3006e-01, -6.7149e-01,-8.6158e-01, -1.7104e-02, -1.9825e+00, -7.7517e-01, -3.8014e-01,-2.1186e+00, -9.2220e-01, -9.2850e-01, -1.2418e+00,  9.7522e-02,-3.6667e-03, -2.1291e+00, -2.8809e+00, -1.3699e+00, -1.5959e+00,-6.5653e-01, -1.2664e+00, -2.8341e-01, -1.5526e+00, -7.1795e-01,-4.8103e-01, -1.6648e+00, -8.2810e-01, -1.6934e+00, -1.3563e+00,-1.6123e+00, -1.1855e+00, -1.2475e+00, -1.3781e+00, -9.8912e-01,-1.3062e-03,  1.2144e+00,  2.8563e+00,  1.7405e+00,  3.0779e-01,8.2037e-01, -4.7336e-01, -2.7651e+00,  4.0167e-01,  2.1637e-01,-5.0109e-01, -1.0902e+00, -2.6263e-01,  5.9031e-01, -5.2879e-01,1.0321e+00,  1.2048e+00,  1.6882e-01,  4.2126e-02, -3.8657e-01,-1.3633e+00,  2.0077e+00, -9.9282e-01, -1.6829e-01, -1.5846e+00,-2.1892e+00, -6.6651e-01,  9.6200e-01,  1.1047e+00, -3.3428e-01,2.7981e+00,  7.2582e-01,  3.4494e-01,  8.2232e-01,  1.7219e+00,1.0106e+00, -2.3200e-01,  4.9711e-02,  1.6123e+00,  8.3826e-01,-1.4559e+00, -2.4328e+00, -2.8555e+00, -2.6156e+00, -1.9900e+00,-2.4778e+00, -1.9356e+00, -1.5563e+00, -2.5033e+00, -3.5848e+00,-2.4205e-01, -5.5758e-01,  2.3322e-01, -1.1810e+00, -8.3212e-01,-4.8195e-02, -4.9411e-01, -3.0698e-03, -1.6134e+00, -1.5790e+00,-5.8626e-01, -1.8875e+00, -1.5670e+00, -2.0681e+00, -1.7590e+00,-3.9325e-01, -2.0172e+00, -1.3237e+00, -1.7693e-01, -8.5266e-01,-2.0535e+00, -2.7916e+00, -1.7173e+00,  5.3713e-02, -1.9363e-01,-3.1787e-01,  7.0567e-01,  5.3067e-01,  1.0458e+00,  1.2243e+00,-3.9257e-01, -3.9865e-01,  3.8122e-01,  3.4527e-01, -1.6836e+00,6.8797e-01,  1.2213e+00,  1.0733e+00,  1.1278e+00,  6.7682e-01,1.2179e+00, -8.0824e-01,  2.7535e-03, -8.5098e-01, -9.4244e-02,-3.7395e-01, -5.9386e-01, -8.1263e-02, -5.8865e-01, -8.3479e-01,-7.2452e-01, -1.6460e-01,  7.2182e-01,  1.2066e+00, -1.8087e+00,-4.4841e-01, -3.2795e-01, -3.0482e-01, -3.3302e-01, -2.4936e+00,-5.7049e-01, -2.0744e-02, -7.5551e-01, -2.4757e+00, -1.7799e+00,-1.1292e+00, -1.0917e+00,  6.8229e-01,  8.7337e-01,  3.1813e+00,-1.5752e+00,  1.0542e-01,  2.5594e+00, -1.0048e+00, -2.2436e+00,4.9551e-01, -2.0745e+00, -9.9214e-01, -2.5501e+00,  2.7392e+00,6.4982e-01,  3.5795e+00,  2.0882e+00,  1.0579e+00,  2.3663e+00,-1.1029e+00, -6.6217e-01, -4.8396e-01,  3.6624e+00,  2.3802e+00,8.2251e-01,  2.5061e+00, -1.8793e+00,  1.6354e+00,  1.9349e+00,7.7006e-01,  2.4251e-01,  1.7568e+00, -9.3206e-01,  1.2631e+00,1.0240e+00, -3.5013e-01,  7.5377e-03,  5.0503e-01, -9.5431e-01,1.5458e+00, -2.5770e+00,  5.7188e-01,  9.7471e-01, -3.1393e-01,1.0891e+00,  2.3057e+00, -7.5324e-01,  3.2789e+00, -8.1716e-01,-1.9879e+00,  5.5330e+00,  6.3507e-01, -1.1635e+00, -1.1235e+00,-3.4298e+00,  7.5610e-01, -3.1293e-02, -9.6185e-01, -8.1488e-02,1.1240e+00, -6.9891e-02,  2.5587e+00,  2.2736e+00,  1.7838e-01,-6.9245e-01,  2.4419e+00,  2.0427e+00,  1.1029e+00,  4.1609e+00,3.5126e+00, -1.8192e+00, -3.3070e+00,  7.6861e-01,  1.2807e+00,2.1298e-01, -8.7622e-01, -2.1935e+00,  1.0431e+00,  1.9949e+00,-3.2491e-01, -3.1093e+00, -1.0409e+00,  1.2334e+00, -1.7676e-01,3.0567e+00,  2.6081e+00,  2.7356e-01,  6.0596e-02, -1.3262e+00,-3.5291e-01, -4.7318e-01,  2.1949e+00,  5.3661e+00,  4.2932e+00,8.3733e+00,  4.1425e-01,  2.4924e-01, -1.3689e+00,  7.1289e-02,-9.8287e-01, -1.2412e+00,  1.3910e+00,  1.9533e+00,  3.3525e+00,1.7242e+00,  1.7637e+00,  1.0108e+00,  1.2255e+00,  1.7504e+00,5.4399e-01,  2.2958e+00,  1.9387e+00,  2.4723e+00, -1.1986e+00,-1.5123e+00, -1.9842e+00,  1.8934e+00,  1.3407e+00,  4.6350e-01,2.6674e+00,  1.0492e+00,  1.0988e+00, -1.4208e-02,  3.9129e-01,-4.7343e-01, -1.7139e+00, -7.8037e-01,  1.3938e+00,  2.4655e+00,-9.8006e-01, -5.5273e-01,  1.1947e+00,  1.5285e+00,  2.2214e-01,2.2346e+00,  1.3524e+00, -3.2841e-01,  2.1160e+00,  4.4156e+00,-2.7112e+00, -9.0547e-01, -1.4378e+00,  1.5687e+00,  3.1633e+00,-2.9853e-01,  1.2451e+00,  2.5149e+00,  1.0312e+00, -6.9518e-01,1.1537e+00,  9.6612e-01, -3.5077e+00, -7.9979e-02,  4.3770e+00,-6.3443e-01, -5.2904e-01,  1.5411e+00,  1.2678e+00, -1.2136e+00,-2.1303e+00,  5.5227e+00,  3.5111e-01,  1.5474e+00,  2.1807e+00,1.4828e+00, -1.4299e+00,  1.9229e+00,  2.4931e+00, -2.5156e+00,-1.7203e+00, -4.2708e-01,  1.6891e+00,  1.5878e+00, -3.3333e+00,2.1083e+00, -1.7954e-02,  3.9262e-01, -1.8340e+00,  7.8696e-01,-2.9308e+00, -2.3592e+00,  1.0347e+00,  8.9930e-01,  1.2392e+00,5.4734e-01,  6.6852e-01, -2.6781e+00,  2.2405e-01, -9.0210e-01,1.0648e+00, -2.3832e+00,  1.7305e+00,  1.6958e+00,  1.0681e+00,8.2608e-01,  2.5071e+00, -2.3054e-01,  3.9594e-01, -1.4630e-01,-2.1682e+00,  3.0358e+00,  1.5096e+00,  7.6303e-01,  4.4392e+00,3.2750e+00,  2.6279e+00,  4.3440e-01, -3.9379e+00,  1.0872e+00,1.7172e+00,  2.8548e+00, -1.0287e+00,  4.9895e+00, -2.0666e+00,4.8006e+00,  2.0120e+00, -1.5181e+00,  8.6181e-01, -3.4666e-01,2.2120e+00,  3.0910e+00,  5.9223e-01,  2.2166e+00,  3.9417e+00,3.5241e+00, -5.3305e-01,  3.5832e+00,  2.5654e+00, -1.5450e+00,-2.6835e+00,  3.1550e+00, -2.6302e+00,  2.3621e-01,  2.1758e+00,1.2487e+00, -1.0268e-01,  3.6262e+00,  3.6049e+00, -2.3248e+00,2.3213e-01,  3.2931e+00, -1.0058e+00,  4.5938e-01, -4.2993e-01,1.3951e+00, -2.8811e-01, -5.2850e-01,  1.0776e+00,  4.6138e+00,-7.1348e-01,  5.8099e-01,  4.4438e-01, -6.0801e-01,  7.0509e-01,3.5084e+00,  3.0626e+00,  7.0831e-01,  1.5073e+00, -2.1074e+00,3.2849e+00, -2.7267e+00,  2.9387e-01,  5.1394e-01,  1.4031e-01,-1.0694e+00, -2.5526e+00,  1.6833e+00, -1.3013e+00,  3.0083e+00,-1.9390e+00,  4.4978e-01, -1.5059e-01, -2.4490e+00,  1.6431e+00,-4.6816e-01, -1.6293e+00, -7.9092e-01,  1.1116e+00,  2.1265e+00,-3.0442e+00,  9.5523e-02,  2.8034e+00,  1.3312e+00,  3.4422e+00,4.4743e-01,  1.7062e+00,  1.8941e-01,  1.2406e+00, -9.8100e-01,-9.7636e-01, -3.9718e-01, -5.6298e-01,  2.1325e+00,  1.4298e+00,-4.6180e+00, -5.8675e-01,  1.7124e+00, -7.3919e-02, -2.9715e+00,2.9501e+00,  1.4472e+00, -1.3756e+00, -1.0018e+00, -1.1162e-01,1.2214e+00, -5.2164e-01, -8.7681e-01,  6.0252e-01,  2.7381e-01,-2.9817e+00, -1.3999e+00,  1.8137e+00, -3.4810e-02,  1.2475e+00,-5.1820e-01,  3.4469e+00,  2.8484e+00,  5.9049e-01,  2.2143e+00,-1.9403e-01,  1.5231e+00, -4.1188e+00,  5.6471e-01, -1.4212e+00,1.1938e+00,  2.8821e+00,  2.4709e+00, -1.6792e+00, -4.7604e-01,1.7501e+00, -2.2566e+00,  7.4556e-01,  2.5034e+00, -3.6194e-01,-1.1058e+00,  2.2076e+00, -6.0705e-03,  2.5470e+00, -1.9637e+00,2.7231e+00,  2.4390e+00,  1.1190e+00, -9.0371e-01, -4.4400e-01,8.6673e-01,  2.8887e+00, -6.5289e-01,  1.6986e+00,  6.0122e-01,-1.1510e+00,  1.9672e+00,  3.6989e+00,  1.3653e-01,  9.0087e-01,1.8489e+00, -2.7983e+00,  1.5802e+00,  2.6502e+00,  1.1414e+00,-5.3817e-01,  1.1085e+00, -2.1715e+00, -7.2016e-01,  1.5999e+00,4.9543e+00,  1.9814e+00, -1.1679e+00,  2.8527e+00,  2.1758e+00,7.5756e-01, -1.0221e+00,  1.2118e+00, -2.4591e-01,  1.4493e+00,3.4529e-01,  1.6389e+00,  4.0479e+00,  1.2619e+00,  4.2199e-01,-1.2010e+00,  2.7446e+00,  3.2914e+00,  1.6454e+00, -4.8627e-01,-3.6592e-01,  1.1508e+00,  4.4760e+00,  3.3516e+00,  2.9289e+00,1.6571e+00, -6.9271e-02,  1.5371e+00, -1.6635e-01,  2.8581e+00,1.0374e+00,  1.1429e+00,  2.1297e+00,  1.0264e+00,  4.7174e+00,-8.5201e-01,  1.7106e+00,  7.4727e-01,  6.5346e-01,  1.6801e+00,-3.7609e-01, -1.5926e+00, -2.6283e+00, -1.6866e+00,  5.5250e-02,-6.2809e-02,  5.9573e-01, -7.4590e-01,  5.3049e-01, -1.5091e+00,-8.0366e-01,  3.3241e+00,  2.3141e+00,  1.1193e+00, -1.6830e+00,3.3035e+00,  2.9134e-01, -2.9930e+00,  2.4471e+00,  9.8725e-01,-2.7953e+00, -1.7308e+00, -9.4977e-01,  1.6247e-01,  2.5793e+00,2.9449e-01,  2.1876e+00,  1.3091e-01,  6.2929e+00, -5.5488e-01,1.2929e+00, -9.5095e-03, -1.1349e+00, -1.0178e-01,  2.3317e+00,-4.3678e-01,  2.3839e+00,  2.6191e+00, -2.0215e+00,  1.5188e+00,3.1490e+00,  3.1997e+00, -2.2047e-01, -1.2029e-01,  2.7171e+00,3.1623e+00,  7.7251e-01, -1.8028e+00, -7.3017e-01,  1.5781e+00,7.6143e-01,  4.7296e+00,  1.7691e+00,  1.4732e+00,  2.0614e+00,2.2509e+00, -4.4578e+00,  1.1764e+00,  2.2630e+00,  5.7318e-01,4.3310e-01,  1.6570e+00, -1.4352e+00, -1.2535e+00, -4.0429e+00,-5.1775e-01, -1.5580e+00, -1.8145e+00,  2.4469e+00,  1.9574e+00,-2.0032e-01, -2.0393e+00,  3.3668e+00, -5.2449e-01, -4.5653e+00,4.8361e-01,  4.8011e-01,  8.3248e-01, -1.4842e-01,  2.5230e+00,-3.1912e-01,  1.1091e+00,  1.9290e+00,  6.5501e-01,  7.5642e-01,1.3678e+00,  1.6187e+00, -2.2867e+00, -1.3338e+00,  7.0305e-01,-2.6969e+00, -3.4848e-01,  3.5779e+00,  2.5296e+00,  1.2646e+00,-8.2202e-01,  1.5727e+00,  2.0048e+00,  1.9939e+00,  3.6664e-01,-3.7189e-01,  6.5360e-02,  2.5970e+00,  1.9509e+00,  7.9060e+00,4.1564e+00,  1.9750e+00,  1.3692e+00,  7.0074e-01,  1.3194e+00,1.5737e+00,  3.1158e+00,  2.8220e-01, -1.1930e+00, -2.9132e+00,3.6715e-01,  2.0554e+00, -4.5951e-01,  1.4659e+00,  1.6097e-01,3.5082e-01,  1.9813e+00,  2.3234e+00, -1.6767e+00, -1.9703e+00,-4.2028e-01, -2.6262e+00, -1.3928e+00, -7.6662e-01,  4.5116e-01,2.6828e-01, -2.8156e-01,  7.0492e-02, -2.3663e+00, -5.0179e-01,-1.6241e-01, -2.5555e+00, -9.8973e-02, -2.2130e+00, -2.3067e+00,-1.8250e+00, -1.8571e+00, -2.4779e+00, -2.7528e+00, -2.9528e+00,-9.4892e-01, -2.8599e+00, -6.0309e-01, -1.4899e-01, -9.7413e-01,9.2476e-01,  1.2974e+00, -8.6647e-01, -1.4522e-01,  1.5039e+00,1.5240e-01, -1.9550e+00, -1.3404e+00,  5.6667e-01, -1.2009e+00,-9.4940e-01,  1.0278e+00, -2.9112e+00, -6.9027e-01, -8.4326e-01,-1.5937e+00,  1.6618e+00,  3.1860e+00,  3.0757e+00,  4.0690e-01,-1.1017e+00,  3.6284e+00, -6.9720e-01, -1.3498e+00,  1.4283e-01,-4.1820e-01, -1.6470e+00,  4.1369e-01,  1.7120e-01, -1.7615e+00,7.3642e-01,  1.7452e+00,  4.3359e-01, -2.8788e-01, -6.6571e-02,-1.4325e-02, -2.2441e+00,  1.2690e+00, -7.3996e-01, -1.1551e+00,-1.4367e+00, -1.5546e+00, -2.9878e+00, -3.5215e+00, -4.2169e+00,-3.7416e+00, -2.0244e+00, -2.6461e+00, -1.1108e+00,  1.1864e+00]],grad_fn=<AddmmBackward0>)
torch_out.size()
torch.Size([1, 1000])

# 导出模型
torch.onnx.export(loaded_model,               # model being runx,             # model input (or a tuple for multiple inputs)onnx_file,   # where to save the model (can be a file or file-like object)export_params=True,        # store the trained parameter weights inside the model fileopset_version=10,   # the ONNX version to export the model todo_constant_folding=True,  # whether to execute constant folding for optimizationinput_names = ['conv1'],   # the model's input namesoutput_names = ['fc'], # the model's output names# variable length axesdynamic_axes={'conv1' : {0 : 'batch_size'},    'fc' : {0 : 'batch_size'}})
============= Diagnostic Run torch.onnx.export version 2.0.0+cu117 =============
verbose: False, log level: Level.ERROR
======================= 0 NONE 0 NOTE 0 WARNING 0 ERROR ========================

4、检验ONNX模型

# 我们可以使用异常处理的方法进行检验
try:# 当我们的模型不可用时,将会报出异常onnx.checker.check_model(onnx_file)
except onnx.checker.ValidationError as e:print("The model is invalid: %s"%e)
else:# 模型可用时,将不会报出异常,并会输出“The model is valid!”print("The model is valid!")
The model is valid!

5. 使用ONNX Runtime进行推理

import onnxruntime
import numpy as nport_session = onnxruntime.InferenceSession(onnx_file, providers=['CPUExecutionProvider'])# 将张量转化为ndarray格式
def to_numpy(tensor):return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()# 构建输入的字典和计算输出结果
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(x)}
ort_outs = ort_session.run(None, ort_inputs)# 比较使用PyTorch和ONNX Runtime得出的精度
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)print("Exported model has been tested with ONNXRuntime, and the result looks good!")
Exported model has been tested with ONNXRuntime, and the result looks good!

6. 进行实际预测并可视化

# 推理数据
from PIL import Image
from torchvision.transforms import transforms# 生成推理图片
image = Image.open('./images/cat.jpg')# 将图像调整为指定大小
image = image.resize((224, 224))# 将图像转换为 RGB 模式
image = image.convert('RGB')image.save('./images/cat_224.jpg')
categories = []
# Read the categories
with open("./imagenet/imagenet_classes.txt", "r") as f:categories = [s.strip() for s in f.readlines()]def get_class_name(probabilities):# Show top categories per imagetop5_prob, top5_catid = torch.topk(probabilities, 5)for i in range(top5_prob.size(0)):print(categories[top5_catid[i]], top5_prob[i].item())
#预处理
def pre_image(image_file):input_image = Image.open(image_file)preprocess = transforms.Compose([transforms.Resize(256),transforms.CenterCrop(224),transforms.ToTensor(),transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),])input_tensor = preprocess(input_image)inputs = input_tensor.unsqueeze(0) # create a mini-batch as expected by the model# input_arr = inputs.cpu().detach().numpy()return inputs 
#inference with model# 先加载模型结构
resnet50 = torchvision.models.resnet50()   
# 在加载模型权重
resnet50.load_state_dict(torch.load(save_dir))resnet50.eval()  
#推理
input_batch = pre_image('./images/cat_224.jpg')# move the input and model to GPU for speed if available
print("GPU Availability: ", torch.cuda.is_available())
if torch.cuda.is_available():input_batch = input_batch.to('cuda')resnet50.to('cuda')with torch.no_grad():output = resnet50(input_batch)
# Tensor of shape 1000, with confidence scores over Imagenet's 1000 classes
# print(output[0])
# The output has unnormalized scores. To get probabilities, you can run a softmax on it.
probabilities = torch.nn.functional.softmax(output[0], dim=0)
get_class_name(probabilities)
GPU Availability:  False
Persian cat 0.6668420433998108
lynx 0.023987364023923874
bow tie 0.016234245151281357
hair slide 0.013150070793926716
Japanese spaniel 0.012279157526791096
input_batch.size()
torch.Size([1, 3, 224, 224])
#benchmark 性能
latency = []
for i in range(10):with torch.no_grad():start = time.time()output = resnet50(input_batch)probabilities = torch.nn.functional.softmax(output[0], dim=0)top5_prob, top5_catid = torch.topk(probabilities, 5)# for catid in range(top5_catid.size(0)):#     print(categories[catid])latency.append(time.time() - start)print("{} model inference CPU time:cost {} ms".format(str(i),format(sum(latency) * 1000 / len(latency), '.2f')))
0 model inference CPU time:cost 149.59 ms
1 model inference CPU time:cost 130.74 ms
2 model inference CPU time:cost 133.76 ms
3 model inference CPU time:cost 130.64 ms
4 model inference CPU time:cost 131.72 ms
5 model inference CPU time:cost 130.88 ms
6 model inference CPU time:cost 136.31 ms
7 model inference CPU time:cost 139.95 ms
8 model inference CPU time:cost 141.90 ms
9 model inference CPU time:cost 140.96 ms
# Inference with ONNX Runtime
import onnxruntime
from onnx import numpy_helper
import time
onnx_file = 'resnet50.onnx'
session_fp32 = onnxruntime.InferenceSession(onnx_file, providers=['CPUExecutionProvider'])
# session_fp32 = onnxruntime.InferenceSession("resnet50.onnx", providers=['CUDAExecutionProvider'])
# session_fp32 = onnxruntime.InferenceSession("resnet50.onnx", providers=['OpenVINOExecutionProvider'])def softmax(x):"""Compute softmax values for each sets of scores in x."""e_x = np.exp(x - np.max(x))return e_x / e_x.sum()latency = []
def run_sample(session, categories, inputs):start = time.time()input_arr = inputsort_outputs = session.run([], {'conv1':input_arr})[0]output = ort_outputs.flatten()output = softmax(output) # this is optionaltop5_catid = np.argsort(-output)[:5]# for catid in top5_catid:#     print(categories[catid])latency.append(time.time() - start)return ort_outputs

input_tensor = pre_image('./images/cat_224.jpg')
input_arr = input_tensor.cpu().detach().numpy()
for i in range(10):ort_output = run_sample(session_fp32, categories, input_arr)print("{} ONNX Runtime CPU Inference time = {} ms".format(str(i),format(sum(latency) * 1000 / len(latency), '.2f')))
0 ONNX Runtime CPU Inference time = 67.66 ms
1 ONNX Runtime CPU Inference time = 56.30 ms
2 ONNX Runtime CPU Inference time = 53.90 ms
3 ONNX Runtime CPU Inference time = 58.18 ms
4 ONNX Runtime CPU Inference time = 64.53 ms
5 ONNX Runtime CPU Inference time = 62.79 ms
6 ONNX Runtime CPU Inference time = 61.75 ms
7 ONNX Runtime CPU Inference time = 60.51 ms
8 ONNX Runtime CPU Inference time = 59.35 ms
9 ONNX Runtime CPU Inference time = 57.57 ms

4、扩展知识

  • 模型量化
  • 模型剪裁
  • 工程优化
  • 算子优化

相关文章:

Pytorch-day10-模型部署推理-checkpoint

模型部署&推理 模型部署模型推理 我们会将PyTorch训练好的模型转换为ONNX 格式&#xff0c;然后使用ONNX Runtime运行它进行推理 1、ONNX ONNX( Open Neural Network Exchange) 是 Facebook (现Meta) 和微软在2017年共同发布的&#xff0c;用于标准描述计算图的一种格式…...

vue使用websocket

建立websocket.js // 信息提示 import { Message } from element-ui // 引入用户id import { getTenantId, getAccessToken } from /utils/auth// websocket地址 var url ws://192.168.2.20:48081/websocket/message // websocket实例 var ws // 重连定时器实例 var tt // w…...

jmeter入门:接口压力测试全解析

一.对接口压力测试 1.配置 1.添加线程组&#xff08;参数上文有解释 这里不介绍&#xff09; 2.添加取样器 不用解释一看就知道填什么。。。 3.添加头信息&#xff08;否则请求头对不上&#xff09; 也不用解释。。。 4.配置监听器 可以尝试使用这几个监听器。 2.聚合结果…...

go、java、.net、C#、nodejs、vue、react、python程序问题进群咨询

1、面试辅导 2、程序辅导 3、一对一腾讯会议辅导 3、业务逻辑辅导 4、各种bug帮你解决。 5、培训小白 6、顺利拿到offer...

树莓派4B最新系统Bullseye 64 bit使用xrdp远程桌面黑屏卡顿问题

1、树莓派换源 打开源文件 sudo nano /etc/apt/sources.list注释原来的&#xff0c;更换为清华源 deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bullseye main contrib non-free deb https://mirrors.tuna.tsinghua.edu.cn/debian/ bullseye-updates main contrib no…...

EasyExcel入门介绍及工具类,网络下载excel

前言&#xff1a;在这里分享自己第一次使用EasyExcel并且编写工具类&#xff0c;且在接口中支持excel文件下载的一系列流程&#xff0c;包含所有前后端&#xff08;JSJAVA&#xff09;完整代码&#xff0c;可以根据自己需要自行提取&#xff0c;仅供参考。 一.引入EasyExcel依赖…...

【HarmonyOS北向开发】-04 ArkTS开发语言-ArkTS基础知识

飞书原文档&#xff1a;Docs...

【Alibaba中间件技术系列】「RocketMQ技术专题」小白专区之领略一下RocketMQ基础之最!

应一些小伙伴们的私信&#xff0c;希望可以介绍一下RocketMQ的基础&#xff0c;那么我们现在就从0开始&#xff0c;进入RocketMQ的基础学习及概念介绍&#xff0c;为学习和使用RocketMQ打好基础&#xff01; RocketMQ是一款快速地、可靠地、分布式、容易使用的消息中间件&#…...

营销活动:提升小程序的用户活跃度的关键

在现今竞争激烈的商业环境中&#xff0c;小程序已成为企业私域营销的重要工具之一。然而&#xff0c;拥有一个小程序并不足以保证用户的活跃度。营销活动作为推动用户参与的有效方式&#xff0c;对于提升小程序的用户活跃度起着至关重要的作用。本文将深入探讨营销活动在提升小…...

Neo4j之CALL基础

CALL 语句用于调用 Neo4j 数据库中预定义的函数、过程或者自定义的函数。它是用来执行一些特定操作或计算的重要工具。以下是一些常用的 CALL 语句示例和解释&#xff1a; 调用内置函数&#xff1a; CALL db.labels()这个示例中&#xff0c;调用了内置函数 db.labels() 来获取…...

【TypeScript】元组

元组&#xff08;Tuple&#xff09;是 TypeScript 中的一种特殊数据类型&#xff0c;它允许你定义一个固定数量和类型的元素组合。元组可以包含不同类型的数据&#xff0c;每个数据的类型在元组中都是固定的。以下是 TypeScript 中元组的基本用法和特点&#xff1a; // 声明一…...

数据仓库一分钟

数据分层 一、数据运营层&#xff1a;ODS&#xff08;Operational Data Store&#xff09; “面向主题的”数据运营层&#xff0c;也叫ODS层&#xff0c;是最接近数据源中数据的一层&#xff0c;数据源中的数据&#xff0c;经过抽取、洗净、传输&#xff0c;也就说传说中的 ETL…...

提升Python代理程序性能的终极解决方案:缓存、连接池和并发

在开发Python代理程序时&#xff0c;优化性能是至关重要的。本文将为你介绍一套终极解决方案&#xff0c;通过缓存、连接池和并发处理等技术&#xff0c;极大地提升Python代理程序的效率和稳定性。 游戏国内地更换虚拟含ip地址数据库地区 1.缓存技术 缓存是 .0-*-696ES2 0一…...

CSS和AJAX阶段学习记录

1、AJAX的工作原理&#xff1a; 如图所示&#xff0c;工作原理可以分为以下几步&#xff1a; 网页中发生一个事件&#xff08;页面加载、按钮点击&#xff09; 由 JavaScript 创建 XMLHttpRequest 对象 XMLHttpRequest 对象向 web 服务器发送请求 服务器处理该请求 服务器将响应…...

Android自定义View知识体系

View的概念、作用和基本属性 View是Android中的基本UI组件&#xff0c;用于构建用户界面。它可以是按钮、文本框、图像等可见元素&#xff0c;也可以是容器&#xff0c;用于组织其他View。View的作用是展示数据和接收用户的输入。它可以显示文本、图片、动画等内容&#xff0c…...

Springboot 自定义 Mybatis拦截器,实现 动态查询条件SQL自动组装拼接(玩具)

前言 ps&#xff1a;最近在参与3100保卫战&#xff0c;战况很激烈&#xff0c;刚刚打完仗&#xff0c;来更新一下之前写了一半的博客。 该篇针对日常写查询的时候&#xff0c;那些动态条件sql 做个简单的封装&#xff0c;自动生成&#xff08;抛砖引玉&#xff0c;搞个小玩具&a…...

Go 1.21新增的 slices 包详解(三)

Go 1.21新增的 slices 包提供了很多和切片相关的函数&#xff0c;可以用于任何类型的切片。 slices.Max 定义如下&#xff1a; func Max[S ~[]E, E cmp.Ordered](x S) E 返回 x 中的最大值&#xff0c;如果 x 为空&#xff0c;则 panic。对于浮点数 E, 如果有元素为 NaN&am…...

Python 在logging.config.dictConfig()日志配置方式下,使用自定义的Handler处理程序

文章目录 一、基于 RotatingFileHandler 的自定义处理程序二、基于 TimedRotatingFileHandler 的自定义处理程序 Python logging模块的基本使用、进阶使用详解 Python logging.handlers模块&#xff0c;RotatingFileHandler、TimedRotatingFileHandler 处理器各参数详细介绍 …...

Anaconda, Python, Jupyter和PyCharm介绍

目录 1 Anaconda, Python, Jupyter和PyCharm介绍 2 macOS通过Anaconda安装Python, Jupyter和PyCharm 3 使用终端创建虚拟环境并安装PyTorch 4 安装PyCharm并导入Anaconda虚拟环境 5 Windows操作系统下Anaconda与PyCharm安装 6 通过 Anaconda Navigator 创建 TensorFlow 虚…...

axios 各种方式的请求 示例

GET请求 示例一&#xff1a; 服务端代码 GetMapping("/f11") public String f11(Integer pageNum, Integer pageSize) {return pageNum " : " pageSize; }前端代码 <template><div class"home"><button click"getFun1…...

C++_核心编程_多态案例二-制作饮品

#include <iostream> #include <string> using namespace std;/*制作饮品的大致流程为&#xff1a;煮水 - 冲泡 - 倒入杯中 - 加入辅料 利用多态技术实现本案例&#xff0c;提供抽象制作饮品基类&#xff0c;提供子类制作咖啡和茶叶*//*基类*/ class AbstractDr…...

【CSS position 属性】static、relative、fixed、absolute 、sticky详细介绍,多层嵌套定位示例

文章目录 ★ position 的五种类型及基本用法 ★ 一、position 属性概述 二、position 的五种类型详解(初学者版) 1. static(默认值) 2. relative(相对定位) 3. absolute(绝对定位) 4. fixed(固定定位) 5. sticky(粘性定位) 三、定位元素的层级关系(z-i…...

uniapp微信小程序视频实时流+pc端预览方案

方案类型技术实现是否免费优点缺点适用场景延迟范围开发复杂度​WebSocket图片帧​定时拍照Base64传输✅ 完全免费无需服务器 纯前端实现高延迟高流量 帧率极低个人demo测试 超低频监控500ms-2s⭐⭐​RTMP推流​TRTC/即构SDK推流❌ 付费方案 &#xff08;部分有免费额度&#x…...

dify打造数据可视化图表

一、概述 在日常工作和学习中&#xff0c;我们经常需要和数据打交道。无论是分析报告、项目展示&#xff0c;还是简单的数据洞察&#xff0c;一个清晰直观的图表&#xff0c;往往能胜过千言万语。 一款能让数据可视化变得超级简单的 MCP Server&#xff0c;由蚂蚁集团 AntV 团队…...

MySQL 8.0 事务全面讲解

以下是一个结合两次回答的 MySQL 8.0 事务全面讲解&#xff0c;涵盖了事务的核心概念、操作示例、失败回滚、隔离级别、事务性 DDL 和 XA 事务等内容&#xff0c;并修正了查看隔离级别的命令。 MySQL 8.0 事务全面讲解 一、事务的核心概念&#xff08;ACID&#xff09; 事务是…...

云原生周刊:k0s 成为 CNCF 沙箱项目

开源项目推荐 HAMi HAMi&#xff08;原名 k8s‑vGPU‑scheduler&#xff09;是一款 CNCF Sandbox 级别的开源 K8s 中间件&#xff0c;通过虚拟化 GPU/NPU 等异构设备并支持内存、计算核心时间片隔离及共享调度&#xff0c;为容器提供统一接口&#xff0c;实现细粒度资源配额…...

鸿蒙(HarmonyOS5)实现跳一跳小游戏

下面我将介绍如何使用鸿蒙的ArkUI框架&#xff0c;实现一个简单的跳一跳小游戏。 1. 项目结构 src/main/ets/ ├── MainAbility │ ├── pages │ │ ├── Index.ets // 主页面 │ │ └── GamePage.ets // 游戏页面 │ └── model │ …...

全面解析数据库:从基础概念到前沿应用​

在数字化时代&#xff0c;数据已成为企业和社会发展的核心资产&#xff0c;而数据库作为存储、管理和处理数据的关键工具&#xff0c;在各个领域发挥着举足轻重的作用。从电商平台的商品信息管理&#xff0c;到社交网络的用户数据存储&#xff0c;再到金融行业的交易记录处理&a…...

用 Rust 重写 Linux 内核模块实战:迈向安全内核的新篇章

用 Rust 重写 Linux 内核模块实战&#xff1a;迈向安全内核的新篇章 ​​摘要&#xff1a;​​ 操作系统内核的安全性、稳定性至关重要。传统 Linux 内核模块开发长期依赖于 C 语言&#xff0c;受限于 C 语言本身的内存安全和并发安全问题&#xff0c;开发复杂模块极易引入难以…...

从实验室到产业:IndexTTS 在六大核心场景的落地实践

一、内容创作&#xff1a;重构数字内容生产范式 在短视频创作领域&#xff0c;IndexTTS 的语音克隆技术彻底改变了配音流程。B 站 UP 主通过 5 秒参考音频即可克隆出郭老师音色&#xff0c;生成的 “各位吴彦祖们大家好” 语音相似度达 97%&#xff0c;单条视频播放量突破百万…...