当前位置: 首页 > news >正文

基于 PyTorch 的模型瘦身三部曲:量化、剪枝和蒸馏,让模型更短小精悍!

基于 PyTorch 的模型量化、剪枝和蒸馏

    • 1. 模型量化
      • 1.1 原理介绍
      • 1.2 PyTorch 实现
    • 2. 模型剪枝
      • 2.1 原理介绍
      • 2.2 PyTorch 实现
    • 3. 模型蒸馏
      • 3.1 原理介绍
      • 3.2 PyTorch 实现
    • 参考文献

在这里插入图片描述

1. 模型量化

1.1 原理介绍

模型量化是将模型参数从高精度(通常是 float32)转换为低精度(如 int8 或更低)的过程。这种技术可以显著减少模型大小、降低计算复杂度,并加快推理速度,同时尽可能保持模型的性能。
在这里插入图片描述
量化的主要方法包括:

  1. 动态量化

    • 在推理时动态地将权重从 float32 量化为 int8。
    • 激活值在计算过程中保持为浮点数。
    • 适用于 RNN 和变换器等模型。
  2. 静态量化

    • 在推理之前,预先将权重从 float32 量化为 int8。
    • 在推理过程中,激活值也被量化。
    • 需要校准数据来确定激活值的量化参数。
  3. 量化感知训练(QAT)

    • 在训练过程中模拟量化操作。
    • 允许模型适应量化带来的精度损失。
    • 通常能够获得比后量化更高的精度。

1.2 PyTorch 实现

import torch# 1. 动态量化
model_fp32 = MyModel()
model_int8 = torch.quantization.quantize_dynamic(model_fp32,  # 原始模型{torch.nn.Linear, torch.nn.LSTM},  # 要量化的层类型dtype=torch.qint8  # 量化后的数据类型
)# 2. 静态量化
model_fp32 = MyModel()
model_fp32.eval()  # 设置为评估模式# 设置量化配置
model_fp32.qconfig = torch.quantization.get_default_qconfig('fbgemm')
model_fp32_prepared = torch.quantization.prepare(model_fp32)# 使用校准数据进行校准
with torch.no_grad():for batch in calibration_data:model_fp32_prepared(batch)# 转换模型
model_int8 = torch.quantization.convert(model_fp32_prepared)# 3. 量化感知训练
model_fp32 = MyModel()
model_fp32.train()  # 设置为训练模式# 设置量化感知训练配置
model_fp32.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
model_fp32_prepared = torch.quantization.prepare_qat(model_fp32)# 训练循环
for epoch in range(num_epochs):for batch in train_data:output = model_fp32_prepared(batch)loss = criterion(output, target)loss.backward()optimizer.step()# 转换模型
model_int8 = torch.quantization.convert(model_fp32_prepared)

2. 模型剪枝

2.1 原理介绍

模型剪枝是一种通过移除模型中不重要的权重或神经元来减少模型复杂度的技术。剪枝可以减少模型大小、降低计算复杂度,并可能改善模型的泛化能力。
在这里插入图片描述

主要的剪枝方法包括:

  1. 权重剪枝

    • 移除绝对值小于某个阈值的单个权重。
    • 可以大幅减少模型参数数量,但可能导致非结构化稀疏性。
  2. 结构化剪枝

    • 移除整个卷积核、神经元或通道。
    • 产生更加规则的稀疏结构,有利于硬件加速。
  3. 重要性剪枝

    • 基于权重或激活值的重要性评分来决定剪枝对象。
    • 常用的重要性度量包括权重幅度、激活值、梯度等。

2.2 PyTorch 实现

import torch
import torch.nn.utils.prune as prunemodel = MyModel()# 1. 权重剪枝
prune.l1_unstructured(model.conv1, name='weight', amount=0.3)# 2. 结构化剪枝
prune.ln_structured(model.conv1, name='weight', amount=0.5, n=2, dim=0)# 3. 全局剪枝
parameters_to_prune = ((model.conv1, 'weight'),(model.conv2, 'weight'),(model.fc1, 'weight'),
)
prune.global_unstructured(parameters_to_prune,pruning_method=prune.L1Unstructured,amount=0.2
)# 4. 移除剪枝
for module in model.modules():if isinstance(module, torch.nn.Conv2d):prune.remove(module, 'weight')

3. 模型蒸馏

3.1 原理介绍

模型蒸馏是一种将复杂模型(教师模型)的知识转移到简单模型(学生模型)的技术。这种方法可以在保持性能的同时,大幅减少模型的复杂度和计算需求。
在这里插入图片描述

主要的蒸馏方法包括:

  1. 响应蒸馏

    • 学生模型学习教师模型的最终输出(软标签)。
    • 软标签包含了教师模型对不同类别的置信度信息。
  2. 特征蒸馏

    • 学生模型学习教师模型的中间层特征。
    • 可以传递更丰富的知识,但需要设计合适的映射函数。
  3. 关系蒸馏

    • 学习样本之间的关系,如相似度或排序。
    • 有助于保持教师模型学到的数据结构。

3.2 PyTorch 实现

import torch
import torch.nn as nn
import torch.nn.functional as Fclass DistillationLoss(nn.Module):def __init__(self, alpha=0.5, temperature=2.0):super().__init__()self.alpha = alphaself.T = temperaturedef forward(self, student_outputs, teacher_outputs, labels):# 硬标签损失hard_loss = F.cross_entropy(student_outputs, labels)# 软标签损失soft_loss = F.kl_div(F.log_softmax(student_outputs / self.T, dim=1),F.softmax(teacher_outputs / self.T, dim=1),reduction='batchmean') * (self.T * self.T)# 总损失loss = (1 - self.alpha) * hard_loss + self.alpha * soft_lossreturn loss# 训练循环
teacher_model = TeacherModel().eval()
student_model = StudentModel().train()
distillation_loss = DistillationLoss(alpha=0.5, temperature=2.0)for epoch in range(num_epochs):for batch, labels in train_loader:optimizer.zero_grad()with torch.no_grad():teacher_outputs = teacher_model(batch)student_outputs = student_model(batch)loss = distillation_loss(student_outputs, teacher_outputs, labels)loss.backward()optimizer.step()

通过这些技术的组合使用,可以显著减小模型大小、提高推理速度,同时尽可能保持模型性能。在实际应用中,可能需要根据具体任务和硬件限制来选择和调整这些方法。

参考文献

[1]Jacob, B., Kligys, S., Chen, B., Zhu, M., Tang, M., Howard, A., Adam, H., & Kalenichenko, D. (2018). Quantization and Training of Neural Networks for Efficient Integer-Arithmetic-Only Inference. In Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition (CVPR) (pp. 2704-2713).[2]Krishnamoorthi, R. (2018). Quantizing deep convolutional networks for efficient inference: A whitepaper. arXiv preprint arXiv:1806.08342.[3]Han, S., Pool, J., Tran, J., & Dally, W. (2015). Learning both Weights and Connections for Efficient Neural Network. In Advances in Neural Information Processing Systems (NeurIPS) (pp. 1135-1143).[4]Li, H., Kadav, A., Durdanovic, I., Samet, H., & Graf, H. P. (2016). Pruning Filters for Efficient ConvNets. arXiv preprint arXiv:1608.08710.[5]Hinton, G., Vinyals, O., & Dean, J. (2015). Distilling the Knowledge in a Neural Network. arXiv preprint arXiv:1503.02531.[6]Romero, A., Ballas, N., Kahou, S. E., Chassang, A., Gatta, C., & Bengio, Y. (2014). FitNets: Hints for Thin Deep Nets. arXiv preprint arXiv:1412.6550.

创作不易,烦请各位观众老爷给个三连,小编在这里跪谢了!
在这里插入图片描述

相关文章:

基于 PyTorch 的模型瘦身三部曲:量化、剪枝和蒸馏,让模型更短小精悍!

基于 PyTorch 的模型量化、剪枝和蒸馏 1. 模型量化1.1 原理介绍1.2 PyTorch 实现 2. 模型剪枝2.1 原理介绍2.2 PyTorch 实现 3. 模型蒸馏3.1 原理介绍3.2 PyTorch 实现 参考文献 1. 模型量化 1.1 原理介绍 模型量化是将模型参数从高精度(通常是 float32&#xff0…...

二、原型模式

文章目录 1 基本介绍2 实现方式深浅拷贝目标2.1 使用 Object 的 clone() 方法2.1.1 代码2.1.2 特性2.1.3 实现深拷贝 2.2 在 clone() 方法中使用序列化2.2.1 代码 2.2.2 特性 3 实现的要点4 Spring 中的原型模式5 原型模式的类图及角色5.1 类图5.1.1 不限制语言5.1.2 在 Java 中…...

【目标检测】Anaconda+PyTorch(GPU)+PyCharm(Yolo5)配置

前言 本文主要介绍在windows系统上的Anaconda、PyTorch、PyCharm、Yolov5关键步骤安装,为使用yolo所需的环境配置完善。同时也算是记录下我的配置流程,为以后用到的时候能笔记查阅。 Anaconda 软件安装 Anaconda官网:https://www.anaconda…...

Django实战项目之进销存数据分析报表——第二天:项目创建和 PyCharm 配置

在上一篇博客中,我们讨论了如何搭建一个全栈 Web 应用的开发环境,包括 Python 环境的创建、Django 和 MySQL 的安装以及前端技术栈的选择。现在,让我们继续深入,学习如何在 PyCharm 中创建一个新的 Django 项目并进行配置。 一…...

静态路由实验

1.实验拓扑图 二、实验要求 1.R6为ISP,接口IP地址均为公有地址,该设备只能配置IP地址,之后不能再对其进行任何配置; 2.R1-R5为局域网,私有IP地址192.168.1.0/24,请合理分配; 3.R1、R2、R4&…...

VSCode STM32嵌入式开发插件记录

要卸载之前搭建的VSCode嵌入式开发环境了,记录一下用的插件。 1.Cortex-Debug https://github.com/Marus/cortex-debug 2.Embedded IDE https://github.com/github0null/eide 3.Keil uVision Assistant https://github.com/jacksonjim/keil-assistant/ 4.RTO…...

linux cpu 占用超100% 分析。

感谢: https://www.cnblogs.com/wolfstark/p/16450131.html 总结&#xff1a; 查看进程中各个线程占用百分比 top -H -p <pid> 某线程100%了 说明 任务处理不过来 会卡 但是永远不可能超100% 系统监视器里面看到的是 所有线程占用的 总和会超100%。 所以最好的情况是&…...

自然学习法和科学学习法

一、自然学习法 自然学习法&#xff1a;什么事自然学习法&#xff0c;特意让kimi来回答了一下。所谓的自然学习法说的俗一点就是野路子学习方法。这种学习方法的特点是“慢”“没有系统性”&#xff0c;学完之后感觉都会了&#xff0c;但是又感觉什么都不会。 二、科学学习法 …...

力扣第二十四题——两两交换链表中的节点

内容介绍 给你一个链表&#xff0c;两两交换其中相邻的节点&#xff0c;并返回交换后链表的头节点。你必须在不修改节点内部的值的情况下完成本题&#xff08;即&#xff0c;只能进行节点交换&#xff09;。 示例 1&#xff1a; 输入&#xff1a;head [1,2,3,4] 输出&#xff…...

C语言柔性数组详解

目录 1.柔性数组 2.柔性数组的特点 3.柔性数组的使用 4.柔性数组的优势 1.柔性数组 C99 中&#xff0c;结构体中的最后一个元素允许是未知大小的数组&#xff0c;这就叫做『柔性数组』成员。 例如&#xff1a; struct S {char c;int n;int arr[];//柔性数组 }; struct …...

自动驾驶---视觉Transformer的应用

1 背景 在过去的几年&#xff0c;随着自动驾驶技术的不断发展&#xff0c;神经网络逐渐进入人们的视野。Transformer的应用也越来越广泛&#xff0c;逐步走向自动驾驶技术的前沿。笔者也在博客《人工智能---什么是Transformer?》中大概介绍了Transformer的一些内容&#xff1a…...

预训练语言模型实践笔记

Roberta output_hidden_statesTrue和last_hidden_states和pooler_output 在使用像BERT或RoBERTa这样的transformer模型时&#xff0c;output_hidden_states和last_hidden_state是两个不同的概念。 output_hidden_states: 这是一个布尔值&#xff0c;决定了模型是否应该返回所…...

Perl 哈希

Perl 哈希 Perl 哈希是一种强大的数据结构&#xff0c;用于存储键值对集合。它是 Perl 语言的核心特性之一&#xff0c;广泛应用于各种编程任务中。本文将详细介绍 Perl 哈希的概念、用法和最佳实践。 什么是 Perl 哈希&#xff1f; Perl 哈希是一种关联数组&#xff0c;其中…...

Linux之Mysql索引和优化

一、MySQL 索引 索引作为一种数据结构,其用途是用于提升数据的检索效率。 1、索引分类 - 普通索引(INDEX):索引列值可重复 - 唯一索引(UNIQUE):索引列值必须唯一,可以为NULL - 主键索引(PRIMARY KEY):索引列值必须唯一,不能为NULL,一个表只能有一个主键索引 - 全…...

springboot业务逻辑写在controller层吗

Spring Boot中的业务逻辑不应该直接写在Controller层。‌ 在Spring Boot项目中&#xff0c;‌通常将业务逻辑分为几个层次&#xff0c;‌包括Controller层、‌Service层、‌Mapper层和Entity层。‌ 1.其中&#xff0c;‌Controller层主要负责处理HTTP请求&#xff0c;‌通过注…...

Ubuntu 24.04 LTS 桌面安装MT4或MT5 (MetaTrader)教程

运行脚本即可在 Ubuntu 24.04 LTS Noble Linux 上轻松安装 MetaTrader 5 或 4 应用程序&#xff0c;使用 WineHQ 进行外汇交易。 MetaTrader 4 (MT4) 或 MetaTrader 5 是用于交易外汇对和商品的流行平台。它支持各种外汇经纪商、内置价格分析工具以及通过专家顾问 (EA) 进行自…...

Go基础编程 - 12 -流程控制

流程控制 1. 条件语句1.1. if...else 语句1.2. switch 语句1.3. select 语句1.3.1. select 语句的通信表达式1.3.2. select 的基特性1.3.3. select 的实现原理1.3.4. 经典用法1.3.4.1 超时控制1.3.4.2 多任务并发控制1.3.4.3 监听多通道消息1.3.4.4 default 实现非堵塞读写 2. …...

汽车信息安全--TLS,OpenSSL

目录 TLS相关知识 加密技术 对称加密 非对称加密 数字签名和CA 信任链 根身份证和自签名 双方TLS认证 加密和解密的性能 TLS相关知识 加密技术 TLS依赖两种加密技术 1. 对称加密&#xff08;symmetric encryption&#xff09; 2. 非对称加密&#xff08;asymmetri…...

深入探索 SQL 中的 LIKE 右模糊匹配(LIKE RIGHT)与左模糊匹配(LIKE LEFT)

引言 在数据库操作中&#xff0c;LIKE 子句是执行模糊搜索的强大工具&#xff0c;用于匹配列中的数据与指定的模式。本文将详细介绍 LIKE 子句中的两种常用模式&#xff1a;右模糊匹配&#xff08;LIKE RIGHT&#xff09;和左模糊匹配&#xff08;LIKE LEFT&#xff09;&#…...

mybatis 多数据源 TDataSource required a single bean, but 2 were found

情况说明&#xff1a; 项目中本来就有一个数据源了&#xff0c;运行的好好的后来又合并了另一个项目&#xff0c;另一个项目也配置了数据源。 于是出现了如下错误&#xff1a; mybatis 多数据源 TDataSource required a single bean, but 2 were found 解决方法&#xff1a…...

Dubbo SPI 之路由器

1. 背景介绍 Dubbo 是一个高性能的 Java RPC 框架&#xff0c;由阿里巴巴开源并广泛应用于分布式系统中。在 Dubbo 的架构中&#xff0c;SPI&#xff08;Service Provider Interface&#xff09;是一个关键组件&#xff0c;允许在运行时动态加载不同的服务实现。SPI 机制提供了…...

Python深度学习环境配置(Pytorch、CUDA、cuDNN),包括Anaconda搭配Pycharm的环境搭建以及基础使用教程(保姆级教程,适合小白、深度学习零基础入门)

全流程导览 一、前言二、基本介绍2.1全过程软件基本介绍2.1.1 Pytorch2.1.2 Anaconda2.1.3 Pycharm2.1.4 显卡GPU及其相关概念2.1.5 CUDA和cuDNN 2.2 各部分相互间的联系和安装逻辑关系 三、Anaconda安装3.1安装Anaconda3.2配置环境变量3.3检验是否安装成功 四、Pycharm安装五、…...

月影护眼大路灯怎么样?书客|月影|霍尼韦尔超硬核实力性能测评pk!

月影护眼大路灯怎么样&#xff1f;选到专业优质的护眼大路灯是真的可以使我们在用眼时减少疲劳感&#xff0c;达到护眼效果&#xff0c;但如果不慎买到劣质的护眼灯产品&#xff0c;不仅达不到健康的环境光&#xff0c;还越用越觉得眼睛疲劳感加重&#xff0c;在水深的护眼灯市…...

邮件安全篇:邮件传输加密(SSL/TLS or STATRTTLS)

1. 前言 使用过邮件客户端的同学一定见过下面这张图。这是客户端账号配置界面&#xff0c;里面有SSL、STARTTLS选项。刚接触邮件客户端的同学肯定会有这些疑问&#xff1a;什么是SSL&#xff1f;什么是STARTTLS&#xff1f;两者有什么区别&#xff1f;具体该如何选择呢&#x…...

【系统架构设计 每日一问】三 Redis支持事务么,Redis的事务如何保证

实际上&#xff0c;关于Redis事务的说法“Redis 的事务只能保证隔离性和一致性&#xff08;I 和 C&#xff09;&#xff0c;无法保证原子性和持久性&#xff08;A 和 D&#xff09;”并不完全准确。下面我将分别解释Redis事务的四个特性&#xff1a;原子性&#xff08;Atomicit…...

【中项】系统集成项目管理工程师-第4章 信息系统架构-4.3应用架构

前言&#xff1a;系统集成项目管理工程师专业&#xff0c;现分享一些教材知识点。觉得文章还不错的喜欢点赞收藏的同时帮忙点点关注。 软考同样是国家人社部和工信部组织的国家级考试&#xff0c;全称为“全国计算机与软件专业技术资格&#xff08;水平&#xff09;考试”&…...

DasViewer打开Revit输出的fbx格式的模型,为啥一团黑?

答:这个应该是没有读取到贴图文件。贴图文件和obj文件需要在同级目录下面。 DasViewer是由大势智慧自主研发的免费的实景三维模型浏览器,采用多细节层次模型逐步自适应加载技术,让用户在极低的电脑配置下,也能流畅的加载较大规模实景三维模型,提供方便快捷的数据浏览操作。 免…...

【05】LLaMA-Factory微调大模型——初尝微调模型

上文【04】LLaMA-Factory微调大模型——数据准备介绍了如何准备指令监督微调数据&#xff0c;为后续的微调模型提供高质量、格式规范的数据支撑。本文将正式进入模型微调阶段&#xff0c;构建法律垂直应用大模型。 一、硬件依赖 LLaMA-Factory框架对硬件和软件的依赖可见以下…...

Training for Stable Diffusion

1.Training for Stable Diffusion 笔记来源&#xff1a; 1.Denoising Diffusion Probabilistic Models 2.最大似然估计(Maximum likelihood estimation) 3.Understanding Maximum Likelihood Estimation 4.How to Solve ‘CUDA out of memory’ in PyTorch 5.pytorch-stable-d…...

初学51单片机之指针基础与串口通信应用

开始之前推荐一个电路学习软件&#xff0c;这个软件笔者也刚接触。名字是Circuit有在线版本和不在线版本&#xff0c;这是笔者在B站看视频翻到的。 Paul Falstadhttps://www.falstad.com/这是地址。 离线版本在网站内点这个进去 根据你的系统下载你需要的版本红线的是windows…...

网站设计计划书模板/班级优化大师网页版

为什么80%的码农都做不了架构师&#xff1f;>>> UIView *cutView self.view.window.rootViewController.view; //截图 //开启上下文 UIGraphicsBeginImageContextWithOptions(cutView.frame.size, YES, 0.0); //把cutView渲染到上下文中 [c…...

好的网站推荐一个/上海关键词排名搜索

关注云报洞察深一度一年一度的“618”大促脚步越来越近了。据说&#xff0c;今年的“618”将是史上折扣力度最大的一次。你是不是已经跃跃欲试了&#xff1f;不过&#xff0c;在享受“剁手”快感的同时&#xff0c;你是否会因那些暗流涌动的欺诈风险而畏首畏尾呢&#xff1f;找…...

做淘客网站用什么服务器好/关键词搜索工具有哪些

什么是HTML?HTML 是用来描述网页的一种语言。HTML 指的是超文本标记语言: HyperText Markup LanguageHTML 不是一种编程语言&#xff0c;而是一种标记语言标记语言是一套标记标签 (markup tag)HTML 使用标记标签来描述网页HTML 文档包含了HTML标签及文本内容HTML文档也叫做web…...

网站制作天津/种子搜索神器

1.相关资料 mybatis开发文档:https://mybatis.org/mybatis-3/zh/getting-started.htmlmybatis源码地址:https://github.com/mybatis/mybatis-3/releases 2.搭建步骤 &#xff08;1&#xff09; 创建 mysql 数据库表 CREATE TABLE t_user (id bigint(20) NOT NULL,user_nam…...

做彩票网站要什么接口/网站优化推广方案

自从1995年Java语言出世以来&#xff0c;从未跌落过神坛&#xff0c;但最近Python语言被抄的太火&#xff0c;看下面这张最新的语言排行榜&#xff0c;独受宠爱的Java似乎要失宠了。从上面这张图我们可以看出&#xff0c;Python一直处于上升趋势&#xff0c;而Java明显有下滑的…...

关于网站建设的意见/中国行业数据分析网

图片尺寸的分布一致&#xff0c; 图片类别的分布一致&#xff0c; 图片分辨率分布一致...