一文详解大模型蒸馏工具TextBrewer
原文:https://zhuanlan.zhihu.com/p/648674584
本文分享自华为云社区《TextBrewer:融合并改进了NLP和CV中的多种知识蒸馏技术、提供便捷快速的知识蒸馏框架、提升模型的推理速度,减少内存占用》,作者:汀丶。
TextBrewer是一个基于PyTorch的、为实现NLP中的知识蒸馏任务而设计的工具包,
融合并改进了NLP和CV中的多种知识蒸馏技术,提供便捷快速的知识蒸馏框架,用于以较低的性能损失压缩神经网络模型的大小,提升模型的推理速度,减少内存占用。
1.简介
TextBrewer 为NLP中的知识蒸馏任务设计,融合了多种知识蒸馏技术,提供方便快捷的知识蒸馏框架。
主要特点:
- 模型无关:适用于多种模型结构(主要面向Transfomer结构)
- 方便灵活:可自由组合多种蒸馏方法;可方便增加自定义损失等模块
- 非侵入式:无需对教师与学生模型本身结构进行修改
- 支持典型的NLP任务:文本分类、阅读理解、序列标注等
TextBrewer目前支持的知识蒸馏技术有:
- 软标签与硬标签混合训练
- 动态损失权重调整与蒸馏温度调整
- 多种蒸馏损失函数: hidden states MSE, attention-based loss, neuron selectivity transfer, …
- 任意构建中间层特征匹配方案
- 多教师知识蒸馏
- …
TextBrewer的主要功能与模块分为3块:
- Distillers:进行蒸馏的核心部件,不同的distiller提供不同的蒸馏模式。目前包含GeneralDistiller, MultiTeacherDistiller, MultiTaskDistiller等
- Configurations and Presets:训练与蒸馏方法的配置,并提供预定义的蒸馏策略以及多种知识蒸馏损失函数
- Utilities:模型参数分析显示等辅助工具
用户需要准备:
- 已训练好的教师模型, 待蒸馏的学生模型
- 训练数据与必要的实验配置, 即可开始蒸馏
在多个典型NLP任务上,TextBrewer都能取得较好的压缩效果。相关实验见蒸馏效果。
2.TextBrewer结构
2.1 安装要求
- Python >= 3.6
- PyTorch >= 1.1.0
- TensorboardX or Tensorboard
- NumPy
- tqdm
- Transformers >= 2.0 (可选, Transformer相关示例需要用到)
- Apex == 0.1.0 (可选,用于混合精度训练)
- 从PyPI自动下载安装包安装:
pip install textbrewer
- 从源码文件夹安装:
git clone https://github.com/airaria/TextBrewer.git
pip install ./textbrewer
2.2工作流程
- Stage 1 : 蒸馏之前的准备工作:
- 训练教师模型
- 定义与初始化学生模型(随机初始化,或载入预训练权重)
- 构造蒸馏用数据集的dataloader,训练学生模型用的optimizer和learning rate scheduler
- Stage 2 : 使用TextBrewer蒸馏:
- 构造训练配置(
TrainingConfig
)和蒸馏配置(DistillationConfig
),初始化distiller - 定义adaptor 和 callback ,分别用于适配模型输入输出和训练过程中的回调
- 调用distiller的train方法开始蒸馏
2.3 以蒸馏BERT-base到3层BERT为例展示TextBrewer用法
在开始蒸馏之前准备:
- 训练好的教师模型
teacher_model
(BERT-base),待训练学生模型student_model
(3-layer BERT) - 数据集
dataloader
,优化器optimizer
,学习率调节器类或者构造函数scheduler_class
和构造用的参数字典scheduler_args
使用TextBrewer蒸馏:
import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig#展示模型参数量的统计
print("\nteacher_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(teacher_model,max_level=3)
print (result)print("student_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(student_model,max_level=3)
print (result)#定义adaptor用于解释模型的输出
def simple_adaptor(batch, model_outputs):# model输出的第二、三个元素分别是logits和hidden statesreturn {'logits': model_outputs[1], 'hidden': model_outputs[2]}#蒸馏与训练配置
# 匹配教师和学生的embedding层;同时匹配教师的第8层和学生的第2层
distill_config = DistillationConfig(intermediate_matches=[ {'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1},{'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}])
train_config = TrainingConfig()#初始化distiller
distiller = GeneralDistiller(train_config=train_config, distill_config = distill_config,model_T = teacher_model, model_S = student_model, adaptor_T = simple_adaptor, adaptor_S = simple_adaptor)#开始蒸馏
with distiller:distiller.train(optimizer, dataloader, num_epochs=1, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None)
2.4蒸馏任务示例
- Transformers 4示例
- examples/notebook_examples/sst2.ipynb (英文): SST-2文本分类任务上的BERT模型训练与蒸馏。
- examples/notebook_examples/msra_ner.ipynb (中文): MSRA NER中文命名实体识别任务上的BERT模型训练与蒸馏。
- examples/notebook_examples/sqaudv1.1.ipynb (英文): SQuAD 1.1英文阅读理解任务上的BERT模型训练与蒸馏。
- examples/random_token_example: 一个可运行的简单示例,在文本分类任务上以随机文本为输入,演示TextBrewer用法。
- examples/cmrc2018_example (中文): CMRC 2018上的中文阅读理解任务蒸馏,并使用DRCD数据集做数据增强。
- examples/mnli_example (英文): MNLI任务上的英文句对分类任务蒸馏,并展示如何使用多教师蒸馏。
- examples/conll2003_example (英文): CoNLL-2003英文实体识别任务上的序列标注任务蒸馏。
- examples/msra_ner_example (中文): MSRA NER(中文命名实体识别)任务上,使用分布式数据并行训练的Chinese-ELECTRA-base模型蒸馏。
2.4.1蒸馏效果
我们在多个中英文文本分类、阅读理解、序列标注数据集上进行了蒸馏实验。实验的配置和效果如下。
- 模型
- 对于英文任务,教师模型为BERT-base-cased
- 对于中文任务,教师模型为HFL发布的RoBERTa-wwm-ext 与 Electra-base
我们测试了不同的学生模型,为了与已有公开结果相比较,除了BiGRU都是和BERT一样的多层Transformer结构。模型的参数如下表所示。需要注意的是,参数量的统计包括了embedding层,但不包括最终适配各个任务的输出层。
- 英文模型
- 中文模型
- T6的结构与DistilBERT[1], BERT6-PKD[2], BERT-of-Theseus[3] 相同。
- T4-tiny的结构与 TinyBERT[4] 相同。
- T3的结构与BERT3-PKD[2] 相同。
2.4.2 蒸馏配置
distill_config = DistillationConfig(temperature = 8, intermediate_matches = matches) #其他参数为默认值
不同的模型用的matches
我们采用了以下配置:
各种matches的定义在examples/matches/matches.py中。均使用GeneralDistiller进行蒸馏。
2.4.3训练配置
蒸馏用的学习率 lr=1e-4(除非特殊说明)。训练30~60轮。
2.4.4英文实验结果
在英文实验中,我们使用了如下三个典型数据集。
我们在下面两表中列出了DistilBERT, BERT-PKD, BERT-of-Theseus, TinyBERT 等公开的蒸馏结果,并与我们的结果做对比。
Public results:
Our results:
说明:
- 公开模型的名称后括号内是其等价的模型结构
- 蒸馏到T4-tiny的实验中,SQuAD任务上使用了NewsQA作为增强数据;CoNLL-2003上使用了HotpotQA的篇章作为增强数据
- 蒸馏到T12-nano的实验中,CoNLL-2003上使用了HotpotQA的篇章作为增强数据
2.4.5中文实验结果
在中文实验中,我们使用了如下典型数据集。
实验结果如下表所示。
说明:
- 以RoBERTa-wwm-ext为教师模型蒸馏CMRC 2018和DRCD时,不采用学习率衰减
- CMRC 2018和DRCD两个任务上蒸馏时他们互作为增强数据
- Electra-base的教师模型训练设置参考自Chinese-ELECTRA
- Electra-small学生模型采用预训练权重初始化
3.核心概念
3.1Configurations
TrainingConfig
和DistillationConfig
:训练和蒸馏相关的配置。
3.2Distillers
Distiller负责执行实际的蒸馏过程。目前实现了以下的distillers:
BasicDistiller
: 提供单模型单任务蒸馏方式。可用作测试或简单实验。GeneralDistiller
(常用): 提供单模型单任务蒸馏方式,并且支持中间层特征匹配,一般情况下推荐使用。MultiTeacherDistiller
: 多教师蒸馏。将多个(同任务)教师模型蒸馏到一个学生模型上。暂不支持中间层特征匹配。MultiTaskDistiller
:多任务蒸馏。将多个(不同任务)单任务教师模型蒸馏到一个多任务学生模型。BasicTrainer
:用于单个模型的有监督训练,而非蒸馏。可用于训练教师模型。
3.3用户定义函数
蒸馏实验中,有两个组件需要由用户提供,分别是callback 和 adaptor :
3.3.1Callback
回调函数。在每个checkpoint,保存模型后会被distiller
调用,并传入当前模型。可以借由回调函数在每个checkpoint评测模型效果。
3.3.2Adaptor
将模型的输入和输出转换为指定的格式,向distiller
解释模型的输入和输出,以便distiller
根据不同的策略进行不同的计算。在每个训练步,batch
和模型的输出model_outputs
会作为参数传递给adaptor
,adaptor
负责重新组织这些数据,返回一个字典。
更多细节可参见完整文档中的说明。
4.FAQ
Q: 学生模型该如何初始化?
A: 知识蒸馏本质上是“老师教学生”的过程。在初始化学生模型时,可以采用随机初始化的形式(即完全不包含任何先验知识),也可以载入已训练好的模型权重。例如,从BERT-base模型蒸馏到3层BERT时,可以预先载入RBT3模型权重(中文任务)或BERT的前三层权重(英文任务),然后进一步进行蒸馏,避免了蒸馏过程的“冷启动”问题。我们建议用户在使用时尽量采用已预训练过的学生模型,以充分利用大规模数据预训练所带来的优势。
Q: 如何设置蒸馏的训练参数以达到一个较好的效果?
A: 知识蒸馏的比有标签数据上的训练需要更多的训练轮数与更大的学习率。比如,BERT-base上训练SQuAD一般以lr=3e-5训练3轮左右即可达到较好的效果;而蒸馏时需要以lr=1e-4训练30~50轮。当然具体到各个任务上肯定还有区别,我们的建议仅是基于我们的经验得出的,仅供参考。
Q: 我的教师模型和学生模型的输入不同(比如词表不同导致input_ids不兼容),该如何进行蒸馏?
A: 需要分别为教师模型和学生模型提供不同的batch,参见完整文档中的 Feed Different batches to Student and Teacher, Feed Cached Values 章节。
Q: 我缓存了教师模型的输出,它们可以用于加速蒸馏吗?
A: 可以, 参见完整文档中的 Feed Different batches to Student and Teacher, Feed Cached Values 章节。
相关文章:

一文详解大模型蒸馏工具TextBrewer
原文:https://zhuanlan.zhihu.com/p/648674584 本文分享自华为云社区《TextBrewer:融合并改进了NLP和CV中的多种知识蒸馏技术、提供便捷快速的知识蒸馏框架、提升模型的推理速度,减少内存占用》,作者:汀丶。 TextBre…...

Go语言加Vue3零基础入门全栈班10 Go语言+gRPC用户微服务项目实战 2024年07月31日 课程笔记
概述 如果您没有Golang的基础,应该学习如下前置课程。 Golang零基础入门Golang面向对象编程Go Web 基础Go语言开发REST API接口_20240728Go语言操作MySQL开发用户管理系统API教程_20240729Redis零基础快速入门_20231227GoRedis开发用户管理系统API实战_20240730Mo…...

ChatGPT能代替网络作家吗?
最强AI视频生成:小说文案智能分镜智能识别角色和场景批量Ai绘图自动配音添加音乐一键合成视频百万播放量https://aitools.jurilu.com/ 当然可以!只要你玩写作AI玩得6,甚至可以达到某些大神的水平! 看看大神、小白、AI输出内容的区…...

Http自定义Header导致的跨域问题
最近写一个小项目,前后端分离,在调试过程中访问远程接口,出现了CORS问题,接口使用的laravel框架,于是添加了解决跨域的中间件,但是前端显示仍存在跨域问题,以为自己写的有问题,检查了…...
python 中 file.read(), file.readline()和file.readlines()区别和用法
python 中 file.read(), file.readline()和file.readlines()区别和用法 文章目录 python 中 file.read(), file.readline()和file.readlines()区别和用法1. file.read()2. file.readline()3. file.readlines()4. 总结5. 注意事项 file.read(), file.readline(), 和 file.readli…...
python 学习: np.pad
在NumPy中,np.pad函数用于对数组进行填充(padding),即在数组的边界处添加额外的值。这在图像处理、信号处理或任何需要扩展数据边界的场景中非常有用。 以下是np.pad函数的一些关键参数和使用示例: array:…...

等保2.0 | 人大金仓数据库测评
人大金仓数据库,全称为金仓数据库管理系统KingbaseES(简称:金仓数据库或KingbaseES),是北京人大金仓信息技术股份有限公司自主研制开发的具有自主知识产权的通用关系型数据库管理系统。以下是关于人大金仓数据库的详细…...

AIGC赋能智慧农业:用AI技术绘就作物生长新蓝图
( 于景鑫 国家农业信息化工程技术研究中心)随着人工智能技术的日新月异,AIGC(AI-Generated Content,AI生成内容)正在各行各业掀起一场革命性的浪潮。而在智慧农业领域,AIGC技术的应用也正迸发出耀眼的火花。特别是在作物生长管理方面,AIGC有望彻底改变传…...
yolov8蒸馏(附代码-免费)
首先蒸馏是什么? 模型蒸馏(Model Distillation)是一种用于在计算机视觉中提高模型性能和效率的技术。在模型蒸馏中,通常存在两个模型,即“教师模型”和“学生模型”。 为什么需要蒸馏? 在不增加模型计算…...

Flink-StarRocks详解:第五部分查询数据湖(第55天)
系列文章目录 4.查询数据湖 4.1 Catalog 4.1.1 概述 4.1.1.1 基本概念 4.1.1.2 Catalog 4.1.1.3 访问Catalog 4.1.2 Default catalog 4.1.3 External Catalog 4.2 文件外部表 4.2.1 使用限制 4.2.2 开源版本语法 4.2.3 阿里云版本 5. 查询及优化 文章目录 系列文章目录前言4.查…...

【MySQL】常用数据类型
目录 数据类型 数据类型分类 数值类型 tinyint类型 bit类型 小数类型 float decimal 字符串类型 char varchar 日期和时间类型 enum和set 数据类型 数据类型分类 数值类型 tinyint类型 tinyint类型只占用一个字节类似于编程语言中的字符char。有带符号和无符号两…...
创建第一个rust tauri项目
安装nodejs curl -sL https://deb.nodesource.com/setup_20.x | sudo bash node -vproxychains4 npm create tauri-applatest✔ Project name tauri-app ✔ Choose which language to use for your frontend TypeScript / JavaScript - (pnpm, yarn, npm, bun) ✔ Choose yo…...

【课程总结】day19(中):Transformer架构及注意力机制了解
前言 本章内容,我们将从注意力的基础概念入手,结合Transformer架构,由宏观理解其运行流程,然后逐步深入了解多头注意力、多头掩码注意力、融合注意力等概念及作用。 注意力机制(Attension) 背景 深度学…...

4.4 标准正交基和格拉姆-施密特正交化
本节的两个目标就是为什么和怎么做(why and how)。首先是知道为什么正交性很好:因为它们的点积为零; A T A A^TA ATA 是对角矩阵;在求 x ^ \boldsymbol{\hat x} x^ 和 p A x ^ \boldsymbol pA\boldsymbol{\hat x} pAx^ 时也会很简单。第二…...
spring事务的8种失效的场景,7种传播行为
Spring事务大部分都是通过AOP实现的,所以事务失效的场景大部分都是因为AOP失效,AOP基于动态代理实现的 1.方法没有被public修饰 原因:Spring会为方法创建代理、AOP添加事务通知前提条件是该方法时public的。 2.类没有被Spring容器所托管 …...

进程的虚拟内存地址(C++程序的内存分区)
严谨的说法: 一个C、C程序实际就是一个进程,那么C的内存分区,实际上就是一个进程的内存分区,这样的话就可以分为两个大模块,从上往下,也就是0地址一直往下,假如是x86的32位Linux系统,…...

英特尔移除超线程与AMD多线程性能对比
#### 英特尔Lunar Lake架构取消超线程 在英特尔宣布Lunar Lake架构时,一个令人惊讶的消息是下一代轻薄优化架构将移除Hyper-Threading(超线程,简称SMT)。而AMD最新的Zen 5/Zen5C多线程基准测试结果显示,该特性依然为A…...

定期自动巡检,及时发现机房运维管理中的潜在问题
随着信息化技术的迅猛发展,机房作为企业数据处理与存储的核心场所,其运维管理的复杂性和挑战性也与日俱增。为确保机房设备的稳定运行和业务的连续性,运维团队必须定期进行全面的巡检。然而,传统的手工巡检方式不仅效率低下&#…...
八股文(一)
1. 为什么不使用本地缓存,而使用Redis? Redis相比于本地缓存(如JVM中的缓存)有以下几个显著优势: 高性能与低延迟:Redis是一个基于内存的数据库,其读写性能非常高,通常可以达到几万…...
灵茶八题 - 子数组 ^w^
灵茶八题 - 子数组 w 题目描述 给你一个长为 n n n 的数组 a a a,输出它的所有连续子数组的异或和的异或和。 例如 a [ 1 , 3 ] a[1,3] a[1,3] 有三个连续子数组 [ 1 ] , [ 3 ] , [ 1 , 3 ] [1],[3],[1,3] [1],[3],[1,3],异或和分别为 1 , 3 , …...

测试微信模版消息推送
进入“开发接口管理”--“公众平台测试账号”,无需申请公众账号、可在测试账号中体验并测试微信公众平台所有高级接口。 获取access_token: 自定义模版消息: 关注测试号:扫二维码关注测试号。 发送模版消息: import requests da…...

2025 后端自学UNIAPP【项目实战:旅游项目】6、我的收藏页面
代码框架视图 1、先添加一个获取收藏景点的列表请求 【在文件my_api.js文件中添加】 // 引入公共的请求封装 import http from ./my_http.js// 登录接口(适配服务端返回 Token) export const login async (code, avatar) > {const res await http…...
【AI学习】三、AI算法中的向量
在人工智能(AI)算法中,向量(Vector)是一种将现实世界中的数据(如图像、文本、音频等)转化为计算机可处理的数值型特征表示的工具。它是连接人类认知(如语义、视觉特征)与…...

【OSG学习笔记】Day 16: 骨骼动画与蒙皮(osgAnimation)
骨骼动画基础 骨骼动画是 3D 计算机图形中常用的技术,它通过以下两个主要组件实现角色动画。 骨骼系统 (Skeleton):由层级结构的骨头组成,类似于人体骨骼蒙皮 (Mesh Skinning):将模型网格顶点绑定到骨骼上,使骨骼移动…...

C++使用 new 来创建动态数组
问题: 不能使用变量定义数组大小 原因: 这是因为数组在内存中是连续存储的,编译器需要在编译阶段就确定数组的大小,以便正确地分配内存空间。如果允许使用变量来定义数组的大小,那么编译器就无法在编译时确定数组的大…...

短视频矩阵系统文案创作功能开发实践,定制化开发
在短视频行业迅猛发展的当下,企业和个人创作者为了扩大影响力、提升传播效果,纷纷采用短视频矩阵运营策略,同时管理多个平台、多个账号的内容发布。然而,频繁的文案创作需求让运营者疲于应对,如何高效产出高质量文案成…...

云原生安全实战:API网关Kong的鉴权与限流详解
🔥「炎码工坊」技术弹药已装填! 点击关注 → 解锁工业级干货【工具实测|项目避坑|源码燃烧指南】 一、基础概念 1. API网关(API Gateway) API网关是微服务架构中的核心组件,负责统一管理所有API的流量入口。它像一座…...

什么是VR全景技术
VR全景技术,全称为虚拟现实全景技术,是通过计算机图像模拟生成三维空间中的虚拟世界,使用户能够在该虚拟世界中进行全方位、无死角的观察和交互的技术。VR全景技术模拟人在真实空间中的视觉体验,结合图文、3D、音视频等多媒体元素…...

Kubernetes 节点自动伸缩(Cluster Autoscaler)原理与实践
在 Kubernetes 集群中,如何在保障应用高可用的同时有效地管理资源,一直是运维人员和开发者关注的重点。随着微服务架构的普及,集群内各个服务的负载波动日趋明显,传统的手动扩缩容方式已无法满足实时性和弹性需求。 Cluster Auto…...
机器学习的数学基础:线性模型
线性模型 线性模型的基本形式为: f ( x ) ω T x b f\left(\boldsymbol{x}\right)\boldsymbol{\omega}^\text{T}\boldsymbol{x}b f(x)ωTxb 回归问题 利用最小二乘法,得到 ω \boldsymbol{\omega} ω和 b b b的参数估计$ \boldsymbol{\hat{\omega}}…...