当前位置: 首页 > news >正文

深度学习:GPT-2的MindSpore实践

GPT-2简介

GPT-2是一个由OpenAI于2019年提出的自回归语言模型。与GPT-1相比,仍基于Transformer Decoder架构,但是做出了一定改进。

模型规格上:

GPT-1有117M参数,为下游微调任务提供预训练模型。

GPT-2显著增加了模型规模,提供了多种模型,如:124M、355M、774M和1.5B

数据集大小上:

GPT-2训练于数据量约有45GB的WebText数据集。数据集的数据收集于Reddit中的网络文章。

模型架构上:

GPT-2维持了GPT-1的Decoder-only架构,但是讲Decoder Block增加至48层,采用了更深层的注意力机制和更大的前馈网络维度并改进了正则化。同时,GPT-2加入了可学习的位置编码。将Layer Norm前置于模型获得输入后。一个额外的Layer Norm被添加于最后一个自注意力Block后。

参数初始化上:

参数初始化上,使用了一个Special Scaled Initialization。Special Scaled Initialization是Xavier Normalization的一种变体,使用了额外的缩放。将因子n调整为残差连接的次数,也就是block数量的两倍。

任务设定

一般语言模型的训练目标设置为:P(output | input)

但是,GPT-2通过同样的无监督模型来完成多个既定任务,学习目标变为:P(output | input, task)

这种修改被称为任务设定。对于同样的输入,模型应该根据不同的任务输出不同的结果。

翻译任务

文本总结

其他下游任务:基于zero-shot或few-shot的文本生成、文本总结、文本翻译、QA问答、文本分类。

基于MindSpore的GPT-2实践

复习Masked Multi Self-Attention

#安装mindnlp 0.4.0套件
!pip install mindnlp==0.4.0
!pip uninstall soundfile -y
!pip install https://ms-release.obs.cn-north-4.myhuaweicloud.com/2.3.1/MindSpore/unified/aarch64/mindspore-2.3.1-cp39-cp39-linux_aarch64.whl --trusted-host ms-release.obs.cn-north-4.myhuaweicloud.com -i https://pypi.tuna.tsinghua.edu.cn/simple

假设一个批大小为1,序列长度为10,特征维度为768的输入

 # GPT-2 Masked Self-Attention# assume an input of self-attention x, input_dim is 768
batch_size, seq_len, embed_dim = 1, 10, 768x = Tensor(np.random.randn(batch_size, seq_len, embed_dim), mindspore.float32)
x.shape

将输入复制三份作为Q、K、V。

import mindspore.ops as ops
from mindnlp.transformers.ms_utils import Conv1D# an input will be multipled by three matrixs Wq, Wk, Wv
# concat the three matrixs, you will get a matrix of (768, 768*3)
# x matmul matrix, the output would be (batch_size, seq_len, 768*3)
c_attn = Conv1D(3 * embed_dim, embed_dim)
output = c_attn(x)
# split the output into q, k, v
query, key, value = ops.Split(axis=2, output_num=3)(output)
query.shape, key.shape, value.shape

 将注意力分头

# split self-attention into multi_head attention
def split_heads(tensor, num_heads, attn_head_size):'''Spilit hidden_size dim into attn_head_size and num_headsArgs:tensor: tensor to splitnum_heads: how many heads to splitattn_head_size: hidden_size of each headReturn:Multi-Head tensor'''new_shape = tensor.shape[:-1] + (num_heads, attn_head_size)tensor = tensor.view(new_shape)return ops.transpose(tensor, (0, 2, 1, 3))num_heads = 12
attn_head_size = embed_dim // num_headsquery = split_heads(query, num_heads, attn_head_size)
key = split_heads(key, num_heads, attn_head_size)
value = split_heads(value, num_heads, attn_head_size)query.shape, key.shape, value.shape

将Q、K相乘,得到注意力分数

# get self-attention score
attn_weights = ops.matmul(query, key.swapaxes(-1, -2))
attn_weights.shape

将注意力分数加上掩码,防止模型看见“未来”的数据

# get mask attn_weighs
max_positions = seq_len
# create a mask matrix
bias = Tensor(np.tril(np.ones((max_positions, max_positions))).reshape((1, 1, max_positions, max_positions)), mindspore.bool_)# apply Mask Matrix to get masked scores
# this normalization helps stabilize gradients 
# and is common in scaled dot-product attention mechanisms
attn_weights = attn_weights / ops.sqrt(ops.scalar_to_tensor(value.shape[-1]))query_length, key_length = query.shape[-2], key.shape[-2]
causal_mask = bias[:, :, key_length - query_length: key_length, :key_length].bool()
mask_value = Tensor(np.finfo(np.float32).min, dtype=attn_weights.dtype)
attn_weights = ops.where(causal_mask, attn_weights, mask_value)

 经过SoftMax层,得到掩码分数

# get attn scores
attn_weights = ops.softmax(attn_weights, axis=-1)

掩码分数与V相乘,得到注意力输出

# get output of Masked Self-Attention
attn_output = ops.matmul(attn_weights, value)
attn_output.shape

 将头合并

# merge multi heads
def merge_heads(tensor, num_heads, attn_head_size):'''Merge attn_head_size dim and num_attn_heads dim to hidden_size'''tensor = ops.transpose(tensor, (0, 2, 1, 3))new_shape = tensor.shape[:-2] + (num_heads * attn_head_size, )return tensor.view(new_shape)attn_output = merge_heads(attn_output, num_heads, attn_head_size)
attn_output.shape

将输出与Wv相乘,得到最终输出

# project Attnetion results with Wv
projection = Conv1D(embed_dim, embed_dim)
attn_output = projection(attn_output)
attn_output.shape

基于MindSpore的GPT2文本摘要

基于GPT-2实现一个简单的文本摘要。

# 数据加载与预处理
from mindnlp.utils import http_get# download dataset
url = 'https://download.mindspore.cn/toolkits/mindnlp/dataset/text_generation/nlpcc2017/train_with_summ.txt'
path = http_get(url, './')from mindspore.dataset import TextFileDataset# load dataset
dataset = TextFileDataset(str(path), shuffle =False)
dataset.get_dataset_size()mini_dataset, _ = dataset.split([0.001, 0.999], randomize=False)
train_dataset, test_dataset = mini_dataset.split([0.9, 0.1], randomize=False)import json
import numpy as np
def process_dataset(dataset, tokenizer, batch_size=4, max_seq_len=1024, shuffle=False):'''数据预处理:原始数据格式:article:[CLS] article_context [SEP]summary:[CLS] summary_context [SEP]预处理后的数据格式:[CLS] article_context [SEP] summary_context [SEP]'''def read_map(text):'''sub function to change the form of data'''data = json.loads(text.tobytes())print(data)return np.array(data['article']), np.array(data['summarization'])def merge_and_pad(article, summary):# tokenization, pad to max_seq_length, only article will be truncatedtokenized = tokenizer(text=article, text_pari=summary, padding='max_length', truncation='only_first', max_length=max_seq_len)# Returns tokenized input IDs for both the input (input_ids) and the labels.return tokenized['input_ids'], tokenized['input_ids']# 'text': Input column to process.# ['article', 'summary']: Names of the output columns after processing.dataset = dataset.map(read_map, output_columns=['article', 'summary'])dataset = dataset.map(merge_and_pad, ['article', 'summary'], ['input_ids', 'labels'])dataset = dataset.batch(batch_size)if shuffle:dataset = dataset.shuffle(batch_size)return datasetfrom mindnlp.transformers import BertTokenizer# Load BERT-base-Chinese tokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-chinese')# load train dataset
train_dataset = process_dataset(train_dataset, tokenizer, batch_size=1)# model architecture of GPT2ForSummarization
from mindnlp.transformers import GPT2LMHeadModelclass GPT2ForSummarization(GPT2LMHeadModel):def forward(self, input_ids=None, attention_mask=None, labels=None):outputs = super().forward(input_ids=input_ids, attention_mask=attention_mask)shift_logits = outputs.logits[..., :-1, :]shift_labels = labels[..., 1:]loss = ops.cross_entropy(shift_logits.view(-1, shift_logits.shape[-1]), shift_labels.view(-1), ignore_index=tokenizer.pad_token_id)return (loss, )num_epochs = 1
warmup_steps = 100
lr = 1.5e-4
max_grad_norm = 1.0
num_training_steps = num_epochs * train_dataset.get_dataset_size()from mindspore import nn
from mindnlp.transformers import GPT2Config, GPT2LMHeadModelconfig = GPT2Config(vocab_size=len(tokenizer))
model = GPT2ForSummarization(config)from mindnlp.engine import TrainingArgumentstraining_args = TrainingArguments(output_dir="gpt2_summarization",save_steps=train_dataset.get_dataset_size(),save_total_limit=3,logging_steps=1000,max_steps=num_training_steps,learning_rate=lr,max_grad_norm=max_grad_norm,warmup_steps=warmup_steps)from mindnlp.engine import Trainertrainer = Trainer(model=model,args=training_args,train_dataset=train_dataset,
)trainer.train()def process_test_dataset(test_dataset, tokenizer, batch_size=1, max_seq_len=1024, max_summary_len=100):def read_map(text):data = json.loads(text.tobytes())return np.array(data['article']), np.array(data['summarization'])def pad(article):tokenized = tokenizer(text=article, truncation=True, max_length=max_seq_len-max_summary_len)return tokenized['input_ids']test_dataset = test_dataset.map(read_map, output_columns=['article', 'summary'])test_dataset = test_dataset.map(pad, 'article', ['input_ids'])test_dataset = test_dataset.batch(batch_size)return test_datasettokenizer_test = BertTokenizer.from_pretrained('bert-base-chinese')batched_test_dataset = process_test_dataset(test_dataset, tokenizer_test, batch_size=1)model = GPT2LMHeadModel.from_pretrained('./gpt2_summarization/checkpoint-45', config=config)model.set_train(False)
model.config.eos_token_id = model.config.sep_token_id
i = 0
for (input_ids, raw_summary) in batched_test_dataset.create_tuple_iterator():output_ids = model.generate(input_ids, max_new_tokens=50, num_beams=5, no_repeat_ngram_size=2)output_text = tokenizer.decode(output_ids[0].tolist())print('input', tokenizer.decode(input_ids[0].tolist()))print()print(output_text)i += 1if i == 1:break

相关文章:

深度学习:GPT-2的MindSpore实践

GPT-2简介 GPT-2是一个由OpenAI于2019年提出的自回归语言模型。与GPT-1相比,仍基于Transformer Decoder架构,但是做出了一定改进。 模型规格上: GPT-1有117M参数,为下游微调任务提供预训练模型。 GPT-2显著增加了模型规模&…...

【Oracle11g SQL详解】ORDER BY 子句的排序规则与应用

ORDER BY 子句的排序规则与应用 在 Oracle 11g 中,ORDER BY 子句用于对查询结果进行排序。通过使用 ORDER BY,可以使返回的数据按照指定的列或表达式以升序或降序排列,便于数据的分析和呈现。本文将详细讲解 ORDER BY 子句的规则及其常见应用…...

YOLO系列论文综述(从YOLOv1到YOLOv11)【第15篇(完结):讨论和未来展望】

总结 0 前言1 YOLO与人工通用智能(AGI)2 YOLO作为“能够行动的神经网络”3 具身人工智能(EAI)4 边缘设备上的YOLO5 评估统计指标的挑战6 YOLO与环境影响 YOLO系列博文: 【第1篇:概述物体检测算法发展史、YO…...

Java设计模式 —— 【创建型模式】原型模式(浅拷贝、深拷贝)详解

文章目录 前言原型模式一、浅拷贝1、案例2、引用数据类型 二、深拷贝1、重写clone()方法2、序列化 总结 前言 先看一下传统的对象克隆方式: 原型类: public class Student {private String name;public Student(String name) {this.name name;}publi…...

SciAssess——评估大语言模型在科学文献处理中关于模型的记忆、理解和分析能力的基准

概述 大规模语言模型(如 Llama、Gemini 和 GPT-4)的最新进展因其卓越的自然语言理解和生成能力而备受关注。对这些模型进行评估对于确定其局限性和潜力以及促进进一步的技术进步非常重要。为此,人们提出了一些特定的基准来评估大规模语言模型…...

SQLModel与FastAPI结合:构建用户增删改查接口

SQLModel简介 SQLModel是一个现代化的Python库,旨在简化与数据库的交互。它结合了Pydantic和SQLAlchemy的优势,使得定义数据模型、进行数据验证和与数据库交互变得更加直观和高效。SQLModel由FastAPI的创始人Sebastin Ramrez开发,专为与FastA…...

【RISC-V CPU debug 专栏 2.3 -- Run Control】

文章目录 Run ControlHart 运行控制状态位状态信号操作流程时间与实现注意事项Run Control 在 RISC-V 调试架构中,运行控制模块通过管理多个状态位来对硬件线程(harts)的执行进行调节和控制。这些状态位帮助调试器请求暂停或恢复 harts,并在 hart 复位时进行控制。以下是运…...

探索 IntelliJ IDEA 中 Spring Boot 运行配置

前言 IntelliJ IDEA 作为一款功能强大的集成开发环境(IDE),为 Spring Boot 应用提供了丰富的运行配置选项,定义了如何在 IntelliJ IDEA 中运行 Spring Boot 应用程序,当从主类文件运行应用程序时,IDE 将创建…...

三除数枚举

给你一个整数 n 。如果 n 恰好有三个正除数 ,返回 true ;否则,返回 false 。 如果存在整数 k ,满足 n k * m ,那么整数 m 就是 n 的一个 除数 。 输入:n 4 输出:true 解释:4 有三…...

【051】基于51单片机温度计【Proteus仿真+Keil程序+报告+原理图】

☆、设计硬件组成:51单片机最小系统DS18B20温度传感器LCD1602液晶显示按键设置蜂鸣器LED灯。 1、本设计采用STC89C51/52、AT89C51/52、AT89S51/52作为主控芯片; 2、采用DS18B20温度传感器测量温度,并且通过LCD1602实时显示温度;…...

[Java]微服务之服务保护

雪崩问题 微服务调用链路中的某个服务故障,引起整个链路中的所有微服务都不可用,这就是雪崩 雪崩问题产生的原因是什么? 微服务相互调用,服务提供者出现故障或阻塞。服务调用者没有做好异常处理,导致自身故障。调用链中的所有服…...

自动驾驶目标检测融合全貌

1、early fusion 早期融合,特点用到几何空间转换3d到2d或者2d到3d的转换,用像素找点云或者用点云找像素。 2、deep fusion 深度融合,也是特征级别融合,也叫多模态融合,如bevfusion范式 3、late fusion 晚融合&#x…...

消息框(Message Box)的测试方法和测试用例

我来帮你了解消息框(Message Box)的测试方法和测试用例的编写。 我已经创建了一个测试用例示例,让我为你解释消息框测试的主要方面: 测试维度: 功能性测试:验证消息框的基本功能是否正常样式测试:确认不同类型消息框…...

Ubuntu 包管理

APT&dpkg 查看已安装包 查看所有已经安装的包 dpkg -l 查找包 apt search <package_name>搜索软件包列表&#xff0c;找到与搜索关键字匹配的包 dpkg与grep结合查找特定的包 dpkg -s <package>&#xff1a;查看某个安装包的详细信息 安装包 apt安装命令 更新…...

[Ubuntu] linux之Ubuntu18.04的下载及在虚拟机中详细安装过程(附有下载链接)

前言 ubuntu 链接&#xff1a;https://pan.quark.cn/s/283509d0d36e 提取码&#xff1a;dfT1 链接失效&#xff08;可能被官方和谐&#xff09;可评论或私信我重发 下载压缩包后解压 &#xff01;&#xff01;安装路径不要有中文 下载后解压得到.iso文件&#xff0c;不要放在…...

ffmpeg安装(windows)

ffmpeg安装-windows 前言ffmpeg安装路径安装说明 前言 ffmpeg的安装也是开箱即用的,并没有小码哥说的那么难 ffmpeg安装路径 这就下载好了! 安装说明 将上面的bin目录加入到环境变量,然后在cmd中测试一下: C:\Users\12114\Desktop\test\TaskmgrPlayer\x64\Debug>ffmpe…...

服务器数据恢复—raid6阵列硬盘被误重组为raid5阵列的数据恢复案例

服务器存储数据恢复环境&#xff1a; 存储中有一组由12块硬盘组建的RAID6阵列&#xff0c;上层linux操作系统EXT3文件系统&#xff0c;该存储划分3个LUN。 服务器存储故障&分析&#xff1a; 存储中RAID6阵列不可用。为了抢救数据&#xff0c;运维人员使用原始RAID中的部分…...

linux内核编译启动总结

linux kernel 编译 升级汇总 写在前面内核编译获取kernel代码开始前的准备工作 编译过程1\.解压与净化将下载好的linux内核解压至/usr/src 2\. 得到源代码后,将其净化3\. 配置要进行编译的内核4.编译内核. &#xff08;15分钟&#xff09;5.编译模块.方法1:方法2&#xff1a; 6…...

Android Studio的AI工具插件使用介绍

Android Studio的AI工具插件使用介绍 一、前言 Android Studio 的 AI 工具插件具有诸多重要作用&#xff0c;以下是一些常见的方面&#xff1a; 代码生成与自动补全 代码优化与重构 代码解读 学习与知识获取 智能搜索与资源推荐实际使用中可以添加注释&#xff0c;解读某段代…...

本地部署 WireGuard 无需公网 IP 实现异地组网

WireGuard 是一个高性能、极简且易于配置的开源虚拟组网协议。使用路由侠内网穿透使其相互通讯。 第一步&#xff0c;服务端&#xff08;假设为公司电脑&#xff09;和客户端&#xff08;假设为公司外的电脑&#xff09;安装部署 WireGuard 1&#xff0c;点此下载&#xff08;…...

asyncio.ensure_future 与 asyncio.create_task:Python异步编程中的选择

asyncio.ensure_future 与 asyncio.create_task&#xff1a;Python异步编程中的选择 引言asyncio.ensure_futureasyncio.create_task两者的区别参数接受范围任务调度的保证代码可读性 哪个更好&#xff1f;使用asyncio.create_task使用asyncio.ensure_future 结论参考 引言 在…...

CTF之密码学(密码特征分析)

一.MD5,sha1,HMAC,NTLM 1.MD5&#xff1a;MD5一般由32/16位的数字(0-9)和字母(a-f)组成的字符串 2.sha1&#xff1a;这种加密的密文特征跟MD5差不多&#xff0c;只不过位数是40&#xff08;sha256&#xff1a;64位&#xff1b;sha512:128位&#xff09; 3.HMAC&#xff1a;这…...

JVM调优篇之JVM基础入门AND字节码文件解读

目录 Java程序编译class文件内容常量池附录-访问标识表附录-常量池类型列表 Java程序编译 Java文件通过编译成class文件后&#xff0c;通过JVM虚拟机解释字节码文件转为操作系统执行的二进制码运行。 规范 Java虚拟机有自己的一套规范&#xff0c;遵循这套规范&#xff0c;任…...

EXCEL截取某一列从第一个字符开始到特定字符结束的字符串到新的一列

使用EXCEL中的公式进行特定截取 假设列A是一组产品的编码&#xff0c;我们需要的数据是“-”之前的字段。 我们需要在B1单元格输入公式“LEFT(A1,SEARCH("-",A1)-1)”然后选中B1至B4单元格&#xff0c;按“CTRLD”向下填充&#xff0c;就可以得出其它几行“-”之前的…...

数据库期末复习题库

1. Mysql日志功能有哪些? 记录日常操作和错误信息&#xff0c;以便了解Mysql数据库的运行情况&#xff0c;日常操作&#xff0c;错误信息和进行相关的优化。 2. 数据库有哪些备份方法 完全备份&#xff1a;全部都备份一遍表备份&#xff1a;只提取数据库中的数据&#xff0…...

私有库gitea安装

一 gitea是什么 Gitea是一款自助Git服务&#xff0c;简单来说&#xff0c;就是可以一个私有的github。 搭建很容易。 Gitea依赖于Git。 类似Gitea的还有GitHub、Gitee、GitLab等。 以下是安装步骤。 二 安装sqilite 参考&#xff1a; 在windows上安装sqlite 三 安装git…...

关于最近win11不能使用ie,而不能使用考试客户端的解决方法

弄ie的那个我感觉是非常难的&#xff0c;所以我的是另一种的方法 下载360浏览器&#xff08;不是360全家桶&#xff09;360安全浏览器-全面保护上网安全&#xff0c;4亿用户共同选择&#xff08;上面的是官网&#xff0c;不要下载错了&#xff0c;还有安装界面注意不要勾选一下…...

深度学习之Mask-R-CNN

1.1 Mask-RCNN 的网络结构示意图 其中黑色部分为原来的Faster-RCNN&#xff0c;红色部分为在Faster网络上的修改&#xff1a;    1&#xff09;将ROI Pooling层替换成了ROIAlign&#xff1b;    2&#xff09;添加并列的FCN层&#xff08;Mask层&#xff09;&#xff1b;  …...

css包含块

包含块 出现 在css中一些属性的计算可能超出你的预料&#xff0c;在普遍情况下会认为定位属性和百分比的宽高是根据父元素计算的&#xff0c;但是准确来说他们都是根据元素所在的包含块来计算的&#xff0c;所以掌握包含块的知识是非常关键的。 内容 在CSS中&#xff0c;“…...

混沌工程/混沌测试/云原生测试/云平台测试

背景 私有云/公有云/混合云等具有复杂&#xff0c;分布式&#xff0c;环境多样性等特点&#xff0c;许多特殊场景引发的线上问题很难被有效发现。所以需要引入混沌工程&#xff0c;建立对系统抵御生产环境中失控条件的能力以及信心&#xff0c;提高系统面对未知风险得能力。 …...

wordpress 登录没反应/宁波seo专员

单链表&#xff08;single-linked list&#xff09;链表结构应用实例分析数据结构算法类方法对象代码实现插入向尾部直接插入节点思路分析算法实现按照顺序插入指定位置思路分析算法实现修改思路分析代码实现删除思路分析代码实现查找思路分析代码实现面试题有效元素的个数代码…...

做网站的成本在哪/必应搜索引擎怎么样

按照相关交通领域机构发布的数据显示&#xff0c;近年来&#xff0c;分心驾驶已经成为公路事故和死亡的主要原因。同时&#xff0c;随着高阶智能驾驶的陆续前装上车&#xff0c;驾驶员的监控&#xff08;保证对前方道路的持续注意力&#xff09;也成为安全风险的“重灾区”。 …...

人与马做网站/服装店营销策划方案

配套FPGA开发板&#xff08;含该设计的工程代码&#xff09;&#xff1a;https://item.taobao.com/item.htm?spma1z10.1-c.w4004-4676525296.4.6e8950ed57YPhv&id17848039135 基于FPGA的智力抢答器设计 功能说明 说明 4路抢答器&#xff0c;选手&#xff0c;主持人可以进行…...

怎么做传奇私服广告网站/站长工具麻豆

关于如何集成spring-data-mongodb到项目中&#xff0c;已经有很多人介绍了&#xff0c;这里只给出几个链接。GETTING STARTED Accessing Data with MongoDB&#xff1a; http://spring.io/guides/gs/accessing-data-mongodb/MongoDB初探(二)----使用spring-data配置mongodb &am…...

网站建设服务器百度云/竞价推广培训课程

2019独角兽企业重金招聘Python工程师标准>>> 公司有同事用foxmail无法正常收取james的邮件。而且在收取时&#xff0c;其他服务部的同事表示后台连接不上。判断是数据库连接问题或james连接并发的问题。 当你在具有很多TCP/IP连接的Windows上运行MySQL服务器&#x…...

企业建设网站的好处/国内10大搜索引擎

盼望着&#xff0c;盼望着。在其他省市的小伙伴早已开奖&#xff0c;奖金都快花完了的时候。北京的同学们&#xff0c;终于可以开奖退税了。点开之前&#xff0c;谁也不知道是喜是忧&#xff1b;点开之后&#xff0c;有人欢喜有人愁。拿出早已收藏好的办税攻略&#xff0c;把看…...