当前位置: 首页 > news >正文

【NLP练习】Transformer实战-单词预测

  • 🍨 本文为🔗365天深度学习训练营 中的学习记录博客
  • 🍖 原作者:K同学啊

任务:自定义输入一段英文文本进行预测

一、定义模型

from tempfile import TemporaryDirectory
from typing import Tuple
from torch import nn,Tensor
from torch.nn import TransformerEncoder, TransformerEncoderLayer
import math, os, torchclass TransformerModel(nn.Module):def __init__(self, ntoken: int, d_model: int, nhead: int, d_hid: int, nlayers: int, dropout: float = 0.5):super().__init__()self.pos_encoder = PositionalEncoding(d_model, dropout)#定义编码器层encoder_layers = TransformerEncoderLayer(d_model, nhead, d_hid, dropout)#定义编码器,pytorch将Transformer编码器进行了打包,这里直接调用即可self.transformer_encoder = TransformerEncoder(encoder_layers, nlayers)self.embedding = nn.Embedding(ntoken,d_model)self.d_model = d_modelself.linear = nn.Linear(d_model, ntoken)self.init_weights()#初始化权重def init_weights(self) -> None:initrange = 0.1self.embedding.weight.data.uniform_(-initrange, initrange)self.linear.bias.data.zeros_()self.linear.weight.data.uniform_(-initrange, initrange)def forward(self, src:Tensor, src_mask: Tensor = None) -> Tensor:"""Arguments:src:      Tensor, 形状为[seq_len, batch_size]src_mask: Tensor, 形状为[seq_len, seq_len]Returns:输出的Tensor,形状为[seq_len, batch_size, ntoken]"""src = self.embedding(src) * math.sqrt(self.d_model)src = self.pos_encoder(src)output = self.transformer_encoder(src, src_mask)output = self.linear(output)return output
class PositionalEncoding(nn.Module):def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000):super().__init__()self.dropout = nn.Dropout(p = dropout)#生成位置编码的位置张量position = torch.arange(max_len).unsqueeze(1)#计算位置编码的除数项div_term = torch.exp(torch.arange(0, d_model, 2) * (-math.log(10000.0) / d_model))#创建位置编码张量pe = torch.zeros(max_len, 1, d_model)#使用正弦函数计算位置编码中的基数维度部分pe[:, 0, 1::2] = torch.sin(position * div_term)#使用余弦函数计算位置编码中的偶数维度部分pe[:, 0, 1::2] = torch.cos(position * div_term)self.register_buffer('pe', pe)def forward(self, x: Tensor) -> Tensor:"""Arguments:x:      Tensor, 形状为[seq_len, batch_size, embedding_dim]"""#将位置编码添加到输入张量x = x + self.pe[:x.size(0)]#应用dropoutreturn self.dropout(x)

二、加载数据集

本实验使用torchtext生成Wikitext-2数据集。在此之前,你需要安装下面的包:

  • pip install portalocker
  • pip install torchdata
import torchtext
from torchtext.datasets.wikitext2 import WikiText2
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator#从torchtext库中导入WikiTetx2数据集
train_iter = WikiText2(split = 'train')#获取基本的英语分词器
tokenizer = get_tokenizer('basic_english')
#通过迭代器构建词汇表
vocab = build_vocab_from_iterator(map(tokenizer, train_iter), specials=['<unk>'])
#将默认索引设置为'<unk>'
vocab.set_default_index(vocab['<unk>'])def data_process(raw_text_iter: dataset.IterableDataset) -> Tensor:"""将原始文本转换为扁平的张量"""data = [torch.tensor(vocab(tokenizer(item)),dtype = torch.long) for item in raw_text_iter]return torch.cat(tuple(filer(lambda t: t.numel() > 0, data)))#由于构建词汇表时"train_iter"被使用了,所以需要重新创建
train_iter, val_iter, test_iter = WikiText2()#队训练、验证和测试数据进行处理
train_data = data_process(train_iter)
val_data = data_process(val_iter)
test_data = data_process(test_iter)#检查是否有可用的CUDA设备,将设备设置为GPU或者CPU
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')def batchify(data: Tensor, bsz: int) -> Tensor:"""将数据划分为bsz个单独的序列,去除不能完全容纳的额外元素。参数:data: Tensor,形状为``[N]``bsz:int,批大小返回:形状为[N // bsz, bsz]的张量"""seq_len = data.size(0) // bszdata = data[:seq_len * bsz]data = data.view(bsz, seq_len).t().contiguous()return data.to(device)#设置批大小和评估批大小
batch_size = 20
eval_batch_size = 10
#将训练、验证和测试数据进行批处理
train_data = batchify(train_data, batch_size)   #形状为[seq_len, batch_size]
val_data = batchify(val_data, eval_batch_size)
test_data = batchify(test_data, eval_batch_size)
bptt = 35#获取批次数据
def get_batch(source:Tensor, i: int) -> Tuple[Tensor, Tensor]:"""参数:source: Tensor,形状为``[full_seq_len, batch_size]``i : int, 当前批次索引返回:tuple(data, target),-data形状为[seq_len, batch_size],-target形状为[seq_len * batch_size]"""#计算当前批次的序列长度,最大为bptt,确保不超过source的长度seq_len = min(bptt, len(source) - 1 - i)#获取data,从i开始,长度为seq_lendata = source[i:i+seq_len]#获取target,从i+1开始,长度为seq_len,并将其形状转换为一维张量target = source[i+1:i+1+seq_len].reshape(-1)return data, target

三、初始化实例

ntokend = len(vocab)
emsize = 200
d_hid = 200
nlayers = 2
nhead = 2
dropout = 0.2
#创建transformer模型
model = TransformerModel(ntokend,emsize,nhead,d_hid,nlayers,dropout).to(device)

四、训练模型

结合使用CrossEntropyLoss与SGD(随机梯度下降优化器)。训练期间,使用torch.nn.utils.clip_grad_norm_来防止梯度爆炸

import time
criterion = nn.CrossEntropyLoss() #定义交叉熵损失函数
lr = 5.0
optimizer = torch.optim.SGD(model.parameters(),lr = lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, 1.0, gama = 0.95)def train(model: nn.Module) -> None:model.train() #开启训练模式total_loss = 0.log_interval = 200 #start_time = time.time()num_batches = len(train_data) // bpttfor batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):data, targets = get_batch(train_data, i)output = model(data)output_flat = output.view(-1, ntokens)loss = criterion(output_flat, targets) #计算损失optimizer.zero_grad()loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), 0.5)optimizer.step()total_loss += loss.item()if batch % log_interval == 0 and batch > 0:lr = scheduler.get_last_lr()[0]ms_per_batch = (time.time() - start_time) * 1000 / log_intervalcur_loss = total_loss / log_intervalppl = math.exp(cur_loss)print(f'| epoch{epoch:3d} | {batch:5d} / {num_batches:5d} batches |'f'lr{lr:02.2f} | ms/batch {ms_per_batch:5.2f} |'f'loss {cur_loss:5.2f}|ppl{ppl:8.2f}')total_loss = 0start_time = time.time()def evaluate(model:nn.Module, eval_data:Tensor) -> float:model.eval()total_loss = 0.with torch.no_grad():for i in range(0,eval_data.size(0) - 1, bptt):data, targets = get_batch(eval_data,i)seq_len = data.size(0)output = model(data)output_flat = output.view(-1,ntokens)total_loss += seq_len * criterion(output_flat, targets).item()return total_loss / (len(eval_data) - 1)
best_val_loss = float('inf')
epochs = 1with TemporaryDirectory() as tempdir:best_model_params_path = os.path.join(tempdir, "best_model_params.pt")for epoch in range(1, epochs + 1):epoch_start_time = time.time()train(model)val_loss = evaluate(model, val_data)val_ppl = math_exp(val_loss)elapsed = time.time() - epoch_start_time#打印当前epoch的信息,包括耗时、验证损失和困惑度print('-' * 89)print(f'|end of epoch {epoch:3d} | time:{elapsed: 5.2f}s |'f'valid loss {val_loss:5.2f} | valid ppl {val_ppl: 8.2f}')print('-' * 89)if val_loss < best_val_loss:best_val_loss = val_losstorch.save(model.state_dict(), best_model_params_path)scheduler.step()    #更新学习率model.load_state_dict(torch.load(best_model_params_path))

代码输出:
在这里插入图片描述

五、总结

加载数据集时,注意包的版本关联关系。另外,注意结合使用优化器提升优化性能。

相关文章:

【NLP练习】Transformer实战-单词预测

&#x1f368; 本文为&#x1f517;365天深度学习训练营 中的学习记录博客&#x1f356; 原作者&#xff1a;K同学啊 任务&#xff1a;自定义输入一段英文文本进行预测 一、定义模型 from tempfile import TemporaryDirectory from typing import Tuple from torch import nn…...

使用Lua脚本保证原子性的Redis分布式锁实现

这是原来的代码&#xff1a; Override public void unlock() {// 获取线程标示String threadId ID_PREFIX Thread.currentThread().getId();// 判断标示是否一致String id stringRedisTemplate.opsForValue().get(KEY_PREFIX name);if (threadId.equals(id)) {// 释放锁st…...

什么是nginx到底怎么配置,什么是网关到底怎么配置?

那使用upstream里面具体哪个服务器是怎么决定和区分的呢? AI生成 在Nginx中,使用upstream里面具体哪个服务器是通过负载均衡算法来决定的。upstream块定义了一组服务器,Nginx会根据配置的负载均衡算法来选择一个服务器来处理当前的请求。常见的负载均衡算法包括轮询(round-…...

轻量级服务器内存不够编译的情况解决方案(以安装Ta-Lib库为例)

安装 TA-Lib 时遇到的问题通常与系统缺少必要的编译依赖项或者内存不足有关。以下是一些解决步骤,你可以按照这些步骤尝试解决问题: 问题描述:编译安装Tal-ib库出现以下问题: root@tianbaobao12:~/shipan/ta-lib# pip install ta-lib Collecting ta-libUsing cached TA-L…...

学校校园考场电子钟,同步授时,助力考场公平公正-讯鹏科技

随着教育技术的不断发展&#xff0c;学校对于考场管理的需求也日益提高。传统的考场时钟往往存在时间误差、维护不便等问题&#xff0c;这在一定程度上影响了考试的公平性和公正性。为了解决这些问题&#xff0c;越来越多的学校开始引入考场电子钟&#xff0c;通过同步授时技术…...

MySQL存储管理(一):删数据

从表中删除数据 从表中删除数据&#xff0c;也即是delete过程。 什么是表空间 表空间可以看做是InnoDB存储引擎逻辑结构的最高层&#xff0c;所有的数据都存放在表空间中。默认情况下&#xff0c;InnoDB存储引擎有一个共享表空间idbdata1&#xff0c;即所有数据都存放在这个表…...

深度剖析现阶段的多模态大模型做不了医疗

导读 在人工智能的这波浪潮中&#xff0c;以ChatGPT为首的大语言模型&#xff08;LLM&#xff09;不仅在自然语言处理&#xff08;NLP&#xff09;领域掀起了一场技术革命&#xff0c;更是在计算机视觉&#xff08;CV&#xff09;乃至多模态领域展现出了令人瞩目的潜力。 这些…...

Zabbix 监控 Kubernetes 集群

Zabbix 监控 Kubernetes 集群 Zabbix作为一个成熟且功能强大的监控系统&#xff0c;被许多企业广泛采用。它能够对各种IT基础设施进行全面的监控&#xff0c;包括服务器、网络设备、应用程序等。而将Zabbix与Kubernetes结合&#xff0c;可以实现对Kubernetes集群的全面监控&am…...

网上预约就医取号系统

摘 要 近年来&#xff0c;随着信息技术的发展和普及&#xff0c;我国医疗信息产业快速发展&#xff0c;各大医院陆续推出自己的信息系统来实现医疗服务的现代化转型。不可否认&#xff0c;对一些大型三级医院来说&#xff0c;其信息服务质量还是广泛被大众所认可的。这就更需要…...

概念描述——TCP/IP模型中的两个重要分界线

TCP/IP模型中的两个重要分界线 协议的层次概念包含了两个也许不太明显的分界线&#xff0c;一个是协议地址分界线&#xff0c;区分出高层与低层寻址操作&#xff1b;另一个是操作系统分界线&#xff0c;它把系统与应用程序区分开来。 高层协议地址界限 当我们看到TCP/P软件的…...

ECharts,拿来吧你!

作为一名前端程序员,在日常的项目开发中,我们会遇到各种各样的图表设计,那么,为了提高我们的开发效率,ECharts便应运而生了!它提供了丰富的图表样式和多浏览器支持的API接口,不仅能够将静态的数据转换为图表,还可以动态的请求后端传递过来的数据,将其以可视化的形式展现给用户,…...

【DICOM】BitsAllocated字段值为8和16时区别

一、读取dicom C# 使用fo-dicom操作dicom文件-CSDN博客 二、DICOM中BitsAllocated字段值为8和16时区别 位深度差异&#xff1a; 当BitsAllocated为8时&#xff0c;意味着每个像素使用8位来表示其灰度值。这允许每个像素有2^8256种不同的灰度等级&#xff0c;适用于那些不需要高…...

【MySQL】 -- 事务

如果对表中的数据进行CRUD操作时&#xff0c;不加控制&#xff0c;会带来一些问题。 比如下面这种场景&#xff1a; 有一个tickets表&#xff0c;这个数据库被两个客户端机器A和B用时连接对此表进行操作。客户端A检查tickets表中还有一张票的时候&#xff0c;将票出售了&#x…...

c#调用c++生成的dll,c++端使用opencv, c#端使用OpenCvSharp, 返回一张图像

c代码&#xff1a; // OpenCVImageLibrary.cpp #include <opencv2/opencv.hpp> #include <vector> extern "C" { __declspec(dllexport) unsigned char* ReadImageToBGR(const char* filePath, int* width, int* height, int* step) { cv::Mat i…...

【Android面试八股文】你能说一说View绘制流程与自定义View注意点吗?

文章目录 一、自定义View的构造函数以及各参数的用法二、自定义View的几种方式三、自定义View的绘制流程四、自定义View需要注意的一些点五、举个例子一、自定义View的构造函数以及各参数的用法 在Android中,自定义View通常需要提供多个构造函数,以适应不同的使用场景。主要…...

【第24章】Vue实战篇之用户信息展示

文章目录 前言一、准备1. 获取用户信息2. 存储用户信息3. 加载用户信息 二、用户信息1.昵称2.头像 三、展示总结 前言 这里我们来展示用户昵称和头像。 一、准备 1. 获取用户信息 export const userInfoService ()>{return request.get(/user/info) }2. 存储用户信息 i…...

“打造智能售货机系统,基于ruoyi微服务版本生成基础代码“

目录 # 开篇 1. 菜单 2. 字典配置 3. 表配置 3.1 导入表 3.2 区域管理 3.3 合作商管理 3.4 点位管理 4. 代码导入 4.1 后端代码生成 4.2 前端代码生成 5. 数据库代码执行 6. 点位管理菜单顺序修改 7. 页面展示 8. 附加设备表 8.1 新增设备管理菜单 8.2 创建字…...

oracle12c到19c adg搭建(五)dg搭建后进行切换19c进行数据字典升级

一、备库切主库升级 12c切换为19c主库的时候是由低版本到高版本所以cdb和pdb的数据字典需要进行升级才可以让数据与软件版本兼容。 1.1切换 SQL> alter database recover managed standby database finish; Database altered. SQL> alter database commit to switcho…...

在公司的一些笔记

6.19 记住挂载在windows上的账户是DAHUATECH\401593&#xff0c;不是401593Windows与linux不能同时挂载在虚拟盘上 6.21 /******************************************************************************* pdc_ledSy7806e.c* * Description: 提供I2C访问sy7806e。 * * …...

2020C++等级考试二级真题题解

202012数组指定部分逆序重放c #include <iostream> using namespace std; int main() {int a[110];int n, k;cin >> n >> k;for (int i 0; i < n; i) {cin >> a[i];}for (int i 0; i < k / 2; i) {swap(a[i], a[k - 1 - i]);}for (int i 0…...

面试官:聊聊 nextTick

前言 在最近的面试中,不少面试官叫我聊聊 nextTick,nextTick 是个啥,这篇文章咱来好好聊聊! 我的回答 nextTick 是官方提供的一个异步方法,用于在 DOM 更新之后执行回调。正好在我的项目中用到了,就拿它来形容一下,大概的场景是渲染一个列表,每次点击按钮就会往列表后…...

shell编程之条件语句(shell脚本)

条件测试操作 要使shell脚本程序具备一定的“智能”,面临的第一个问题就是如何区分不同的情况以确定执行何种操作。例如,当磁盘使用率超过95%时,发送告警信息;当备份目录不存在时,能够自动创建;当源码编译程序时,若配置失败则不再继续安装等。 shell环境根据命令执行后…...

QT中QSettings的使用系列之二:保存和恢复应用程序主窗口

1、核心代码 #include "widget.h" #include "ui_widget.h" #include <QSettings> #include <QDebug> #include <QColo...

Linux系统上安装Miniconda并安装特定版本的Python

要在Linux系统上安装Miniconda并安装特定版本的Python&#xff08;例如3.10.12&#xff09;&#xff0c;请按照以下步骤进行操作&#xff1a; 1. 下载并安装Miniconda 下载Miniconda安装脚本&#xff1a; 使用wget或curl下载Miniconda安装脚本。以下是使用wget的命令&#xff…...

解决Qt中 -lGL无法找到的问题

在使用Qt Creator创建并编译新项目时&#xff0c;可能会遇到以下错误&#xff1a; /usr/bin/ld: cannot find -lGL collect2: error: ld returned 1 exit status make: *** [untitled1] Error 1 18:07:41: The process "/usr/bin/make" exited with code 2. Error w…...

【重要】《HTML趣味编程》专栏内资源的下载链接

目录 关于专栏 博主简介 专栏内资源的下载链接 写在后面 关于专栏 本专栏将持续更新,至少含有30个案例,后续随着案例的增加可能会涨价,欢迎大家尽早订阅!(订阅后可查看专栏内所有文章,并且可以下载专栏内的所有资源) 博主简介 ⭐ 2024年百度文心智能体大赛 Top1⭐…...

苍穹外卖环境搭建

一、前端环境搭建 ①整体结构 ②前端工程基于nginx运行 启动nginx:双击 nginx.exe 即可启动 nginx 服务&#xff0c;访问端口号为 80 进入浏览器地址输入locallhost回车 二、后端环境搭建 后端初始工程基于maven进行项目构建&#xff0c;并且进行分模块开发 (1) idea打开初始…...

切割游戏介绍

简介 上大学时&#xff0c;在学校实验室里玩过一个貌似使用VC写的小游戏&#xff0c;一个小球在界面上四处游荡&#xff0c;玩家使用鼠标切割背景&#xff0c;将背景切割剩余到一定的百分比后&#xff0c;就胜利了&#xff0c;后边的背景图会全部展示出来。 使用qt的qml技术&a…...

对接Paypal、Stripe支付简单流程

一、Stripe卡支付简单流程&#xff1a; #mermaid-svg-bZxQh1bt4Z8agjJg {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-bZxQh1bt4Z8agjJg .error-icon{fill:#552222;}#mermaid-svg-bZxQh1bt4Z8agjJg .error-text{fi…...

微服务中不同服务使用openfeign 相互调用

首先 我们上文 已经知道了 nacos 的注册服务&#xff0c;现在 我们 在不同服务中相互调用就可以使用openfeign 直接调用&#xff0c;而不是 再写冗余的调用代码啦 首先 我们的微服务组件如下 因为我这个微服务是我在 员工登录demo 中 拆出来的&#xff0c;在userlogin模块中…...

网站制作布局/线上推广的三种方式

OpenGL ES 03 – 转化今天&#xff0c;我们要在之前的基础 上&#xff0c;在屏幕上同时显示三角形和矩形。为了做到这点&#xff0c;我们需要移动它们。移动物体这个动作我们称之为转化。&#xff08;坐标转换&#xff09;在OpenGL ES中&#xff0c;对模型&#xff0f;物体进行…...

互联网创业项目哪家好平台/上海快速排名优化

页面布局遇到一个奇怪现象&#xff0c;在RelativeLayout 时最下面的 一个view 设置的layout_marginBottom 在小米手机显示正常&#xff0c;在三星&#xff0c;华为设置的距离就变为 0 了。 <RelativeLayout xmlns:android"http://schemas.android.com/apk/res/android&…...

建立网站赚钱吗/最新病毒感染

组件就是创建html中不存在的标签。...

ddns做网站/今天刚刚发生的新闻

eclipse在tomcat运行自己的网页程序 Run on Server时发现RT错误 解决办法&#xff1a; Windows ->Show View -> Servers 在 Servers View 里删除相应的 Server 然后再Run on Server&#xff0c;就可以了 转载于:https://www.cnblogs.com/forgeting/p/4382952.html...

专业做幼儿园设计的网站/日本比分预测

参考 温绍景-Java虚拟机基础...

wordpress 新建栏目/关键词优化哪家好

1.论述 1.1.获取计算机系统所有实际可用物理区域信息 计算机系统的物理区域并不是连续的&#xff0c;且有的物理区域映射到特定硬件内部区域以便实现特定功能。 对计算机系统的物理区域进行分页管理&#xff0c;首先&#xff0c;要获取所有离散的物理区域信息。 INT 15h, AXE…...