word2vector训练代码详解
目录
1.代码实现
2.知识点
1.代码实现
#导包
import math
import torch
from torch import nn
import dltools
#加载PTB数据集 ,需要把PTB数据集的文件夹放在代码上一级目录的data文件中,不用解压
#批次大小、窗口大小、噪声词大小
batch_size, max_window_size, num_noise_words = 512, 5, 5
#获取数据集迭代器、词汇表
data_iter, vocab = dltools.load_data_ptb(batch_size, max_window_size, num_noise_words)
#讲解嵌入层embedding的用法(此行代码无用)#嵌入层
#通过嵌入层来获取skip—gram的中心词向量和上下文词向量
embed = nn.Embedding(num_embeddings=20, embedding_dim=4)
# num_embeddings就是词表大小
# X的shape=(batch_size, num_steps)
# --one_hot编码--->(batch_size, num_steps, num_embedding(vocab_size))
# --点乘中心词矩阵-->(batch_size, num_steps, embed_size)
embed.weight.shape #讲解嵌入层embedding的用法(此行代码无用)
torch.Size([20, 4])
embedding层先one_hot编码,再进行与embedding层的矩阵(num_embeddings,embedding_dim)乘法
#构造skip_gram的前向传播
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):"""embed_v:表示对中心词进行embedding层embed_u:对上下文词进行embedding层 """v = embed_v(center) #中心词的词向量表达u = embed_u(contexts_and_negatives) #上下文词的词向量表达#用中心词来预测上下文词#u_shape = (batch_size, num_steps, embed_size)---->(batch_size, embed_size, num_steps)进行矩阵乘法pred = torch.bmm(v, u.permute(0, 2, 1)) #矩阵乘法(bmm三维乘法),不用管batch_size维度return pred
#假设数据
skip_gram(torch.ones((2, 1), dtype=torch.long), torch.ones((2, 4), dtype=torch.long), embed, embed)
tensor([[[3.1980, 3.1980, 3.1980, 3.1980]],[[3.1980, 3.1980, 3.1980, 3.1980]]], grad_fn=<BmmBackward0>)
#假设数据
skip_gram(torch.ones((2, 1), dtype=torch.long), torch.ones((2, 4), dtype=torch.long), embed, embed).shape
torch.Size([2, 1, 4])
#带掩码的二元交叉熵损失
class SigmoidBCELoss(nn.Module):def __init__(self):super().__init__() #直接继承父类的初始化属性和方法def forward(self, inputs, target, mask=None):#nn.functional.binary_cross_entropy_with_logits表示返回的不是转化后的概率,是原始计算的数据结果#weight=mask权重将掩码带上#reduction='none'表示不将计算结果聚合,算损失时(默认聚合)out = nn.functional.binary_cross_entropy_with_logits(inputs, target, weight=mask, reduction='none')return out.mean(dim=1) #计算结果是二维的,在索引1维度上聚合求平均
loss = SigmoidBCELoss()
[[1.1, -2.2, 3.3, -4.4]] * 2
[[1.1, -2.2, 3.3, -4.4], [1.1, -2.2, 3.3, -4.4]]
torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2).shape
torch.Size([2, 4])
#假设数据测试
pred = torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2)
label = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])
mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]])
#mask每一行都有4个数值,所以* mask.shape[1]=4
#但是mask中的数值0表示权重,是补充步长的,不重要,需要计算有效序列的损失平均值,所以 / mask.sum(axis=1)
loss(pred, label, mask) * mask.shape[1] / mask.sum(axis=1)
tensor([0.9352, 1.8462])
#初始化模型参数,定义两个嵌入层
#一开始,embed_weights会标准正态分布的数据初始化
#两个embedding层的参数不一样,不能重复使用,需要初始化定义两个
embed_size = 100
net = nn.Sequential(nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_size),nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_size))
#定义训练过程
def train(net, data_iter, lr, num_epochs, device=dltools.try_gpu()):#修改embedding层的初始化方法,使用nn.init.xavier_uniform_初始化embed.weight权重,在NLP中不使用标准正态分布的额数据初始化权重def init_weights(m):if type(m) == nn.Embedding:nn.init.xavier_uniform_(m.weight)net.apply(init_weights) net = net.to(device)#设置梯度下降的优化器optimizer = torch.optim.Adam(net.parameters(), lr=lr)#设置绘制可视化的动图(epoch——loss)animator = dltools.Animator(xlabel='epoch', ylabel='loss', xlim=[1, num_epochs])#设置累加metric = dltools.Accumulator(2) #2种数据需要累加for epoch in range(num_epochs): #遍历训练次数#设置计时器, 赋值批次数量timer, num_batches = dltools.Timer(), len(data_iter) #data_iter是分好批次的数据集,长度就是批次数量num_batchesfor i, batch in enumerate(data_iter): #i是索引, batch是取出的一批批数据#梯度清零optimizer.zero_grad()#接收中心词, 上下文词_噪声词, 掩码, 标记目标值 center, context_negative, mask, label = [data.to(device) for data in batch]#调用skip_gram模型预测pred = skip_gram(center, context_negative, embed_v=net[0], embed_u=net[1])#计算损失l = loss(pred.reshape(label.shape).float(), label.float(), mask) / mask.shape[1] * mask.sum(dim=1)#用loss反向传播 ,loss先sum()聚合变成标量(合并成一个数值), 只有标量才能反向传播l.sum().backward()#梯度更新optimizer.step()#累加metric.add(l.sum(), l.numel()) #l.sum()数值求和累加, l.numel()数量累加# % 取余数 # // 商向下取整#迭代到总数据量的5%的倍数时 或者 处理到最后一批数据时,执行下面操作# i+1是因为i是从0开始遍历的if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1: #epoch + (i+1) / num_batches当前迭代次数占整个数据集的比例animator.add(epoch + (i+1) / num_batches, (metric[0] / metric[1]))print(f'loss {metric[0] / metric[1]:.3f}', f'{metric[1] / timer.stop():.1f} tokens/sec on {str(device)}')
lr, num_epochs = 0.002, 50
train(net, data_iter, lr, num_epochs)

#如果能够找到词的近义词, 就说明训练的不错
def get_similar_tokens(query_token, k, embed):"""query_token:需要预测的词k:最高相似度的词数量embed:embedding层的哪一层"""#获取词向量权重 (词向量权重*词的one_hot编码,就是词向量)W = embed.weight.dataprint(f'W的shape:{W.shape}')x = W[vocab[query_token]] #embedding层是按照索引查表查词对应的权重-->优点print(f'x的shape:{x.shape}')#计算余弦相似度#torch.mv两个向量的点乘cos = torch.mv(W, x) / torch.sqrt(torch.sum(W * W, dim=1) * torch.sum(x * x) + 1e-9)print(f'cos的shape:{cos.shape}')#排序选择前k个对应的索引topk = torch.topk(cos, k=k+1)[1].cpu().numpy().astype('int32')for i in topk[1:]: #排除query_token他本身,自己与自己余弦相似度最高print(f'cosine sim={float(cos[i]):.3f}:{vocab.to_tokens(i)}')
get_similar_tokens('food', 3, net[0])
W的shape:torch.Size([6719, 100]) x的shape:torch.Size([100]) cos的shape:torch.Size([6719]) cosine sim=0.430:feed cosine sim=0.418:precious cosine sim=0.412:drink
2.知识点


相关文章:
word2vector训练代码详解
目录 1.代码实现 2.知识点 1.代码实现 #导包 import math import torch from torch import nn import dltools #加载PTB数据集 ,需要把PTB数据集的文件夹放在代码上一级目录的data文件中,不用解压 #批次大小、窗口大小、噪声词大小 batch_size, ma…...
Python的风格应该是怎样的?除语法外,有哪些规范?
写代码不那么pythonic风格的,多多少少都会让人有点难受。 什么是pythonic呢?简而言之,这是一种写代码时遵守的规范,主打简洁、清晰、可读性高,符合PEP 8(Python代码样式指南)约定的模式。 Pyth…...
net core mvc 数据绑定 《1》
其它的绑定 跟net mvc 一样 》》MVC core 、framework 一样 1 模型绑定数组类型 2 模型绑定集合类型 3 模型绑定复杂的集合类型 4 模型绑定源 》》》》 模型绑定 使用输入数据的原生请求集合是可以工作的【request[],Querystring,request.from[]】, 但是从可读…...
python为姓名注音实战案例
有如下数据,需要对名字注音。 数据样例:👇 一、实现过程 前提条件:由于会用到pypinyin库,所以一定得提前安装。 pip install pypinyin1、详细代码: from pypinyin import pinyin, Style# 输入数据 names…...
MATLAB中的艺术:用爱心形状控制坐标轴
在MATLAB中,坐标轴控制是绘图和数据可视化中的一个重要方面。通过精细地管理坐标轴,我们不仅可以改善图形的视觉效果,还可以赋予图形更深的情感寓意。本文将介绍如何在MATLAB中使用坐标轴控制来绘制一个爱心形状,并探讨其背后的技…...
基于mybatis-plus创建springboot,添加增删改查功能,使用postman来测试接口出现的常见错误
1 当你在使用postman检测 添加和更新功能时,报了一个500错误 查看idea发现是: Data truncation: Out of range value for column id at row 1 通过翻译:数据截断:表单第1行的“id”列出现范围外值。一般情况下,出现这个…...
Java:Object操作
目录 1、Object转List对象2、Object转实体对象 1、Object转List对象 List<User> userList MtUtils.ObjectToList(objData, User.class);/*** Object对象转 List集合** param object Object对象* param clazz 需要转换的集合* param <T> 泛型类* return*/ public s…...
Java-并发基础
启动线程的方式 只有: 1、X extends Thread;,然后X.start 2、X implements Runnable;然后交给Thread运行 有争议可以可以查看 Thread源码的注释: There are two ways to create a new thread of execution.Callable的方式需要…...
速盾:网页游戏部署高防服务器有什么优势?
在当前互联网发展的背景下,网页游戏的市场需求不断增长,相应地带来了对高防服务器的需求。高防服务器可以为网页游戏部署提供许多优势,下面就详细介绍一下。 第一,高防服务器具有强大的抗DDoS攻击能力。DDoS攻击是目前互联网上最…...
【从0开始自动驾驶】ros2编写自定义消息 msg文件和msg文件嵌套
【从0开始自动驾驶】ros2编写自定义消息 msg文件和msg文件嵌套 详细解答和讨论请私信在工作空间内新建一个功能包在msg内创建对应的msg文件创建名为TestMsg.msg的文件创建名为TestSubMsg.msg的文件(在前一个msg文件中引用)修改CmakeList.txt修改package.…...
docker 部署 Seatunnel 和 Seatunnel Web
docker 部署 Seatunnel 和 Seatunnel Web 说明: 部署方式前置条件,已经在宿主机上运行成功运行文件采用挂载宿主机目录的方式部署SeaTunnel Engine 采用的是混合模式集群 编写Dockerfile并打包镜像 Seatunnel FROM openjdk:8 WORKDIR /opt/seatunne…...
【易上手快捷开发新框架技术】nicegui标签组件lable用法庖丁解牛深度解读和示例源代码IDE运行和调试通过截图为证
传奇开心果微博文系列 序言一、标签组件lable最基本用法示例1.在网页上显示出 Hello World 的标签示例2. 使用 style 参数改变标签样式示例 二、标签组件lable更多用法示例1. 添加按钮动态修改标签文字2. 点击按钮动态改变标签内容、颜色、大小和粗细示例代码3. 添加开关组件动…...
从HarmonyOS Next导出手机照片
1)打开DevEco Studio开发工具 2)插入USB数据线,连接手机 3)在DevEco Studio开发工具,通过View -> Tool Windows -> Device File Browser打开管理工具 4)选择storage -> cloud -> 100->fi…...
[Docker学习笔记]Docker的原理Docker常见命令
文章目录 什么是DockerDocker的优势Docker的原理Docker 的安装Docker 的 namespaces Docker的常见命令docker version:查看版本信息docker info 查看docker详细信息我们关注的信息 docker search:镜像搜索docker pull:镜像拉取到本地docker push:推送本地镜像到镜像仓库docker …...
【ESP 保姆级教程】小课设篇 —— 案例:20240507_esp01s+UNO的智能浇水系统
忘记过去,超越自己 ❤️ 博客主页 单片机菜鸟哥,一个野生非专业硬件IOT爱好者 ❤️❤️ 本篇创建记录 2024-09-30 ❤️❤️ 本篇更新记录 2023-09-30 ❤️🎉 欢迎关注 🔎点赞 👍收藏 ⭐️留言📝🙏 此博客均由博主单独编写,不存在任何商业团队运营,如发现错误,请…...
如何设置MySQL分布式架构主键ID,为什么不能使用自增ID或者UUID做主键?
MySQL分布式架构主键ID的设置方法 雪花算法(Snowflake) 原理:雪花算法是一种生成分布式唯一ID的算法。它由64位二进制数组成,结构如下:1位符号位(固定为0) 41位时间戳(表示从一个固…...
服务器虚拟化详解
服务器虚拟化详解 服务器虚拟化是一种将物理服务器资源转化为虚拟服务器资源的技术,它允许在一台物理服务器上运行多个虚拟服务器,每个虚拟服务器都拥有独立的操作系统、应用程序和资源配置。这种技术极大地提高了服务器的利用率、灵活性和可扩展性&…...
医疗陪诊APP开发实战:从互联网医院系统源码开始
本文将从互联网医院系统源码出发,深入探讨医疗陪诊APP的开发实战。 一、从互联网医院系统源码入手 开发医疗陪诊APP的基础在于互联网医院系统的源码。互联网医院系统通常包括以下几个模块: 1.用户管理:用户注册、登录、信息管理等功能。 …...
jenkins 构建报错ERROR: Error fetching remote repo ‘origin‘
问题描述 修改项目的仓库地址后,使用jenkins构建报错 Running as SYSTEM Building in workspace /var/jenkins_home/workspace/【测试】客户端/client-fonchain-main The recommended git tool is: NONE using credential 680a5841-cfa5-4d8a-bb38-977f796c26dd&g…...
初识C#(三)- 数组
我有17栋楼,在不同地域,都是不同价格租出去给不同的人~ 文章目录 前言一、数组1.1 我有17栋楼 - 数组的声明1.2 包租公&包租婆 - 数组赋值1.3 每个月都要交租的苦逼租客 - 数组的使用 二、字符串2.1 字符串的使用方法 总结 前言 本篇笔记重点描述C#…...
uniapp 对接腾讯云IM群组成员管理(增删改查)
UniApp 实战:腾讯云IM群组成员管理(增删改查) 一、前言 在社交类App开发中,群组成员管理是核心功能之一。本文将基于UniApp框架,结合腾讯云IM SDK,详细讲解如何实现群组成员的增删改查全流程。 权限校验…...
进程地址空间(比特课总结)
一、进程地址空间 1. 环境变量 1 )⽤户级环境变量与系统级环境变量 全局属性:环境变量具有全局属性,会被⼦进程继承。例如当bash启动⼦进程时,环 境变量会⾃动传递给⼦进程。 本地变量限制:本地变量只在当前进程(ba…...
Unity3D中Gfx.WaitForPresent优化方案
前言 在Unity中,Gfx.WaitForPresent占用CPU过高通常表示主线程在等待GPU完成渲染(即CPU被阻塞),这表明存在GPU瓶颈或垂直同步/帧率设置问题。以下是系统的优化方案: 对惹,这里有一个游戏开发交流小组&…...
聊聊 Pulsar:Producer 源码解析
一、前言 Apache Pulsar 是一个企业级的开源分布式消息传递平台,以其高性能、可扩展性和存储计算分离架构在消息队列和流处理领域独树一帜。在 Pulsar 的核心架构中,Producer(生产者) 是连接客户端应用与消息队列的第一步。生产者…...
渗透实战PortSwigger靶场-XSS Lab 14:大多数标签和属性被阻止
<script>标签被拦截 我们需要把全部可用的 tag 和 event 进行暴力破解 XSS cheat sheet: https://portswigger.net/web-security/cross-site-scripting/cheat-sheet 通过爆破发现body可以用 再把全部 events 放进去爆破 这些 event 全部可用 <body onres…...
【Redis技术进阶之路】「原理分析系列开篇」分析客户端和服务端网络诵信交互实现(服务端执行命令请求的过程 - 初始化服务器)
服务端执行命令请求的过程 【专栏简介】【技术大纲】【专栏目标】【目标人群】1. Redis爱好者与社区成员2. 后端开发和系统架构师3. 计算机专业的本科生及研究生 初始化服务器1. 初始化服务器状态结构初始化RedisServer变量 2. 加载相关系统配置和用户配置参数定制化配置参数案…...
系统设计 --- MongoDB亿级数据查询优化策略
系统设计 --- MongoDB亿级数据查询分表策略 背景Solution --- 分表 背景 使用audit log实现Audi Trail功能 Audit Trail范围: 六个月数据量: 每秒5-7条audi log,共计7千万 – 1亿条数据需要实现全文检索按照时间倒序因为license问题,不能使用ELK只能使用…...
1.3 VSCode安装与环境配置
进入网址Visual Studio Code - Code Editing. Redefined下载.deb文件,然后打开终端,进入下载文件夹,键入命令 sudo dpkg -i code_1.100.3-1748872405_amd64.deb 在终端键入命令code即启动vscode 需要安装插件列表 1.Chinese简化 2.ros …...
【Go】3、Go语言进阶与依赖管理
前言 本系列文章参考自稀土掘金上的 【字节内部课】公开课,做自我学习总结整理。 Go语言并发编程 Go语言原生支持并发编程,它的核心机制是 Goroutine 协程、Channel 通道,并基于CSP(Communicating Sequential Processes࿰…...
在WSL2的Ubuntu镜像中安装Docker
Docker官网链接: https://docs.docker.com/engine/install/ubuntu/ 1、运行以下命令卸载所有冲突的软件包: for pkg in docker.io docker-doc docker-compose docker-compose-v2 podman-docker containerd runc; do sudo apt-get remove $pkg; done2、设置Docker…...
