当前位置: 首页 > news >正文

使用 DPO 微调 Llama 2

简介

基于人类反馈的强化学习 (Reinforcement Learning from Human Feedback,RLHF) 事实上已成为 GPT-4 或 Claude 等 LLM 训练的最后一步,它可以确保语言模型的输出符合人类在闲聊或安全性等方面的期望。然而,它也给 NLP 引入了一些 RL 相关的复杂性: 既要构建一个好的奖励函数,并训练一个模型用以估计每个状态的价值 (value); 又要注意最终生成的 LLM 不能与原始模型相差太远,如果太远的话会使得模型容易产生乱码而非有意义的文本。该过程非常复杂,涉及到许多复杂的组件,而这些组件本身在训练过程中又是动态变化的,因此把它们料理好并不容易。

Rafailov、Sharma、Mitchell 等人最近发表了一篇论文 Direct Preference Optimization,论文提出将现有方法使用的基于强化学习的目标转换为可以通过简单的二元交叉熵损失直接优化的目标,这一做法大大简化了 LLM 的提纯过程。

本文介绍了直接偏好优化 (Direct Preference Optimization,DPO) 法,该方法现已集成至 TRL 库 中。同时,我们还展示了如何在 stack-exchange preference 数据集上微调最新的 Llama v2 7B 模型, stack-exchange preference 数据集中包含了各个 stack-exchange 门户上的各种问题及其排序后的回答。

DPO 与 PPO

在通过 RL 优化人类衍生偏好时,一直以来的传统做法是使用一个辅助奖励模型来微调目标模型,以通过 RL 机制最大化目标模型所能获得的奖励。直观上,我们使用奖励模型向待优化模型提供反馈,以促使它多生成高奖励输出,少生成低奖励输出。同时,我们使用冻结的参考模型来确保输出偏差不会太大,且继续保持输出的多样性。这通常需要在目标函数设计时,除了奖励最大化目标外再添加一个相对于参考模型的 KL 惩罚项,这样做有助于防止模型学习作弊或钻营奖励模型。

DPO 绕过了建模奖励函数这一步,这源于一个关键洞见: 从奖励函数到最优 RL 策略的分析映射。这个映射直观地度量了给定奖励函数与给定偏好数据的匹配程度。有了它,作者就可与将基于奖励和参考模型的 RL 损失直接转换为仅基于参考模型的损失,从而直接在偏好数据上优化语言模型!因此,DPO 从寻找最小化 RLHF 损失的最佳方案开始,通过改变参量的方式推导出一个 仅需 参考模型的损失!

有了它,我们可以直接优化该似然目标,而不需要奖励模型或繁琐的强化学习优化过程。

如何使用 TRL 进行训练

如前所述,一个典型的 RLHF 流水线通常包含以下几个环节:

  1. 有监督微调 (supervised fine-tuning,SFT)

  2. 用偏好标签标注数据

  3. 基于偏好数据训练奖励模型

  4. RL 优化

TRL 库包含了所有这些环节所需的工具程序。而 DPO 训练直接消灭了奖励建模和 RL 这两个环节 (环节 3 和 4),直接根据标注好的偏好数据优化 DPO 目标。

使用 DPO,我们仍然需要执行环节 1,但我们仅需在 TRL 中向 DPOTrainer 提供环节 2 准备好的偏好数据,而不再需要环节 3 和 4。标注好的偏好数据需要遵循特定的格式,它是一个含有以下 3 个键的字典:

  • prompt : 即推理时输入给模型的提示

  • chosen : 即针对给定提示的较优回答

  • rejected :  即针对给定提示的较劣回答或非给定提示的回答

例如,对于 stack-exchange preference 数据集,我们可以通过以下工具函数将数据集中的样本映射至上述字典格式并删除所有原始列:

def return_prompt_and_responses(samples) -> Dict[str, str, str]:return {"prompt": ["Question: " + question + "\n\nAnswer: "for question in samples["question"]],"chosen": samples["response_j"], # rated better than k"rejected": samples["response_k"], # rated worse than j}dataset = load_dataset("lvwerra/stack-exchange-paired",split="train",data_dir="data/rl"
)
original_columns = dataset.column_namesdataset.map(return_prompt_and_responses,batched=True,remove_columns=original_columns
)

一旦有了排序数据集,DPO 损失其实本质上就是一种有监督损失,其经由参考模型获得隐式奖励。因此,从上层来看,DPOTrainer 需要我们输入待优化的基础模型以及参考模型:

dpo_trainer = DPOTrainer(model, # 经 SFT 的基础模型model_ref, # 一般为经 SFT 的基础模型的一个拷贝beta=0.1, # DPO 的温度超参train_dataset=dataset, # 上文准备好的数据集tokenizer=tokenizer, # 分词器args=training_args, # 训练参数,如: batch size, 学习率等
)

其中,超参 beta 是 DPO 损失的温度,通常在 0.10.5 之间。它控制了我们对参考模型的关注程度,beta 越小,我们就越忽略参考模型。对训练器初始化后,我们就可以简单调用以下方法,使用给定的 training_args 在给定数据集上进行训练了:

dpo_trainer.train()

基于 Llama v2 进行实验

在 TRL 中实现 DPO 训练器的好处是,人们可以利用 TRL 及其依赖库 (如 Peft 和 Accelerate) 中已有的 LLM 相关功能。有了这些库,我们甚至可以使用 bitsandbytes 库提供的 QLoRA 技术 来训练 Llama v2 模型。

有监督微调

如上文所述,我们先用 TRL 的 SFTTrainer 在 SFT 数据子集上使用 QLoRA 对 7B Llama v2 模型进行有监督微调:

# load the base model in 4-bit quantization
bnb_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_quant_type="nf4",bnb_4bit_compute_dtype=torch.bfloat16,
)base_model = AutoModelForCausalLM.from_pretrained(script_args.model_name, # "meta-llama/Llama-2-7b-hf"quantization_config=bnb_config,device_map={"": 0},trust_remote_code=True,use_auth_token=True,
)
base_model.config.use_cache = False# add LoRA layers on top of the quantized base model
peft_config = LoraConfig(r=script_args.lora_r,lora_alpha=script_args.lora_alpha,lora_dropout=script_args.lora_dropout,target_modules=["q_proj", "v_proj"],bias="none",task_type="CAUSAL_LM",
)
...
trainer = SFTTrainer(model=base_model,train_dataset=train_dataset,eval_dataset=eval_dataset,peft_config=peft_config,packing=True,max_seq_length=None,tokenizer=tokenizer,args=training_args, # HF Trainer arguments
)
trainer.train()

DPO 训练

SFT 结束后,我们保存好生成的模型。接着,我们继续进行 DPO 训练,我们把 SFT 生成的模型作为 DPO 的基础模型和参考模型,并在上文生成的 stack-exchange preference 数据上,以 DPO 为目标函数训练模型。我们选择对模型进行 LoRa 微调,因此我们使用 Peft 的 AutoPeftModelForCausalLM 函数加载模型:

model = AutoPeftModelForCausalLM.from_pretrained(script_args.model_name_or_path, # location of saved SFT modellow_cpu_mem_usage=True,torch_dtype=torch.float16,load_in_4bit=True,is_trainable=True,
)
model_ref = AutoPeftModelForCausalLM.from_pretrained(script_args.model_name_or_path, # same model as the main onelow_cpu_mem_usage=True,torch_dtype=torch.float16,load_in_4bit=True,
)
...
dpo_trainer = DPOTrainer(model,model_ref,args=training_args,beta=script_args.beta,train_dataset=train_dataset,eval_dataset=eval_dataset,tokenizer=tokenizer,peft_config=peft_config,
)
dpo_trainer.train()
dpo_trainer.save_model()

可以看出,我们以 4 比特的方式加载模型,然后通过 peft_config 参数选择 QLora 方法对其进行训练。训练器还会用评估数据集评估训练进度,并报告一些关键指标,例如可以选择通过 WandB 记录并显示隐式奖励。最后,我们可以将训练好的模型推送到 HuggingFace Hub。

总结

SFT 和 DPO 训练脚本的完整源代码可在该目录 examples/stack_llama_2 处找到,训好的已合并模型也已上传至 HF Hub (见 此处)。

你可以在 这儿 找到我们的模型在训练过程的 WandB 日志,其中包含了 DPOTrainer 在训练和评估期间记录下来的以下奖励指标:

  • rewards/chosen (较优回答的奖励) : 针对较优回答,策略模型与参考模型的对数概率二者之差的均值,按 beta 缩放。

  • rewards/rejected (较劣回答的奖励) : 针对较劣回答,策略模型与参考模型的对数概率二者之差的均值,按 beta 缩放。

  • rewards/accuracy (奖励准确率) : 较优回答的奖励大于相应较劣回答的奖励的频率的均值

  • rewards/margins (奖励余裕值) : 较优回答的奖励与相应较劣回答的奖励二者之差的均值。

直观上讲,在训练过程中,我们希望余裕值增加并且准确率达到 1.0,换句话说,较优回答的奖励高于较劣回答的奖励 (或余裕值大于零)。随后,我们还可以在评估数据集上计算这些指标。

我们希望我们代码的发布可以降低读者的入门门槛,让大家可以在自己的数据集上尝试这种大语言模型对齐方法,我们迫不及待地想看到你会用它做哪些事情!如果你想试试我们训练出来的模型,可以玩玩这个 space: trl-lib/stack-llama。

🤗 宝子们可以戳 阅读原文 查看文中所有的外部链接哟!


英文原文: https://hf.co/blog/dpo-trl

原文作者: Kashif Rasul, Younes Belkada, Leandro von Werra

译者: Matrix Yao (姚伟峰),英特尔深度学习工程师,工作方向为 transformer-family 模型在各模态数据上的应用及大规模模型的训练推理。

审校/排版: zhongdongy (阿东)

相关文章:

使用 DPO 微调 Llama 2

简介 基于人类反馈的强化学习 (Reinforcement Learning from Human Feedback,RLHF) 事实上已成为 GPT-4 或 Claude 等 LLM 训练的最后一步,它可以确保语言模型的输出符合人类在闲聊或安全性等方面的期望。然而,它也给 NLP 引入了一些 RL 相关…...

数据库——事务,事务隔离级别

文章目录 什么是事务?事务的特性(ACID)并发事务带来的问题事务隔离级别实际情况演示脏读(读未提交)避免脏读(读已提交)不可重复读可重复读防止幻读(可串行化) 什么是事务? 事务是逻辑上的一组操作,要么都执行,要么都不执行。 事务最经典也经常被拿出…...

对《VB.NET通过VB6 ActiveX DLL调用PowerBasic及FreeBasic动态库》的改进

《VB.NET通过VB6 ActiveX DLL调用PowerBasic及FreeBasic动态库》使用的Activex DLL公共对象是需要先注册的。https://blog.csdn.net/weixin_45707491/article/details/132437502?spm1001.2014.3001.5501 Activex DLL事前注册,一次多用说起来也不是啥大问题&#x…...

【PHP】数据类型运算符位运算

文章目录 数据类型简单(基本)数据类型:4个小类复合数据类型:2个小类特殊数据类型:2个小类类型转换类型判断整数类型浮点类型布尔类型 运算符赋值运算符算术运算符比较运算符逻辑运算符连接运算符错误抑制符三目运算符自…...

使用 Nacos 作为 Spring Boot 配置中心

🌷🍁 博主猫头虎 带您 Go to New World.✨🍁 🦄 博客首页——猫头虎的博客🎐 🐳《面试题大全专栏》 文章图文并茂🦕生动形象🦖简单易学!欢迎大家来踩踩~🌺 &a…...

微服务 Eureka

Eureka Eureka是Netflix开源的一个用于构建基于微服务架构的服务发现和注册中心技术。在微服务架构中,系统被拆分成多个小型、自治的服务,每个服务负责特定的业务功能。这些服务需要能够相互发现和通信,这就是Eureka所提供的功能。 Eureka主…...

Spring Boot 事务和事务传播机制

1. 为什么需要事务? 事务定义 将一组操作封装成一个执行单元 (封装到一起),这一组的执行具备原子性, 那么就要么全部成功,要么全部失败. 为什么要用事务? 比如转账分为两个操作: 第一步操作:A 账户-100 元。 第二步操作:B账户 100 元。 如果没有事务&a…...

计算机组成原理(巨巨巨基础篇)

有关《计算机组成原理》课本中有关 内存计算换算(字,位,字节) 个人理解 前面知识点搭建框架,最后两道例题是直观理解体会 主存储器的基本概念 位:存储信息的最小单位,称为存储位或存储元。 背…...

C语言:选择+编程(每日一练Day7)

目录 选择题: 题一: 题二: 题三: 题四: 题五: 编程题: 题一:图片整理 思路一: 思路二: 题二:寻找数组的中心下标 思路一&#xff1…...

leetcode做题笔记93. 复原 IP 地址

有效 IP 地址 正好由四个整数(每个整数位于 0 到 255 之间组成,且不能含有前导 0),整数之间用 . 分隔。 例如:"0.1.2.201" 和 "192.168.1.1" 是 有效 IP 地址,但是 "0.011.255.2…...

HTTPS 中间人攻击

HTTPS 中间人攻击 中间人攻击过程 通讯过程 客户端——中间人——服务器 过程如下 服务器向客户端发送公钥攻击者截获公钥,保留在自己手上然后攻击者自己生成一个【伪造的】公钥,发给客户端客户端收到【伪造的】公钥后,利用【伪造的】公…...

MATLAB打开excel读取写入操作例程

本文使用素材含代码测试用例等 MATLAB读写excel文件历程含,内含有测试代码资源-CSDN文库 打开文件 使用uigetfile函数过滤非xlsx文件,找到需要读取的文件,首先判断文件是否存在,如果文件不存在,程序直接返回&#x…...

[C语言]分支与循环

导言: 在人生中我们总会有选择,**如下一顿吃啥?**又或者每天都是在重复,吃饭!!!!,当然在C语言中也有选择和重复那就是分支语句与循环语句 文章目录 分支循环循环中的关键…...

绘制区块链之链:解码去中心化、安全性和透明性的奇迹

区块链技术以其去中心化、安全性和透明性等特点在全球范围内引起了广泛的关注和兴趣。区块链是一种分布式账本技术,通过将数据以不可篡改的方式链接在一起,创建了一个安全可靠的数据库。这种革命性的技术正在许多领域中发挥作用,包括加密货币…...

4G工业路由器的功能与选型!详解工作原理、关键参数、典型品牌

随着工业互联网的发展,4G工业路由器得到越来越广泛的应用。但是如何根据实际需求选择合适的4G工业路由器,是许多用户关心的问题。为此,本文将深入剖析4G工业路由器的工作原理、重要参数及选型要点,并推荐优质的品牌及产品,以提供选型参考。 一、4G工业路由器的工作原理 4G工业…...

c与c++中struct的主要区别和c++中的struct与class的主要区别

1、c和c中struct的主要区别 c中的struct不可以含有成员函数,而c中的struct可以。 C语言 c中struct 是一种用于组合多个不同数据类型的数据成员的方式。struct 声明中的成员默认是公共的,并且不支持成员函数、访问控制和继承等概念。C中的struct通常被用…...

mysql中char_length()和length()

MySQL中计算字符串长度有两个函数分别为char_length和length。 char_length char_length函数可以计算unicode字符,包括中文等字符集的长度 char_length(‘string’)/char_length(column_name) 1、返回值为字符串string或者对应字段长度,长度的单位为字…...

Numpy学习笔记

科学计算库(Numpy) 通常数据都能转换成矩阵,行就是每一条样本数据,列就是每个字段的特征,Numpy在矩阵运算上非常高效,可以快速处理数据并进行数据计算。 Numpy基本操作 先导入 import numpy as nparray…...

LAMP配置与应用

目录 一、LAMP架构的组成 1、WEB资源类型 2、LAMP架构的组成 二、编译安装LAMP 编译安装apache 1、环境准备 2、导入apache相关压缩安装包,然后安装编译环境 3、解压软件包,并移动apr包与apr-util包到安装目录中,并切换到http解压出…...

Dockerfile搭建LNMP运行Wordpress平台

Dockerfile搭建LNMP运行Wordpress平台 一、项目1.1 项目环境1.2 服务器环境1.3 任务需求 二、Linux 系统基础镜像三、Nginx1、建立工作目录2、编写 Dockerfile 脚本3、准备 nginx.conf 配置文件4、生成镜像5、创建自定义网络6、启动镜像容器7、验证 nginx 四、Mysql1、建立工作…...

数据库第十五课-------------非关系型数据库----------Redis

作者前言 🎂 ✨✨✨✨✨✨🍧🍧🍧🍧🍧🍧🍧🎂 ​🎂 作者介绍: 🎂🎂 🎂 🎉🎉&#x1f389…...

BM2 链表内指定区间反转,为什么链表要new一个结点?

链表内指定区间反转_牛客题霸_牛客网 (nowcoder.com) 思路就是&#xff0c;把需要反转的结点放入栈中&#xff0c;然后在弹出来。 /*** struct ListNode {* int val;* struct ListNode *next;* ListNode(int x) : val(x), next(nullptr) {}* };*/#include<stack> class…...

SQL阶段性优化

&#x1f61c;作 者&#xff1a;是江迪呀✒️本文关键词&#xff1a;MySQL、SQL优化、阶段性优化☀️每日 一言&#xff1a;我们要把懦弱扼杀在摇篮中。 一、前言 我们在做系统的过程中&#xff0c;难免会遇到页面查询速度慢&#xff0c;性能差的问题&#xff0c;…...

2023-08-22 Unity Shader 开发入门2 —— Shader 开发介绍

文章目录 一、必备概念1 计算机图形程序接口2 图形接口程序与其他概念的联系 二、Shader 开发1 Shader2 Shader 开发3 需掌握的内容 一、必备概念 1 计算机图形程序接口 ​ 计算机图形程序接口&#xff08;Graphics API&#xff09;是一套可编程的开放标准&#xff0c;不论 2…...

UE5 运行时捕捉外部窗口并嵌入到主窗口

UE5 运行时捕捉外部窗口并嵌入到主窗口的一种方法 创建一个Slate类用于生成一个窗口 .h// Fill out your copyright notice in the Description page of Project Settings.#pragma once#include "CoreMinimal.h" #include "Widgets/SCompoundWidget.h"/*…...

uniapp 使用permission获取录音权限

使用前&#xff0c;需要先配置权限 android.permission.RECORD_AUDIO...

基于paddleocr的文档识别

1、版面分析 使用轻量模型PP-PicoDet检测模型实现版面各种类别的检测。 数据集&#xff1a; 英文&#xff1a;publaynet数据集的训练集合中包含35万张图像&#xff0c;验证集合中包含1.1万张图像。总共包含5个类别。 中文&#xff1a;CDLA据集的训练集合中包含5000张图像&a…...

魏副业而战:闲鱼卖货赚钱策略

我是魏哥&#xff0c;与其躺平&#xff0c;不如魏副业而战&#xff01; 闲鱼卖货有人赚钱&#xff0c;有人不赚钱。 什么原因呢&#xff1f;闲鱼卖货的策略不对。 这不&#xff0c;社群成员小K找我反馈40单赚了150。 利润太低&#xff0c;不在正常范围之内。 魏哥建议继续…...

语法篇--XML数据传输格式

一、XML概述 1.1简介 XML&#xff0c;全称为Extensible Markup Language&#xff0c;即可扩展标记语言&#xff0c;是一种用于存储和传输数据的文本格式。它是由W3C&#xff08;万维网联盟&#xff09;推荐的标准&#xff0c;广泛应用于各种系统中&#xff0c;如Web服务、数据…...

【Redis】缓存雪崩、缓存击穿、缓存穿透

在使用 Redis 缓存时&#xff0c;常常会遇到三个主要的问题&#xff0c;分别是缓存雪崩、缓存击穿和缓存穿透。这些问题都可能导致缓存系统的性能下降或数据不一致性的问题。 一、缓存雪崩&#xff08;Cache Avalanche&#xff09; 缓存雪崩是指在某个时间点&#xff0c;缓存…...

订阅号做微网站需要认证吗/长沙网络推广软件

贵为新一代存储介质&#xff0c;强大IOPS处理能力&#xff0c;以及低延迟性能表现&#xff0c;让闪存足以横扫磁盘并取而代之。所谓理想丰满&#xff0c;现实骨干。如今这样的情形并没有发生&#xff0c;闪存仍然没有跻身主存储&#xff0c;现实的市场规模依然偏小&#xff0c;…...

建博客网站/百度推广怎么做免费

一、安装 MongoDB &#xff1a; Mac 下安装 MongoDB 一般有多种方法&#xff0c;本文介绍“使用 homebrew安装”和“使用安装包安装” 两种方法。 方法一&#xff1a;使用 homebrew安装 一、安装 homebrew &#xff1a; /usr/bin/ruby -e "$(curl -fsSL https://raw.githu…...

网站关键词优化方案/网站建设排名优化

imx6 KEY_ROW4的pin设置成gpio之后&#xff0c;不能够输出高电平。解决方法记录于此。 参考链接&#xff1a; https://lists.yoctoproject.org/pipermail/meta-freescale/2014-January/006271.html 解决方法 打开board-mx6dl_sabresd.h// Tony 2016-11-28,添加下面的宏 #define…...

专业网站设计联系方式/百度首页 百度一下

前面已经介绍过Entity Framework的工作单元和映射层超类型的封装&#xff0c;从本文开始&#xff0c;将逐步介绍仓储以及对查询的扩展支持。 什么是仓储 仓储表示聚合的集合。 仓储所表现出来的集合外观&#xff0c;仅仅是一…...

铁道部建设监理协会网站/百度广告联盟平台官网

http://msdn.microsoft.com/zh-cn/library/cc716729(VS.90).aspxSQL Server 数据类型映射 (ADO.NET) .NET Framework 3.5其他版本3&#xff08;共 3&#xff09;对本文的评价是有帮助 - 评价此主题更新&#xff1a;November 2007 SQL Server 和 .NET Framework 基于不同的类型系…...

上海建站模板厂家/营业推广促销方式有哪些

题目 题目链接 题解 输入为n和m&#xff0c;则输出有2*n1行和2*m1列。 讲一下下面代码中的i&1的含义&#xff0c;这是二进制判断奇偶&#xff0c;速度比i%2快点。 i&1 0表示偶数&#xff0c;i&1 1表示奇数。 注意n或m为0的情况。ljLQB&#xff01;数据总是奇…...