当前位置: 首页 > news >正文

从零构建属于自己的GPT系列3:模型训练2(训练函数解读、模型训练函数解读、代码逐行解读)

🚩🚩🚩Hugging Face 实战系列 总目录

有任何问题欢迎在下面留言
本篇文章的代码运行界面均在PyCharm中进行
本篇文章配套的代码资源已经上传

从零构建属于自己的GPT系列1:文本数据预处理
从零构建属于自己的GPT系列2:语言模型训练

3 数据加载函数

def load_dataset(logger, args):"""加载训练集"""logger.info("loading training dataset")train_path = args.train_pathwith open(train_path, "rb") as f:train_list = pickle.load(f)# test# train_list = train_list[:24]train_dataset = CPMDataset(train_list, args.max_len)return train_dataset
  1. List item

4 训练函数

def train(model, logger, train_dataset, args):train_dataloader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, collate_fn=collate_fn,drop_last=True)logger.info("total_steps:{}".format(len(train_dataloader)* args.epochs))t_total = len(train_dataloader) // args.gradient_accumulation_steps * args.epochsoptimizer = transformers.AdamW(model.parameters(), lr=args.lr, eps=args.eps)scheduler = transformers.get_linear_schedule_with_warmup(optimizer, num_warmup_steps=args.warmup_steps, num_training_steps=t_total)# 设置warmuplogger.info('start training')train_losses = []   # 记录每个epoch的平均loss# ========== start training ========== #for epoch in range(args.epochs):train_loss = train_epoch(model=model, train_dataloader=train_dataloader,optimizer=optimizer, scheduler=scheduler,logger=logger, epoch=epoch, args=args)train_losses.append(round(train_loss, 4))logger.info("train loss list:{}".format(train_losses))logger.info('training finished')logger.info("train_losses:{}".format(train_losses))

5 迭代训练函数

def train_epoch(model, train_dataloader, optimizer, scheduler, logger,epoch, args):model.train()device = args.deviceignore_index = args.ignore_indexepoch_start_time = datetime.now()total_loss = 0  # 记录下整个epoch的loss的总和epoch_correct_num = 0   # 每个epoch中,预测正确的word的数量epoch_total_num = 0  # 每个epoch中,预测的word的总数量for batch_idx, (input_ids, labels) in enumerate(train_dataloader):# 捕获cuda out of memory exceptiontry:input_ids = input_ids.to(device)labels = labels.to(device)outputs = model.forward(input_ids, labels=labels)logits = outputs.logitsloss = outputs.lossloss = loss.mean()# 统计该batch的预测token的正确数与总数batch_correct_num, batch_total_num = calculate_acc(logits, labels, ignore_index=ignore_index)# 统计该epoch的预测token的正确数与总数epoch_correct_num += batch_correct_numepoch_total_num += batch_total_num# 计算该batch的accuracybatch_acc = batch_correct_num / batch_total_numtotal_loss += loss.item()if args.gradient_accumulation_steps > 1:loss = loss / args.gradient_accumulation_stepsloss.backward()# 梯度裁剪torch.nn.utils.clip_grad_norm_(model.parameters(), args.max_grad_norm)# 进行一定step的梯度累计之后,更新参数if (batch_idx + 1) % args.gradient_accumulation_steps == 0:# 更新参数optimizer.step()# 更新学习率scheduler.step()# 清空梯度信息optimizer.zero_grad()if (batch_idx + 1) % args.log_step == 0:logger.info("batch {} of epoch {}, loss {}, batch_acc {}, lr {}".format(batch_idx + 1, epoch + 1, loss.item() * args.gradient_accumulation_steps, batch_acc, scheduler.get_lr()))del input_ids, outputsexcept RuntimeError as exception:if "out of memory" in str(exception):logger.info("WARNING: ran out of memory")if hasattr(torch.cuda, 'empty_cache'):torch.cuda.empty_cache()else:logger.info(str(exception))raise exception# 记录当前epoch的平均loss与accuracyepoch_mean_loss = total_loss / len(train_dataloader)epoch_mean_acc = epoch_correct_num / epoch_total_numlogger.info("epoch {}: loss {}, predict_acc {}".format(epoch + 1, epoch_mean_loss, epoch_mean_acc))# save modellogger.info('saving model for epoch {}'.format(epoch + 1))model_path = join(args.save_model_path, 'epoch{}'.format(epoch + 1))if not os.path.exists(model_path):os.mkdir(model_path)model_to_save = model.module if hasattr(model, 'module') else modelmodel_to_save.save_pretrained(model_path)logger.info('epoch {} finished'.format(epoch + 1))epoch_finish_time = datetime.now()logger.info('time for one epoch: {}'.format(epoch_finish_time - epoch_start_time))return epoch_mean_loss

从零构建属于自己的GPT系列1:文本数据预处理
从零构建属于自己的GPT系列2:语言模型训练

相关文章:

从零构建属于自己的GPT系列3:模型训练2(训练函数解读、模型训练函数解读、代码逐行解读)

🚩🚩🚩Hugging Face 实战系列 总目录 有任何问题欢迎在下面留言 本篇文章的代码运行界面均在PyCharm中进行 本篇文章配套的代码资源已经上传 从零构建属于自己的GPT系列1:文本数据预处理 从零构建属于自己的GPT系列2:语…...

Python词频统计(数据整理)

请编写程序,对一段英文文本,统计其中所有不同单词的个数,以及词频最大的前10%的单词。 输入格式: 输入给出一段非空文本,最后以符号#结尾。输入保证存在至少10个不同的单词。 输出格式: 在第一行中输出文本中所有不同单词的个数…...

基本面选股的方法

基本面选股是一种投资策略,主要关注公司的财务状况、盈利能力、行业地位等因素,以判断公司的价值并做出投资决策。以下是基本面选股的具体分析方法和重点: 财务状况分析: 利润表分析:关注公司的净利润、毛利率、营业…...

应用密码学期末复习(3)

目录 第三章 现代密码学应用案例 3.1安全电子邮件方案 3.1.1 PGP产生的背景 3.2 PGP提供了一个安全电子邮件解决方案 3.2.1 PGP加密流程 3.2.2 PGP解密流程 3.2.3 PGP整合了对称加密和公钥加密的方案 3.3 PGP数字签名和Hash函数 3.4 公钥分发与认证——去中心化模型 …...

PAD平板签约投屏-高端活动的选择

传统的现场纸质签约仪式除了缺乏仪式感之外还缺少互动性,如果要将签约的过程投放到大屏幕上更是需要额外的硬件设备成本。相比于传统的纸质签约仪式,平板现场电子签约的形式更加的新颖、更富有科技感、更具有仪式感。 平板签约投屏是应用于会议签字仪式的…...

分布式架构demo

1、外层创建pom 版本管理器 <parent><groupId>org.springframework.boot</groupId><artifactId>spring-boot-starter-parent</artifactId><version>2.7.15</version><relativePath/> <!-- lookup parent from repository…...

TA-Lib学习研究笔记(二)——Overlap Studies上

TA-Lib学习研究笔记&#xff08;二&#xff09;——Overlap Studies 1. Overlap Studies 指标 [BBANDS, DEMA, EMA, HT_TRENDLINE, KAMA, MA, MAMA, MAVP, MIDPOINT, MIDPRICE, SAR, SAREXT, SMA, T3, TEMA, TRIMA, WMA]2.数据准备 get_data函数参数&#xff08;代码&#x…...

牛客java基础考点1 标识符和变量

牛客java基础考点1 标识符和变量 标识符 字母和数字&#xff1a; 标识符由字母、数字、下划线&#xff08;_&#xff09;和美元符号&#xff08;$&#xff09;组成。其中&#xff0c;标识符必须以字母、下划线或美元符号开头。大小写敏感&#xff1a; Java 是大小写敏感的语言…...

Qt将打印信息输出到文件

将打印信息&#xff08;qDebug、qInfo、qWarning、qCritial等&#xff09;输出到指定文件来以实现简单的日志功能。 #include "mainwindow.h" #include <QApplication> #include <QLoggingCategory> #include <QMutex> #include <QDateTime>…...

【risc-v】易灵思efinix FPGA sapphire_soc IP配置参数分享

系列文章目录 分享一些fpga内使用riscv软核的经验&#xff0c;共大家参考。后续内容比较多&#xff0c;会做成一个系列。 本系列会覆盖以下FPGA厂商 易灵思 efinix 赛灵思 xilinx 阿尔特拉 Altera 本文内容隶属于【易灵思efinix】系列。 前言 在efinix fpga中使用riscv是一…...

直播的种类及类型

随着网络技术和移动设备的普及&#xff0c;直播已经成为人们娱乐、学习、商业交流等众多领域的重要工具。 直播的种类主要有以下几种: 1.视频直播:这是最常见的直播形式&#xff0c;包括电商直播、婚庆直播、培训直播、家居直播等。 2.图文直播:这种直播形式包括PPT互动直播…...

时间序列数据压缩算法简述

本文简单介绍了时间序列压缩任务的来源&#xff0c;压缩算法的分类&#xff0c;并对常见压缩算法的优缺点进行了简介&#xff0c;爱码士们快来一探究竟呀&#xff01; 引言 时间序列数据是在许多应用程序和领域中生成的一种基本数据类型&#xff0c;例如金融、医疗保健、交通和…...

智能锁-SI522TORC522方案资料

南京中科微这款SI522目前完全PinTOPin兼容的NXP&#xff1a;RC522、CV520 复旦微&#xff1a;FM17520、FM17522/FM17550 瑞盟&#xff1a;MS520、MS522 国民技术:NZ3801、NZ3802 SI522 是应用于13.56MHz 非接触式通信中高集成度读写卡系列芯片中的一员。是NXP 公司针对&quo…...

redux(4) -RTK简单使用

简单使用 1、下载 npm i reduxjs/toolkit react-redux 2、创建 1、在redux/user.js中创建模块user。从reduxjs/toolkit中引入createSlice创建模块片段&#xff0c;我们需要传入name、初始数据initialState、改state的reducers等。最后需要导出reducer和action。 代码如下&a…...

开源运维监控系统-Nightingale(夜莺)应用实践(未完)

一、前言 某业务系统因OS改造,原先的Zabbix监控系统推倒后未重建,本来计划用外部企业内其他监控系统接入,后又通知需要自建才能对接,考虑之前zabbix的一些不便,本次计划采用一个类Prometheus的监控系统,镜调研后发现Nightingale兼容Prometheus,又有一些其他功能增强,又…...

深入理解GMP模型

1、GMP模型的设计思想 1&#xff09;、GMP模型 GMP分别代表&#xff1a; G&#xff1a;goroutine&#xff0c;Go协程&#xff0c;是参与调度与执行的最小单位M&#xff1a;machine&#xff0c;系统级线程P&#xff1a;processor&#xff0c;包含了运行goroutine的资源&#…...

数学建模-基于集成学习的共享单车异常检测的研究

基于集成学习的共享单车异常检测的研究 整体求解过程概述(摘要) 近年来&#xff0c;共享单车的快速发展在方便了人们出行的同时&#xff0c;也对城市交通产生了一定的负面影响&#xff0c;其主要原因为单车资源配置的不合理。本文通过建立单车租赁数量的预测模型和异常检测模型…...

C语言-内存分配

内存分配 1. 引入 int nums[10] {0}; //对int len 10; int nums[len] {0}; //错是因为系统的内存分配原则导致的2. 概述 在程序运行时&#xff0c;系统为了 更好的管理进程中的内存&#xff0c;所以有了 内存分配机制。 分配原则&#xff1a; 2.1 静态分配 静态分配原…...

算法工程师-机器学习面试题总结(1)

目录 1-1 损失函数是什么&#xff0c;如何定义合理的损失函数&#xff1f; 1-2 回归模型和分类模型常用损失函数有哪些&#xff1f;各有什么优缺点 1-3 什么是结构误差和经验误差&#xff1f;训练模型的时候如何判断已经达到最优&#xff1f; 1-4 模型的“泛化”能力是指&a…...

【蓝桥杯选拔赛真题73】Scratch烟花特效 少儿编程scratch图形化编程 蓝桥杯创意编程选拔赛真题解析

目录 scratch烟花特效 一、题目要求 编程实现 二、案例分析 1、角色分析...

以下是对华为 HarmonyOS NETX 5属性动画(ArkTS)文档的结构化整理,通过层级标题、表格和代码块提升可读性:

一、属性动画概述NETX 作用&#xff1a;实现组件通用属性的渐变过渡效果&#xff0c;提升用户体验。支持属性&#xff1a;width、height、backgroundColor、opacity、scale、rotate、translate等。注意事项&#xff1a; 布局类属性&#xff08;如宽高&#xff09;变化时&#…...

前端开发面试题总结-JavaScript篇(一)

文章目录 JavaScript高频问答一、作用域与闭包1.什么是闭包&#xff08;Closure&#xff09;&#xff1f;闭包有什么应用场景和潜在问题&#xff1f;2.解释 JavaScript 的作用域链&#xff08;Scope Chain&#xff09; 二、原型与继承3.原型链是什么&#xff1f;如何实现继承&a…...

k8s业务程序联调工具-KtConnect

概述 原理 工具作用是建立了一个从本地到集群的单向VPN&#xff0c;根据VPN原理&#xff0c;打通两个内网必然需要借助一个公共中继节点&#xff0c;ktconnect工具巧妙的利用k8s原生的portforward能力&#xff0c;简化了建立连接的过程&#xff0c;apiserver间接起到了中继节…...

UR 协作机器人「三剑客」:精密轻量担当(UR7e)、全能协作主力(UR12e)、重型任务专家(UR15)

UR协作机器人正以其卓越性能在现代制造业自动化中扮演重要角色。UR7e、UR12e和UR15通过创新技术和精准设计满足了不同行业的多样化需求。其中&#xff0c;UR15以其速度、精度及人工智能准备能力成为自动化领域的重要突破。UR7e和UR12e则在负载规格和市场定位上不断优化&#xf…...

Springboot社区养老保险系统小程序

一、前言 随着我国经济迅速发展&#xff0c;人们对手机的需求越来越大&#xff0c;各种手机软件也都在被广泛应用&#xff0c;但是对于手机进行数据信息管理&#xff0c;对于手机的各种软件也是备受用户的喜爱&#xff0c;社区养老保险系统小程序被用户普遍使用&#xff0c;为方…...

SiFli 52把Imagie图片,Font字体资源放在指定位置,编译成指定img.bin和font.bin的问题

分区配置 (ptab.json) img 属性介绍&#xff1a; img 属性指定分区存放的 image 名称&#xff0c;指定的 image 名称必须是当前工程生成的 binary 。 如果 binary 有多个文件&#xff0c;则以 proj_name:binary_name 格式指定文件名&#xff0c; proj_name 为工程 名&…...

ubuntu系统文件误删(/lib/x86_64-linux-gnu/libc.so.6)修复方案 [成功解决]

报错信息&#xff1a;libc.so.6: cannot open shared object file: No such file or directory&#xff1a; #ls, ln, sudo...命令都不能用 error while loading shared libraries: libc.so.6: cannot open shared object file: No such file or directory重启后报错信息&…...

微服务通信安全:深入解析mTLS的原理与实践

&#x1f525;「炎码工坊」技术弹药已装填&#xff01; 点击关注 → 解锁工业级干货【工具实测|项目避坑|源码燃烧指南】 一、引言&#xff1a;微服务时代的通信安全挑战 随着云原生和微服务架构的普及&#xff0c;服务间的通信安全成为系统设计的核心议题。传统的单体架构中&…...

实战设计模式之模板方法模式

概述 模板方法模式定义了一个操作中的算法骨架&#xff0c;并将某些步骤延迟到子类中实现。模板方法使得子类可以在不改变算法结构的前提下&#xff0c;重新定义算法中的某些步骤。简单来说&#xff0c;就是在一个方法中定义了要执行的步骤顺序或算法框架&#xff0c;但允许子类…...

【1】跨越技术栈鸿沟:字节跳动开源TRAE AI编程IDE的实战体验

2024年初&#xff0c;人工智能编程工具领域发生了一次静默的变革。当字节跳动宣布退出其TRAE项目&#xff08;一款融合大型语言模型能力的云端AI编程IDE&#xff09;时&#xff0c;技术社区曾短暂叹息。然而这一退场并非终点——通过开源社区的接力&#xff0c;TRAE在WayToAGI等…...