当前位置: 首页 > news >正文

大模型基础——从零实现一个Transformer(3)

大模型基础——从零实现一个Transformer(1)-CSDN博客


一、前言

之前两篇文章已经讲了Transformer的Embedding,Tokenizer,Attention,Position Encoding,
本文我们继续了解Transformer中剩下的其他组件.

二、归一化

2.1 Layer Normalization

layerNorm是针对序列数据提出的一种归一化方法,主要在layer维度进行归一化,即对整个序列进行归一化。

layerNorm会计算一个layer的所有activation的均值和方差,利用均值和方差进行归一化。

𝜇=∑𝑖=1𝑑𝑥𝑖

𝜎=1𝑑∑𝑖=1𝑑(𝑥𝑖−𝜇)2

归一化后的激活值如下:

𝑦=𝑥−𝜇𝜎+𝜖𝛾+𝛽

其中 𝛾 和 𝛽 是可训练的模型参数。 𝛾 是缩放参数,新分布的方差 𝛾2 ; 𝛽 是平移系数,新分布的均值为 𝛽 。 𝜖 为一个小数,添加到方差上,避免分母为0。

2.2 LayerNormalization 代码实现

import torch
import torch.nn as nnclass LayerNorm(nn.Module):def __init__(self,num_features,eps=1e-6):super().__init__()self.gamma = nn.Parameter(torch.ones(num_features))self.beta = nn.Parameter(torch.zeros(num_features))self.eps = epsdef forward(self,x):"""Args:x (Tensor): (batch_size, seq_length, d_model)Returns:Tensor: (batch_size, seq_length, d_model)"""mean = x.mean(dim=-1,keepdim=True)std = x.std(dim=-1,keepdim=True,unbiased=False)normalized_x = (x - mean) / (std + self.eps)return self.gamma * normalized_x + self.betaif __name__ == '__main__':batch_size = 2seqlen = 3hidden_dim = 4# 初始化一个随机tensorx = torch.randn(batch_size,seqlen,hidden_dim)print(x)# 初始化LayerNormlayer_norm  = LayerNorm(num_features=hidden_dim)output_tensor = layer_norm(x)print("output after layer norm:\n,",output_tensor)torch_layer_norm = torch.nn.LayerNorm(normalized_shape=hidden_dim)torch_output_tensor = torch_layer_norm(x)print("output after torch layer norm:\n",torch_output_tensor)

三、残差连接

残差连接(residual connection,skip residual,也称为残差块)其实很简单

x为网络层的输入,该网络层包含非线性激活函数,记为F(x),用公式描述的话就是:

代码简单实现

x = x + layer(x)

四、前馈神经网络

4.1 Position-wise Feed Forward

Position-wise Feed Forward(FFN),逐位置的前馈网络,其实就是一个全连接前馈网络。目的是为了增加非线性,增强模型的表示能力。

它一个简单的两层全连接神经网络,不是将整个嵌入序列处理成单个向量,而是独立地处理每个位置的嵌入。所以称为position-wise前馈网络层。也可以看为核大小为1的一维卷积。

目的是把输入投影到特定的空间,再投影回输入维度。

FFN具体的公式如下:

𝐹𝐹𝑁(𝑥)=𝑓(𝑥𝑊1+𝑏1)𝑊2+𝑏2

上述公式对应FFN中的向量变换操作,其中f为非线性激活函数。

4.2 FFN代码实现

from torch import nn,Tensor
from torch.nn import functional as Fclass PositonWiseFeedForward(nn.Module):def __init__(self,d_model:int ,d_ff: int ,dropout: float=0.1) -> None:''':param d_model:  dimension of embeddings:param d_ff: dimension of feed-forward network:param dropout: dropout ratio'''super().__init__()self.ff1 = nn.Linear(d_model,d_ff)self.ff2 = nn.Linear(d_ff,d_model)self.dropout = nn.Dropout(dropout)def forward(self,x: Tensor) -> Tensor:''':param x:  (batch_size, seq_length, d_model) output from attention:return: (batch_size, seq_length, d_model)'''return self.ff2(self.dropout(F.relu(self.ff1(x))))

五、Transformer Encoder Block

如图所示,编码器(Encoder)由N个编码器块(Encoder Block)堆叠而成,我们依次实现。

from torch import nn,Tensor
## 之前实现的函数引入
from llm_base.attention.MultiHeadAttention1 import MultiHeadAttention
from llm_base.layer_norm.normal_layernorm import LayerNorm
from llm_base.ffn.PositionWiseFeedForward import PositonWiseFeedForwardfrom typing import *class EncoderBlock(nn.Module):def __init__(self,d_model: int,n_heads: int,d_ff: int,dropout: float,norm_first: bool = False):''':param d_model: dimension of embeddings:param n_heads: number of heads:param d_ff: dimension of inner feed-forward network:param dropout:dropout ratio:param norm_first : if True, layer norm is done prior to attention and feedforward operations(Pre-Norm).Otherwise it's done after(Post-Norm). Default to False.'''super().__init__()self.norm_first = norm_firstself.attention = MultiHeadAttention(d_model,n_heads,dropout)self.norm1 = LayerNorm(d_model)self.ff = PositonWiseFeedForward(d_model,d_ff,dropout)self.norm2 = LayerNorm(d_model)self.dropout1 = nn.Dropout(dropout)self.dropout2 = nn.Dropout(dropout)# self attention sub layerdef _self_attention_sub_layer(self,x: Tensor, attn_mask: Tensor, keep_attentions: bool) -> Tensor:x = self.attention(x,x,x,attn_mask,keep_attentions)return self.dropout1(x)# ffn sub layerdef _ffn_sub_layer(self,x: Tensor) -> Tensor:x = self.ff(x)return self.dropout2(x)def forward(self,src: Tensor,src_mask: Tensor == None,keep_attentions: bool= False) -> Tuple[Tensor,Tensor]:''':param src: (batch_size, seq_length, d_model):param src_mask: (batch_size,  1, seq_length):param keep_attentions:whether keep attention weigths or not. Defaults to False.:return:(batch_size, seq_length, d_model) output of encoder block'''# pass througth multi-head attention# src (batch_size, seq_length, d_model)# attn_score (batch_size, n_heads, seq_length, k_length)x = src# post LN or pre LNif self.norm_first:# pre LNx = x + self._self_attention_sub_layer(self.norm1(x),src_mask,keep_attentions)x = x + self._ffn_sub_layer(self.norm2(x))else:x = self.norm1(x + self._self_attention_sub_layer(x,src_mask,keep_attentions))x = self.norm2(x + self._ffn_sub_layer(x))return x

5.1 Post Norm Vs Pre Norm

公式区别

Pre Norm 和 Post Norm 的式子分别如下:

在大模型的区别

Post-LN :是在 Transformer 的原始版本中使用的归一化方案。在此方案中,每个子层(例如,自注意力机制或前馈网络)的输出先通过子层自身的操作,然后再通过层归一化(Layer Normalization)

Pre-LN:是先对输入进行层归一化,然后再传递到子层操作中。这样的顺序对于训练更深的网络可能更稳定,因为归一化的输入可以帮助缓解训练过程中的梯度消失和梯度爆炸问题。

5.2为什么Pre效果弱于Post

相关文章:

大模型基础——从零实现一个Transformer(3)

大模型基础——从零实现一个Transformer(1)-CSDN博客 一、前言 之前两篇文章已经讲了Transformer的Embedding,Tokenizer,Attention,Position Encoding, 本文我们继续了解Transformer中剩下的其他组件. 二、归一化 2.1 Layer Normalization layerNorm是针对序列数据提出的一种…...

一二三应用开发平台应用开发示例——概述、应用开发示例简介及创建前后端模块

概述 对于应用开发平台的核心基石——系统管理模块,我精心撰写了一份详尽的说明手册。该手册旨在从使用者的角度出发,不仅全面阐述系统的各项属性和功能,更着重强调使用过程中的注意事项和最佳实践。 在手册的编写过程中,我特别…...

springboot+minio+kkfileview实现文件的在线预览

在原来的文章中已经讲述过springbootminio的开发过程,这里不做讲述。 原文章地址: https://blog.csdn.net/qq_39990869/article/details/131598884?spm1001.2014.3001.5501 如果你的项目只是需要在线预览图片或者视频那么可以使用minio自己的预览地址进…...

HTML5+CSS3小实例:粘性文字的滚动效果

实例:粘性文字的滚动效果 技术栈:HTML+CSS 效果: 源码: 【HTML】 <!DOCTYPE html> <html lang="zh-CN"><head><meta charset="UTF-8"><meta name="viewport" content="width=device-width, initial-sca…...

Java 关于抽象 -- Java 语言的抽象类、接口和函数式接口

大家好,我是栗筝i,这篇文章是我的 “栗筝i 的 Java 技术栈” 专栏的第 008 篇文章,在 “栗筝i 的 Java 技术栈” 这个专栏中我会持续为大家更新 Java 技术相关全套技术栈内容。专栏的主要目标是已经有一定 Java 开发经验,并希望进一步完善自己对整个 Java 技术体系来充实自…...

用 Notepad++ 写 Java 程序

安装包 百度网盘 提取码&#xff1a;6666 安装步骤 双击安装包开始安装。 安装完成&#xff1a; 配置编码 用 NotePad 写 Java 程序时&#xff0c;需要设置编码。 在 设置&#xff0c;首选项&#xff0c;新建 中进行设置&#xff0c;可以对每一个新建的文件起作用。 Note…...

malloc brk mmap

malloc 是一个库函数&#xff0c;通常在 C 标准库中实现&#xff0c;用于动态内存分配。malloc 的具体实现可能因库、操作系统和平台而异&#xff0c;但通常它会与底层操作系统提供的内存管理功能进行交互。 对于大多数现代操作系统&#xff08;如 Unix、Linux、Windows 等&am…...

java多线程相关概念

在Java多线程编程中&#xff0c;有几个关键的术语需要理解&#xff1a; 1.线程(Thread)&#xff1a;线程是操作系统能够进行运算调度的最小单位&#xff0c;它被包含在进程之中&#xff0c;是进程中的实际运作单位。 2.进程(Process)&#xff1a;进程是系统进行资源分配和调度…...

【html】简单网页模板源码

大家每一次在写网页的时候会不会因为布局而困扰今天就给大家带来一个我自己亲自编写的网页的基本的模板大家可以直接去利用&#xff0c;大家也可以利用自己的想法去做空间的美化和完善。 源码&#xff1a; html: <!DOCTYPE html> <html lang"zh"><…...

借助Historian Connector + TDengine,打造工业创新底座

在工业自动化的领域中&#xff0c;数据的采集、存储和分析是实现高效决策和操作的基石。AVEVA Historian (原 Wonderware Historian) 作为领先的工业实时数据库&#xff0c;专注于收集和存储高保真度的历史工艺数据。与此同时&#xff0c;TDengine 作为一款专为时序数据打造的高…...

51单片机-实机演示(LED点阵)

目录 前言: 一.线位置 二.扩展 三.总结 前言: 这是一篇关于51单片机实机LED点阵的插线图和代码说明.另外还有一篇我写的仿真的连接在这:http://t.csdnimg.cn/ZNLCl,欢迎大家的点赞,评论,关注. 一.线位置 接线实机图. 引脚位置注意: 1. *-* P00->RE8 P01->RE7 …...

STM32硬件接口I2C应用(基于MP6050)

目录 概述 1 STM32Cube控制配置I2C 1.1 I2C参数配置 1.2 使用STM32Cube产生工程 2 HAL库函数介绍 2.1 初始化函数 2.2 写数据函数 2.3 读数据函数 3 认识MP6050 3.1 MP6050功能介绍 3.2 加速计测量寄存器 ​编辑3.3 温度计量寄存器 3.4 陀螺仪测量寄存器 4 MP60…...

基于JSP的贝儿米幼儿教育管理系统

开头语&#xff1a; 你好呀&#xff0c;我是计算机学长猫哥&#xff01;如果您对本系统感兴趣或者有相关需求&#xff0c;文末可以找到我的联系方式。 开发语言&#xff1a; Java 数据库&#xff1a; MySQL 技术&#xff1a; JSP技术 工具&#xff1a; IDEA/Eclipse、…...

数字化与文化交融,树莓集团助力园区文化升级

树莓集团在产业园运营领域建设了特色空间布局&#xff0c;包括产业实训基地、产业办公中心、业务资源平台、产学研中心、数字资产空间、双创孵化空间、产业实验室和人才项目转化中心等八大板块&#xff0c;共同构建了一个全面而深入的产业支撑体系&#xff0c;为园区文化建设提…...

【原创课程】如何制作安装板

具体步骤如下: 第一步:新建页类型为“安装板布局图(交互式)”并修改页描述为“安装板布局图”。 第二步:新建安装板 第三步:设置图纸上符号元件的部件,双击符号,弹出常规设备窗口,点击部件进行选择 第四步:打开2D安装板导航器,将图纸中的设备拖拽到安装板上 第五步…...

简单聊聊【java.util.Stream】,更新中

public class Main {public static void main(String[] args) {List<Integer> numbers Arrays.asList(1, 2, 3, 4, 5, 6); // 原始容器&#xff1a;java.util.Arrays.ArrayList#ArrayList// 创建一个 Stream&#xff0c;过滤出偶数&#xff0c;并打印它们numbers.str…...

GIS之arcgis系列07:conda环境下安装arcpy环境

首先将python27环境下的“Desktop10.8.pth”拷贝到anaconda环境下。 路径如下&#xff08;仅参考&#xff09;&#xff1a; C:\Python27\ArcGIS10.8\Lib\site-packages\Desktop10.8.pth D:\Anaconda\Lib\site-packages 在anaconda prompt中穿创建一个新环境 conda create -…...

容器运行nslookup提示bash: nslookup: command not found【笔记】

在容器中提示bash: nslookup: command not found&#xff0c;表示容器中没有安装nslookup命令。 可以通过以下命令安装nslookup&#xff1a; 对于基于Debian/Ubuntu的容器&#xff0c;使用以下命令&#xff1a; apt-get update apt-get install -y dnsutils对于基于CentOS/R…...

解析 Spring 框架中的三种 BeanName 生成策略

在 Spring 框架中&#xff0c;定义 Bean 时不一定需要指定名称&#xff0c;Spring 会智能生成默认名称。本文将介绍 Spring 的三种 BeanName 生成器&#xff0c;包括在 XML 配置、Java 注解和组件扫描中使用的情况&#xff0c;并解释它们如何自动创建和管理 Bean 名称。 1. Be…...

细说ARM MCU的串口接收数据的实现过程

目录 一、硬件及工程 1、硬件 2、软件目的 3、创建.ioc工程 二、 代码修改 1、串口初始化函数MX_USART2_UART_Init() &#xff08;1&#xff09;MX_USART2_UART_Init()串口参数初始化函数 &#xff08;2&#xff09;HAL_UART_MspInit()串口功能模块初始化函数 2、串口…...

Java 语言特性(面试系列2)

一、SQL 基础 1. 复杂查询 &#xff08;1&#xff09;连接查询&#xff08;JOIN&#xff09; 内连接&#xff08;INNER JOIN&#xff09;&#xff1a;返回两表匹配的记录。 SELECT e.name, d.dept_name FROM employees e INNER JOIN departments d ON e.dept_id d.dept_id; 左…...

PHP和Node.js哪个更爽?

先说结论&#xff0c;rust完胜。 php&#xff1a;laravel&#xff0c;swoole&#xff0c;webman&#xff0c;最开始在苏宁的时候写了几年php&#xff0c;当时觉得php真的是世界上最好的语言&#xff0c;因为当初活在舒适圈里&#xff0c;不愿意跳出来&#xff0c;就好比当初活在…...

IoT/HCIP实验-3/LiteOS操作系统内核实验(任务、内存、信号量、CMSIS..)

文章目录 概述HelloWorld 工程C/C配置编译器主配置Makefile脚本烧录器主配置运行结果程序调用栈 任务管理实验实验结果osal 系统适配层osal_task_create 其他实验实验源码内存管理实验互斥锁实验信号量实验 CMISIS接口实验还是得JlINKCMSIS 简介LiteOS->CMSIS任务间消息交互…...

如何更改默认 Crontab 编辑器 ?

在 Linux 领域中&#xff0c;crontab 是您可能经常遇到的一个术语。这个实用程序在类 unix 操作系统上可用&#xff0c;用于调度在预定义时间和间隔自动执行的任务。这对管理员和高级用户非常有益&#xff0c;允许他们自动执行各种系统任务。 编辑 Crontab 文件通常使用文本编…...

Python 训练营打卡 Day 47

注意力热力图可视化 在day 46代码的基础上&#xff0c;对比不同卷积层热力图可视化的结果 import torch import torch.nn as nn import torch.optim as optim from torchvision import datasets, transforms from torch.utils.data import DataLoader import matplotlib.pypl…...

十九、【用户管理与权限 - 篇一】后端基础:用户列表与角色模型的初步构建

【用户管理与权限 - 篇一】后端基础:用户列表与角色模型的初步构建 前言准备工作第一部分:回顾 Django 内置的 `User` 模型第二部分:设计并创建 `Role` 和 `UserProfile` 模型第三部分:创建 Serializers第四部分:创建 ViewSets第五部分:注册 API 路由第六部分:后端初步测…...

渗透实战PortSwigger靶场:lab13存储型DOM XSS详解

进来是需要留言的&#xff0c;先用做简单的 html 标签测试 发现面的</h1>不见了 数据包中找到了一个loadCommentsWithVulnerableEscapeHtml.js 他是把用户输入的<>进行 html 编码&#xff0c;输入的<>当成字符串处理回显到页面中&#xff0c;看来只是把用户输…...

Modbus RTU与Modbus TCP详解指南

目录 1. Modbus协议基础 1.1 什么是Modbus? 1.2 Modbus协议历史 1.3 Modbus协议族 1.4 Modbus通信模型 🎭 主从架构 🔄 请求响应模式 2. Modbus RTU详解 2.1 RTU是什么? 2.2 RTU物理层 🔌 连接方式 ⚡ 通信参数 2.3 RTU数据帧格式 📦 帧结构详解 🔍…...

接口 RESTful 中的超媒体:REST 架构的灵魂驱动

在 RESTful 架构中&#xff0c;** 超媒体&#xff08;Hypermedia&#xff09;** 是一个核心概念&#xff0c;它体现了 REST 的 “表述性状态转移&#xff08;Representational State Transfer&#xff09;” 的本质&#xff0c;也是区分 “真 RESTful API” 与 “伪 RESTful AP…...

day51 python CBAM注意力

目录 一、CBAM 模块简介 二、CBAM 模块的实现 &#xff08;一&#xff09;通道注意力模块 &#xff08;二&#xff09;空间注意力模块 &#xff08;三&#xff09;CBAM 模块的组合 三、CBAM 模块的特性 四、CBAM 模块在 CNN 中的应用 一、CBAM 模块简介 在之前的探索中…...