当前位置: 首页 > news >正文

word2vector训练代码详解

目录

1.代码实现

2.知识点 


 

1.代码实现

#导包
import math
import torch
from torch import nn
import dltools
#加载PTB数据集  ,需要把PTB数据集的文件夹放在代码上一级目录的data文件中,不用解压
#批次大小、窗口大小、噪声词大小
batch_size, max_window_size, num_noise_words = 512, 5, 5  
#获取数据集迭代器、词汇表
data_iter, vocab = dltools.load_data_ptb(batch_size, max_window_size, num_noise_words)
#讲解嵌入层embedding的用法(此行代码无用)#嵌入层
#通过嵌入层来获取skip—gram的中心词向量和上下文词向量
embed = nn.Embedding(num_embeddings=20, embedding_dim=4)  
# num_embeddings就是词表大小
# X的shape=(batch_size, num_steps)
# --one_hot编码--->(batch_size, num_steps, num_embedding(vocab_size))
# --点乘中心词矩阵-->(batch_size, num_steps, embed_size)
embed.weight.shape   #讲解嵌入层embedding的用法(此行代码无用)
torch.Size([20, 4])

embedding层先one_hot编码,再进行与embedding层的矩阵(num_embeddings,embedding_dim)乘法 

#构造skip_gram的前向传播
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):"""embed_v:表示对中心词进行embedding层embed_u:对上下文词进行embedding层 """v = embed_v(center)                 #中心词的词向量表达u = embed_u(contexts_and_negatives) #上下文词的词向量表达#用中心词来预测上下文词#u_shape = (batch_size, num_steps, embed_size)---->(batch_size, embed_size, num_steps)进行矩阵乘法pred = torch.bmm(v, u.permute(0, 2, 1))  #矩阵乘法(bmm三维乘法),不用管batch_size维度return pred
#假设数据
skip_gram(torch.ones((2, 1), dtype=torch.long), torch.ones((2, 4), dtype=torch.long), embed, embed)
tensor([[[3.1980, 3.1980, 3.1980, 3.1980]],[[3.1980, 3.1980, 3.1980, 3.1980]]], grad_fn=<BmmBackward0>)
#假设数据
skip_gram(torch.ones((2, 1), dtype=torch.long), torch.ones((2, 4), dtype=torch.long), embed, embed).shape

 torch.Size([2, 1, 4])

#带掩码的二元交叉熵损失
class SigmoidBCELoss(nn.Module):def __init__(self):super().__init__()  #直接继承父类的初始化属性和方法def forward(self, inputs, target, mask=None):#nn.functional.binary_cross_entropy_with_logits表示返回的不是转化后的概率,是原始计算的数据结果#weight=mask权重将掩码带上#reduction='none'表示不将计算结果聚合,算损失时(默认聚合)out = nn.functional.binary_cross_entropy_with_logits(inputs, target, weight=mask, reduction='none')return out.mean(dim=1)  #计算结果是二维的,在索引1维度上聚合求平均
loss = SigmoidBCELoss()
[[1.1, -2.2, 3.3, -4.4]] * 2
[[1.1, -2.2, 3.3, -4.4], [1.1, -2.2, 3.3, -4.4]]
torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2).shape

 torch.Size([2, 4])

#假设数据测试
pred = torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2)
label = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])
mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]])
#mask每一行都有4个数值,所以* mask.shape[1]=4
#但是mask中的数值0表示权重,是补充步长的,不重要,需要计算有效序列的损失平均值,所以 / mask.sum(axis=1)
loss(pred, label, mask) * mask.shape[1] / mask.sum(axis=1)

 tensor([0.9352, 1.8462])

#初始化模型参数,定义两个嵌入层
#一开始,embed_weights会标准正态分布的数据初始化
#两个embedding层的参数不一样,不能重复使用,需要初始化定义两个
embed_size = 100
net = nn.Sequential(nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_size),nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_size))

 

#定义训练过程
def train(net, data_iter, lr, num_epochs, device=dltools.try_gpu()):#修改embedding层的初始化方法,使用nn.init.xavier_uniform_初始化embed.weight权重,在NLP中不使用标准正态分布的额数据初始化权重def init_weights(m):if type(m) == nn.Embedding:nn.init.xavier_uniform_(m.weight)net.apply(init_weights)  net = net.to(device)#设置梯度下降的优化器optimizer = torch.optim.Adam(net.parameters(), lr=lr)#设置绘制可视化的动图(epoch——loss)animator = dltools.Animator(xlabel='epoch', ylabel='loss', xlim=[1, num_epochs])#设置累加metric = dltools.Accumulator(2)   #2种数据需要累加for epoch in range(num_epochs):  #遍历训练次数#设置计时器, 赋值批次数量timer, num_batches = dltools.Timer(), len(data_iter)    #data_iter是分好批次的数据集,长度就是批次数量num_batchesfor i, batch in enumerate(data_iter):   #i是索引, batch是取出的一批批数据#梯度清零optimizer.zero_grad()#接收中心词, 上下文词_噪声词, 掩码, 标记目标值 center, context_negative, mask, label = [data.to(device) for data in batch]#调用skip_gram模型预测pred = skip_gram(center, context_negative, embed_v=net[0], embed_u=net[1])#计算损失l = loss(pred.reshape(label.shape).float(), label.float(), mask) / mask.shape[1] * mask.sum(dim=1)#用loss反向传播  ,loss先sum()聚合变成标量(合并成一个数值), 只有标量才能反向传播l.sum().backward()#梯度更新optimizer.step()#累加metric.add(l.sum(), l.numel())   #l.sum()数值求和累加, l.numel()数量累加#   %  取余数      #  //  商向下取整#迭代到总数据量的5%的倍数时 或者 处理到最后一批数据时,执行下面操作#  i+1是因为i是从0开始遍历的if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1:  #epoch + (i+1) / num_batches当前迭代次数占整个数据集的比例animator.add(epoch + (i+1) / num_batches, (metric[0] / metric[1]))print(f'loss {metric[0] / metric[1]:.3f}', f'{metric[1] / timer.stop():.1f} tokens/sec on {str(device)}')      
lr, num_epochs = 0.002, 50
train(net, data_iter, lr, num_epochs)

#如果能够找到词的近义词, 就说明训练的不错
def get_similar_tokens(query_token, k, embed):"""query_token:需要预测的词k:最高相似度的词数量embed:embedding层的哪一层"""#获取词向量权重    (词向量权重*词的one_hot编码,就是词向量)W = embed.weight.dataprint(f'W的shape:{W.shape}')x = W[vocab[query_token]]     #embedding层是按照索引查表查词对应的权重-->优点print(f'x的shape:{x.shape}')#计算余弦相似度#torch.mv两个向量的点乘cos = torch.mv(W, x) / torch.sqrt(torch.sum(W * W, dim=1) * torch.sum(x * x) + 1e-9)print(f'cos的shape:{cos.shape}')#排序选择前k个对应的索引topk = torch.topk(cos, k=k+1)[1].cpu().numpy().astype('int32')for i in topk[1:]:   #排除query_token他本身,自己与自己余弦相似度最高print(f'cosine sim={float(cos[i]):.3f}:{vocab.to_tokens(i)}')
get_similar_tokens('food', 3, net[0])

 

W的shape:torch.Size([6719, 100])
x的shape:torch.Size([100])
cos的shape:torch.Size([6719])
cosine sim=0.430:feed
cosine sim=0.418:precious
cosine sim=0.412:drink

2.知识点 

 

相关文章:

word2vector训练代码详解

目录 1.代码实现 2.知识点 1.代码实现 #导包 import math import torch from torch import nn import dltools #加载PTB数据集 &#xff0c;需要把PTB数据集的文件夹放在代码上一级目录的data文件中&#xff0c;不用解压 #批次大小、窗口大小、噪声词大小 batch_size, ma…...

Python的风格应该是怎样的?除语法外,有哪些规范?

写代码不那么pythonic风格的&#xff0c;多多少少都会让人有点难受。 什么是pythonic呢&#xff1f;简而言之&#xff0c;这是一种写代码时遵守的规范&#xff0c;主打简洁、清晰、可读性高&#xff0c;符合PEP 8&#xff08;Python代码样式指南&#xff09;约定的模式。 Pyth…...

net core mvc 数据绑定 《1》

其它的绑定 跟net mvc 一样 》》MVC core 、framework 一样 1 模型绑定数组类型 2 模型绑定集合类型 3 模型绑定复杂的集合类型 4 模型绑定源 》》》》 模型绑定 使用输入数据的原生请求集合是可以工作的【request[],Querystring,request.from[]】&#xff0c; 但是从可读…...

python为姓名注音实战案例

有如下数据&#xff0c;需要对名字注音。 数据样例&#xff1a;&#x1f447; 一、实现过程 前提条件&#xff1a;由于会用到pypinyin库&#xff0c;所以一定得提前安装。 pip install pypinyin1、详细代码&#xff1a; from pypinyin import pinyin, Style# 输入数据 names…...

MATLAB中的艺术:用爱心形状控制坐标轴

在MATLAB中&#xff0c;坐标轴控制是绘图和数据可视化中的一个重要方面。通过精细地管理坐标轴&#xff0c;我们不仅可以改善图形的视觉效果&#xff0c;还可以赋予图形更深的情感寓意。本文将介绍如何在MATLAB中使用坐标轴控制来绘制一个爱心形状&#xff0c;并探讨其背后的技…...

基于mybatis-plus创建springboot,添加增删改查功能,使用postman来测试接口出现的常见错误

1 当你在使用postman检测 添加和更新功能时&#xff0c;报了一个500错误 查看idea发现是&#xff1a; Data truncation: Out of range value for column id at row 1 通过翻译&#xff1a;数据截断&#xff1a;表单第1行的“id”列出现范围外值。一般情况下&#xff0c;出现这个…...

Java:Object操作

目录 1、Object转List对象2、Object转实体对象 1、Object转List对象 List<User> userList MtUtils.ObjectToList(objData, User.class);/*** Object对象转 List集合** param object Object对象* param clazz 需要转换的集合* param <T> 泛型类* return*/ public s…...

Java-并发基础

启动线程的方式 只有&#xff1a; 1、X extends Thread;&#xff0c;然后X.start 2、X implements Runnable&#xff1b;然后交给Thread运行 有争议可以可以查看 Thread源码的注释&#xff1a; There are two ways to create a new thread of execution.Callable的方式需要…...

速盾:网页游戏部署高防服务器有什么优势?

在当前互联网发展的背景下&#xff0c;网页游戏的市场需求不断增长&#xff0c;相应地带来了对高防服务器的需求。高防服务器可以为网页游戏部署提供许多优势&#xff0c;下面就详细介绍一下。 第一&#xff0c;高防服务器具有强大的抗DDoS攻击能力。DDoS攻击是目前互联网上最…...

【从0开始自动驾驶】ros2编写自定义消息 msg文件和msg文件嵌套

【从0开始自动驾驶】ros2编写自定义消息 msg文件和msg文件嵌套 详细解答和讨论请私信在工作空间内新建一个功能包在msg内创建对应的msg文件创建名为TestMsg.msg的文件创建名为TestSubMsg.msg的文件&#xff08;在前一个msg文件中引用&#xff09;修改CmakeList.txt修改package.…...

docker 部署 Seatunnel 和 Seatunnel Web

docker 部署 Seatunnel 和 Seatunnel Web 说明&#xff1a; 部署方式前置条件&#xff0c;已经在宿主机上运行成功运行文件采用挂载宿主机目录的方式部署SeaTunnel Engine 采用的是混合模式集群 编写Dockerfile并打包镜像 Seatunnel FROM openjdk:8 WORKDIR /opt/seatunne…...

【易上手快捷开发新框架技术】nicegui标签组件lable用法庖丁解牛深度解读和示例源代码IDE运行和调试通过截图为证

传奇开心果微博文系列 序言一、标签组件lable最基本用法示例1.在网页上显示出 Hello World 的标签示例2. 使用 style 参数改变标签样式示例 二、标签组件lable更多用法示例1. 添加按钮动态修改标签文字2. 点击按钮动态改变标签内容、颜色、大小和粗细示例代码3. 添加开关组件动…...

从HarmonyOS Next导出手机照片

1&#xff09;打开DevEco Studio开发工具 2&#xff09;插入USB数据线&#xff0c;连接手机 3&#xff09;在DevEco Studio开发工具&#xff0c;通过View -> Tool Windows -> Device File Browser打开管理工具 4&#xff09;选择storage -> cloud -> 100->fi…...

[Docker学习笔记]Docker的原理Docker常见命令

文章目录 什么是DockerDocker的优势Docker的原理Docker 的安装Docker 的 namespaces Docker的常见命令docker version:查看版本信息docker info 查看docker详细信息我们关注的信息 docker search:镜像搜索docker pull:镜像拉取到本地docker push:推送本地镜像到镜像仓库docker …...

【ESP 保姆级教程】小课设篇 —— 案例:20240507_esp01s+UNO的智能浇水系统

忘记过去,超越自己 ❤️ 博客主页 单片机菜鸟哥,一个野生非专业硬件IOT爱好者 ❤️❤️ 本篇创建记录 2024-09-30 ❤️❤️ 本篇更新记录 2023-09-30 ❤️🎉 欢迎关注 🔎点赞 👍收藏 ⭐️留言📝🙏 此博客均由博主单独编写,不存在任何商业团队运营,如发现错误,请…...

如何设置MySQL分布式架构主键ID,为什么不能使用自增ID或者UUID做主键?

MySQL分布式架构主键ID的设置方法 雪花算法&#xff08;Snowflake&#xff09; 原理&#xff1a;雪花算法是一种生成分布式唯一ID的算法。它由64位二进制数组成&#xff0c;结构如下&#xff1a;1位符号位&#xff08;固定为0&#xff09; 41位时间戳&#xff08;表示从一个固…...

服务器虚拟化详解

服务器虚拟化详解 服务器虚拟化是一种将物理服务器资源转化为虚拟服务器资源的技术&#xff0c;它允许在一台物理服务器上运行多个虚拟服务器&#xff0c;每个虚拟服务器都拥有独立的操作系统、应用程序和资源配置。这种技术极大地提高了服务器的利用率、灵活性和可扩展性&…...

医疗陪诊APP开发实战:从互联网医院系统源码开始

本文将从互联网医院系统源码出发&#xff0c;深入探讨医疗陪诊APP的开发实战。 一、从互联网医院系统源码入手 开发医疗陪诊APP的基础在于互联网医院系统的源码。互联网医院系统通常包括以下几个模块&#xff1a; 1.用户管理&#xff1a;用户注册、登录、信息管理等功能。 …...

jenkins 构建报错ERROR: Error fetching remote repo ‘origin‘

问题描述 修改项目的仓库地址后&#xff0c;使用jenkins构建报错 Running as SYSTEM Building in workspace /var/jenkins_home/workspace/【测试】客户端/client-fonchain-main The recommended git tool is: NONE using credential 680a5841-cfa5-4d8a-bb38-977f796c26dd&g…...

初识C#(三)- 数组

我有17栋楼&#xff0c;在不同地域&#xff0c;都是不同价格租出去给不同的人~ 文章目录 前言一、数组1.1 我有17栋楼 - 数组的声明1.2 包租公&包租婆 - 数组赋值1.3 每个月都要交租的苦逼租客 - 数组的使用 二、字符串2.1 字符串的使用方法 总结 前言 本篇笔记重点描述C#…...

React第五十七节 Router中RouterProvider使用详解及注意事项

前言 在 React Router v6.4 中&#xff0c;RouterProvider 是一个核心组件&#xff0c;用于提供基于数据路由&#xff08;data routers&#xff09;的新型路由方案。 它替代了传统的 <BrowserRouter>&#xff0c;支持更强大的数据加载和操作功能&#xff08;如 loader 和…...

Qt Http Server模块功能及架构

Qt Http Server 是 Qt 6.0 中引入的一个新模块&#xff0c;它提供了一个轻量级的 HTTP 服务器实现&#xff0c;主要用于构建基于 HTTP 的应用程序和服务。 功能介绍&#xff1a; 主要功能 HTTP服务器功能&#xff1a; 支持 HTTP/1.1 协议 简单的请求/响应处理模型 支持 GET…...

蓝桥杯3498 01串的熵

问题描述 对于一个长度为 23333333的 01 串, 如果其信息熵为 11625907.5798&#xff0c; 且 0 出现次数比 1 少, 那么这个 01 串中 0 出现了多少次? #include<iostream> #include<cmath> using namespace std;int n 23333333;int main() {//枚举 0 出现的次数//因…...

USB Over IP专用硬件的5个特点

USB over IP技术通过将USB协议数据封装在标准TCP/IP网络数据包中&#xff0c;从根本上改变了USB连接。这允许客户端通过局域网或广域网远程访问和控制物理连接到服务器的USB设备&#xff08;如专用硬件设备&#xff09;&#xff0c;从而消除了直接物理连接的需要。USB over IP的…...

SQL慢可能是触发了ring buffer

简介 最近在进行 postgresql 性能排查的时候,发现 PG 在某一个时间并行执行的 SQL 变得特别慢。最后通过监控监观察到并行发起得时间 buffers_alloc 就急速上升,且低水位伴随在整个慢 SQL,一直是 buferIO 的等待事件,此时也没有其他会话的争抢。SQL 虽然不是高效 SQL ,但…...

Golang——9、反射和文件操作

反射和文件操作 1、反射1.1、reflect.TypeOf()获取任意值的类型对象1.2、reflect.ValueOf()1.3、结构体反射 2、文件操作2.1、os.Open()打开文件2.2、方式一&#xff1a;使用Read()读取文件2.3、方式二&#xff1a;bufio读取文件2.4、方式三&#xff1a;os.ReadFile读取2.5、写…...

MinIO Docker 部署:仅开放一个端口

MinIO Docker 部署:仅开放一个端口 在实际的服务器部署中,出于安全和管理的考虑,我们可能只能开放一个端口。MinIO 是一个高性能的对象存储服务,支持 Docker 部署,但默认情况下它需要两个端口:一个是 API 端口(用于存储和访问数据),另一个是控制台端口(用于管理界面…...

springboot 日志类切面,接口成功记录日志,失败不记录

springboot 日志类切面&#xff0c;接口成功记录日志&#xff0c;失败不记录 自定义一个注解方法 import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target;/***…...

Python 高效图像帧提取与视频编码:实战指南

Python 高效图像帧提取与视频编码:实战指南 在音视频处理领域,图像帧提取与视频编码是基础但极具挑战性的任务。Python 结合强大的第三方库(如 OpenCV、FFmpeg、PyAV),可以高效处理视频流,实现快速帧提取、压缩编码等关键功能。本文将深入介绍如何优化这些流程,提高处理…...

鸿蒙(HarmonyOS5)实现跳一跳小游戏

下面我将介绍如何使用鸿蒙的ArkUI框架&#xff0c;实现一个简单的跳一跳小游戏。 1. 项目结构 src/main/ets/ ├── MainAbility │ ├── pages │ │ ├── Index.ets // 主页面 │ │ └── GamePage.ets // 游戏页面 │ └── model │ …...