当前位置: 首页 > news >正文

经典神经网络(13)GPT-1、GPT-2原理及nanoGPT源码分析(GPT-2)

经典神经网络(13)GPT-1、GPT-2原理及nanoGPT源码分析(GPT-2)

  • 2022 年 11 月,ChatGPT 成功面世,成为历史上用户增长最快的消费者应用。与 Google、FaceBook等公司不同,OpenAI 从初代模型 GPT-1 开始,始终贯彻只有解码器(Decoder-only)的技术路径,不断迭代升级。
模型发布日期
GPT2018-11-14
GPT-22019-11-27
GPT-32020-6-11
InstructGPT2022-3-4
ChatGPT2022-11-30
GPT-42023-3-14
ChatGPT Plugin2023-5-12
  • 可以看到,2022 年是 GPT 系列模型围绕 GPT-3、GPT-3.5 加速版本迭代的年份;

    • 2022 年 3 月,基于 GPT-3 微调的 InstructGPT 发布,验证了人类反馈强化学习RLHF对模型输出对齐(alignment)的重要作用;
    • 2022年4-6月,基于Codex、InstructGPT,OpenAI 加速迭代形成 GPT-3.5 模型;
    • 2022 年 11 月,基于 GPT-3.5 微调的 ChatGPT 发布,成为 Instruction-tuning、RLHF、思维链等 LLM 相关技术的集大成者。
      • ChatGPT与InstructGPT的训练方法基本一致,区别在于InstructGPT、ChatGPT分别基于GPT-3、GPT-3.5进行模型微调。
      • InstructGPT具体可分为有监督微调、奖励模型训练、PPO 强化学习三个步骤。
      • ChatGPT技术原理解析可参考:https://blog.csdn.net/v_JULY_v/article/details/128579457
    • 2023年3月中旬,OpenAI正式对外发布GPT-4,增加了多模态(支持图片的输入形式),且ChatGPT底层的语言模型直接从GPT3.5升级到了GPT4。
  • ChatGPT 的发展不仅得益于 GPT 模型参数、训练数据的持续优化,也得益于各类 LLM 新技术的融会贯通,OpenAI 博采众长,加速 Instruction-tuning、RLHF、思维链等新技术在 GPT 系列模型中的深度应用,ChatGPT 是现有技术的集大成者。

  • 我们今天主要了解下GPT-2。

1 GPT-1简介

  • 2017年,Google推出了Transformer模型,这一架构因其在性能上的显著优势迅速吸引了OpenAI团队的注意。OpenAI随后将研发重点转移到Transformer架构,并在2018年发布了GPT-1模型。

  • OpenAI在2018年提出了GPT(Generative Pre-training)生成式预训练模型,采用了仅有解码器的Transformer模型,专注于预测下一个词元。

  • GPT采用了transformer的Decoder作为框架,并采用了两阶段的训练方式。首先,在大量的无标记数据集中,进行生成式训练(Generative Pre-training);然后,在在特定任务进行微调(fine-tuning)。

  • 论文:language_understanding_paper (openai.com)

1.1 GPT-1网络结构

  • 下图左半部分是transformer架构图,右半部分是GPT的架构图。与GPT相比,transformer的Decoder除了中间被隐去的cross-attention,其余部分结构相同,GPT由12个Decoder串联而成。
    • GPT的输入部分与transformer的Encoder的输入部分相同,与transformer不同的是位置向量采用随机初始化,并在训练中进行更新;
    • GPT的Decoder可以看做是 transformer 的 Decoder 去掉中间Cross-Attention后的结构。
    • Decoder主要有三个子模块组成:Masked Multi-Head Attention、残差网络&LayerNorm 以及 Feed Forward。
  • GPT的输出部分对应于下游任务,不同的任务使用不同的全连接层作为输出。

在这里插入图片描述

1.2 两阶段训练

GPT的训练过程分为两个阶段,第一阶段是采用大量文本预料进行无监督训练,第二个阶段是采用少量有标注的数据进行有监督的微调。

1.2.1 无监督预训练

  • GPT采用标准的语言模型进行无监督训练,即通过上文前 k k k个词来预测当前词
  • 预训练时只有Text Prediction,没有Task Classifier

在这里插入图片描述

预训练模型的参数设置

  • GPT-1使用BooksCorpus数据集(约5GB)来训练语言模型。BooksCorpus有大约7000本未出版的书籍,这些书籍帮助在不可见的数据上训练语言模型。
  • 采用L=12层decoder,每个自注意层有A=12个注意头
  • 使用了带有40,000个合并的字节对编码(BPE)词汇表
  • 分词嵌入维度H为768,位置编码嵌入维度为768,通过模型学习获得位置编码
  • 位置前馈层采用3072维
  • dropout为0.1
  • GELU函数作为激活函数
  • batch_size为64,序列长度为512,epoch为100, 参数共117M,即1.17 亿的参数量。

1.2.2 有监督微调

  • 无监督训练完毕后,并不能直接用于下游任务,需要在具体任务的标注数据集上进行微调(如下图右半部分)。
  • 对于大多数下游任务,监督的微调只需要3个epoch。
  • 对比预训练阶段,只是多增加了“Task Classifier”模块。

当得到无监督的预训练模型后,可以将该模型直接应用到有监督任务中。每个实例有m个输入 x 1 , x 2 , . . , x m x^1,x^2,..,x^m x1,x2,..,xm ,以及标签y组成。首先将这些token输入到预训练的GPT1中,得到最终的特征向量 h l m h_l^m hlm ,然后再通过一个全连接层得到预测结果y:
P ( y ∣ x 1 , . . . , x m ) = s o f t m a x ( h l m W y ) P(y|x^1,...,x^m)=softmax(h_l^mW_y) P(yx1,...,xm)=softmax(hlmWy)
有监督的目标是最大化以下损失:
L 2 ( C ) = ∑ l o g P ( y ∣ x 1 , . . . , x m ) L_2(C)=∑logP(y|x^1,...,x^m) L2(C)=logP(yx1,...,xm)
GPT的实验中发现,加入语言模型学习目标作为辅助任务,也就是损失函数中加入 L 1 ( u ) L_1(u) L1(u)能带来两点好处:不仅能提升监督模型的泛化能力;还能加快收敛;

因此,作者并没有直接使用L2,而是使用联合损失函数来进行微调(Text Prediction【L1】 + Task Classifier【L2】),并使用 λ λ λ 进行两个任务权值的调整, λ λ λ 的值一般为0.5:
L 3 ( C ) = L 2 ( C ) + λ ∗ L 1 ( C ) L_3(C)=L_2(C)+λ*L_1(C) L3(C)=L2(C)+λL1(C)
在这里插入图片描述

  • 文本分类:可以看到,有两个特殊符号(Start和Extract)通过Transformer后,添加了一个线性层。

  • 蕴含理解:给一段话,提出一个假设,看看假设是否成立。

    • 将前提(premise)和假设(hypothesis)通过分隔符(Delimiter)隔开
    • 两端加上起始和终止token
    • 再依次通过Transformer和全连接得到预测结果;
  • 文本相似:断两段文字是不是相似。相似是一个对称关系,A和B相似,那么B和A也是相似的

    • 所以先有Text1+分隔符+Text2,再有Text2+分隔符+Text1;
    • 两个序列分别经过Transformer后,各自得到输出的向量;
    • 我们把它按元素加到一起,然后送给一个线性层。这也是一个二分类问题。
  • 多项选择:多个序列,每个序列都由相同的问题Context和不同的Answer构成。

    • 如果有N个答案,就构造N个序列;
    • 每个QA序列都各自经过Transformers和线性层,对每个答案都计算出一个标量;
    • 最后经过softmax生成一个各个答案的概率密度分布。这是一个N分类问题。

2 GPT-2简介

  • Bert论文:https://arxiv.org/pdf/1810.04805

  • GPT-2论文:Language Models are Unsupervised Multitask Learners (openai.com)

  • 同年10月,谷歌发布了BERT模型。BERT用了Transformer中的Encoder部分,更类似完形填空,根据上下文来确定中间词:

    • 和GPT相比,BERT所使用的掩码语言模型任务(Masked Language Model)虽然让它失去了直接生成文本的能力,但换来的是双向编码的能力,这让模型拥有了更强的文本编码性能,直接的体现则是下游任务效果的大幅提升。
    • 而GPT为了保留生成文本的能力,只能采用单向编码。
  • 如下图所示,为了和GPT进行对比Bert论文作者设计了Base模型,模型参数大小和GPT大致相等。但是BERT数据集的数据量大概是GPT的四倍。
    在这里插入图片描述

  • 对比结果如下图所示,BERT-Large模型在多个NLU任务上取得了显著的性能提升,成为当时自然语言处理领域的明星模型,引领了一波研究热潮。

在这里插入图片描述

  • 以当年的眼光来看,BERT绝对是一个更加优秀的模型。因为既然BERT和GPT两者都是采用「预训练+微调」的范式,并且下游任务依然是分类、匹配、序列标注等等「经典」的NLP任务形式,那么像BERT模型这种更注重特征编码的质量,下游任务选一个合适的损失函数去配合任务做微调,显然比GPT这种以文本生成的方式去「迂回地」完成这些任务更加直接。
  • 如果当时OpenAI放弃生成式预训练这条路,也许我们要等更长的时间才能见到ChatGPT这样的模型。

2.1 GPT-2核心思想

Q:当模型被别人用更大的数据集、参数量(Bert)打败时,应该怎么做?

A:自己也增大数据集和模型参数量。

Q:但是当文本数据集增大,模型参数量增大的基础上,模型的优势没有那么高的情况下应该怎么办?

A:找到另一个卖点(另辟蹊径)——zero shot

如果这篇文章单纯讲结果比Bert好一些,其实大家一看没啥意思,太工程化了,但是换一个角度,做一个更难的问题,但是得到一个好一点的结果,文章新颖性一下子有了。做工程和做研究的区别就是,做工程可以一直死盯着精度,但是做研究需要另辟蹊径找到一个创新点。

来自《李沐GPT、GPT-2、GPT-3论文精读》

GPT-2的核心思想就是,当模型的容量非常大且数据量足够丰富时,仅仅靠语言模型的学习便可以完成其他有监督学习的任务,不需要在下游任务微调。

GPT-2和GPT的区别在于GPT-2使用了更多的网络参数和更大的数据集,以此来训练一个泛化能力更强的词向量模型。GPT-2相比于GPT有如下几点区别:

  • 主推zero-shot,而GPT-1为pre-train+fine-tuning;
  • 模型更大,参数量达到了15亿个,而GPT-1只有1.17亿个;
  • 数据集更大,WebText数据集包含了40GB的文本数据,而GPT-1只有5GB;
    • 没有选择Common Crawl这种具有很多冗余无用信息的项目
    • 选用的是reddit里面已经被人工筛选出的有意义的,并且具有至少3karma值的网页进行数据处理,大概有800万个文本,40GB的文字。
  • 训练参数变化,batch_size从64增加到 512,上文窗口大小从512增加到1024;

在这里插入图片描述

2.2 GPT-2模型结构

GPT-2 提供了四种规模的模型:

在这里插入图片描述

在模型结构方面,整个GPT-2的模型框架与GPT相同,调整更多的是被当作训练时的trick,而不作为GPT-2的创新,具体为以下几点:

  • 后置层归一化(post-norm)改为前置层归一化(pre-norm)
  • 在模型最后一个自注意力层之后,额外增加一个层归一化
  • 调整参数的初始化方式,按残差层个数进行缩放,缩放比例为 1 : sqrt(n)
  • 输入序列的最大长度从512扩充到1024

在这里插入图片描述

2.3 GPT-2的贡献

2.3.1 预训练和zero-shot

  • GPT-2的预训练和GPT基本没什么区别,但是对下游任务用了zero-shot。

  • GPT-2可以在zero-shot设定下实现下游任务,即不需要用有标签的数据进行微调。

    • 为了实现zero-shot,下游任务的输入就不能像GPT那样在构造输入时加入开始、中间和结束的特殊字符,这些是模型在预训练时没有见过的,而是应该和预训练模型看到的文本一样,更像一个自然语言。
    • 可以通过做prompt的方式来zero-shot。
      • 例如:在做句子翻译任务时,训练的句子可以被写为: (translate to french, english text, french text). 其中translate to french在后文叫做prompt也叫做提示,相当于做了一个特殊的提示词。
      • 如果要做阅读理解任务时:可以写作(answer the question, document(阅读的文本), question, answer),answer the question相当于任务提示
    • 为何zero-shot这种方式是有效的呢?从一个尽可能大且多样化的数据集中一定能收集到不同领域不同任务相关的自然语言描述示例,数据集里就存在展示了这些prompt示例,所以训练出来就自然而然有一定zero-shot的能力了。

2.3.2 总结

  • GPT-2的最大贡献是验证了通过海量数据和大量参数训练出来的词向量模型有迁移到其它类别任务中而不需要额外的训练,即zero-shot learning的能力。但是效果其实很一般。

  • GPT-2表明随着模型容量和数据量的增大,其潜能还有进一步开发的空间,基于这个思想,促使了GPT3的出现。

  • GPT-2虽然提出zero-shot,比Bert有新意,但是有效性方面不佳。GPT-3考虑few-shot,用少量文本提升有效性。

3 nanoGPT源码分析(GPT-2)

  • nanoGPT是李飞飞教授的学生,前特斯拉AI总监Andrej Karpathy基于OpenWebText重现 GPT-2 (124M)的开源项目。
  • nanoGPT删繁就简,代码非常简单,cpu也可以跑。无论是从头开始训练新模型,或者微调,都能很容易满足需求。

在这里插入图片描述

  • GPT-2原理解释可参考:https://jalammar.github.io/illustrated-gpt2/

  • nanoGPT源码仓库:https://github.com/karpathy/nanoGPT

  • 首先我们clone下整个仓库,并按照仓库的Install的部分安装需要的package即可

  • 如果我们想使用torch.compile的技术则需要安装2.x版本的PyTorch【torch.compile是PyTorch 2.x版本提出的一个新的技术,可以有效降低我们训练模型的时间】

pip install torch numpy transformers datasets tiktoken wandb tqdm
  • nanoGPT提供了两个案例:
    • 第一个案例是在shakespeare上的构建的字符级别(character-level)的GPT2。这个数据很小,并且训练很快,很适合理解原理。
    • 第二个就是根据原始论文在OpenWebText数据上重现GPT2,但是这个需要在8块A100 40GB机器上训练4天。

3.1 GPT-2模型

3.1.1 LayerNorm

import math
import inspect
from dataclasses import dataclassimport torch
import torch.nn as nn
from torch.nn import functional as Fclass LayerNorm(nn.Module):""" LayerNorm but with an optional bias. PyTorch doesn't support simply bias=False """def __init__(self, ndim, bias):super().__init__()self.weight = nn.Parameter(torch.ones(ndim))self.bias = nn.Parameter(torch.zeros(ndim)) if bias else Nonedef forward(self, input):return F.layer_norm(input, self.weight.shape, self.weight, self.bias, 1e-5)

3.1.2 CausalSelfAttention

class CausalSelfAttention(nn.Module):def __init__(self, config):super().__init__()assert config.n_embd % config.n_head == 0# key, query, value projections for all heads, but in a batchself.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)# output projectionself.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)# regularizationself.attn_dropout = nn.Dropout(config.dropout)self.resid_dropout = nn.Dropout(config.dropout)self.n_head = config.n_headself.n_embd = config.n_embdself.dropout = config.dropout# flash attention make GPU go brrrrr but support is only in PyTorch >= 2.0self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention')if not self.flash:print("WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0")# causal mask to ensure that attention is only applied to the left in the input sequence# 目的:使用下三角矩阵屏蔽未来词汇# torch.ones(config.block_size, config.block_size) 创建一个 block_size * block_size的矩阵# torch.tril() 将上三角元素设置为0,下三角仍为1self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size)).view(1, 1, config.block_size, config.block_size))def forward(self, x):B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)# calculate query, key, values for all heads in batch and move head forward to be the batch dimq, k, v  = self.c_attn(x).split(self.n_embd, dim=2)# 多头注意力机制,需要拆头k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2) # (B, nh, T, hs)# causal self-attention; Self-attend: (B, nh, T, hs) x (B, nh, hs, T) -> (B, nh, T, T)if self.flash:# efficient attention using Flash Attention CUDA kernelsy = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=self.dropout if self.training else 0, is_causal=True)else:# manual implementation of attention# QK^T / sqrt(d_k)att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))# 将矩阵的为0的地方填为-inf(即上三角填充为-inf、屏蔽未来词汇)att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))# 得到的softmax以后,矩阵上三角部分的注意力权重 -> 0att = F.softmax(att, dim=-1)# 如果提供了dropout,对注意力权重att进行dropout操作att = self.attn_dropout(att)# 对value进行加权求和y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)# 维度转换y = y.transpose(1, 2).contiguous().view(B, T, C) # re-assemble all head outputs side by side# output projectiony = self.resid_dropout(self.c_proj(y))return y

3.1.3 FFN

class MLP(nn.Module):"""前馈神经网络FFN"""def __init__(self, config):super().__init__()# 注意:输出维度为n_embd的4倍self.c_fc    = nn.Linear(config.n_embd, 4 * config.n_embd, bias=config.bias)self.gelu    = nn.GELU()self.c_proj  = nn.Linear(4 * config.n_embd, config.n_embd, bias=config.bias)self.dropout = nn.Dropout(config.dropout)def forward(self, x):x = self.c_fc(x)x = self.gelu(x)x = self.c_proj(x)x = self.dropout(x)return x

3.1.4 组装Block

  • 包含两个子层,一个为Masked Multi-Head Attention,一个为FFN
  • 后置层归一化(post-norm)改为前置层归一化(pre-norm)
class Block(nn.Module):def __init__(self, config):super().__init__()self.ln_1 = LayerNorm(config.n_embd, bias=config.bias)self.attn = CausalSelfAttention(config)self.ln_2 = LayerNorm(config.n_embd, bias=config.bias)self.mlp = MLP(config)def forward(self, x):# 1、第一个子层# Sublayer(LayerNorm(x)) + x,其中Sublayer = Masked Multi-Head Attentionx = x + self.attn(self.ln_1(x))# 2、第二个子层# Sublayer(LayerNorm(x)) + x,其中Sublayer = FFNx = x + self.mlp(self.ln_2(x))return x

3.1.5 封装GPT-2模型

  • 这里主要看下初始化、forward函数、以及generate函数。当然作者还提供了从transformers库加载GPT2模型、配置优化器等函数,可以自己看下。

  • 这里作者把word embedding和lm_head权重进行共享,减少了模型参数

  • c_proj的权重初始化按n_layer层数进行缩放

class GPT(nn.Module):def __init__(self, config):super().__init__()assert config.vocab_size is not Noneassert config.block_size is not Noneself.config = configself.transformer = nn.ModuleDict(dict(wte = nn.Embedding(config.vocab_size, config.n_embd), # word table embeddingwpe = nn.Embedding(config.block_size, config.n_embd), # word position embeddingdrop = nn.Dropout(config.dropout),h = nn.ModuleList([Block(config) for _ in range(config.n_layer)]),ln_f = LayerNorm(config.n_embd, bias=config.bias),))# 定义一个线性层,将模型的输出维度映射到词汇表大小self.lm_head = nn.Linear(config.n_embd, config.vocab_size, bias=False)# with weight tying when using torch.compile() some warnings get generated:# "UserWarning: functional_call was passed multiple values for tied weights.# This behavior is deprecated and will be an error in future versions"# not 100% sure what this is, so far seems to be harmless. TODO investigate# word embedding和lm_head权重共享self.transformer.wte.weight = self.lm_head.weight # https://paperswithcode.com/method/weight-tying# init all weightsself.apply(self._init_weights)# apply special scaled init to the residual projections, per GPT-2 paperfor pn, p in self.named_parameters():if pn.endswith('c_proj.weight'):# 初始化方式按n_layer层数进行缩放torch.nn.init.normal_(p, mean=0.0, std=0.02/math.sqrt(2 * config.n_layer))# report number of parameters(参数共享后的参数量)print("number of parameters: %.2fM" % (self.get_num_params()/1e6,))
  • 在模型最后一个自注意力层之后,额外增加一个层归一化
  • 推理时只用最后一个时刻的hidden state去预测logits
    ......def forward(self, idx, targets=None):device = idx.deviceb, t = idx.size()assert t <= self.config.block_size, f"Cannot forward sequence of length {t}, block size is only {self.config.block_size}"pos = torch.arange(0, t, dtype=torch.long, device=device) # shape (t)# forward the GPT model itselftok_emb = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)pos_emb = self.transformer.wpe(pos) # position embeddings of shape (t, n_embd)x = self.transformer.drop(tok_emb + pos_emb)for block in self.transformer.h:x = block(x)# 在模型最后一个自注意力层之后,额外增加一个层归一化x = self.transformer.ln_f(x)if targets is not None:# if we are given some desired targets also calculate the losslogits = self.lm_head(x)loss = F.cross_entropy(logits.view(-1, logits.size(-1)), targets.view(-1), ignore_index=-1)else:# inference-time mini-optimization: only forward the lm_head on the very last position# 小优化: 只用最后一个时刻的hidden state去预测logitslogits = self.lm_head(x[:, [-1], :]) # note: using list [-1] to preserve the time dimloss = Nonereturn logits, loss
  • 推理过程中,如果上下文长度大于block_size(1024),就裁剪到block_size长度,再输入到模型中推理
  • temperature的是作用在logits上,通过temperature的大小来调整模型输出概率分布的"尖锐程度"。
    • 当temperature值较高(大于1)时,将使得概率分布变得更加均匀,模型的预测结果将更加多样化,但可能不太准确。换句话说,模型会有更大的概率去尝试预测不太可能的结果。
    • 相反,当temperature值较低(小于1)时,将使得概率分布变得更加尖锐,模型的预测结果将更加聚焦于最有可能的结果,但可能牺牲多样性。换句话说,模型会更倾向于预测最可能的结果。
  • top_k在这段代码中的作用就是在生成预测结果后,只保留得分最高的前k个选项,以减少计算量,并且可能提高模型的生成质量。
    ......@torch.no_grad()def generate(self, idx, max_new_tokens, temperature=1.0, top_k=None):"""Take a conditioning sequence of indices idx (LongTensor of shape (b,t)) and completethe sequence max_new_tokens times, feeding the predictions back into the model each time.Most likely you'll want to make sure to be in model.eval() mode of operation for this."""for _ in range(max_new_tokens):# if the sequence context is growing too long we must crop it at block_size# 如果上下文长度大于block_size,就裁剪到block_size长度,再输入到模型中推理idx_cond = idx if idx.size(1) <= self.config.block_size else idx[:, -self.config.block_size:]# forward the model to get the logits for the index in the sequencelogits, _ = self(idx_cond)# pluck the logits at the final step and scale by desired temperature"""temperature的是作用在logits上,通过temperature的大小来调整模型输出概率分布的"尖锐程度"。当temperature值较高(大于1)时,将使得概率分布变得更加均匀,模型的预测结果将更加多样化,但可能不太准确。换句话说,模型会有更大的概率去尝试预测不太可能的结果。相反,当temperature值较低(小于1)时,将使得概率分布变得更加尖锐,模型的预测结果将更加聚焦于最有可能的结果,但可能牺牲多样性。换句话说,模型会更倾向于预测最可能的结果。"""logits = logits[:, -1, :] / temperature# optionally crop the logits to only the top k options"""top_k在这段代码中的作用就是在生成预测结果后,只保留得分最高的前k个选项,以减少计算量,并且可能提高模型的生成质量。"""if top_k is not None:# 取概率最大的topk个v, _ = torch.topk(logits, min(top_k, logits.size(-1)))# 取topk个概率中最小的值,如果小于此值就将概率置为负无穷,这样softmax后就趋近于0logits[logits < v[:, [-1]]] = -float('Inf')# apply softmax to convert logits to (normalized) probabilitiesprobs = F.softmax(logits, dim=-1)# sample from the distribution# 函数作用是从模型输出的概率中采样,num_samples是采样次数idx_next = torch.multinomial(probs, num_samples=1)# append sampled index to the running sequence and continue# 在得到idx_next后,我们会将它和idx合在一起去预测新的下一个idxidx = torch.cat((idx, idx_next), dim=1)return idx

3.2 数据集的准备

我们运行下面的代码就可以下载Shakespeare数据集,并且将数据分为训练集和验证集。

python data/shakespeare_char/prepare.py

执行这个命令后,命令行会输出下面的内容:

length of dataset in characters:  1115394 # 统计所有数据有多少个字母
all the unique characters: 
!$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz # 统计有哪些字母和符号
vocab size: 65 # 构建vocabulary,即就是上面的这些符号
train has 1003854 tokens
val has 111540 tokens
  • 运行完后,我们也会得到train.bin和val.bin。这两个.bin文件其实就是将对应字母/符号映射到vocabulary后的序号存储而形成的。train.bin包含了100w个token,val.bin包括10w个token。

  • 因为我们训练的是character-level的GPT,所以我们的vocabulary就是26个英文字母的大小写和一些特殊符号组成,vocabulary的大小是65。运行完后,会将vocabulary信息:vocab_size、itos(id2token)、stoi(token2id)封装为dict,保存到meta.pkl文件中。

  • 具体可以看data/shakespeare_char/prepare.py源码。

3.3 模型训练过程

  • 其中如果pytorch版本<2.0,在训练时需要指定compile为False

  • 如果没有GPU,但是又想训练尝试下的话,需要设置device为cpu

  • train.py一共只有300多行代码,写的通俗易懂,支持多卡训练,混合精度训练,断点训练,梯度积累,动态学习率变化,wandb记录log,具体可以看源码。

这里主要介绍下DDP分布式训练的相关概念:

  • rank
    • 进程号,在多进程上下文中,我们通常假定rank 0是第一个进程或者主进程
    • 其它进程分别具有1,2,3不同rank号,这样总共具有4个进程
  • node
    • 物理节点,可以是一个容器也可以是一台机器,节点内部可以有多个GPU
    • nnodes指物理节点数量
    • nproc_per_node指每个物理节点上面进程的数量
  • local_rank
    • 指在一个node上进程的相对序号
    • local_rank在node之间相互独立
    • 这里需要注意的是,在torch.distributed.launch中,local_rank是隐式参数,即torch自动分配的。local_rank可以通过自动注入命令行参数来获得 。
    • torch1.10开始,官方建议使用环境变量的方式来获取local_rank。用终端命令torchrun来代替torch.distributed.launchlocal_rank不再支持用命令行隐式传递的方式,完全使用环境变量配置各类参数。
  • world_size
    • 全局进程总个数,即在一个分布式任务中rank的数量
  • group
    • 进程组,一个分布式任务对应了一个进程组。
    • 只有用户需要创立多个进程组时才会用到group来管理
    • 默认情况下只有一个group

如下图所示,共有3个节点(node),每个节点上有4个GPU,每台机器上起4个进程,每个进程占一块GPU,那么图中一共有12个rank(world_size),nproc_per_node=4,nnodes=3,每个节点都一个对应的node_rank。

在这里插入图片描述

  • backend

    • 通信后端,可选的包括:nccl(NVIDIA推出)、gloo(Facebook推出)、mpi(OpenMPI)。
    • 一般建议GPU训练选择nccl,CPU训练选择gloo
  • master_addr与master_port

    • 主节点的地址以及端口,供init_method 的tcp方式使用。
    • 因为pytorch中网络通信建立是从机去连接主机,运行ddp只需要指定主节点的IP与端口,其它节点的IP不需要填写。
    • 这个两个参数可以通过环境变量或者init_method传入
# 方式1:
os.environ['MASTER_ADDR'] = 'localhost'
os.environ['MASTER_PORT'] = '12345'
dist.init_process_group(“nccl”, rank=rank, world_size=world_size)# 方式2:
dist.init_process_group(“nccl”,init_method=“tcp://localhost:12345,rank=rank,world_size=world_size)

使用DDP分布式训练的话,主要包含下面步骤:

  • 初始化进程组 dist.init_process_group
  • 设置分布式采样器 DistributedSampler
  • 使用DistributedDataParallel封装模型
  • 使用torchrun 或者 mp.spawn 启动分布式训练

需要注意的点:

  • 模型保存的时候,注意调用model.module.state_dict()
  • 加载模型的时候,需要利用map_location参数指定加载的设备。
  • 有了sampler,在DataLoader中不需要shuffle
  • 在每个训练周期开始处,调用sampler.set_epoch(epoch)使得数据充分打乱

DDP代码实战可以参考:

Pytorch DDP分布式训练介绍

相关文章:

经典神经网络(13)GPT-1、GPT-2原理及nanoGPT源码分析(GPT-2)

经典神经网络(13)GPT-1、GPT-2原理及nanoGPT源码分析(GPT-2) 2022 年 11 月&#xff0c;ChatGPT 成功面世&#xff0c;成为历史上用户增长最快的消费者应用。与 Google、FaceBook等公司不同&#xff0c;OpenAI 从初代模型 GPT-1 开始&#xff0c;始终贯彻只有解码器&#xff0…...

MySQL库与表的操作

目录 一、登录并进入数据库 1、登录 2、USE 命令 检查当前数据库 二、库的操作 1、创建数据库语法 2、举例演示 3、退出 三、字符集和校对规则 1、字符集&#xff08;Character Set&#xff09; 2、校对集&#xff08;Collation&#xff09; 总结 3、操作命令 …...

TTS 语音合成技术学习

TTS 语音合成技术 TTS&#xff08;Text-to-Speech&#xff0c;文字转语音&#xff09;技术是一种能够将文字内容转换为自然语音的技术。通过 TTS&#xff0c;机器可以“说话”&#xff0c;这大大增强了人与机器之间的互动能力。无论是在语音助手、导航系统还是电子书朗读器中&…...

小公司做自动化的困境

1. 人员数量不够 非常常见的场景, 开发没几个, 凭什么测试要那么多, 假设这里面有3个测试, 是不是得有1个人会搞框架? 是不是得有2人搞功能测试, 一个人又搞框架, 有些脚本, 真来得及吗? 2. 人员基础不够 现在有的大公司, 是这样子协作的, 也就是某模块需求谁谁测试的, 那么…...

基于pytorch框架的手写数字识别(保姆级教学)

1、前言 本文基于PyTorch框架,采用CNN卷积神经网络实现MNIST手写数字识别,不仅可以在GPU上,同时也可以在CPU上运行。方便即使只有CPU的小伙伴也可以运行该模型。本博客手把手教学,如何手写网络层(3层),以及模型训练,详细介绍各参数含义与用途。 2、模型源码解读 该模型…...

注意力机制在大语言模型中的应用

在大语言模型中&#xff0c;注意力机制&#xff08;Attention Mechanism&#xff09;用于捕获输入序列中不同标记&#xff08;token&#xff09;之间的关系和依赖性。这种机制可以动态地调整每个标记对当前处理任务的重要性&#xff0c;从而提高模型的性能。具体来说&#xff0…...

qt 实现对字体高亮处理原理

在Qt中实现对文本的字体高亮处理&#xff0c;通常涉及到使用QTextDocument、QTextCharFormat和QSyntaxHighlighter。下面是一个简单的例子&#xff0c;演示如何为一个文本编辑器&#xff08;假设是QTextEdit&#xff09;添加简单的关键词高亮功能&#xff1a; 步骤 1: 定义关键…...

SAP中通过财务科目确定分析功能来定位解决BILLING问题实例

接用户反馈&#xff0c;一笔销售订单做发货后做销售发票时&#xff0c;没有成功过账到财务&#xff0c;提示财户确定错误。 这个之前可以通过VF02中点击小绿旗来重新执行过财动作&#xff0c;看看有没有相应日志来定位问题。本次尝试用此方法&#xff0c;也没有找到相关线索。 …...

充电站,正在杀死加油站

最近&#xff0c;深圳公布了一组数据&#xff0c;深圳的超级充电站数量已超过传统加油站数量&#xff0c;充电枪数量也已超过加油枪数量。 从全国范围看&#xff0c;加油站关停的速度在加快。 充电站正在杀死加油站。 加油站&#xff0c;未来何去何从&#xff1f; 01. 减少 我…...

哪个牌子的超声波清洗机好?四样超卓超声波清洗机独具特色!

眼镜是许多人日常生活中必不可少的工具&#xff0c;然而&#xff0c;相信很多人都有过清洗眼镜的烦恼。传统的清洗眼镜的方法往往不够彻底&#xff0c;容易留下污渍或者划伤镜片。因此&#xff0c;超声波洗眼镜机成为了现代人清洗眼镜的新选择。超声波洗眼镜机通过利用超声波震…...

vue3中若v-model绑定的响应字段出现三级,该如何实现rules验证规则

比如以下内容&#xff1a; 配置的rules内容 const rulesref({title:[{required:true,message:"请输入标题",trigger:"blur"},{max:50,message:"最大不能超过256个字",trigger:"blur"}],Category:[{required:true,message:"请选择…...

Docker-Compose一键部署项目

Docker-Compose一键部署项目 目录 Docker-Compose一键部署项目介绍部署Django项目项目目录结构 docker-compose.ymlnginx的default.conf文件后端Dockerfile文件mysql.env一键部署DNS域名解析引起的跨域问题 介绍 Docker Compose 是一个用于定义和运行多容器 Docker 应用程序的…...

【C++】相机标定源码笔记-线激光点云处理工具类

一个线激光点云处理工具类&#xff0c;它包含了一系列的方法用于处理和分析线激光扫描得到的点云数据。提供的功能包括&#xff1a; 通过文件或直接数据设置点云。计算线激光在机器人坐标系下的精度&#xff0c;输出内点的平均距离、最大距离、最小距离、总点数和内点数。提供了…...

解决Transformer根本缺陷,所有大模型都能获得巨大改进

即使最强大的 LLM 也难以通过 token 索引来关注句子等概念&#xff0c;现在有办法了。 最近两天&#xff0c;马斯克和 LeCun 的口水战妥妥成为大家的看点。这两位 AI 圈的名人你来我往&#xff0c;在推特&#xff08;现为 X&#xff09;上相互拆对方台。 LeCun 在宣传自家最新论…...

如何排查Java应用的死锁

排查Java应用中的死锁问题是一个复杂但重要的任务&#xff0c;因为死锁会导致应用程序停止响应&#xff0c;影响用户体验和系统稳定性。以下是一些方法和步骤&#xff0c;帮助你排查Java应用中的死锁。 1. 理解死锁的概念 在计算机科学中&#xff0c;死锁是指两个或多个线程相…...

JS面试题1

1. 延迟加载JS有哪些方式&#xff1f; defer: 等html全部解析完成&#xff0c;才会执行js代码&#xff0c;顺次执行js脚本 async&#xff1a;是和html解析同步的&#xff0c;不是顺次执行js脚本&#xff08;当有很多个js时&#xff09;&#xff0c;是谁先加载完谁先执行。 <…...

Linux网络 - 再谈、详谈UDP和TCP协议

文章目录 前言预备netstatpidofcat /etc/services 一、UDP协议UDP协议端格式UDP的缓冲区基于UDP的应用层协议 二、TCP协议1.TCP协议段格式确认应答(ACK)机制三次握手疑问1 最后一次客户端发给服务端的ACK请求怎么保证服务端能够收到&#xff1f; 四次挥手疑问2 为什么挥手是四次…...

el-form重置后input无法输入问题

新增用户遇到的问题&#xff1a; 如果你没有为 formData 设置默认值&#xff0c;而只是将其初始化为空对象 {}&#xff0c;则在打开dialog时&#xff0c;正常输入&#xff0c; formdata会变成如下 但是&#xff0c;打开后&#xff0c;直接使用 resetFields 或直接清空表单&…...

Java网络编程(JavaWeb的基础)

Java网络编程&#xff08;JavaWeb的基础&#xff09; 文章目录 Java网络编程&#xff08;JavaWeb的基础&#xff09;前言一、网络编程概述1.1 软件架构&网络基础1.2 网络通信要素:IP/端口/通信协议1.3 传输层协议:tcp/udp 二、网络编程API2.1 InetAddress类2.2 Socket类&am…...

鸿蒙Harmony开发实战案例:使用OpenGL绘制3D图形

XComponent控件常用于相机预览流的显示和游戏画面的绘制,在OpenHarmony上&#xff0c;可以配合Native Window创建OpenGL开发环境&#xff0c;并最终将OpenGL绘制的图形显示到XComponent控件。本文将采用"Native C"模板&#xff0c;调用OpenGL ES图形库绘制3D图形&…...

DM达梦数据库存储过程

&#x1f49d;&#x1f49d;&#x1f49d;首先&#xff0c;欢迎各位来到我的博客&#xff0c;很高兴能够在这里和您见面&#xff01;希望您在这里不仅可以有所收获&#xff0c;同时也能感受到一份轻松欢乐的氛围&#xff0c;祝你生活愉快&#xff01; &#x1f49d;&#x1f49…...

【python】OpenCV—Color Correction

文章目录 cv2.aruco 介绍imutils.perspective.four_point_transform 介绍skimage.exposure.match_histograms 介绍牛刀小试遇到的问题 参考学习来自 OpenCV基础&#xff08;18&#xff09;使用 OpenCV 和 Python 进行自动色彩校正 cv2.aruco 介绍 一、cv2.aruco模块概述 cv2.…...

Java基础知识整理笔记

目录 1.关于Java概念 1.1 谈谈对Java的理解&#xff1f; 1.2 Java的基础数据类型&#xff1f; 1.3 关于面向对象的设计理解 1.3.1 面向对象的特性有哪些&#xff1f; 1.3.2 重写和重载的区别&#xff1f; 1.3.3 面向对象的设计原则是什么&#xff1f; 1.4 关于变量与方…...

知识图谱——Neo4j数据库实战

数据与代码链接见文末 1.Neo4j数据库安装 JDK 安装:https://www.oracle.com/java/technologies/javase-downloads.html Neo4j 安装:https://neo4j.com/download-center/ 配置好 JDK 和 Neo4j 的环境变量...

第十一次Javaweb作业

4.登录校验 4.1会话 --用户打开浏览器&#xff0c;访问web服务器的资源&#xff0c;会话建立&#xff0c;直到有一方断开连接&#xff0c;会话结束。在一次会话中可以包含多次请求和响应。 会话跟踪&#xff1a;一种维护浏览器状态的方法&#xff0c;服务器需要识别多次请求…...

人工智能AI风口已开:如何赋予UI设计与视频剪辑新生命

随着科技的浪潮不断向前推进&#xff0c;人工智能&#xff08;AI&#xff09;正以惊人的速度重塑着我们的世界&#xff0c;特别是在创意产业的核心领域——UI设计与视频剪辑中&#xff0c;AI正逐步成为驱动行业创新与变革的关键力量。在这个AI技术全面开花的新时代&#xff0c;…...

计算机专业课面试常见问题-编程语言篇

目录 1. 程序的编译执行流程&#xff1f; 2. C浅拷贝和深拷贝的区别&#xff1f; 3. C虚函数&#xff1f; …...

CSS|05 继承性与优先级

继承性 一、继承性的特点&#xff1a; 1.外层元素身上的样式会被内层元素所继承 2.如果内层元素与外层元素身上的演示相同时&#xff0c;外层元素的样式会被内层元素所覆盖 二、关于继承性的问题 是不是所有样式都能被继承&#xff1f; 答&#xff1a;并不是所有样式能被继承…...

KVM性能优化之内存优化(宿主机)

linux系统自带了一技术叫透明巨型页&#xff08;transparent huge page&#xff09;&#xff0c;它允许所有的空余内存被用作缓存以提高性能&#xff0c;而且这个设置是默认开启的&#xff0c;我们不需要手动去操作。 Centos下&#xff0c;我们用cat /sys/kernel/mm/transpare…...

【Linux杂货铺】Linux学习之路:期末总结篇1

第一章 什么是Linux? Linux 是 UNIX 操作系统的一个克隆&#xff1b;它由林纳斯 本纳第克特 托瓦兹从零开始编写&#xff0c;并在网络上众多松散的黑客团队的帮助下得以发展和完善&#xff1b;它遵从可移植操作系统接口&#xff08;POSIX&#xff09;标准和单一 UNIX 规范…...

GPT-5的到来:智能飞跃与未来畅想

IT之家6月22日消息&#xff0c;在美国达特茅斯工程学院的采访中&#xff0c;OpenAI首席技术官米拉穆拉蒂确认了GPT-5的发布计划&#xff0c;预计将在一年半后推出。穆拉蒂形象地将GPT-4到GPT-5的飞跃比作高中生到博士生的成长。这一飞跃将给我们带来哪些变化&#xff1f;GPT-5的…...

gin中间件

在web应用服务中&#xff0c;完整的业务处理在技术上包含客户端操作&#xff0c;服务端处理&#xff0c;返回处理结果给客户端三个步骤。但是在在更负责的业务和需求场景。一个完整的系统可能要包含鉴权认证&#xff0c;权限管理&#xff0c;安全检查&#xff0c;日志记录等多维…...

swagger常用注解

最近查看接口文档的时候发现&#xff0c;POST方法中的query没法在swagger中显示&#xff0c;查了才发现这是因为Swagger或OpenAPI规范默认将HTTP POST请求的参数识别为请求体&#xff08;body&#xff09;参数&#xff0c;而不是查询字符串&#xff08;query&#xff09;参数。…...

【Flink metric(1)】Flink指标系统的系统性知识:获取metric以及注册自己的metric

文章目录 一. Registering metrics&#xff1a;向flink注册新自己的metrics1. 注册metrics2. Metric types:指标类型2.1. Counter2.2. Gauge2.3. Histogram(ing)2.4. Meter 二. Scope:指标作用域1. User Scope2. System Scope ing3. User Variables 三. Reporter ing四. System…...

命令模式(Command Pattern)

命令模式&#xff08;Command Pattern&#xff09; 定义 命令模式是对命令的封装&#xff0c;每一个命令都是一个操作&#xff1a;请求的一方发出请求要求执行一个操作&#xff1b;接收的一方收到请求&#xff0c;并执行操作。 命令模式解耦了请求方和接收方&#xff0c;请求…...

掌握Symfony的模板继承:构建强大且灵活的Web界面

掌握Symfony的模板继承&#xff1a;构建强大且灵活的Web界面 在Symfony框架中&#xff0c;模板继承是一个强大的功能&#xff0c;它允许开发者创建可重用的布局模板&#xff0c;并通过扩展这些模板来构建具体的页面。这种机制不仅提高了代码的可维护性&#xff0c;还使得页面结…...

uboot基本使用网络命令和从服务器端下载linux内核启动

网络命令ip地址设置: setenv gmac_debug 0; setenv mdio_intf rgmii; setenv bootdelay 1; setenv ethaddr 00:xxxx:81:70; // mac地址 setenv ipaddr xxx; //开发板 IP 地址 setenv netmask 255.255.255.0; setenv gatewayip xxx.1; setenv serverip xxxx; //服…...

解决ArcGIS导出的svg格式的图片插入Word后的字体问题

背景 在ArcGIS中设置字体为Times New Roman&#xff0c;但导入Word后字体转为等线。 ArcGIS中的Layout 导入Word​​​​​​ 原因分析 Word无法识别嵌入进SVG格式文件中的字体。 解决方案 在Export Layer窗口中&#xff0c;将Embed fonts取消勾选&#xff0c;Convert cha…...

如何确保 Puppet 配置在复杂网络环境中的可靠分发和同步?

在复杂网络环境中确保 Puppet 配置的可靠分发和同步可以采取以下措施&#xff1a; 网络拓扑规划&#xff1a;在复杂网络环境中&#xff0c;首先需要进行网络拓扑规划&#xff0c;确保网络结构合理&#xff0c;并能够支持可靠的分发和同步机制。 Puppet Master 多节点部署&…...

2024最新!将mysql的数据导入到Solr

Solr导入mysql的数据 如何安装导入数据前准备配置Solr的Jar包以及Mysql驱动包1.1、将solr-8.11.3\dist下的两个包进行移动1.2、将mysql-connect包也移动到该位置1.3、重启Solr项目 配置xml2.1、第一步我们需要创建核心2.2、第二步修改xml(这里是结合19年的教程)2.3、 创建data-…...

Python数据分析第二课:conda的基础命令

Python数据分析第二课&#xff1a;conda的基础命令 1.conda是什么? conda是一个开源的包管理系统&#xff0c;可以帮助我们进行管理多个不同版本的软件包&#xff0c;还可以帮助我们建立虚拟环境&#xff0c;以便对不同的项目进行隔离。 简单来说&#xff0c;conda是一个软…...

LayoutInflater加载流程

简介 LayoutInflater在日常的Android开发中是经常使用的类&#xff0c;常常用于XML中View的加载相关流程。本文主要总结一些其常见api的源码流程。 获取LayoutInflater 我们一般会在Activity的onCreate方法中会通过setContentView方法设置自己的布局layoutId&#xff0c;Act…...

PLC数据采集案例

--------天津三石峰科技案例分享 项目介绍 项目背景 本项目为天津某钢铁集团下数字化改造项目&#xff0c;主要解决天津大型钢厂加氢站数字化改造过程中遇到的数据采集需求。项目难点PLC已经在运行了&#xff0c;需要采集里面数据&#xff0c;不修改程序&#xff0c;不影响P…...

基于单片机和LabVIEW 的远程矿井水位监控系统设计

摘要 &#xff1a; 针 对 现 有 矿 井 水 位 监 控 系 统 存 在 结 构 复 杂 和 不 能 远 程 监 控 的 问 题 &#xff0c; 设计了基于单片机和&#xff2c;&#xff41;&#xff42;&#xff36;&#xff29;&#xff25;&#xff37; 的远程矿井水位监控系统 &#xff0c; 详…...

element 表格嵌套表单验证指定行

elementui表格嵌套动态表单&#xff0c;单独验证某一行输入项是否符合校验规则&#xff1b; input动态绑定校验 :prop"imgTable. scope.$index .bxName" <el-form :model"formTable" ref"formTable" inline size"small"><…...

CORE Mobility Errorr的调试

在运行CORE tutorial 3中的mobility示例时&#xff0c;出现如下错误&#xff1a; 当看到这个问题的时候&#xff0c;并没有仔细去分析日志和现象&#xff0c;在core-daemon的进程打印界面只看了一下最后的出错堆栈&#xff1a; 2024-06-27 10:43:48,614 - ERROR - _server:_ca…...

基于weixin小程序乡村旅游系统的设计

管理员账户功能包括&#xff1a;系统首页&#xff0c;个人中心&#xff0c;用户管理&#xff0c;商家管理&#xff0c;旅游景点管理&#xff0c;景点类型管理&#xff0c;景点路线管理&#xff0c;系统管理 商家帐号账号功能包括&#xff1a;系统首页&#xff0c;旅游景点管理&…...

详解三种常用标准化 Batch Norm Layer Norm RMSNorm

参考&#xff1a; BN究竟起了什么作用&#xff1f;一个闭门造车的分析《动手学深度学习》7.5 节 深度学习中&#xff0c;归一化是常用的稳定训练的手段&#xff0c;CV 中常用 Batch Norm&#xff1b; Transformer 类模型中常用 layer norm&#xff0c;而 RMSNorm 是近期很流行…...

云计算运维工程师面试

1. 云计算运维工程师的角色和职责是什么? 回答: 云计算运维工程师负责确保云计算环境(包括硬件和软件系统)的高可用性和稳定性。他们的主要职责包括: 监测系统和应用程序的性能,确保它们正常运行。故障排除,快速响应并解决系统或应用程序中出现的问题。容量规划,根据…...

聚观早报 | iPhone 16核心硬件曝光;三星Galaxy全球新品发布会

聚观早报每日整理最值得关注的行业重点事件&#xff0c;帮助大家及时了解最新行业动态&#xff0c;每日读报&#xff0c;就读聚观365资讯简报。 整理丨Cutie 6月28日消息 iPhone 16核心硬件曝光 三星Galaxy全球新品发布会 苹果正多方下注布局AI商店 黄仁勋2024年薪酬3400…...