当前位置: 首页 > news >正文

使用 DPO 微调 Llama 2

简介

基于人类反馈的强化学习 (Reinforcement Learning from Human Feedback,RLHF) 事实上已成为 GPT-4 或 Claude 等 LLM 训练的最后一步,它可以确保语言模型的输出符合人类在闲聊或安全性等方面的期望。然而,它也给 NLP 引入了一些 RL 相关的复杂性: 既要构建一个好的奖励函数,并训练一个模型用以估计每个状态的价值 (value); 又要注意最终生成的 LLM 不能与原始模型相差太远,如果太远的话会使得模型容易产生乱码而非有意义的文本。该过程非常复杂,涉及到许多复杂的组件,而这些组件本身在训练过程中又是动态变化的,因此把它们料理好并不容易。

Rafailov、Sharma、Mitchell 等人最近发表了一篇论文 Direct Preference Optimization,论文提出将现有方法使用的基于强化学习的目标转换为可以通过简单的二元交叉熵损失直接优化的目标,这一做法大大简化了 LLM 的提纯过程。

本文介绍了直接偏好优化 (Direct Preference Optimization,DPO) 法,该方法现已集成至 TRL 库 中。同时,我们还展示了如何在 stack-exchange preference 数据集上微调最新的 Llama v2 7B 模型, stack-exchange preference 数据集中包含了各个 stack-exchange 门户上的各种问题及其排序后的回答。

DPO 与 PPO

在通过 RL 优化人类衍生偏好时,一直以来的传统做法是使用一个辅助奖励模型来微调目标模型,以通过 RL 机制最大化目标模型所能获得的奖励。直观上,我们使用奖励模型向待优化模型提供反馈,以促使它多生成高奖励输出,少生成低奖励输出。同时,我们使用冻结的参考模型来确保输出偏差不会太大,且继续保持输出的多样性。这通常需要在目标函数设计时,除了奖励最大化目标外再添加一个相对于参考模型的 KL 惩罚项,这样做有助于防止模型学习作弊或钻营奖励模型。

DPO 绕过了建模奖励函数这一步,这源于一个关键洞见: 从奖励函数到最优 RL 策略的分析映射。这个映射直观地度量了给定奖励函数与给定偏好数据的匹配程度。有了它,作者就可与将基于奖励和参考模型的 RL 损失直接转换为仅基于参考模型的损失,从而直接在偏好数据上优化语言模型!因此,DPO 从寻找最小化 RLHF 损失的最佳方案开始,通过改变参量的方式推导出一个 仅需 参考模型的损失!

有了它,我们可以直接优化该似然目标,而不需要奖励模型或繁琐的强化学习优化过程。

如何使用 TRL 进行训练

如前所述,一个典型的 RLHF 流水线通常包含以下几个环节:

  1. 有监督微调 (supervised fine-tuning,SFT)

  2. 用偏好标签标注数据

  3. 基于偏好数据训练奖励模型

  4. RL 优化

TRL 库包含了所有这些环节所需的工具程序。而 DPO 训练直接消灭了奖励建模和 RL 这两个环节 (环节 3 和 4),直接根据标注好的偏好数据优化 DPO 目标。

使用 DPO,我们仍然需要执行环节 1,但我们仅需在 TRL 中向 DPOTrainer 提供环节 2 准备好的偏好数据,而不再需要环节 3 和 4。标注好的偏好数据需要遵循特定的格式,它是一个含有以下 3 个键的字典:

  • prompt : 即推理时输入给模型的提示

  • chosen : 即针对给定提示的较优回答

  • rejected :  即针对给定提示的较劣回答或非给定提示的回答

例如,对于 stack-exchange preference 数据集,我们可以通过以下工具函数将数据集中的样本映射至上述字典格式并删除所有原始列:

def return_prompt_and_responses(samples) -> Dict[str, str, str]:return {"prompt": ["Question: " + question + "\n\nAnswer: "for question in samples["question"]],"chosen": samples["response_j"], # rated better than k"rejected": samples["response_k"], # rated worse than j}dataset = load_dataset("lvwerra/stack-exchange-paired",split="train",data_dir="data/rl"
)
original_columns = dataset.column_namesdataset.map(return_prompt_and_responses,batched=True,remove_columns=original_columns
)

一旦有了排序数据集,DPO 损失其实本质上就是一种有监督损失,其经由参考模型获得隐式奖励。因此,从上层来看,DPOTrainer 需要我们输入待优化的基础模型以及参考模型:

dpo_trainer = DPOTrainer(model, # 经 SFT 的基础模型model_ref, # 一般为经 SFT 的基础模型的一个拷贝beta=0.1, # DPO 的温度超参train_dataset=dataset, # 上文准备好的数据集tokenizer=tokenizer, # 分词器args=training_args, # 训练参数,如: batch size, 学习率等
)

其中,超参 beta 是 DPO 损失的温度,通常在 0.10.5 之间。它控制了我们对参考模型的关注程度,beta 越小,我们就越忽略参考模型。对训练器初始化后,我们就可以简单调用以下方法,使用给定的 training_args 在给定数据集上进行训练了:

dpo_trainer.train()

基于 Llama v2 进行实验

在 TRL 中实现 DPO 训练器的好处是,人们可以利用 TRL 及其依赖库 (如 Peft 和 Accelerate) 中已有的 LLM 相关功能。有了这些库,我们甚至可以使用 bitsandbytes 库提供的 QLoRA 技术 来训练 Llama v2 模型。

有监督微调

如上文所述,我们先用 TRL 的 SFTTrainer 在 SFT 数据子集上使用 QLoRA 对 7B Llama v2 模型进行有监督微调:

# load the base model in 4-bit quantization
bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,
)base_model = AutoModelForCausalLM.from_pretrained(script_args.model_name, # "meta-llama/Llama-2-7b-hf"quantization_config=bnb_config,device_map={"": 0},trust_remote_code=True,use_auth_token=True,
)
base_model.config.use_cache = False# add LoRA layers on top of the quantized base model
peft_config = LoraConfig(r=script_args.lora_r,lora_alpha=script_args.lora_alpha,lora_dropout=script_args.lora_dropout,target_modules=["q_proj", "v_proj"],bias="none",task_type="CAUSAL_LM",
)
...
trainer = SFTTrainer(model=base_model,train_dataset=train_dataset,eval_dataset=eval_dataset,peft_config=peft_config,packing=True,max_seq_length=None,tokenizer=tokenizer,args=training_args, # HF Trainer arguments
)
trainer.train()

DPO 训练

SFT 结束后,我们保存好生成的模型。接着,我们继续进行 DPO 训练,我们把 SFT 生成的模型作为 DPO 的基础模型和参考模型,并在上文生成的 stack-exchange preference 数据上,以 DPO 为目标函数训练模型。我们选择对模型进行 LoRa 微调,因此我们使用 Peft 的 AutoPeftModelForCausalLM 函数加载模型:

model = AutoPeftModelForCausalLM.from_pretrained(script_args.model_name_or_path, # location of saved SFT modellow_cpu_mem_usage=True,torch_dtype=torch.float16,load_in_4bit=True,is_trainable=True,
)
model_ref = AutoPeftModelForCausalLM.from_pretrained(script_args.model_name_or_path, # same model as the main onelow_cpu_mem_usage=True,torch_dtype=torch.float16,load_in_4bit=True,
)
...
dpo_trainer = DPOTrainer(model,model_ref,args=training_args,beta=script_args.beta,train_dataset=train_dataset,eval_dataset=eval_dataset,tokenizer=tokenizer,peft_config=peft_config,
)
dpo_trainer.train()
dpo_trainer.save_model()

可以看出,我们以 4 比特的方式加载模型,然后通过 peft_config 参数选择 QLora 方法对其进行训练。训练器还会用评估数据集评估训练进度,并报告一些关键指标,例如可以选择通过 WandB 记录并显示隐式奖励。最后,我们可以将训练好的模型推送到 HuggingFace Hub。

总结

SFT 和 DPO 训练脚本的完整源代码可在该目录 examples/stack_llama_2 处找到,训好的已合并模型也已上传至 HF Hub (见 此处)。

你可以在 这儿 找到我们的模型在训练过程的 WandB 日志,其中包含了 DPOTrainer 在训练和评估期间记录下来的以下奖励指标:

  • rewards/chosen (较优回答的奖励) : 针对较优回答,策略模型与参考模型的对数概率二者之差的均值,按 beta 缩放。

  • rewards/rejected (较劣回答的奖励) : 针对较劣回答,策略模型与参考模型的对数概率二者之差的均值,按 beta 缩放。

  • rewards/accuracy (奖励准确率) : 较优回答的奖励大于相应较劣回答的奖励的频率的均值

  • rewards/margins (奖励余裕值) : 较优回答的奖励与相应较劣回答的奖励二者之差的均值。

直观上讲,在训练过程中,我们希望余裕值增加并且准确率达到 1.0,换句话说,较优回答的奖励高于较劣回答的奖励 (或余裕值大于零)。随后,我们还可以在评估数据集上计算这些指标。

我们希望我们代码的发布可以降低读者的入门门槛,让大家可以在自己的数据集上尝试这种大语言模型对齐方法,我们迫不及待地想看到你会用它做哪些事情!如果你想试试我们训练出来的模型,可以玩玩这个 space: trl-lib/stack-llama。

🤗 宝子们可以戳 阅读原文 查看文中所有的外部链接哟!


英文原文: https://hf.co/blog/dpo-trl

原文作者: Kashif Rasul, Younes Belkada, Leandro von Werra

译者: Matrix Yao (姚伟峰),英特尔深度学习工程师,工作方向为 transformer-family 模型在各模态数据上的应用及大规模模型的训练推理。

审校/排版: zhongdongy (阿东)

相关文章:

使用 DPO 微调 Llama 2

简介 基于人类反馈的强化学习 (Reinforcement Learning from Human Feedback,RLHF) 事实上已成为 GPT-4 或 Claude 等 LLM 训练的最后一步,它可以确保语言模型的输出符合人类在闲聊或安全性等方面的期望。然而,它也给 NLP 引入了一些 RL 相关…...

数据库——事务,事务隔离级别

文章目录 什么是事务?事务的特性(ACID)并发事务带来的问题事务隔离级别实际情况演示脏读(读未提交)避免脏读(读已提交)不可重复读可重复读防止幻读(可串行化) 什么是事务? 事务是逻辑上的一组操作,要么都执行,要么都不执行。 事务最经典也经常被拿出…...

对《VB.NET通过VB6 ActiveX DLL调用PowerBasic及FreeBasic动态库》的改进

《VB.NET通过VB6 ActiveX DLL调用PowerBasic及FreeBasic动态库》使用的Activex DLL公共对象是需要先注册的。https://blog.csdn.net/weixin_45707491/article/details/132437502?spm1001.2014.3001.5501 Activex DLL事前注册,一次多用说起来也不是啥大问题&#x…...

【PHP】数据类型运算符位运算

文章目录 数据类型简单(基本)数据类型:4个小类复合数据类型:2个小类特殊数据类型:2个小类类型转换类型判断整数类型浮点类型布尔类型 运算符赋值运算符算术运算符比较运算符逻辑运算符连接运算符错误抑制符三目运算符自…...

使用 Nacos 作为 Spring Boot 配置中心

🌷🍁 博主猫头虎 带您 Go to New World.✨🍁 🦄 博客首页——猫头虎的博客🎐 🐳《面试题大全专栏》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~🌺 &a…...

微服务 Eureka

Eureka Eureka是Netflix开源的一个用于构建基于微服务架构的服务发现和注册中心技术。在微服务架构中,系统被拆分成多个小型、自治的服务,每个服务负责特定的业务功能。这些服务需要能够相互发现和通信,这就是Eureka所提供的功能。 Eureka主…...

Spring Boot 事务和事务传播机制

1. 为什么需要事务? 事务定义 将一组操作封装成一个执行单元 (封装到一起),这一组的执行具备原子性, 那么就要么全部成功,要么全部失败. 为什么要用事务? 比如转账分为两个操作: 第一步操作:A 账户-100 元。 第二步操作:B账户 100 元。 如果没有事务&a…...

计算机组成原理(巨巨巨基础篇)

有关《计算机组成原理》课本中有关 内存计算换算(字,位,字节) 个人理解 前面知识点搭建框架,最后两道例题是直观理解体会 主存储器的基本概念 位:存储信息的最小单位,称为存储位或存储元。 背…...

C语言:选择+编程(每日一练Day7)

目录 选择题: 题一: 题二: 题三: 题四: 题五: 编程题: 题一:图片整理 思路一: 思路二: 题二:寻找数组的中心下标 思路一&#xff1…...

leetcode做题笔记93. 复原 IP 地址

有效 IP 地址 正好由四个整数(每个整数位于 0 到 255 之间组成,且不能含有前导 0),整数之间用 . 分隔。 例如:"0.1.2.201" 和 "192.168.1.1" 是 有效 IP 地址,但是 "0.011.255.2…...

HTTPS 中间人攻击

HTTPS 中间人攻击 中间人攻击过程 通讯过程 客户端——中间人——服务器 过程如下 服务器向客户端发送公钥攻击者截获公钥,保留在自己手上然后攻击者自己生成一个【伪造的】公钥,发给客户端客户端收到【伪造的】公钥后,利用【伪造的】公…...

MATLAB打开excel读取写入操作例程

本文使用素材含代码测试用例等 MATLAB读写excel文件历程含,内含有测试代码资源-CSDN文库 打开文件 使用uigetfile函数过滤非xlsx文件,找到需要读取的文件,首先判断文件是否存在,如果文件不存在,程序直接返回&#x…...

[C语言]分支与循环

导言: 在人生中我们总会有选择,**如下一顿吃啥?**又或者每天都是在重复,吃饭!!!!,当然在C语言中也有选择和重复那就是分支语句与循环语句 文章目录 分支循环循环中的关键…...

绘制区块链之链:解码去中心化、安全性和透明性的奇迹

区块链技术以其去中心化、安全性和透明性等特点在全球范围内引起了广泛的关注和兴趣。区块链是一种分布式账本技术,通过将数据以不可篡改的方式链接在一起,创建了一个安全可靠的数据库。这种革命性的技术正在许多领域中发挥作用,包括加密货币…...

4G工业路由器的功能与选型!详解工作原理、关键参数、典型品牌

随着工业互联网的发展,4G工业路由器得到越来越广泛的应用。但是如何根据实际需求选择合适的4G工业路由器,是许多用户关心的问题。为此,本文将深入剖析4G工业路由器的工作原理、重要参数及选型要点,并推荐优质的品牌及产品,以提供选型参考。 一、4G工业路由器的工作原理 4G工业…...

c与c++中struct的主要区别和c++中的struct与class的主要区别

1、c和c中struct的主要区别 c中的struct不可以含有成员函数,而c中的struct可以。 C语言 c中struct 是一种用于组合多个不同数据类型的数据成员的方式。struct 声明中的成员默认是公共的,并且不支持成员函数、访问控制和继承等概念。C中的struct通常被用…...

mysql中char_length()和length()

MySQL中计算字符串长度有两个函数分别为char_length和length。 char_length char_length函数可以计算unicode字符,包括中文等字符集的长度 char_length(‘string’)/char_length(column_name) 1、返回值为字符串string或者对应字段长度,长度的单位为字…...

Numpy学习笔记

科学计算库(Numpy) 通常数据都能转换成矩阵,行就是每一条样本数据,列就是每个字段的特征,Numpy在矩阵运算上非常高效,可以快速处理数据并进行数据计算。 Numpy基本操作 先导入 import numpy as nparray…...

LAMP配置与应用

目录 一、LAMP架构的组成 1、WEB资源类型 2、LAMP架构的组成 二、编译安装LAMP 编译安装apache 1、环境准备 2、导入apache相关压缩安装包,然后安装编译环境 3、解压软件包,并移动apr包与apr-util包到安装目录中,并切换到http解压出…...

Dockerfile搭建LNMP运行Wordpress平台

Dockerfile搭建LNMP运行Wordpress平台 一、项目1.1 项目环境1.2 服务器环境1.3 任务需求 二、Linux 系统基础镜像三、Nginx1、建立工作目录2、编写 Dockerfile 脚本3、准备 nginx.conf 配置文件4、生成镜像5、创建自定义网络6、启动镜像容器7、验证 nginx 四、Mysql1、建立工作…...

MPNet:旋转机械轻量化故障诊断模型详解python代码复现

目录 一、问题背景与挑战 二、MPNet核心架构 2.1 多分支特征融合模块(MBFM) 2.2 残差注意力金字塔模块(RAPM) 2.2.1 空间金字塔注意力(SPA) 2.2.2 金字塔残差块(PRBlock) 2.3 分类器设计 三、关键技术突破 3.1 多尺度特征融合 3.2 轻量化设计策略 3.3 抗噪声…...

conda相比python好处

Conda 作为 Python 的环境和包管理工具,相比原生 Python 生态(如 pip 虚拟环境)有许多独特优势,尤其在多项目管理、依赖处理和跨平台兼容性等方面表现更优。以下是 Conda 的核心好处: 一、一站式环境管理&#xff1a…...

Vue记事本应用实现教程

文章目录 1. 项目介绍2. 开发环境准备3. 设计应用界面4. 创建Vue实例和数据模型5. 实现记事本功能5.1 添加新记事项5.2 删除记事项5.3 清空所有记事 6. 添加样式7. 功能扩展:显示创建时间8. 功能扩展:记事项搜索9. 完整代码10. Vue知识点解析10.1 数据绑…...

论文解读:交大港大上海AI Lab开源论文 | 宇树机器人多姿态起立控制强化学习框架(二)

HoST框架核心实现方法详解 - 论文深度解读(第二部分) 《Learning Humanoid Standing-up Control across Diverse Postures》 系列文章: 论文深度解读 + 算法与代码分析(二) 作者机构: 上海AI Lab, 上海交通大学, 香港大学, 浙江大学, 香港中文大学 论文主题: 人形机器人…...

Docker 运行 Kafka 带 SASL 认证教程

Docker 运行 Kafka 带 SASL 认证教程 Docker 运行 Kafka 带 SASL 认证教程一、说明二、环境准备三、编写 Docker Compose 和 jaas文件docker-compose.yml代码说明:server_jaas.conf 四、启动服务五、验证服务六、连接kafka服务七、总结 Docker 运行 Kafka 带 SASL 认…...

【论文笔记】若干矿井粉尘检测算法概述

总的来说,传统机器学习、传统机器学习与深度学习的结合、LSTM等算法所需要的数据集来源于矿井传感器测量的粉尘浓度,通过建立回归模型来预测未来矿井的粉尘浓度。传统机器学习算法性能易受数据中极端值的影响。YOLO等计算机视觉算法所需要的数据集来源于…...

RNN避坑指南:从数学推导到LSTM/GRU工业级部署实战流程

本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在聚客AI学院。 本文全面剖析RNN核心原理,深入讲解梯度消失/爆炸问题,并通过LSTM/GRU结构实现解决方案,提供时间序列预测和文本生成…...

JVM虚拟机:内存结构、垃圾回收、性能优化

1、JVM虚拟机的简介 Java 虚拟机(Java Virtual Machine 简称:JVM)是运行所有 Java 程序的抽象计算机,是 Java 语言的运行环境,实现了 Java 程序的跨平台特性。JVM 屏蔽了与具体操作系统平台相关的信息,使得 Java 程序只需生成在 JVM 上运行的目标代码(字节码),就可以…...

iOS性能调优实战:借助克魔(KeyMob)与常用工具深度洞察App瓶颈

在日常iOS开发过程中,性能问题往往是最令人头疼的一类Bug。尤其是在App上线前的压测阶段或是处理用户反馈的高发期,开发者往往需要面对卡顿、崩溃、能耗异常、日志混乱等一系列问题。这些问题表面上看似偶发,但背后往往隐藏着系统资源调度不当…...

算法:模拟

1.替换所有的问号 1576. 替换所有的问号 - 力扣(LeetCode) ​遍历字符串​:通过外层循环逐一检查每个字符。​遇到 ? 时处理​: 内层循环遍历小写字母(a 到 z)。对每个字母检查是否满足: ​与…...