word2vector训练代码详解
目录
1.代码实现
2.知识点
1.代码实现
#导包
import math
import torch
from torch import nn
import dltools
#加载PTB数据集 ,需要把PTB数据集的文件夹放在代码上一级目录的data文件中,不用解压
#批次大小、窗口大小、噪声词大小
batch_size, max_window_size, num_noise_words = 512, 5, 5
#获取数据集迭代器、词汇表
data_iter, vocab = dltools.load_data_ptb(batch_size, max_window_size, num_noise_words)
#讲解嵌入层embedding的用法(此行代码无用)#嵌入层
#通过嵌入层来获取skip—gram的中心词向量和上下文词向量
embed = nn.Embedding(num_embeddings=20, embedding_dim=4)
# num_embeddings就是词表大小
# X的shape=(batch_size, num_steps)
# --one_hot编码--->(batch_size, num_steps, num_embedding(vocab_size))
# --点乘中心词矩阵-->(batch_size, num_steps, embed_size)
embed.weight.shape #讲解嵌入层embedding的用法(此行代码无用)
torch.Size([20, 4])
embedding层先one_hot编码,再进行与embedding层的矩阵(num_embeddings,embedding_dim)乘法
#构造skip_gram的前向传播
def skip_gram(center, contexts_and_negatives, embed_v, embed_u):"""embed_v:表示对中心词进行embedding层embed_u:对上下文词进行embedding层 """v = embed_v(center) #中心词的词向量表达u = embed_u(contexts_and_negatives) #上下文词的词向量表达#用中心词来预测上下文词#u_shape = (batch_size, num_steps, embed_size)---->(batch_size, embed_size, num_steps)进行矩阵乘法pred = torch.bmm(v, u.permute(0, 2, 1)) #矩阵乘法(bmm三维乘法),不用管batch_size维度return pred
#假设数据
skip_gram(torch.ones((2, 1), dtype=torch.long), torch.ones((2, 4), dtype=torch.long), embed, embed)
tensor([[[3.1980, 3.1980, 3.1980, 3.1980]],[[3.1980, 3.1980, 3.1980, 3.1980]]], grad_fn=<BmmBackward0>)
#假设数据
skip_gram(torch.ones((2, 1), dtype=torch.long), torch.ones((2, 4), dtype=torch.long), embed, embed).shape
torch.Size([2, 1, 4])
#带掩码的二元交叉熵损失
class SigmoidBCELoss(nn.Module):def __init__(self):super().__init__() #直接继承父类的初始化属性和方法def forward(self, inputs, target, mask=None):#nn.functional.binary_cross_entropy_with_logits表示返回的不是转化后的概率,是原始计算的数据结果#weight=mask权重将掩码带上#reduction='none'表示不将计算结果聚合,算损失时(默认聚合)out = nn.functional.binary_cross_entropy_with_logits(inputs, target, weight=mask, reduction='none')return out.mean(dim=1) #计算结果是二维的,在索引1维度上聚合求平均
loss = SigmoidBCELoss()
[[1.1, -2.2, 3.3, -4.4]] * 2
[[1.1, -2.2, 3.3, -4.4], [1.1, -2.2, 3.3, -4.4]]
torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2).shape
torch.Size([2, 4])
#假设数据测试
pred = torch.tensor([[1.1, -2.2, 3.3, -4.4]] * 2)
label = torch.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])
mask = torch.tensor([[1, 1, 1, 1], [1, 1, 0, 0]])
#mask每一行都有4个数值,所以* mask.shape[1]=4
#但是mask中的数值0表示权重,是补充步长的,不重要,需要计算有效序列的损失平均值,所以 / mask.sum(axis=1)
loss(pred, label, mask) * mask.shape[1] / mask.sum(axis=1)
tensor([0.9352, 1.8462])
#初始化模型参数,定义两个嵌入层
#一开始,embed_weights会标准正态分布的数据初始化
#两个embedding层的参数不一样,不能重复使用,需要初始化定义两个
embed_size = 100
net = nn.Sequential(nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_size),nn.Embedding(num_embeddings=len(vocab), embedding_dim=embed_size))
#定义训练过程
def train(net, data_iter, lr, num_epochs, device=dltools.try_gpu()):#修改embedding层的初始化方法,使用nn.init.xavier_uniform_初始化embed.weight权重,在NLP中不使用标准正态分布的额数据初始化权重def init_weights(m):if type(m) == nn.Embedding:nn.init.xavier_uniform_(m.weight)net.apply(init_weights) net = net.to(device)#设置梯度下降的优化器optimizer = torch.optim.Adam(net.parameters(), lr=lr)#设置绘制可视化的动图(epoch——loss)animator = dltools.Animator(xlabel='epoch', ylabel='loss', xlim=[1, num_epochs])#设置累加metric = dltools.Accumulator(2) #2种数据需要累加for epoch in range(num_epochs): #遍历训练次数#设置计时器, 赋值批次数量timer, num_batches = dltools.Timer(), len(data_iter) #data_iter是分好批次的数据集,长度就是批次数量num_batchesfor i, batch in enumerate(data_iter): #i是索引, batch是取出的一批批数据#梯度清零optimizer.zero_grad()#接收中心词, 上下文词_噪声词, 掩码, 标记目标值 center, context_negative, mask, label = [data.to(device) for data in batch]#调用skip_gram模型预测pred = skip_gram(center, context_negative, embed_v=net[0], embed_u=net[1])#计算损失l = loss(pred.reshape(label.shape).float(), label.float(), mask) / mask.shape[1] * mask.sum(dim=1)#用loss反向传播 ,loss先sum()聚合变成标量(合并成一个数值), 只有标量才能反向传播l.sum().backward()#梯度更新optimizer.step()#累加metric.add(l.sum(), l.numel()) #l.sum()数值求和累加, l.numel()数量累加# % 取余数 # // 商向下取整#迭代到总数据量的5%的倍数时 或者 处理到最后一批数据时,执行下面操作# i+1是因为i是从0开始遍历的if (i + 1) % (num_batches // 5) == 0 or i == num_batches - 1: #epoch + (i+1) / num_batches当前迭代次数占整个数据集的比例animator.add(epoch + (i+1) / num_batches, (metric[0] / metric[1]))print(f'loss {metric[0] / metric[1]:.3f}', f'{metric[1] / timer.stop():.1f} tokens/sec on {str(device)}')
lr, num_epochs = 0.002, 50
train(net, data_iter, lr, num_epochs)
#如果能够找到词的近义词, 就说明训练的不错
def get_similar_tokens(query_token, k, embed):"""query_token:需要预测的词k:最高相似度的词数量embed:embedding层的哪一层"""#获取词向量权重 (词向量权重*词的one_hot编码,就是词向量)W = embed.weight.dataprint(f'W的shape:{W.shape}')x = W[vocab[query_token]] #embedding层是按照索引查表查词对应的权重-->优点print(f'x的shape:{x.shape}')#计算余弦相似度#torch.mv两个向量的点乘cos = torch.mv(W, x) / torch.sqrt(torch.sum(W * W, dim=1) * torch.sum(x * x) + 1e-9)print(f'cos的shape:{cos.shape}')#排序选择前k个对应的索引topk = torch.topk(cos, k=k+1)[1].cpu().numpy().astype('int32')for i in topk[1:]: #排除query_token他本身,自己与自己余弦相似度最高print(f'cosine sim={float(cos[i]):.3f}:{vocab.to_tokens(i)}')
get_similar_tokens('food', 3, net[0])
W的shape:torch.Size([6719, 100]) x的shape:torch.Size([100]) cos的shape:torch.Size([6719]) cosine sim=0.430:feed cosine sim=0.418:precious cosine sim=0.412:drink
2.知识点
相关文章:
word2vector训练代码详解
目录 1.代码实现 2.知识点 1.代码实现 #导包 import math import torch from torch import nn import dltools #加载PTB数据集 ,需要把PTB数据集的文件夹放在代码上一级目录的data文件中,不用解压 #批次大小、窗口大小、噪声词大小 batch_size, ma…...
Python的风格应该是怎样的?除语法外,有哪些规范?
写代码不那么pythonic风格的,多多少少都会让人有点难受。 什么是pythonic呢?简而言之,这是一种写代码时遵守的规范,主打简洁、清晰、可读性高,符合PEP 8(Python代码样式指南)约定的模式。 Pyth…...
net core mvc 数据绑定 《1》
其它的绑定 跟net mvc 一样 》》MVC core 、framework 一样 1 模型绑定数组类型 2 模型绑定集合类型 3 模型绑定复杂的集合类型 4 模型绑定源 》》》》 模型绑定 使用输入数据的原生请求集合是可以工作的【request[],Querystring,request.from[]】, 但是从可读…...
python为姓名注音实战案例
有如下数据,需要对名字注音。 数据样例:👇 一、实现过程 前提条件:由于会用到pypinyin库,所以一定得提前安装。 pip install pypinyin1、详细代码: from pypinyin import pinyin, Style# 输入数据 names…...
MATLAB中的艺术:用爱心形状控制坐标轴
在MATLAB中,坐标轴控制是绘图和数据可视化中的一个重要方面。通过精细地管理坐标轴,我们不仅可以改善图形的视觉效果,还可以赋予图形更深的情感寓意。本文将介绍如何在MATLAB中使用坐标轴控制来绘制一个爱心形状,并探讨其背后的技…...
基于mybatis-plus创建springboot,添加增删改查功能,使用postman来测试接口出现的常见错误
1 当你在使用postman检测 添加和更新功能时,报了一个500错误 查看idea发现是: Data truncation: Out of range value for column id at row 1 通过翻译:数据截断:表单第1行的“id”列出现范围外值。一般情况下,出现这个…...
Java:Object操作
目录 1、Object转List对象2、Object转实体对象 1、Object转List对象 List<User> userList MtUtils.ObjectToList(objData, User.class);/*** Object对象转 List集合** param object Object对象* param clazz 需要转换的集合* param <T> 泛型类* return*/ public s…...
Java-并发基础
启动线程的方式 只有: 1、X extends Thread;,然后X.start 2、X implements Runnable;然后交给Thread运行 有争议可以可以查看 Thread源码的注释: There are two ways to create a new thread of execution.Callable的方式需要…...
速盾:网页游戏部署高防服务器有什么优势?
在当前互联网发展的背景下,网页游戏的市场需求不断增长,相应地带来了对高防服务器的需求。高防服务器可以为网页游戏部署提供许多优势,下面就详细介绍一下。 第一,高防服务器具有强大的抗DDoS攻击能力。DDoS攻击是目前互联网上最…...
【从0开始自动驾驶】ros2编写自定义消息 msg文件和msg文件嵌套
【从0开始自动驾驶】ros2编写自定义消息 msg文件和msg文件嵌套 详细解答和讨论请私信在工作空间内新建一个功能包在msg内创建对应的msg文件创建名为TestMsg.msg的文件创建名为TestSubMsg.msg的文件(在前一个msg文件中引用)修改CmakeList.txt修改package.…...
docker 部署 Seatunnel 和 Seatunnel Web
docker 部署 Seatunnel 和 Seatunnel Web 说明: 部署方式前置条件,已经在宿主机上运行成功运行文件采用挂载宿主机目录的方式部署SeaTunnel Engine 采用的是混合模式集群 编写Dockerfile并打包镜像 Seatunnel FROM openjdk:8 WORKDIR /opt/seatunne…...
【易上手快捷开发新框架技术】nicegui标签组件lable用法庖丁解牛深度解读和示例源代码IDE运行和调试通过截图为证
传奇开心果微博文系列 序言一、标签组件lable最基本用法示例1.在网页上显示出 Hello World 的标签示例2. 使用 style 参数改变标签样式示例 二、标签组件lable更多用法示例1. 添加按钮动态修改标签文字2. 点击按钮动态改变标签内容、颜色、大小和粗细示例代码3. 添加开关组件动…...
从HarmonyOS Next导出手机照片
1)打开DevEco Studio开发工具 2)插入USB数据线,连接手机 3)在DevEco Studio开发工具,通过View -> Tool Windows -> Device File Browser打开管理工具 4)选择storage -> cloud -> 100->fi…...
[Docker学习笔记]Docker的原理Docker常见命令
文章目录 什么是DockerDocker的优势Docker的原理Docker 的安装Docker 的 namespaces Docker的常见命令docker version:查看版本信息docker info 查看docker详细信息我们关注的信息 docker search:镜像搜索docker pull:镜像拉取到本地docker push:推送本地镜像到镜像仓库docker …...
【ESP 保姆级教程】小课设篇 —— 案例:20240507_esp01s+UNO的智能浇水系统
忘记过去,超越自己 ❤️ 博客主页 单片机菜鸟哥,一个野生非专业硬件IOT爱好者 ❤️❤️ 本篇创建记录 2024-09-30 ❤️❤️ 本篇更新记录 2023-09-30 ❤️🎉 欢迎关注 🔎点赞 👍收藏 ⭐️留言📝🙏 此博客均由博主单独编写,不存在任何商业团队运营,如发现错误,请…...
如何设置MySQL分布式架构主键ID,为什么不能使用自增ID或者UUID做主键?
MySQL分布式架构主键ID的设置方法 雪花算法(Snowflake) 原理:雪花算法是一种生成分布式唯一ID的算法。它由64位二进制数组成,结构如下:1位符号位(固定为0) 41位时间戳(表示从一个固…...
服务器虚拟化详解
服务器虚拟化详解 服务器虚拟化是一种将物理服务器资源转化为虚拟服务器资源的技术,它允许在一台物理服务器上运行多个虚拟服务器,每个虚拟服务器都拥有独立的操作系统、应用程序和资源配置。这种技术极大地提高了服务器的利用率、灵活性和可扩展性&…...
医疗陪诊APP开发实战:从互联网医院系统源码开始
本文将从互联网医院系统源码出发,深入探讨医疗陪诊APP的开发实战。 一、从互联网医院系统源码入手 开发医疗陪诊APP的基础在于互联网医院系统的源码。互联网医院系统通常包括以下几个模块: 1.用户管理:用户注册、登录、信息管理等功能。 …...
jenkins 构建报错ERROR: Error fetching remote repo ‘origin‘
问题描述 修改项目的仓库地址后,使用jenkins构建报错 Running as SYSTEM Building in workspace /var/jenkins_home/workspace/【测试】客户端/client-fonchain-main The recommended git tool is: NONE using credential 680a5841-cfa5-4d8a-bb38-977f796c26dd&g…...
初识C#(三)- 数组
我有17栋楼,在不同地域,都是不同价格租出去给不同的人~ 文章目录 前言一、数组1.1 我有17栋楼 - 数组的声明1.2 包租公&包租婆 - 数组赋值1.3 每个月都要交租的苦逼租客 - 数组的使用 二、字符串2.1 字符串的使用方法 总结 前言 本篇笔记重点描述C#…...
黑马智数Day3
渲染基础Table列表 封装接口: export function getCardListAPI(params) {return request({url: /parking/card/list,params}) } 具体实现: import { getCardListAPI } from /apis/cardexport default {data() {return {// 请求参数params: {page: 1,pa…...
【Java】再一次踩了整数溢出的坑
【Java】再一次踩了整数溢出的坑 一、起因原题示例 1示例 2提示 我的代码提交结果 二、思考修改后的代码如下 三、知识点1. int m l ((r - l) / 2)解释 2. if (m < x / m)解释 四、结尾 一、起因 我在做【力扣】69.x 的平方根 一题的时候,明明觉得逻辑没问题&…...
Windows开发工具使用技巧大揭秘:让编码效率翻倍的秘籍!
【ACM出版|厦大主办|EI稳定检索】第五届计算机科学与管理科技国际学术会议(ICCSMT 2024)_艾思科蓝_学术一站式服务平台 更多学术会议请看:学术会议-学术交流征稿-学术会议在线-艾思科蓝 目录 引言 1. 快捷键大全:加速你的编码…...
CSS外边距
元素的外边距(margin)是围绕在元素边框以外(不包括边框)的空白区域,这片区域不受 background 属性的影响,始终是透明的。 为元素设置外边距 默认情况下如果不设置外边距属性,HTML 元素就是不会…...
C++ set,multiset与map,multimap的基本使用
1. 序列式容器和关联式容器 string、vector、list、deque、array、forward_list等STL容器统称为序列式容器,因为逻辑结构为线性序列的数据结构,两个位置存储的值之间一般没有紧密的关联关系,比如交换一下,他依旧是序列式容器。顺…...
评估潜力无限:解读自闭症患者的工作能力评估
在星贝育园这片充满爱与希望的土地上,我们不仅见证了无数自闭症儿童在康复训练中的点滴进步,更深刻理解了他们内在潜力的无限可能。自闭症,这一复杂的神经发育障碍,常常让外界对其患者的工作能力产生误解和偏见。然而,…...
js 实现视频封面截图
今天给大家分享一下,如何实现视频封面截取功能,这里主要用到了 HTML5 的 canvas 相关的 api 和 js 相关的一些知识,话不多说,直接上代码: <template><div><div class"margin-tb-sm"><…...
Hadoop FileSystem Shell 常用操作命令
提示:本文章只总结一下常用的哈,详细的命令大家可以移步官方的文档(链接贴在下面了哈🤣)— HDFS官方命令手册链接。 目录 1. cat 命令:查看 HDFS 文件内容2. put 命令:将本地文件上传到 HDFS3.…...
uniapp EChars图表
1. uniapp EChars图表 (1)Apache ECharts 一个基于 JavaScript 的开源可视化图表库 https://echarts.apache.org/examples/zh/index.html (1)官网图例 (2)个人实现图例 1.1. 下载echart 1.1.1. 下…...
最新版ingress-nginx-controller安装 使用host主机模式
最新版ingress-nginx-controller安装 使用host主机模式 文章目录 最新版ingress-nginx-controller安装 使用host主机模式单节点安装方式多节点高可用安装方式 官方参考链接: https://github.com/kubernetes/ingress-nginx/ https://kubernetes.github.io/ingress-ng…...
廊坊做网站电话/seo关键词优化外包
问题描述:父组件传如lesser和larger两个参数,并且是ajax从服务器获取的。子组件定义created阶段输出lesser和larger。但larger为空。改成延迟输出则正确。问题来源:https://segmentfault.com/q/1010000008912491提问者的主要问题是没有搞清楚…...
app 网站开发公司/长沙网动网络科技有限公司
为什么80%的码农都做不了架构师?>>> UICollectionViewController中有collectionView;而collectionView有UICollectionViewCell; 因此UITableViewController会有collectionView 和collectionViewCell,控制器默认已经遵守数据源和代理方法了&a…...
wordpress统计插件下载/网络营销主要有哪些特点
智慧树知到_大数据可视化_2020章节测试答案更多相关问题(103)5______.计算:(-3)332______.计算:(-3)332______.若m为正整数,且a-1,则-(࿰…...
一个公司主体可以在多个网站做备案/竞价托管哪家便宜
1.content-type 是指 数据在 http网络通信的时候,字符串的类型 2.请求有 发送数据的 content-type,有 可以接收的content-type 3.而 编码格式,只是 这个字符串里面的 字符的 编码格式,content-type是这个字符串的 类型 4.而有些特定的 需求 例…...
聊城手机网站建设方案/怎么注册域名网址
前面写了一个参数估计,现在也顺便把假设检验也总结一下吧,主要参考书还是那本《概率论与数理统计》(陈希孺)。 假设检验就是提出一个命题,根据样本判断对错。 问题提法 有一个已知分布的总体,其中个别参数未知。现在抽取的一组样本…...
微信知彼网络网站建设/dw友情链接怎么设置
最近几年,在DDD的领域,我们经常会看到CQRS架构的概念。我个人也写了一个ENode框架,专门用来实现这个架构。CQRS架构本身的思想其实非常简单,就是读写分离。是一个很好理解的思想。就像我们用MySQL数据库的主备,数据写到…...