预训练 BERT 使用 Hugging Face 和 PyTorch 在 AMD GPU 上
Pre-training BERT using Hugging Face & PyTorch on an AMD GPU — ROCm Blogs

2024年1月26日,作者:Vara Lakshmi Bayanagari.
这篇博客解释了如何从头开始使用 Hugging Face 库和 PyTorch 后端在 AMD GPU 上为英文语料(WikiText-103-raw-v1)预训练 BERT 基础模型的端到端过程。
你可以在 GitHub folder中找到与这篇博客相关的文件。
BERT简介
BERT是一种在2019年提出的语言表示模型。其模型架构基于一个transformer编码器,其中自注意力层对输入的每个token对进行注意力计算,整合了来自两个方向的上下文(因此称为BERT的“双向”特性)。在此之前,像ELMo和GPT这样的模型只使用从左到右的(单向)架构,这极大地限制了模型的表现力;模型性能依赖于微调。
本博客解释了BERT所采用的预训练任务,这些任务在通用语言理解评估(GLUE)基准测试中取得了最先进的成果。在接下来的章节中,我们将展示在PyTorch中的实现。
这篇BERT paper最先提出了一种新的预训练方法,称为掩码语言建模(MLM)。MLM随机掩盖输入的某些部分,并对一批输入进行训练以预测这些被掩盖的tokens。预训练期间,在对输入进行分词之后,15%的tokens被随机挑选。其中,80%被替换为一个`[MASK]`标记,10%被替换为一个随机标记,10%则保持不变。
在下面的示例中,MLM预处理方法如下:`dog`标记保持不变,`Golden`和`years`标记被掩盖,并且`and`标记被随机替换为`paper`标记。预训练的目标是使用`CategoricalCrossEntropy`损失来预测这些标记,以便模型学习语言的语法、模式和结构。
Input sentence: My dog is a Golden Retriever and his is 5 years oldAfter MLM: My dog is a [MASK] Retriever paper his is 5 [MASK] old
此外,为了捕捉句子之间的关系,超越掩码语言建模任务,论文提出了第二个预训练任务,称为下一个句子预测(NSP)。在不改变架构的情况下,论文证明了NSP有助于提升问答(QA)和自然语言推理(NLI)任务的结果。
这个任务不直接输入token流,而是输入一对句子的token,例如`A`和`B`,以及一个前置分类标记(`[CLS]`)。分类标记指示句对是随机组合的(label=0)还是`B`是`A`的下一个句子(label=1)。因此,NSP预训练是一种二元分类任务。
_IsNext_ Pair: [1] My dog is a Golden Retriever. He is five years old.Not _IsNext_ Pair: [0] My dog is a Golden Retriever. The next chapter in the book is a biography.
总之,数据集首先进行预处理以形成一对句子,然后进行分词,并最终随机掩盖某些tokens。预处理后的输入批次要么*填充*(使用`[PAD]`标记)或*修剪*(到_max_seq_length_超参数),以便所有输入元素在加载到BERT模型中之前都统一为相同的长度。BERT模型配有两个分类头:一个用于MLM(`num_cls_heads = vocab_size),另一个用于NSP(num_cls_heads=2`)。来自两个预训练任务的分类损失之和用于训练BERT。
在多台 AMD GPU 上的实现
在开始之前,确保您已经满足以下要求:
-
在搭载 AMD GPU 的设备上安装 ROCm 兼容的 PyTorch。本实验在 ROCm 5.7.0 和 PyTorch 2.0.1 上进行了测试。
-
运行命令
pip install datasets transformers accelerate以安装 Hugging Face 的相关库。 -
运行
accelerate config命令以设置分布式训练参数,详见此处。在本实验中,我们使用了单节点上的八块 GPU 并行计算,运用了DistributedDataParallel。
实现
Hugging Face 使用 Torch 作为大多数模型的默认后端,从而实现了这两个框架的良好结合。为了简化常规训练步骤并避免样板代码,Hugging Face 提供了一个名为 Trainer 的类,该类模仿了 PyTorch 的功能。类似地,Lightning AI 提供了 Trainer 类。此外,对于分布式训练,Hugging Face 可能更方便,因为代码中没有额外的配置设置,系统会根据 accelerate config 自动检测并利用所有 GPU 设备。然而,如果你希望进一步自定义你的模型并对加载预训练检查点做出额外修改,原生的 PyTorch 是更好的选择。这篇博客解释了使用 Hugging Face 的 transformers 库对 BERT 进行端到端预训练,同时提供了简化的数据预处理管道。
使用 Hugging Face 的 Trainer 进行 BERT 预训练可以用几行代码来总结。transformer 编码器、MLM 分类头和 NSP 分类头都打包在 Hugging Face 的 BertForPreTraining 模型中,该模型返回一个累积分类损失,如我们在 介绍 中所解释的。模型使用默认的 BERT base 配置参数(`NUM_LAYERS`、`ACT_FUNC`、`BATCH_SIZE`、`HIDDEN_SIZE`、`EMBED_DIM` 等)进行初始化。你可以从 Hugging Face 的 BertConfig 中导入这些参数。
那就是全部了吗?几乎。训练最关键的部分是数据预处理。预处理分为三个步骤:
-
将你的数据集重新组织为每个文档的句子字典。这对于从随机文档中选取随机句子以进行 NSP 任务非常有用。为此,可以对整个数据集使用简单的for循环。
-
使用 Hugging Face 的
AutoTokenizer来对所有句子进行标记化。 -
使用另一个 for 循环,创建 50% 随机对和 50% 顺序对的句子对。
我已经对 WikiText-103-raw-v1 语料库(2,500 M单词)进行了上述的预处理步骤,并将生成的验证集放在这里。预处理的训练集已上传到 Hugging Face Hub。
接下来,导入 DataCollatorForLanguageModeling 收集器以运行 MLM 预处理,并获取掩码和句子分类标签。在使用 Trainer 类时,我们只需要访问 torch.utils.data.Dataset 和一个收集函数。与 TensorFlow 不同,Hugging Face 的 Trainer 会从数据集和收集器函数中创建数据加载器。为了演示,我们使用了有 3,000+ 句对的 Wikitext-103-raw-v1 验证集。
tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
collater = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
# tokenized_dataset = datasets.load_from_disk(args.dataset_file)
tokenized_dataset_valid = datasets.load_from_disk('./wikiTokenizedValid.hf')
创建一个 TrainerArguments 实例,并传递所有必需的参数,如以下代码所示。这部分代码有助于在训练模型时抽象样板代码。该类很灵活,因为它提供了 100 多个参数来适应不同的训练模式;有关更多信息,请参阅 Hugging face transformers 页面。
你现在可以使用 t.train() 来训练模型了。你还可以通过将 resume_from_checkpoint=True 参数传递给 t.train() 来恢复训练。trainer 类会提取 output_dir 文件夹中的最新检查点,并继续训练直到达到总共 num_train_epochs。
train_args = TrainingArguments(output_dir=args.output_dir, overwrite_output_dir =True, per_device_train_batch_size =args.BATCH_SIZE, logging_first_step=True,logging_strategy='epoch', evaluation_strategy = 'epoch', save_strategy ='epoch', num_train_epochs=args.EPOCHS,save_total_limit=50)
t = Trainer(model, args = train_args, data_collator=collater, train_dataset = tokenized_dataset, optimizers=(optimizer, None), eval_dataset = tokenized_dataset_valid)
t.train()#resume_from_checkpoint=True)
上述模型使用Adam优化器(`learning_rate=2e-5`)和`per_device_train_batch_size=8`进行了大约400个epoch的训练。在一块AMD GPU(MI210,ROCm 5.7.0,PyTorch 2.0.1)上,使用3,000+句对的验证集进行预训练仅需几个小时。训练曲线如图1所示。可以使用最佳模型检查点微调不同的数据集,并在各种NLP任务上测试其表现。

完整的代码如下:
set_seed(42)
parser = argparse.ArgumentParser()
parser.add_argument('--BATCH_SIZE', type=int, default = 8) # 32 is the global batch size, since I use 8 GPUs
parser.add_argument('--EPOCHS', type=int, default=200)
parser.add_argument('--train', action='store_true')
parser.add_argument('--dataset_file', type=str, default= './wikiTokenizedValid.hf')
parser.add_argument('--lr', default = 0.00005, type=float)
parser.add_argument('--output_dir', default = './acc_valid/')
args = parser.parse_args()accelerator = Accelerator()if args.train:args.dataset_file = './wikiTokenizedTrain.hf'args.output_dir = './acc/'
print(args)tokenizer = AutoTokenizer.from_pretrained('bert-base-cased')
collater = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15)
tokenized_dataset = datasets.load_from_disk(args.dataset_file)
tokenized_dataset_valid = datasets.load_from_disk('./wikiTokenizedValid.hf')model = BertForPreTraining(BertConfig.from_pretrained("bert-base-cased"))
optimizer = torch.optim.Adam(model.parameters(), lr =args.lr)device = accelerator.device
model.to(accelerator.device)
train_args = TrainingArguments(output_dir=args.output_dir, overwrite_output_dir =True, per_device_train_batch_size =args.BATCH_SIZE, logging_first_step=True,logging_strategy='epoch', evaluation_strategy = 'epoch', save_strategy ='epoch', num_train_epochs=args.EPOCHS,save_total_limit=50)#, lr_scheduler_type=None)
t = Trainer(model, args = train_args, data_collator=collater, train_dataset = tokenized_dataset, optimizers=(optimizer, None), eval_dataset = tokenized_dataset_valid)
t.train()#resume_from_checkpoint=True)
推理
以一个示例文本为例,使用分词器将其转换为输入tokens,并通过collator生成一个掩码输入。
collater = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=True, mlm_probability=0.15, pad_to_multiple_of=128)
text="The author takes his own advice when it comes to writing: he seeks to ground his claims in clear, concrete examples. He shows specific examples of bad writing to help readers better grasp exactly what he’s critiquing"
tokens = tokenizer.convert_tokens_to_ids(tokenizer.tokenize(text))
inp = collater([tokens])
inp['attention_mask'] = torch.where(inp['input_ids']==0,0,1)
使用预训练的权重初始化模型并进行推理。你将看到模型生成的随机tokens没有上下文意义。
config = BertConfig.from_pretrained('bert-base-cased')
model = BertForPreTraining.from_pretrained('./acc_valid/checkpoint-19600/')
model.eval()
out = model(inp['input_ids'], inp['attention_mask'], labels=inp['labels'])print('Input: ', tokenizer.decode(inp['input_ids'][0][:30]), '\n')
print('Output: ', tokenizer.decode(torch.argmax(out[0], -1)[0][:30]))
输入和输出如下所示。该模型在一个非常小的数据集(3,000多句子)上进行了训练;你可以通过在更大的数据集上训练,例如`wikiText-103-raw-v1`的训练切分数据,来提高性能。
The author takes his own advice when it comes to writing : he [MASK] to ground his claims in clear, concrete examples. He shows specific examples of bad
The Churchill takes his own, when it comes to writing : he continued to ground his claims in clear, this examples. He shows is examples of bad
源代码存储在这个 GitHub 文件夹。
结论
我们所描述的预训练BERT基础模型的过程可以很容易地扩展到不同大小的BERT版本以及不同的数据集。我们使用Hugging Face Trainer和PyTorch后端在AMD GPU上训练了我们的模型。对于训练,我们使用了`wikiText-103-raw-v1`数据集的验证集,但这可以很容易地替换为训练集,只需下载我们在Hugging Face Hub上的仓库中托管的预处理和标记化的训练文件Hugging Face Hub.
在本文中,我们通过MLM和NSP预训练任务复制了BERT的预训练过程,这与许多公共平台上仅使用MLM的方法不同。此外,我们没有使用数据集的小部分,而是预处理并上传了整个数据集到Hub上供您方便使用。在未来的文章中,我们将讨论在多个AMD GPU上使用数据并行和分布式策略来训练各种机器学习应用。
相关文章:
预训练 BERT 使用 Hugging Face 和 PyTorch 在 AMD GPU 上
Pre-training BERT using Hugging Face & PyTorch on an AMD GPU — ROCm Blogs 2024年1月26日,作者:Vara Lakshmi Bayanagari. 这篇博客解释了如何从头开始使用 Hugging Face 库和 PyTorch 后端在 AMD GPU 上为英文语料(WikiText-103-raw-v1)预训练…...
鸿蒙是必经之路
少了大嘴的发布会,老实讲有点让人昏昏入睡。关于技术本身的东西,放在后面。 我想想来加把油~ 鸿蒙发布后褒贬不一,其中很多人不太看好鸿蒙,一方面是开源性、一方面是南向北向的利益问题。 不说技术的领先点,我只扯扯…...
Java项目实战II基于微信小程序的马拉松报名系统(开发文档+数据库+源码)
目录 一、前言 二、技术介绍 三、系统实现 四、文档参考 五、核心代码 六、源码获取 全栈码农以及毕业设计实战开发,CSDN平台Java领域新星创作者,专注于大学生项目实战开发、讲解和毕业答疑辅导。获取源码联系方式请查看文末 一、前言 马拉松运动…...
家用wifi的ip地址固定吗?换wifi就是换ip地址吗
在探讨家用WiFi的IP地址是否固定,以及换WiFi是否就意味着换IP地址这两个问题时,我们首先需要明确几个关键概念:IP地址、家用WiFi网络、以及它们之间的相互作用。 一、家用WiFi的IP地址固定性 家用WiFi环境中的IP地址通常涉及两类:…...
codeforces _ 补题
C. Ball in Berland 传送门:Problem - C - Codeforces 题意: 思路:容斥原理 考虑 第 i 对情侣组合 ,男生为 a ,女生为 b ,那么考虑与之匹配的情侣 必须没有 a | b ,一共有 k 对情侣&#x…...
DataSophon集成ApacheImpala的过程
注意: 本次安装操作系统环境为Anolis8.9(Centos7和Centos8应该也一样) DataSophon版本为DDP-1.2.1 整合的安装包我放网盘了: 通过网盘分享的文件:impala-4.4.1.tar.gz等2个文件 链接: https://pan.baidu.com/s/18KfkO_BEFa5gVcc16I-Yew?pwdza4k 提取码: za4k 1…...
深入探讨TCP/IP协议基础
在当今数字化的时代,计算机网络已经成为人们生活和工作中不可或缺的一部分。而 TCP/IP 协议作为计算机网络的核心协议,更是支撑着全球互联网的运行。本文将深入探讨常见的 TCP/IP 协议基础,带你了解计算机网络的奥秘。 一、计算机网络概述 计…...
《Windows PE》7.4 资源表应用
本节我们将通过两个示例程序,演示对PE文件内图标资源的置换与提取。 本节必须掌握的知识点: 更改图标 提取图标资源 7.4.1 更改图标 让我们来做一个实验,替换PE文件中现有的图标。如果手工替换,一定是先找到资源表,…...
【重生之我要苦学C语言】猜数字游戏和关机程序的整合
今天来把学过的猜数字游戏和关机程序来整合一下 如果有不明白的可以看往期的博客 废话不多说,上代码: #define _CRT_SECURE_NO_WARNINGS #include <stdio.h> #include <time.h> #include <stdlib.h> #include <string.h> void…...
基于centos7脚本一键部署gpmall商城
基于centos7脚本一键部署单节点gpmall商城,该商城可单节点,可集群,可高可用集群部署,VMware17,虚拟机IP:192.168.200.100 将软件包解压到/root目录 [rootlocalhost ~]# ls dist …...
Mac book英特尔系列?M系列?两者有什么区别呢
众所周知,Mac book有M系列,搭载的是苹果自研的M芯片,也有着英特尔系列,搭载的是英特尔的处理器,虽然从 2020 年开始,苹果公司逐步推出了自家研发的 M 系列芯片,并逐渐将 MacBook 产品线过渡到 M…...
Python unstructured库详解:partition_pdf函数完整参数深度解析
Python unstructured库详解:partition_pdf函数完整参数深度解析 1. 简介2. 基础文件处理参数2.1 文件输入参数2.2 页面处理参数 3. 文档解析策略3.1 strategy参数详解3.2 策略选择建议 4. 表格处理参数4.1 表格结构推断 5. 语言处理参数5.1 语言设置 6. 图像处理参数…...
<项目代码>YOLOv8路面病害识别<目标检测>
YOLOv8是一种单阶段(one-stage)检测算法,它将目标检测问题转化为一个回归问题,能够在一次前向传播过程中同时完成目标的分类和定位任务。相较于两阶段检测算法(如Faster R-CNN),YOLOv8具有更高的…...
广告牌和标签学习
效果: 知识学习: entities添加标签label和广告牌billboard label: text:文本添加 font:字体大小和字体类型 fillColor:字体颜色 outlineColor:字体外轮廓颜色 outlineWidth:字体外轮…...
GDB 从裸奔到穿戴整齐
无数次被问道:你在终端下怎么调试更高效?或者怎么在 Vim 里调试?好吧,今天统一回答下,我从来不在 vim 里调试,因为它还不成熟。那除了命令行 GDB 裸奔以外,终端下还有没有更高效的方法ÿ…...
WPF的触发器(Trigger)
WPF(Windows Presentation Foundation)是微软.NET框架的一部分,用于构建Windows客户端应用程序。在WPF中,触发器(Triggers)是一种强大的功能,允许开发者根据控件的状态或属性值来动态改变控件的…...
全能大模型GPT-4o体验和接入教程
GPT-4o体验和接入教程 前言一、原生API二、Python LangchainSpring AI总结 前言 Open AI发布了产品GPT-4o,o表示"omni",全能的意思。 GPT-4o可以实时对音频、视觉和文本进行推理,响应时间平均为 320 毫秒,和人类之间对…...
详解Apache版本、新功能和技术前景
文章目录 一、 版本溯源二、新功能和特性举例1. 模块化和可扩展性增强2. 多处理模块(MPMs)3. 异步支持4. 更细粒度的日志级别控制5. 通用表达式解析器6. HTTP/2支持7. Server Push8. Early Hints9. 更好的SSL/TLS支持10. 更安全的默认设置 三、 技术前景…...
Docker Redis集群3主3从模式
主从集群 docker run -d --name redis-node1 --net host --privilegedtrue -v /home/redis/node1:/data redis:7.0 --cluster-enabled yes --appendonly yes --port 9371docker run -d --name redis-node2 --net host --privilegedtrue -v /home/redis/node2:/data redis:7.0 …...
【Go语言】
type关键字的用法 定义结构体定义接口定义类型别名类型定义类型判断 别名实际上是为了更好地理解代码/ 这里要分点进行记录 使用传值的例子,当两个类型不一样需要进行类型转换 type Myint int // 自定义类型,基于已有的类型自定义一个类型type Myin…...
Prompt Tuning、P-Tuning、Prefix Tuning的区别
一、Prompt Tuning、P-Tuning、Prefix Tuning的区别 1. Prompt Tuning(提示调优) 核心思想:固定预训练模型参数,仅学习额外的连续提示向量(通常是嵌入层的一部分)。实现方式:在输入文本前添加可训练的连续向量(软提示),模型只更新这些提示参数。优势:参数量少(仅提…...
STM32F4基本定时器使用和原理详解
STM32F4基本定时器使用和原理详解 前言如何确定定时器挂载在哪条时钟线上配置及使用方法参数配置PrescalerCounter ModeCounter Periodauto-reload preloadTrigger Event Selection 中断配置生成的代码及使用方法初始化代码基本定时器触发DCA或者ADC的代码讲解中断代码定时启动…...
macOS多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用
文章目录 问题现象问题原因解决办法 问题现象 macOS启动台(Launchpad)多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用。 问题原因 很明显,都是Google家的办公全家桶。这些应用并不是通过独立安装的…...
华为OD机试-食堂供餐-二分法
import java.util.Arrays; import java.util.Scanner;public class DemoTest3 {public static void main(String[] args) {Scanner in new Scanner(System.in);// 注意 hasNext 和 hasNextLine 的区别while (in.hasNextLine()) { // 注意 while 处理多个 caseint a in.nextIn…...
pikachu靶场通关笔记22-1 SQL注入05-1-insert注入(报错法)
目录 一、SQL注入 二、insert注入 三、报错型注入 四、updatexml函数 五、源码审计 六、insert渗透实战 1、渗透准备 2、获取数据库名database 3、获取表名table 4、获取列名column 5、获取字段 本系列为通过《pikachu靶场通关笔记》的SQL注入关卡(共10关࿰…...
全面解析各类VPN技术:GRE、IPsec、L2TP、SSL与MPLS VPN对比
目录 引言 VPN技术概述 GRE VPN 3.1 GRE封装结构 3.2 GRE的应用场景 GRE over IPsec 4.1 GRE over IPsec封装结构 4.2 为什么使用GRE over IPsec? IPsec VPN 5.1 IPsec传输模式(Transport Mode) 5.2 IPsec隧道模式(Tunne…...
基于matlab策略迭代和值迭代法的动态规划
经典的基于策略迭代和值迭代法的动态规划matlab代码,实现机器人的最优运输 Dynamic-Programming-master/Environment.pdf , 104724 Dynamic-Programming-master/README.md , 506 Dynamic-Programming-master/generalizedPolicyIteration.m , 1970 Dynamic-Programm…...
【无标题】路径问题的革命性重构:基于二维拓扑收缩色动力学模型的零点隧穿理论
路径问题的革命性重构:基于二维拓扑收缩色动力学模型的零点隧穿理论 一、传统路径模型的根本缺陷 在经典正方形路径问题中(图1): mermaid graph LR A((A)) --- B((B)) B --- C((C)) C --- D((D)) D --- A A -.- C[无直接路径] B -…...
DiscuzX3.5发帖json api
参考文章:PHP实现独立Discuz站外发帖(直连操作数据库)_discuz 发帖api-CSDN博客 简单改造了一下,适配我自己的需求 有一个站点存在多个采集站,我想通过主站拿标题,采集站拿内容 使用到的sql如下 CREATE TABLE pre_forum_post_…...
绕过 Xcode?使用 Appuploader和主流工具实现 iOS 上架自动化
iOS 应用的发布流程一直是开发链路中最“苹果味”的环节:强依赖 Xcode、必须使用 macOS、各种证书和描述文件配置……对很多跨平台开发者来说,这一套流程并不友好。 特别是当你的项目主要在 Windows 或 Linux 下开发(例如 Flutter、React Na…...
