yolov8蒸馏(附代码-免费)
首先蒸馏是什么?
模型蒸馏(Model Distillation)是一种用于在计算机视觉中提高模型性能和效率的技术。在模型蒸馏中,通常存在两个模型,即“教师模型”和“学生模型”。
为什么需要蒸馏?
- 在不增加模型计算量和参数量的情况下提升精度,也即是可以无损提高精度。
- 配合剪枝一起使用,可以尽量达到无损降低模型参数量、计算量,提高FPS的情况下,还能保持模型精度没有下降甚至上升,这是改进网络结构无法达到的高度。
- 论文中的保底手段,因为剪枝和蒸馏的特殊性,其都不会增加参数量和计算量,可以在最后一个点上大幅度增加实验和工作量,因为本身蒸馏也需要做大量实验。
目录
一.代码前提
(1)本文选取的老师模型为yolov8s,学生为剪枝完的yolov8s
(2)本文使用的蒸馏方法包括mgd,cwd
(3)使用前下载必须的包,并且把数据集放在datasets文件夹中,最后替换data.yaml中分类。
二.蒸馏步骤
(1) 训练教师模型
(2) 训练学生模型
(3) 蒸馏训练
三.模型剪枝+蒸馏
(1)约束训练在我上一篇文章中提到,链接:yolov8剪枝
(2)约束训练后,先进行剪枝,使用prune.py。替换模型位置,直接运行。
(3)剪完枝后,效果不一定好,所以使用剪枝完后的模型,继续训练:
一.代码前提
(1)本文选取的老师模型为yolov8s,学生为剪枝完的yolov8s
(2)本文使用的蒸馏方法包括mgd,cwd
(3)使用前下载必须的包,并且把数据集放在datasets文件夹中,最后替换data.yaml中分类。
本文代码已经上传到GitHub,链接:yolov8_蒸馏
使用不妨加个关注,后续还会加入Vit(vision transformer),替换loss等提升精度的方法。
二.蒸馏步骤
(1) 训练教师模型
打开文件中train.py,替换模型文件位置。开始训练,达到理想目标就停止。
import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():model = YOLO("yolov8s.pt")model.train(data="data.yaml", Distillation = None, loss_type='None', amp=False, imgsz=640, epochs=50, batch=20, device=0, workers=0)if __name__ == '__main__':main()
(2) 训练学生模型
打开文件中train.py,替换模型文件位置。我这边使用的是剪枝后的yolov8s模型,具体轻量化剪枝步骤可见本文最后。
import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():model_s = YOLO("./runs/detect/prune/weights/prune.pt")model_s.train(data="data.yaml", Distillation = None, loss_type='None', amp=False, imgsz=640, epochs=50, batch=20, device=0, workers=0)if __name__ == '__main__':main()
(3) 蒸馏训练
打开文件中train_distillation.py,替换老师与学生模型文件位置。两种蒸馏方法可以选择:cwd和mgd。
import os
from ultralytics import YOLO
import torchos.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():model_t = YOLO('runs/detect/yolov8s/weights/best.pt') # the teacher modelmodel_s = YOLO('runs/detect/prune/weights/best.pt') # the student model"""Attributes:Distillation: the distillation modelloss_type: mgd, cwdamp: Automatic Mixed Precision"""model_s.train(data="data.yaml", Distillation=model_t.model, loss_type='mgd', amp=False, imgsz=640, epochs=100,batch=20, device=0, workers=0, lr0=0.001)if __name__ == '__main__':main()
现在先不进行训练,打开文件夹yolo_project_distillation\ultralytics\engine\trainer.py
在类FeatureLoss中,函数forward大概162行处打一个断点,进行调试。代码位置:
def forward(self, y_s, y_t):assert len(y_s) == len(y_t)tea_feats = []stu_feats = []for idx, (s, t) in enumerate(zip(y_s, y_t)):# change ---if self.distiller == 'cwd':s = self.align_module[idx](s)s = self.norm[idx](s)else:s = self.norm1[idx](s)t = self.norm[idx](t)tea_feats.append(t)stu_feats.append(s)loss = self.feature_loss(stu_feats, tea_feats)return self.loss_weight * loss
调试运行,查看变量中学生模型y_s和老师模型y_t的张量大小。把通道数记下来,写在类Distillation_loss的
channels_s = [256, 480, 256, 64, 143, 229][-le:]channels_t = [256, 512, 256, 128, 256, 512][-le:]
这边总共有六个,刚好对应模型的六个层的通道数。
替换完成后,应该就可以进行训练了。训练不好的话,再来评论区找我吧。
三.模型剪枝+蒸馏
(1)约束训练在我上一篇文章中提到,链接:yolov8剪枝
(2)约束训练后,先进行剪枝,使用prune.py。替换模型位置,直接运行。
from ultralytics import YOLO
import torch
from ultralytics.nn.modules import Bottleneck, Conv, C2f, SPPF, Detect
from copy import deepcopy# Load a model
yolo = YOLO("./runs/detect/yolov8s/weights/last.pt")
# Save model address
res_dir = "./runs/detect/prune/weights/prune.pt"
# Pruning rate
factor = 0.75yolo.info()
model = yolo.model
ws = []
bs = []for name, m in model.named_modules():if isinstance(m, torch.nn.BatchNorm2d):w = m.weight.abs().detach()b = m.bias.abs().detach()ws.append(w)bs.append(b)# print(name, w.max().item(), w.min().item(), b.max().item(), b.min().item())# keepws = torch.cat(ws)
threshold = torch.sort(ws, descending=True)[0][int(len(ws) * factor)]
print(threshold)def prune_conv(conv1: Conv, conv2: Conv):gamma = conv1.bn.weight.data.detach()beta = conv1.bn.bias.data.detach()keep_idxs = []local_threshold = thresholdwhile len(keep_idxs) < 8:keep_idxs = torch.where(gamma.abs() >= local_threshold)[0]local_threshold = local_threshold * 0.5n = len(keep_idxs)# n = max(int(len(idxs) * 0.8), p)# print(n / len(gamma) * 100)# scale = len(idxs) / nconv1.bn.weight.data = gamma[keep_idxs]conv1.bn.bias.data = beta[keep_idxs]conv1.bn.running_var.data = conv1.bn.running_var.data[keep_idxs]conv1.bn.running_mean.data = conv1.bn.running_mean.data[keep_idxs]conv1.bn.num_features = nconv1.conv.weight.data = conv1.conv.weight.data[keep_idxs]conv1.conv.out_channels = nif conv1.conv.bias is not None:conv1.conv.bias.data = conv1.conv.bias.data[keep_idxs]if not isinstance(conv2, list):conv2 = [conv2]for item in conv2:if item is not None:if isinstance(item, Conv):conv = item.convelse:conv = itemconv.in_channels = nconv.weight.data = conv.weight.data[:, keep_idxs]def prune(m1, m2):if isinstance(m1, C2f): # C2f as a top convm1 = m1.cv2if not isinstance(m2, list): # m2 is just one modulem2 = [m2]for i, item in enumerate(m2):if isinstance(item, C2f) or isinstance(item, SPPF):m2[i] = item.cv1prune_conv(m1, m2)for name, m in model.named_modules():if isinstance(m, Bottleneck):prune_conv(m.cv1, m.cv2)seq = model.model
for i in range(3, 9):if i in [6, 4, 9]: continueprune(seq[i], seq[i + 1])detect: Detect = seq[-1]
last_inputs = [seq[15], seq[18], seq[21]]
colasts = [seq[16], seq[19], None]
for last_input, colast, cv2, cv3 in zip(last_inputs, colasts, detect.cv2, detect.cv3):prune(last_input, [colast, cv2[0], cv3[0]])prune(cv2[0], cv2[1])prune(cv2[1], cv2[2])prune(cv3[0], cv3[1])prune(cv3[1], cv3[2])for name, p in yolo.model.named_parameters():p.requires_grad = True#yolo.val(workers=0) # 剪枝模型进行验证 yolo.val(workers=0)
yolo.info()
# yolo.export(format="onnx") # 导出为onnx文件
# yolo.train(data="./data/data_nc5/data_nc5.yaml", epochs=100) # 剪枝后直接训练微调
ckpt = {'epoch': -1,'best_fitness': None,'model': yolo.ckpt['ema'],'ema': None,'updates': None,'optimizer': None,'train_args': yolo.ckpt["train_args"], # save as dict'date': None,'version': '8.0.142'}torch.save(yolo.ckpt, res_dir)
(3)剪完枝后,效果不一定好,所以使用剪枝完后的模型,继续训练:
import os
from ultralytics import YOLO
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'def main():# model = YOLO(r'ultralytics/cfg/models/v8/yolov8s.yaml').load('runs/detect/yolov8s/weights/best.pt')model_s = YOLO("./runs/detect/prune/weights/prune.pt")model_s.train(data="data.yaml", Distillation = None, loss_type='None', amp=False, imgsz=640, epochs=50, batch=20, device=0, workers=0)if __name__ == '__main__':main()
------------------------------------------over!!!!!!!!!!!!!!!!!------------------------------
相关文章:
yolov8蒸馏(附代码-免费)
首先蒸馏是什么? 模型蒸馏(Model Distillation)是一种用于在计算机视觉中提高模型性能和效率的技术。在模型蒸馏中,通常存在两个模型,即“教师模型”和“学生模型”。 为什么需要蒸馏? 在不增加模型计算…...
Flink-StarRocks详解:第五部分查询数据湖(第55天)
系列文章目录 4.查询数据湖 4.1 Catalog 4.1.1 概述 4.1.1.1 基本概念 4.1.1.2 Catalog 4.1.1.3 访问Catalog 4.1.2 Default catalog 4.1.3 External Catalog 4.2 文件外部表 4.2.1 使用限制 4.2.2 开源版本语法 4.2.3 阿里云版本 5. 查询及优化 文章目录 系列文章目录前言4.查…...
【MySQL】常用数据类型
目录 数据类型 数据类型分类 数值类型 tinyint类型 bit类型 小数类型 float decimal 字符串类型 char varchar 日期和时间类型 enum和set 数据类型 数据类型分类 数值类型 tinyint类型 tinyint类型只占用一个字节类似于编程语言中的字符char。有带符号和无符号两…...
创建第一个rust tauri项目
安装nodejs curl -sL https://deb.nodesource.com/setup_20.x | sudo bash node -vproxychains4 npm create tauri-applatest✔ Project name tauri-app ✔ Choose which language to use for your frontend TypeScript / JavaScript - (pnpm, yarn, npm, bun) ✔ Choose yo…...
【课程总结】day19(中):Transformer架构及注意力机制了解
前言 本章内容,我们将从注意力的基础概念入手,结合Transformer架构,由宏观理解其运行流程,然后逐步深入了解多头注意力、多头掩码注意力、融合注意力等概念及作用。 注意力机制(Attension) 背景 深度学…...
4.4 标准正交基和格拉姆-施密特正交化
本节的两个目标就是为什么和怎么做(why and how)。首先是知道为什么正交性很好:因为它们的点积为零; A T A A^TA ATA 是对角矩阵;在求 x ^ \boldsymbol{\hat x} x^ 和 p A x ^ \boldsymbol pA\boldsymbol{\hat x} pAx^ 时也会很简单。第二…...
spring事务的8种失效的场景,7种传播行为
Spring事务大部分都是通过AOP实现的,所以事务失效的场景大部分都是因为AOP失效,AOP基于动态代理实现的 1.方法没有被public修饰 原因:Spring会为方法创建代理、AOP添加事务通知前提条件是该方法时public的。 2.类没有被Spring容器所托管 …...
进程的虚拟内存地址(C++程序的内存分区)
严谨的说法: 一个C、C程序实际就是一个进程,那么C的内存分区,实际上就是一个进程的内存分区,这样的话就可以分为两个大模块,从上往下,也就是0地址一直往下,假如是x86的32位Linux系统,…...
英特尔移除超线程与AMD多线程性能对比
#### 英特尔Lunar Lake架构取消超线程 在英特尔宣布Lunar Lake架构时,一个令人惊讶的消息是下一代轻薄优化架构将移除Hyper-Threading(超线程,简称SMT)。而AMD最新的Zen 5/Zen5C多线程基准测试结果显示,该特性依然为A…...
定期自动巡检,及时发现机房运维管理中的潜在问题
随着信息化技术的迅猛发展,机房作为企业数据处理与存储的核心场所,其运维管理的复杂性和挑战性也与日俱增。为确保机房设备的稳定运行和业务的连续性,运维团队必须定期进行全面的巡检。然而,传统的手工巡检方式不仅效率低下&#…...
八股文(一)
1. 为什么不使用本地缓存,而使用Redis? Redis相比于本地缓存(如JVM中的缓存)有以下几个显著优势: 高性能与低延迟:Redis是一个基于内存的数据库,其读写性能非常高,通常可以达到几万…...
灵茶八题 - 子数组 ^w^
灵茶八题 - 子数组 w 题目描述 给你一个长为 n n n 的数组 a a a,输出它的所有连续子数组的异或和的异或和。 例如 a [ 1 , 3 ] a[1,3] a[1,3] 有三个连续子数组 [ 1 ] , [ 3 ] , [ 1 , 3 ] [1],[3],[1,3] [1],[3],[1,3],异或和分别为 1 , 3 , …...
git clone private repo
Create personal access token Clone repo $ git clone https://<user_name>:<personal_access_tokens>github.com/<user_name>/<repo_name>.git...
vue3+ts+pinia+vant-项目搭建
1.pnpm介绍 npm和pnpm都是JavaScript的包管理工具,用于自动化安装、配置、更新和卸载npm包依赖。 pnpm节省了大量的磁盘空间并提高了安装速度:使用一个内容寻址的文件存储方式,如果多个项目使用相同的包版本,pnpm会存储单个副本…...
自动化测试概念篇
目录 一、自动化 1.1 自动化概念 1.2 自动化分类 1.3 自动化测试金字塔 二、web自动化测试 2.1 驱动 2.2 安装驱动管理 三、selenium 3.1 ⼀个简单的web自动化示例 3.2 selenium驱动浏览器的工作原理 一、自动化 1.1 自动化概念 在生活中: 自动洒水机&am…...
Mojo值的生命周期(Life of a value)详解
到目前为止,我们已经解释了 Mojo 如何允许您使用 Mojo 的所有权模型构建内存安全的高性能代码而无需手动管理内存。但是,Mojo 是为 系统编程而设计的,这通常需要对自定义数据类型进行手动内存管理。因此,Mojo 允许您根据需要执行此操作。需要明确的是,Mojo 没有引用计数器…...
java对接kimi详细说明,附完整项目
需求: 使用java封装kimi接口为http接口,并把调用kimi时的传参和返回数据,保存到mysql数据库中 自己记录一下,以做备忘。 具体步骤如下: 1.申请apiKey 访问:Moonshot AI - 开放平台使用手机号手机号验证…...
鸿蒙媒体开发【基于AVCodec能力的视频编解码】音频和视频
基于AVCodec能力的视频编解码 介绍 本实例基于AVCodec能力,提供基于视频编解码的视频播放和录制的功能。 视频播放的主要流程是将视频文件通过解封装->解码->送显/播放。视频录制的主要流程是相机采集->编码->封装成mp4文件。 播放支持的原子能力规…...
django集成pytest进行自动化单元测试实战
文章目录 一、引入pytest相关的包二、配置pytest1、将django的配置区分测试环境、开发环境和生产环境2、配置pytest 三、编写测试用例1、业务测试2、接口测试 四、进行测试 在Django项目中集成Pytest进行单元测试可以提高测试的灵活性和效率,相比于Django自带的测试…...
48天笔试训练错题——day40
目录 选择题 1. 2. 3. 4. 5. 6. 7. 8. 9. 10. 编程题 1. 发邮件 2. 最长上升子序列 选择题 1. DNS 劫持又称域名劫持,是指在劫持的网络范围内拦截域名解析的请求,分析请求的域名,把审查范围以外的请求放行,否则返回…...
LabVIEW在DCS中的优势
DCS(Distributed Control System,分布式控制系统)是一种用于工业过程控制的自动化系统。它将控制任务分散到多个控制单元中,通过网络连接和协调这些单元来实现对整个过程的监控和控制。DCS通常用于大型工业设施,如化工…...
英特尔:从硅谷创业到全球科技巨头
在科技行业,英特尔不仅是一个品牌,更是一种精神的象征。自1968年成立以来,英特尔经历了从初创企业到全球半导体产业领导者的华丽转变,其发展历程是科技创新与市场战略完美结合的典范。本文将深入探讨英特尔的发展历程,…...
生物计算与纳米技术:交汇前沿的科学领域
在当今科技迅猛发展的时代,生物计算和纳米技术作为前沿科技领域的两个重要方向,正在逐渐融合并带来深远的影响。生物计算涉及使用生物系统进行计算和数据存储,而纳米技术则关注制造极小尺度的电子器件和材料科学。本文将深入探讨这两个领域的…...
C#中栈和队列
在C#中,Stack和Queue是两种不同的集合类型,它们用于实现后进先出(LIFO)和先进先出(FIFO)的数据结构。 Stack(堆栈) Stack是一个后进先出的集合,这意味着最后一个添加到堆…...
技战法丨攻防演练防御——纵深、联动、诱捕(可搬运、可cv)
演习活动经过近几年的发展,攻击方的专业水平已大幅提高,逐渐呈现出隐秘化、APT化的趋势。其利用渗透技术对目标系统做深入探测,不断挖掘防守方网络系统的薄弱环节,这就要求防守方构建立体式纵深防护体系来抵御入侵。同时ÿ…...
1、 window平台opencv下载编译, 基于cmake和QT工具链
1. 环境准备,源码下载 1.1 前置环境 qt 下载安装cmake 安装,可参考: https://blog.csdn.net/qq_51355375/article/details/139186681 1.2 opencv 源码下载 官网地址: https://opencv.org/releases/ 下载源码: 2 …...
C++20三向比较运算符详解
三向比较运算符可以用于确定两个值的大小顺序,也被称为太空飞船操作符。使用单个表达式,它可以告诉一个值是否等于,小于或大于另一个值。 它返回的是类枚举(enumeration-like)类型,定义在 <compare> …...
监听机制与耗电量
一、监听机制与耗电量的关系 监听机制通常涉及对特定事件、状态或数据的持续监测。在移动设备和嵌入式系统中,这种监听可能由多种组件和传感器实现,如GPS、传感器(如加速度计、陀螺仪)、网络连接等。监听的频率越高,意…...
C++ //练习 16.29 修改你的Blob类,用你自己的shared_ptr代替标准库中的版本。
C Primer(第5版) 练习 16.29 练习 16.29 修改你的Blob类,用你自己的shared_ptr代替标准库中的版本。 环境:Linux Ubuntu(云服务器) 工具:vim 代码块 template <typename> class BlobP…...
【Mode Management】CanNm处于PBS状态下接收到一帧诊断报文DCM会响应吗
目录 前言 正文 1.CanNm从RSS状态切换到PBS状态行为分析 1.1.CanNm动作 1.2.ComM动作 1.3.DCM动作 1.4 小结 2.CanNM在PBS状态下收到一帧诊断报文行为分析 2.1.DCM动作1 2.2. ComM动作 2.3. DCM动作2 2.3. CanNm动作 2.4 问题 2.5 分析 3.总结 前言 我们知道EC…...
shopify可以做企业网站嘛/trinseo公司
一、下载 移步到阿里的github仓库地址:https://github.com/alibaba/nacos/releases 目前最新版是1.3.2 在release log下方有下载链接 选择对应的版本进行下载。我机器的操作系统是Windows10,所以下载的是zip格式,如图: 二、解…...
小程序开发公司网站源码下载/百度平台app
在使用U盘的过程中,不知道大家有没有碰到过这样的情况,将文件拷贝到U盘中后,在另一台电脑上打开U盘却找不到,设置新复制进去的也看不到,U盘空空如也。这些看不见的文件能恢复吗?答案是肯定的。1、先将杀毒软件或360之…...
微信小程序开发技术介绍/seo优化方法
1. 命令功能 cp --copy files and directories。复制文件或目录。 2. 语法格式 cp [option] source des cp [option] source directory cp [option] -t directory source 参数 参数说明 -a --archive 相当于drp结合使用 -d 如果文件为链接文件,复制链…...
长沙网站制作哪家好/厦门seo代理商
在通常工程项目中,基线库的创建通常包含以下一些目录: 1 创建一个工程目录 2 再该工程目录中创建Document, Source, test&Demo, tools, Resource目录 3 Document目录用于存放项目各个阶段所产生的文档,如:规格阶段,…...
海外网站建设推广/泸州网站优化推广
背景介绍: 最近公司启动了一个新的版本,我负责的一个的模块中有一个很复杂的新建的页面,表格里嵌套表格,三层数据,数据级联,组件较多.交互复杂, 下面是我做的一个简略图,为了保密我已将需求细节隐藏.(PS:没有table组件的墨刀,掩面苦笑,真难用) 从整个页面上分析,整体分为二部…...
南京网站建设设计/企业营销策划合同
GTP-U 5GS用户面GTP协议解析 GTP-U消息:GTP-U(用户平面)消息是用户平面消息或信令消息。用户平面消息用于在GTP-U实体之间承载用户数据分组。在网络节点之间发送的信令消息用于用于路径管理和隧道管理。 GTP-U peer:实现任何基于…...