【python函数】torch.nn.Embedding函数用法图解
学习SAM模型的时候,第一次看见了nn.Embedding函数,以前接触CV比较多,很少学习词嵌入方面的,找了一些资料一开始也不是很理解,多看了两遍后,突然顿悟,特此记录。
SAM中PromptEncoder中运用nn.Embedding:
point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)]
torch.nn.Embedding官方页面
1. torch.nn.Embedding介绍
(1)词嵌入简介
关于词嵌入,这篇文章讲的挺清楚的,相比于One-hot编码,Embedding方式更方便计算,例如在“就在江湖之上”整个词典中,要编码“江湖”两个字,One-hot编码需要 [ l e n g t h , w o r d _ c o u n t ] {[length, word\_count]} [length,word_count] 大小的张量,其中 w o r d _ c o u n t {word\_count} word_count 为词典中所有词的总数,而Embedding方式的嵌入维度 e m b e d d i n g _ d i m {embedding\_dim} embedding_dim 可远远小于 w o r d _ c o u n t {word\_count} word_count 。在运用Embedding方式编码的词典时,只需要词的索引,下图例子中: “江湖”——>[2, 3]

(2)重要参数介绍
nn.embedding就相当于一个词典嵌入表:
torch.nn.Embedding(num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, _freeze=False, device=None, dtype=None)
常用参数:
① num_embeddings (int): 词典中词的总数
② embedding_dim (int): 词典中每个词的嵌入维度
③ padding_idx (int, optional): 填充索引,在padding_idx处的嵌入向量在训练过程中没有更新,即它是一个固定的“pad”。对于新构造的Embedding,在padding_idx处的嵌入向量将默认为全零,但可以更新为另一个值以用作填充向量。
输入: I n p u t ( ∗ ) {Input(∗)} Input(∗): IntTensor 或者 LongTensor,为任意size的张量,包含要提取的所有词索引。
输出: O u t p u t ( ∗ , H ) {Output(∗, H)} Output(∗,H): ∗ {∗} ∗ 为输入张量的size, H {H} H = embedding_dim
2. torch.nn.Embedding用法
(1)基本用法
官方例子如下:
import torch
import torch.nn as nnembedding = nn.Embedding(10, 3)
x = torch.LongTensor([[1, 2, 4, 5], [4, 3, 2, 9]])y = embedding(x)print('权重:\n', embedding.weight)
print('输出:')
print(y)
查看权重与输出,打印如下:
权重:Parameter containing:
tensor([[ 1.4212, 0.6127, -1.1126],[ 0.4294, -1.0121, -1.8348],[-0.0315, -1.2234, -0.4589],[ 0.6131, -0.4381, 0.1253],[-1.0621, -0.1466, 1.7412],[ 1.0708, -0.7888, -0.0177],[-0.5979, 0.6465, 0.6508],[-0.5608, -0.3802, -0.4206],[ 1.1516, 0.4091, 1.2477],[-0.5753, 0.1394, 2.3447]], requires_grad=True)
输出:
tensor([[[ 0.4294, -1.0121, -1.8348],[-0.0315, -1.2234, -0.4589],[-1.0621, -0.1466, 1.7412],[ 1.0708, -0.7888, -0.0177]],[[-1.0621, -0.1466, 1.7412],[ 0.6131, -0.4381, 0.1253],[-0.0315, -1.2234, -0.4589],[-0.5753, 0.1394, 2.3447]]], grad_fn=<EmbeddingBackward0>)
家人们,发现了什么,输入 x {x} x 的 s i z e {size} size 大小为 [ 2 , 4 ] {[2, 4]} [2,4] ,输出 y {y} y 的 s i z e {size} size 大小为 [ 2 , 4 , 3 ] {[2, 4, 3]} [2,4,3] ,下图清晰的展示出nn.Embedding干了个什么事儿:

nn.Embedding相当于是一本词典,本例中,词典中一共有10个词 X 0 {X_0} X0~ X 9 {X_9} X9,每个词的嵌入维度为3,输入 x {x} x 中记录词在词典中的索引,输出 y {y} y 为输入 x {x} x 经词典编码后的映射。
注意:此时存在一个问题,词索引是不能超出词典的最大容量的,即本例中,输入 x {x} x 中的数值取值范围为 [ 0 , 9 ] {[0, 9]} [0,9]。
(2)自定义词典权重
如上所示,在未定义时,nn.Embedding的自动初始化权重满足 N ( 0 , 1 ) {N(0,1)} N(0,1) 分布,此外,nn.Embedding的权重也可以通过from_pretrained来自定义:
import torch
import torch.nn as nnweight = torch.FloatTensor([[1, 2.3, 3], [4, 5.1, 6.3]])
embedding = nn.Embedding.from_pretrained(weight)
x = torch.LongTensor([1, 0, 0])
y = embedding(x)
print(y)
输出为:
tensor([[4.0000, 5.1000, 6.3000],[1.0000, 2.3000, 3.0000],[1.0000, 2.3000, 3.0000]])
(3)padding_idx用法
padding_idx可用于指定词典中哪一个索引的词填充为0。
import torch
import torch.nn as nnembedding = nn.Embedding(10, 3, padding_idx=5)
x = torch.LongTensor([[5, 2, 0, 5]])
y = embedding(x)
print('权重:\n', embedding.weight)
print('输出:')
print(y)
输出为:
权重:Parameter containing:
tensor([[ 0.1831, -0.0200, 0.7023],[ 0.2751, -0.1189, -0.3325],[-0.5242, -0.2230, -1.1677],[-0.4078, -1.2141, 1.3185],[ 0.8973, -0.9650, 0.5420],[ 0.0000, 0.0000, 0.0000],[ 0.0597, 0.6810, -0.2595],[ 0.6543, -0.6242, 0.2337],[-0.0780, -0.9607, -0.0618],[ 0.2801, -0.6041, -1.4143]], requires_grad=True)
输出:
tensor([[[ 0.0000, 0.0000, 0.0000],[-0.5242, -0.2230, -1.1677],[ 0.1831, -0.0200, 0.7023],[ 0.0000, 0.0000, 0.0000]]], grad_fn=<EmbeddingBackward0>)
词典中,被padding_idx标定后的词嵌入向量可被重新定义:
import torch
import torch.nn as nnpadding_idx=2
embedding = nn.Embedding(3, 3, padding_idx=padding_idx)
print('权重:\n', embedding.weight)with torch.no_grad():embedding.weight[padding_idx] = torch.tensor([1.1, 2.2, 3.3])
print('权重:\n', embedding.weight)
输出为:
权重:Parameter containing:
tensor([[ 0.7247, 0.7553, -1.8226],[-1.3304, -0.5025, 0.5237],[ 0.0000, 0.0000, 0.0000]], requires_grad=True)
权重:Parameter containing:
tensor([[ 0.7247, 0.7553, -1.8226],[-1.3304, -0.5025, 0.5237],[ 1.1000, 2.2000, 3.3000]], requires_grad=True)
相关文章:
【python函数】torch.nn.Embedding函数用法图解
学习SAM模型的时候,第一次看见了nn.Embedding函数,以前接触CV比较多,很少学习词嵌入方面的,找了一些资料一开始也不是很理解,多看了两遍后,突然顿悟,特此记录。 SAM中PromptEncoder中运用nn.Emb…...
with ldid... /opt/MonkeyDev/bin/md: line 326: ldid: command not found
吐槽傻逼xcode 根据提示 执行了这个脚本/opt/MonkeyDev/bin/md 往这里面添加你brew install 安装文件的目录即可...
[golang gui]fyne框架代码示例
1、下载GO Go语言中文网 golang安装包 - 阿里镜像站(镜像站使用方法:查找最新非rc版本的golang安装包) golang安装包 - 中科大镜像站 go二进制文件下载 - 南京大学开源镜像站 Go语言官网(Google中国) Go语言官网(Go团队) 截至目前(2023年9月17日&#x…...
2000-2018年各省能源消费和碳排放数据
2000-2018年各省能源消费和碳排放数据 1、时间:2000-2018年 2、范围:30个省市 3、指标:id、year、ENERGY、COAL、碳排放倒数*100 4、来源:能源年鉴 5、指标解释: 2018年碳排放和能源数据为插值法推算得到 碳排放…...
C# ref 学习1
ref 关键字用在四种不同的上下文中; 1.在方法签名和方法调用中,按引用将参数传递给方法。 2.在方法签名中,按引用将值返回给调用方。 3.在成员正文中,指示引用返回值是否作为调用方欲修改的引用被存储在本地,或在一般…...
MQ - 08 基础篇_消费者客户端SDK设计(下)
文章目录 导图Pre概述消费分组协调者消费分区分配策略轮询粘性自定义消费确认确认后删除数据确认后保存消费进度数据消费失败处理从服务端拉取数据失败本地业务数据处理失败提交位点信息失败总结导图 Pre...
Flutter层对于Android 13存储权限的适配问题
感觉很久没有写博客了,不对,的确是很久没有写博客了。原因我不怎么想说,玩物丧志了。后面渐渐要恢复之前的写作节奏。今天来聊聊我最近遇到的一个问题: Android 13版本对于storage权限的控制问题。 我们都知道,Andro…...
Android kotlin开源项目-功能标题目录
目录 一、BRVAH二、开源项目1、RV列表动效(标题目录)2、拖拽与侧滑(标题目录)3、数据库(标题目录)4、树形图(多级菜单)(标题目录)5、轮播图与头条(标题目录)6…...
Linux下,基于TCP与UDP协议,不同进程下单线程通信服务器
C语言实现Linux下,基于TCP与UDP协议,不同进程下单线程通信服务器 一、TCP单线程通信服务器 先运行server端,再运行client端输入"exit" 是退出 1.1 server_TCP.c **#include <my_head.h>#define PORT 6666 #define IP &qu…...
qt功能自己创作
按钮按下三秒禁用 void MainWindow::on_pushButton_5_clicked(){// 锁定界面setWidgetsEnabled(ui->centralwidget, false);// 创建一个定时器,等待3秒后解锁界面QTimer::singleShot(3000, this, []() {setWidgetsEnabled(ui->centralwidget, true);;//ui-&g…...
Linux网络编程:使用UDP和TCP协议实现网络通信
目录 一. 端口号的概念 二. 对于UDP和TCP协议的认识 三. 网络字节序 3.1 字节序的概念 3.2 网络通信中的字节序 3.3 本地地址格式和网络地址格式 四. socket编程的常用函数 4.1 sockaddr结构体 4.2 socket编程常见函数的功能和使用方法 五. UDP协议实现网络通信 5.…...
【后端速成 Vue】初识指令(上)
前言: Vue 会根据不同的指令,针对标签实现不同的功能。 在 Vue 中,指定就是带有 v- 前缀 的特殊 标签属性,比如: <div v-htmlstr> </div> 这里问题就来了,既然 Vue 会更具不同的指令&#…...
爬虫 — Scrapy-Redis
目录 一、背景1、数据库的发展历史2、NoSQL 和 SQL 数据库的比较 二、Redis1、特性2、作用3、应用场景4、用法5、安装及启动6、Redis 数据库简单使用7、Redis 常用五大数据类型7.1 Redis-String7.2 Redis-List (单值多value)7.3 Redis-Hash7.4 Redis-Set (不重复的)7.5 Redis-Z…...
tcpdump常用命令
需要安装 tcpdump wireshark ifconfig找到网卡名称 eth0, ens192... tcpdump需要root权限 网卡eth0 经过221.231.92.240:80的流量写入到http.cap tcpdump -i eth0 host 221.231.92.240 and port 80 -vvv -w http.cap ssh登录到主机查看排除ssh 22端口的报文 tcpdump -i …...
计算机网络运输层网络层补充
1 CDMA是码分多路复用技术 和CMSA不是一个东西 UPD是只确保发送 但是接收端收到之后(使用检验和校验 除了检验的部分相加 对比检验和是否相等。如果不相同就丢弃。 复用和分用是发生在上层和下层的问题。通过比如时分多路复用 频分多路复用等。TCP IP 应用层的IO多路复用。网…...
java CAS详解(深入源码剖析)
CAS是什么 CAS是compare and swap的缩写,即我们所说的比较交换。该操作的作用就是保证数据一致性、操作原子性。 cas是一种基于锁的操作,而且是乐观锁。在java中锁分为乐观锁和悲观锁。悲观锁是将资源锁住,等之前获得锁的线程释放锁之后&am…...
1786_MTALAB代码生成把通用函数生成独立文件
全部学习汇总: GitHub - GreyZhang/g_matlab: MATLAB once used to be my daily tool. After many years when I go back and read my old learning notes I felt maybe I still need it in the future. So, start this repo to keep some of my old learning notes…...
2023/09/19 qt day3
头文件 #ifndef WIDGET_H #define WIDGET_H #include <QWidget> #include <QDebug> #include <QTime> #include <QTimer> #include <QPushButton> #include <QTextEdit> #include <QLineEdit> #include <QLabel> #include &l…...
Docker 学习总结(78)—— Docker Rootless 让你的容器更安全
前言 在以 root 用户身份运行 Docker 会带来一些潜在的危害和安全风险,这些风险包括: 容器逃逸:如果一个容器以 root 权限运行,并且它包含了漏洞或者被攻击者滥用,那么攻击者可能会成功逃出容器,并在宿主系统上执行恶意操作。这会导致宿主系统的安全性受到威胁。 特权升…...
如何使用ArcGIS Pro将等高线转DEM
通常情况下,我们拿到的等高线数据一般都是CAD格式,如果要制作三维地形模型,使用栅格格式的DEM数据是更好的选择,这里就为大家介绍一下如何使用ArcGIS Pro将等高线转DEM,希望能对你有所帮助。 创建TIN 在工具箱中选择“…...
css实现圆环展示百分比,根据值动态展示所占比例
代码如下 <view class""><view class"circle-chart"><view v-if"!!num" class"pie-item" :style"{background: conic-gradient(var(--one-color) 0%,#E9E6F1 ${num}%),}"></view><view v-else …...
R语言AI模型部署方案:精准离线运行详解
R语言AI模型部署方案:精准离线运行详解 一、项目概述 本文将构建一个完整的R语言AI部署解决方案,实现鸢尾花分类模型的训练、保存、离线部署和预测功能。核心特点: 100%离线运行能力自包含环境依赖生产级错误处理跨平台兼容性模型版本管理# 文件结构说明 Iris_AI_Deployme…...
React第五十七节 Router中RouterProvider使用详解及注意事项
前言 在 React Router v6.4 中,RouterProvider 是一个核心组件,用于提供基于数据路由(data routers)的新型路由方案。 它替代了传统的 <BrowserRouter>,支持更强大的数据加载和操作功能(如 loader 和…...
OPENCV形态学基础之二腐蚀
一.腐蚀的原理 (图1) 数学表达式:dst(x,y) erode(src(x,y)) min(x,y)src(xx,yy) 腐蚀也是图像形态学的基本功能之一,腐蚀跟膨胀属于反向操作,膨胀是把图像图像变大,而腐蚀就是把图像变小。腐蚀后的图像变小变暗淡。 腐蚀…...
C++:多态机制详解
目录 一. 多态的概念 1.静态多态(编译时多态) 二.动态多态的定义及实现 1.多态的构成条件 2.虚函数 3.虚函数的重写/覆盖 4.虚函数重写的一些其他问题 1).协变 2).析构函数的重写 5.override 和 final关键字 1&#…...
纯 Java 项目(非 SpringBoot)集成 Mybatis-Plus 和 Mybatis-Plus-Join
纯 Java 项目(非 SpringBoot)集成 Mybatis-Plus 和 Mybatis-Plus-Join 1、依赖1.1、依赖版本1.2、pom.xml 2、代码2.1、SqlSession 构造器2.2、MybatisPlus代码生成器2.3、获取 config.yml 配置2.3.1、config.yml2.3.2、项目配置类 2.4、ftl 模板2.4.1、…...
uniapp 小程序 学习(一)
利用Hbuilder 创建项目 运行到内置浏览器看效果 下载微信小程序 安装到Hbuilder 下载地址 :开发者工具默认安装 设置服务端口号 在Hbuilder中设置微信小程序 配置 找到运行设置,将微信开发者工具放入到Hbuilder中, 打开后出现 如下 bug 解…...
离线语音识别方案分析
随着人工智能技术的不断发展,语音识别技术也得到了广泛的应用,从智能家居到车载系统,语音识别正在改变我们与设备的交互方式。尤其是离线语音识别,由于其在没有网络连接的情况下仍然能提供稳定、准确的语音处理能力,广…...
【Linux】Linux安装并配置RabbitMQ
目录 1. 安装 Erlang 2. 安装 RabbitMQ 2.1.添加 RabbitMQ 仓库 2.2.安装 RabbitMQ 3.配置 3.1.启动和管理服务 4. 访问管理界面 5.安装问题 6.修改密码 7.修改端口 7.1.找到文件 7.2.修改文件 1. 安装 Erlang 由于 RabbitMQ 是用 Erlang 编写的,需要先安…...
多元隐函数 偏导公式
我们来推导隐函数 z z ( x , y ) z z(x, y) zz(x,y) 的偏导公式,给定一个隐函数关系: F ( x , y , z ( x , y ) ) 0 F(x, y, z(x, y)) 0 F(x,y,z(x,y))0 🧠 目标: 求 ∂ z ∂ x \frac{\partial z}{\partial x} ∂x∂z、 …...
