当前位置: 首页 > news >正文

Transformer的PyTorch实现之若干问题探讨(一)

《Transformer的PyTorch实现》这篇博文以一个机器翻译任务非常优雅简介的阐述了Transformer结构。在阅读时存在一些小困惑,此处权当一个记录。

1.自定义数据中enc_input、dec_input及dec_output的区别

博文中给出了两对德语翻译成英语的例子:

# S: decoding input 的起始符
# E: decoding output 的结束符
# P:意为padding,如果当前句子短于本batch的最长句子,那么用这个符号填补缺失的单词
sentence = [# enc_input   dec_input    dec_output['ich mochte ein bier P','S i want a beer .', 'i want a beer . E'],['ich mochte ein cola P','S i want a coke .', 'i want a coke . E'],
]

初看会对这其中的enc_input、dec_input及dec_output三个句子的作用不太理解,此处作详细解释:
-enc_input是模型需要翻译的输入句子,
-dec_input是用于指导模型开始翻译过程的信号
-dec_output是模型训练时的目标输出,模型的目标是使其产生的输出尽可能接近dec_output,即为翻译真实标签。他们在transformer block中的位置如下:
在这里插入图片描述

在使用Transformer进行翻译的时候,需要在Encoder端输入enc_input编码的向量,在decoder端最初只输入起始符S,然后让Transformer网络预测下一个token。

我们知道Transformer架构在进行预测时,每次推理时会获得下一个token,因此推理不是并行的,需要输出多少个token,理论上就要推理多少次。那么,在训练阶段,也需要像预测那样根据之前的输出预测下一个token,然而再所引出dec_output中对应的token做损失吗?实际并不是这样,如果真是这样做,就没有办法并行训练了。

实际我认为Transformer的并行应该是有两个层次:
(1)不同batch在训练和推理时是否可以实现并行?
(2)一个batch是否能并行得把所有的token推理出来?
Tranformer在训练时实现了上述的(1)(2),而推理时(1)(2)都没有实现。Transformer的推理似乎很难实现并行,原因是如果一次性推理两句话,那么如何保证这两句话一样长?难道有一句已经结束了,另一句没有结束,需要不断的把结束符E送入继续预测下一个结束符吗?此外,Transformer在预测下一个token时必须前面的token已经预测出来了,如果第i-1个token都没有,是无法得到第i个token。因此推理的时候都是逐句话预测,逐token预测。这儿实际也是我认为是transformer结构需要改进的地方。这样才可以提高transformer的推理效率。

2.Transformer的训练流程

此处给出博文中附带的非常简洁的Transformer训练代码:

from torch import optim
from model import *model = Transformer().cuda()
model.train()
# 损失函数,忽略为0的类别不对其计算loss(因为是padding无意义)
criterion = nn.CrossEntropyLoss(ignore_index=0)
optimizer = optim.SGD(model.parameters(), lr=1e-3, momentum=0.99)# 训练开始
for epoch in range(1000):for enc_inputs, dec_inputs, dec_outputs in loader:'''enc_inputs: [batch_size, src_len] [2,5]dec_inputs: [batch_size, tgt_len] [2,6]dec_outputs: [batch_size, tgt_len] [2,6]'''enc_inputs, dec_inputs, dec_outputs = enc_inputs.cuda(), dec_inputs.cuda(), dec_outputs.cuda() # [2, 6], [2, 6], [2, 6]outputs = model(enc_inputs, dec_inputs) # outputs: [batch_size * tgt_len, tgt_vocab_size]loss = criterion(outputs, dec_outputs.view(-1))  # 将dec_outputs展平成一维张量# 更新权重optimizer.zero_grad()loss.backward()optimizer.step()print(f'Epoch [{epoch + 1}/1000], Loss: {loss.item()}')
torch.save(model, f'MyTransformer_temp.pth')

这段代码非常简洁,可以看到输入的是batch为2的样本,送入Transformer网络中直接logits算损失。Transformer在训练时实际上使用了一个策略叫teacher forcing。要解释这个策略的意义,以本博文给出的样本为例,对于输入的样本:

ich mochte ein bier

在进行训练时,当我们给出起始符S,接下来应该预测出:

I

那训练时,有了SI后,则应该预测出

want

那么问题来了,如I就预测错了,假如预测成了a,那么在预测want时,还应该使用Sa来预测吗?当然不是,即使预测错了,也应该用对应位置正确的tokenSI去预测下一个token,这就是teacher forcing。

那么transformer是如何实现这样一个teacher forcing的机制的呢?且听下回分解。

相关文章:

Transformer的PyTorch实现之若干问题探讨(一)

《Transformer的PyTorch实现》这篇博文以一个机器翻译任务非常优雅简介的阐述了Transformer结构。在阅读时存在一些小困惑,此处权当一个记录。 1.自定义数据中enc_input、dec_input及dec_output的区别 博文中给出了两对德语翻译成英语的例子: # S: de…...

系统参数SystemParameters.MinimumHorizontalDragDistance

SystemParameters.MinimumHorizontalDragDistance 是一个系统参数,它表示在拖放操作中鼠标水平移动的最小距离。 当用户按下鼠标左键并开始移动鼠标时,系统会检查鼠标的水平移动距离是否超过了 SystemParameters.MinimumHorizontalDragDistance。只有当…...

平屋顶安装光伏需要注意哪些事项?

我国对于房屋建设的屋顶形式,主要有平屋顶、斜屋顶、曲面屋顶和多波式折板屋顶等。今天来讲讲在平屋顶安装光伏,需要注意的事项。 1.屋顶结构:在安装光伏系统之前,需要对屋顶结构进行评估,确保屋顶能够承受光伏系统的…...

《Git 简易速速上手小册》第7章:处理大型项目(2024 最新版)

文章目录 7.1 Git Large File Storage (LFS)7.1.1 基础知识讲解7.1.2 重点案例:在 Python 项目中使用 Git LFS 管理数据集7.1.3 拓展案例 1:使用 Git LFS 管理大型静态资源7.1.4 拓展案例 2:优化现有项目中的大文件管理 7.2 性能优化技巧7.2.…...

从0开始学Docker ---Docker安装教程

Docker安装教程 本安装教程参考Docker官方文档,地址如下: https://docs.docker.com/engine/install/centos/ 1.卸载旧版 首先如果系统中已经存在旧的Docker,则先卸载: yum remove docker \docker-client \docker-client-latest…...

嵌入式学习之Linux入门篇笔记——15,Linux编写第一个自己的命令

配套视频学习链接:http://【【北京迅为】嵌入式学习之Linux入门篇】 https://www.bilibili.com/video/BV1M7411m7wT/?p4&share_sourcecopy_web&vd_sourcea0ef2c4953d33a9260910aaea45eaec8 1.什么是命令? 命令就是可执行程序。 比如 ls -a…...

【C语言】SYSCALL_DEFINE3(socket, int, family, int, type, int, protocol)

一、SYSCALL_DEFINE3与系统调用 在Linux操作系统中,为了从用户空间跳转到内核空间执行特定的内核级操作,使用了一种机制叫做"系统调用"(System Call)。系统调用是操作系统提供给程序员访问和使用内核功能的接口。例如&…...

C++实现鼠标点击和获取鼠标位置(编译环境visual studio 2022)

1环境说明 2获取鼠标位置的接口 void GetMouseCurPoint() {POINT mypoint;for (int i 0; i < 100; i){GetCursorPos(&mypoint);//获取鼠标当前所在位置printf("% ld, % ld \n", mypoint.x, mypoint.y);Sleep(1000);} } 3操作鼠标左键和右键的接口 void Mo…...

Matplotlib绘制炫酷散点图:从二维到三维,再到散点图矩阵的完整指南与实战【第58篇—python:Matplotlib绘制炫酷散点图】

文章目录 Matplotlib绘制炫酷散点图&#xff1a;二维、三维和散点图矩阵的参数说明与实战引言二维散点图三维散点图散点图矩阵二维散点图进阶&#xff1a;辅助线、注释和子图三维散点图进阶&#xff1a;动画效果和交互性散点图矩阵进阶&#xff1a;调整样式和添加密度图总结与展…...

Docker-Learn(一)使用Dockerfile创建Docker镜像

1.创建并运行容器 编写Dockerfile&#xff0c;文件名字就是为Dockerfile 在自己的工作工作空间当中新建文件&#xff0c;名字为Docerfile vim Dockerfile写入以下内容&#xff1a; # 使用一个基础镜像 FROM ubuntu:latest # 设置工作目录 WORKDIR /app # 复制当前目…...

问题:银行账号建立以后,一般需要维护哪些设置,不包括() #学习方法#经验分享

问题&#xff1a;银行账号建立以后&#xff0c;一般需要维护哪些设置&#xff0c;不包括&#xff08;&#xff09; A&#xff0e;维护结算科目对照 B&#xff0e;期初余额初始化刷 C&#xff0e;自定义转账定义 D&#xff0e;对账单初始化 参考答案如图所示...

教授LLM思考和行动:ReAct提示词工程

ReAct&#xff1a;论文主页 原文链接&#xff1a;Teaching LLMs to Think and Act: ReAct Prompt Engineering 在人类从事一项需要多个步骤的任务时&#xff0c;而步骤和步骤之间&#xff0c;或者说动作和动作之间&#xff0c;往往会有一个推理过程。让LLM把内心独白说出来&am…...

FPGA_工程_按键控制的基于Rom数码管显示

一 信号 框图&#xff1a; 其中 key_filter seg_595_dynamic均为已有模块&#xff0c;直接例化即可使用&#xff0c;rom_8*256模块&#xff0c;调用rom ip实现。Rom_ctrl模块需要重新编写。 波形图&#xff1a; 二 代码 module key_fliter #(parameter CNT_MAX 24d9_999_99…...

WordPress Plugin HTML5 Video Player SQL注入漏洞复现(CVE-2024-1061)

0x01 产品简介 WordPress和WordPress plugin都是WordPress基金会的产品。WordPress是一套使用PHP语言开发的博客平台。该平台支持在PHP和MySQL的服务器上架设个人博客网站。WordPress plugin是一个应用插件。 0x02 漏洞概述 WordPress Plugin HTML5 Video Player 插件 get_v…...

【Kotlin】Kotlin基本数据类型

1 变量声明 var a : Int // 声明整数类型变量 var b : Int 1 // 声明整数类型变量, 同时赋初值为1 var c 1 // 声明整数类型变量, 同时赋初值为1 val d 1 // 声明整数类型常量, 值为1(后面不能改变d的值) 变量命名规范如下。 变量名可以由字母、数字、下划线&#xff08;_…...

UDP端口探活的那些细节

一 背景 商业客户反馈用categraf的net_response插件配置了udp探测, 遇到报错了&#xff0c;如图 udp是无连接的&#xff0c;无法用建立连接的形式判断端口。 插件最初的设计是需要配置udp的发送字符&#xff0c;并且配置期望返回的字符串&#xff0c; [[instances]] targets…...

拦截器配置,FeignClient根据业务规则实现微服务动态路由

文章目录 业务场景拦截器用法Open Feign介绍 业务场景 我们服务使用Spring Cloud微服务架构&#xff0c;使用Spring Cloud Gateway 作为网关&#xff0c;使用 Spring Cloud OpenFeign 作为服务间通信方式我们现在做的信控平台&#xff0c;主要功能之一就是对路口信号机进行管控…...

预测模型:MATLAB线性回归

1. 线性回归模型的基本原理 线性回归是统计学中用来预测连续变量之间关系的一种方法。它假设变量之间存在线性关系&#xff0c;可以通过一个或多个自变量&#xff08;预测变量&#xff09;来预测因变量&#xff08;响应变量&#xff09;的值。基本的线性回归模型可以表示为&…...

【人工智能】神奇的Embedding:文本变向量,大语言模型智慧密码解析(10)

什么是嵌入&#xff1f; OpenAI 的文本嵌入衡量文本字符串的相关性。嵌入通常用于&#xff1a; Search 搜索&#xff08;结果按与查询字符串的相关性排序&#xff09;Clustering 聚类&#xff08;文本字符串按相似性分组&#xff09;Recommendations 推荐&#xff08;推荐具有…...

Redis + Lua 实现分布式限流器

文章目录 Redis Lua 限流实现1. 导入依赖2. 配置application.properties3. 配置RedisTemplate实例4. 定义限流类型枚举类5. 自定义注解6. 切面代码实现7. 控制层实现8. 测试 相比 Redis事务&#xff0c; Lua脚本的优点&#xff1a; 减少网络开销&#xff1a;使用Lua脚本&…...

未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?

编辑&#xff1a;陈萍萍的公主一点人工一点智能 未来机器人的大脑&#xff1a;如何用神经网络模拟器实现更智能的决策&#xff1f;RWM通过双自回归机制有效解决了复合误差、部分可观测性和随机动力学等关键挑战&#xff0c;在不依赖领域特定归纳偏见的条件下实现了卓越的预测准…...

DockerHub与私有镜像仓库在容器化中的应用与管理

哈喽&#xff0c;大家好&#xff0c;我是左手python&#xff01; Docker Hub的应用与管理 Docker Hub的基本概念与使用方法 Docker Hub是Docker官方提供的一个公共镜像仓库&#xff0c;用户可以在其中找到各种操作系统、软件和应用的镜像。开发者可以通过Docker Hub轻松获取所…...

python/java环境配置

环境变量放一起 python&#xff1a; 1.首先下载Python Python下载地址&#xff1a;Download Python | Python.org downloads ---windows -- 64 2.安装Python 下面两个&#xff0c;然后自定义&#xff0c;全选 可以把前4个选上 3.环境配置 1&#xff09;搜高级系统设置 2…...

解锁数据库简洁之道:FastAPI与SQLModel实战指南

在构建现代Web应用程序时&#xff0c;与数据库的交互无疑是核心环节。虽然传统的数据库操作方式&#xff08;如直接编写SQL语句与psycopg2交互&#xff09;赋予了我们精细的控制权&#xff0c;但在面对日益复杂的业务逻辑和快速迭代的需求时&#xff0c;这种方式的开发效率和可…...

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility

Cilium动手实验室: 精通之旅---20.Isovalent Enterprise for Cilium: Zero Trust Visibility 1. 实验室环境1.1 实验室环境1.2 小测试 2. The Endor System2.1 部署应用2.2 检查现有策略 3. Cilium 策略实体3.1 创建 allow-all 网络策略3.2 在 Hubble CLI 中验证网络策略源3.3 …...

【Go】3、Go语言进阶与依赖管理

前言 本系列文章参考自稀土掘金上的 【字节内部课】公开课&#xff0c;做自我学习总结整理。 Go语言并发编程 Go语言原生支持并发编程&#xff0c;它的核心机制是 Goroutine 协程、Channel 通道&#xff0c;并基于CSP&#xff08;Communicating Sequential Processes&#xff0…...

python爬虫:Newspaper3k 的详细使用(好用的新闻网站文章抓取和解析的Python库)

更多内容请见: 爬虫和逆向教程-专栏介绍和目录 文章目录 一、Newspaper3k 概述1.1 Newspaper3k 介绍1.2 主要功能1.3 典型应用场景1.4 安装二、基本用法2.2 提取单篇文章的内容2.2 处理多篇文档三、高级选项3.1 自定义配置3.2 分析文章情感四、实战案例4.1 构建新闻摘要聚合器…...

如何为服务器生成TLS证书

TLS&#xff08;Transport Layer Security&#xff09;证书是确保网络通信安全的重要手段&#xff0c;它通过加密技术保护传输的数据不被窃听和篡改。在服务器上配置TLS证书&#xff0c;可以使用户通过HTTPS协议安全地访问您的网站。本文将详细介绍如何在服务器上生成一个TLS证…...

WEB3全栈开发——面试专业技能点P2智能合约开发(Solidity)

一、Solidity合约开发 下面是 Solidity 合约开发 的概念、代码示例及讲解&#xff0c;适合用作学习或写简历项目背景说明。 &#x1f9e0; 一、概念简介&#xff1a;Solidity 合约开发 Solidity 是一种专门为 以太坊&#xff08;Ethereum&#xff09;平台编写智能合约的高级编…...

【C++特殊工具与技术】优化内存分配(一):C++中的内存分配

目录 一、C 内存的基本概念​ 1.1 内存的物理与逻辑结构​ 1.2 C 程序的内存区域划分​ 二、栈内存分配​ 2.1 栈内存的特点​ 2.2 栈内存分配示例​ 三、堆内存分配​ 3.1 new和delete操作符​ 4.2 内存泄漏与悬空指针问题​ 4.3 new和delete的重载​ 四、智能指针…...