当前位置: 首页 > news >正文

12.6深度学习_模型优化和迁移_模型移植

八、模型移植

1. 认识ONNX

​ https://onnx.ai/

​ Open Neural Network Exchange(ONNX,开放神经网络交换)格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移。

​ ONNX的规范及代码主要由微软,亚马逊 ,Face book 和 IBM等公司共同开发,以开放源代码的方式托管在Github上。目前官方支持加载ONNX模型并进行推理的深度学习框架有: Caffe2, PyTorch, PaddlePaddle, TensorFlow等。

2. 导出ONNX

2.1 安装依赖包

pip install onnx
pip install onnxruntime

2.2 导出ONNX模型

import os
import torch
import torch.nn as nn
from torchvision.models import resnet18if __name__ == "__main__":dir = os.path.dirname(__file__)weightpath = os.path.join(os.path.dirname(__file__), "pth", "resnet18_default_weight.pth")onnxpath = os.path.join(os.path.dirname(__file__), "pth", "resnet18_default_weight.onnx")device = torch.device("cuda" if torch.cuda.is_available() else "cpu")model = resnet18(pretrained=False)model.conv1 = nn.Conv2d(#in_channels=3,out_channels=64,kernel_size=3,stride=1,padding=0,bias=False,)# 删除池化层model.maxpool = nn.MaxPool2d(kernel_size=1, stride=1, padding=0)# 修改全连接层in_feature = model.fc.in_featuresmodel.fc = nn.Linear(in_feature, 10)model.load_state_dict(torch.load(weightpath, map_location=device))model.to(device)# 创建一个实例输入x = torch.randn(1, 3, 224, 224, device=device)# 导出onnxtorch.onnx.export(model,x,onnxpath,#verbose=True, # 输出转换过程input_names=["input"],output_names=["output"],)print("onnx导出成功")

2.3 ONNX结构可视化

可以直接在线查看:https://netron.app/

也可以下载桌面版:https://github.com/lutzroeder/netron

3. ONNX推理

ONNX在做推理时不再需要导入网络,且适用于Python、JAVA、PyQT等各种语言,不再依赖于PyTorch框架;

3.1 简单推理

import onnxruntime as ort
import torchvision.transforms as transforms
import cv2 as cv
import os
import numpy as npimg_size = 224
transformtest = transforms.Compose([transforms.ToPILImage(),  # 将numpy数组转换为PIL图像transforms.Resize((img_size, img_size)),transforms.ToTensor(),transforms.Normalize(# 均值和标准差mean=[0.4914, 0.4822, 0.4465],std=[0.2471, 0.2435, 0.2616],),]
)def softmax(x):e_x = np.exp(x - np.max(x))return e_x / e_x.sum(axis=1, keepdims=True)def cv_imread(file_path):cv_img = cv.imdecode(np.fromfile(file_path, dtype=np.uint8), cv.IMREAD_COLOR)return cv_imglablename = "飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船和卡车".split("、")if __name__ == "__main__":dir = os.path.dirname(__file__)weightpath = os.path.join(os.path.dirname(__file__), "pth", "resnet18_default_weight.pth")onnxpath = os.path.join(os.path.dirname(__file__), "pth", "resnet18_default_weight.onnx")# 读取图片img_path = os.path.join(dir, "test", "5.jpg")img = cv_imread(img_path)img = cv.cvtColor(img, cv.COLOR_BGR2RGB)img_tensor = transformtest(img)# 将图片转换为ONNX运行时所需的格式img_numpy = img_tensor.numpy()img_numpy = np.expand_dims(img_numpy, axis=0)  # 增加batch_size维度# 加载onnx模型sess = ort.InferenceSession(onnxpath)# 运行onnx模型outputs = sess.run(None, {"input": img_numpy})output = outputs[0]# 应用softmaxprobabilities = softmax(output)print(probabilities)# 获得预测结果pred_index = np.argmax(probabilities, axis=1)pred_value = probabilities[0][pred_index[0]]print(pred_index)print("预测目标:",lablename[pred_index[0]],"预测概率:",str(pred_value * 100)[:5] + "%",)

输出结果:

[[6.7321511e-05 9.7113671e-11 7.6417709e-05 2.8661249e-02 7.0206769e-043.9052707e-04 9.7010124e-01 6.8206714e-07 4.1351362e-07 5.7089373e-09]]
[6]
预测目标: 青蛙 预测概率: 97.01%

3.2 使用GPU推理

需要安装依赖包:

pip install onnxruntime-gpu

代码:

# 导入FileSystemStorage
import time
import random
import os# 人工智能推理用到的模块
import onnxruntime as ort
import torchvision.transforms as transforms
import numpy as np
import PIL.Image as Imageimg_size = 32
transformtest = transforms.Compose([transforms.Resize((img_size, img_size)),transforms.ToTensor(),transforms.Normalize(# 均值和标准差mean=[0.4914, 0.4822, 0.4465],std=[0.2471, 0.2435, 0.2616],),]
)def softmax(x):e_x = np.exp(x - np.max(x))return e_x / e_x.sum(axis=1, keepdims=True)def imgclass():# AI推理# 读取图片imgpath = os.path.join(os.path.dirname(__file__), "..", "static/ai", filename)# 加载并预处理图像image = Image.open(imgpath)input_tensor = transformtest(image)input_tensor = input_tensor.unsqueeze(0)  # 添加批量维度# 将图片转换为ONNX运行时所需的格式img_numpy = input_tensor.numpy()# 加载模型onnxPath = os.path.join(#os.path.dirname(__file__),"..","onnx","resnet18_default_weight_1.onnx",)# 设置 ONNX Runtime 使用 GPUproviders = ["CUDAExecutionProvider"]sess = ort.InferenceSession(onnxPath, providers=providers)# 使用模型对图片进行推理运算output = sess.run(None, {"input": img_numpy})output = softmax(output[0])print(output)ind = np.argmax(output, axis=1)print(ind)lablename = "飞机、汽车、鸟类、猫、鹿、狗、青蛙、马、船、卡车".split("、")res = {"code": 200, "msg": "处理成功", "url": img, "class": lablename[ind[0]]}

相关文章:

12.6深度学习_模型优化和迁移_模型移植

八、模型移植 1. 认识ONNX ​ https://onnx.ai/ ​ Open Neural Network Exchange(ONNX,开放神经网络交换)格式,是一个用于表示深度学习模型的标准,可使模型在不同框架之间进行转移。 ​ ONNX的规范及代码主要由微软…...

Grid++Report:自定义模板设计(自由表格使用),详细教程

实现效果 步骤 一、新建空白 初始状态都是空白页,如果不是,点击右上角->文件->新建空白 二、页面设置 右击页面灰色部分->页面设置 根据需求自定义页面 三、报表头设计 1、新增报表头 右击屏幕->新增->报表节->报表头 点击报表头…...

[Collection与数据结构] 位图与布隆过滤器

🌸个人主页:https://blog.csdn.net/2301_80050796?spm1000.2115.3001.5343 🏵️热门专栏: 🧊 Java基本语法(97平均质量分)https://blog.csdn.net/2301_80050796/category_12615970.html?spm1001.2014.3001.5482 🍕 Collection与…...

idea中新建一个空项目

目的,为了在同一个目录下有多个小的项目:使用IDE为idea2022。 步骤: 点击新建项目,点击创建空项目,这里选择空项目是将其作为其他项目的一个容器,如图所示: 然后点击文件->项目结构&#xf…...

【Python】【Conda 】Conda 与 venv 虚拟环境优缺点全解:如何做出明智选择

目录 引言一、基本概念1.1 Conda 虚拟环境1.2 Python venv 虚拟环境 二、主要区别对比三、优缺点分析3.1 Conda 虚拟环境的优缺点3.2 Python venv 虚拟环境的优缺点 四、使用场景推荐4.1 使用 Conda 虚拟环境的场景4.2 使用 Python venv 虚拟环境的场景 五、虚拟环境管理工具对…...

深度学习在故障检测中的应用:从理论到实践

随着工业设备和信息系统的复杂性增加,故障检测成为企业运维的重要任务。然而,传统的基于规则或统计学的故障检测方法难以应对复杂多变的故障模式。深度学习作为一种强大的数据分析工具,为故障检测提供了新的解决思路。本文将介绍深度学习模型…...

自然语言处理与人工智能

自然语言处理(NLP)与人工智能(AI) 自然语言处理(NLP)是人工智能(AI)领域的一个重要分支,旨在使计算机能够理解、解释和生成自然语言。随着深度学习技术的进步&#xff0…...

量化交易系统开发-实时行情自动化交易-8.15.Ptrade/恒生平台

19年创业做过一年的量化交易但没有成功,作为交易系统的开发人员积累了一些经验,最近想重新研究交易系统,一边整理一边写出来一些思考供大家参考,也希望跟做量化的朋友有更多的交流和合作。 接下来会对于Ptrade/恒生平台介绍。 P…...

非常简单实用的前后端分离项目-仓库管理系统(Springboot+Vue)part 4

三十三、出入库管理 Header.vue导一下,RecordController加一个 //将入库数据和原有数据相加吧//新增PostMapping("/save")public Result save(RequestBody Record record) {return recordService.save(record) ? Result.success() : Result.fail();} GoodsManage.v…...

基于MATLAB的信号处理工具:信号分析器

信号(或时间序列)是与特定时间相关的一系列数字或测量值,不同的行业和学科将这一与时间相关的数字序列称为信号或时间序列。生物医学或电气工程师会将其称为信号,而统计学家或金融定量分析师会使用时间序列这一术语。例如&#xf…...

Codeforces Round 784 (Div. 4)

题目链接 A. Division? 题意 思路 模拟即可 示例代码 void solve() {int n;cin >> n;int ans;if(n > 1900) ans 1;else if(n > 1600) ans 2;else if(n > 1400) ans 3;else ans 4;cout << "Division " << ans << \n;}B. T…...

OpenNebula 开源虚拟平台,对标 VMware

Beeks Group 主要为金融服务提供商提供虚拟专用服务器和裸机服务器。该公司表示&#xff0c;转向 OpenNebula 不仅大幅降低了成本&#xff0c;还使其虚拟机效率提升了 200%&#xff0c;并将更多裸机服务器资源用于客户端负载&#xff0c;而非像以往使用 VMware 时那样用于虚拟机…...

软件项目标书参考,合同拟制,开发合同制定,开发协议,标书整体技术方案,实施方案,通用套用方案,业务流程,技术架构,数据库架构全资料下载(原件)

1、终止合同协议书 2、项目合作协议 3、合同交底纪要 4、合同管理台账 软件资料清单列表部分文档清单&#xff1a;工作安排任务书&#xff0c;可行性分析报告&#xff0c;立项申请审批表&#xff0c;产品需求规格说明书&#xff0c;需求调研计划&#xff0c;用户需求调查单&…...

Jenkins环境一站式教程:从安装到配置,打造高效CI/CD流水线环境-Ubuntu 22.04.5 环境离线安装配置 Jenkins 2.479.1

文章目录 Jenkins环境一站式教程&#xff1a;从安装到配置&#xff0c;打造高效CI/CD流水线环境-Ubuntu 22.04.5 环境离线安装配置 Jenkins 2.479.1一、环境准备1.1 机器规划1.2 环境配置1.2.1 设置主机名1.2.2 停止和禁用防火墙1.2.3 更新系统 二、安装配置Jenkins2.1 安装JDK…...

【Android】ARouter源码解析

本篇文章主要讲解了 ARouter 框架的源码分析&#xff0c;包括其初始化过程、核心方法等。 初始化 在使用ARouter的时候我们都会先进行初始化&#xff1a; ARouter.init(this);我们看下 init() 源码&#xff1a; public static void init(Application application) {// 检查…...

计算直线的交点数

主要实现思路 整体流程思路&#xff1a; 程序旨在解决给定平面上不同数量的直线&#xff08;无三线共点&#xff09;&#xff0c;求出每种直线数量下所有可能的交点数量&#xff0c;并按要求格式输出的问题。整体通过初始化一个二维数组来存储不同直线数量与交点数量对应的存在…...

STM32基于HAL库的串口接收中断触发机制和适用场景

1. HAL_UART_Receive_DMA函数 基本功能 作用&#xff1a;启动一个固定长度的 DMA 数据接收。特点&#xff1a; 需要预先指定接收数据的长度&#xff08;Size 参数&#xff09;。DMA 会一直工作直到接收到指定数量的数据&#xff0c;接收完成后触发 HAL_UART_RxCpltCallback 回…...

java面试宝典

本文只摘抄部分宝典内容&#xff0c;完整宝典可以在打开下方链接&#xff0c;在网盘获取 ^ _ ^ 链接:java面试宝典 提取码: wxy1 复制这段内容后打开百度网盘手机App&#xff0c;操作更方便哦 链接: java前端面试宝典 提取码: wxy1 复制这段内容后打开百度网盘手机App&#xff…...

Scala—Slice(提取子序列)方法详解

Scala—Slice&#xff08;提取子序列&#xff09;方法详解 在 Scala 中&#xff0c;slice 方法用于从集合中提取一个连续的子序列&#xff08;切片&#xff09;。可以应用于多种集合类型&#xff0c;如 List、Array、Seq 等。 一、slice 方法的定义 slice 根据提供的起始索引…...

【电子通识】案例:USB Type-C USB 3.0线缆做直通连接器TX/RX反向

【电子通识】案例:连接器接线顺序评估为什么新人总是评估不到位?-CSDN博客这个文章的后续。最近在做一个工装项目,需要用到USB Type-C线缆做连接。 此前已经做好了线序规划,结果新人做成实物后发现有的USB Type-C线缆可用,有的不行。其中发现USB3.0的TX-RX信号与自己的板卡…...

XCTF-web-easyupload

试了试php&#xff0c;php7&#xff0c;pht&#xff0c;phtml等&#xff0c;都没有用 尝试.user.ini 抓包修改将.user.ini修改为jpg图片 在上传一个123.jpg 用蚁剑连接&#xff0c;得到flag...

iOS 26 携众系统重磅更新,但“苹果智能”仍与国行无缘

美国西海岸的夏天&#xff0c;再次被苹果点燃。一年一度的全球开发者大会 WWDC25 如期而至&#xff0c;这不仅是开发者的盛宴&#xff0c;更是全球数亿苹果用户翘首以盼的科技春晚。今年&#xff0c;苹果依旧为我们带来了全家桶式的系统更新&#xff0c;包括 iOS 26、iPadOS 26…...

Docker 运行 Kafka 带 SASL 认证教程

Docker 运行 Kafka 带 SASL 认证教程 Docker 运行 Kafka 带 SASL 认证教程一、说明二、环境准备三、编写 Docker Compose 和 jaas文件docker-compose.yml代码说明&#xff1a;server_jaas.conf 四、启动服务五、验证服务六、连接kafka服务七、总结 Docker 运行 Kafka 带 SASL 认…...

Go 语言接口详解

Go 语言接口详解 核心概念 接口定义 在 Go 语言中&#xff0c;接口是一种抽象类型&#xff0c;它定义了一组方法的集合&#xff1a; // 定义接口 type Shape interface {Area() float64Perimeter() float64 } 接口实现 Go 接口的实现是隐式的&#xff1a; // 矩形结构体…...

视频字幕质量评估的大规模细粒度基准

大家读完觉得有帮助记得关注和点赞&#xff01;&#xff01;&#xff01; 摘要 视频字幕在文本到视频生成任务中起着至关重要的作用&#xff0c;因为它们的质量直接影响所生成视频的语义连贯性和视觉保真度。尽管大型视觉-语言模型&#xff08;VLMs&#xff09;在字幕生成方面…...

3403. 从盒子中找出字典序最大的字符串 I

3403. 从盒子中找出字典序最大的字符串 I 题目链接&#xff1a;3403. 从盒子中找出字典序最大的字符串 I 代码如下&#xff1a; class Solution { public:string answerString(string word, int numFriends) {if (numFriends 1) {return word;}string res;for (int i 0;i &…...

Caliper 配置文件解析:config.yaml

Caliper 是一个区块链性能基准测试工具,用于评估不同区块链平台的性能。下面我将详细解释你提供的 fisco-bcos.json 文件结构,并说明它与 config.yaml 文件的关系。 fisco-bcos.json 文件解析 这个文件是针对 FISCO-BCOS 区块链网络的 Caliper 配置文件,主要包含以下几个部…...

QT: `long long` 类型转换为 `QString` 2025.6.5

在 Qt 中&#xff0c;将 long long 类型转换为 QString 可以通过以下两种常用方法实现&#xff1a; 方法 1&#xff1a;使用 QString::number() 直接调用 QString 的静态方法 number()&#xff0c;将数值转换为字符串&#xff1a; long long value 1234567890123456789LL; …...

在WSL2的Ubuntu镜像中安装Docker

Docker官网链接: https://docs.docker.com/engine/install/ubuntu/ 1、运行以下命令卸载所有冲突的软件包&#xff1a; for pkg in docker.io docker-doc docker-compose docker-compose-v2 podman-docker containerd runc; do sudo apt-get remove $pkg; done2、设置Docker…...

GC1808高性能24位立体声音频ADC芯片解析

1. 芯片概述 GC1808是一款24位立体声音频模数转换器&#xff08;ADC&#xff09;&#xff0c;支持8kHz~96kHz采样率&#xff0c;集成Δ-Σ调制器、数字抗混叠滤波器和高通滤波器&#xff0c;适用于高保真音频采集场景。 2. 核心特性 高精度&#xff1a;24位分辨率&#xff0c…...