[大模型]GLM4-9B-chat Lora 微调
本节我们简要介绍如何基于 transformers、peft 等框架,对 LLaMA3-8B-Instruct 模型进行 Lora 微调。Lora 是一种高效微调方法,深入了解其原理可参见博客:知乎|深入浅出 Lora。
这个教程会在同目录下给大家提供一个 nodebook 文件,来让大家更好的学习。
环境准备
在 Autodl 平台中租赁一个 3090 等 24G 显存的显卡机器,如下图所示镜像选择 PyTorch-->2.1.0-->3.10(ubuntu22.04)-->12.1
。
接下来打开刚刚租用服务器的 JupyterLab,并且打开其中的终端开始环境配置、模型下载和运行演示。
环境配置
在完成基本环境配置和本地模型部署的情况下,你还需要安装一些第三方库,可以使用以下命令:
python -m pip install --upgrade pip
# 更换 pypi 源加速库的安装
pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simplepip install modelscope==1.9.5
pip install "transformers>=4.40.0"
pip install streamlit==1.24.0
pip install sentencepiece==0.1.99
pip install accelerate==0.29.3
pip install datasets==2.19.0
pip install peft==0.10.0
pip install tiktoken==0.7.0MAX_JOBS=8 pip install flash-attn --no-build-isolation
注意:flash-attn 安装会比较慢,大概需要十几分钟。
考虑到部分同学配置环境可能会遇到一些问题,我们在 AutoDL 平台准备了 GLM-4 的环境镜像,该镜像适用于本教程需要 GLM-4 的部署环境。点击下方链接并直接创建 AutoDL 示例即可。(vLLM 对 torch 版本要求较高,且越高的版本对模型的支持更全,效果更好,所以新建一个全新的镜像。) https://www.codewithgpu.com/i/datawhalechina/self-llm/GLM-4
在本节教程里,我们将微调数据集放置在根目录 /dataset。
模型下载
使用 modelscope 中的 snapshot_download 函数下载模型,第一个参数为模型名称,参数 cache_dir 为模型的下载路径。
在 /root/autodl-tmp 路径下新建 model_download.py 文件并在其中输入以下内容,粘贴代码后记得保存文件,如下图所示。并运行 python /root/autodl-tmp/model_download.py
执行下载。
import torch
from modelscope import snapshot_download, AutoModel, AutoTokenizer
import osmodel_dir = snapshot_download('ZhipuAI/glm-4-9b-chat', cache_dir='/root/autodl-tmp/glm-4-9b-chat', revision='master')
指令集构建
LLM 的微调一般指指令微调过程。所谓指令微调,是说我们使用的微调数据形如:
{"instruction": "回答以下用户问题,仅输出答案。","input": "1+1等于几?","output": "2"
}
其中,instruction
是用户指令,告知模型其需要完成的任务;input
是用户输入,是完成用户指令所必须的输入内容;output
是模型应该给出的输出。
即我们的核心训练目标是让模型具有理解并遵循用户指令的能力。因此,在指令集构建时,我们应针对我们的目标任务,针对性构建任务指令集。例如,在本节我们使用由笔者合作开源的 Chat-甄嬛 项目作为示例,我们的目标是构建一个能够模拟甄嬛对话风格的个性化 LLM,因此我们构造的指令形如:
{"instruction": "你是谁?","input": "","output": "家父是大理寺少卿甄远道。"
}
我们所构造的全部指令数据集在根目录下。
数据格式化
Lora
训练的数据是需要经过格式化、编码之后再输入给模型进行训练的,如果是熟悉 Pytorch
模型训练流程的同学会知道,我们一般需要将输入文本编码为 input_ids,将输出文本编码为 labels
,编码之后的结果都是多维的向量。我们首先定义一个预处理函数,这个函数用于对每一个样本,编码其输入、输出文本并返回一个编码后的字典:
def process_func(example):MAX_LENGTH = 384input_ids, attention_mask, labels = [], [], []instruction = tokenizer((f"[gMASK]<sop><|system|>\n假设你是皇帝身边的女人--甄嬛。<|user|>\n"f"{example['instruction']+example['input']}<|assistant|>\n"), add_special_tokens=False)response = tokenizer(f"{example['output']}", add_special_tokens=False)input_ids = instruction["input_ids"] + response["input_ids"] + [tokenizer.pad_token_id]attention_mask = instruction["attention_mask"] + response["attention_mask"] + [1] # 因为eos token咱们也是要关注的所以 补充为1labels = [-100] * len(instruction["input_ids"]) + response["input_ids"] + [tokenizer.pad_token_id] if len(input_ids) > MAX_LENGTH: # 做一个截断input_ids = input_ids[:MAX_LENGTH]attention_mask = attention_mask[:MAX_LENGTH]labels = labels[:MAX_LENGTH]return {"input_ids": input_ids,"attention_mask": attention_mask,"labels": labels}
GLM4-9B-chat
采用的Prompt Template
格式如下:
[gMASK]<sop><|system|>
假设你是皇帝身边的女人--甄嬛。<|user|>
小姐,别的秀女都在求中选,唯有咱们小姐想被撂牌子,菩萨一定记得真真儿的——<|assistant|>
嘘——都说许愿说破是不灵的。<|endoftext|>
加载 tokenizer 和半精度模型
模型以半精度形式加载,如果你的显卡比较新的话,可以用torch.bfolat
形式加载。对于自定义的模型一定要指定trust_remote_code
参数为True
。
tokenizer = AutoTokenizer.from_pretrained('/root/autodl-tmp/glm-4-9b-chat/ZhipuAI/glm-4-9b-chat', use_fast=False, trust_remote_code=True)model = AutoModelForCausalLM.from_pretrained('/root/autodl-tmp/glm-4-9b-chat/ZhipuAI/glm-4-9b-chat', device_map="auto",torch_dtype=torch.bfloat16, trust_remote_code=True)
定义 LoraConfig
LoraConfig
这个类中可以设置很多参数,但主要的参数没多少,简单讲一讲,感兴趣的同学可以直接看源码。
task_type
:模型类型target_modules
:需要训练的模型层的名字,主要就是attention
部分的层,不同的模型对应的层的名字不同,可以传入数组,也可以字符串,也可以正则表达式。r
:lora
的秩,具体可以看Lora
原理lora_alpha
:Lora alaph
,具体作用参见Lora
原理
Lora
的缩放是啥嘞?当然不是r
(秩),这个缩放就是lora_alpha/r
, 在这个LoraConfig
中缩放就是 4 倍。
config = LoraConfig(task_type=TaskType.CAUSAL_LM, target_modules=["query_key_value", "dense", "dense_h_to_4h", "dense_4h_to_h"], # 现存问题只微调部分演示即可inference_mode=False, # 训练模式r=8, # Lora 秩lora_alpha=32, # Lora alaph,具体作用参见 Lora 原理lora_dropout=0.1# Dropout 比例
)
自定义 TrainingArguments 参数
TrainingArguments
这个类的源码也介绍了每个参数的具体作用,当然大家可以来自行探索,这里就简单说几个常用的。
output_dir
:模型的输出路径per_device_train_batch_size
:顾名思义batch_size
gradient_accumulation_steps
: 梯度累加,如果你的显存比较小,那可以把batch_size
设置小一点,梯度累加增大一些。logging_steps
:多少步,输出一次log
num_train_epochs
:顾名思义epoch
gradient_checkpointing
:梯度检查,这个一旦开启,模型就必须执行model.enable_input_require_grads()
,这个原理大家可以自行探索,这里就不细说了。
args = TrainingArguments(output_dir="./output/GLM4",per_device_train_batch_size=1,gradient_accumulation_steps=8,logging_steps=50,num_train_epochs=2,save_steps=100,learning_rate=1e-5,save_on_each_node=True,gradient_checkpointing=True
)
使用 Trainer 训练
trainer = Trainer(model=model,args=args,train_dataset=tokenized_id,data_collator=DataCollatorForSeq2Seq(tokenizer=tokenizer, padding=True),
)
trainer.train()
保存 lora 权重
lora_path='./GLM4'
trainer.model.save_pretrained(lora_path)
tokenizer.save_pretrained(lora_path)
加载 lora 权重推理
训练好了之后可以使用如下方式加载lora
权重进行推理:
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
from peft import PeftModelmode_path = '/root/autodl-tmp/glm-4-9b-chat/ZhipuAI/glm-4-9b-chat'
lora_path = './GLM4_lora'# 加载tokenizer
tokenizer = AutoTokenizer.from_pretrained(mode_path, trust_remote_code=True)# 加载模型
model = AutoModelForCausalLM.from_pretrained(mode_path, device_map="auto",torch_dtype=torch.bfloat16, trust_remote_code=True).eval()# 加载lora权重
model = PeftModel.from_pretrained(model, model_id=lora_path)prompt = "你是谁?"
inputs = tokenizer.apply_chat_template([{"role": "user", "content": "假设你是皇帝身边的女人--甄嬛。"},{"role": "user", "content": prompt}],add_generation_prompt=True,tokenize=True,return_tensors="pt",return_dict=True).to('cuda')gen_kwargs = {"max_length": 2500, "do_sample": True, "top_k": 1}
with torch.no_grad():outputs = model.generate(**inputs, **gen_kwargs)outputs = outputs[:, inputs['input_ids'].shape[1]:]print(tokenizer.decode(outputs[0], skip_special_tokens=True))
完整脚本参考:
https://github.com/datawhalechina/self-llm/blob/master/GLM-4/05-GLM-4-9B-chat%20Lora%20%E5%BE%AE%E8%B0%83.ipynb
相关文章:
[大模型]GLM4-9B-chat Lora 微调
本节我们简要介绍如何基于 transformers、peft 等框架,对 LLaMA3-8B-Instruct 模型进行 Lora 微调。Lora 是一种高效微调方法,深入了解其原理可参见博客:知乎|深入浅出 Lora。 这个教程会在同目录下给大家提供一个 nodebook 文件,…...
目标检测算法YOLOv9简介
YOLOv9由Chien-Yao Wang等人于2024年提出,论文名为:《YOLOv9: Learning What You Want to Learn Using Programmable Gradient Information》,论文见:https://arxiv.org/pdf/2402.13616 ;源码见: https://github.com/W…...
达梦数据库搭建守护集群
前言 DM 数据守护(Data Watch)是一种集成化的高可用、高性能数据库解决方案,是数据库异地容灾的首选方案。通过部署 DM 数据守护,可以在硬件故障(如磁盘损坏)、自然灾害(地震、火灾)…...
OpenGL-ES 学习(6)---- Ubuntu OES 环境搭建
OpenGL-ES Ubuntu 环境搭建 此的方法在 ubuntu 和 deepin 上验证都可以成功搭建 目录 OpenGL-ES Ubuntu 环境搭建软件包安装第一个三角形基于 glfw 实现基于 X11 实现 软件包安装 sudo apt install libx11-dev sudo apt install libglfw3 libglfw3-dev sudo apt-get install…...
Django学习二:配置mysql,创建model实例,自动创建数据库表,对mysql数据库表已经创建好的进行直接操作和实验。
文章目录 前言一、项目初始化搭建1、创建项目:test_models_django2、创建应用app01 二、配置mysql三、创建model实例,自动创建数据库表1、创建对象User类2、执行命令 四、思考问题(****)1、是否会生成新表呢(答案报错&…...
对象创建的4种模式
1. 工厂模式 这种模式抽象了创建具体对象的过程,用函数来封装以特定接口创建对象的细节 缺点:没有解决对象识别的问题(即怎样知道一个对象的类型) function createPerson(name, age, job) {var o new Object();o.name name;o.ag…...
如何判断 是否 需要 CSS 中的媒体查询
以下是一些常见的使用媒体查询的场景: 响应式布局:当设备的屏幕尺寸变化时,我们可以使用媒体查询来调整布局,以适应不同的屏幕尺寸。 设备特性适配:我们可以使用媒体查询来检测设备的特性,如设备方向、分辨…...
设计模式-装饰器模式(结构型)
装饰器模式 装饰器模式是一种结构模式,通过装饰器模式可以在不改变原有类结构的情况下向一个新对象添加新功能,是现有类的包装。 图解 角色 抽象组件:定义组件的抽象方法具体组件:实现组件的抽象方法抽象装饰器:实现…...
升级HarmonyOS 4.2,开启健康生活篇章
夏日来临,华为智能手表携 HarmonyOS 4.2 版本邀您体验,它不仅可以作为时尚单品搭配夏日绚丽服饰,还能充当你的健康管家,从而更了解自己的身体,开启智能健康生活篇章。 高血糖风险评估优化,健康监测更精准 …...
给gRPC增加负载均衡功能
在现代的分布式系统中,负载均衡是确保服务高可用性和性能的关键技术之一。而gRPC作为一种高性能的RPC框架,自然也支持负载均衡功能。本文将探讨如何为gRPC服务增加负载均衡功能,从而提高系统的性能和可扩展性。 什么是负载均衡? …...
【优选算法】详解target类求和问题(附总结)
目录 1.两数求和 题目: 算法思路: 代码: 2.!!!三数之和 题目 算法思路: 代码: 3.四数字和 题目: 算法思路: 代码: 总结&易错点&…...
【数据结构】图论入门
引入 数据的逻辑结构: 集合:数据元素间除“同属于一个集合”外,无其他关系线性结构:一个对一个,例如:线性表、栈、队列树形结构:一个对多个,例如:树图形结构࿱…...
11_1 Linux NFS服务与触发挂载autofs
11_1 Linux NFS服务与触发挂载服务 文章目录 11_1 Linux NFS服务与触发挂载服务[toc]1. NFS服务基础1.1 示例 2. 触发挂载autofs2.1 触发挂载基础2.2 触发挂载进阶autofs与NFS 文件共享服务:scp、FTP、web(httpd)、NFS 1. NFS服务基础 Netwo…...
开发uniapp 小程序时遇到的问题
1、【微信开发者工具报错】routeDone with a webviewId XXX that is not the current page 解决方案: 在app.json 中添加 “lazyCodeLoading”: “requiredComponents” uniapp的话加到manifest.json下的mp-weixin 外部链接文章:解决方案文章1 解决方案文章2 &qu…...
怎样快速获取Vmware VCP 证书,线上考试,voucher报名优惠
之前考一个VCP证书,要花大一万的费用,可贵了,考试费不贵,贵就贵在培训费,要拿到证书,必须交培训费,即使vmware你玩的很溜,不需要再培训了,但是一笔贵到肉疼的培训费你得拿…...
LeetCode 1141, 134, 142
目录 1141. 查询近30天活跃用户数题目链接表要求知识点思路代码 134. 加油站题目链接标签普通版思路代码 简化版思路代码 142. 环形链表 II题目链接标签思路代码 1141. 查询近30天活跃用户数 题目链接 1141. 查询近30天活跃用户数 表 表Activity的字段为user_id,…...
华为FPGA工程师面试题
FPGA工程师面试会涉及多个方面,包括基础知识、项目经验、编程能力、硬件调试和分析等。以下是一些必问的面试题: 基础知识题: 请解释FPGA的基本组成和工作原理。描述FPGA中的可编程互联资源以及它们在构建复杂数字电路中的作用。请解释嵌入式多用途块(如BRAM、DSP slices、…...
Windows11上安装docker(WSL2后端)和使用docker安装MySQL和达梦数据库
Windows11上安装docker(WSL2后端)和使用docker安装MySQL和达梦数据库 1. 操作系统环境2. 首先安装wsl2.1 关于wsl2.2 安装wsl2.3 查看可用的wsl2.4 安装ubuntu-22.042.5 查看、启动ubuntu-22.04应用2.6 上面安装开了daili2.7 wsl的更多参考 3. 下载Docke…...
UnityXR Interactable Toolkit如何实现Climb爬梯子
前言 在VR中,通常会有一些交互需要我们做爬梯子,爬墙的操作,之前用VRTK3时,里面是还有这个Demo的,最近看XRI,发现也除了一个爬的示例,今天我们就来讲解一下 如何在Unity中使用XR Interaction Toolkit实现爬行(Climb)操作 环境配置 步骤 1:设置XR环境 确保你的Uni…...
sqli-labs 靶场 less-11~14 第十一关、第十二关、第十三关、第十四关详解:联合注入、错误注入
SQLi-Labs是一个用于学习和练习SQL注入漏洞的开源应用程序。通过它,我们可以学习如何识别和利用不同类型的SQL注入漏洞,并了解如何修复和防范这些漏洞。Less 11 SQLI DUMB SERIES-11判断注入点 尝试在用户名这个字段实施注入,且试出SQL语句闭合方式为单…...
国内外网络安全现状分析
一、国内网络安全现状 1.1 国内网络安全威胁 国内的网络安全威胁主要表现在以下几个方面: 恶意软件:包括计算机病毒、蠕虫、木马和间谍软件等,它们能感染计算机系统、窃取敏感信息或破坏系统功能。网络钓鱼:通过伪装成可信任的…...
vscode copilot git commit 生成效果太差,用其他模型替换
问题 众所周知,copilot git commit 就像在随机生成 git commit 这种较为复杂的内容还是交给大模型做比较合适 方法 刚好,gitlens 最近开发了 AI commit的功能,其提供配置url api可以实现自定义模型 gitlens 只有3种模型可用:…...
计算机毕业设计hadoop+spark+hive舆情分析系统 微博数据分析可视化大屏 微博情感分析 微博爬虫 微博大数据 微博推荐系统 微博预测系统
本 科 毕 业 论 文 论文题目:基于Hadoop的热点舆情数据分析与可视化 姓名: 金泓羽 学号: 20200804050115 导师: 关英 职称&…...
【MySQL】(基础篇二) —— MySQL初始用
MySQL初始用 目录 MySQL初始用基本语法约定选择数据库查看数据库和表其它的SHOW 在Navicat中,大部分数据库管理相关的操作都可以通过图形界面完成,这个很简单,大家可以自行探索。虽然Navicat等图形化数据库管理工具为操作和管理数据库提供了非…...
计算机网络 期末复习(谢希仁版本)第4章
路由器:查找转发表,转发分组。 IP网的意义:当互联网上的主机进行通信时,就好像在一个网络上通信一样,看不见互连的各具体的网络异构细节。如果在这种覆盖全球的 IP 网的上层使用 TCP 协议,那么就…...
如何使用Pandas处理数据?
一、技术难点 Pandas是Python中一个强大的数据处理和分析库,它提供了高效、灵活且易于使用的数据结构,主要用于数据清洗、转换、聚合和可视化等任务。然而,在使用Pandas处理数据时,也会遇到一些技术难点。 数据导入与导出&#…...
Error: spawn xdg-open ENOENT
报错:The CJS build of Vite’s Node API is deprecated. See https://vitejs.dev/guide/troubleshooting.html#vite-cjs-node-api-deprecated for more details. VITE v5.1.4 ready in 2298 ms ➜ Local: http://localhost:80/ ➜ Network: http://10.0.4.13:80/ ➜…...
写给大数据开发,如何去掌握数据分析
这篇文章源于自己一个大数据开发,天天要做分析的事情,发现数据分析实在高大上很多,写代码和做汇报可真比不了。。。。 文章目录 1. 引言2. 数据分析的重要性2.1 技能对比2.2 业务理解的差距 3. 提升数据分析能力的方向4. 数据分析的系统过程4…...
大数据湖一体化运营管理建设方案(49页PPT)
方案介绍: 本大数据湖一体化运营管理建设方案通过构建统一存储、高效处理、智能分析和安全管控的大数据湖平台,实现了企业数据的集中管理、快速处理和智能分析。该方案具有可扩展性、高性能、智能化、安全性和易用性等特点,能够为企业数字化…...
大模型训练的艺术:从预训练到增强学习的四阶段之旅
文章目录 大模型训练的艺术:从预训练到增强学习的四阶段之旅1. 预训练阶段(Pretraining)2. 监督微调阶段(Supervised Finetuning, SFT)3. 奖励模型训练阶段(Reward Modeling)4. 增强学习微调阶段…...
创建网页用什么软件/深圳网站设计专业乐云seo
在这个例子中,主要会用到python内置的和OS模块的几个函数:os.walk() : 该方法用来遍历指定的文件目录,返回一个三元tuple(dirpath, dirnames, filenames) ,其中dirpath为当前目录路径,dirnames为当前路径下…...
网站制作用到什么技术/哪个平台可以免费打广告
注册一个npm账户 1.项目生成package.json文件 npm init --y 2.获取关联的镜像 npm config get registry2.1 配置全局使用npm官方镜像源 npm config set registry https://registry.npmjs.org很多人因下包方便都配置为了https://registry.npm.taobao.org 3.发布 npm publ…...
视频发布播放网站建设/中山网站seo
例21题目:猴子吃桃问题:猴子第一天摘下若干个桃子,当即吃了一半,还不瘾,又多吃了一个第二天早上又将剩下的桃子吃掉一半,又多吃了一个。以后每天早上都吃了前一天剩下的一半零一个。到第10天早上想再吃时&a…...
wordpress 限制游客/企业网站策划
调度算法一、先来先服务FCFS (First Come First Serve)1.思想:选择最先进入后备/就绪队列的作业/进程,入主存/分配CPU2.优缺点优点:对所有作业/进程公平,算法简单稳定缺点:不够灵活,对紧急进程的优先处理权…...
建立免费网站/百度关键词优化软件排名
intellij 开发webservice 最近项目中有用到WebService,于是就研究了一下,但是关于intellij 开发 WebService 的文章极少,要不就是多年以前,于是研究一下,写这篇博文。纯属记录,分享,中间有不对的…...
寻找网站建设 网站外包/嘉兴网站建设
第一种方法: 1、详细查询命令: 查看cpu最大进程,或者内存最大进程。 #CPU ps aux|head -1;ps aux|grep -v PID|sort -rn -k 3|head #内存 ps aux|head -1;ps aux|grep -v PID|sort -rn -k 4|head显示如下: ubuntuubuntu:~$ ps…...