大语言模型微调实践——LoRA 微调细节
1. 引言
近年来人工智能领域不断进步,大语言模型的崛起引领了自然语言处理的革命。这些参数量巨大的预训练模型,凭借其在大规模数据上学习到的丰富语言表示,为我们带来了前所未有的文本理解和生成能力。然而,要使这些通用模型在特定任务上发挥出色,还需借助微调技术。大语言模型的微调技术已经成为自然语言处理领域的一个焦点,其不断的演化和创新正引领着我们进入一个更加精细、个性化的文本处理时代。
在本文中,我们将选取目前大语言模型热点任务——代码生成,结合 StarCoder 模型微调实践介绍高效微调方法——LoRA。
2. LoRA 微调原理
论文:LoRA: Low-Rank Adaptation of Large Language Models
LoRA 基于大模型的内在低秩特性,增加旁路矩阵来模拟全参数微调,是目前最通用、效果最好的微调方法之一,而且能和其它参数高效微调方法有效结合。利用该方法对 175B GPT-3 微调,需要训练更新的参数量可以小到全量微调参数量的 0.01%。
图1. LoRA原理
上图为 LoRA 的实现原理,其实现流程为:
在原始预训练语言模型旁边增加一个旁路,做降维再升维的操作来模拟内在秩;
用随机高斯分布初始化 A,用零矩阵初始化B,训练时固定预训练模型的参数,只训练矩阵 A 与矩阵 B ;
训练完成后,将 B 矩阵与 A 矩阵相乘后合并预训练模型参数作为微调后的模型参数。
研究表明,Transformer 等神经网络包含许多执行矩阵乘法的密集层,这些权重通常具有满秩。预训练的语言模型具有较低的“本征维度(Instrinsic Dimension)”,并且可以和完整参数空间一样进行有效学习。受此启发,本文在微调过程中假设权重的更新也具有较低的“本征维度”。对于预训练模型的权重矩阵 ,通过低秩分解(Low-Rank Decomposition)来表示约束其更新。训练过程中 被固定不再进行梯度更新,只训练 和 ,其中 。训练结束后,更新参数为 。对于输入 ,模型的前向传播过程更新为 。
由于模型整体参数量不变,所以不会降低推理时的性能。作者通过实验比较了在内容理解任务、生成任务上的效果,相比全量微调参数量显著降低,性能上持平甚至超过,相比其他高效微调方法,增加参数量不会导致性能下降。需要注意的是此方法对低秩矩阵的秩数和目标模块的选择比较敏感,可能影响模型的性能和稳定性。使用LoRA微调有以下几个细节:
对哪些参数进行微调:基于 Transformer 结构,LoRA 只对每层的 Self-Attention 的部分进行微调,有 四个映射层参数可以进行微调。需要注意不同模型参数名称不同,像 StarCoder 模型 Multi-query 结构的 attention 层对应的参数名称是
attn.c_attn
,attn.c_proj
Rank(r) 的选取:Rank 的取值作者对比了 1-64,效果上 Rank 在 4-8 之间最好,再高并没有效果提升。不过论文的实验是面向下游单一监督任务的,因此在指令微调上根据指令分布的广度,Rank选择还是需要在 8 以上的取值进行测试。
alpha 参数选取:alpha 其实是个缩放参数,训练后权重 merge 时的比例为
alpha/r
初始化:矩阵A是 Uniform 初始化,B 是零初始化,这样最初的 lora 权重为 0,所以 lora 参数是从头学起,并没有那么容易收敛。
3. LoRA 微调实践
本节以 StarCoder 微调为例,介绍使用 LoRA 微调的实践过程。
首先,StarCoder 是使用 86 种编程语言的 1 万亿个 token 训练,并在另外 35billion Python token 上微调出的模型,专注于解决编程问题,模型结构为:"GPTBigCodeForCausalLM",40层 decoder-only Transformer,Attention 层结构为 Multi-query,参数量约 15.5B。
3.1 环境配置
实例环境:A800 + python3.8 + torch2.0 + CUDA11.6
python环境:主要坑在 transforemrs 和 peft,这两个包建议使用"Development Mode"安装
环境中主要包的版本:
tqdm==4.65.0
transformers=4.31.0.dev0
peft=0.4.0.dev0
datasets==2.11.0
huggingface-hub==0.13.4
accelerate==0.18.0
3.2 模型加载
以下代码主要整合自 alpaca-lora 项目和 StarCoder 的 finetune 项目。其实 LoRA 微调的代码本身并不复杂,但是对于如何加速大模型训练,如何以时间换空间的降低显存占用处理值得学习。模型初始化代码如下,get_peft_model 会初始化 PeftModel 把原模型作为 base 模型,并在各个 self-attention 层加入 LoRA 层,同时改写模型 forward 的计算方式。主要说下 load_in_8bit
,prepare_model_for_int8_training
和 get_peft_model
分别做了哪些操作。
from accelerate import Accelerator
from peft import LoraConfig, get_peft_model, prepare_model_for_int8_training
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, Trainermodel = AutoModelForCausalLM.from_pretrained(args.model_path,use_auth_token=True,use_cache=True,load_in_8bit=True,device_map={"": Accelerator().process_index},)model = prepare_model_for_int8_training(model)lora_config = LoraConfig(r=16,lora_alpha=32,lora_dropout=0.05,bias="none",task_type="CAUSAL_LM",target_modules = ["attn.c_proj", "attn.c_attn"]
)model = get_peft_model(model, lora_config)
模型加载时,load_in_8bit=True
的 8bit 量化优化的是静态显存,是 bitsandbytes 库赋予的能力,会把加载模型转化成混合 8bit 的量化模型。模型量化本质是对浮点参数进行压缩的同时,降低压缩带来的误差。8bit quantization是把原始 fp32(4字节)压缩到 int8(1字节)也就是 1/4 的显存占用。我们主要关注 attention
层的情况:
Parameter name: transformer.h.0.ln_1.weight
Data type: torch.float16Parameter name: transformer.h.0.ln_1.bias
Data type: torch.float16Parameter name: transformer.h.0.attn.c_attn.weight
Data type: torch.int8Parameter name: transformer.h.0.attn.c_attn.bias
Data type: torch.float16Parameter name: transformer.h.0.attn.c_proj.weight
Data type: torch.int8Parameter name: transformer.h.0.attn.c_proj.bias
Data type: torch.float16
通过第一层模型可以看出,这一步,attention 层 c_attn 和 c_proj 的 weight 设为 int8,其他为 fp16。
下面,prepare_model_for_int8_training
是对在 LoRA 微调中使用 LLM.int8() 进行了适配用来提高训练的稳定性,主要包括
layer norm 层保留 fp32 精度
输出层保留 fp32 精度保证解码时随机 sample 的差异性
操作后区别如下:
Parameter name: transformer.h.0.ln_1.weight
Data type: torch.float32Parameter name: transformer.h.0.ln_1.bias
Data type: torch.float32Parameter name: transformer.h.0.attn.c_attn.weight
Data type: torch.int8Parameter name: transformer.h.0.attn.c_attn.bias
Data type: torch.float32Parameter name: transformer.h.0.attn.c_proj.weight
Data type: torch.int8Parameter name: transformer.h.0.attn.c_proj.bias
Data type: torch.float32
prepare_model_for_int8_training
还设置了 gradient_checkpointing=True
,这是一个时间换空间的技巧。gradient checkpoint
的实现是在前向传播的过程中使用 torch.no_grad()
不存储中间激活值,降低动态显存的占用,而只保存输入和激活函数,当进行反向传播的时候,会重新获取输入并计算激活值用于梯度计算。因此前向传播会计算两遍,所以需要更多的训练时间。
第三步 get_peft_model
的操作后,区别如下:
Parameter name: base_model.model.transformer.h.0.attn.c_attn.lora_A.default.weight
Data type: torch.float32
Require grads: TrueParameter name: base_model.model.transformer.h.0.attn.c_attn.lora_B.default.weight
Data type: torch.float32
Require grads: TrueParameter name: base_model.model.transformer.h.0.attn.c_proj.lora_A.default.weight
Data type: torch.float32
Require grads: TrueParameter name: base_model.model.transformer.h.0.attn.c_proj.lora_B.default.weight
Data type: torch.float32
Require grads: True
在 attention 层的 c_attn 和 c_proj 添加 LoRA 层,数据类型为 fp32,并且需要梯度计算。
3.3 模型训练
模型训练的代码如下,和常规训练基本相同,需要注意模型存储和混合精度训练。StarCoder 项目推荐使用的数据集是 stack-exchange-instruction。Stack Exchange 是一个著名的问答网站,涉及不同领域的主题,用户可以在这里提出问题并从其他用户那里获得答案。这些答案根据其质量进行评分和排名。此数据集构建的即为问答对集合。可以在该数据集上微调语言模型,激活模型的问答技能。
train_dataset, eval_dataset = create_datasets(tokenizer, args)training_args = TrainingArguments(output_dir=args.output_dir,evaluation_strategy="steps",max_steps=args.max_steps,eval_steps=100,save_steps=100,per_device_train_batch_size=1,learning_rate=5e-6,gradient_accumulation_steps=16,fp16=True,report_to="wandb",)trainer = Trainer(model=model, args=training_args, train_dataset=train_data, eval_dataset=val_data, callbacks=[SavePeftModelCallback, LoadBestPeftModelCallback)trainer.train()model.save_pretrained(os.path.join(args.output_dir, "final_checkpoint/"))
(1)模型存储
需要注意 PeftModel
重写了原始 model 的 save_pretrained
函数,只把 LoRA 层的权重进行存储,因此 model.save_pretrained
只会存储 LoRA 权重。
(2)混合精度训练
实现原理是并非所有变量都需要全精度存储,如果把部分中间变量转化成半精度,则计算效率会大幅提升,加上一些 GPU 对 fp16 计算做了优化,吞吐上比全精度会快 2~5 倍。不过只使用半精度训练会带来量化误差,所以混合精度训练前向传播时,模型权重、激活值和梯度都使用 fp16 进行存储,同时会拷贝一份模型权重以 fp32 存储,反向传播 optimizer 更新时会更新 fp32 的参数。因此其并不会节省内存,只会提高模型训练速度,同时使精度不下降太多。
3.4 merge 权重
LoRA 权重和原始模型权重进行合并,合并后的参数会存储成新的 bin 文件,然后和加载常规模型一样加载合并后的模型参数进行推理,不会引入推理延迟,权重合并关键代码如下:
base_model = AutoModelForCausalLM.from_pretrained(args.base_model_name_or_path,load_in_8bit=False,torch_dtype=torch.float16,device_map={"": "cuda:0"},
)
lora_model = PeftModel.from_pretrained(base_model,args.peft_model_path,device_map={"": "cuda:0"},torch_dtype=torch.float16,
)model = lora_model.merge_and_unload()lora_model.train(False)tokenizer = AutoTokenizer.from_pretrained(args.base_model_name_or_path)
model.save_pretrained(f"{args.save_model_path}-merged")
tokenizer.save_pretrained(f"{args.save_model_path}-merged")
至此,StarCoder 模型的 LoRA 微调过程就完成了。
3.5 模型推理
下面我们可以看一下激活问答能力后,模型的表现。
问题:Based on the torch framework, please write a handwritten digital processing model.
回答:
Answer: I think you can use the torch.nn.Conv2d module.
import torch
import torch.nn as nnclass MyModel(nn.Module):def __init__(self):super(MyModel, self).__init__()self.conv1 = nn.Conv2d(1, 1, 3, padding=1)self.conv2 = nn.Conv2d(1, 1, 3, padding=1)def forward(self, x):x = self.conv1(x)x = self.conv2(x)return xmodel = MyModel()
This is a simple model with two convolutional layers.
input = torch.randn(1, 1, 10, 10)
output = model(input)
This is how you can use it.
通过以上回答,我们可以看到精调后的模型激活了问答能力。
4. 总结
在本文中,我们探讨了 LoRA 微调方法,并以 StarCoder 模型的微调为例介绍了实践过程。通过实践过程的经验来为大家展示一些细节及需要注意的点,希望大家也能通过这种低资源高效微调方法微调出符合自己需求的模型。
参考
[1] LoRA: Low-Rank Adaptation of Large Language Models
[2] https://github.com/bigcode-project/starcoder
[3] https://github.com/tloen/alpaca-lora
[4] 苏剑林,梯度视角下的LoRA:简介、分析、猜测及推广
相关文章:
大语言模型微调实践——LoRA 微调细节
1. 引言 近年来人工智能领域不断进步,大语言模型的崛起引领了自然语言处理的革命。这些参数量巨大的预训练模型,凭借其在大规模数据上学习到的丰富语言表示,为我们带来了前所未有的文本理解和生成能力。然而,要使这些通用模型在特…...
国内ChatGPT对比与最佳方案
很久没写内容了,主要还是工作占据了太多时间。简单分享下我这段时间的研究吧,由于时间仓促,有很多内容没有具体写,请自行到我分享的网站体验查看。 前言 ChatGPT 的出现确实在很大程度上改变了世界。许多人已经亲身体验到了ChatGPT作为一个…...
绝美的古诗词AI作画,惊艳到我了!
前言 时光荏苒,科技的飞速发展催生出了许多令人惊叹的创新成果。近年来,人工智能技术在艺术领域的应用日益引人注目,其中最为引人瞩目的莫过于AI作画。这项技术将传统的古诗词与现代的人工智能相结合,创造出一幅幅令人叹为观止的…...
数据结构—排序
8.排序 8.1排序的概念 什么是排序? 排序:将一组杂乱无章的数据按一定规律顺序排列起来。即,将无序序列排成一个有序序列(由小到大或由大到小)的运算。 如果参加排序的数据结点包含多个数据域,那么排序往…...
GraphScope,开源图数据分析引擎的领航者
文章首发地址 GraphScope是一个开源的大规模图数据分析引擎,由Aliyun、阿里巴巴集团和华为公司共同开发。GraphScope旨在为大规模图数据处理和分析提供高性能、高效率的解决方案。 Github地址: https://github.com/alibaba/GraphScope GraphScope 的重…...
【Linux】邮件服务器搭建 postfix+dovecot+mysql (终极版 超详细 亲测多遍无问题)
🍁博主简介 🏅云计算领域优质创作者 🏅华为云开发者社区专家博主 🏅阿里云开发者社区专家博主 💊交流社区:运维交流社区 欢迎大家的加入! 文章目录 前言基础原理准备工作一 、安装关于权…...
GitLab与GitLab Runner安装(RPM与Docker方式),CI/CD初体验
背景 GitLab 是一个强大的版本控制系统和协作平台,记录一下在实际工作中关于 GitLab 的安装使用记录。 一开始使用 GitLab 时,是在 CentOS7 上直接以 rpm 包的方式进行安装,仅作为代码托管工具来使用,版本: 14.10.4 …...
vue3+element下拉多选框组件
<!-- 下拉多选 --> <template><div class"select-checked"><el-select v-model"selected" :class"{ all: optionsAll, hidden: selectedOptions.data.length < 2 }" multipleplaceholder"请选择" :popper-app…...
Python科研绘图--Task02
目录 图形元素 画布 (fifigure)。 坐标图形 (axes),也称为子图。 轴 (axis) :数据轴对象,即坐标轴线。 刻度 (tick),即刻度对象。 图层顺序 轴比例和刻度 轴比例 刻度位置和刻度格式 坐标系 直角坐标系 极坐标系 地理…...
[保研/考研机试] KY11 二叉树遍历 清华大学复试上机题 C++实现
题目链接: 二叉树遍历_牛客题霸_牛客网编一个程序,读入用户输入的一串先序遍历字符串,根据此字符串建立一个二叉树(以指针方式存储)。题目来自【牛客题霸】https://www.nowcoder.com/share/jump/43719512169254700747…...
【官方中文文档】Mybatis-Spring #简介
简介 什么是 MyBatis-Spring? MyBatis-Spring 会帮助你将 MyBatis 代码无缝地整合到 Spring 中。它将允许 MyBatis 参与到 Spring 的事务管理之中,创建映射器 mapper 和 SqlSession 并注入到 bean 中,以及将 Mybatis 的异常转换为 Spring 的…...
稳定扩散ControlNet v1.1 权威指南
ControlNet 是一种稳定扩散模型,可让你从参考图像中复制构图或人体姿势。 经验丰富的稳定扩散用户知道生成想要的确切成分有多难。图像有点随机。你所能做的就是玩数字游戏:生成大量图像并选择你喜欢的图片。 借助 ControlNet,稳定扩散用户…...
【golang】结构体及其方法的使用(struct)
函数是独立的程序实体。我们可以声明有名字的函数,也可以声明没名字的函数,还可以把它们当做普通的值传来传去。我们能把具有相同签名的函数抽象成独立的函数类型,以作为一组输入、输出(或者说一类逻辑组件)的代表。 …...
【数据结构】-- 排序算法习题总结
排序 时间复杂度 空间复杂度 稳定性 冒泡排序 O(n^2) 优化后O(n) O(1) 稳定 快速排序 最好O(n*logn) 最坏O(n^2) 最好O(logn) 最坏O(n) 不稳定直接插入排序…...
第十章 CUDA流(stream)实战篇
cuda教程目录 第一章 指针篇 第二章 CUDA原理篇 第三章 CUDA编译器环境配置篇 第四章 kernel函数基础篇 第五章 kernel索引(index)篇 第六章 kenel矩阵计算实战篇 第七章 kenel实战强化篇 第八章 CUDA内存应用与性能优化篇 第九章 CUDA原子(atomic)实战篇 第十章 CUDA流(strea…...
如何进行电脑文件夹分类与整理?
本科电脑用了四年,毕业后发现空间很满,但是真正有用的东西仿佛就一点。好像是在学开发的时候,听到一个老师说,根目录不要放太多文件夹,不然就相当于没有根目录了。刚好研究生有了新的台式电脑,开始有规划的…...
kafka-python 消费者消费不到消息
排除步骤1: 使用group_id”consumer_group_id_001“ 和 auto_offset_reset"earliest" from kafka import KafkaConsumerconsumer KafkaConsumer(bootstrap_servers["dev-kafka01.test.xxx.cloud:9092"],enable_auto_commitTrue, auto_commit…...
穿起“新架构”的舞鞋,跳一支金融数字化转型的华尔兹
华尔兹,是男女两位舞者,通过形体的控制,舞步技巧的发挥,完美配合呈现而出的一种舞蹈形式。华尔兹舞姿,如行云流水、潇洒自如、飘逸优美,素有“舞中皇后”的美称。 在跳华尔兹的时候,如果舞者双…...
SpringBoot 常用注解
随着Spring及Spring Boot的发展,基于Java的配置已经慢慢替代了基于xml的配置形式。本篇文章为大家整理和简介Spring Boot中常用的注解及其功能。 SpringBoot注解 SpringBootApplication:开启Spring Boot自动配置的核心注解,相关等同于Configu…...
k8s deployment创建pod流程图
参考 k8s 创建pod和deployment的流程 - SoulChild随笔记...
C++ 逗号运算符
使用逗号运算符是为了把几个表达式放在一起。 整个逗号表达式的值为系列中最后一个表达式的值。 从本质上讲,逗号的作用是将一系列运算按顺序执行。 表达式1, 表达式2求解过程是:先求解表达式 1,再求解表达式 2。整个逗号表达式的值是表达…...
jdbc集成phoneix hbase
为什么使用jdbc集成 需求简单,只是往phoneix存储数据原本项目已经有mysql的mybatis plus集成,如果采用dataSource方式就需要采用多数据源的方式,造成架构复杂化,使用复杂化,并且修改地方过多。 Qualifier("phoe…...
16.遍历二叉树,线索二叉树
目录 一. 遍历二叉树 (1)三种遍历方式 (2)递归遍历算法 (3)非递归遍历算法 (4)层次遍历算法 二. 基于递归遍历算法的二叉树有关算法 (1)二叉树的建立 …...
电商平台按关键字搜索商品淘宝京东拼多多api接口PHP示例
关键词搜索商品接口的作用是通过调用接口来实现在电商平台中进行商品搜索。具体而言,该接口可以提供以下功能和作用: 商品搜索:用户可以通过输入关键词,在电商平台上进行商品搜索。接口可以根据关键词对商品的名称、描述、标签等…...
胖小酱之恰恰是什么
意思是:指所指的事物截然不同,正好相反。 恰恰相反的近义词:事与愿违、适得其反 一、事与愿违 [ sh yǔ yun wi ] 【解释】:事实与愿望相反。指原来打算做的事没能做到。 【出自】:茅盾《子夜》十六:不…...
豪越科技受邀出席2023中国算力大会
2023年8月17日-8月20日,“算汇银川 数创未来”创新中国行走进银川暨2023中国算力大会在银川中关村创新中心召开。政府领导、行业领袖、专家学者、以及大型科技企业负责人齐聚大会现场,围绕算力基础设施建设、创新应用和产业发展成果等方面开展广泛交流与…...
python脚本——批量将word文件转换成多张图片
前提:有时候需要快速查看word文档的内容是否自己需要的,或者就是单纯需要将word文档转换成一张张图片。 思路:word文档直接生成图片比较蛮烦,可能会引起格式变化,就先将word文档转换成PDF,然后将PDF文档转…...
FairyGUI编辑器的弹窗操作【插件】
之前在FairyGUI编辑器菜单扩展中,我使用了App.Alert("复制失败")来提示操作是否成功。这篇则会说一下我们可以使用的弹窗提示,以及做到类似资源发布成功时的“发布成功”飘窗。 打开APP的API脚本,可以看到有很多公开方法ÿ…...
Elasticsearch(十三)搜索---搜索匹配功能④--Constant Score查询、Function Score查询
一、前言 之前我们学习了布尔查询,知道了filter查询只在乎查询条件和文档的匹配程度,但不会根据匹配程度对文档进行打分,而对于must、should这两个布尔查询会对文档进行打分,那如果我想在查询的时候同时不去在乎文档的打分&#…...
直播系统源码协议探索篇(二):网络套接字协议WebSocket
上一篇我们分析了直播平台的会话初始化协议SIP,他关乎着直播平台的实时通信和多方互动技术的实现,今天我们来讲另一个协议,叫网络套接字协议WebSocket,WebSocket基于TCP在客户端与服务器建立双向通信的网络协议,并且可…...
手机网站制作教程视频教程/肇庆网络推广
Kubernetes 全面拥抱微服务架构,其具备良好的横向扩容能力,并构建在 Google 15 年生产环境经验、每周运行数 10 亿个容器的目标基础之上。Kubernetes 很好的结合了来自社区的创意和最佳实践。Kubernetes 是目前唯一被业界广泛认可的 Docker 分布式解决方…...
石家庄网站建设seo优化营销/网站建设需要多少钱
2017第五届CCF大数据与计算智能大赛(BDCI)启动仪式于本9月24日正式启动。 CCF大数据与计算智能大赛(Big Data & Computing Intelligence Contest,简称“BDCI”)是由中国计算机学会(CCF)主办,中国计算机学会大数据专家委员会(CCF-TFBD)、中国计算机学会高性能计算…...
做百科需要用什么网站做参考/种子库
微友提问您好,凌老师!有没有专门针对教学中PPT课件制作的书推荐或是网课,如果有光盘更好。因为每次做的课件都要搜索好多网页才弄出一个效果,想暑假好好学习一下。下面是我看到过印象比较深刻的一些效果,不知道是如何实…...
b2c网站特点/seo排名工具给您好的建议
摘要:了解 TreeView Web 控件,并学习如何在 ASP.NET Web 应用程序中使用 TreeView Web 控件。除了标准的 ASP.NET Web 控件(例如 TextBox、DropDownList、DataGrid、DataList 等)之外,Microsoft 还发布了附加的 Web 控…...
软件园二期做网站的公司有哪些/注册一个网站
最近做了一个小工具,可以将XML和Excel之前互转。 里面用到的XML读写库是tinyxml,在Excel2010上运行,请先确保装了Excel,而不是WPS。 代码写的比较挫,一大坨,最近忙也懒得去做优化了。 github地址&#x…...
菏泽做网站多少钱/属于网络营销的特点是
Tomcat版本问题,servlet乱码问题 我在学习的时候,老师用的是Tomcat1.7版本,在jsp发送get请求的时候,Servlet中还要对get请求传递过来的参数进行解码编码,因为tomcat1.7版本之前的内部编码为ISO8859-1,然而在…...