当前位置: 首页 > news >正文

Transformer的PyTorch实现之若干问题探讨(一)

《Transformer的PyTorch实现》这篇博文以一个机器翻译任务非常优雅简介的阐述了Transformer结构。在阅读时存在一些小困惑,此处权当一个记录。

1.自定义数据中enc_input、dec_input及dec_output的区别

博文中给出了两对德语翻译成英语的例子:

# S: decoding input 的起始符
# E: decoding output 的结束符
# P:意为padding,如果当前句子短于本batch的最长句子,那么用这个符号填补缺失的单词
sentence = [# enc_input   dec_input    dec_output['ich mochte ein bier P','S i want a beer .', 'i want a beer . E'],['ich mochte ein cola P','S i want a coke .', 'i want a coke . E'],
]

初看会对这其中的enc_input、dec_input及dec_output三个句子的作用不太理解,此处作详细解释:
-enc_input是模型需要翻译的输入句子,
-dec_input是用于指导模型开始翻译过程的信号
-dec_output是模型训练时的目标输出,模型的目标是使其产生的输出尽可能接近dec_output,即为翻译真实标签。他们在transformer block中的位置如下:
在这里插入图片描述

在使用Transformer进行翻译的时候,需要在Encoder端输入enc_input编码的向量,在decoder端最初只输入起始符S,然后让Transformer网络预测下一个token。

我们知道Transformer架构在进行预测时,每次推理时会获得下一个token,因此推理不是并行的,需要输出多少个token,理论上就要推理多少次。那么,在训练阶段,也需要像预测那样根据之前的输出预测下一个token,然而再所引出dec_output中对应的token做损失吗?实际并不是这样,如果真是这样做,就没有办法并行训练了。

实际我认为Transformer的并行应该是有两个层次:
(1)不同batch在训练和推理时是否可以实现并行?
(2)一个batch是否能并行得把所有的token推理出来?
Tranformer在训练时实现了上述的(1)(2),而推理时(1)(2)都没有实现。Transformer的推理似乎很难实现并行,原因是如果一次性推理两句话,那么如何保证这两句话一样长?难道有一句已经结束了,另一句没有结束,需要不断的把结束符E送入继续预测下一个结束符吗?此外,Transformer在预测下一个token时必须前面的token已经预测出来了,如果第i-1个token都没有,是无法得到第i个token。因此推理的时候都是逐句话预测,逐token预测。这儿实际也是我认为是transformer结构需要改进的地方。这样才可以提高transformer的推理效率。

2.Transformer的训练流程

此处给出博文中附带的非常简洁的Transformer训练代码:

from torch import optim
from model import *model = Transformer().cuda()
model.train()
# 损失函数,忽略为0的类别不对其计算loss(因为是padding无意义)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)# 训练开始
for epoch in range(1000):for enc_inputs, dec_inputs, dec_outputs in loader:'''enc_inputs: [batch_size, src_len] [2,5]dec_inputs: [batch_size, tgt_len] [2,6]dec_outputs: [batch_size, tgt_len] [2,6]'''enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda() # [2, 6], [2, 6], [2, 6]outputs = model(enc_inputs, dec_inputs) # outputs: [batch_size * tgt_len, tgt_vocab_size]loss = criterion(outputs, dec_outputs.view(-1))  # 将dec_outputs展平成一维张量# 更新权重optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch [{epoch + 1}/1000], Loss: {loss.item()}')
torch.save(model, f'MyTransformer_temp.pth')

这段代码非常简洁,可以看到输入的是batch为2的样本,送入Transformer网络中直接logits算损失。Transformer在训练时实际上使用了一个策略叫teacher forcing。要解释这个策略的意义,以本博文给出的样本为例,对于输入的样本:

ich mochte ein bier

在进行训练时,当我们给出起始符S,接下来应该预测出:

I

那训练时,有了SI后,则应该预测出

want

那么问题来了,如I就预测错了,假如预测成了a,那么在预测want时,还应该使用Sa来预测吗?当然不是,即使预测错了,也应该用对应位置正确的tokenSI去预测下一个token,这就是teacher forcing。

那么transformer是如何实现这样一个teacher forcing的机制的呢?且听下回分解。

相关文章:

Transformer的PyTorch实现之若干问题探讨(一)

《Transformer的PyTorch实现》这篇博文以一个机器翻译任务非常优雅简介的阐述了Transformer结构。在阅读时存在一些小困惑,此处权当一个记录。 1.自定义数据中enc_input、dec_input及dec_output的区别 博文中给出了两对德语翻译成英语的例子: # S: de…...

系统参数SystemParameters.MinimumHorizontalDragDistance

SystemParameters.MinimumHorizontalDragDistance 是一个系统参数,它表示在拖放操作中鼠标水平移动的最小距离。 当用户按下鼠标左键并开始移动鼠标时,系统会检查鼠标的水平移动距离是否超过了 SystemParameters.MinimumHorizontalDragDistance。只有当…...

平屋顶安装光伏需要注意哪些事项?

我国对于房屋建设的屋顶形式,主要有平屋顶、斜屋顶、曲面屋顶和多波式折板屋顶等。今天来讲讲在平屋顶安装光伏,需要注意的事项。 1.屋顶结构:在安装光伏系统之前,需要对屋顶结构进行评估,确保屋顶能够承受光伏系统的…...

《Git 简易速速上手小册》第7章:处理大型项目(2024 最新版)

文章目录 7.1 Git Large File Storage (LFS)7.1.1 基础知识讲解7.1.2 重点案例:在 Python 项目中使用 Git LFS 管理数据集7.1.3 拓展案例 1:使用 Git LFS 管理大型静态资源7.1.4 拓展案例 2:优化现有项目中的大文件管理 7.2 性能优化技巧7.2.…...

从0开始学Docker ---Docker安装教程

Docker安装教程 本安装教程参考Docker官方文档,地址如下: https://docs.docker.com/engine/install/centos/ 1.卸载旧版 首先如果系统中已经存在旧的Docker,则先卸载: yum remove docker \docker-client \docker-client-latest…...

嵌入式学习之Linux入门篇笔记——15,Linux编写第一个自己的命令

配套视频学习链接:http://【【北京迅为】嵌入式学习之Linux入门篇】 https://www.bilibili.com/video/BV1M7411m7wT/?p4&share_sourcecopy_web&vd_sourcea0ef2c4953d33a9260910aaea45eaec8 1.什么是命令? 命令就是可执行程序。 比如 ls -a…...

【C语言】SYSCALL_DEFINE3(socket, int, family, int, type, int, protocol)

一、SYSCALL_DEFINE3与系统调用 在Linux操作系统中,为了从用户空间跳转到内核空间执行特定的内核级操作,使用了一种机制叫做"系统调用"(System Call)。系统调用是操作系统提供给程序员访问和使用内核功能的接口。例如&…...

C++实现鼠标点击和获取鼠标位置(编译环境visual studio 2022)

1环境说明 2获取鼠标位置的接口 void GetMouseCurPoint() {POINT mypoint;for (int i 0; i < 100; i){GetCursorPos(&mypoint);//获取鼠标当前所在位置printf("% ld, % ld \n", mypoint.x, mypoint.y);Sleep(1000);} } 3操作鼠标左键和右键的接口 void Mo…...

Matplotlib绘制炫酷散点图:从二维到三维,再到散点图矩阵的完整指南与实战【第58篇—python:Matplotlib绘制炫酷散点图】

文章目录 Matplotlib绘制炫酷散点图&#xff1a;二维、三维和散点图矩阵的参数说明与实战引言二维散点图三维散点图散点图矩阵二维散点图进阶&#xff1a;辅助线、注释和子图三维散点图进阶&#xff1a;动画效果和交互性散点图矩阵进阶&#xff1a;调整样式和添加密度图总结与展…...

Docker-Learn(一)使用Dockerfile创建Docker镜像

1.创建并运行容器 编写Dockerfile&#xff0c;文件名字就是为Dockerfile 在自己的工作工作空间当中新建文件&#xff0c;名字为Docerfile vim Dockerfile写入以下内容&#xff1a; # 使用一个基础镜像 FROM ubuntu:latest # 设置工作目录 WORKDIR /app # 复制当前目…...

问题:银行账号建立以后,一般需要维护哪些设置,不包括() #学习方法#经验分享

问题&#xff1a;银行账号建立以后&#xff0c;一般需要维护哪些设置&#xff0c;不包括&#xff08;&#xff09; A&#xff0e;维护结算科目对照 B&#xff0e;期初余额初始化刷 C&#xff0e;自定义转账定义 D&#xff0e;对账单初始化 参考答案如图所示...

教授LLM思考和行动:ReAct提示词工程

ReAct&#xff1a;论文主页 原文链接&#xff1a;Teaching LLMs to Think and Act: ReAct Prompt Engineering 在人类从事一项需要多个步骤的任务时&#xff0c;而步骤和步骤之间&#xff0c;或者说动作和动作之间&#xff0c;往往会有一个推理过程。让LLM把内心独白说出来&am…...

FPGA_工程_按键控制的基于Rom数码管显示

一 信号 框图&#xff1a; 其中 key_filter seg_595_dynamic均为已有模块&#xff0c;直接例化即可使用&#xff0c;rom_8*256模块&#xff0c;调用rom ip实现。Rom_ctrl模块需要重新编写。 波形图&#xff1a; 二 代码 module key_fliter #(parameter CNT_MAX 24d9_999_99…...

WordPress Plugin HTML5 Video Player SQL注入漏洞复现(CVE-2024-1061)

0x01 产品简介 WordPress和WordPress plugin都是WordPress基金会的产品。WordPress是一套使用PHP语言开发的博客平台。该平台支持在PHP和MySQL的服务器上架设个人博客网站。WordPress plugin是一个应用插件。 0x02 漏洞概述 WordPress Plugin HTML5 Video Player 插件 get_v…...

【Kotlin】Kotlin基本数据类型

1 变量声明 var a : Int // 声明整数类型变量 var b : Int 1 // 声明整数类型变量, 同时赋初值为1 var c 1 // 声明整数类型变量, 同时赋初值为1 val d 1 // 声明整数类型常量, 值为1(后面不能改变d的值) 变量命名规范如下。 变量名可以由字母、数字、下划线&#xff08;_…...

UDP端口探活的那些细节

一 背景 商业客户反馈用categraf的net_response插件配置了udp探测, 遇到报错了&#xff0c;如图 udp是无连接的&#xff0c;无法用建立连接的形式判断端口。 插件最初的设计是需要配置udp的发送字符&#xff0c;并且配置期望返回的字符串&#xff0c; [[instances]] targets…...

拦截器配置,FeignClient根据业务规则实现微服务动态路由

文章目录 业务场景拦截器用法Open Feign介绍 业务场景 我们服务使用Spring Cloud微服务架构&#xff0c;使用Spring Cloud Gateway 作为网关&#xff0c;使用 Spring Cloud OpenFeign 作为服务间通信方式我们现在做的信控平台&#xff0c;主要功能之一就是对路口信号机进行管控…...

预测模型:MATLAB线性回归

1. 线性回归模型的基本原理 线性回归是统计学中用来预测连续变量之间关系的一种方法。它假设变量之间存在线性关系&#xff0c;可以通过一个或多个自变量&#xff08;预测变量&#xff09;来预测因变量&#xff08;响应变量&#xff09;的值。基本的线性回归模型可以表示为&…...

【人工智能】神奇的Embedding:文本变向量,大语言模型智慧密码解析(10)

什么是嵌入&#xff1f; OpenAI 的文本嵌入衡量文本字符串的相关性。嵌入通常用于&#xff1a; Search 搜索&#xff08;结果按与查询字符串的相关性排序&#xff09;Clustering 聚类&#xff08;文本字符串按相似性分组&#xff09;Recommendations 推荐&#xff08;推荐具有…...

Redis + Lua 实现分布式限流器

文章目录 Redis Lua 限流实现1. 导入依赖2. 配置application.properties3. 配置RedisTemplate实例4. 定义限流类型枚举类5. 自定义注解6. 切面代码实现7. 控制层实现8. 测试 相比 Redis事务&#xff0c; Lua脚本的优点&#xff1a; 减少网络开销&#xff1a;使用Lua脚本&…...

【大模型RAG】拍照搜题技术架构速览:三层管道、两级检索、兜底大模型

摘要 拍照搜题系统采用“三层管道&#xff08;多模态 OCR → 语义检索 → 答案渲染&#xff09;、两级检索&#xff08;倒排 BM25 向量 HNSW&#xff09;并以大语言模型兜底”的整体框架&#xff1a; 多模态 OCR 层 将题目图片经过超分、去噪、倾斜校正后&#xff0c;分别用…...

基于Flask实现的医疗保险欺诈识别监测模型

基于Flask实现的医疗保险欺诈识别监测模型 项目截图 项目简介 社会医疗保险是国家通过立法形式强制实施&#xff0c;由雇主和个人按一定比例缴纳保险费&#xff0c;建立社会医疗保险基金&#xff0c;支付雇员医疗费用的一种医疗保险制度&#xff0c; 它是促进社会文明和进步的…...

CMake基础:构建流程详解

目录 1.CMake构建过程的基本流程 2.CMake构建的具体步骤 2.1.创建构建目录 2.2.使用 CMake 生成构建文件 2.3.编译和构建 2.4.清理构建文件 2.5.重新配置和构建 3.跨平台构建示例 4.工具链与交叉编译 5.CMake构建后的项目结构解析 5.1.CMake构建后的目录结构 5.2.构…...

以光量子为例,详解量子获取方式

光量子技术获取量子比特可在室温下进行。该方式有望通过与名为硅光子学&#xff08;silicon photonics&#xff09;的光波导&#xff08;optical waveguide&#xff09;芯片制造技术和光纤等光通信技术相结合来实现量子计算机。量子力学中&#xff0c;光既是波又是粒子。光子本…...

视觉slam十四讲实践部分记录——ch2、ch3

ch2 一、使用g++编译.cpp为可执行文件并运行(P30) g++ helloSLAM.cpp ./a.out运行 二、使用cmake编译 mkdir build cd build cmake .. makeCMakeCache.txt 文件仍然指向旧的目录。这表明在源代码目录中可能还存在旧的 CMakeCache.txt 文件,或者在构建过程中仍然引用了旧的路…...

PAN/FPN

import torch import torch.nn as nn import torch.nn.functional as F import mathclass LowResQueryHighResKVAttention(nn.Module):"""方案 1: 低分辨率特征 (Query) 查询高分辨率特征 (Key, Value).输出分辨率与低分辨率输入相同。"""def __…...

莫兰迪高级灰总结计划简约商务通用PPT模版

莫兰迪高级灰总结计划简约商务通用PPT模版&#xff0c;莫兰迪调色板清新简约工作汇报PPT模版&#xff0c;莫兰迪时尚风极简设计PPT模版&#xff0c;大学生毕业论文答辩PPT模版&#xff0c;莫兰迪配色总结计划简约商务通用PPT模版&#xff0c;莫兰迪商务汇报PPT模版&#xff0c;…...

力扣热题100 k个一组反转链表题解

题目: 代码: func reverseKGroup(head *ListNode, k int) *ListNode {cur : headfor i : 0; i < k; i {if cur nil {return head}cur cur.Next}newHead : reverse(head, cur)head.Next reverseKGroup(cur, k)return newHead }func reverse(start, end *ListNode) *ListN…...

HybridVLA——让单一LLM同时具备扩散和自回归动作预测能力:训练时既扩散也回归,但推理时则扩散

前言 如上一篇文章《dexcap升级版之DexWild》中的前言部分所说&#xff0c;在叠衣服的过程中&#xff0c;我会带着团队对比各种模型、方法、策略&#xff0c;毕竟针对各个场景始终寻找更优的解决方案&#xff0c;是我个人和我司「七月在线」的职责之一 且个人认为&#xff0c…...

DBLP数据库是什么?

DBLP&#xff08;Digital Bibliography & Library Project&#xff09;Computer Science Bibliography是全球著名的计算机科学出版物的开放书目数据库。DBLP所收录的期刊和会议论文质量较高&#xff0c;数据库文献更新速度很快&#xff0c;很好地反映了国际计算机科学学术研…...