一文详解大模型蒸馏工具TextBrewer
原文:https://zhuanlan.zhihu.com/p/648674584
本文分享自华为云社区《TextBrewer:融合并改进了NLP和CV中的多种知识蒸馏技术、提供便捷快速的知识蒸馏框架、提升模型的推理速度,减少内存占用》,作者:汀丶。
TextBrewer是一个基于PyTorch的、为实现NLP中的知识蒸馏任务而设计的工具包,
融合并改进了NLP和CV中的多种知识蒸馏技术,提供便捷快速的知识蒸馏框架,用于以较低的性能损失压缩神经网络模型的大小,提升模型的推理速度,减少内存占用。
1.简介
TextBrewer 为NLP中的知识蒸馏任务设计,融合了多种知识蒸馏技术,提供方便快捷的知识蒸馏框架。
主要特点:
- 模型无关:适用于多种模型结构(主要面向Transfomer结构)
- 方便灵活:可自由组合多种蒸馏方法;可方便增加自定义损失等模块
- 非侵入式:无需对教师与学生模型本身结构进行修改
- 支持典型的NLP任务:文本分类、阅读理解、序列标注等
TextBrewer目前支持的知识蒸馏技术有:
- 软标签与硬标签混合训练
- 动态损失权重调整与蒸馏温度调整
- 多种蒸馏损失函数: hidden states MSE, attention-based loss, neuron selectivity transfer, …
- 任意构建中间层特征匹配方案
- 多教师知识蒸馏
- …
TextBrewer的主要功能与模块分为3块:
- Distillers:进行蒸馏的核心部件,不同的distiller提供不同的蒸馏模式。目前包含GeneralDistiller, MultiTeacherDistiller, MultiTaskDistiller等
- Configurations and Presets:训练与蒸馏方法的配置,并提供预定义的蒸馏策略以及多种知识蒸馏损失函数
- Utilities:模型参数分析显示等辅助工具
用户需要准备:
- 已训练好的教师模型, 待蒸馏的学生模型
- 训练数据与必要的实验配置, 即可开始蒸馏
在多个典型NLP任务上,TextBrewer都能取得较好的压缩效果。相关实验见蒸馏效果。
2.TextBrewer结构
2.1 安装要求
- Python >= 3.6
- PyTorch >= 1.1.0
- TensorboardX or Tensorboard
- NumPy
- tqdm
- Transformers >= 2.0 (可选, Transformer相关示例需要用到)
- Apex == 0.1.0 (可选,用于混合精度训练)
- 从PyPI自动下载安装包安装:
pip install textbrewer
- 从源码文件夹安装:
git clone https://github.com/airaria/TextBrewer.git
pip install ./textbrewer
2.2工作流程
- Stage 1 : 蒸馏之前的准备工作:
- 训练教师模型
- 定义与初始化学生模型(随机初始化,或载入预训练权重)
- 构造蒸馏用数据集的dataloader,训练学生模型用的optimizer和learning rate scheduler
- Stage 2 : 使用TextBrewer蒸馏:
- 构造训练配置(
TrainingConfig
)和蒸馏配置(DistillationConfig
),初始化distiller - 定义adaptor 和 callback ,分别用于适配模型输入输出和训练过程中的回调
- 调用distiller的train方法开始蒸馏
2.3 以蒸馏BERT-base到3层BERT为例展示TextBrewer用法
在开始蒸馏之前准备:
- 训练好的教师模型
teacher_model
(BERT-base),待训练学生模型student_model
(3-layer BERT) - 数据集
dataloader
,优化器optimizer
,学习率调节器类或者构造函数scheduler_class
和构造用的参数字典scheduler_args
使用TextBrewer蒸馏:
import textbrewer
from textbrewer import GeneralDistiller
from textbrewer import TrainingConfig, DistillationConfig#展示模型参数量的统计
print("\nteacher_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(teacher_model,max_level=3)
print (result)print("student_model's parametrers:")
result, _ = textbrewer.utils.display_parameters(student_model,max_level=3)
print (result)#定义adaptor用于解释模型的输出
def simple_adaptor(batch, model_outputs):# model输出的第二、三个元素分别是logits和hidden statesreturn {'logits': model_outputs[1], 'hidden': model_outputs[2]}#蒸馏与训练配置
# 匹配教师和学生的embedding层;同时匹配教师的第8层和学生的第2层
distill_config = DistillationConfig(intermediate_matches=[ {'layer_T':0, 'layer_S':0, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1},{'layer_T':8, 'layer_S':2, 'feature':'hidden', 'loss': 'hidden_mse','weight' : 1}])
train_config = TrainingConfig()#初始化distiller
distiller = GeneralDistiller(train_config=train_config, distill_config = distill_config,model_T = teacher_model, model_S = student_model, adaptor_T = simple_adaptor, adaptor_S = simple_adaptor)#开始蒸馏
with distiller:distiller.train(optimizer, dataloader, num_epochs=1, scheduler_class=scheduler_class, scheduler_args = scheduler_args, callback=None)
2.4蒸馏任务示例
- Transformers 4示例
- examples/notebook_examples/sst2.ipynb (英文): SST-2文本分类任务上的BERT模型训练与蒸馏。
- examples/notebook_examples/msra_ner.ipynb (中文): MSRA NER中文命名实体识别任务上的BERT模型训练与蒸馏。
- examples/notebook_examples/sqaudv1.1.ipynb (英文): SQuAD 1.1英文阅读理解任务上的BERT模型训练与蒸馏。
- examples/random_token_example: 一个可运行的简单示例,在文本分类任务上以随机文本为输入,演示TextBrewer用法。
- examples/cmrc2018_example (中文): CMRC 2018上的中文阅读理解任务蒸馏,并使用DRCD数据集做数据增强。
- examples/mnli_example (英文): MNLI任务上的英文句对分类任务蒸馏,并展示如何使用多教师蒸馏。
- examples/conll2003_example (英文): CoNLL-2003英文实体识别任务上的序列标注任务蒸馏。
- examples/msra_ner_example (中文): MSRA NER(中文命名实体识别)任务上,使用分布式数据并行训练的Chinese-ELECTRA-base模型蒸馏。
2.4.1蒸馏效果
我们在多个中英文文本分类、阅读理解、序列标注数据集上进行了蒸馏实验。实验的配置和效果如下。
- 模型
- 对于英文任务,教师模型为BERT-base-cased
- 对于中文任务,教师模型为HFL发布的RoBERTa-wwm-ext 与 Electra-base
我们测试了不同的学生模型,为了与已有公开结果相比较,除了BiGRU都是和BERT一样的多层Transformer结构。模型的参数如下表所示。需要注意的是,参数量的统计包括了embedding层,但不包括最终适配各个任务的输出层。
- 英文模型
- 中文模型
- T6的结构与DistilBERT[1], BERT6-PKD[2], BERT-of-Theseus[3] 相同。
- T4-tiny的结构与 TinyBERT[4] 相同。
- T3的结构与BERT3-PKD[2] 相同。
2.4.2 蒸馏配置
distill_config = DistillationConfig(temperature = 8, intermediate_matches = matches) #其他参数为默认值
不同的模型用的matches
我们采用了以下配置:
各种matches的定义在examples/matches/matches.py中。均使用GeneralDistiller进行蒸馏。
2.4.3训练配置
蒸馏用的学习率 lr=1e-4(除非特殊说明)。训练30~60轮。
2.4.4英文实验结果
在英文实验中,我们使用了如下三个典型数据集。
我们在下面两表中列出了DistilBERT, BERT-PKD, BERT-of-Theseus, TinyBERT 等公开的蒸馏结果,并与我们的结果做对比。
Public results:
Our results:
说明:
- 公开模型的名称后括号内是其等价的模型结构
- 蒸馏到T4-tiny的实验中,SQuAD任务上使用了NewsQA作为增强数据;CoNLL-2003上使用了HotpotQA的篇章作为增强数据
- 蒸馏到T12-nano的实验中,CoNLL-2003上使用了HotpotQA的篇章作为增强数据
2.4.5中文实验结果
在中文实验中,我们使用了如下典型数据集。
实验结果如下表所示。
说明:
- 以RoBERTa-wwm-ext为教师模型蒸馏CMRC 2018和DRCD时,不采用学习率衰减
- CMRC 2018和DRCD两个任务上蒸馏时他们互作为增强数据
- Electra-base的教师模型训练设置参考自Chinese-ELECTRA
- Electra-small学生模型采用预训练权重初始化
3.核心概念
3.1Configurations
TrainingConfig
和DistillationConfig
:训练和蒸馏相关的配置。
3.2Distillers
Distiller负责执行实际的蒸馏过程。目前实现了以下的distillers:
BasicDistiller
: 提供单模型单任务蒸馏方式。可用作测试或简单实验。GeneralDistiller
(常用): 提供单模型单任务蒸馏方式,并且支持中间层特征匹配,一般情况下推荐使用。MultiTeacherDistiller
: 多教师蒸馏。将多个(同任务)教师模型蒸馏到一个学生模型上。暂不支持中间层特征匹配。MultiTaskDistiller
:多任务蒸馏。将多个(不同任务)单任务教师模型蒸馏到一个多任务学生模型。BasicTrainer
:用于单个模型的有监督训练,而非蒸馏。可用于训练教师模型。
3.3用户定义函数
蒸馏实验中,有两个组件需要由用户提供,分别是callback 和 adaptor :
3.3.1Callback
回调函数。在每个checkpoint,保存模型后会被distiller
调用,并传入当前模型。可以借由回调函数在每个checkpoint评测模型效果。
3.3.2Adaptor
将模型的输入和输出转换为指定的格式,向distiller
解释模型的输入和输出,以便distiller
根据不同的策略进行不同的计算。在每个训练步,batch
和模型的输出model_outputs
会作为参数传递给adaptor
,adaptor
负责重新组织这些数据,返回一个字典。
更多细节可参见完整文档中的说明。
4.FAQ
Q: 学生模型该如何初始化?
A: 知识蒸馏本质上是“老师教学生”的过程。在初始化学生模型时,可以采用随机初始化的形式(即完全不包含任何先验知识),也可以载入已训练好的模型权重。例如,从BERT-base模型蒸馏到3层BERT时,可以预先载入RBT3模型权重(中文任务)或BERT的前三层权重(英文任务),然后进一步进行蒸馏,避免了蒸馏过程的“冷启动”问题。我们建议用户在使用时尽量采用已预训练过的学生模型,以充分利用大规模数据预训练所带来的优势。
Q: 如何设置蒸馏的训练参数以达到一个较好的效果?
A: 知识蒸馏的比有标签数据上的训练需要更多的训练轮数与更大的学习率。比如,BERT-base上训练SQuAD一般以lr=3e-5训练3轮左右即可达到较好的效果;而蒸馏时需要以lr=1e-4训练30~50轮。当然具体到各个任务上肯定还有区别,我们的建议仅是基于我们的经验得出的,仅供参考。
Q: 我的教师模型和学生模型的输入不同(比如词表不同导致input_ids不兼容),该如何进行蒸馏?
A: 需要分别为教师模型和学生模型提供不同的batch,参见完整文档中的 Feed Different batches to Student and Teacher, Feed Cached Values 章节。
Q: 我缓存了教师模型的输出,它们可以用于加速蒸馏吗?
A: 可以, 参见完整文档中的 Feed Different batches to Student and Teacher, Feed Cached Values 章节。
相关文章:

一文详解大模型蒸馏工具TextBrewer
原文:https://zhuanlan.zhihu.com/p/648674584 本文分享自华为云社区《TextBrewer:融合并改进了NLP和CV中的多种知识蒸馏技术、提供便捷快速的知识蒸馏框架、提升模型的推理速度,减少内存占用》,作者:汀丶。 TextBre…...

Go语言加Vue3零基础入门全栈班10 Go语言+gRPC用户微服务项目实战 2024年07月31日 课程笔记
概述 如果您没有Golang的基础,应该学习如下前置课程。 Golang零基础入门Golang面向对象编程Go Web 基础Go语言开发REST API接口_20240728Go语言操作MySQL开发用户管理系统API教程_20240729Redis零基础快速入门_20231227GoRedis开发用户管理系统API实战_20240730Mo…...

ChatGPT能代替网络作家吗?
最强AI视频生成:小说文案智能分镜智能识别角色和场景批量Ai绘图自动配音添加音乐一键合成视频百万播放量https://aitools.jurilu.com/ 当然可以!只要你玩写作AI玩得6,甚至可以达到某些大神的水平! 看看大神、小白、AI输出内容的区…...

Http自定义Header导致的跨域问题
最近写一个小项目,前后端分离,在调试过程中访问远程接口,出现了CORS问题,接口使用的laravel框架,于是添加了解决跨域的中间件,但是前端显示仍存在跨域问题,以为自己写的有问题,检查了…...
python 中 file.read(), file.readline()和file.readlines()区别和用法
python 中 file.read(), file.readline()和file.readlines()区别和用法 文章目录 python 中 file.read(), file.readline()和file.readlines()区别和用法1. file.read()2. file.readline()3. file.readlines()4. 总结5. 注意事项 file.read(), file.readline(), 和 file.readli…...
python 学习: np.pad
在NumPy中,np.pad函数用于对数组进行填充(padding),即在数组的边界处添加额外的值。这在图像处理、信号处理或任何需要扩展数据边界的场景中非常有用。 以下是np.pad函数的一些关键参数和使用示例: array:…...

等保2.0 | 人大金仓数据库测评
人大金仓数据库,全称为金仓数据库管理系统KingbaseES(简称:金仓数据库或KingbaseES),是北京人大金仓信息技术股份有限公司自主研制开发的具有自主知识产权的通用关系型数据库管理系统。以下是关于人大金仓数据库的详细…...

AIGC赋能智慧农业:用AI技术绘就作物生长新蓝图
( 于景鑫 国家农业信息化工程技术研究中心)随着人工智能技术的日新月异,AIGC(AI-Generated Content,AI生成内容)正在各行各业掀起一场革命性的浪潮。而在智慧农业领域,AIGC技术的应用也正迸发出耀眼的火花。特别是在作物生长管理方面,AIGC有望彻底改变传…...
yolov8蒸馏(附代码-免费)
首先蒸馏是什么? 模型蒸馏(Model Distillation)是一种用于在计算机视觉中提高模型性能和效率的技术。在模型蒸馏中,通常存在两个模型,即“教师模型”和“学生模型”。 为什么需要蒸馏? 在不增加模型计算…...

Flink-StarRocks详解:第五部分查询数据湖(第55天)
系列文章目录 4.查询数据湖 4.1 Catalog 4.1.1 概述 4.1.1.1 基本概念 4.1.1.2 Catalog 4.1.1.3 访问Catalog 4.1.2 Default catalog 4.1.3 External Catalog 4.2 文件外部表 4.2.1 使用限制 4.2.2 开源版本语法 4.2.3 阿里云版本 5. 查询及优化 文章目录 系列文章目录前言4.查…...

【MySQL】常用数据类型
目录 数据类型 数据类型分类 数值类型 tinyint类型 bit类型 小数类型 float decimal 字符串类型 char varchar 日期和时间类型 enum和set 数据类型 数据类型分类 数值类型 tinyint类型 tinyint类型只占用一个字节类似于编程语言中的字符char。有带符号和无符号两…...
创建第一个rust tauri项目
安装nodejs curl -sL https://deb.nodesource.com/setup_20.x | sudo bash node -vproxychains4 npm create tauri-applatest✔ Project name tauri-app ✔ Choose which language to use for your frontend TypeScript / JavaScript - (pnpm, yarn, npm, bun) ✔ Choose yo…...

【课程总结】day19(中):Transformer架构及注意力机制了解
前言 本章内容,我们将从注意力的基础概念入手,结合Transformer架构,由宏观理解其运行流程,然后逐步深入了解多头注意力、多头掩码注意力、融合注意力等概念及作用。 注意力机制(Attension) 背景 深度学…...

4.4 标准正交基和格拉姆-施密特正交化
本节的两个目标就是为什么和怎么做(why and how)。首先是知道为什么正交性很好:因为它们的点积为零; A T A A^TA ATA 是对角矩阵;在求 x ^ \boldsymbol{\hat x} x^ 和 p A x ^ \boldsymbol pA\boldsymbol{\hat x} pAx^ 时也会很简单。第二…...
spring事务的8种失效的场景,7种传播行为
Spring事务大部分都是通过AOP实现的,所以事务失效的场景大部分都是因为AOP失效,AOP基于动态代理实现的 1.方法没有被public修饰 原因:Spring会为方法创建代理、AOP添加事务通知前提条件是该方法时public的。 2.类没有被Spring容器所托管 …...

进程的虚拟内存地址(C++程序的内存分区)
严谨的说法: 一个C、C程序实际就是一个进程,那么C的内存分区,实际上就是一个进程的内存分区,这样的话就可以分为两个大模块,从上往下,也就是0地址一直往下,假如是x86的32位Linux系统,…...

英特尔移除超线程与AMD多线程性能对比
#### 英特尔Lunar Lake架构取消超线程 在英特尔宣布Lunar Lake架构时,一个令人惊讶的消息是下一代轻薄优化架构将移除Hyper-Threading(超线程,简称SMT)。而AMD最新的Zen 5/Zen5C多线程基准测试结果显示,该特性依然为A…...

定期自动巡检,及时发现机房运维管理中的潜在问题
随着信息化技术的迅猛发展,机房作为企业数据处理与存储的核心场所,其运维管理的复杂性和挑战性也与日俱增。为确保机房设备的稳定运行和业务的连续性,运维团队必须定期进行全面的巡检。然而,传统的手工巡检方式不仅效率低下&#…...
八股文(一)
1. 为什么不使用本地缓存,而使用Redis? Redis相比于本地缓存(如JVM中的缓存)有以下几个显著优势: 高性能与低延迟:Redis是一个基于内存的数据库,其读写性能非常高,通常可以达到几万…...
灵茶八题 - 子数组 ^w^
灵茶八题 - 子数组 w 题目描述 给你一个长为 n n n 的数组 a a a,输出它的所有连续子数组的异或和的异或和。 例如 a [ 1 , 3 ] a[1,3] a[1,3] 有三个连续子数组 [ 1 ] , [ 3 ] , [ 1 , 3 ] [1],[3],[1,3] [1],[3],[1,3],异或和分别为 1 , 3 , …...
浅谈 React Hooks
React Hooks 是 React 16.8 引入的一组 API,用于在函数组件中使用 state 和其他 React 特性(例如生命周期方法、context 等)。Hooks 通过简洁的函数接口,解决了状态与 UI 的高度解耦,通过函数式编程范式实现更灵活 Rea…...
R语言AI模型部署方案:精准离线运行详解
R语言AI模型部署方案:精准离线运行详解 一、项目概述 本文将构建一个完整的R语言AI部署解决方案,实现鸢尾花分类模型的训练、保存、离线部署和预测功能。核心特点: 100%离线运行能力自包含环境依赖生产级错误处理跨平台兼容性模型版本管理# 文件结构说明 Iris_AI_Deployme…...

云启出海,智联未来|阿里云网络「企业出海」系列客户沙龙上海站圆满落地
借阿里云中企出海大会的东风,以**「云启出海,智联未来|打造安全可靠的出海云网络引擎」为主题的阿里云企业出海客户沙龙云网络&安全专场于5.28日下午在上海顺利举办,现场吸引了来自携程、小红书、米哈游、哔哩哔哩、波克城市、…...

【网络安全产品大调研系列】2. 体验漏洞扫描
前言 2023 年漏洞扫描服务市场规模预计为 3.06(十亿美元)。漏洞扫描服务市场行业预计将从 2024 年的 3.48(十亿美元)增长到 2032 年的 9.54(十亿美元)。预测期内漏洞扫描服务市场 CAGR(增长率&…...

微信小程序 - 手机震动
一、界面 <button type"primary" bindtap"shortVibrate">短震动</button> <button type"primary" bindtap"longVibrate">长震动</button> 二、js逻辑代码 注:文档 https://developers.weixin.qq…...
ffmpeg(四):滤镜命令
FFmpeg 的滤镜命令是用于音视频处理中的强大工具,可以完成剪裁、缩放、加水印、调色、合成、旋转、模糊、叠加字幕等复杂的操作。其核心语法格式一般如下: ffmpeg -i input.mp4 -vf "滤镜参数" output.mp4或者带音频滤镜: ffmpeg…...
【python异步多线程】异步多线程爬虫代码示例
claude生成的python多线程、异步代码示例,模拟20个网页的爬取,每个网页假设要0.5-2秒完成。 代码 Python多线程爬虫教程 核心概念 多线程:允许程序同时执行多个任务,提高IO密集型任务(如网络请求)的效率…...

多模态大语言模型arxiv论文略读(108)
CROME: Cross-Modal Adapters for Efficient Multimodal LLM ➡️ 论文标题:CROME: Cross-Modal Adapters for Efficient Multimodal LLM ➡️ 论文作者:Sayna Ebrahimi, Sercan O. Arik, Tejas Nama, Tomas Pfister ➡️ 研究机构: Google Cloud AI Re…...

ABAP设计模式之---“简单设计原则(Simple Design)”
“Simple Design”(简单设计)是软件开发中的一个重要理念,倡导以最简单的方式实现软件功能,以确保代码清晰易懂、易维护,并在项目需求变化时能够快速适应。 其核心目标是避免复杂和过度设计,遵循“让事情保…...
Java线上CPU飙高问题排查全指南
一、引言 在Java应用的线上运行环境中,CPU飙高是一个常见且棘手的性能问题。当系统出现CPU飙高时,通常会导致应用响应缓慢,甚至服务不可用,严重影响用户体验和业务运行。因此,掌握一套科学有效的CPU飙高问题排查方法&…...