SwissArmyTransformer瑞士军刀工具箱使用手册
Introduction sat(SwissArmyTransformer)是一个灵活而强大的库,用于开发您自己的Transformer变体。
sat是以“瑞士军刀”命名的,这意味着所有型号(例如BERT、GPT、T5、GLM、CogView、ViT…)共享相同的backone代码,并通过一些超轻量级的mixin满足多种用途。
sat由deepspeed ZeRO和模型并行性提供支持,旨在为大模型(100M\~20B参数)的预训练和微调提供最佳实践。
从 SwissArmyTransformer 0.2.x 迁移到 0.3.x
- 导入时将包名称从 SwissArmyTransformer 更改为 sat,例如从 sat 导入 get_args。
- 删除脚本中的所有--sandwich-ln,使用layernorm-order='sandwich'。
- 更改顺序 from_pretrained(args, name) => from_pretrained(name, args)。
- 我们可以直接使用 from sat.model import AutoModel;model, args = AutoModel.from_pretrained('roberta-base') 以 仅模型模式 加载模型,而不是先初始化 sat。
安装
pip install SwissArmyTransformer
特征
添加与模型无关的组件,例如前缀调整,只需一行!
前缀调整(或 P 调整)通过在每个注意力层中添加可训练参数来改进微调。使用我们的库可以轻松地将其应用于 GLM 分类(或任何其他)模型。
class ClassificationModel(GLMModel): # can also be BertModel, RobertaModel, etc. def __init__(self, args, transformer=None, **kwargs):super().__init__(args, transformer=transformer, **kwargs)self.add_mixin('classification_head', MLPHeadMixin(args.hidden_size, 2048, 1))# Arm an arbitrary model with Prefix-tuning with this line!self.add_mixin('prefix-tuning', PrefixTuningMixin(args.num_layers, args.hidden_size // args.num_attention_heads, args.num_attention_heads, args.prefix_len))
GPT 和其他自回归模型在训练和推理过程中的行为有所不同。在推理过程中,文本是逐个令牌生成的,我们需要缓存以前的状态以提高效率。使用我们的库,您只需要考虑训练期间的行为(教师强制),并通过添加 mixin 将其转换为缓存的自回归模型:
model, args = AutoModel.from_pretrained('glm-10b-chinese', args)
model.add_mixin('auto-regressive', CachedAutoregressiveMixin())
# Generate a sequence with beam search
from sat.generation.autoregressive_sampling import filling_sequence
from sat.generation.sampling_strategies import BeamSearchStrategy
output, *mems = filling_sequence(model, input_seq,batch_size=args.batch_size,strategy=BeamSearchStrategy(args.batch_size))
使用最少的代码构建基于 Transformer 的模型。我们提到了 GLM,它与标准转换器(称为 BaseModel)仅在位置嵌入(和训练损失)上有所不同。我们在编码的时候只需要关注相关的部分就可以了。
扩展整个定义:
class BlockPositionEmbeddingMixin(BaseMixin):# Here define parameters for the mixindef __init__(self, max_sequence_length, hidden_size, init_method_std=0.02):super(BlockPositionEmbeddingMixin, self).__init__()self.max_sequence_length = max_sequence_lengthself.hidden_size = hidden_sizeself.block_position_embeddings = torch.nn.Embedding(max_sequence_length, hidden_size)torch.nn.init.normal_(self.block_position_embeddings.weight, mean=0.0, std=init_method_std)# Here define the method for the mixindef position_embedding_forward(self, position_ids, **kwargs):position_ids, block_position_ids = position_ids[:, 0], position_ids[:, 1]position_embeddings = self.transformer.position_embeddings(position_ids)block_position_embeddings = self.block_position_embeddings(block_position_ids)return position_embeddings + block_position_embeddingsclass GLMModel(BaseModel):def __init__(self, args, transformer=None, parallel_output=True):super().__init__(args, transformer=transformer, parallel_output=parallel_output)self.add_mixin('block_position_embedding', BlockPositionEmbeddingMixin(args.max_sequence_length, args.hidden_size)) # Add the mixin for GLM
全方位的培训支持。 sat 旨在提供预训练和微调的最佳实践,您只需要完成forward_step 和 create_dataset_function,但可以使用超参数来更改有用的训练配置。
通过指定 --num_nodes、--num_gpus 和一个简单的主机文件,将训练扩展到多个 GPU 或节点。
DeepSpeed 和模型并行性。
ZeRO-2 和激活检查点的更好集成。
自动扩展和改组训练数据和内存映射。
成功支持CogView2和CogVideo的训练。
目前唯一支持在 GPU 上微调 T5-10B 的开源代码库。
快速浏览
在 sat 中使用 Bert(用于推理)的最典型的 python 文件如下:
# @File: inference_bert.py
from sat import get_args, get_tokenizer, AutoModel
# Parse args, initialize the environment. This is necessary.
args = get_args()
# Automatically download and load model. Will also dump model-related hyperparameters to args.
model, args = AutoModel.from_pretrained('bert-base-uncased', args)
# Get the BertTokenizer according to args.tokenizer_type (automatically set).
tokenizer = get_tokenizer(args)
# Here to use bert as you want!
# ...
然后我们可以通过以下方式运行代码
SAT_HOME=/path/to/download python inference_bert.py --mode inference
所有官方支持的模型名称都在 urls.py 中。
# @File: finetune_bert.py
from sat import get_args, get_tokenizer, AutoModel
from sat.model.mixins import MLPHeadMixindef create_dataset_function(path, args):# Here to load the dataset# ...assert isinstance(dataset, torch.utils.data.Dataset)return datasetdef forward_step(data_iterator, model, args, timers):inputs = next(data_iterator) # from the dataset of create_dataset_function.loss, *others = model(inputs)return loss# Parse args, initialize the environment. This is necessary.
args = get_args()
model, args = AutoModel.from_pretrained('bert-base-uncased', args)
tokenizer = get_tokenizer(args)
# Here to use bert as you want!
model.del_mixin('bert-final')
model.add_mixin('classification_head', MLPHeadMixin(args.hidden_size, 2048, 1))
# ONE LINE to train!
# args already includes hyperparams such as lr, train-iters, zero-stage ...
training_main(args, model_cls=model, forward_step_function=forward_step, # user definecreate_dataset_function=create_dataset_function # user define
)
然后我们可以通过以下方式运行代码
deepspeed --include localhost:0,1 finetune_bert.py \--experiment-name ftbert \--mode finetune --train-iters 1000 --save /path/to/save \--train-data /path/to/train --valid-data /path/to/valid \--lr 0.00002 --batch-size 8 --zero-stage 1 --fp16
这里我们在 GPU 0,1 上使用数据并行。我们还可以通过 --hostfile/path/to/hostfile 在许多互连的机器上启动训练。请参阅教程了解更多详细信息。
要编写自己的模型,您只需要考虑与标准 Transformer 的差异。例如,如果你有一个改进注意力操作的想法:
from sat.model import BaseMixin
class MyAttention(BaseMixin):def __init__(self, hidden_size):super(MyAttention, self).__init__()# MyAttention may needs some new params, e.g. a learnable alpha.self.learnable_alpha = torch.nn.Parameter(torch.ones(hidden_size))# This is a hook function, the name `attention_fn` is special.def attention_fn(q, k, v, mask, dropout=None, **kwargs):# Code for my attention.# ...return attention_results
这里的attention_fn是一个钩子函数,用新函数替换默认动作。所有可用的钩子都在transformer_defaults.py中。现在我们可以使用 add_mixin 将更改应用到所有转换器,例如 BERT、Vit 和 CogView。请参阅教程了解更多详细信息。
教程
- How to use pretrained models collected in sat?
- Why and how to train models in sat?
Citation
Currently we don't have a paper, so you don't need to formally cite us!~
If this project helps your research or engineering, use \footnote{https://github.com/THUDM/SwissArmyTransformer} to mention us and recommend SwissArmyTransformer to others.
The tutorial for contributing sat is on the way!
The project is based on (a user of) DeepSpeed, Megatron-LM and Huggingface transformers. Thanks for their awesome work.
训练指导
The Training API
我们提供了一个简单但功能强大的训练APItraining_main(),它不仅限于我们的Transformer模型,还适用于任何torch.nn.Module。
from sat import get_args, training_main
from sat.model import AutoModel, BaseModel
args = get_args()
# to pretrain from scratch, give a class obj
model = BaseModel
# to finetuned from a given model, give a torch.nn.Module
model = AutoModel.from_pretrained('bert-base-uncased', args)training_main(args, model_cls=model,forward_step_function=forward_step,create_dataset_function=dataset_func,handle_metrics_function=None,init_function=None
)
以上是使用 sat 的标准训练计划的(不完整)示例。 Training_main 接受 5 个参数:(必需)model_cls:继承 torch.nn.Module 的类型对象,或我们训练的 torch.nn.Module 对象。
(必需)forward_step_function:一个自定义函数,输入 data_iterator、model、args、timers、returns loss、{'metric0': m0, ...}。
(必填)create_dataset_function:返回一个torch.utils.data.Dataset用于加载。我们的库会自动将数据分配给多个worker,并将数据迭代器交给forward_step_function。
(可选)handle_metrics_function:在评估过程中处理特殊指标。
(可选)init_function:在训练之前更改模型的钩子,对于继续训练很有用。
有关完整示例,请参阅 Finetune BERT 示例。
相关文章:
SwissArmyTransformer瑞士军刀工具箱使用手册
Introduction sat(SwissArmyTransformer)是一个灵活而强大的库,用于开发您自己的Transformer变体。 sat是以“瑞士军刀”命名的,这意味着所有型号(例如BERT、GPT、T5、GLM、CogView、ViT…)共享相同的backo…...
unity【动画】脚本_角色动画控制器 c#
首先创建一个代码文件夹Scripts 从人物角色Player的基类开始 创建IPlayer类 首先我们考虑到如果不挂载MonoBehaviour需要将角色设置成预制体实例化到场景上十分麻烦, 所以我们采用继承MonoBehaviour类的角色基类方法写代码 也就是说这个脚本直接绑定在角色物体…...
Java代码如何对Excel文件进行zip压缩
1:新建 ZipUtils 工具类 package com.ly.cloud.datacollection.util;import java.io.File; import java.io.FileInputStream; import java.io.FileOutputStream; import java.io.IOException; import java.io.OutputStream; import java.net.URLEncoder; import ja…...
改进YOLO系列:12.Repulsion损失函数【遮挡】
1. RepLoss论文 物体遮挡问题可以分为类内遮挡和类间遮挡两种情况。类间遮挡产生于扎堆的同类物体,也被称为密集遮挡(crowd occlusion)。Repulsion损失函数由三个部分构成,yolov5样本匹配,得到的目标框和预测框-一对应第一部分主要作用:预测目标框吸引IOU最大的真实目标框,…...
win11网络连接正常,但是无法正常上网
前言: 这个是一个win11的bug,好多人都遇到了,在孜孜不倦的百度下,毫无收获,终于是在抖音上看到有人分享的经验而解决了这个问题。 找到internet选项,然后点击打开 选择连接 将代理服务器中,为…...
硬科技企业社区“曲率引擎”品牌正式发布
“曲率引擎”,是科幻作品中最硬核的加速系统,通过改变时空的曲率,可实现光速飞行甚至能够超越光速。11月3日,“曲率引擎(warp drive)”作为硬科技企业社区品牌,在2023全球硬科技创新大会上正式对…...
少儿编程 2023年9月中国电子学会图形化编程等级考试Scratch编程三级真题解析(判断题)
2023年9月scratch编程等级考试三级真题 判断题(共10题,每题2分,共20分) 19、运行程序后,“我的变量”的值为25 答案:对 考点分析:考查积木综合使用,重点考查变量和运算积木的使用 开始我的变量为50,执行完第二行代码我的变量变为49,条件不成立执行否则语句,所以…...
MCU常见通信总线串讲(二)—— RS232和RS485
🙌秋名山码民的主页 😂oi退役选手,Java、大数据、单片机、IoT均有所涉猎,热爱技术,技术无罪 🎉欢迎关注🔎点赞👍收藏⭐️留言📝 获取源码,添加WX 目录 前言一…...
LazyVim: 将 Neovim 升级为完整 IDE | 开源日报 No.67
curl/curl Stars: 31.5k License: NOASSERTION Curl 是一个命令行工具,用于通过 URL 语法传输数据。 核心优势和关键特点包括: 可在命令行中方便地进行数据传输支持多种协议 (HTTP、FTP 等)提供丰富的选项和参数来满足不同需求 kubernetes/ingress-n…...
想要搭建网站帮助中心,看这一篇指南就对了!
在现今互联网时代,除了让用户了解产品的功能和一些操作,很多企业都需要在网上进行信息的发布和产品销售等业务活动。而这就需要一个帮助中心,在用户遇到问题或者需要了解更多信息的时候,能够快速地解答他们的疑惑和提供响应的帮助…...
92.更新一些收藏的经验贴总结学习
一、JS相关 1.进制转换 (1)十进制转二进制 十进制数除2取余法:十进制数除2,余数为权位上的数,得到的商继续除2,直到商为0。最后余数从下往上取值。 (2)二进制转十进制 把二进制…...
mysql 问题解决 4
7、集群 7.1 日志 1、MySQL 中有哪些常见日志? MySQL 中有以下常见的日志类型: 错误日志(Error Log):记录 MySQL 服务器在运行过程中出现的错误信息。通用查询日志(General Query Log):记录所有连接到 MySQL 服务器的 SQL 查询语句。慢查询日志(Slow Query Log):…...
llama-7B、vicuna-7b-delta-v1.1和vicuna-7b-v1.3——使用体验
Chatgpt的出现给NLP领域带来了让人振奋的消息,可以很逼真的模拟人的对话,回答人们提出的问题,不过Chatgpt参数量,规模,训练代价都很昂贵。 幸运的是,出现了开源的一些相对小的模型,可以在本地或…...
深入理解JVM虚拟机第十九篇:JVM字节码中方法内部的结构和与局部变量表中变量槽的介绍
大神链接:作者有幸结识技术大神孙哥为好友,获益匪浅。现在把孙哥视频分享给大家。 孙哥链接:孙哥个人主页 作者简介:一个颜值99分,只比孙哥差一点的程序员 本专栏简介:话不多说,让我们一起干翻JVM 本文章简介:话不多说,让我们讲清楚虚拟机栈存储结构和运行原理 文章目…...
windows好玩的cmd命令
颜色 后边的数字查表吧,反正我是喜欢一个随机的数字 color 01MAC getmac /v更新主机IP地址 通过DHCP更新 ipconfig /release ipconfig /renew改标题 title code with 你想要的标题...
线扫相机DALSA--常见问题四:修改相机参数,参数保存无效情况
该问题是操作不当,未按照正常步骤保存参数所致,相机为RAM机制,参数需保存在采集卡的ROM内。 保存参数步骤: ①首先将相机参数保存至User Set1; ②然后回到Board(采集卡)参数设置区,鼠标选中Basic Timing&a…...
linux中用date命令获取昨天、明天或多天前后的日期
在实际操作中,一些脚本中会调用明天,或者昨天,或更多天前的日期,本文将叙述讲述用date命令实现时间的显示。在Linux系统中用man date -d 查询的参数说的比较模糊,以下举例进一步说明: # man date -d, --da…...
【无标题】360压缩软件怎么用?超级好用!
360压缩是一款功能强大的解压缩软件,如何用它压缩文件呢?下面给出了详细的操作步骤。 一、360压缩详细步骤 1、下载软件后,在电脑上右击需要压缩的文件,在弹出的菜单中点击【添加到压缩文件】选项。 2、在360压缩窗口中按需设置相…...
一图搞懂傅里叶变换(FT)、DTFT、DFS和DFT之间的关系
自然界中的信号都是模拟信号,计算机无法处理,因此我们会基于奈奎斯特定理对模拟信号采样得到数字信号。 但是我们发现,即便是经过采样,在时域上得到了数字信号,而在频域上还是连续信号。 因此我们可以在时域中选取N点…...
行情分析——加密货币市场大盘走势(11.7)
大饼昨日下跌过后开始有回调的迹象,现在还是在做指标修复,大饼的策略保持逢低做多。稳健的依然是不碰,目前涨不上去,跌不下来。 以太昨天给的策略,依然有效,现在以太坊开始回调。 目前来看,回踩…...
华为云AI开发平台ModelArts
华为云ModelArts:重塑AI开发流程的“智能引擎”与“创新加速器”! 在人工智能浪潮席卷全球的2025年,企业拥抱AI的意愿空前高涨,但技术门槛高、流程复杂、资源投入巨大的现实,却让许多创新构想止步于实验室。数据科学家…...
在软件开发中正确使用MySQL日期时间类型的深度解析
在日常软件开发场景中,时间信息的存储是底层且核心的需求。从金融交易的精确记账时间、用户操作的行为日志,到供应链系统的物流节点时间戳,时间数据的准确性直接决定业务逻辑的可靠性。MySQL作为主流关系型数据库,其日期时间类型的…...
大话软工笔记—需求分析概述
需求分析,就是要对需求调研收集到的资料信息逐个地进行拆分、研究,从大量的不确定“需求”中确定出哪些需求最终要转换为确定的“功能需求”。 需求分析的作用非常重要,后续设计的依据主要来自于需求分析的成果,包括: 项目的目的…...
进程地址空间(比特课总结)
一、进程地址空间 1. 环境变量 1 )⽤户级环境变量与系统级环境变量 全局属性:环境变量具有全局属性,会被⼦进程继承。例如当bash启动⼦进程时,环 境变量会⾃动传递给⼦进程。 本地变量限制:本地变量只在当前进程(ba…...
python/java环境配置
环境变量放一起 python: 1.首先下载Python Python下载地址:Download Python | Python.org downloads ---windows -- 64 2.安装Python 下面两个,然后自定义,全选 可以把前4个选上 3.环境配置 1)搜高级系统设置 2…...
线程同步:确保多线程程序的安全与高效!
全文目录: 开篇语前序前言第一部分:线程同步的概念与问题1.1 线程同步的概念1.2 线程同步的问题1.3 线程同步的解决方案 第二部分:synchronized关键字的使用2.1 使用 synchronized修饰方法2.2 使用 synchronized修饰代码块 第三部分ÿ…...
【机器视觉】单目测距——运动结构恢复
ps:图是随便找的,为了凑个封面 前言 在前面对光流法进行进一步改进,希望将2D光流推广至3D场景流时,发现2D转3D过程中存在尺度歧义问题,需要补全摄像头拍摄图像中缺失的深度信息,否则解空间不收敛…...
CocosCreator 之 JavaScript/TypeScript和Java的相互交互
引擎版本: 3.8.1 语言: JavaScript/TypeScript、C、Java 环境:Window 参考:Java原生反射机制 您好,我是鹤九日! 回顾 在上篇文章中:CocosCreator Android项目接入UnityAds 广告SDK。 我们简单讲…...
C# 类和继承(抽象类)
抽象类 抽象类是指设计为被继承的类。抽象类只能被用作其他类的基类。 不能创建抽象类的实例。抽象类使用abstract修饰符声明。 抽象类可以包含抽象成员或普通的非抽象成员。抽象类的成员可以是抽象成员和普通带 实现的成员的任意组合。抽象类自己可以派生自另一个抽象类。例…...
Maven 概述、安装、配置、仓库、私服详解
目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...
