当前位置: 首页 > news >正文

鹅厂面试官:Transformer 为何需要位置编码?

最近这一两周看到不少互联网公司都已经开始秋招发放Offer。

不同以往的是,当前职场环境已不再是那个双向奔赴时代了。求职者在变多,HC 在变少,岗位要求还更高了。

最近,我们又陆续整理了很多大厂的面试题,帮助一些球友解惑答疑,分享技术面试中的那些弯弯绕绕。

  • 《大模型面试宝典》(2024版) 正式发布!

喜欢本文记得收藏、关注、点赞。更多实战和面试交流,文末加入我们

技术交流

在这里插入图片描述

本文基于 llama 模型的源码,学习相对位置编码的实现方法,本文不细究绝对位置编码和相对位置编码的数学原理。

大模型新人在学习中容易困惑的几个问题:

  • 为什么一定要在 transformer 中使用位置编码?

  • 相对位置编码在 llama 中是怎么实现的?

  • 大模型的超长文本预测和位置编码有什么关系?

01 为什么需要位置编码

很多初学者都会读到这样一句话:transformer 使用位置编码的原因是它不具备位置信息。大家都只把这句话当作公理,却很少思考这句话到底是什么意思?

这句话的意思是,如果没有位置编码,那么 “床前明月”、“前床明月”、“前明床月” 这几个输入,会预测出完全一样的文本。

也就是说,不管你输入的 prompt 顺序是什么,只要 prompt 的文本是相同的,那么模型 decode 的文本就只取决于 prompt 的最后一个 token。

import torch
from torch import nn
import mathbatch = 1
dim = 10
num_head = 2
embedding = nn.Embedding(5, dim)
q_matrix = nn.Linear(dim, dim, bias=False)
k_matrix = nn.Linear(dim, dim, bias=False)
v_matrix = nn.Linear(dim, dim, bias=False)x = embedding(torch.tensor([1,2,3])).unsqueeze(0)
y = embedding(torch.tensor([2,1,3])).unsqueeze(0)def attention(input):q = q_matrix(input).view(batch, -1, num_head, dim // num_head).transpose(1, 2)k = k_matrix(input).view(batch, -1, num_head, dim // num_head).transpose(1, 2)v = v_matrix(input).view(batch, -1, num_head, dim // num_head).transpose(1, 2)attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(dim // num_head)attn_weights = nn.functional.softmax(attn_weights, dim=-1)outputs = torch.matmul(attn_weights, v).transpose(1, 2).reshape(1, len([1,2,3]), dim)print(outputs)attention(x)
attention(y)

执行上面的代码会发现,虽然 x 和 y 交换了第一个 token 和第二个 token 的输入顺序,但是第三个 token 的计算结果完全没有发生改变,那么模型预测第四个 token 时,便会得到相同的结果。

如果有读者对矩阵运算感到混淆的话,可以看看下面的简单推导:

图片

可以看出,当第一个 token 与第二个 token 交换顺序后,模型输出矩阵的第一维和第二维也交换了顺序,但输出的值完全没有变化。

第三个 token 的输出结果也是完全没有受到影响,这也就是前面说的:如果没有位置编码,模型 decode 的文本就只取决于 prompt 的最后一个 token

不过需要注意的是,由于 attention_mask 的存在(前置位 token 看不到后置位 token),所以即使不加位置编码,transformer 的输出还是会受到 token 的位置影响。

02 相对位置编码的实现

我们以 modeling_llama.py 的源码为例,来学习相对位置编码的实现方法。

class LlamaRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):super().__init__()inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1)self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)def rotate_half(x):"""Rotates half the hidden dims of the input."""x1 = x[..., : x.shape[-1] // 2]x2 = x[..., x.shape[-1] // 2 :]return torch.cat((-x2, x1), dim=-1)def apply_rotary_pos_emb(q, k, cos, sin, position_ids):# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]q_embed = (q * cos) + (rotate_half(q) * sin)k_embed = (k * cos) + (rotate_half(k) * sin)return q_embed, k_embed

相对位置编码在 attention 中的应用方法如下:

self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)if past_key_value is not None:# reuse k, v, self_attentionkey_states = torch.cat([past_key_value[0], key_states], dim=1)value_states = torch.cat([past_key_value[1], value_states], dim=1)

根据 value_states 矩阵的形状去调取 cos 和 sin 两个 tensor, cos 与 sin 的维度均是 batch_size * head_num * seq_len * head_dim;

利用 apply_rotary_pos_emb 去修改 query_states 和 key_states 两个 tensor,得到新的 q,k 矩阵

需要注意的是,在解码时,position_ids 的长度是和输入 token 的长度保持一致的,prompt 是 4 个 token 的话。

第一次解码时,position_ids: tensor([[0, 1, 2, 3]], device=‘cuda:0’),q 矩阵与 k 矩阵的相对位置编码信息通过 apply_rotary_pos_emb() 获得;

第二次解码时,position_ids: tensor([[4]], device=‘cuda:0’),当前 token 的相对位置编码信息通过 apply_rotary_pos_emb() 获得。

前 4 个 token 的相对位置编码信息则是通过 key_states = torch.cat([past_key_value[0], key_states], dim=1) 集成到 k 矩阵中;

……

……

以上代码的公式,均可以从苏神原文中找到。

这些代码可以从 llama 模型中剥离出来直接执行,如果感到困惑,可以像下面一样,将 apply_rotary_pos_emb() 的整个过程给 print 出来观察一下:

head_num, head_dim, kv_seq_len = 8, 20, 5
position_ids = torch.tensor([[0, 1, 2, 3, 4]])
query_states = torch.randn(1, head_dim, kv_seq_len, head_dim)
key_states = torch.randn(1, head_dim, kv_seq_len, head_dim)
value_states = torch.randn(1, head_dim, kv_seq_len, head_dim)
rotary_emb = LlamaRotaryEmbedding(head_dim)
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
print(cos, sin)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

03 位置编码与长度外推

长度外推指的是,大模型在训练的只见过长度为 X 的文本,但在实际应用时却有如下情况:

图片

我们假设 X 的取值为 4096,那么也就意味着,模型自始至终没有见到过 pos_id >= 4096 的位置编码,进而导致模型的预测结果完全不可控。

因此,解决长度外推问题的关键便是如何让模型见到比训练文本更长的位置编码。

图片

以上关于文本外推的介绍均是比较大白话的理解,只是为了强调位置编码很重要这一观点。

相关文章:

鹅厂面试官:Transformer 为何需要位置编码?

最近这一两周看到不少互联网公司都已经开始秋招发放Offer。 不同以往的是,当前职场环境已不再是那个双向奔赴时代了。求职者在变多,HC 在变少,岗位要求还更高了。 最近,我们又陆续整理了很多大厂的面试题,帮助一些球…...

MySQL数据库学习指南

一、数据库的库操作 1、创建数据库 2、删除数据库 3、查看数据库 4、选择数据库 5、修改数据库 6、数据库备份与恢复 7、数据库的权限管理 二、数据库的表操作 1、创建表 2、删除表 3、修改表 4、查看表的结构 5、查看表的数据 6、创建索引 7、删除索引 8、约束…...

算法刷题-小猫爬山

本题来源165. 小猫爬山 - AcWing题库 翰翰和达达饲养了 NN 只小猫&#xff0c;这天&#xff0c;小猫们要去爬山。 经历了千辛万苦&#xff0c;小猫们终于爬上了山顶&#xff0c;但是疲倦的它们再也不想徒步走下山了&#xff08;呜咕>_<&#xff09;。 翰翰和达达只好花…...

Maven项目管理工具-初始+环境配置

1. Maven的概念 1.1. 什么是Maven Maven是跨平台的项目管理工具。主要服务于基于Java平台的项目构建&#xff0c;依赖管理和项目信息管理。 理想的项目构建&#xff1a;高度自动化&#xff0c;跨平台&#xff0c;可重用的组件&#xff0c;标准化的流程 maven能够自动下载依…...

【JavaEE初阶】网络编程TCP协议实现回显服务器以及如何处理多个客户端的响应

前言 &#x1f31f;&#x1f31f;本期讲解关于TCP/UDP协议的原理理解~~~ &#x1f308;感兴趣的小伙伴看一看小编主页&#xff1a;GGBondlctrl-CSDN博客 &#x1f525; 你的点赞就是小编不断更新的最大动力 &#x1f386;那么废话不多说…...

Android 中的串口开发

一&#xff1a;背景 本文着重讲安卓下的串口。 由于开源的Android在各种智能设备上的使用越来越多&#xff0c;如车载系统等。在我们的认识中&#xff0c;Android OS的物理接口一般只有usb host接口和耳机接口&#xff0c;但其实安卓支持各种各样的工业接口&#xff0c;如HDM…...

TensorRt OP

在TensorRT中&#xff0c;OP&#xff08;Operations&#xff0c;操作&#xff09;是指网络中的基本计算单元&#xff0c;类似于数学中的运算符。每个OP执行一个特定的计算任务&#xff0c;例如卷积、矩阵乘法、激活函数等。TensorRT通过识别和优化这些OP来提高深度学习模型的推…...

构建负责任的人工智能:数据伦理与隐私保护

构建负责任的人工智能&#xff1a;数据伦理与隐私保护 目录 &#x1f31f; 数据伦理的重要性&#x1f4ca; 公平性评估&#xff1a;实现无偏差的模型&#x1f512; 数据去标识化&#xff1a;保护用户隐私的必要手段&#x1f50d; 透明性与问责&#xff1a;建立可信的数据处理…...

微信小程序live-pusher和video同时使用,video播放声音时时大时小

一、遇到的问题 微信小程序live-pusher和video同时使用,video播放声音时有时无时大时小 二、排查流程 业务是模拟面试,每道题一个推流live-pusher和一个面试题video,一次面试有多道面试题,页面就一个live-pusher和一个video,切换面试题时给live-pusher和video重新赋值u…...

MySQL 分库分表实战

在当今互联网时代&#xff0c;数据量的增长呈爆炸式趋势&#xff0c;传统的单库单表架构已经难以满足大规模数据存储和高并发访问的需求。MySQL 分库分表技术应运而生&#xff0c;它可以有效地提高数据库的性能、扩展性和可用性。本文将详细介绍 MySQL 分库分表的实战经验。 一…...

MySQL—CRUD—进阶—(二) (ಥ_ಥ)

文本目录&#xff1a; ❄️一、新增&#xff1a; ❄️二、查询&#xff1a; 1、聚合查询&#xff1a; 1&#xff09;、聚合函数&#xff1a; 2&#xff09;、GROUP BY子句&#xff1a; 3&#xff09;、HAVING 子句&#xff1a; 2、联合查询&#xff1a; 1&#xff09;、内连接…...

时序分解 | TTNRBO-VMD改进牛顿-拉夫逊算法优化变分模态分解

时序分解 | TTNRBO-VMD改进牛顿-拉夫逊算法优化变分模态分解 目录 时序分解 | TTNRBO-VMD改进牛顿-拉夫逊算法优化变分模态分解效果一览基本介绍程序设计参考资料 效果一览 基本介绍 (创新独家)TTNRBO-VMD改进牛顿-拉夫逊优化算优化变分模态分解TTNRBO–VMD 优化VMD分解层数K和…...

2024“源鲁杯“高校网络安全技能大赛-Misc-WP

Round 1 hide_png 题目给了一张图片&#xff0c;flag就在图片上&#xff0c;不过不太明显&#xff0c;写个python脚本处理一下 from PIL import Image ​ # 打开图像并转换为RGB模式 img Image.open("./attachments.png").convert("RGB") ​ # 获取图像…...

CSS行块标签的显示方式

块级元素 标签&#xff1a;h1-h6&#xff0c;p,div,ul,ol,li,dd,dt 特点&#xff1a; &#xff08;1&#xff09;如果块级元素不设置默认宽度&#xff0c;那么该元素的宽度等于其父元素的宽度。 &#xff08;2&#xff09;所有的块级元素独占一行显示. &#xff08;3&#xff…...

Go 语言中的 for range 循环教程

在 Go 语言中&#xff0c;for range 循环是一个方便的语法结构&#xff0c;用于遍历数组、切片、映射和字符串。本教程将通过示例代码来帮助理解如何在 Go 中使用 for range 循环。 package mainimport "fmt"func main() {// 遍历切片并计算和nums : []int{2, 3, 4}…...

青训营 X 豆包MarsCode 技术训练营--小M的比赛胜场计算

问题描述 小M参加了一场n个人的比赛&#xff0c;比赛规则是所有选手两两对决。每个人有一个能力值&#xff0c;对应着他们的序号。参赛者同时被分为黄色或蓝色两种颜色。比赛胜负的规则如下&#xff1a; 当比赛双方颜色不同时&#xff0c;能力值大的选手获胜&#xff1b; 当比…...

海王3纯源码

海王3是一款热门的捕鱼类游戏&#xff0c;其纯源码为开发者提供了一个完整的游戏开发基础。该源码包括客户端和服务端的完整架构&#xff0c;支持多人在线竞技模式和丰富的游戏玩法。服务端采用C语言编写&#xff0c;并使用MySQL数据库来存储玩家数据&#xff0c;确保数据处理的…...

【ShuQiHere】Linux 系统中的硬盘管理详解:命令与技巧

【ShuQiHere】 &#x1f4bd; 在 Linux 系统中&#xff0c;硬盘管理不仅仅是存储数据的操作&#xff0c;更涉及系统性能、数据安全和稳定性的优化。无论你是系统管理员、开发者还是 Linux 爱好者&#xff0c;掌握硬盘管理的基础操作都非常有用。本文将从硬盘健康检查、分区管理…...

数据结构之堆和二叉树的简介

1.树 1.1 树的概念与结构 如图所示&#xff0c;树是⼀种非线性的数据结构&#xff0c;它是由 n &#xff08;n>0&#xff09; 个有限结点组成⼀个具有层次关系的集合。把它叫做树是因为它看起来像一棵倒挂的树&#xff0c;也就是说它是根朝上&#xff0c;而叶朝下的。 …...

微信小程序上传图片添加水印

微信小程序使用wx.chooseMedia拍摄或从手机相册中选择图片并添加水印&#xff0c; 代码如下&#xff1a; // WXML代码&#xff1a;<canvas canvas-id"watermarkCanvas" style"width: {{canvasWidth}}px; height: {{canvasHeight}}px;"></canvas&…...

HBuilderX安装(uni-app和小程序开发)

下载HBuilderX 访问官方网站&#xff1a;https://www.dcloud.io/hbuilderx.html 根据您的操作系统选择合适版本&#xff1a; Windows版&#xff08;推荐下载标准版&#xff09; Windows系统安装步骤 运行安装程序&#xff1a; 双击下载的.exe安装文件 如果出现安全提示&…...

【Web 进阶篇】优雅的接口设计:统一响应、全局异常处理与参数校验

系列回顾&#xff1a; 在上一篇中&#xff0c;我们成功地为应用集成了数据库&#xff0c;并使用 Spring Data JPA 实现了基本的 CRUD API。我们的应用现在能“记忆”数据了&#xff01;但是&#xff0c;如果你仔细审视那些 API&#xff0c;会发现它们还很“粗糙”&#xff1a;有…...

深度学习习题2

1.如果增加神经网络的宽度&#xff0c;精确度会增加到一个特定阈值后&#xff0c;便开始降低。造成这一现象的可能原因是什么&#xff1f; A、即使增加卷积核的数量&#xff0c;只有少部分的核会被用作预测 B、当卷积核数量增加时&#xff0c;神经网络的预测能力会降低 C、当卷…...

如何更改默认 Crontab 编辑器 ?

在 Linux 领域中&#xff0c;crontab 是您可能经常遇到的一个术语。这个实用程序在类 unix 操作系统上可用&#xff0c;用于调度在预定义时间和间隔自动执行的任务。这对管理员和高级用户非常有益&#xff0c;允许他们自动执行各种系统任务。 编辑 Crontab 文件通常使用文本编…...

【Android】Android 开发 ADB 常用指令

查看当前连接的设备 adb devices 连接设备 adb connect 设备IP 断开已连接的设备 adb disconnect 设备IP 安装应用 adb install 安装包的路径 卸载应用 adb uninstall 应用包名 查看已安装的应用包名 adb shell pm list packages 查看已安装的第三方应用包名 adb shell pm list…...

深度学习之模型压缩三驾马车:模型剪枝、模型量化、知识蒸馏

一、引言 在深度学习中&#xff0c;我们训练出的神经网络往往非常庞大&#xff08;比如像 ResNet、YOLOv8、Vision Transformer&#xff09;&#xff0c;虽然精度很高&#xff0c;但“太重”了&#xff0c;运行起来很慢&#xff0c;占用内存大&#xff0c;不适合部署到手机、摄…...

《Docker》架构

文章目录 架构模式单机架构应用数据分离架构应用服务器集群架构读写分离/主从分离架构冷热分离架构垂直分库架构微服务架构容器编排架构什么是容器&#xff0c;docker&#xff0c;镜像&#xff0c;k8s 架构模式 单机架构 单机架构其实就是应用服务器和单机服务器都部署在同一…...

二维FDTD算法仿真

二维FDTD算法仿真&#xff0c;并带完全匹配层&#xff0c;输入波形为高斯波、平面波 FDTD_二维/FDTD.zip , 6075 FDTD_二维/FDTD_31.m , 1029 FDTD_二维/FDTD_32.m , 2806 FDTD_二维/FDTD_33.m , 3782 FDTD_二维/FDTD_34.m , 4182 FDTD_二维/FDTD_35.m , 4793...

鸿蒙Navigation路由导航-基本使用介绍

1. Navigation介绍 Navigation组件是路由导航的根视图容器&#xff0c;一般作为Page页面的根容器使用&#xff0c;其内部默认包含了标题栏、内容区和工具栏&#xff0c;其中内容区默认首页显示导航内容&#xff08;Navigation的子组件&#xff09;或非首页显示&#xff08;Nav…...

五、jmeter脚本参数化

目录 1、脚本参数化 1.1 用户定义的变量 1.1.1 添加及引用方式 1.1.2 测试得出用户定义变量的特点 1.2 用户参数 1.2.1 概念 1.2.2 位置不同效果不同 1.2.3、用户参数的勾选框 - 每次迭代更新一次 总结用户定义的变量、用户参数 1.3 csv数据文件参数化 1、脚本参数化 …...