当前位置: 首页 > news >正文

【模型学习之路】手写+分析bert

手写+分析bert

目录

前言

架构

embeddings

Bertmodel

预训练任务

MLM

NSP

Bert

后话

netron可视化

code2flow可视化

fine tuning


前言

Attention is all you need!

读本文前,建议至少看懂【模型学习之路】手写+分析Transformer-CSDN博客。

毕竟Bert是transformer的变种之一。

架构

embeddings

Bert可以说就是transformer的Encoder,就像训练卷积网络时可以利用现成的网络然后fine tune就投入使用一样,Bert的动机就是训练一种预训练模型,之后根据不同的场景可以做不同的fine tune。

这里我们还是B代表批次(对于Bert,一个Batch可以输入一到两个句子,输入两个句子时,两个直接拼接就好了),m代表一个batch的单词数,n表示词向量的长度。

Bert的输入是三种输入之和(维度设定我们与本系列上一篇文章保持相同):

token_embeddings  和Transformer完全一样。

segment_embeddings  用来标记句子。第一个句子每个单词标0,第二个句子的每个单词标1。

pos_embeddings  用来标记位置,维度和Transformer中的一样,但是Bert的pos_embeddings是训练出来的(这意味它成为了神经网络里要训练的参数了)。

def get_token_and_segments(tokens_a, tokens_b=None):"""bert的输入之一:token embeddingsbert的输入之二:segment embeddingspos_embeddings在后面的模型里面"""tokens = ['<cls>'] + tokens_a + ['<sep>']segments = [0] * (len(tokens_a) + 2)if tokens_b is not None:tokens += tokens_b + ['<sep>']segments += [1] * (len(tokens_b) + 1)return tokens, segments

Bertmodel

Bert的单个EncpderLayer和Transformer是一样的,我们直接把上一节的代码复制过来就好。

组装好。

class BertModel(nn.Module):def __init__(self, vocab, n, d_ff, h, n_layers,max_len=1000, k=768, v=768):super(BertModel, self).__init__()self.token_embeddings = nn.Embedding(vocab, n)  # [B, m]->[B, m, vocab]->[B, m, n]self.segment_embeddings = nn.Embedding(2, n)  # [B, m]->[B, m, 2]->[B, m, n]self.pos_embeddings = nn.Parameter(torch.randn(1, max_len, n))  # [1, max_len, n]self.layers = nn.ModuleList([EncoderLayer(n, h, k, v, d_ff)for _ in range(n_layers)])def forward(self, tokens, segments, m):  # m是句子长度X = self.token_embeddings(tokens) + \self.segment_embeddings(segments)X += self.pos_embeddings[:, :X.shape[1], :]for layer in self.layers:X, attn = layer(X)return X

简单测试一下。

# 弄一点数据测试一下tokens = torch.randint(0, 100, (2, 10))  # [B, m]segments = torch.randint(0, 2, (2, 10))  # [B, m]m = 10bert = BertModel(100, 768, 3072, 12, 12)out = bert(tokens, segments, m)print(out.shape)  # [2, 10, 768]

 

预训练任务

Bert在训练时要做两种训练,这里先画个图表示架构,后面给出分析和代码。

MLM

Maked language model,是指在训练的时候随即从输入预料上mask掉一些单词,然后通过的上下文预测该单词,该任务非常像我们在中学时期经常做的完形填空。

在BERT的实验中,15%的WordPiece Token会被随机Mask掉。在训练模型时,一个句子会被多次喂到模型中用于参数学习,但是Google并没有在每次都mask掉这些单词,而是在确定要Mask掉的单词之后,80%的时候会直接替换为[Mask],10%的时候将其替换为其它任意单词,10%的时候会保留原始Token。(这里就不深入了)

class MLM(nn.Module):def __init__(self, vocab, n, mlm_hid):super(MLM, self).__init__()self.mlp = nn.Sequential(nn.Linear(n, mlm_hid),nn.ReLU(),nn.LayerNorm(mlm_hid),nn.Linear(mlm_hid, vocab))def forward(self, X, P):# X: [B, m, n]# P: [B, p]# 这里P指的是记录了要mask的元素的矩阵,若P(i,j)==k,表示X(i,k)被mask了p = P.shape[1]P = P.reshape(-1)batch_size = X.shape[0]batch_idx = torch.arange(batch_size)batch_idx = torch.repeat_interleave(batch_idx, p)X = X[batch_idx, P].reshape(batch_size, p, -1)  # [B, p, n]out = self.mlp(X)return out

这里的forward的逻辑有点麻烦,要读懂的话可以要手推一下。p是每一个Batch中mask的词的个数。(即在一个Batch中,m个词挑出了p个)

NSP

Next Sentence Prediction的任务是判断句子B是否是句子A的下文。训练数据的生成方式是从平行语料中随机抽取的连续两句话,其中50%保留抽取的两句话,它们符合IsNext关系,另外50%的第二句话是随机从预料中提取的,它们的关系是NotNext的。这个关系由每个句子的第一个token——<cls>捕捉。

class NSP(nn.Module):def __init__(self, n, nsp_hid):super(NSP, self).__init__()self.mlp = nn.Sequential(nn.Linear(n, nsp_hid),nn.Tanh(),nn.Linear(nsp_hid, 2))def forward(self, X):# X: [B, m, n]X = X[:, 0, :]  # [B, n]out = self.mlp(X)  # [B, 2]return out

 

Bert

下面拼装Bert。

class Bert(nn.Module):def __init__(self, vocab, n, d_ff, h, n_layers,max_len=1000, k=768, v=768, mlm_feat=768, nsp_feat=768):super(Bert, self).__init__()self.encoder = BertModel(vocab, n, d_ff, h, n_layers, max_len, k, v)self.mlm = MLM(vocab, n, mlm_feat)self.nsp = NSP(n, nsp_feat)def forward(self, tokens, segments, m, P=None):X = self.encoder(tokens, segments, m)mlm_out = self.mlm(X, P) if P is not None else Nonensp_out = self.nsp(X)return X, mlm_out, nsp_out

后话

netron可视化

利用netron可视化。

test_tokens = torch.randint(0, 100, (2, 10))  # [B, m]
test_segments = torch.randint(0, 2, (2, 10))  # [B, m]
test_P = torch.tensor([[1, 2, 4, 6, 8], [1, 3, 4, 5, 6]])
test_m = 10
test_bert = Bert(100, 768, 3072, 12, 12)
test_X, test_mlm_out, test_nsp_out = test_bert(test_tokens, test_segments, test_m, test_P)modelData = "./demo.pth"
torch.onnx.export(test_bert, (test_tokens, test_segments), modelData)
netron.start(modelData)

截取部分看一下。

code2flow可视化

code2flow可以可视化代码函数和类的相互调用关系。

code2flow.code2flow([r'代码路径.py'], '输出路径.svg')

这里生成的png,其实svg清晰得多。

fine tuning

Bert的精髓在于,Bert只是一个编码器(Encoder),经过MLM和NSP两个任务的训练之后,可以自己在它的基础上训练一个Decoder来输出特定的值、得到特定的效果。这也是Bert的神奇和魅力所在!通过两个任务训练出一个编码器,然后可以通过不同的Decoder达到各种效果!

持续探索Bert......

相关文章:

【模型学习之路】手写+分析bert

手写分析bert 目录 前言 架构 embeddings Bertmodel 预训练任务 MLM NSP Bert 后话 netron可视化 code2flow可视化 fine tuning 前言 Attention is all you need! 读本文前&#xff0c;建议至少看懂【模型学习之路】手写分析Transformer-CSDN博客。 毕竟Bert是tr…...

Redis学习文档(常见面试题)

目录 Redis回收使用的是什么算法&#xff1f; Redis如何做大量数据插入&#xff1f; 为什么要做Redis分区&#xff1f; 你知道有哪些Redis分区实现方案&#xff1f; Redis分区有什么缺点&#xff1f; Redis持久化数据和缓存怎么做扩容&#xff1f; 分布式Redis是前期做还…...

【C++刷题】力扣-#594-最长和谐子序列

题目描述 和谐数组是指一个数组里元素的最大值和最小值之间的差别 正好是 1 。 给你一个整数数组 nums &#xff0c;请你在所有可能的子序列中找到最长的和谐子序列的长度。 数组的 子序列是一个由数组派生出来的序列&#xff0c;它可以通过删除一些元素或不删除元素、且不改变…...

MoveIt 控制自己的真实机械臂【2】——编写 action server 端代码

完成了 MoveIt 这边 action client 的基本配置&#xff0c;MoveIt 理论上可以将规划好的 trajectory 以 action 的形式发布出来了&#xff0c;浅浅尝试一下&#xff0c;在 terminal 中运行 roslaunch xmate7_moveit_config_new demo.launch 报错提示他在等待 xmate_arm_control…...

C#制作学生管理系统

定义学生类 定义一个简单的类来表示学生&#xff0c;包括学号、姓名、性别、年龄、电话、地址。再给其添加一个方法利于后续添加方法查看学生信息。 //定义学生类 public class student {public int ID { get; set; }//开放读写权限public string Name { get; set; }public i…...

python Pandas合并(单元格、sheet、excel )

安装 Pandas 和 openpyxl 首先&#xff0c;确保已经安装了 Pandas 和 openpyxl。可以通过 pip 安装&#xff1a; pip install pandas openpyxl 创建 DataFrame import pandas as pd # 创建 DataFrame df1 pd.DataFrame({ 姓名: [张三, 李四, 王五], 年龄: [25, 30, 35]…...

OJ在线编程常见输入输出练习【JavaScript】

&#xff08;注&#xff1a;本文是对【JavaScript Node 】 ACM模式&#xff0c;常见输入输出练习相关内容的介绍&#xff01;&#xff01;&#xff01;&#xff09; 牛客竞赛_ACM/NOI/CSP/CCPC/ICPC算法编程高难度练习赛_牛客竞赛OJ 一、ACM模式下的编辑页面 二、ACM模式下&a…...

新能源汽车空调系统:绿色出行的舒适保障

在新能源汽车迅速发展的今天&#xff0c;空调系统作为提升驾乘舒适度的重要组成部分&#xff0c;发挥着不可或缺的作用。新能源汽车空调系统主要由压缩机、冷凝器、节流装置和蒸发器四大件组成&#xff0c;它们协同工作&#xff0c;为车内提供适宜的温度和湿度环境。 一、压缩…...

Date工具类详细汇总-Date日期相关方法

# 1024程序员节 | 征文 # 目录 简介 Date工具类单元测试 Date工具类 简介 本文章是个人总结实际工作中常用到的Date工具类&#xff0c;主要包含Java-jdk8以下版本的Date相关使用方法&#xff0c;可以方便的在工作中灵活的应用&#xff0c;在个人工作期间频繁使用这些时间的格…...

TMUX1308PWR规格书 数据手册 具有注入电流控制功能的 5V 双向 8:1单通道和 4:1 双通道多路复用器芯片

TMUX1308 和 TMUX1309 为通用互补金属氧化物半导体 (CMOS) 多路复用器 (MUX)。TMUX1308 是 8:1单通道&#xff08;单端&#xff09;多路复用器&#xff0c;而 TMUX1309 是 4:1 双通道&#xff08;差分&#xff09;多路复用器。这些器件可在源极 (Sx) 和漏极 (Dx) 引脚上支持从 …...

证件照怎么换底色?简单又快速!不看后悔

一、引言 证件照在我们的生活中有着广泛的应用&#xff0c;无论是求职、考试还是办理各种证件&#xff0c;都需要用到不同底色的证件照。传统的换底色方法往往比较复杂&#xff0c;需要一定的专业技能和软件操作经验。但是现在&#xff0c;有了更简单快捷的方法&#xff0c;让你…...

Rust 基础语法与常用特性

Rust 跨界&#xff1a;全面掌握跨平台应用开发 第一章&#xff1a;快速上手 Rust 1.2 基础语法与常用特性 1.2.1 数据类型与控制流 数据类型 Rust 提供了丰富的内置数据类型&#xff0c;主要分为标量类型和复合类型。 标量类型 标量类型表示单一的值&#xff0c;Rust 中…...

一、开发环境的搭建

环境搭建步骤&#xff1a; 下载软件安装软件运行软件 其他&#xff1a; Visual studio 安装包文件&#xff1a;https://www.alipan.com/s/nd5RgzD4e3b 下载软件 在浏览器中搜索Visual studio&#xff0c;选择如图的选项 点击该区域&#xff0c;进入该页面&#xff0c;【或…...

Docker:存储原理

Docker&#xff1a;存储原理 镜像联合文件系统overlay镜像存储结构容器存储结构 存储卷绑定挂载存储卷结构 镜像 联合文件系统 联合文件系统Union File System是一种分层&#xff0c;轻量且高效的文件系统。其将整个文件系统分为多个层&#xff0c;层与层之间进行覆盖&#x…...

ts:数组的常用方法(push、pop、shift、unshift、splice、slice)

前端css中filter的使用 一、主要内容说明二、例子&#xff08;一&#xff09;、push方法&#xff08;尾添加&#xff09;1.源码1 &#xff08;push方法&#xff09;2.源码1运行效果 &#xff08;二&#xff09;、pop方法&#xff08;尾删除&#xff09;1.源码2&#xff08;pop方…...

物联网网关确保设备安全

物联网&#xff08;IoT&#xff09;网关在确保设备安全方面扮演着至关重要的角色。 作为连接物联网设备和云端或企业系统的中介&#xff0c;物联网网关可以实施多种安全措施来保护设备和数据。 是物联网网关确保设备安全的关键方法&#xff1a; 1. 设备认证和授权 认证&…...

Vue学习笔记(五)

Class绑定 数据绑定的一个常见需求场景式操纵元素的CSS class列表&#xff0c;因为class是attribute,我们可以和其他attribute一样使用v-bind将它们和动态的字符串绑定。但是&#xff0c;在处理比较复杂的绑定时&#xff0c;通过拼接生成字符串是麻烦且易出错的。因此&#xf…...

Nestjs返回格式小结

在 NestJS 中&#xff0c;除了 text/event-stream&#xff08;用于 Server-Sent Events&#xff09;之外&#xff0c;还有多种格式的返回方式&#xff0c;具体取决于你的应用需求。以下是一些常见的返回格式及其示例&#xff1a; 1. JSON 格式 Get(json) getJsonResponse(Res…...

【力扣刷题实战】相同的树

大家好&#xff0c;我是小卡皮巴拉 文章目录 目录 力扣题目&#xff1a; 相同的树 题目描述 示例 1&#xff1a; 示例 2&#xff1a; 示例 3&#xff1a; 解题思路 题目理解 算法选择 具体思路 解题要点 完整代码&#xff08;C语言&#xff09; 兄弟们共勉 &#…...

Golang | Leetcode Golang题解之第515题在每个树行中找最大值

题目&#xff1a; 题解&#xff1a; func largestValues(root *TreeNode) (ans []int) {if root nil {return}q : []*TreeNode{root}for len(q) > 0 {maxVal : math.MinInt32tmp : qq nilfor _, node : range tmp {maxVal max(maxVal, node.Val)if node.Left ! nil {q …...

Zookeeper 对于 Kafka 的作用是什么?

大家好&#xff0c;我是锋哥。今天分享关于【Zookeeper 对于 Kafka 的作用是什么&#xff1f;】面试题&#xff1f;希望对大家有帮助&#xff1b; Zookeeper 对于 Kafka 的作用是什么&#xff1f; 1000道 互联网大厂Java工程师 精选面试题-Java资源分享网 ZooKeeper 在 Kafka…...

Thread类及线程的核心操作

一. Thread类的常见构造方法 1. Thread() Thread类无参的构造方法, 用于创建Thread类的实例对象. 2. Thread(String name) 带一个参数的Thread类构造方法, 创建一个线程对象, 并给其命名. [注]: 如果不专门给线程命名, 那么线程默认的名字就是Thread-0, Thread-1, Thread-…...

算法|牛客网华为机试11-20C++

牛客网华为机试 上篇&#xff1a;算法|牛客网华为机试1-10C 文章目录 HJ11 数字颠倒HJ12 字符串反转HJ13 句子逆序HJ14 字符串排序HJ15 求int型正整数在内存中存储时1的个数HJ16 购物单HJ17 坐标移动HJ18 识别有效的IP地址和掩码并进行分类统计HJ19 简单错误记录HJ20 密码验证…...

OpenAI低调发布多智能体工具Swarm:让多个智能体协同工作!

大家好&#xff0c;我是木易&#xff0c;一个持续关注AI领域的互联网技术产品经理&#xff0c;国内Top2本科&#xff0c;美国Top10 CS研究生&#xff0c;MBA。我坚信AI是普通人变强的“外挂”&#xff0c;专注于分享AI全维度知识&#xff0c;包括但不限于AI科普&#xff0c;AI工…...

性能之光 年度电竞性能旗舰iQOO 13发布

2024年10月30日&#xff0c;被定义为“性能之光”的年度电竞性能旗舰——iQOO 13正式发布&#xff0c;售价3999元起。iQOO 13作为iQOO 品牌在性能上的又一次深入探索&#xff0c;它像是一束光&#xff0c;引领行业不断拉高性能上限&#xff0c;让用户看到更多的可能性。 iQOO …...

如何避免因不熟悉数据保护法规而受损

在当今数字化时代&#xff0c;数据保护法规的遵守对于企业至关重要。不熟悉新的数据保护法规会导致法律风险增加、财务损失、声誉受损、客户信任下降等多方面的负面影响。其中&#xff0c;法律风险增加尤为严重&#xff0c;因为不符合规定可能引发高额罚款和法律诉讼。企业若未…...

LLaMA Factory 核心原理讲解

大家好,我是herosunly。985院校硕士毕业,现担任算法研究员一职,热衷于大模型算法的研究与应用。曾担任百度千帆大模型比赛、BPAA算法大赛评委,编写微软OpenAI考试认证指导手册。曾获得阿里云天池比赛第一名,CCF比赛第二名,科大讯飞比赛第三名。授权多项发明专利。对机器学…...

Java题集练习5

Java题集练习5&#xff08;集合&#xff09; 1.三种集合差别&#xff0c;集合类都是什么&#xff0c;数据结构是什么&#xff0c;都什么时候用 三者关系 Set集合 Set接口是Collection接口的一个子接口是无序的&#xff0c;set中不包含重复的元素&#xff0c;也就是说set中不…...

操作系统学习笔记-2.3哲学家和管程问题

哲学家问题 问题描述 假设有五位哲学家围坐在一张圆桌旁&#xff0c;每位哲学家面前放有一盘意大利面&#xff0c;他们各自间隔放置一根叉子。哲学家的行为分为“思考”和“进餐”两种状态。为了进餐&#xff0c;哲学家需要同时拿起左手边和右手边的两根叉子。用餐结束后&…...

2023年信息安全工程师摸底测试卷

目录 1.密码算法 2.等级保护 3.密码学 4.安全评估 5.网络安全控制技术 6.恶意代码 7.身份认证 8.资产管理 9.密码分类 10.被动攻击 11.商用密码服务​编辑 12.超文本传输协议 13.数字水印技术 14.信息系统安全设计 15.重放攻击 16.信息资产保护 17.身份认证 …...

网站建设程序的步骤过程/百度官方客服平台

今天看到《银行主数据项目(MDM)的数据持久层,你选择hibernate还是ibatis(MyBatis)》跑到首页来了&#xff0c; 把我最近使用方式分享一下。Hiberante(Spring JDBC freemarker)两次结合&#xff0c;hibernate对简单的数据操作很方便&#xff0c;可以大量减少SQL语句的维护。对于…...

wordpress 支持数据库/百度搜索排名推广

前言在日常的android开发当中,按钮是必不可少控件。但是如果要实现下面的效果恐怕写shape文件都要写的头晕w(&#xff9f;Д&#xff9f;)ww(&#xff9f;Д&#xff9f;)w&#xff0c;所以为了以后的开发&#xff0c;我们就简单的封装下。代码块很简单我们通过GradientDrawabl…...

开发深圳网站建设/专业推广图片

NIS服务器的配置过程 以前在做实验的过程中总结和写的一些教程的一些资料&#xff0c;一直没时间发布到博客上面&#xff0c;五一到了&#xff0c;终于有点时间发布啦&#xff01;关于Linux上面还会有RHCE系列的学习笔记发表 NIS需要的软件包&#xff1a;rpm -ivh ypserv-2.13-…...

政务网站建设管理工作总结/网络营销比较好的企业

解决方法&#xff1a; sudo apt-get install -f ​​​​转载于:https://www.cnblogs.com/wulinmenghuantejing/p/8378005.html...

巴音郭楞网站建设/怎么建个人网站

在分类中&#xff0c;和自己的父类关联public class AssessQualityIndex extends IdEntity<AssessQualityIndex> {private static final long serialVersionUID 1L; private String scoreStandard; // 评分标准private AssessQualityIndex parent; // 父级Js…...

wordpress 邮箱变更/好的seo公司营销网

一、路由基础Routing protocol 用于路由器动态寻找最优路径&#xff0c;并使路由器都拥有路由表&#xff0c;R/p 决定了数据包的上行路径&#xff0c;eg&#xff1a;RIP IGRP EIGRP OSPF,被动路由协议被分配到接口上并决定数据数据包的传送方式&#xff0c; Router:把一个数据包…...