预训练 BERT 使用 Hugging Face 和 PyTorch 在 AMD GPU 上
Pre-training BERT using Hugging Face & PyTorch on an AMD GPU — ROCm Blogs
2024年1月26日,作者:Vara Lakshmi Bayanagari.
这篇博客解释了如何从头开始使用 Hugging Face 库和 PyTorch 后端在 AMD GPU 上为英文语料(WikiText-103-raw-v1)预训练 BERT 基础模型的端到端过程。
你可以在 GitHub folder中找到与这篇博客相关的文件。
BERT简介
BERT是一种在2019年提出的语言表示模型。其模型架构基于一个transformer编码器,其中自注意力层对输入的每个token对进行注意力计算,整合了来自两个方向的上下文(因此称为BERT的“双向”特性)。在此之前,像ELMo和GPT这样的模型只使用从左到右的(单向)架构,这极大地限制了模型的表现力;模型性能依赖于微调。
本博客解释了BERT所采用的预训练任务,这些任务在通用语言理解评估(GLUE)基准测试中取得了最先进的成果。在接下来的章节中,我们将展示在PyTorch中的实现。
这篇BERT paper最先提出了一种新的预训练方法,称为掩码语言建模(MLM)。MLM随机掩盖输入的某些部分,并对一批输入进行训练以预测这些被掩盖的tokens。预训练期间,在对输入进行分词之后,15%的tokens被随机挑选。其中,80%被替换为一个`[MASK]`标记,10%被替换为一个随机标记,10%则保持不变。
在下面的示例中,MLM预处理方法如下:`dog`标记保持不变,`Golden`和`years`标记被掩盖,并且`and`标记被随机替换为`paper`标记。预训练的目标是使用`CategoricalCrossEntropy`损失来预测这些标记,以便模型学习语言的语法、模式和结构。
Input sentence: My dog is a Golden Retriever and his is 5 years oldAfter MLM: My dog is a [MASK] Retriever paper his is 5 [MASK] old
此外,为了捕捉句子之间的关系,超越掩码语言建模任务,论文提出了第二个预训练任务,称为下一个句子预测(NSP)。在不改变架构的情况下,论文证明了NSP有助于提升问答(QA)和自然语言推理(NLI)任务的结果。
这个任务不直接输入token流,而是输入一对句子的token,例如`A`和`B`,以及一个前置分类标记(`[CLS]`)。分类标记指示句对是随机组合的(label=0)还是`B`是`A`的下一个句子(label=1)。因此,NSP预训练是一种二元分类任务。
_IsNext_ Pair: [1] My dog is a Golden Retriever. He is five years old.Not _IsNext_ Pair: [0] My dog is a Golden Retriever. The next chapter in the book is a biography.
总之,数据集首先进行预处理以形成一对句子,然后进行分词,并最终随机掩盖某些tokens。预处理后的输入批次要么*填充*(使用`[PAD]`标记)或*修剪*(到_max_seq_length_超参数),以便所有输入元素在加载到BERT模型中之前都统一为相同的长度。BERT模型配有两个分类头:一个用于MLM(`num_cls_heads = vocab_size),另一个用于NSP(
num_cls_heads=2`)。来自两个预训练任务的分类损失之和用于训练BERT。
在多台 AMD GPU 上的实现
在开始之前,确保您已经满足以下要求:
-
在搭载 AMD GPU 的设备上安装 ROCm 兼容的 PyTorch。本实验在 ROCm 5.7.0 和 PyTorch 2.0.1 上进行了测试。
-
运行命令
pip install datasets transformers accelerate
以安装 Hugging Face 的相关库。 -
运行
accelerate config
命令以设置分布式训练参数,详见此处。在本实验中,我们使用了单节点上的八块 GPU 并行计算,运用了DistributedDataParallel
。
实现
Hugging Face 使用 Torch 作为大多数模型的默认后端,从而实现了这两个框架的良好结合。为了简化常规训练步骤并避免样板代码,Hugging Face 提供了一个名为 Trainer 的类,该类模仿了 PyTorch 的功能。类似地,Lightning AI 提供了 Trainer 类。此外,对于分布式训练,Hugging Face 可能更方便,因为代码中没有额外的配置设置,系统会根据 accelerate config
自动检测并利用所有 GPU 设备。然而,如果你希望进一步自定义你的模型并对加载预训练检查点做出额外修改,原生的 PyTorch 是更好的选择。这篇博客解释了使用 Hugging Face 的 transformers 库对 BERT 进行端到端预训练,同时提供了简化的数据预处理管道。
使用 Hugging Face 的 Trainer 进行 BERT 预训练可以用几行代码来总结。transformer 编码器、MLM 分类头和 NSP 分类头都打包在 Hugging Face 的 BertForPreTraining
模型中,该模型返回一个累积分类损失,如我们在 介绍 中所解释的。模型使用默认的 BERT base 配置参数(`NUM_LAYERS`、`ACT_FUNC`、`BATCH_SIZE`、`HIDDEN_SIZE`、`EMBED_DIM` 等)进行初始化。你可以从 Hugging Face 的 BertConfig
中导入这些参数。
那就是全部了吗?几乎。训练最关键的部分是数据预处理。预处理分为三个步骤:
-
将你的数据集重新组织为每个文档的句子字典。这对于从随机文档中选取随机句子以进行 NSP 任务非常有用。为此,可以对整个数据集使用简单的for循环。
-
使用 Hugging Face 的
AutoTokenizer
来对所有句子进行标记化。 -
使用另一个 for 循环,创建 50% 随机对和 50% 顺序对的句子对。
我已经对 WikiText-103-raw-v1
语料库(2,500 M单词)进行了上述的预处理步骤,并将生成的验证集放在这里。预处理的训练集已上传到 Hugging Face Hub。
接下来,导入 DataCollatorForLanguageModeling
收集器以运行 MLM 预处理,并获取掩码和句子分类标签。在使用 Trainer 类时,我们只需要访问 torch.utils.data.Dataset
和一个收集函数。与 TensorFlow 不同,Hugging Face 的 Trainer 会从数据集和收集器函数中创建数据加载器。为了演示,我们使用了有 3,000+ 句对的 Wikitext-103-raw-v1
验证集。
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
collater = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
# tokenized_dataset = datasets.load_from_disk(args.dataset_file)
tokenized_dataset_valid = datasets.load_from_disk('./wikiTokenizedValid.hf')
创建一个 TrainerArguments 实例,并传递所有必需的参数,如以下代码所示。这部分代码有助于在训练模型时抽象样板代码。该类很灵活,因为它提供了 100 多个参数来适应不同的训练模式;有关更多信息,请参阅 Hugging face transformers 页面。
你现在可以使用 t.train()
来训练模型了。你还可以通过将 resume_from_checkpoint=True
参数传递给 t.train()
来恢复训练。trainer 类会提取 output_dir
文件夹中的最新检查点,并继续训练直到达到总共 num_train_epochs
。
train_args = TrainingArguments(output_dir=args.output_dir, overwrite_output_dir =True, per_device_train_batch_size =args.BATCH_SIZE, logging_first_step=True,logging_strategy='epoch', evaluation_strategy = 'epoch', save_strategy ='epoch', num_train_epochs=args.EPOCHS,save_total_limit=50)
t = Trainer(model, args = train_args, data_collator=collater, train_dataset = tokenized_dataset, optimizers=(optimizer, None), eval_dataset = tokenized_dataset_valid)
t.train()#resume_from_checkpoint=True)
上述模型使用Adam优化器(`learning_rate=2e-5`)和`per_device_train_batch_size=8`进行了大约400个epoch的训练。在一块AMD GPU(MI210,ROCm 5.7.0,PyTorch 2.0.1)上,使用3,000+句对的验证集进行预训练仅需几个小时。训练曲线如图1所示。可以使用最佳模型检查点微调不同的数据集,并在各种NLP任务上测试其表现。
完整的代码如下:
set_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument('--BATCH_SIZE', type=int, default = 8) # 32 is the global batch size, since I use 8 GPUs
parser.add_argument('--EPOCHS', type=int, default=200)
parser.add_argument('--train', action='store_true')
parser.add_argument('--dataset_file', type=str, default= './wikiTokenizedValid.hf')
parser.add_argument('--lr', default = 0.00005, type=float)
parser.add_argument('--output_dir', default = './acc_valid/')
args = parser.parse_args()accelerator = Accelerator()if args.train:args.dataset_file = './wikiTokenizedTrain.hf'args.output_dir = './acc/'
print(args)tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
collater = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
tokenized_dataset = datasets.load_from_disk(args.dataset_file)
tokenized_dataset_valid = datasets.load_from_disk('./wikiTokenizedValid.hf')model = BertForPreTraining(BertConfig.from_pretrained("bert-base-cased"))
optimizer = torch.optim.Adam(model.parameters(), lr =args.lr)device = accelerator.device
model.to(accelerator.device)
train_args = TrainingArguments(output_dir=args.output_dir, overwrite_output_dir =True, per_device_train_batch_size =args.BATCH_SIZE, logging_first_step=True,logging_strategy='epoch', evaluation_strategy = 'epoch', save_strategy ='epoch', num_train_epochs=args.EPOCHS,save_total_limit=50)#, lr_scheduler_type=None)
t = Trainer(model, args = train_args, data_collator=collater, train_dataset = tokenized_dataset, optimizers=(optimizer, None), eval_dataset = tokenized_dataset_valid)
t.train()#resume_from_checkpoint=True)
推理
以一个示例文本为例,使用分词器将其转换为输入tokens,并通过collator生成一个掩码输入。
collater = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15, pad_to_multiple_of=128)
text="The author takes his own advice when it comes to writing: he seeks to ground his claims in clear, concrete examples. He shows specific examples of bad writing to help readers better grasp exactly what he’s critiquing"
tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
inp = collater([tokens])
inp['attention_mask'] = torch.where(inp['input_ids']==0,0,1)
使用预训练的权重初始化模型并进行推理。你将看到模型生成的随机tokens没有上下文意义。
config = BertConfig.from_pretrained('bert-base-cased')
model = BertForPreTraining.from_pretrained('./acc_valid/checkpoint-19600/')
model.eval()
out = model(inp['input_ids'], inp['attention_mask'], labels=inp['labels'])print('Input: ', tokenizer.decode(inp['input_ids'][0][:30]), '\n')
print('Output: ', tokenizer.decode(torch.argmax(out[0], -1)[0][:30]))
输入和输出如下所示。该模型在一个非常小的数据集(3,000多句子)上进行了训练;你可以通过在更大的数据集上训练,例如`wikiText-103-raw-v1`的训练切分数据,来提高性能。
The author takes his own advice when it comes to writing : he [MASK] to ground his claims in clear, concrete examples. He shows specific examples of bad
The Churchill takes his own, when it comes to writing : he continued to ground his claims in clear, this examples. He shows is examples of bad
源代码存储在这个 GitHub 文件夹。
结论
我们所描述的预训练BERT基础模型的过程可以很容易地扩展到不同大小的BERT版本以及不同的数据集。我们使用Hugging Face Trainer和PyTorch后端在AMD GPU上训练了我们的模型。对于训练,我们使用了`wikiText-103-raw-v1`数据集的验证集,但这可以很容易地替换为训练集,只需下载我们在Hugging Face Hub上的仓库中托管的预处理和标记化的训练文件Hugging Face Hub.
在本文中,我们通过MLM和NSP预训练任务复制了BERT的预训练过程,这与许多公共平台上仅使用MLM的方法不同。此外,我们没有使用数据集的小部分,而是预处理并上传了整个数据集到Hub上供您方便使用。在未来的文章中,我们将讨论在多个AMD GPU上使用数据并行和分布式策略来训练各种机器学习应用。
相关文章:

预训练 BERT 使用 Hugging Face 和 PyTorch 在 AMD GPU 上
Pre-training BERT using Hugging Face & PyTorch on an AMD GPU — ROCm Blogs 2024年1月26日,作者:Vara Lakshmi Bayanagari. 这篇博客解释了如何从头开始使用 Hugging Face 库和 PyTorch 后端在 AMD GPU 上为英文语料(WikiText-103-raw-v1)预训练…...

鸿蒙是必经之路
少了大嘴的发布会,老实讲有点让人昏昏入睡。关于技术本身的东西,放在后面。 我想想来加把油~ 鸿蒙发布后褒贬不一,其中很多人不太看好鸿蒙,一方面是开源性、一方面是南向北向的利益问题。 不说技术的领先点,我只扯扯…...

Java项目实战II基于微信小程序的马拉松报名系统(开发文档+数据库+源码)
目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者,专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 马拉松运动…...

家用wifi的ip地址固定吗?换wifi就是换ip地址吗
在探讨家用WiFi的IP地址是否固定,以及换WiFi是否就意味着换IP地址这两个问题时,我们首先需要明确几个关键概念:IP地址、家用WiFi网络、以及它们之间的相互作用。 一、家用WiFi的IP地址固定性 家用WiFi环境中的IP地址通常涉及两类:…...

codeforces _ 补题
C. Ball in Berland 传送门:Problem - C - Codeforces 题意: 思路:容斥原理 考虑 第 i 对情侣组合 ,男生为 a ,女生为 b ,那么考虑与之匹配的情侣 必须没有 a | b ,一共有 k 对情侣&#x…...

DataSophon集成ApacheImpala的过程
注意: 本次安装操作系统环境为Anolis8.9(Centos7和Centos8应该也一样) DataSophon版本为DDP-1.2.1 整合的安装包我放网盘了: 通过网盘分享的文件:impala-4.4.1.tar.gz等2个文件 链接: https://pan.baidu.com/s/18KfkO_BEFa5gVcc16I-Yew?pwdza4k 提取码: za4k 1…...
深入探讨TCP/IP协议基础
在当今数字化的时代,计算机网络已经成为人们生活和工作中不可或缺的一部分。而 TCP/IP 协议作为计算机网络的核心协议,更是支撑着全球互联网的运行。本文将深入探讨常见的 TCP/IP 协议基础,带你了解计算机网络的奥秘。 一、计算机网络概述 计…...

《Windows PE》7.4 资源表应用
本节我们将通过两个示例程序,演示对PE文件内图标资源的置换与提取。 本节必须掌握的知识点: 更改图标 提取图标资源 7.4.1 更改图标 让我们来做一个实验,替换PE文件中现有的图标。如果手工替换,一定是先找到资源表,…...
【重生之我要苦学C语言】猜数字游戏和关机程序的整合
今天来把学过的猜数字游戏和关机程序来整合一下 如果有不明白的可以看往期的博客 废话不多说,上代码: #define _CRT_SECURE_NO_WARNINGS #include <stdio.h> #include <time.h> #include <stdlib.h> #include <string.h> void…...

基于centos7脚本一键部署gpmall商城
基于centos7脚本一键部署单节点gpmall商城,该商城可单节点,可集群,可高可用集群部署,VMware17,虚拟机IP:192.168.200.100 将软件包解压到/root目录 [rootlocalhost ~]# ls dist …...
Mac book英特尔系列?M系列?两者有什么区别呢
众所周知,Mac book有M系列,搭载的是苹果自研的M芯片,也有着英特尔系列,搭载的是英特尔的处理器,虽然从 2020 年开始,苹果公司逐步推出了自家研发的 M 系列芯片,并逐渐将 MacBook 产品线过渡到 M…...
Python unstructured库详解:partition_pdf函数完整参数深度解析
Python unstructured库详解:partition_pdf函数完整参数深度解析 1. 简介2. 基础文件处理参数2.1 文件输入参数2.2 页面处理参数 3. 文档解析策略3.1 strategy参数详解3.2 策略选择建议 4. 表格处理参数4.1 表格结构推断 5. 语言处理参数5.1 语言设置 6. 图像处理参数…...

<项目代码>YOLOv8路面病害识别<目标检测>
YOLOv8是一种单阶段(one-stage)检测算法,它将目标检测问题转化为一个回归问题,能够在一次前向传播过程中同时完成目标的分类和定位任务。相较于两阶段检测算法(如Faster R-CNN),YOLOv8具有更高的…...

广告牌和标签学习
效果: 知识学习: entities添加标签label和广告牌billboard label: text:文本添加 font:字体大小和字体类型 fillColor:字体颜色 outlineColor:字体外轮廓颜色 outlineWidth:字体外轮…...

GDB 从裸奔到穿戴整齐
无数次被问道:你在终端下怎么调试更高效?或者怎么在 Vim 里调试?好吧,今天统一回答下,我从来不在 vim 里调试,因为它还不成熟。那除了命令行 GDB 裸奔以外,终端下还有没有更高效的方法ÿ…...

WPF的触发器(Trigger)
WPF(Windows Presentation Foundation)是微软.NET框架的一部分,用于构建Windows客户端应用程序。在WPF中,触发器(Triggers)是一种强大的功能,允许开发者根据控件的状态或属性值来动态改变控件的…...

全能大模型GPT-4o体验和接入教程
GPT-4o体验和接入教程 前言一、原生API二、Python LangchainSpring AI总结 前言 Open AI发布了产品GPT-4o,o表示"omni",全能的意思。 GPT-4o可以实时对音频、视觉和文本进行推理,响应时间平均为 320 毫秒,和人类之间对…...
详解Apache版本、新功能和技术前景
文章目录 一、 版本溯源二、新功能和特性举例1. 模块化和可扩展性增强2. 多处理模块(MPMs)3. 异步支持4. 更细粒度的日志级别控制5. 通用表达式解析器6. HTTP/2支持7. Server Push8. Early Hints9. 更好的SSL/TLS支持10. 更安全的默认设置 三、 技术前景…...
Docker Redis集群3主3从模式
主从集群 docker run -d --name redis-node1 --net host --privilegedtrue -v /home/redis/node1:/data redis:7.0 --cluster-enabled yes --appendonly yes --port 9371docker run -d --name redis-node2 --net host --privilegedtrue -v /home/redis/node2:/data redis:7.0 …...

【Go语言】
type关键字的用法 定义结构体定义接口定义类型别名类型定义类型判断 别名实际上是为了更好地理解代码/ 这里要分点进行记录 使用传值的例子,当两个类型不一样需要进行类型转换 type Myint int // 自定义类型,基于已有的类型自定义一个类型type Myin…...
浅谈 React Hooks
React Hooks 是 React 16.8 引入的一组 API,用于在函数组件中使用 state 和其他 React 特性(例如生命周期方法、context 等)。Hooks 通过简洁的函数接口,解决了状态与 UI 的高度解耦,通过函数式编程范式实现更灵活 Rea…...
前端倒计时误差!
提示:记录工作中遇到的需求及解决办法 文章目录 前言一、误差从何而来?二、五大解决方案1. 动态校准法(基础版)2. Web Worker 计时3. 服务器时间同步4. Performance API 高精度计时5. 页面可见性API优化三、生产环境最佳实践四、终极解决方案架构前言 前几天听说公司某个项…...

成都鼎讯硬核科技!雷达目标与干扰模拟器,以卓越性能制胜电磁频谱战
在现代战争中,电磁频谱已成为继陆、海、空、天之后的 “第五维战场”,雷达作为电磁频谱领域的关键装备,其干扰与抗干扰能力的较量,直接影响着战争的胜负走向。由成都鼎讯科技匠心打造的雷达目标与干扰模拟器,凭借数字射…...

【C++特殊工具与技术】优化内存分配(一):C++中的内存分配
目录 一、C 内存的基本概念 1.1 内存的物理与逻辑结构 1.2 C 程序的内存区域划分 二、栈内存分配 2.1 栈内存的特点 2.2 栈内存分配示例 三、堆内存分配 3.1 new和delete操作符 4.2 内存泄漏与悬空指针问题 4.3 new和delete的重载 四、智能指针…...

Golang——6、指针和结构体
指针和结构体 1、指针1.1、指针地址和指针类型1.2、指针取值1.3、new和make 2、结构体2.1、type关键字的使用2.2、结构体的定义和初始化2.3、结构体方法和接收者2.4、给任意类型添加方法2.5、结构体的匿名字段2.6、嵌套结构体2.7、嵌套匿名结构体2.8、结构体的继承 3、结构体与…...

Git 3天2K星标:Datawhale 的 Happy-LLM 项目介绍(附教程)
引言 在人工智能飞速发展的今天,大语言模型(Large Language Models, LLMs)已成为技术领域的焦点。从智能写作到代码生成,LLM 的应用场景不断扩展,深刻改变了我们的工作和生活方式。然而,理解这些模型的内部…...

DBLP数据库是什么?
DBLP(Digital Bibliography & Library Project)Computer Science Bibliography是全球著名的计算机科学出版物的开放书目数据库。DBLP所收录的期刊和会议论文质量较高,数据库文献更新速度很快,很好地反映了国际计算机科学学术研…...

VisualXML全新升级 | 新增数据库编辑功能
VisualXML是一个功能强大的网络总线设计工具,专注于简化汽车电子系统中复杂的网络数据设计操作。它支持多种主流总线网络格式的数据编辑(如DBC、LDF、ARXML、HEX等),并能够基于Excel表格的方式生成和转换多种数据库文件。由此&…...
第八部分:阶段项目 6:构建 React 前端应用
现在,是时候将你学到的 React 基础知识付诸实践,构建一个简单的前端应用来模拟与后端 API 的交互了。在这个阶段,你可以先使用模拟数据,或者如果你的后端 API(阶段项目 5)已经搭建好,可以直接连…...
用递归算法解锁「子集」问题 —— LeetCode 78题解析
文章目录 一、题目介绍二、递归思路详解:从决策树开始理解三、解法一:二叉决策树 DFS四、解法二:组合式回溯写法(推荐)五、解法对比 递归算法是编程中一种非常强大且常见的思想,它能够优雅地解决很多复杂的…...