GPT实战系列-ChatGLM2模型的微调训练参数解读
GPT实战系列-ChatGLM2模型的微调训练参数解读
目录
- GPT实战系列-ChatGLM2模型的微调训练参数解读
- ChatGLM2模型
- 1、P-Tuning模型微调
- 2、微调训练配置参数
- train.sh中配置参数
- 训练配置信息
- 模型配置信息
- 附录:训练正常运行打印信息
ChatGLM2模型
ChatGLM-6B是开源的文本生成式对话模型,基于General Language Model(GLM)框架,具有62亿参数,FP16 半精度下,ChatGLM-6B 需要 13GB 左右的显存进行推理。
ChatGLM-6B另一个突出优点是 可以部署在消费级显卡上,结合模型蒸馏技术和模型量化技术,可以进一步降低到 10GB(INT8) 和 6GB(INT4)。实测在2080ti显卡预测上(INT4)显存占用6G左右。
1、P-Tuning模型微调
P-Tuning的全称是Prefix-tuning,意为“前缀调优”。它通过在模型输入前添加小段Discrete prompt(类似填空句),并只优化这个prompt来实现模型微调。P-tuning-v2是基于Prompt-tuning方法的NLP模型微调技术。总体来说,P-tuning-v2是Prompt tuning技术的升级版本,使得Prompt的表示能力更强,应用也更灵活广泛。它被认为是Prompt tuning类方法中效果最优且易用性最好的版本。
代码实现对于 ChatGLM2-6B 模型基于 P-Tuning v2 的微调。P-Tuning v2 将需要微调的参数量,减少到原来的 0.1%,再通过模型量化、Gradient Checkpoint 等方法,预测最低只需要 7GB 显存即可运行。
将训练和测试数据解压后的 AdvertiseGen 目录放到ptuning目录下。
ChatGLM2的训练源代码:https://github.com/THUDM/ChatGLM2-6B
文件目录结构:
├── FAQ.md
├── MODEL_LICENSE
├── README.md 说明文档
├── README_EN.md
├── api.py
├── cli_demo.py
├── evaluation
│ ├── README.md
│ └── evaluate_ceval.py
├── openai_api.py
├── ptuning
│ ├── README.md 说明文档
│ ├── arguments.py
│ ├── deepspeed.json
│ ├── ds_train_finetune.sh
│ ├── evaluate.sh
│ ├── evaluate_finetune.sh
│ ├── main.py
│ ├── train.sh 训练脚本
│ ├── train_chat.sh
│ ├── trainer.py
│ ├── trainer_seq2seq.py
│ ├── web_demo.py
│ └── web_demo.sh 测试脚本
├── requirements.txt 环境依赖文件
├── resources
│ ├── WECHAT.md
│ ├── cli-demo.png
│ ├── knowledge.png
│ ├── long-context.png
│ ├── math.png
│ ├── web-demo.gif
│ ├── web-demo2.gif
│ └── wechat.jpg
├── utils.py
├── web_demo.py
└── web_demo2.py
2、微调训练配置参数
训练之前,需要根据自己的训练需求,训练数据和机器配置情况修改配置参数。
train.sh中配置参数
PRE_SEQ_LEN=128 # soft prompt 长度
LR=2e-2 # 训练学习率
NUM_GPUS=2 # GPU卡的数量torchrun --standalone --nnodes=1 --nproc-per-node=$NUM_GPUS main.py \--do_train \ # 执行训练功能,还可以执行评估功能--train_file AdvertiseGen/train.json \ # 训练文件目录--validation_file AdvertiseGen/fval.json \ # 验证文件目录--prompt_column content \ # 训练集中prompt提示名称,对应训练文件,测试文件的"content"--response_column summary \ # 训练集中答案名称,对应训练文件,测试文件的"summary"--overwrite_cache \ # 缓存,重复训练一次的时候可删除--model_name_or_path THUDM/chatglm-6b \ # 加载模型文件目录,也可修改为本地模型的路径--output_dir output/adgen-chatglm-6b-pt-$PRE_SEQ_LEN-$LR \ # 保存训练模型文件目录--overwrite_output_dir \ # 覆盖训练文件目录--max_source_length 64 \ # 最大输入文本的长度--max_target_length 128 \--per_device_train_batch_size 1 \ # batch_size 训练批次根据显存调节--per_device_eval_batch_size 1 \ # 验证批次--gradient_accumulation_steps 16 \ # 梯度累加的步数--predict_with_generate \--max_steps 3000 \ # 最大训练模型的步数--logging_steps 10 \ # 多少步打印日志一次--save_steps 1000 \ # 多少步保存模型一次--learning_rate $LR \ # 学习率--pre_seq_len $PRE_SEQ_LEN \--quantization_bit 4 # 量化,也可修改为int8
训练配置参数具体解释
# 长度和 学习率
–PRE_SEQ_LEN 是 soft prompt 长度,可以进行调节以取得最佳的效果。
–LR 是训练的学习率
# 本地数据,训练集和测试集的路径
–train_file AdvertiseGen/train.json
–validation_file AdvertiseGen/dev.json \
# 模型目录。
如果你想要从本地加载模型,可以将THUDM/chatglm2-6b 改为你本地的模型路径。
–model_name_or_path THUDM/chatglm-6b
# 最大训练步数
–max_steps 3000
# 模型量化,可通过调整 quantization_bit 来被原始模型的量化等级,不加此选项则为 FP16 精度加载。在默认配置 quantization_bit=4
–quantization_bit 4 # 量化,也可修改为int8
# 批次,迭代参数,在默认配置 per_device_train_batch_size=1、gradient_accumulation_steps=16 下,一次训练迭代会以 1 的批处理大小进行 16 次累加的前后向传播,等效为 16 的总批处理大小,此时最低只需 6.7G 显存。若想在同等批处理大小下提升训练效率,可在二者乘积不变的情况下,加大 per_device_train_batch_size 的值,但也会带来更多的显存消耗,请根据实际情况酌情调整。
–per_device_train_batch_size 1 \ # batch_size 训练批次根据显存调节
–per_device_eval_batch_size 1 \ # 验证批次
–gradient_accumulation_steps 16 \ # 梯度累加的步数
训练配置信息
执行训练脚本后,训练过程中屏幕打印信息如下:
***** Running training *****
Num examples = 114,599Num Epochs = 100 Instantaneous batch size per device = 4Total train batch size (w. parallel, distributed & accumulation) = 16Gradient Accumulation steps = 4Total optimization steps = 3,000Number of trainable parameters = 1,949,696
这些参数是深度学习模型训练过程中的一些关键设置。
Num examples = 114,599: 表示在训练集中有114,599个样本,即114,599个独立的训练数据用于训练。
Num Epochs = 100: 一个epoch指的是模型在训练过程中遍历整个训练集一次。因此,Num Epochs = 100意味着模型会遍历整个训练集100次。
Instantaneous batch size per device = 4: 在深度学习中,通常不会同时处理所有的训练样本,而是将它们分成“批次”进行处理。每个批次的大小就是每次模型训练的样本数量。在这个例子中,每个设备上的即时批量大小为4,意味着每个设备一次处理4个样本。
Total train batch size (w. parallel, distributed & accumulation) = 16: 表示在并行、分布式和积累情况下,总的训练批次大小为16。这可能意味着在多个设备上同时进行训练,每个设备处理一部分批次,然后把这些批次加起来,总和为16。
Gradient Accumulation steps = 4: 梯度累积是一种在内存不足的情况下训练大模型的技巧。它的工作原理是:在进行反向传播并更新模型权重之前,先计算并累积一定步数的梯度。在该例中,每4个批次后进行一次权重更新。
Total optimization steps = 3,000: 优化步数是模型训练过程中权重更新的总次数。在这个例子中,模型权重将被更新3000次。
Number of trainable parameters = 1,949,696: 这是模型中可以通过训练改变的参数的数量。深度学习模型的性能通常与其可训练参数的数量有关。但是,更多的参数并不总是意味着更好的性能,因为过多的参数可能导致过拟合,即模型过于复杂,不能很好地泛化到训练集之外的新数据。
模型配置信息
"add_bias_linear": false,"add_qkv_bias": true,"apply_query_key_layer_scaling": true,"apply_residual_connection_post_layernorm": false,"architectures": ["ChatGLMModel"],"attention_dropout": 0.0,"attention_softmax_in_fp32": true,"auto_map": {"AutoConfig": "configuration_chatglm.ChatGLMConfig","AutoModel": "modeling_chatglm.ChatGLMForConditionalGeneration","AutoModelForSeq2SeqLM": "modeling_chatglm.ChatGLMForConditionalGeneration"},"bias_dropout_fusion": true,"eos_token_id": 2,"ffn_hidden_size": 13696,"fp32_residual_connection": false,"hidden_dropout": 0.0,"hidden_size": 4096,"kv_channels": 128,"layernorm_epsilon": 1e-05,"model_type": "chatglm","multi_query_attention": true,"multi_query_group_num": 2,"num_attention_heads": 32,"num_layers": 28,"original_rope": true,"pad_token_id": 2,"padded_vocab_size": 65024,"post_layer_norm": true,"quantization_bit": 0,"rmsnorm": true,"seq_length": 32768,"tie_word_embeddings": false,"torch_dtype": "float16","transformers_version": "4.30.2","use_cache": true
这些参数是深度学习模型配置的详细设置,特别是对于ChatGLM的模型。以下是每个参数的含义:
“add_bias_linear”: false: 表示是否在线性层中添加偏置项。
“add_qkv_bias”: true: 表示是否在注意力机制的查询(Q)、键(K)和值(V)计算中添加偏置项。
“apply_query_key_layer_scaling”: true: 表示是否对注意力机制中的查询和键进行缩放处理。
“apply_residual_connection_post_layernorm”: false: 表示是否在层归一化后应用残差连接。
“architectures”: [“ChatGLMModel”]: 表示该配置用于的模型架构。
“attention_dropout”: 0.0: 表示在注意力计算中应用的dropout的比率。Dropout是一种防止模型过拟合的技术。
“attention_softmax_in_fp32”: true: 表示是否在单精度浮点格式(FP32)中执行注意力机制的Softmax计算。
“auto_map”: 这部分将自动配置,模型映射到ChatGLM的配置和模型。
“bias_dropout_fusion”: true: 表示是否融合偏置和dropout。这通常用于优化和提高训练速度。
“eos_token_id”: 2: 定义结束符(End of Sentence)的标识符。
“ffn_hidden_size”: 13696: 表示前馈神经网络(Feedforward Neural Network,FFN)的隐藏层的大小。
“fp32_residual_connection”: false: 表示是否在单精度浮点格式(FP32)中应用残差连接。
“hidden_dropout”: 0.0: 隐藏层的dropout率。
“hidden_size”: 4096: 隐藏层的大小。
“kv_channels”: 128: 键值(Key-Value)的通道数。
“layernorm_epsilon”: 1e-05: 层归一化的epsilon值,为了防止除数为零。
“model_type”: “chatglm”: 模型类型。
“multi_query_attention”: true: 表示是否使用多查询注意力。
“multi_query_group_num”: 2: 在多查询注意力中的查询组数。
“num_attention_heads”: 32: 注意力机制的头数。
“num_layers”: 28: 模型的层数。
“original_rope”: true: 是否使用原始的ROPE模式。
“pad_token_id”: 2: 定义填充符的标识符。
“padded_vocab_size”: 65024: 表示经过填充后的词汇表大小。
“post_layer_norm”: true: 是否在层后应用层归一化。
“quantization_bit”: 0: 表示量化的位数。
“rmsnorm”: true: 表示是否使用RMS归一化。
“seq_length”: 32768: 序列长度。
“tie_word_embeddings”: false: 是否绑定输入和输出的词嵌入。
“torch_dtype”: “float16”: 使用的数据类型,这里是半精度浮点数。
“transformers_version”: “4.30.2”: 使用的Transformers库版本。
“use_cache”: true: 是否使用缓存以加快计算速度。
请注意,上述是训练开始前打印信息的解释,针对ChatGLM模型配置的解释。
附录:训练正常运行打印信息
执行以下指令进行训练:
./train.sh
当出现以下信息后,模型训练迭代开始后屏幕输出如下:
{'loss': 3.0614, 'learning_rate': 0.018000000000000002, 'epoch': 4.21}
{'loss': 2.2158, 'learning_rate': 0.016, 'epoch': 8.42}
训练完成后,屏幕将打印信息:
***** train metrics *****epoch = xxtrain_loss = xxtrain_runtime = xxtrain_samples = xxtrain_samples_per_second = xxtrain_steps_per_second = xx
End
相关文章:
GPT实战系列-如何用自己数据微调ChatGLM2模型训练
GPT实战系列-ChatGLM2部署Ubuntu+Cuda11+显存24G实战方案
GPT实战系列-Baichuan2本地化部署实战方案
相关文章:
GPT实战系列-ChatGLM2模型的微调训练参数解读
GPT实战系列-ChatGLM2模型的微调训练参数解读 目录 GPT实战系列-ChatGLM2模型的微调训练参数解读ChatGLM2模型1、P-Tuning模型微调2、微调训练配置参数train.sh中配置参数训练配置信息模型配置信息附录:训练正常运行打印信息 ChatGLM2模型 ChatGLM-6B是开源的文本生…...
RabbitMQ入门到实战教程,消息队列实战,改造配置MQ
RabbitMQ入门到实战教程,MQ消息中间件,消息队列实战-CSDN博客 3.7.Topic交换机 3.7.1.说明 Topic类型的Exchange与Direct相比,都是可以根据RoutingKey把消息路由到不同的队列。 只不过Topic类型Exchange可以让队列在绑定BindingKey 的时候…...
phar反序列化学习
PHP反序列化常见的是使用unserilize()进行反序列化,除此之外还有其它的反序列化方法,不需要用到unserilize()。就是用到phar反序列化。 Phar phar文件 Phar是将php文件打包而成的一种压缩文档,类似于Java中的jar包。它有一个特性就是phar文…...
十年回望 -- JAVA
十年 十年时间,弹指一挥,好像一直都是在为工作奔波,匆匆忙忙的十年。 一、个人介绍 本人毕业于一所很普通的公办专科院校(全日制统招大专),专业是软件技术,当初能进入计算机这一行业࿰…...
Linux 环境下 安装 Elasticsearch 7.13.2
Linux 环境下 安装 Elasticsearch 7.13.2 前言镜像下载(国内镜像地址)解压安装包修改配置文件用 Es 自带Jdk 运行配置 Es 可被远程访问然后启动接着启动本地测试一下能不能连 Es 前言 借公司的 centos 7 服务器,搭建一个 Es,正好熟…...
心理咨询预约小程序
随着微信小程序的日益普及,越来越多的人开始关注如何利用小程序来提供便捷的服务。对于心理咨询行业来说,搭建一个心理咨询预约小程序可以大大提高服务的效率和用户体验。本文以乔拓云平台为例,详细介绍如何轻松搭建一个心理咨询预约小程序。…...
常用排序算法的理解
1.插入排序 插入排序的思想是将一个记录插入到已经排好序的有序表中,从而形成一个新的、记录数加1的有序表。在其实现过程使用双层循环,外层循环是进行插入的次数(也可以理解为比较的轮数),内层循环是当前记录查找插入…...
Python小程序 - 文件解析
1. 目录下文件解析:特定文件、文件列表、文件数 Windows文件目录分格使用“ / ” 或 “ \\ ”文件目录路径包含空格的,绝对路径使用“双引号”,保证文件路径的可识别性保存和读取结果时,使用 encodingUTF-8可以添加对文件目录的过…...
.mxdown-V-XXXXXXXX勒索病毒的最新威胁:如何恢复您的数据?
导言: 在数字时代,网络安全威胁层出不穷,其中.mxdown-V-XXXXXXXX、.vollhavhelp-V-XXXXXXXX、.arricklu-V-XXXXXXXX勒索病毒已成为备受关注的问题。这种病毒以其高级加密技术和威胁勒索金的方式,严重危害用户和企业的数据安全。本…...
audio 标签动态src 且src是http无法播放问题
<audioref"audio" :src"src"alt"加载失败"controls/>src是动态传参的 无法播放因为动态src需要在赋值后对audio进行重载 this.$refs.audio.load()注意如果,src跟本项目地址IP端口协议不同,会出现跨域问题。audio标…...
Leetcode—485.最大连续1的个数【中等】明天修改
2023每日刷题(十五) Leetcode—2.两数相加 迭代法实现代码 /*** Definition for singly-linked list.* struct ListNode {* int val;* struct ListNode *next;* };*/ struct ListNode* addTwoNumbers(struct ListNode* l1, struct ListNode* l…...
JavaWeb 怎么在servlet向页面输出Html元素?
service()方法里面的方法体: resp.setContentType("text/html;charsetutf-8");//获得输出流PrintWriter对象PrintWriter outresp.getWriter();out.println("<html>");out.println("<head><title>a servlet</title>…...
Spring及SpringBoot中AOP的使用
Spring中AOP示例 <dependencies><!--Spring核心包--><dependency><groupId>org.springframework</groupId><artifactId>spring-core</artifactId><version>5.3.6</version></dependency><!--引入SpringBean--&…...
cmake多目录构建初步成功
目录和代码和 首次cmake 多目录构建失败 此文一样; 只有一个CMakeLists.txt; cmake_minimum_required(VERSION 3.10) project(mytest3 VERSION 1.0) include_directories("${PROJECT_SOURCE_DIR}/include") add_executable(mytest3 src/main…...
idea插件(一)-- SequenceDiagram(UML自动生成工具)
目录 1. 安装 2. 默认快捷键 3. 操作说明 4. 导出为图片与UML类图 4.1 导出为图片: 4.2 导出 UML 类图 SequenceDiagram是从java、kotlin、scala(Beta)和groovy(limited)代码生成简单序列图(UML&…...
STM32 APP跳转到Bootloader
stm32 app跳转到bootloade 【STM32】串口IAP功能的实现,BootLoader与App相互跳转 STM32 从APP跳入BootLoader问题...
[RISC-V]verilog
小明教IC-1天学会verilog(7)_哔哩哔哩_bilibili task不可综合,function可以综合...
Log4j-tag丢失
一、引言 最近有个线上日志丢失tag的问题,是组内封装了后置请求的拦截器把请求的响应结果存到ClickHouse里面去,但是日志总有一些tag丢失。 作者提出父级线程的threadlocal被清空,同事认为可能是threadlocal的弱引用在gc的时候被回收。两种想…...
代码随想录算法训练营第五十六天|1143.最长公共子序列 ● 1035.不相交的线 ● 53. 最大子序和 动态规划
1143. 最长公共子序列 int longestCommonSubsequence(char * text1, char * text2){int len1 strlen(text1);int len2 strlen(text2);int dp[len11][len21];for (int i 0; i < len1; i){for (int j 0; j < len2; j){dp[i][j] 0;}}for (int i 1; i < len1; i){f…...
虚拟机和Windows的文件传输
拖拽/复制粘贴 直接将虚拟机linux系统的文件拖曳到windows桌面,或者直接将windows的文件拖曳到虚拟机linux系统当中,可以实现文件传输。当然复制粘贴方式也可以,但是前提是需要下载安装好VMware tools。 共享文件夹 概念:在Win…...
服务器硬防的应用场景都有哪些?
服务器硬防是指一种通过硬件设备层面的安全措施来防御服务器系统受到网络攻击的方式,避免服务器受到各种恶意攻击和网络威胁,那么,服务器硬防通常都会应用在哪些场景当中呢? 硬防服务器中一般会配备入侵检测系统和预防系统&#x…...
测试markdown--肇兴
day1: 1、去程:7:04 --11:32高铁 高铁右转上售票大厅2楼,穿过候车厅下一楼,上大巴车 ¥10/人 **2、到达:**12点多到达寨子,买门票,美团/抖音:¥78人 3、中饭&a…...
渲染学进阶内容——模型
最近在写模组的时候发现渲染器里面离不开模型的定义,在渲染的第二篇文章中简单的讲解了一下关于模型部分的内容,其实不管是方块还是方块实体,都离不开模型的内容 🧱 一、CubeListBuilder 功能解析 CubeListBuilder 是 Minecraft Java 版模型系统的核心构建器,用于动态创…...
【Web 进阶篇】优雅的接口设计:统一响应、全局异常处理与参数校验
系列回顾: 在上一篇中,我们成功地为应用集成了数据库,并使用 Spring Data JPA 实现了基本的 CRUD API。我们的应用现在能“记忆”数据了!但是,如果你仔细审视那些 API,会发现它们还很“粗糙”:有…...
Axios请求超时重发机制
Axios 超时重新请求实现方案 在 Axios 中实现超时重新请求可以通过以下几种方式: 1. 使用拦截器实现自动重试 import axios from axios;// 创建axios实例 const instance axios.create();// 设置超时时间 instance.defaults.timeout 5000;// 最大重试次数 cons…...
【C++从零实现Json-Rpc框架】第六弹 —— 服务端模块划分
一、项目背景回顾 前五弹完成了Json-Rpc协议解析、请求处理、客户端调用等基础模块搭建。 本弹重点聚焦于服务端的模块划分与架构设计,提升代码结构的可维护性与扩展性。 二、服务端模块设计目标 高内聚低耦合:各模块职责清晰,便于独立开发…...
C# 求圆面积的程序(Program to find area of a circle)
给定半径r,求圆的面积。圆的面积应精确到小数点后5位。 例子: 输入:r 5 输出:78.53982 解释:由于面积 PI * r * r 3.14159265358979323846 * 5 * 5 78.53982,因为我们只保留小数点后 5 位数字。 输…...
力扣-35.搜索插入位置
题目描述 给定一个排序数组和一个目标值,在数组中找到目标值,并返回其索引。如果目标值不存在于数组中,返回它将会被按顺序插入的位置。 请必须使用时间复杂度为 O(log n) 的算法。 class Solution {public int searchInsert(int[] nums, …...
A2A JS SDK 完整教程:快速入门指南
目录 什么是 A2A JS SDK?A2A JS 安装与设置A2A JS 核心概念创建你的第一个 A2A JS 代理A2A JS 服务端开发A2A JS 客户端使用A2A JS 高级特性A2A JS 最佳实践A2A JS 故障排除 什么是 A2A JS SDK? A2A JS SDK 是一个专为 JavaScript/TypeScript 开发者设计的强大库ÿ…...
Java毕业设计:WML信息查询与后端信息发布系统开发
JAVAWML信息查询与后端信息发布系统实现 一、系统概述 本系统基于Java和WML(无线标记语言)技术开发,实现了移动设备上的信息查询与后端信息发布功能。系统采用B/S架构,服务器端使用Java Servlet处理请求,数据库采用MySQL存储信息࿰…...
