当前位置: 首页 > news >正文

鹅厂面试官:Transformer 为何需要位置编码?

最近这一两周看到不少互联网公司都已经开始秋招发放Offer。

不同以往的是,当前职场环境已不再是那个双向奔赴时代了。求职者在变多,HC 在变少,岗位要求还更高了。

最近,我们又陆续整理了很多大厂的面试题,帮助一些球友解惑答疑,分享技术面试中的那些弯弯绕绕。

  • 《大模型面试宝典》(2024版) 正式发布!

喜欢本文记得收藏、关注、点赞。更多实战和面试交流,文末加入我们

技术交流

在这里插入图片描述

本文基于 llama 模型的源码,学习相对位置编码的实现方法,本文不细究绝对位置编码和相对位置编码的数学原理。

大模型新人在学习中容易困惑的几个问题:

  • 为什么一定要在 transformer 中使用位置编码?

  • 相对位置编码在 llama 中是怎么实现的?

  • 大模型的超长文本预测和位置编码有什么关系?

01 为什么需要位置编码

很多初学者都会读到这样一句话:transformer 使用位置编码的原因是它不具备位置信息。大家都只把这句话当作公理,却很少思考这句话到底是什么意思?

这句话的意思是,如果没有位置编码,那么 “床前明月”、“前床明月”、“前明床月” 这几个输入,会预测出完全一样的文本。

也就是说,不管你输入的 prompt 顺序是什么,只要 prompt 的文本是相同的,那么模型 decode 的文本就只取决于 prompt 的最后一个 token。

import torch
from torch import nn
import mathbatch = 1
dim = 10
num_head = 2
embedding = nn.Embedding(5, dim)
q_matrix = nn.Linear(dim, dim, bias=False)
k_matrix = nn.Linear(dim, dim, bias=False)
v_matrix = nn.Linear(dim, dim, bias=False)x = embedding(torch.tensor([1,2,3])).unsqueeze(0)
y = embedding(torch.tensor([2,1,3])).unsqueeze(0)def attention(input):q = q_matrix(input).view(batch, -1, num_head, dim // num_head).transpose(1, 2)k = k_matrix(input).view(batch, -1, num_head, dim // num_head).transpose(1, 2)v = v_matrix(input).view(batch, -1, num_head, dim // num_head).transpose(1, 2)attn_weights = torch.matmul(q, k.transpose(2, 3)) / math.sqrt(dim // num_head)attn_weights = nn.functional.softmax(attn_weights, dim=-1)outputs = torch.matmul(attn_weights, v).transpose(1, 2).reshape(1, len([1,2,3]), dim)print(outputs)attention(x)
attention(y)

执行上面的代码会发现,虽然 x 和 y 交换了第一个 token 和第二个 token 的输入顺序,但是第三个 token 的计算结果完全没有发生改变,那么模型预测第四个 token 时,便会得到相同的结果。

如果有读者对矩阵运算感到混淆的话,可以看看下面的简单推导:

图片

可以看出,当第一个 token 与第二个 token 交换顺序后,模型输出矩阵的第一维和第二维也交换了顺序,但输出的值完全没有变化。

第三个 token 的输出结果也是完全没有受到影响,这也就是前面说的:如果没有位置编码,模型 decode 的文本就只取决于 prompt 的最后一个 token

不过需要注意的是,由于 attention_mask 的存在(前置位 token 看不到后置位 token),所以即使不加位置编码,transformer 的输出还是会受到 token 的位置影响。

02 相对位置编码的实现

我们以 modeling_llama.py 的源码为例,来学习相对位置编码的实现方法。

class LlamaRotaryEmbedding(torch.nn.Module):def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):super().__init__()inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))self.register_buffer("inv_freq", inv_freq)# Build here to make `torch.jit.trace` work.self.max_seq_len_cached = max_position_embeddingst = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1)self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)def forward(self, x, seq_len=None):# x: [bs, num_attention_heads, seq_len, head_size]# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.if seq_len > self.max_seq_len_cached:self.max_seq_len_cached = seq_lent = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)freqs = torch.einsum("i,j->ij", t, self.inv_freq)# Different from paper, but it uses a different permutation in order to obtain the same calculationemb = torch.cat((freqs, freqs), dim=-1).to(x.device)self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)return (self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),)def rotate_half(x):"""Rotates half the hidden dims of the input."""x1 = x[..., : x.shape[-1] // 2]x2 = x[..., x.shape[-1] // 2 :]return torch.cat((-x2, x1), dim=-1)def apply_rotary_pos_emb(q, k, cos, sin, position_ids):# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.cos = cos.squeeze(1).squeeze(0)  # [seq_len, dim]sin = sin.squeeze(1).squeeze(0)  # [seq_len, dim]cos = cos[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]sin = sin[position_ids].unsqueeze(1)  # [bs, 1, seq_len, dim]q_embed = (q * cos) + (rotate_half(q) * sin)k_embed = (k * cos) + (rotate_half(k) * sin)return q_embed, k_embed

相对位置编码在 attention 中的应用方法如下:

self.rotary_emb = LlamaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)if past_key_value is not None:# reuse k, v, self_attentionkey_states = torch.cat([past_key_value[0], key_states], dim=1)value_states = torch.cat([past_key_value[1], value_states], dim=1)

根据 value_states 矩阵的形状去调取 cos 和 sin 两个 tensor, cos 与 sin 的维度均是 batch_size * head_num * seq_len * head_dim;

利用 apply_rotary_pos_emb 去修改 query_states 和 key_states 两个 tensor,得到新的 q,k 矩阵

需要注意的是,在解码时,position_ids 的长度是和输入 token 的长度保持一致的,prompt 是 4 个 token 的话。

第一次解码时,position_ids: tensor([[0, 1, 2, 3]], device=‘cuda:0’),q 矩阵与 k 矩阵的相对位置编码信息通过 apply_rotary_pos_emb() 获得;

第二次解码时,position_ids: tensor([[4]], device=‘cuda:0’),当前 token 的相对位置编码信息通过 apply_rotary_pos_emb() 获得。

前 4 个 token 的相对位置编码信息则是通过 key_states = torch.cat([past_key_value[0], key_states], dim=1) 集成到 k 矩阵中;

……

……

以上代码的公式,均可以从苏神原文中找到。

这些代码可以从 llama 模型中剥离出来直接执行,如果感到困惑,可以像下面一样,将 apply_rotary_pos_emb() 的整个过程给 print 出来观察一下:

head_num, head_dim, kv_seq_len = 8, 20, 5
position_ids = torch.tensor([[0, 1, 2, 3, 4]])
query_states = torch.randn(1, head_dim, kv_seq_len, head_dim)
key_states = torch.randn(1, head_dim, kv_seq_len, head_dim)
value_states = torch.randn(1, head_dim, kv_seq_len, head_dim)
rotary_emb = LlamaRotaryEmbedding(head_dim)
cos, sin = rotary_emb(value_states, seq_len=kv_seq_len)
print(cos, sin)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)

03 位置编码与长度外推

长度外推指的是,大模型在训练的只见过长度为 X 的文本,但在实际应用时却有如下情况:

图片

我们假设 X 的取值为 4096,那么也就意味着,模型自始至终没有见到过 pos_id >= 4096 的位置编码,进而导致模型的预测结果完全不可控。

因此,解决长度外推问题的关键便是如何让模型见到比训练文本更长的位置编码。

图片

以上关于文本外推的介绍均是比较大白话的理解,只是为了强调位置编码很重要这一观点。

相关文章:

鹅厂面试官:Transformer 为何需要位置编码?

最近这一两周看到不少互联网公司都已经开始秋招发放Offer。 不同以往的是,当前职场环境已不再是那个双向奔赴时代了。求职者在变多,HC 在变少,岗位要求还更高了。 最近,我们又陆续整理了很多大厂的面试题,帮助一些球…...

MySQL数据库学习指南

一、数据库的库操作 1、创建数据库 2、删除数据库 3、查看数据库 4、选择数据库 5、修改数据库 6、数据库备份与恢复 7、数据库的权限管理 二、数据库的表操作 1、创建表 2、删除表 3、修改表 4、查看表的结构 5、查看表的数据 6、创建索引 7、删除索引 8、约束…...

算法刷题-小猫爬山

本题来源165. 小猫爬山 - AcWing题库 翰翰和达达饲养了 NN 只小猫&#xff0c;这天&#xff0c;小猫们要去爬山。 经历了千辛万苦&#xff0c;小猫们终于爬上了山顶&#xff0c;但是疲倦的它们再也不想徒步走下山了&#xff08;呜咕>_<&#xff09;。 翰翰和达达只好花…...

Maven项目管理工具-初始+环境配置

1. Maven的概念 1.1. 什么是Maven Maven是跨平台的项目管理工具。主要服务于基于Java平台的项目构建&#xff0c;依赖管理和项目信息管理。 理想的项目构建&#xff1a;高度自动化&#xff0c;跨平台&#xff0c;可重用的组件&#xff0c;标准化的流程 maven能够自动下载依…...

【JavaEE初阶】网络编程TCP协议实现回显服务器以及如何处理多个客户端的响应

前言 &#x1f31f;&#x1f31f;本期讲解关于TCP/UDP协议的原理理解~~~ &#x1f308;感兴趣的小伙伴看一看小编主页&#xff1a;GGBondlctrl-CSDN博客 &#x1f525; 你的点赞就是小编不断更新的最大动力 &#x1f386;那么废话不多说…...

Android 中的串口开发

一&#xff1a;背景 本文着重讲安卓下的串口。 由于开源的Android在各种智能设备上的使用越来越多&#xff0c;如车载系统等。在我们的认识中&#xff0c;Android OS的物理接口一般只有usb host接口和耳机接口&#xff0c;但其实安卓支持各种各样的工业接口&#xff0c;如HDM…...

TensorRt OP

在TensorRT中&#xff0c;OP&#xff08;Operations&#xff0c;操作&#xff09;是指网络中的基本计算单元&#xff0c;类似于数学中的运算符。每个OP执行一个特定的计算任务&#xff0c;例如卷积、矩阵乘法、激活函数等。TensorRT通过识别和优化这些OP来提高深度学习模型的推…...

构建负责任的人工智能:数据伦理与隐私保护

构建负责任的人工智能&#xff1a;数据伦理与隐私保护 目录 &#x1f31f; 数据伦理的重要性&#x1f4ca; 公平性评估&#xff1a;实现无偏差的模型&#x1f512; 数据去标识化&#xff1a;保护用户隐私的必要手段&#x1f50d; 透明性与问责&#xff1a;建立可信的数据处理…...

微信小程序live-pusher和video同时使用,video播放声音时时大时小

一、遇到的问题 微信小程序live-pusher和video同时使用,video播放声音时有时无时大时小 二、排查流程 业务是模拟面试,每道题一个推流live-pusher和一个面试题video,一次面试有多道面试题,页面就一个live-pusher和一个video,切换面试题时给live-pusher和video重新赋值u…...

MySQL 分库分表实战

在当今互联网时代&#xff0c;数据量的增长呈爆炸式趋势&#xff0c;传统的单库单表架构已经难以满足大规模数据存储和高并发访问的需求。MySQL 分库分表技术应运而生&#xff0c;它可以有效地提高数据库的性能、扩展性和可用性。本文将详细介绍 MySQL 分库分表的实战经验。 一…...

MySQL—CRUD—进阶—(二) (ಥ_ಥ)

文本目录&#xff1a; ❄️一、新增&#xff1a; ❄️二、查询&#xff1a; 1、聚合查询&#xff1a; 1&#xff09;、聚合函数&#xff1a; 2&#xff09;、GROUP BY子句&#xff1a; 3&#xff09;、HAVING 子句&#xff1a; 2、联合查询&#xff1a; 1&#xff09;、内连接…...

时序分解 | TTNRBO-VMD改进牛顿-拉夫逊算法优化变分模态分解

时序分解 | TTNRBO-VMD改进牛顿-拉夫逊算法优化变分模态分解 目录 时序分解 | TTNRBO-VMD改进牛顿-拉夫逊算法优化变分模态分解效果一览基本介绍程序设计参考资料 效果一览 基本介绍 (创新独家)TTNRBO-VMD改进牛顿-拉夫逊优化算优化变分模态分解TTNRBO–VMD 优化VMD分解层数K和…...

2024“源鲁杯“高校网络安全技能大赛-Misc-WP

Round 1 hide_png 题目给了一张图片&#xff0c;flag就在图片上&#xff0c;不过不太明显&#xff0c;写个python脚本处理一下 from PIL import Image ​ # 打开图像并转换为RGB模式 img Image.open("./attachments.png").convert("RGB") ​ # 获取图像…...

CSS行块标签的显示方式

块级元素 标签&#xff1a;h1-h6&#xff0c;p,div,ul,ol,li,dd,dt 特点&#xff1a; &#xff08;1&#xff09;如果块级元素不设置默认宽度&#xff0c;那么该元素的宽度等于其父元素的宽度。 &#xff08;2&#xff09;所有的块级元素独占一行显示. &#xff08;3&#xff…...

Go 语言中的 for range 循环教程

在 Go 语言中&#xff0c;for range 循环是一个方便的语法结构&#xff0c;用于遍历数组、切片、映射和字符串。本教程将通过示例代码来帮助理解如何在 Go 中使用 for range 循环。 package mainimport "fmt"func main() {// 遍历切片并计算和nums : []int{2, 3, 4}…...

青训营 X 豆包MarsCode 技术训练营--小M的比赛胜场计算

问题描述 小M参加了一场n个人的比赛&#xff0c;比赛规则是所有选手两两对决。每个人有一个能力值&#xff0c;对应着他们的序号。参赛者同时被分为黄色或蓝色两种颜色。比赛胜负的规则如下&#xff1a; 当比赛双方颜色不同时&#xff0c;能力值大的选手获胜&#xff1b; 当比…...

海王3纯源码

海王3是一款热门的捕鱼类游戏&#xff0c;其纯源码为开发者提供了一个完整的游戏开发基础。该源码包括客户端和服务端的完整架构&#xff0c;支持多人在线竞技模式和丰富的游戏玩法。服务端采用C语言编写&#xff0c;并使用MySQL数据库来存储玩家数据&#xff0c;确保数据处理的…...

【ShuQiHere】Linux 系统中的硬盘管理详解:命令与技巧

【ShuQiHere】 &#x1f4bd; 在 Linux 系统中&#xff0c;硬盘管理不仅仅是存储数据的操作&#xff0c;更涉及系统性能、数据安全和稳定性的优化。无论你是系统管理员、开发者还是 Linux 爱好者&#xff0c;掌握硬盘管理的基础操作都非常有用。本文将从硬盘健康检查、分区管理…...

数据结构之堆和二叉树的简介

1.树 1.1 树的概念与结构 如图所示&#xff0c;树是⼀种非线性的数据结构&#xff0c;它是由 n &#xff08;n>0&#xff09; 个有限结点组成⼀个具有层次关系的集合。把它叫做树是因为它看起来像一棵倒挂的树&#xff0c;也就是说它是根朝上&#xff0c;而叶朝下的。 …...

微信小程序上传图片添加水印

微信小程序使用wx.chooseMedia拍摄或从手机相册中选择图片并添加水印&#xff0c; 代码如下&#xff1a; // WXML代码&#xff1a;<canvas canvas-id"watermarkCanvas" style"width: {{canvasWidth}}px; height: {{canvasHeight}}px;"></canvas&…...

xshell5找不到匹配的host key算法

xshell5找不到匹配的host key算法&#xff0c;是因为电脑客户端不支持服务器的算法&#xff0c;因此需要再服务器增加算法。 下面以Ubuntu系统为例&#xff0c;修改下面的文件 sudo vim /etc/ssh/sshd_config 增加下面算法 KexAlgorithms diffie-hellman-group-exchange-…...

Linux中安装Tomcat

文章目录 一、Tomcat介绍1.1、Tomcat是什么1.2、Tomcat的工作原理1.3、Tomcat适用的场景1.4、Tomcat与Nginx、Apache比较1.4.1、优势1.4.2、劣势1.4.3、定位功能 1.5、Tomcat 的主要组件1.6、Tomcat 的主要配置文件 二、Tomcat安装2.1、查看可用的JDK2.2、安装OpenJDK 112.3、配…...

RV1126音视频学习(二)-----VI模块

文章目录 前言2.RV1126的视频输入vi模块2.1什么是VI模块2.3RV1126VI模块主要APIRK_MPI_SYS_Init()RK_MPI_VI_SetChnAttrRK_MPI_VI_EnableChnRK_S32 RK_MPI_VI_DisableChnRK_MPI_VI_StartStreamRK_MPI_SYS_GetMediaBufferRK_MPI_MB_GetPtrRK_MPI_MB_GetSizeRK_MPI_MB_ReleaseBuf…...

「C/C++」C++17 之 std::string_view 轻量级字符串视图

✨博客主页何曾参静谧的博客&#x1f4cc;文章专栏「C/C」C/C程序设计&#x1f4da;全部专栏「VS」Visual Studio「C/C」C/C程序设计「UG/NX」BlockUI集合「Win」Windows程序设计「DSA」数据结构与算法「UG/NX」NX二次开发「QT」QT5程序设计「File」数据文件格式「PK」Parasoli…...

Linux内核-内核模块内核参数

作者介绍&#xff1a;简历上没有一个精通的运维工程师。希望大家多多关注作者&#xff0c;下面的思维导图也是预计更新的内容和当前进度(不定时更新)。 我们的Linux进阶部分&#xff0c;到目前为止&#xff0c;已经讲过&#xff1a;硬件&#xff0c;日常运维&#xff0c;基础软…...

中电信翼康工程师:我在 Apache SeaTunnel 社区的贡献之旅

贡献者Github ID&#xff1a;luckyLJY 文章整理&#xff1a;曾辉 Apache SeaTunnel 作为一款强大的数据同步和转换工具&#xff0c;凭借其部署易用性、容错机制、数据源支持、性能优势、功能丰富性以及活跃的社区支持&#xff0c;成为了数据工程师们不可或缺的利器。 因其具有的…...

【ESP32S3】VSCode 开发环境搭建

ESP32S3 有多种开发方式&#xff0c;主流的有 Eclipse 和 VSCode 两种。本文来介绍一下基于 VSCode 的开发环境搭建。 VSCode 环境需要依赖于 ESP-IDF 插件&#xff0c;因此需要在 VSCode 插件市场中搜索并安装 ESP-IDF 插件&#xff1a; 安装完成后侧边栏会多出一个 ESP-IDF …...

大模型,多模态大模型面试问题基础记录24/10/24

大模型&#xff0c;多模态大模型面试问题基础记录24/10/24 问题一&#xff1a;LoRA是用在节省资源的场景下&#xff0c;那么LoRA具体是节省了内存带宽还是显存呢&#xff1f;问题二&#xff1a;假如用pytorch完成一个分类任务&#xff0c;那么具体的流程是怎么样的&#xff1f;…...

使用TimeShift备份和恢复Ubuntu Linux

您是否曾经想过如何备份和恢复您的Ubuntu或Debian系统&#xff1f;TimeShift是一个强大的备份和还原工具。TimeShift允许您创建系统快照&#xff0c;提供了一种在出现意外问题或系统故障时恢复到先前状态的简便方式。您可以使用RSYNC或BTRFS创建快照。 有了这个介绍&#xff0…...

win7现在还能用吗_哪些配置的电脑还可以安装win7系统

2024年了都&#xff0c;win7现在还能用吗&#xff1f;答案是肯定的。那么哪些配置的电脑还可以安装win7系统呢&#xff1f;下面就针对这两个问题详细分区。 win7现在还能用吗&#xff1f; Windows 7系统虽然已经停止官方支持&#xff0c;但仍然可以使用。以下是关于Windows 7系…...

人社局网站建设/网推是什么

ClassLoader翻译过来就是类加载器&#xff0c;普通的Java开发者其实用到的不多&#xff0c;但对于某些框架开发者来说却非常常见。理解ClassLoader的加载机制&#xff0c;也有利于我们编写出更高效的代码。ClassLoader的具体作用就是将class文件加载到jvm虚拟机中去&#xff0c…...

网络推广教育机构/seo网页优化工具

1: 查询错误日志地址show variables like log_error;2&#xff1a;查询慢查询是否开启show variables like log_slow_queries;3&#xff1a;查询慢查询时间show variables like long_query_time;4&#xff1a;设置慢查询[mysqld]slow_query_logonlog_slow_queriesonslow_launch…...

wordpress免费商城主题/潍坊自动seo

ASP.NET 支持两组性能计数器&#xff1a;系统和应用程序。前者在 ASP.NET 性能计数器对象中的 PerfMon 中公开&#xff1b;后者在 ASP.NET Applications 性能对象中公开。ASP.NET 性能对象中的 State Server Sessions 计数器&#xff08;仅适用于在其中运行状态服务器的服务器计…...

wordpress 热门排序/seo草根博客

Vue,v-for循环遍历方式 1.v-for循环普通数组 item是自定义名称&#xff0c; in后面加的是 list这个普通数组 1 <!DOCTYPE html>2 <html>3 <head>4 <meta charset"utf-8">5 <title></title>6 </head&…...

用什么软件做动漫视频网站/宁波网络营销推广公司

一、for语句 (1)、语句结构 for 定义变量 do 使用变量执行动作 done 结束标志(2)、示例 格式1 seq&#xff08;启始值&#xff0c;间隔值&#xff0c;最终值&#xff09; for WESTOS in $(seq 1 2 10) do echo $WESTOS done格…...

视频网站建设成本/关键词优化的发展趋势

题目来源&#xff1a;http://acm.hdu.edu.cn/showproblem.php?pid1024 (http://www.fjutacm.com/Problem.jsp?pid1375) 题意&#xff1a;长度为n的序列里&#xff0c;m段不相关区间的最大和 思路&#xff1a;我们先要确定一个东西&#xff0c;就是状态&#xff0c;这里我用dp…...