word2vector训练代码详解
目录
1.代码实现
2.知识点
1.代码实现
#导包
import math
import torch
from torch import nn
import dltools
#加载PTB数据集 ,需要把PTB数据集的文件夹放在代码上一级目录的data文件中,不用解压
#批次大小、窗口大小、噪声词大小
batch_size, max_window_size, num_noise_words = 512, 5, 5
#获取数据集迭代器、词汇表
data_iter, vocab = dltools.load_data_ptb(batch_size, max_window_size, num_noise_words)
#讲解嵌入层embedding的用法(此行代码无用)#嵌入层
#通过嵌入层来获取skip—gram的中心词向量和上下文词向量
embed = nn.Embedding(num_embeddings=20, embedding_dim=4)
# num_embeddings就是词表大小
# X的shape=(batch_size, num_steps)
# --one_hot编码--->(batch_size, num_steps, num_embedding(vocab_size))
# --点乘中心词矩阵-->(batch_size, num_steps, embed_size)
embed.weight.shape #讲解嵌入层embedding的用法(此行代码无用)
torch.Size([20, 4])
embedding层先one_hot编码,再进行与embedding层的矩阵(num_embeddings,embedding_dim)乘法
#构造skip_gram的前向传播
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):"""embed_v:表示对中心词进行embedding层embed_u:对上下文词进行embedding层 """v = embed_v(center) #中心词的词向量表达u = embed_u(contexts_and_negatives) #上下文词的词向量表达#用中心词来预测上下文词#u_shape = (batch_size, num_steps, embed_size)---->(batch_size, embed_size, num_steps)进行矩阵乘法pred = torch.bmm(v, u.permute(0, 2, 1)) #矩阵乘法(bmm三维乘法),不用管batch_size维度return pred
#假设数据
skip_gram(torch.ones((2, 1), dtype=torch.long), torch.ones((2, 4), dtype=torch.long), embed, embed)
tensor([[[3.1980, 3.1980, 3.1980, 3.1980]],[[3.1980, 3.1980, 3.1980, 3.1980]]], grad_fn=<BmmBackward0>)
#假设数据
skip_gram(torch.ones((2, 1), dtype=torch.long), torch.ones((2, 4), dtype=torch.long), embed, embed).shape
torch.Size([2, 1, 4])
#带掩码的二元交叉熵损失
class SigmoidBCELoss(nn.Module):def __init__(self):super().__init__() #直接继承父类的初始化属性和方法def forward(self, inputs, target, mask=None):#nn.functional.binary_cross_entropy_with_logits表示返回的不是转化后的概率,是原始计算的数据结果#weight=mask权重将掩码带上#reduction='none'表示不将计算结果聚合,算损失时(默认聚合)out = nn.functional.binary_cross_entropy_with_logits(inputs, target, weight=mask, reduction='none')return out.mean(dim=1) #计算结果是二维的,在索引1维度上聚合求平均
loss = SigmoidBCELoss()
[[1.1, -2.2, 3.3, -4.4]] * 2
[[1.1, -2.2, 3.3, -4.4], [1.1, -2.2, 3.3, -4.4]]
torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2).shape
torch.Size([2, 4])
#假设数据测试
pred = torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2)
label = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])
mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]])
#mask每一行都有4个数值,所以* mask.shape[1]=4
#但是mask中的数值0表示权重,是补充步长的,不重要,需要计算有效序列的损失平均值,所以 / mask.sum(axis=1)
loss(pred, label, mask) * mask.shape[1] / mask.sum(axis=1)
tensor([0.9352, 1.8462])
#初始化模型参数,定义两个嵌入层
#一开始,embed_weights会标准正态分布的数据初始化
#两个embedding层的参数不一样,不能重复使用,需要初始化定义两个
embed_size = 100
net = nn.Sequential(nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_size),nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_size))
#定义训练过程
def train(net, data_iter, lr, num_epochs, device=dltools.try_gpu()):#修改embedding层的初始化方法,使用nn.init.xavier_uniform_初始化embed.weight权重,在NLP中不使用标准正态分布的额数据初始化权重def init_weights(m):if type(m) == nn.Embedding:nn.init.xavier_uniform_(m.weight)net.apply(init_weights) net = net.to(device)#设置梯度下降的优化器optimizer = torch.optim.Adam(net.parameters(), lr=lr)#设置绘制可视化的动图(epoch——loss)animator = dltools.Animator(xlabel='epoch', ylabel='loss', xlim=[1, num_epochs])#设置累加metric = dltools.Accumulator(2) #2种数据需要累加for epoch in range(num_epochs): #遍历训练次数#设置计时器, 赋值批次数量timer, num_batches = dltools.Timer(), len(data_iter) #data_iter是分好批次的数据集,长度就是批次数量num_batchesfor i, batch in enumerate(data_iter): #i是索引, batch是取出的一批批数据#梯度清零optimizer.zero_grad()#接收中心词, 上下文词_噪声词, 掩码, 标记目标值 center, context_negative, mask, label = [data.to(device) for data in batch]#调用skip_gram模型预测pred = skip_gram(center, context_negative, embed_v=net[0], embed_u=net[1])#计算损失l = loss(pred.reshape(label.shape).float(), label.float(), mask) / mask.shape[1] * mask.sum(dim=1)#用loss反向传播 ,loss先sum()聚合变成标量(合并成一个数值), 只有标量才能反向传播l.sum().backward()#梯度更新optimizer.step()#累加metric.add(l.sum(), l.numel()) #l.sum()数值求和累加, l.numel()数量累加# % 取余数 # // 商向下取整#迭代到总数据量的5%的倍数时 或者 处理到最后一批数据时,执行下面操作# i+1是因为i是从0开始遍历的if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1: #epoch + (i+1) / num_batches当前迭代次数占整个数据集的比例animator.add(epoch + (i+1) / num_batches, (metric[0] / metric[1]))print(f'loss {metric[0] / metric[1]:.3f}', f'{metric[1] / timer.stop():.1f} tokens/sec on {str(device)}')
lr, num_epochs = 0.002, 50
train(net, data_iter, lr, num_epochs)
#如果能够找到词的近义词, 就说明训练的不错
def get_similar_tokens(query_token, k, embed):"""query_token:需要预测的词k:最高相似度的词数量embed:embedding层的哪一层"""#获取词向量权重 (词向量权重*词的one_hot编码,就是词向量)W = embed.weight.dataprint(f'W的shape:{W.shape}')x = W[vocab[query_token]] #embedding层是按照索引查表查词对应的权重-->优点print(f'x的shape:{x.shape}')#计算余弦相似度#torch.mv两个向量的点乘cos = torch.mv(W, x) / torch.sqrt(torch.sum(W * W, dim=1) * torch.sum(x * x) + 1e-9)print(f'cos的shape:{cos.shape}')#排序选择前k个对应的索引topk = torch.topk(cos, k=k+1)[1].cpu().numpy().astype('int32')for i in topk[1:]: #排除query_token他本身,自己与自己余弦相似度最高print(f'cosine sim={float(cos[i]):.3f}:{vocab.to_tokens(i)}')
get_similar_tokens('food', 3, net[0])
W的shape:torch.Size([6719, 100]) x的shape:torch.Size([100]) cos的shape:torch.Size([6719]) cosine sim=0.430:feed cosine sim=0.418:precious cosine sim=0.412:drink
2.知识点
相关文章:

word2vector训练代码详解
目录 1.代码实现 2.知识点 1.代码实现 #导包 import math import torch from torch import nn import dltools #加载PTB数据集 ,需要把PTB数据集的文件夹放在代码上一级目录的data文件中,不用解压 #批次大小、窗口大小、噪声词大小 batch_size, ma…...

Python的风格应该是怎样的?除语法外,有哪些规范?
写代码不那么pythonic风格的,多多少少都会让人有点难受。 什么是pythonic呢?简而言之,这是一种写代码时遵守的规范,主打简洁、清晰、可读性高,符合PEP 8(Python代码样式指南)约定的模式。 Pyth…...

net core mvc 数据绑定 《1》
其它的绑定 跟net mvc 一样 》》MVC core 、framework 一样 1 模型绑定数组类型 2 模型绑定集合类型 3 模型绑定复杂的集合类型 4 模型绑定源 》》》》 模型绑定 使用输入数据的原生请求集合是可以工作的【request[],Querystring,request.from[]】, 但是从可读…...

python为姓名注音实战案例
有如下数据,需要对名字注音。 数据样例:👇 一、实现过程 前提条件:由于会用到pypinyin库,所以一定得提前安装。 pip install pypinyin1、详细代码: from pypinyin import pinyin, Style# 输入数据 names…...
MATLAB中的艺术:用爱心形状控制坐标轴
在MATLAB中,坐标轴控制是绘图和数据可视化中的一个重要方面。通过精细地管理坐标轴,我们不仅可以改善图形的视觉效果,还可以赋予图形更深的情感寓意。本文将介绍如何在MATLAB中使用坐标轴控制来绘制一个爱心形状,并探讨其背后的技…...

基于mybatis-plus创建springboot,添加增删改查功能,使用postman来测试接口出现的常见错误
1 当你在使用postman检测 添加和更新功能时,报了一个500错误 查看idea发现是: Data truncation: Out of range value for column id at row 1 通过翻译:数据截断:表单第1行的“id”列出现范围外值。一般情况下,出现这个…...
Java:Object操作
目录 1、Object转List对象2、Object转实体对象 1、Object转List对象 List<User> userList MtUtils.ObjectToList(objData, User.class);/*** Object对象转 List集合** param object Object对象* param clazz 需要转换的集合* param <T> 泛型类* return*/ public s…...
Java-并发基础
启动线程的方式 只有: 1、X extends Thread;,然后X.start 2、X implements Runnable;然后交给Thread运行 有争议可以可以查看 Thread源码的注释: There are two ways to create a new thread of execution.Callable的方式需要…...
速盾:网页游戏部署高防服务器有什么优势?
在当前互联网发展的背景下,网页游戏的市场需求不断增长,相应地带来了对高防服务器的需求。高防服务器可以为网页游戏部署提供许多优势,下面就详细介绍一下。 第一,高防服务器具有强大的抗DDoS攻击能力。DDoS攻击是目前互联网上最…...

【从0开始自动驾驶】ros2编写自定义消息 msg文件和msg文件嵌套
【从0开始自动驾驶】ros2编写自定义消息 msg文件和msg文件嵌套 详细解答和讨论请私信在工作空间内新建一个功能包在msg内创建对应的msg文件创建名为TestMsg.msg的文件创建名为TestSubMsg.msg的文件(在前一个msg文件中引用)修改CmakeList.txt修改package.…...

docker 部署 Seatunnel 和 Seatunnel Web
docker 部署 Seatunnel 和 Seatunnel Web 说明: 部署方式前置条件,已经在宿主机上运行成功运行文件采用挂载宿主机目录的方式部署SeaTunnel Engine 采用的是混合模式集群 编写Dockerfile并打包镜像 Seatunnel FROM openjdk:8 WORKDIR /opt/seatunne…...

【易上手快捷开发新框架技术】nicegui标签组件lable用法庖丁解牛深度解读和示例源代码IDE运行和调试通过截图为证
传奇开心果微博文系列 序言一、标签组件lable最基本用法示例1.在网页上显示出 Hello World 的标签示例2. 使用 style 参数改变标签样式示例 二、标签组件lable更多用法示例1. 添加按钮动态修改标签文字2. 点击按钮动态改变标签内容、颜色、大小和粗细示例代码3. 添加开关组件动…...

从HarmonyOS Next导出手机照片
1)打开DevEco Studio开发工具 2)插入USB数据线,连接手机 3)在DevEco Studio开发工具,通过View -> Tool Windows -> Device File Browser打开管理工具 4)选择storage -> cloud -> 100->fi…...

[Docker学习笔记]Docker的原理Docker常见命令
文章目录 什么是DockerDocker的优势Docker的原理Docker 的安装Docker 的 namespaces Docker的常见命令docker version:查看版本信息docker info 查看docker详细信息我们关注的信息 docker search:镜像搜索docker pull:镜像拉取到本地docker push:推送本地镜像到镜像仓库docker …...
【ESP 保姆级教程】小课设篇 —— 案例:20240507_esp01s+UNO的智能浇水系统
忘记过去,超越自己 ❤️ 博客主页 单片机菜鸟哥,一个野生非专业硬件IOT爱好者 ❤️❤️ 本篇创建记录 2024-09-30 ❤️❤️ 本篇更新记录 2023-09-30 ❤️🎉 欢迎关注 🔎点赞 👍收藏 ⭐️留言📝🙏 此博客均由博主单独编写,不存在任何商业团队运营,如发现错误,请…...
如何设置MySQL分布式架构主键ID,为什么不能使用自增ID或者UUID做主键?
MySQL分布式架构主键ID的设置方法 雪花算法(Snowflake) 原理:雪花算法是一种生成分布式唯一ID的算法。它由64位二进制数组成,结构如下:1位符号位(固定为0) 41位时间戳(表示从一个固…...
服务器虚拟化详解
服务器虚拟化详解 服务器虚拟化是一种将物理服务器资源转化为虚拟服务器资源的技术,它允许在一台物理服务器上运行多个虚拟服务器,每个虚拟服务器都拥有独立的操作系统、应用程序和资源配置。这种技术极大地提高了服务器的利用率、灵活性和可扩展性&…...

医疗陪诊APP开发实战:从互联网医院系统源码开始
本文将从互联网医院系统源码出发,深入探讨医疗陪诊APP的开发实战。 一、从互联网医院系统源码入手 开发医疗陪诊APP的基础在于互联网医院系统的源码。互联网医院系统通常包括以下几个模块: 1.用户管理:用户注册、登录、信息管理等功能。 …...

jenkins 构建报错ERROR: Error fetching remote repo ‘origin‘
问题描述 修改项目的仓库地址后,使用jenkins构建报错 Running as SYSTEM Building in workspace /var/jenkins_home/workspace/【测试】客户端/client-fonchain-main The recommended git tool is: NONE using credential 680a5841-cfa5-4d8a-bb38-977f796c26dd&g…...
初识C#(三)- 数组
我有17栋楼,在不同地域,都是不同价格租出去给不同的人~ 文章目录 前言一、数组1.1 我有17栋楼 - 数组的声明1.2 包租公&包租婆 - 数组赋值1.3 每个月都要交租的苦逼租客 - 数组的使用 二、字符串2.1 字符串的使用方法 总结 前言 本篇笔记重点描述C#…...

盘古信息PCB行业解决方案:以全域场景重构,激活智造新未来
一、破局:PCB行业的时代之问 在数字经济蓬勃发展的浪潮中,PCB(印制电路板)作为 “电子产品之母”,其重要性愈发凸显。随着 5G、人工智能等新兴技术的加速渗透,PCB行业面临着前所未有的挑战与机遇。产品迭代…...

Keil 中设置 STM32 Flash 和 RAM 地址详解
文章目录 Keil 中设置 STM32 Flash 和 RAM 地址详解一、Flash 和 RAM 配置界面(Target 选项卡)1. IROM1(用于配置 Flash)2. IRAM1(用于配置 RAM)二、链接器设置界面(Linker 选项卡)1. 勾选“Use Memory Layout from Target Dialog”2. 查看链接器参数(如果没有勾选上面…...

Psychopy音频的使用
Psychopy音频的使用 本文主要解决以下问题: 指定音频引擎与设备;播放音频文件 本文所使用的环境: Python3.10 numpy2.2.6 psychopy2025.1.1 psychtoolbox3.0.19.14 一、音频配置 Psychopy文档链接为Sound - for audio playback — Psy…...
C++.OpenGL (10/64)基础光照(Basic Lighting)
基础光照(Basic Lighting) 冯氏光照模型(Phong Lighting Model) #mermaid-svg-GLdskXwWINxNGHso {font-family:"trebuchet ms",verdana,arial,sans-serif;font-size:16px;fill:#333;}#mermaid-svg-GLdskXwWINxNGHso .error-icon{fill:#552222;}#mermaid-svg-GLd…...
Unit 1 深度强化学习简介
Deep RL Course ——Unit 1 Introduction 从理论和实践层面深入学习深度强化学习。学会使用知名的深度强化学习库,例如 Stable Baselines3、RL Baselines3 Zoo、Sample Factory 和 CleanRL。在独特的环境中训练智能体,比如 SnowballFight、Huggy the Do…...

vue3+vite项目中使用.env文件环境变量方法
vue3vite项目中使用.env文件环境变量方法 .env文件作用命名规则常用的配置项示例使用方法注意事项在vite.config.js文件中读取环境变量方法 .env文件作用 .env 文件用于定义环境变量,这些变量可以在项目中通过 import.meta.env 进行访问。Vite 会自动加载这些环境变…...

Mysql中select查询语句的执行过程
目录 1、介绍 1.1、组件介绍 1.2、Sql执行顺序 2、执行流程 2.1. 连接与认证 2.2. 查询缓存 2.3. 语法解析(Parser) 2.4、执行sql 1. 预处理(Preprocessor) 2. 查询优化器(Optimizer) 3. 执行器…...

iview框架主题色的应用
1.下载 less要使用3.0.0以下的版本 npm install less2.7.3 npm install less-loader4.0.52./src/config/theme.js文件 module.exports {yellow: {theme-color: #FDCE04},blue: {theme-color: #547CE7} }在sass中使用theme配置的颜色主题,无需引入,直接可…...
Bean 作用域有哪些?如何答出技术深度?
导语: Spring 面试绕不开 Bean 的作用域问题,这是面试官考察候选人对 Spring 框架理解深度的常见方式。本文将围绕“Spring 中的 Bean 作用域”展开,结合典型面试题及实战场景,帮你厘清重点,打破模板式回答,…...

给网站添加live2d看板娘
给网站添加live2d看板娘 参考文献: stevenjoezhang/live2d-widget: 把萌萌哒的看板娘抱回家 (ノ≧∇≦)ノ | Live2D widget for web platformEikanya/Live2d-model: Live2d model collectionzenghongtu/live2d-model-assets 前言 网站环境如下,文章也主…...