【模型学习之路】手写+分析bert
手写+分析bert
目录
前言
架构
embeddings
Bertmodel
预训练任务
MLM
NSP
Bert
后话
netron可视化
code2flow可视化
fine tuning
前言
Attention is all you need!
读本文前,建议至少看懂【模型学习之路】手写+分析Transformer-CSDN博客。
毕竟Bert是transformer的变种之一。
架构
embeddings
Bert可以说就是transformer的Encoder,就像训练卷积网络时可以利用现成的网络然后fine tune就投入使用一样,Bert的动机就是训练一种预训练模型,之后根据不同的场景可以做不同的fine tune。
这里我们还是B代表批次(对于Bert,一个Batch可以输入一到两个句子,输入两个句子时,两个直接拼接就好了),m代表一个batch的单词数,n表示词向量的长度。
Bert的输入是三种输入之和(维度设定我们与本系列上一篇文章保持相同):
token_embeddings 和Transformer完全一样。
segment_embeddings 用来标记句子。第一个句子每个单词标0,第二个句子的每个单词标1。
pos_embeddings 用来标记位置,维度和Transformer中的一样,但是Bert的pos_embeddings是训练出来的(这意味它成为了神经网络里要训练的参数了)。
def get_token_and_segments(tokens_a, tokens_b=None):"""bert的输入之一:token embeddingsbert的输入之二:segment embeddingspos_embeddings在后面的模型里面"""tokens = ['<cls>'] + tokens_a + ['<sep>']segments = [0] * (len(tokens_a) + 2)if tokens_b is not None:tokens += tokens_b + ['<sep>']segments += [1] * (len(tokens_b) + 1)return tokens, segments
Bertmodel
Bert的单个EncpderLayer和Transformer是一样的,我们直接把上一节的代码复制过来就好。
组装好。
class BertModel(nn.Module):def __init__(self, vocab, n, d_ff, h, n_layers,max_len=1000, k=768, v=768):super(BertModel, self).__init__()self.token_embeddings = nn.Embedding(vocab, n) # [B, m]->[B, m, vocab]->[B, m, n]self.segment_embeddings = nn.Embedding(2, n) # [B, m]->[B, m, 2]->[B, m, n]self.pos_embeddings = nn.Parameter(torch.randn(1, max_len, n)) # [1, max_len, n]self.layers = nn.ModuleList([EncoderLayer(n, h, k, v, d_ff)for _ in range(n_layers)])def forward(self, tokens, segments, m): # m是句子长度X = self.token_embeddings(tokens) + \self.segment_embeddings(segments)X += self.pos_embeddings[:, :X.shape[1], :]for layer in self.layers:X, attn = layer(X)return X
简单测试一下。
# 弄一点数据测试一下tokens = torch.randint(0, 100, (2, 10)) # [B, m]segments = torch.randint(0, 2, (2, 10)) # [B, m]m = 10bert = BertModel(100, 768, 3072, 12, 12)out = bert(tokens, segments, m)print(out.shape) # [2, 10, 768]
预训练任务
Bert在训练时要做两种训练,这里先画个图表示架构,后面给出分析和代码。
MLM
Maked language model,是指在训练的时候随即从输入预料上mask掉一些单词,然后通过的上下文预测该单词,该任务非常像我们在中学时期经常做的完形填空。
在BERT的实验中,15%的WordPiece Token会被随机Mask掉。在训练模型时,一个句子会被多次喂到模型中用于参数学习,但是Google并没有在每次都mask掉这些单词,而是在确定要Mask掉的单词之后,80%的时候会直接替换为[Mask],10%的时候将其替换为其它任意单词,10%的时候会保留原始Token。(这里就不深入了)
class MLM(nn.Module):def __init__(self, vocab, n, mlm_hid):super(MLM, self).__init__()self.mlp = nn.Sequential(nn.Linear(n, mlm_hid),nn.ReLU(),nn.LayerNorm(mlm_hid),nn.Linear(mlm_hid, vocab))def forward(self, X, P):# X: [B, m, n]# P: [B, p]# 这里P指的是记录了要mask的元素的矩阵,若P(i,j)==k,表示X(i,k)被mask了p = P.shape[1]P = P.reshape(-1)batch_size = X.shape[0]batch_idx = torch.arange(batch_size)batch_idx = torch.repeat_interleave(batch_idx, p)X = X[batch_idx, P].reshape(batch_size, p, -1) # [B, p, n]out = self.mlp(X)return out
这里的forward的逻辑有点麻烦,要读懂的话可以要手推一下。p是每一个Batch中mask的词的个数。(即在一个Batch中,m个词挑出了p个)
NSP
Next Sentence Prediction的任务是判断句子B是否是句子A的下文。训练数据的生成方式是从平行语料中随机抽取的连续两句话,其中50%保留抽取的两句话,它们符合IsNext关系,另外50%的第二句话是随机从预料中提取的,它们的关系是NotNext的。这个关系由每个句子的第一个token——<cls>捕捉。
class NSP(nn.Module):def __init__(self, n, nsp_hid):super(NSP, self).__init__()self.mlp = nn.Sequential(nn.Linear(n, nsp_hid),nn.Tanh(),nn.Linear(nsp_hid, 2))def forward(self, X):# X: [B, m, n]X = X[:, 0, :] # [B, n]out = self.mlp(X) # [B, 2]return out
Bert
下面拼装Bert。
class Bert(nn.Module):def __init__(self, vocab, n, d_ff, h, n_layers,max_len=1000, k=768, v=768, mlm_feat=768, nsp_feat=768):super(Bert, self).__init__()self.encoder = BertModel(vocab, n, d_ff, h, n_layers, max_len, k, v)self.mlm = MLM(vocab, n, mlm_feat)self.nsp = NSP(n, nsp_feat)def forward(self, tokens, segments, m, P=None):X = self.encoder(tokens, segments, m)mlm_out = self.mlm(X, P) if P is not None else Nonensp_out = self.nsp(X)return X, mlm_out, nsp_out
后话
netron可视化
利用netron可视化。
test_tokens = torch.randint(0, 100, (2, 10)) # [B, m]
test_segments = torch.randint(0, 2, (2, 10)) # [B, m]
test_P = torch.tensor([[1, 2, 4, 6, 8], [1, 3, 4, 5, 6]])
test_m = 10
test_bert = Bert(100, 768, 3072, 12, 12)
test_X, test_mlm_out, test_nsp_out = test_bert(test_tokens, test_segments, test_m, test_P)modelData = "./demo.pth"
torch.onnx.export(test_bert, (test_tokens, test_segments), modelData)
netron.start(modelData)
截取部分看一下。
code2flow可视化
code2flow可以可视化代码函数和类的相互调用关系。
code2flow.code2flow([r'代码路径.py'], '输出路径.svg')
这里生成的png,其实svg清晰得多。
fine tuning
Bert的精髓在于,Bert只是一个编码器(Encoder),经过MLM和NSP两个任务的训练之后,可以自己在它的基础上训练一个Decoder来输出特定的值、得到特定的效果。这也是Bert的神奇和魅力所在!通过两个任务训练出一个编码器,然后可以通过不同的Decoder达到各种效果!
持续探索Bert......
相关文章:

【模型学习之路】手写+分析bert
手写分析bert 目录 前言 架构 embeddings Bertmodel 预训练任务 MLM NSP Bert 后话 netron可视化 code2flow可视化 fine tuning 前言 Attention is all you need! 读本文前,建议至少看懂【模型学习之路】手写分析Transformer-CSDN博客。 毕竟Bert是tr…...
Redis学习文档(常见面试题)
目录 Redis回收使用的是什么算法? Redis如何做大量数据插入? 为什么要做Redis分区? 你知道有哪些Redis分区实现方案? Redis分区有什么缺点? Redis持久化数据和缓存怎么做扩容? 分布式Redis是前期做还…...
【C++刷题】力扣-#594-最长和谐子序列
题目描述 和谐数组是指一个数组里元素的最大值和最小值之间的差别 正好是 1 。 给你一个整数数组 nums ,请你在所有可能的子序列中找到最长的和谐子序列的长度。 数组的 子序列是一个由数组派生出来的序列,它可以通过删除一些元素或不删除元素、且不改变…...

MoveIt 控制自己的真实机械臂【2】——编写 action server 端代码
完成了 MoveIt 这边 action client 的基本配置,MoveIt 理论上可以将规划好的 trajectory 以 action 的形式发布出来了,浅浅尝试一下,在 terminal 中运行 roslaunch xmate7_moveit_config_new demo.launch 报错提示他在等待 xmate_arm_control…...
C#制作学生管理系统
定义学生类 定义一个简单的类来表示学生,包括学号、姓名、性别、年龄、电话、地址。再给其添加一个方法利于后续添加方法查看学生信息。 //定义学生类 public class student {public int ID { get; set; }//开放读写权限public string Name { get; set; }public i…...

python Pandas合并(单元格、sheet、excel )
安装 Pandas 和 openpyxl 首先,确保已经安装了 Pandas 和 openpyxl。可以通过 pip 安装: pip install pandas openpyxl 创建 DataFrame import pandas as pd # 创建 DataFrame df1 pd.DataFrame({ 姓名: [张三, 李四, 王五], 年龄: [25, 30, 35]…...

OJ在线编程常见输入输出练习【JavaScript】
(注:本文是对【JavaScript Node 】 ACM模式,常见输入输出练习相关内容的介绍!!!) 牛客竞赛_ACM/NOI/CSP/CCPC/ICPC算法编程高难度练习赛_牛客竞赛OJ 一、ACM模式下的编辑页面 二、ACM模式下&a…...
新能源汽车空调系统:绿色出行的舒适保障
在新能源汽车迅速发展的今天,空调系统作为提升驾乘舒适度的重要组成部分,发挥着不可或缺的作用。新能源汽车空调系统主要由压缩机、冷凝器、节流装置和蒸发器四大件组成,它们协同工作,为车内提供适宜的温度和湿度环境。 一、压缩…...

Date工具类详细汇总-Date日期相关方法
# 1024程序员节 | 征文 # 目录 简介 Date工具类单元测试 Date工具类 简介 本文章是个人总结实际工作中常用到的Date工具类,主要包含Java-jdk8以下版本的Date相关使用方法,可以方便的在工作中灵活的应用,在个人工作期间频繁使用这些时间的格…...

TMUX1308PWR规格书 数据手册 具有注入电流控制功能的 5V 双向 8:1单通道和 4:1 双通道多路复用器芯片
TMUX1308 和 TMUX1309 为通用互补金属氧化物半导体 (CMOS) 多路复用器 (MUX)。TMUX1308 是 8:1单通道(单端)多路复用器,而 TMUX1309 是 4:1 双通道(差分)多路复用器。这些器件可在源极 (Sx) 和漏极 (Dx) 引脚上支持从 …...

证件照怎么换底色?简单又快速!不看后悔
一、引言 证件照在我们的生活中有着广泛的应用,无论是求职、考试还是办理各种证件,都需要用到不同底色的证件照。传统的换底色方法往往比较复杂,需要一定的专业技能和软件操作经验。但是现在,有了更简单快捷的方法,让你…...
Rust 基础语法与常用特性
Rust 跨界:全面掌握跨平台应用开发 第一章:快速上手 Rust 1.2 基础语法与常用特性 1.2.1 数据类型与控制流 数据类型 Rust 提供了丰富的内置数据类型,主要分为标量类型和复合类型。 标量类型 标量类型表示单一的值,Rust 中…...

一、开发环境的搭建
环境搭建步骤: 下载软件安装软件运行软件 其他: Visual studio 安装包文件:https://www.alipan.com/s/nd5RgzD4e3b 下载软件 在浏览器中搜索Visual studio,选择如图的选项 点击该区域,进入该页面,【或…...

Docker:存储原理
Docker:存储原理 镜像联合文件系统overlay镜像存储结构容器存储结构 存储卷绑定挂载存储卷结构 镜像 联合文件系统 联合文件系统Union File System是一种分层,轻量且高效的文件系统。其将整个文件系统分为多个层,层与层之间进行覆盖&#x…...

ts:数组的常用方法(push、pop、shift、unshift、splice、slice)
前端css中filter的使用 一、主要内容说明二、例子(一)、push方法(尾添加)1.源码1 (push方法)2.源码1运行效果 (二)、pop方法(尾删除)1.源码2(pop方…...
物联网网关确保设备安全
物联网(IoT)网关在确保设备安全方面扮演着至关重要的角色。 作为连接物联网设备和云端或企业系统的中介,物联网网关可以实施多种安全措施来保护设备和数据。 是物联网网关确保设备安全的关键方法: 1. 设备认证和授权 认证&…...

Vue学习笔记(五)
Class绑定 数据绑定的一个常见需求场景式操纵元素的CSS class列表,因为class是attribute,我们可以和其他attribute一样使用v-bind将它们和动态的字符串绑定。但是,在处理比较复杂的绑定时,通过拼接生成字符串是麻烦且易出错的。因此…...
Nestjs返回格式小结
在 NestJS 中,除了 text/event-stream(用于 Server-Sent Events)之外,还有多种格式的返回方式,具体取决于你的应用需求。以下是一些常见的返回格式及其示例: 1. JSON 格式 Get(json) getJsonResponse(Res…...

【力扣刷题实战】相同的树
大家好,我是小卡皮巴拉 文章目录 目录 力扣题目: 相同的树 题目描述 示例 1: 示例 2: 示例 3: 解题思路 题目理解 算法选择 具体思路 解题要点 完整代码(C语言) 兄弟们共勉 &#…...

Golang | Leetcode Golang题解之第515题在每个树行中找最大值
题目: 题解: func largestValues(root *TreeNode) (ans []int) {if root nil {return}q : []*TreeNode{root}for len(q) > 0 {maxVal : math.MinInt32tmp : qq nilfor _, node : range tmp {maxVal max(maxVal, node.Val)if node.Left ! nil {q …...
web vue 项目 Docker化部署
Web 项目 Docker 化部署详细教程 目录 Web 项目 Docker 化部署概述Dockerfile 详解 构建阶段生产阶段 构建和运行 Docker 镜像 1. Web 项目 Docker 化部署概述 Docker 化部署的主要步骤分为以下几个阶段: 构建阶段(Build Stage):…...
React 第五十五节 Router 中 useAsyncError的使用详解
前言 useAsyncError 是 React Router v6.4 引入的一个钩子,用于处理异步操作(如数据加载)中的错误。下面我将详细解释其用途并提供代码示例。 一、useAsyncError 用途 处理异步错误:捕获在 loader 或 action 中发生的异步错误替…...

Xshell远程连接Kali(默认 | 私钥)Note版
前言:xshell远程连接,私钥连接和常规默认连接 任务一 开启ssh服务 service ssh status //查看ssh服务状态 service ssh start //开启ssh服务 update-rc.d ssh enable //开启自启动ssh服务 任务二 修改配置文件 vi /etc/ssh/ssh_config //第一…...
深入浅出:JavaScript 中的 `window.crypto.getRandomValues()` 方法
深入浅出:JavaScript 中的 window.crypto.getRandomValues() 方法 在现代 Web 开发中,随机数的生成看似简单,却隐藏着许多玄机。无论是生成密码、加密密钥,还是创建安全令牌,随机数的质量直接关系到系统的安全性。Jav…...

【网络安全产品大调研系列】2. 体验漏洞扫描
前言 2023 年漏洞扫描服务市场规模预计为 3.06(十亿美元)。漏洞扫描服务市场行业预计将从 2024 年的 3.48(十亿美元)增长到 2032 年的 9.54(十亿美元)。预测期内漏洞扫描服务市场 CAGR(增长率&…...

家政维修平台实战20:权限设计
目录 1 获取工人信息2 搭建工人入口3 权限判断总结 目前我们已经搭建好了基础的用户体系,主要是分成几个表,用户表我们是记录用户的基础信息,包括手机、昵称、头像。而工人和员工各有各的表。那么就有一个问题,不同的角色…...

[ICLR 2022]How Much Can CLIP Benefit Vision-and-Language Tasks?
论文网址:pdf 英文是纯手打的!论文原文的summarizing and paraphrasing。可能会出现难以避免的拼写错误和语法错误,若有发现欢迎评论指正!文章偏向于笔记,谨慎食用 目录 1. 心得 2. 论文逐段精读 2.1. Abstract 2…...
C++ 基础特性深度解析
目录 引言 一、命名空间(namespace) C 中的命名空间 与 C 语言的对比 二、缺省参数 C 中的缺省参数 与 C 语言的对比 三、引用(reference) C 中的引用 与 C 语言的对比 四、inline(内联函数…...

智能分布式爬虫的数据处理流水线优化:基于深度强化学习的数据质量控制
在数字化浪潮席卷全球的今天,数据已成为企业和研究机构的核心资产。智能分布式爬虫作为高效的数据采集工具,在大规模数据获取中发挥着关键作用。然而,传统的数据处理流水线在面对复杂多变的网络环境和海量异构数据时,常出现数据质…...

Maven 概述、安装、配置、仓库、私服详解
目录 1、Maven 概述 1.1 Maven 的定义 1.2 Maven 解决的问题 1.3 Maven 的核心特性与优势 2、Maven 安装 2.1 下载 Maven 2.2 安装配置 Maven 2.3 测试安装 2.4 修改 Maven 本地仓库的默认路径 3、Maven 配置 3.1 配置本地仓库 3.2 配置 JDK 3.3 IDEA 配置本地 Ma…...