当前位置: 首页 > news >正文

模型优化之剪枝

文章目录

  • 什么是神经网络剪枝
  • 剪枝的好处
  • 不同粒度的剪枝
  • 剪枝的分类
    • 非结构化剪枝
    • 结构化剪枝
  • 哪些层的参数更容易被剪掉
  • 剪枝效果

什么是神经网络剪枝

神经网络剪枝

  • 在训练期间删除连接
  • 密集张量将变得稀疏(用零填充)
  • 可以通过结构化块( n m nm nm)或( 11 11 11)删除连接

在这里插入图片描述

剪枝的好处

  • 减少过拟合
  • 稀疏性优势
  • 文件中有大量的0,如果有适当的稀疏张量表示方法,模型二进制文件尺寸减小。
  • 模型更小,可以减少内存带宽消耗量。
  • 对于特定模式的稀疏模型,可以开发优化算子,实现加速推理。

不同粒度的剪枝

在这里插入图片描述
什么时候做剪枝?

one-shot pruning : 一次性修剪,包括三个步骤训练模型、剪枝、再训练
剪枝:通常根据某种标准(如权重的大小、梯度的大小等)一次性去除大量权重。
再训练:剪枝后,模型通常需要进行一定数量的额外训练(称为fine-tuning或再训练)来恢复剪枝过程中可能损失的性能。

iterative pruning: 迭代式训练,特点如下:
初始训练:首先,对未剪枝的完整模型进行训练,直到达到满意的性能水平。
剪枝:然后,根据某种剪枝策略(例如基于权重的大小或敏感度)剪除模型的部分组件(如权重、神经元或通道)。
再训练:剪枝后,重新训练模型以恢复因剪枝而丢失的性能。
迭代:重复剪枝和再训练的过程,直到达到所需的剪枝率或性能标准。

automated gradual pruning: 自动化渐进剪枝,特点如下:
剪枝策略:采用一种预定义的剪枝策略,例如基于权重阈值、敏感度分析等,该策略在整个剪枝过程中保持一致。
渐进剪枝:在整个训练过程中逐渐增加剪枝率,通常从较低的剪枝率开始,逐步增加到目标剪枝率。
无需再训练:在整个剪枝过程中,模型持续被训练,而不是在剪枝后重新训练。
自动化:整个过程高度自动化,可以减少人为干预的需求

在这里插入图片描述

剪枝的分类

结构化剪枝(Structured Pruning)和非结构化剪枝(Unstructured Pruning)是两种常见的神经网络剪枝方法,它们的主要区别在于剪枝后网络结构的变化以及剪枝操作的粒度。

非结构化剪枝

不改变网络结构或者参数数量,把连接上的参数置0即为剪枝。
基于某种度量(如权重的绝对值大小)对所有权重进行排序,然后根据预先设定的剪枝比例(例如去除50%的最小权重)来决定哪些权重被设置为零。这种剪枝方法不会考虑权重在模型中的位置或结构,只关注权重本身的价值。示例代码:

# 导入剪枝函数
prune_low_magnitude = tfmot.sparsity.keras.prune_low_magnitude# 计算两轮之后完成剪枝时对应的迭代次数end_step
batch_size = 128
epochs = 2
validation_split = 0.1  # 10% of training set will be used for validation set.num_images = train_images.shape[0] * (1 - validation_split)
end_step = np.ceil(num_images / batch_size).astype(np.int32) * epochs# 定义剪枝模型参数,开始模型从50%稀疏度(权重为0的参数数量百分比),到80%稀疏度
pruning_params = {'pruning_schedule': tfmot.sparsity.keras.PolynomialDecay(initial_sparsity=0.50,final_sparsity=0.80,begin_step=0,end_step=end_step)
}model_for_pruning = prune_low_magnitude(model, **pruning_params)# 当使用函数`prune_low_magnitude`包装了一下模型后,需要重新编译一下
model_for_pruning.compile(optimizer='adam',loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),metrics=['accuracy'])model_for_pruning.summary()logdir = "./logs/mnist_pruning"callbacks = [tfmot.sparsity.keras.UpdatePruningStep(),tfmot.sparsity.keras.PruningSummaries(log_dir=logdir),
]model_for_pruning.fit(train_images, train_labels,batch_size=batch_size, epochs=epochs, validation_split=validation_split,callbacks=callbacks)
# --------------------------------------------------
# 评估模型,对比剪枝前后模型的准确率变化
# 经过剪枝,这里有一个小的准确率下降,和没有进行剪枝相比的话
# --------------------------------------------------_, model_for_pruning_accuracy = model_for_pruning.evaluate(test_images, test_labels, verbose=0)print('Baseline test accuracy:', baseline_model_accuracy)
print('Pruned test accuracy:', model_for_pruning_accuracy)

结构化剪枝

结构化剪枝改变了网络结构,即网络层输出元素个数,比如卷积核的减少会影响特征图数量。
在下面的例子中是基于选择的模型层做剪枝,所以需要指出哪些层去做结构化剪枝。比如剪枝第二个卷积层和第一个全连接层,剪枝策略为pruning_params_2_by_4,表示该层剪枝比例为2 / 4,即该层保留一半(2/4)的权重,而将另一半设为零。
注意:第一个卷积层不能被结构化剪枝。要是结构化剪枝的话,应该至少大于一个input channels(本例所用图片为单通道灰度图),所以我们对第一个卷积层使用随机剪枝。

model = keras.Sequential([prune_low_magnitude(keras.layers.Conv2D(32, 5, padding='same', activation='relu',input_shape=(28, 28, 1),name="pruning_sparsity_0_5"),**pruning_params_sparsity_0_5),keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'),prune_low_magnitude(keras.layers.Conv2D(64, 5, padding='same',name="structural_pruning"),**pruning_params_2_by_4),keras.layers.BatchNormalization(),keras.layers.ReLU(),keras.layers.MaxPooling2D((2, 2), (2, 2), padding='same'),keras.layers.Flatten(),prune_low_magnitude(keras.layers.Dense(1024, activation='relu',name="structural_pruning_dense"),**pruning_params_2_by_4),keras.layers.Dropout(0.4),keras.layers.Dense(10)
])

哪些层的参数更容易被剪掉

因为卷积层(conv)中的参数相比全连接层(fc)来说参数量少,所以卷积层参数的压缩比没有全连接层参数的压缩比大。换句话说,就是卷积层参数更加敏感,剪掉对准确率影响相对更大。越靠后的卷积层或卷积层之后的那些全连接层往往参数越容易被剪掉。

剪枝效果

  • 一般50%-70%左右的稀疏性,准确率降低幅度并不大
  • 剪枝是独立于量化技巧,通常与量化配合效果不错
  • 可以通过微调尝试不同的参数组合

相关文章:

模型优化之剪枝

文章目录 什么是神经网络剪枝剪枝的好处不同粒度的剪枝剪枝的分类非结构化剪枝结构化剪枝 哪些层的参数更容易被剪掉剪枝效果 什么是神经网络剪枝 神经网络剪枝 在训练期间删除连接密集张量将变得稀疏(用零填充)可以通过结构化块( n m nm nm&…...

JVM的组成

JVM 运行在操作系统之上 java二进制字节码文件的运行环境 JVM的组成部分 java代码在编写完成后编译成字节码文件通过类加载器 来到运行数据区,主要作用是加载字节码到内存 包含 方法区/元空间 堆 程序计数器,虚拟机栈,本地方法栈等等 随后来到执行引擎,主要作用是翻译字…...

快速上手 iOS Protocol Buffer

快速上手 iOS Protocol Buffer | 来自缤纷多彩的灰 本文主要介绍在 iOS 开发中如何快速上手使用 Protobuf。更多关于 Protobuf 的介绍和相关的功能 api,读者可自行查阅官网。 Protocol Buffer(简称 Protobuf)是一种由Google开发的语言中立、…...

每天一个数据分析题(四百八十)- 线性回归建模

关于线性回归建模,线性回归模型假设说法不正确的是? A. 因变量和自变量要有因果关系 B. 残差均值为0 C. 残差服从正态分布 D. 自变量不存在共线性 数据分析认证考试介绍:点击进入 题目来源于CDA模拟题库 点击此处获取答案 数据分析专…...

电动汽车和混动汽车DC-DC转换器的创新设计与测试方法

汽车 DC-DC 转换器市场规模将达到187亿美元,年复合增长率为10%。 DC-DC 转换器是汽车的重要组成部分,它可以通过电压转换为各种车载系统供电,例如日益复杂的车载信息娱乐系统、使用驾驶辅助系统(ADAS)实现的增强安全功…...

OriginPro快速上手指南:数据可视化与分析的利器

目录 OriginLab - Origin and OriginPro - Data Analysis and Graphing Softwarehttps://www.originlab.com/​编辑 一、安装与界面概览 安装 界面概览 二、基础操作 数据输入 创建图表 三、高级功能 数据分析 自动化与脚本 Origin 提供了几个小工具 四、技巧与提示…...

缓存学习

缓存基本概念 概念 对于缓存,最普遍的理解是能让打开某些页面速度更快的工具。从技术角度来看,其本质上是因为缓存是基于内存建立的,而内存的读写速度相比之于硬盘快了xx倍,因此用内存来代替硬盘作为读写的介质当然能大大提高访…...

亚世光电:消费电子年度表演

机圈风云再起,消费电子乘风而起? 今天我们来聊——亚世光电 最近,华为mate60突然降价,被大家怀疑是为新品上市做准备,算算时间,下半年的消费电子大战也即将拉开帷幕,而亚世光电所在的光电显示领…...

AI 工程应用 建筑表面检测及修复

文章目录 1 项目概述(必写):2 技术方案与实施步骤2.1 模型选择(必写):2.2 数据的构建:2.3 功能整合(进阶): 3 实施步骤:3.1 环境搭建(…...

Qt-Qt中的小事项(7)

目录 命名风格 快捷键 查询文档 坐标系 代码理解 move 命名风格 这个也是老生常谈的问题了,入乡随俗就好啦 快捷键 这里是一些常用的快捷键,用多了自然就熟悉了 • 注释:ctrl/ • 运行:ctrlR • 编译:ctrlB …...

Android MediaRecorder 视频录制及报错解决

目录 一、start failed: -19 二、使用MediaRecorder录制视频 2.1 申请权限 2.2 布局文件 2.3 MediaRecordActivity 2.4 运行结果 三、拓展 3.1 录制视频模糊(解决) 3.2 阿里云OSS上传文件 3.2.1 权限(刚需) 3.2.2 安装SDK 3.2.3 使用 相关链接 一、start failed…...

HarmonyOS应用程序访问控制探究

关于作者 白晓明 宁夏图尔科技有限公司董事长兼CEO、坚果派联合创始人 华为HDE、润和软件HiHope社区专家、鸿蒙KOL、仓颉KOL 华为开发者学堂/51CTO学堂/CSDN学堂认证讲师 开放原子开源基金会2023开源贡献之星 一、引言 随着信息技术的飞速发展,移动应用程序已经成为…...

董卫民赴考拉悠然等企业调研,强调加快发展人工智能产业

8月14日,按照省政府重点产业链协同推进机制有关工作安排,省委常委、常务副省长董卫民在成都市调研人工智能产业发展情况,并召开座谈会。他强调,要坚决落实党的二十届三中全会精神和省委省政府决策部署,充分把握人工智能…...

MFC将类A中的事件在类B中处理采用回调函数实现

需求: 在类A的界面上有一个tab控件。tab控件上面有那个页面。在MFC编程中一个tab的一个页面就应该是一个新的类。在tab的一个页面上有一个list控件。现在需要将list控件的点击事件,双击事件等在类A里面处理。 解决: 在类B里面给控件list添加…...

公众号 微信登录

export function getWxCode(that, localhostUrl) { // localhostUrl 当前页面的路径 传这个也可以this.$route.fullPath// console.log(that.$store.state.wxSessionData)// console.log(that.$store.state.wxSessionData.openId)//openId为undefine执行获取openid判断是否没有…...

sanic + webSocket:股票实时行情推送服务实现

💝💝💝欢迎莅临我的博客,很高兴能够在这里和您见面!希望您在这里可以感受到一份轻松愉快的氛围,不仅可以获得有趣的内容和知识,也可以畅所欲言、分享您的想法和见解。 推荐:「storm…...

Unity动态给按钮各个状态下的图片赋值

Unity动态给按钮各个状态下的图片赋值 using UnityEngine; using UnityEngine.UI; public class ButtonOnClickTest : MonoBehaviour {public Button btn;public Sprite _highlighterSprite;public Sprite _pressedSprite;public Sprite _selectesdSprite;public Sprite _disa…...

xiaomi pad 6PRO 小米平板6 pro hyperOS降级 澎湃os 降级MIUI 14 教程 免解锁BL 降级,168小时解锁绑定

小米平板 6 Pro 机型代号 :liuqin 降级MIUI 14 小米澎湃 OS 正式版 澎湃OS安卓发布日期卡刷包线刷包OS1.0.7.0.UMYCNXM14.02024-07-13miui_LIUQIN_OS1.0.7.0.UMYCNXM_d618a5c980_14.0.zipliuqin_images_OS1.0.7.0.UMYCNXM_20240705.0000.00_14.0_cn_8cbf5920be.…...

MySQL 备份一个表

语法(创建一个与table1结构相同的新表table2,并且将table1的数据复制到table2): create table table2 as select * from table1 举例(备份tb_log表到tb_log_20240815中去): create table tb_log_20240815 as select * from tb_log...

鸿蒙开发入门day10-组件导航

(创作不易,感谢有你,你的支持,就是我前行的最大动力,如果看完对你有帮助,还请三连支持一波哇ヾ(@^∇^@)ノ) 目录 组件导航 (Navigation) 设置页面显示模式 设置标题栏模式 设置菜…...

铭豹扩展坞 USB转网口 突然无法识别解决方法

当 USB 转网口扩展坞在一台笔记本上无法识别,但在其他电脑上正常工作时,问题通常出在笔记本自身或其与扩展坞的兼容性上。以下是系统化的定位思路和排查步骤,帮助你快速找到故障原因: 背景: 一个M-pard(铭豹)扩展坞的网卡突然无法识别了,扩展出来的三个USB接口正常。…...

C++实现分布式网络通信框架RPC(3)--rpc调用端

目录 一、前言 二、UserServiceRpc_Stub 三、 CallMethod方法的重写 头文件 实现 四、rpc调用端的调用 实现 五、 google::protobuf::RpcController *controller 头文件 实现 六、总结 一、前言 在前边的文章中,我们已经大致实现了rpc服务端的各项功能代…...

golang循环变量捕获问题​​

在 Go 语言中,当在循环中启动协程(goroutine)时,如果在协程闭包中直接引用循环变量,可能会遇到一个常见的陷阱 - ​​循环变量捕获问题​​。让我详细解释一下: 问题背景 看这个代码片段: fo…...

Frozen-Flask :将 Flask 应用“冻结”为静态文件

Frozen-Flask 是一个用于将 Flask 应用“冻结”为静态文件的 Python 扩展。它的核心用途是:将一个 Flask Web 应用生成成纯静态 HTML 文件,从而可以部署到静态网站托管服务上,如 GitHub Pages、Netlify 或任何支持静态文件的网站服务器。 &am…...

SpringBoot+uniapp 的 Champion 俱乐部微信小程序设计与实现,论文初版实现

摘要 本论文旨在设计并实现基于 SpringBoot 和 uniapp 的 Champion 俱乐部微信小程序,以满足俱乐部线上活动推广、会员管理、社交互动等需求。通过 SpringBoot 搭建后端服务,提供稳定高效的数据处理与业务逻辑支持;利用 uniapp 实现跨平台前…...

【python异步多线程】异步多线程爬虫代码示例

claude生成的python多线程、异步代码示例,模拟20个网页的爬取,每个网页假设要0.5-2秒完成。 代码 Python多线程爬虫教程 核心概念 多线程:允许程序同时执行多个任务,提高IO密集型任务(如网络请求)的效率…...

成都鼎讯硬核科技!雷达目标与干扰模拟器,以卓越性能制胜电磁频谱战

在现代战争中,电磁频谱已成为继陆、海、空、天之后的 “第五维战场”,雷达作为电磁频谱领域的关键装备,其干扰与抗干扰能力的较量,直接影响着战争的胜负走向。由成都鼎讯科技匠心打造的雷达目标与干扰模拟器,凭借数字射…...

【HarmonyOS 5 开发速记】如何获取用户信息(头像/昵称/手机号)

1.获取 authorizationCode: 2.利用 authorizationCode 获取 accessToken:文档中心 3.获取手机:文档中心 4.获取昵称头像:文档中心 首先创建 request 若要获取手机号,scope必填 phone,permissions 必填 …...

Web 架构之 CDN 加速原理与落地实践

文章目录 一、思维导图二、正文内容(一)CDN 基础概念1. 定义2. 组成部分 (二)CDN 加速原理1. 请求路由2. 内容缓存3. 内容更新 (三)CDN 落地实践1. 选择 CDN 服务商2. 配置 CDN3. 集成到 Web 架构 &#xf…...

短视频矩阵系统文案创作功能开发实践,定制化开发

在短视频行业迅猛发展的当下,企业和个人创作者为了扩大影响力、提升传播效果,纷纷采用短视频矩阵运营策略,同时管理多个平台、多个账号的内容发布。然而,频繁的文案创作需求让运营者疲于应对,如何高效产出高质量文案成…...