机器学习课程学习周报八
机器学习课程学习周报八
文章目录
- 机器学习课程学习周报八
- 摘要
- Abstract
- 一、机器学习部分
- 1.1 self-attention的计算量
- 1.2 人类理解代替自注意力计算
- 1.2.1 Local Attention/Truncated Attention
- 1.2.2 Stride Attention
- 1.2.3 Global Attention
- 1.2.4 聚类Query和Key
- 1.3 自动选择自注意力计算
- 1.4 Attention Matrix中的线性组合
- 1.5 通过矩阵乘法推导自注意力计算
- 1.6 Batch Normalization
- 总结
摘要
本周的学习重点是自注意力机制的计算优化。我探讨了如何通过Local Attention、Stride Attention、Global Attention等方法减少计算量。此外,还介绍了自动选择注意力计算和Attention Matrix的线性组合方法。最后,补充了Batch Normalization的知识,为模型训练提供了更好的稳定性。
Abstract
This week’s focus is on optimizing the computation of the self-attention mechanism. I explored methods like Local Attention, Stride Attention, and Global Attention to reduce computational load. Additionally, we discussed automatic selection of attention computation and linear combinations in the Attention Matrix. Lastly, we supplemented our understanding with Batch Normalization, enhancing model training stability.
一、机器学习部分
1.1 self-attention的计算量
如果现在自注意力模型输入的序列长度为 N N N,则对应的Query为 N N N个,对应的Key也为 N N N个。它们之间相互计算关联性(即注意力分数),可以得到上图中的Attention Matrix,这个矩阵的复杂度是 N 2 {N^2} N2,当 N N N的数值很大时,该矩阵的计算量就会变得很大。因此,这一节介绍多种方法以加速计算Attention Matrix的计算。
Notice:当 N N N很大时,self-attention的计算才会主导整个模型中计算量。例如:在Transformer模型中,除了self-attention还有其他模块的计算量,self-attention模块的计算量占模型整体计算量是与 N N N有关的,当 N N N过小时,对self-attention的改进计算并不会明显提高Transformer模型的运算速度。
1.2 人类理解代替自注意力计算
根据人类对问题的理解,对Attention Matrix某些位置的值直接赋值,跳过计算步骤,从而减少计算量。
1.2.1 Local Attention/Truncated Attention
计算self-attention时,并非计算整个序列间的self-attention分数,而是只看自己和左右的邻居,其他的关联性都设定为0。下图在Attention Matrix中,表示为灰色的部分都人工设定为0,只计算蓝色部分的self-attention分数。这种方法叫做Local Attention或Truncated Attention。
Local Attention与CNN较为相似,主要体现在它们的局部关注机制上。这种机制使得模型在处理输入数据时,只关注输入数据的局部区域,而不是整体。卷积神经网络(CNN)中,卷积层通过滑动窗口的方式在输入数据上提取特征。这种操作也可以看作是一种局部关注机制,通过卷积核仅关注输入数据的局部区域来提取特征。Local attention相比于之前介绍的包含全序列的注意力,更加注重输入数据的局部关系,与卷积核的滑动也很类似。
1.2.2 Stride Attention
根据自己对问题的理解,计算局部的self-attention并不一定是左右邻居,如下图,可以是分别计算序列中两步前或两步后的关联性,也可以是分别计算序列中一步前或一步后的关联性,灰色的地方设定为0。这种方法叫做Stride Attention。
1.2.3 Global Attention
前面介绍的方法都是以某一个位置为中心,分别计算左右的关联性。Global Attention注重于整个序列,其会添加特殊的token到原始的序列中,特殊的token分别与整个序列计算self-attention,具体做法有两种:
- 从原来的token序列中,选择一部分作为特殊的token。
- 外加一部分额外的token。
从上图的Attention Matrix观察得到,在原始的序列中,第一和第二个位置被选择为特殊的token。从横轴的角度看,第一和第二个位置的Query与整个序列的Key分别做了self-attention。从纵轴的角度看,序列每一个位置的Query都与第一和第二位置的Key做了self-attention。灰色的位置设定为0。
在Big Bird中提出了Random attention并且将其与前面的Local Attention和Global Attention一并融合。
1.2.4 聚类Query和Key
第一步,根据相似度聚类Query和Key,上图中根据不同颜色聚类为了4类。
第二步,相同类之间的Query和Key才做self-attention。
1.3 自动选择自注意力计算
通过神经网络学习出一个0-1矩阵,深色位置代表1,浅色位置代表0。只有深色位置计算self-attention,浅色位置不计算。
输入序列中的每一个位置都通过一个神经网络产生一个长度为 N N N的向量,然后将这些向量拼起来得到大小为 N × N N \times N N×N的矩阵。然而现在这个由向量拼成得到的矩阵中的值,是连续值,要转换为0-1矩阵,这一部分是可以微分的,所以可以通过学习得到,具体需要看Sinkhorn Sorting Network的论文。
1.4 Attention Matrix中的线性组合
计算Attention Matrix的Rank(秩),得到Low Rank,说明该矩阵的很多列是其它列的线性组合。由此可得,实际上并不需要 N × N N \times N N×N的矩阵,目前 N × N N \times N N×N的矩阵中包含很多重复的信息,也许可以通过减少Attention Matrix的大小(主要是列数量)实现减少运算量。
选择具有代表性的Key,得到K个Key,即得到大小为 N × K N \times K N×K的Attention Matrix。接下来考虑self-attention这一层的输出,同样地要从N个Value中挑出具有代表性的K个Value,一个Key对应一个Value向量。然后用Value矩阵乘上Attention Matrix可以得到self-attention层的输出。
为什么我们不能挑出K个代表的Query呢?
输出序列的长度与Query的数量是一致的,如果减少Query的数量,输出序列的长度就会变短。
挑选具有代表性的Key的方法为:
卷积降维和线性组合(K个向量是N个向量的K种线性组合,下图右)
1.5 通过矩阵乘法推导自注意力计算
简要复习一下自注意力机制的矩阵计算过程:第一步,输入序列分别做三种不同的变换,得到 d × N d \times N d×N大小的Query和 d × N d \times N d×N大小的Key,其中 d d d是Query和Key的维度, N N N代表序列的长度。并得到 d ′ × N d' \times N d′×N大小的Value,其中特别用 d ′ d' d′表示Value的维度,是因为Value的维度可以与Query、Key不一样。第二步, K T {K^{\rm T}} KT乘上 Q Q Q得到Attention Matrix,然后通过softmax做归一化。第三步,用 V V V乘上归一化后的Attention Matrix( A ′ A' A′)得到自注意力层的输出 O O O。
如果我们先忽略softmax的操作,self-attention的计算方法就是上图中第一行的计算过程,现在考虑第二行运算,先算 V V V乘上 K T {K^{\rm T}} KT的结果,再乘上 Q Q Q,这样的计算顺序与第一行有何不同?得到的结果是一样的,运算量是不一样的。
尽管 A ( C P ) = ( A C ) P A\left( {CP} \right) = \left( {AC} \right)P A(CP)=(AC)P,但是第一种计算方式的计算量是 1 0 6 {10^6} 106,第二种计算方式的计算量的 1 0 3 {10^3} 103,两者计算量之间的差异很大。因此我们这里先忽略softmax操作,考虑self-attention中矩阵计算的改进。
根据上图证明, V ( K T Q ) V({K^{\rm T}}Q) V(KTQ)的计算量通常大于 ( V K T ) Q (V{K^{\rm T}})Q (VKT)Q的计算量。
接下来加入softmax,写出计算self-attention的数学表达式:
下面通过数学证明的角度说明更换矩阵乘法顺序,计算self-attention的过程:
还有一个问题是, exp ( q ⋅ k ) ≈ Φ ( q ) ⋅ Φ ( k ) \exp (q \cdot k) \approx \Phi (q) \cdot \Phi (k) exp(q⋅k)≈Φ(q)⋅Φ(k)是如何实现的,具体需要参考下面的论文。
1.6 Batch Normalization
在Transformer的编码器中使用到了Layer Normalization,在上一周的周报中并将其与Batch Normalization做了比较,这里特别补充Batch Normalization的知识。
做标准化的原因是,希望能把不同维度的特征值规范到同样的数值范围,从而使得error surface比较平滑,更好训练。
Batch Normalization是对不同特征向量的同一维度,计算平均值和标准差,然后将特征值减去平均值再除以标准差,实现标准化。标准化后,同一维度上的数值的平均值是0,方差是1,接近高斯分布。
在神经网络中,输入特征 x ~ 1 {\tilde x^1} x~1、 x ~ 2 {\tilde x^2} x~2、 x ~ 3 {\tilde x^3} x~3已经做过了标准化,在经过 W 1 {W^1} W1层后,且输入 W 2 {W^2} W2层之前仍需要做标准化。至于是对激活函数前的 z 1 {z^1} z1、 z 2 {z^2} z2、 z 3 {z^3} z3还是之后的 a 1 {a^1} a1、 a 2 {a^2} a2、 a 3 {a^3} a3做标准化,差别不是很大。以 z 1 {z^1} z1、 z 2 {z^2} z2、 z 3 {z^3} z3为例, z 1 {z^1} z1、 z 2 {z^2} z2、 z 3 {z^3} z3都是向量,做标准化的方法如下:
μ = 1 3 ∑ i = 1 3 z i \mu = \frac{1}{3}\sum\limits_{i = 1}^3 {{z^i}} μ=31i=1∑3zi是对向量 z i {z^i} zi中对应元素进行相加,然后取平均。 σ = 1 3 ∑ i = 1 3 ( z i − μ ) 2 \sigma = \sqrt {\frac{1}{3}\sum\limits_{i = 1}^3 {{{\left( {{z^i} - \mu } \right)}^2}} } σ=31i=1∑3(zi−μ)2是向量 z i {z^i} zi与 μ \mu μ相减,然后逐元素平方,求和平均后,再对向量的逐元素开根号。如果直接看公式会有一些歧义,因为 z i {z^i} zi、 μ \mu μ、 σ \sigma σ都是向量,其中的求和,平方,开根号都是对向量中逐元素操作。最后标准化公式为:
z ~ i = z i − μ σ {{\tilde z}^i} = \frac{{{z^i} - \mu }}{\sigma } z~i=σzi−μ
实际上,GPU的内存不足以把整个dataset的数据一次性加载。因此,只考虑一个batch中的样本,对一个batch中的样本做Batch Normalization。在inference中,不可能等到整个batch数量的输入才做推理,具体方法为:在训练时计算 μ \mu μ和 σ \sigma σ的moving average,训练时的第一个batch为 μ 1 {\mu^1} μ1,第二个batch为 μ 1 {\mu^1} μ1,直到第t个batch为 μ t {\mu^t} μt,且不断地计算moving average:
μ ˉ ← p μ ˉ + ( 1 − p ) μ t \bar \mu \leftarrow p\bar \mu + \left( {1 - p} \right){\mu ^t} μˉ←pμˉ+(1−p)μt
inference中标准化的公式变为:
z ~ i = z i − μ ˉ σ ˉ {{\tilde z}^i} = \frac{{{z^i} - \bar \mu }}{{\bar \sigma }} z~i=σˉzi−μˉ
总结
通过本周的学习,我对自注意力机制的优化策略有了更深入的了解,不同的注意力方法提供了多样化的计算选择,有助于提高模型的效率。下周还会围绕自注意力机制进行拓展学习。
相关文章:

机器学习课程学习周报八
机器学习课程学习周报八 文章目录 机器学习课程学习周报八摘要Abstract一、机器学习部分1.1 self-attention的计算量1.2 人类理解代替自注意力计算1.2.1 Local Attention/Truncated Attention1.2.2 Stride Attention1.2.3 Global Attention1.2.4 聚类Query和Key 1.3 自动选择自…...

福泰轴承股份有限公司进销存系统pf
TOC springboot413福泰轴承股份有限公司进销存系统pf 绪论 1.1 研究背景 现在大家正处于互联网加的时代,这个时代它就是一个信息内容无比丰富,信息处理与管理变得越加高效的网络化的时代,这个时代让大家的生活不仅变得更加地便利化&#…...

【k8s从节点报错】error: You must be logged in to the server (Unauthorized)
k8s主节点可以获取nodes节点信息,但是从节点无法获取,且报错“error: You must be logged in to the server (Unauthorized)” 排查思路: 当时证书过期了,只处理的主节点的证书过期,没有处理从节点的 kubeadm alpha …...

风清扬/基于Java语言的光伏监控系统+光伏发电预测+光伏项目+光伏运维+光伏储能项目
基于Java语言的光伏监控系统光伏发电预测光伏项目光伏运维光伏储能项目 介绍 基于Java语言的光伏监控系统光伏发电系统光伏软件系统光伏监控系统源码光伏发电系统源码 基于Java语言的光伏监控系统光伏发电预测光伏项目光伏运维光伏储能项目 安装教程 参与贡献 Fork 本仓库新…...

Datawhale X 魔搭 AI夏令营第四期 魔搭-AIGC方向全过程笔记
task1: 传送门 task2: 传送门 task3: 传送门 目录 Task1 赛题内容 可图Kolors-LoRA风格故事挑战赛 baseline要点讲解(请配合Datawhale速通教程食用) Step1 设置算例及比赛账号的报名和授权 Step2 进行赛事报名并创建PAI实例 Step3 执行baseline Step4…...

数组---怎么样定义和引用数组
一怎么定义数组 例 int a[10]; //定义了一个一维数组,数组名为a,此数组包含10个整型元素 所以我们了解到数组的基本定义为 类型符 数组名 [常量表达式] 定义数组可以包括常量和符号常量如 int [ 35 ];但是不能利用变量定义如 int n; …...
Nginx—Rewrite
目录 一、Nginx—Rewrite概述 1、常用的Nginx正则表达式 2、Rewrite功能 3、Rewrite跳转实现 4、Rewrite执行顺序和语法格式 二、location概述 1、location分类 2、location 常用的匹配规则 3、location 优先级 案例一: 案例二: 案例三&…...

《深入浅出WPF》读书笔记.5控件与布局(上)
《深入浅出WPF》读书笔记.5控件与布局(上) 背景 深入浅出WPF书籍学习笔记附代码。WPF中数据是核心是主动的,UI是数据的表达是被动的。 程序的本质是数据算法;控件的本质是数据行为; 5.控件与布局 一、6类控件派生关系 1.布局控件:可以容纳多个控件…...

二叉树的判断
二叉树的判断 判断一颗二叉树是不是搜索二叉树 (左边的比根小,右边的比根大) 中序遍历一下,如果是的话就一定是升序的 如何判断一颗二叉树是否是完全二叉树 1.遍历任意的节点时候,如果返回右孩子没有左孩子&#x…...

Hive3:常用的内置函数
1、查看函数列表 -- 查看所有可用函数 show functions; -- 查看count函数使用方式 describe function extended count;2、数学函数 -- round 取整,设置小数精度 select round(3.1415926); -- 取整(四舍五入) select round(3.1415926, 4); -- 设置小数精度4位(四…...

设计模式---构建者模式(Builder Pattern)
构建者模式(Builder Pattern) 是一种创建型设计模式,旨在将复杂对象的构建过程与其表示分离。它允许使用相同的构建过程创建不同的表示。该模式通常用于构建复杂对象,这些对象由多个部分组成或具有多个可选属性。 构建者模式的核…...
Pytorch中transform的应用
在PyTorch中,transforms模块主要用于对图像进行预处理和数据增强,以便于训练深度学习模型。这些转换操作可以包括裁剪、缩放、旋转、翻转等,以及对图像进行标准化处理。下面将详细介绍一些常用的transforms操作及其应用。 1. 常用的transfor…...
okular阅读软件简介
okular阅读软件官网:https://okular.kde.org/zh-cn/ Okular 是一款由 KDE 开发的跨平台文档阅读器,以其功能丰富、轻巧快速而著称。它支持多种文件格式,包括 PDF、EPub、DjVu、MD 文档,以及 JPEG、PNG、GIF、Tiff 和 WebP 图像&a…...

【书生大模型实战营(暑假场)闯关材料】基础岛:第1关 书生大模型全链路开源体系
【书生大模型实战营(暑假场)闯关材料】基础岛:第1关 书生大模型全链路开源体系 简介一、背景介绍1.1 背景介绍1.2 全链路开源开放体系的优势 二、全链路开源开放体系的主要特点2.1 模型组件的公开和共享2.2 数据集的公开和共享2.3 模型的互操…...
掌握抽象工厂模式:打造灵活且强大的跨平台产品族
抽象工厂模式是一种创建型设计模式,它的核心思想是提供一个创建一系列相关或相互依赖对象的接口,而无需指定它们具体的类。这种模式通过使用抽象工厂来封装和隔离具体产品的创建过程,使得客户端可以通过工厂接口来创建一族产品,从…...
【Hadoop】建立圈内组件的宏观认识(大纲版)
Hadoop生态圈解析:各组件的主要功能及作用详解 Hadoop生态圈是由一系列开源组件组成的,这些组件共同构建了一个大规模分布式计算和存储平台。 01存储类型组件 HDFS Hadoop体系的核心组件之一,它是一个分布式文件系统,被设计用于存…...
NFS主从同步Rsync、sersync2
准备工作检查selinux 防火墙 #关闭 selinux sed -i s/^SELINUX.*/SELINUXdisabled/ /etc/selinux/config #关闭防火墙 systemctl stop firewalld;systemctl disable firewalld1.安装nfs相关包 # 所有节点安装nfs相关包 yum install nfs-utils -y systemctl enable nfs-utils …...
uniapp项目中,在原有数据中增加选中的状态,数据不改变
uniapp项目中,在原有数据中增加选中的状态,选中后打印的数据显示有变化,然而文本的数据并没有发生变化 看代码 export default {data() {return {thicate: [{ id: 1, text: "Item 1" },{ id: 2, text: "Item 2" },{ id…...

WPF自定义控件
控件模板 顾名思义就是在原有的控件上进行模版修改成自己需要的样式 把ProgressBar修改为一个水液面的进度条 <Window x:Class"XH.CustomLesson.MainWindow"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://s…...
Java中的全局异常处理器 -- GlobalExceptionHandler
开发记录:全局异常处理器笔记 import lombok.extern.slf4j.Slf4j; import org.mybatis.spring.MyBatisSystemException; import org.springframework.beans.factory.annotation.Value; import org.springframework.data.redis.RedisConnectionFailureException; im…...

eNSP-Cloud(实现本地电脑与eNSP内设备之间通信)
说明: 想象一下,你正在用eNSP搭建一个虚拟的网络世界,里面有虚拟的路由器、交换机、电脑(PC)等等。这些设备都在你的电脑里面“运行”,它们之间可以互相通信,就像一个封闭的小王国。 但是&#…...
重启Eureka集群中的节点,对已经注册的服务有什么影响
先看答案,如果正确地操作,重启Eureka集群中的节点,对已经注册的服务影响非常小,甚至可以做到无感知。 但如果操作不当,可能会引发短暂的服务发现问题。 下面我们从Eureka的核心工作原理来详细分析这个问题。 Eureka的…...

C/C++ 中附加包含目录、附加库目录与附加依赖项详解
在 C/C 编程的编译和链接过程中,附加包含目录、附加库目录和附加依赖项是三个至关重要的设置,它们相互配合,确保程序能够正确引用外部资源并顺利构建。虽然在学习过程中,这些概念容易让人混淆,但深入理解它们的作用和联…...

uniapp 开发ios, xcode 提交app store connect 和 testflight内测
uniapp 中配置 配置manifest 文档:manifest.json 应用配置 | uni-app官网 hbuilderx中本地打包 下载IOS最新SDK 开发环境 | uni小程序SDK hbulderx 版本号:4.66 对应的sdk版本 4.66 两者必须一致 本地打包的资源导入到SDK 导入资源 | uni小程序SDK …...

FFmpeg:Windows系统小白安装及其使用
一、安装 1.访问官网 Download FFmpeg 2.点击版本目录 3.选择版本点击安装 注意这里选择的是【release buids】,注意左上角标题 例如我安装在目录 F:\FFmpeg 4.解压 5.添加环境变量 把你解压后的bin目录(即exe所在文件夹)加入系统变量…...
【Kafka】Kafka从入门到实战:构建高吞吐量分布式消息系统
Kafka从入门到实战:构建高吞吐量分布式消息系统 一、Kafka概述 Apache Kafka是一个分布式流处理平台,最初由LinkedIn开发,后成为Apache顶级项目。它被设计用于高吞吐量、低延迟的消息处理,能够处理来自多个生产者的海量数据,并将这些数据实时传递给消费者。 Kafka核心特…...
加密通信 + 行为分析:运营商行业安全防御体系重构
在数字经济蓬勃发展的时代,运营商作为信息通信网络的核心枢纽,承载着海量用户数据与关键业务传输,其安全防御体系的可靠性直接关乎国家安全、社会稳定与企业发展。随着网络攻击手段的不断升级,传统安全防护体系逐渐暴露出局限性&a…...

SQL注入篇-sqlmap的配置和使用
在之前的皮卡丘靶场第五期SQL注入的内容中我们谈到了sqlmap,但是由于很多朋友看不了解命令行格式,所以是纯手动获取数据库信息的 接下来我们就用sqlmap来进行皮卡丘靶场的sql注入学习,链接:https://wwhc.lanzoue.com/ifJY32ybh6vc…...

MySQL 数据库深度剖析:事务、SQL 优化、索引与 Buffer Pool
在当今数据驱动的时代,数据库作为数据存储与管理的核心,其性能与可靠性至关重要。MySQL 作为一款广泛使用的开源数据库,在众多应用场景中发挥着关键作用。在这篇博客中,我将围绕 MySQL 数据库的核心知识展开,涵盖事务及…...

【靶场】XXE-Lab xxe漏洞
前言 学习xxe漏洞,搭了个XXE-Lab的靶场 一、搭建靶场 现在需要登录,不知道用户名密码,先随便试试抓包 二、判断是否存在xxe漏洞 1.首先登录抓包 看到xml数据解析,由此判断和xxe漏洞有关,但还不确定xxe漏洞是否存在。 2.尝试xxe 漏洞 判断是否存在xxe漏洞 A.send to …...