当前位置: 首页 > news >正文

训练自己的GPT2

训练自己的GPT2

  • 1.预训练与微调
  • 2.准备工作
  • 2.在自己的数据上进行微调

1.预训练与微调

所谓的预训练,就是在海量的通用数据上训练大模型。比如,我把全世界所有的网页上的文本内容都整理出来,把全人类所有的书籍、论文都整理出来,然后进行训练。这个训练过程代价很大,首先模型很大,同时数据量又很大,比如GPT3参数量达到了175B,训练数据达到了45TB,训练一次就话费上千万美元。如此大代价学出来的是一个通用知识的模型,他确实很强,但是这样一个模型,可能无法在一些专业性很强的领域上取得比较好的表现,因为他没有针对这个领域的数据进行训练过。

因此,大模型火了之后,很多人都开始把大模型用在自己的领域。通常也就是把自己领域的一些数据,比如专业书、论文等等整理出来,使用预训练好的大模型在新的数据集上进行微调。微调的成本相比于预训练就要小得多了。

2.准备工作

首先需要安装第三方库transformerstransformers是一个用于自然语言处理(NLP)的Python第三方库,实现Bert、GPT-2和XLNET等比较新的模型,支持TensorFlow和PyTorch。以及下载预训练好的模型权重。

pip install transformers

安装完成之后,我们可以直接使用下面的代码,来构造一个预训练的GPT2

from transformers import GPT2Tokenizer, GPT2LMHeadModeltokenizer = GPT2Tokenizer.from_pretrained("gpt2")
model = GPT2LMHeadModel.from_pretrained("gpt2")

当运行的时候,代码会自动从hugging face上下载模型。但是由于hugging face是国外网站,可能下载起来很慢或者无法下载,因此我们也可以自己手动下载之后在本地读取。

打开hugging face的网站,搜索GPT2。或者直接进入GPT2的页面。

下载上图中的几个文件到本地,假设下载到./gpt2文件夹

然后就可以使用下面的代码来尝试预训练的模型直接生成文本你的效果。

from transformers import GPT2Tokenizer, GPT2LMHeadModeltokenizer = GPT2Tokenizer.from_pretrained("./gpt2")
model = GPT2LMHeadModel.from_pretrained("./gpt2")q = "tell me a fairy story"ids = tokenizer.encode(q, return_tensors='pt')
final_outputs = model.generate(ids,do_sample=True,max_length=100,pad_token_id=model.config.eos_token_id,top_k=50,top_p=0.95,
)print(tokenizer.decode(final_outputs[0], skip_special_tokens=True))

回答如下:

2.在自己的数据上进行微调

首先把我们的数据,也就是文本,全部整理到一起。比如可以把所有文本拼接到一起。

假设所有的文本数据都存到一个文件中。那么可以直接使用下面的代码进行训练。

import torch
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Tokenizer, GPT2ForSequenceClassification, AdamW, GPT2LMHeadModel
from transformers import DataCollatorForLanguageModeling
from transformers import Trainer, TrainingArguments, TextDatasetdef load_data_collator(tokenizer, mlm = False):data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=mlm,)return data_collatordef load_dataset(file_path, tokenizer, block_size = 128):dataset = TextDataset(tokenizer = tokenizer,file_path = file_path,block_size = block_size,)return datasetdef train(train_file_path, model_name,output_dir,overwrite_output_dir,per_device_train_batch_size,num_train_epochs,save_steps):tokenizer = GPT2Tokenizer.from_pretrained(model_name)train_dataset = load_dataset(train_file_path, tokenizer)data_collator = load_data_collator(tokenizer)tokenizer.save_pretrained(output_dir)model = GPT2LMHeadModel.from_pretrained(model_name)model.save_pretrained(output_dir)training_args = TrainingArguments(output_dir=output_dir,overwrite_output_dir=overwrite_output_dir,per_device_train_batch_size=per_device_train_batch_size,num_train_epochs=num_train_epochs,)trainer = Trainer(model=model,args=training_args,data_collator=data_collator,train_dataset=train_dataset,)trainer.train()trainer.save_model()train_file_path = "./train.txt"   # 你自己的训练文本
model_name = './gpt2'  # 预训练的模型路径
output_dir = './custom_data'  # 你自己设定的模型保存路径
overwrite_output_dir = False
per_device_train_batch_size = 96  # 每一台机器上的batch size。
num_train_epochs = 50   
save_steps = 50000# Train
train(train_file_path=train_file_path,model_name=model_name,output_dir=output_dir,overwrite_output_dir=overwrite_output_dir,per_device_train_batch_size=per_device_train_batch_size,num_train_epochs=num_train_epochs,save_steps=save_steps
)     

训练完成之后,推理的话,直接使用第二节里的代码,将预训练模型路径换成自己训练的模型路径就行了

相关文章:

训练自己的GPT2

训练自己的GPT2 1.预训练与微调2.准备工作2.在自己的数据上进行微调 1.预训练与微调 所谓的预训练,就是在海量的通用数据上训练大模型。比如,我把全世界所有的网页上的文本内容都整理出来,把全人类所有的书籍、论文都整理出来,然…...

etcd储存安装

目录 etcd介绍: etcd工作原理 选举 复制日志 安全性 etcd工作场景 服务发现 etcd基本术语 etcd安装(centos) 设置:etcd后台运行 etcd 是云原生架构中重要的基础组件,由 CNCF 孵化托管。etcd 在微服务和 Kubernates 集群中不仅可以作为服务注册…...

如何彻底卸载Microsoft Edge浏览器

一、引语 随着微软推出全新的Edge浏览器,许多用户可能想要尝试或完全切换到其他浏览器。在这篇文章中,我们将向您介绍如何彻底卸载Microsoft Edge浏览器,以确保您的系统干净整洁。 二、通过系统设置卸载 1、首先,右键单击桌面上…...

Transformers 2023年度回顾 :从BERT到GPT4

人工智能已成为近年来最受关注的话题之一,由于神经网络的发展,曾经被认为纯粹是科幻小说中的服务现在正在成为现实。从对话代理到媒体内容生成,人工智能正在改变我们与技术互动的方式。特别是机器学习 (ML) 模型在自然语言处理 (NLP) 领域取得…...

判断两个对象某些字段的值是否相同

1、借助mybatis plus的方法 import com.baomidou.mybatisplus.core.toolkit.LambdaUtils; import com.baomidou.mybatisplus.core.toolkit.support.SFunction; import com.baomidou.mybatisplus.core.toolkit.support.SerializedLambda; import lombok.SneakyThrows; import o…...

TYPE-C接口取电芯片介绍和应用场景

随着科技的发展,USB PDTYPE-C已经成为越来越多设备的充电接口。而在这一领域中,LDR6328Q PD取电芯片作为设备端协议IC芯片,扮演着至关重要的角色。本文将详细介绍LDR6328Q PD取电芯片的工作原理、应用场景以及选型要点。 一、工作原理 LDR63…...

基于TI TPSXX系列 Buck电路应用计算-外围器件详细计算过程

TPS54202 Buck电路应用计算 1、电气特性2、内部框图3、典型应用电路4、设计需求5、计算EN引脚电阻6、FB引脚电阻估算7、查看反馈电压电压基准8、输入电容计算10、FB引脚反馈电阻计算11、功率电感计算12、输出电容计算13、前馈电容计算15、Layout布局TPS54202-中文版 1、电气特…...

NOIP2012提高组day1-T3:开车旅行

题目链接 [NOIP2012 提高组] 开车旅行 题目描述 小 A \text{A} A 和小 B \text{B} B 决定利用假期外出旅行,他们将想去的城市从 1 1 1 到 n n n 编号,且编号较小的城市在编号较大的城市的西边,已知各个城市的海拔高度互不相同&#xf…...

Golang Web框架性能对比

Golang Web框架性能对比 github star排名依次: Gin Beego Iris Echo Revel Buffalo 性能上gin、iris、echo网上是给的数据都是五星,beego三星,revel两星 beego是国产,有中文文档,文档齐全 根据star数,性能,易用程度…...

【OCR】 - Tesseract OCR在mac系统中安装

Tesseract OCR 在Mac环境下安装Tesseract OCR(Optical Character Recognition)通常可以通过Homebrew包管理器进行。以下是安装步骤: 安装Homebrew 如果你还没有安装Homebrew,请访问 https://brew.sh/ 并按照页面上的说明安装。…...

了解不同方式导入导出的速度之快

目录 一、用工具导出导入 Navicat(速度慢) 1.1、导入: 共耗时: 1.2、导出表 共耗时: 二、用命令语句导出导入 2.1、mysqldump速度快 导出表数据和表结构 共耗时: 只导出表结构 导入 共耗时&…...

2024年第九届计算机与通信系统国际会议(ICCCS2024) ,邀您相约西安!

会议官网: ICCCS2024 | Xian China 时间: 2024年4月19-22日 地点: 中国西安 会议简介: 近年来,信息通信在不断发展,为计算机网络的进步与发展提供了先进可靠的技术支持。随着计算机网络与通信技术的深入发展,计算机通信技术、数…...

获取直播间的最新评论 - python 取两个list的差集

python 取两个list的差集 作用:比如我要获取评论区列表,先获取了一遍,这个时候有人评论了几条,我再获取一遍后,找出多的那几条 使用set数据类型来取两个列表的差集。差集表示仅包含在第一个列表中而不在第二个列表中…...

2023年度总结:但行前路,不负韶华

​ 🦁作者简介:一名喜欢分享和记录学习的在校大学生 🐯个人主页:妄北y 🐧个人QQ:2061314755 🐻个人邮箱:2061314755qq.com 🦉个人WeChat:Vir2021GKBS &#x…...

智数融合|低代码入局,推动工业数字化转型走"深"向"实"

当下,“数字化、智能化”已经不再是新鲜词汇。事实上,早在几年前,就有企业开始大力推动数字化转型,并持续进行了一段时间。一些业内人士甚至认为,“如今的企业数字化已经走过了成熟期,进入了深水区。” 但事…...

初学者的基本 Python 面试问题和答案

文章目录 专栏导读1、什么是Python?列出 Python 在技术领域的一些流行应用。2、在目前场景下使用Python语言作为工具有什么好处?3、Python是编译型语言还是解释型语言?4、Python 中的“#”符号有什么作用?5、可变数据类型和不可变…...

支持向量机(Support Vector Machines,SVM)

什么是机器学习 支持向量机(Support Vector Machines,SVM)是一种强大的机器学习算法,可用于解决分类和回归问题。SVM的目标是找到一个最优的超平面,以在特征空间中有效地划分不同类别的样本。 基本原理 超平面 在二…...

golang一个轻量级基于内存的kv存储或缓存

golang一个轻量级基于内存的kv存储或缓存 go-cache是一个轻量级的基于内存的key:value 储存组件,类似于memcached,适用于在单机上运行的应用程序。 它的主要优点是,本质上是一个具有过期时间的线程安全map[string]interface{}。interface的结…...

henauOJ 1103: 统计元音

题目描述 统计每个元音字母在字符串中出现的次数。 输入 输入数据首先包括一个整数n,表示测试实例的个数,然后是n行长度不超过100的字符串。 输出 对于每个测试实例输出5行,格式如下: a:num1 e:num2 i:num3 o:num4 u:num5 多…...

虚幻引擎:开创视觉与创意的新纪元

先看看据说虚幻5做出来的东西吧: 虚幻引擎5!!!4K画质PS5实机演示! 好了,用文字认识一下吧: 虚幻引擎5.3对UE5的核心工具集作了进一步优化,涉及渲染、世界构建、程序化内容生成&…...

T527 Android 13 编译步骤

步骤1: cd longan./build.sh config (0 2 1) 选择 Android 平台: 步骤2:选择IC为t527: 步骤3:板子类型选为demo_car: 步骤4:选择 flash,默认选择 default 则可: 步骤5&…...

OpenAI ChatGPT-4开发笔记2024-04:Chat之Tool之2:multiple functions

从程序员到ai Expert 1 定义参数和函数2 第一轮chatgpt3 第一轮结果和function定义全部加入prompt再喂给chatgpt4 大结局7 参考资料 上一篇解决了调用一个函数的问题。这一篇扩展为调用3个。n个自行脑补。 1 定义参数和函数 #1.设定目标 import json import openai#1.定义para…...

14:00面试,14:07就出来了,问的问题有点变态。。。

前言 刚从小厂出来,没想到网盘我在另一家公司又寄了。 在这家公司上班,每天都要加班,但看在钱给的比较多的份上,也就不太计较了。但万万没想到一纸通知,所有人不准加班了,不仅加班费没有了,薪…...

206. 反转链表(Java)

题目描述: 给你单链表的头节点 head ,请你反转链表,并返回反转后的链表。 输入: head [1,2,3,4,5] 输出: [5,4,3,2,1] 代码实现: 1.根据题意创建一个结点类: public class ListNode {int val…...

LeetCode 2807. 在链表中插入最大公约数【链表,迭代,递归】1279

本文属于「征服LeetCode」系列文章之一,这一系列正式开始于2021/08/12。由于LeetCode上部分题目有锁,本系列将至少持续到刷完所有无锁题之日为止;由于LeetCode还在不断地创建新题,本系列的终止日期可能是永远。在这一系列刷题文章…...

Hive之set参数大全-3

D 是否启用本地任务调试模式 hive.debug.localtask 是 Apache Hive 中的一个配置参数,用于控制是否启用本地任务调试模式。在调试模式下,Hive 将尝试在本地模式下运行一些任务,以便更容易调试和分析问题。 具体来说,当 hive.de…...

Golang拼接字符串性能对比

g o l a n g golang golang的 s t r i n g string string类型是不可修改的,对于拼接字符串来说,本质上还是创建一个新的对象将数据放进去。主要有以下几种拼接方式 拼接方式介绍 1.使用 s t r i n g string string自带的运算符 ans ans s2. 使用…...

【问题解决】web页面html锚点定位后内容被遮挡问题解决【暗锚】

正常的锚点跳转 a标签的href填写目标元素的id即可 <a href"#my_target">to div1</a> <div id"my_target">div1</div> 内容被顶栏遮挡示例 但是当id所在元素被嵌套多层flex和relative布局之后&#xff0c;跳转后部分内容会被遮挡…...

easyui datagrid无数据时显示无数据

这里写自定义目录标题 需求解决办法 需求 使用datagrid显示记录时&#xff0c;结果查询记录数为0&#xff0c;此时需要显示无数据。 示例代码 <table id"dg"></table>$(#dg).datagrid({url:datagrid_data.json,columns:[[{field:code,title:Code,widt…...

动态规划python简单例子-斐波那契数列

def fibonacci(n):dp [0, 1] [0] * (n - 1) # 初始化动态规划数组for i in range(2, n 1):dp[i] dp[i - 1] dp[i - 2] # 计算斐波那契数列的第 i 项print(dp)return dp[n] # 返回斐波那契数列的第 n 项# 示例用法 n 10 # 计算斐波那契数列的第 10 项 result fibonac…...

自己开网店没有货源怎么办/百度seo通科

JDBC客户端操作步骤 转载于:https://www.cnblogs.com/ggzhangxiaochao/p/9222944.html...

中国人做代购的网站/西安网站推广

原文链接:http://click.aliyun.com/m/13858/ 免费开通大数据服务&#xff1a;https://www.aliyun.com/product/odps目前人人都在谈大数据&#xff0c;谈DT时代&#xff0c;但是&#xff0c;大数据是什么&#xff0c;每个人都有自己的一个看法&#xff0c;好比盲人摸象&#xff…...

无锡做网站哪家公司好/目前好的推广平台

python列出指定目录"c:\"所有的后缀名为*.txt 的文并输出每个文件的创建日期和大小下载文件:python列出指定目录.rarpython提取文件夹中所有子文件夹下所有文件的某一行每当看见这个软弱的自己&#xff0c;就想到自己曾经不堪回首的往事&#xff0c;和那个怀念至今&a…...

wordpress 七牛云冲突/app投放推广

三月份小编在美国参加MVP峰会的时候&#xff0c;有幸碰到了几个Uber的高级工程师&#xff0c;他们在当天还分享了Uber的消息总线系统如何在每日兆级信息量、PB级数据卷、数万个Topic的情况下&#xff0c;保证低延时&#xff08;小于5ms&#xff09;&#xff0c;高可用&#xff…...

帝国cms手机网站/拉新任务接单放单平台

在使用MySQL 时 有时会登录遇到错误 ERROR 2003 (HY000): Cant connect to MySQL server on localhost (10061) 以下方法几步解决问题&#xff0c; 让mysql恢复如初&#xff08;不会删除原有表和数据&#xff09;&#xff1a; 1.找到mysql安装目录&#xff0c;以我的为例&…...

容桂做网站/北京十大教育培训机构排名

let userId dsf.getCookie("userId");...