图注意网络GAT理解及Pytorch代码实现【PyGAT代码详细注释】
文章目录
- GAT
- 代码实现【PyGAT】
- GraphAttentionLayer【一个图注意力层实现】
- 用上面实现的单层网络测试
- 加入Multi-head机制的GAT
- 对数据集Cora的处理
- csr_matrix()处理稀疏矩阵
- encode_onehot()对label编号
- build graph
- 邻接矩阵构造
- GAT的推广
GAT
题:Graph Attention Networks
摘要:
提出了图形注意网络(GAT) ,这是一种基于图结构数据的新型神经网络结构,利用掩蔽的自我注意层来解决基于图卷积或其近似的先前方法的缺点。通过叠加层,节点能够参与其邻域的特征,我们能够(隐式地)为邻域中的不同节点指定不同的权重,而不需要任何代价高昂的矩阵操作(如反演) ,或者依赖于预先知道图的结构。通过这种方法,我们同时解决了基于谱的图形神经网络的几个关键问题,并使我们的模型容易地适用于归纳和转导问题。我们的 GAT 模型已经实现或匹配了四个已建立的转导和归纳图基准的最新结果: Cora,Citeseer 和 Pubmed 引用网络数据集,以及protein-protein interaction dataset(其中测试图在训练期间保持不可见)。
在Paper with code 网址,可找到对应论文和github源码,原论文使用TensorFlow实现,本篇主要对Pytorch版本的 PyGAT附详细注释帮助理解和测试。
GitHUb: keras版本实现
Pytorch版本实现 PyGAT
截图及下文代码注释参考自视频:GAT详解及代码实现
视频中的eij的实现与源码不同,视频中是先拼接两个W,再与a乘;
源码在_prepare_attentional_mechanism_input()函数中先分别与a乘,再拼接。
代码实现【PyGAT】
在PyGAT :
- layers.py中定义Simple GAT layer实现(GraphAttentionLayer)和Sparse version GAT layer实现(SpGraphAttentionLayer)。
- models.py 实现两个版本加入Multi-head机制
- trains.py 使用model定义的GAT构建模型进行训练,使用cora数据集
GraphAttentionLayer【一个图注意力层实现】
class GraphAttentionLayer(nn.Module):"""Simple GAT layer, similar to https://arxiv.org/abs/1710.10903"""def __init__(self, in_features, out_features, dropout, alpha, concat=True):super(GraphAttentionLayer, self).__init__()self.dropout = dropoutself.in_features = in_features#结点向量的特征维度self.out_features = out_features#经过GAT之后的特征维度self.alpha = alpha#dropout参数self.concat = concat#LeakyReLU参数# 定义可训练参数,即论文中的W和aself.W = nn.Parameter(torch.empty(size=(in_features, out_features)))nn.init.xavier_uniform_(self.W.data, gain=1.414)# xavier初始化self.a = nn.Parameter(torch.empty(size=(2*out_features, 1)))nn.init.xavier_uniform_(self.a.data, gain=1.414)# xavier初始化# 定义leakyReLU激活函数self.leakyrelu = nn.LeakyReLU(self.alpha)def forward(self, h, adj):'''adj图邻接矩阵,维度[N,N]非零即一h.shape: (N, in_features), self.W.shape:(in_features,out_features)Wh.shape: (N, out_features)'''Wh = torch.mm(h, self.W) # 对应eij的计算公式e = self._prepare_attentional_mechanism_input(Wh)#对应LeakyReLU(eij)计算公式zero_vec = -9e15*torch.ones_like(e)#将没有链接的边设置为负无穷attention = torch.where(adj > 0, e, zero_vec)#[N,N]# 表示如果邻接矩阵元素大于0时,则两个节点有连接,该位置的注意力系数保留# 否则需要mask设置为非常小的值,因为softmax的时候这个最小值会不考虑attention = F.softmax(attention, dim=1)# softmax形状保持不变[N,N],得到归一化的注意力全忠!attention = F.dropout(attention, self.dropout, training=self.training)# dropout,防止过拟合h_prime = torch.matmul(attention, Wh)#[N,N].[N,out_features]=>[N,out_features]# 得到由周围节点通过注意力权重进行更新后的表示if self.concat:return F.elu(h_prime)else:return h_primedef _prepare_attentional_mechanism_input(self, Wh):# Wh.shape (N, out_feature)# self.a.shape (2 * out_feature, 1)# Wh1&2.shape (N, 1)# e.shape (N, N)# 先分别与a相乘再进行拼接Wh1 = torch.matmul(Wh, self.a[:self.out_features, :])Wh2 = torch.matmul(Wh, self.a[self.out_features:, :])# broadcast adde = Wh1 + Wh2.Treturn self.leakyrelu(e)def __repr__(self):return self.__class__.__name__ + ' (' + str(self.in_features) + ' -> ' + str(self.out_features) + ')'
用上面实现的单层网络测试
x = torch.randn(6,10)
adj=torch.tensor([[0,1,1,0,0,0],[1,0,1,0,0,0],[1,1,0,1,0,0],[0,0,1,0,1,1],[0,0,0,1,0,0,],[0,0,0,1,1,0]])
my_gat = GraphAttentionLayer(10,5,0.2,0.2)
print(my_gat(x,adj))
输出:
tensor([[-0.2965, 2.8110, -0.6680, -0.9643, -0.9882],[ 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],[-0.4981, -0.7515, 1.1159, 0.3546, 1.3592],[ 0.4679, 1.7208, 0.3084, -0.5331, -0.1291],[-0.4375, -0.8778, 1.1767, -0.5869, 1.5154],[-0.2164, -0.5897, 0.4988, -0.3125, 0.6423]], grad_fn=<EluBackward>)
加入Multi-head机制的GAT
用不同head捕捉不同特征,使模型有更好的拟合能力。
class GAT(nn.Module):def __init__(self, nfeat, nhid, nclass, dropout, alpha, nheads):"""Dense version of GAT."""super(GAT, self).__init__()self.dropout = dropout# 加入Multi-head机制self.attentions = [GraphAttentionLayer(nfeat, nhid, dropout=dropout, alpha=alpha, concat=True) for _ in range(nheads)]for i, attention in enumerate(self.attentions):self.add_module('attention_{}'.format(i), attention)self.out_att = GraphAttentionLayer(nhid * nheads, nclass, dropout=dropout, alpha=alpha, concat=False)def forward(self, x, adj):x = F.dropout(x, self.dropout, training=self.training)x = torch.cat([att(x, adj) for att in self.attentions], dim=1)x = F.dropout(x, self.dropout, training=self.training)x = F.elu(self.out_att(x, adj))return F.log_softmax(x, dim=1)
对数据集Cora的处理
数据集中两个文件,cites:比如上图11行:编号25和编号1114331的文章
content文件:如下图,每篇文章的id、features及类别
csr_matrix()处理稀疏矩阵
utils.py中对数据进行的处理
#数据是稀疏的,csr_matrix操作从行开始将1的位置取出来,对数据进行压缩
features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)
labels = encode_onehot(idx_features_labels[:, -1])
encode_onehot()对label编号
有7个类别,通过classes_dict是7*7的对角阵把每个类别映射成不同向量,对所有label进行编号,再将编号转换为one_hot向量
def encode_onehot(labels):# The classes must be sorted before encoding to enable static class encoding.# In other words, make sure the first class always maps to index 0.classes = sorted(list(set(labels)))classes_dict = {c: np.identity(len(classes))[i, :] for i, c in enumerate(classes)}labels_onehot = np.array(list(map(classes_dict.get, labels)), dtype=np.int32)return labels_onehot
build graph
见注释:
# build graphidx = np.array(idx_features_labels[:, 0], dtype=np.int32)#获取所有文章ididx_map = {j: i for i, j in enumerate(idx)}#按文章数目,对id重新映射# 读取数据集中文章和文章直接的引用关系edges_unordered = np.genfromtxt("{}{}.cites".format(path, dataset), dtype=np.int32)# 根据idx_map,将文章引用关系也重新映射edges = np.array(list(map(idx_map.get, edges_unordered.flatten())), dtype=np.int32).reshape(edges_unordered.shape)adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])), shape=(labels.shape[0], labels.shape[0]), dtype=np.float32)# build symmetric adjacency matrix 生成邻接矩阵adj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)features = normalize_features(features)adj = normalize_adj(adj + sp.eye(adj.shape[0]))
邻接矩阵构造
csr_matrix()只记录了(0,1)1,忽略了(1,0)1。所以需要coo_matrix()操作!才能还原出无向图的邻接矩阵!
本文的一些代码注释及截图还可见视频
一个拓展:
GAT的推广
GAT的推广
GAT仅仅是应用在了单层图结构网络上,我们是否可以将它推广到多层网络结构呢?
这里我们假设一个有N层网络的结构,每层网络都定义了相同的节点,但是节点之间的关系有所差异。举一个简单的例子,假设有一个用户关系型网络,每层网络的节点都被定义成了网络中的用户,网络的第一层视图的关系可以定义为,两个用户之间是否具有好友关系;网络的第二层视图可以定义为,你评论过我的动态;网络的第三层视图可以定义为你转发过我的动态;第四层关系可以定义为,你at过我等等。
通过这样的定义我们就完成了一个多层网络的构建,他们共享相同的节点,但又分别具有不同的邻边,如果我们分别处理每一层视图视图,然后将他们得出的节点表示单纯相加的话,就可能会失去不同视图之间的协作关系,降低分类(预测)的精度。
基于以上观点,我们提出了一种新的方法:首先在每一层单视图中应用GAT进行学习,并计算出每层视图的节点表示。之后再不同视图之间引入attention机制来让网络自行学习不同视图的权重。之后根据学习的权重,将各个视图加权相加得到全局节点表示并进行后续的诸如节点表示,链接预测等任务。
同时,因为不同视图共享同样的节点,即使每一层视图都表示了不同的节点关系,最终得到的每一层的节点嵌入表示应具有一定的相关性。基于以上理论,我们在每层GAT的网络参数间引入正则化项来约束参数,使其向互相相近的方向学习。大致的网络流程图如下:
这部分来源于 链接:https://www.jianshu.com/p/d5d366ba1a57 来源:简书
相关文章:
图注意网络GAT理解及Pytorch代码实现【PyGAT代码详细注释】
文章目录GAT代码实现【PyGAT】GraphAttentionLayer【一个图注意力层实现】用上面实现的单层网络测试加入Multi-head机制的GAT对数据集Cora的处理csr_matrix()处理稀疏矩阵encode_onehot()对label编号build graph邻接矩阵构造GAT的推广GAT 题:Graph Attention Netwo…...
项目成本管理中的常见误区及解决方案
做过项目的人都明白,项目实施时间一般很长,在实施期间总有很多项目结果不尽人意的问题。要使一个项目取得成功,就要结合很多因素一起才能作用,其中做好项目成本的管理就是最重要的步骤之一,下面列出了常见的项目成本管…...
墨天轮2022年度数据库获奖名单
2022年,国家相继从高位部署、省级试点布局、地市重点深入三个维度,颁布了多项中国数据库行业发展的利好政策。但是我们也能清晰地看到,中国数据库行业发展之路道阻且长,而道路上的“拦路虎”之一则是生态。中国数据库的发展需要多…...
仓储调度|库存管理系统
技术:Java、JSP等摘要:随着电子商务技术和网络技术的快速发展,现代物流技术也在不断进步。物流技术是指与物流要素活动有关的所有专业技术的总称,包括各种操作方法、管理技能等,物流业采用某些现代信息技术方面的成功经…...
Canvas入门-01
导读: 读完全文需要2min。通过这篇文章,你可以了解到以下内容: Canvas标签基本属性如何使用Canvas画矩形、圆形、线条、曲线、笑脸😊 如果你曾经了解过Canvas,可以对照目录回忆一下能否回答上来 毕竟带着问题学习最有效…...
运算符优先级
醋坛酸味罐,位落跳福豆 醋:初等运算符: () [] -> . 坛:单目运算符: - ~ – * & ! sizeof 右结合 酸:算术运算符: - * / % 味:位移运算符:>> << …...
微信小程序使用scss编译wxss文件的配置步骤
文章目录1、在 vscode 中搜索 easysass 插件并安装2、在微信开发工具中导入安装的easysass插件3、修改 spook.easysass-0.0.6/package.json 文件中的配置4、重启开发者工具,就可用使用了微信小程序开发者工具集成了 vscode 编辑器,可以使用 vscode 中众多…...
一步一步教你如何使用 Visual Studio Code 编译一段 C# 代码
以下是一步一步教你如何使用 Visual Studio Code 编写使用 C# 语言输出当前日期和时间的代码: 1、下载并安装 .NET SDK。您可以从 Microsoft 官网下载并安装它。 2、打开 Visual Studio Code,并安装 C# 扩展。您可以在 Visual Studio Code 中通过扩展菜…...
vue-cli中的环境变量注意点
在客户端侧代码中使用环境变量只有以 VUE_APP_ 开头的变量会被 webpack.DefinePlugin 静态嵌入到客户端侧的包中。你可以在应用的代码中这样访问它们:console.log(process.env.VUE_APP_SECRET)在构建过程中,process.env.VUE_APP_SECRET 将会被相应的值所…...
2.3数据类型
文章目录1. 命名规则2.字符3.数字4.日期5.图片1. 命名规则 字段名必须以字母开头,尽量不要使用拼音长度不能超过30个字符(不同数据库,不同版本会有不同)不能使用SQL的保留字,如where,order,group只能使用如下字符a-z、…...
Kafka基本概念
什么是Kafka Kafka是一个消息系统。它可以集中收集生产者的消息,并由消费者按需获取。在Kafka中,也将消息称为日志(log)。 一个系统,若仅有一类或者少量的消息,可直接进行发送和接收。 随着业务量日益复杂,消息的种类…...
使用QueryBuilders、NativeSearchQuery实现复杂查询
使用QueryBuilders、NativeSearchQuery实现复杂查询 本文继续前面文章《ElasticSearch系列(二)springboot中集成使用ElasticSearch的Demo》,在前文中,我们介绍了使用springdata做一些简单查询,但是要实现一些高级的组…...
taobao.open.account.update( Open Account数据更新 )
¥开放平台免费API不需用户授权 Open Account数据更新 公共参数 请求地址: HTTP地址 http://gw.api.taobao.com/router/rest 公共请求参数: 公共响应参数: 响应参数 点击获取key和secret 请求示例 TaobaoClient client new DefaultTaobaoClient(url, appkey, sec…...
PT100铂电阻温度传感器
PT100温度传感器又叫做铂热电阻。 热电阻是中低温区﹡常用的一种温度检测器。它的主要特点是测量精度高,性能稳定。其中铂热电阻的测量精确度是﹡高的,它不仅广泛应用于工业测温,而且被制成标准的基准仪。金属热…...
蓝桥杯-本质上升序列
没有白走的路,每一步都算数🎈🎈🎈 题目描述: 小蓝特别喜欢单调递增的事物 在一个字符串中如果取出若干个字符,按照在原来字符串中的顺序排列在一起,组成的新的字符串如果是单调递增的…...
synchronized锁重入验证
文章目录synchronized锁重入验证1. 可重入锁2. synchronized锁重入2.1 本类同步方法内部调用本类其它同步方法2.2 子类同步方法内部调用父类的同步方法2.3 A类的同步方法内部调用B类的同步方法3. synchronized修饰方法写法synchronized锁重入验证 1. 可重入锁 可重入锁&#…...
超简单的计数排序!!
假设给定混乱数据为:3,0,1,3,6,5,4,2,1,9。 下面我们将通过使用计数排序的思想来完成对上面数据的排序。(先不谈负数) 计数排序 该排序的思路和它的名字一样…...
发现新大陆——原来软件开发根本不需要会编码(看我10分钟应用上线)
目录 一、前言 二、官网基础功能及搭建 三、体验过程 01、连接数据源 02、设计表单 03、流程设计 04、图表呈现 05、组织架构设置 五、效率评价 六、小结 一、前言 众所周知,每家公司在发展过程中都需要构建大量的内部系统, 如运营使用的用户…...
【Leedcode】栈和队列必备的面试题(第二期)
【Leedcode】栈和队列必备的面试题(第二期) 文章目录【Leedcode】栈和队列必备的面试题(第二期)一、题目(用两个队列实现栈)二、思路图解1.定义两个队列2.初始化两个队列3.往两个队列中放入数据4.两个队列出…...
Elasticsearch实战之(商品搜索API实现)
Elasticsearch实战之(商品搜索API实现) 1、案例介绍 某医药电商H5商城基于Elasticsearch实现商品搜索 2、案例分析 2.1、数据来源 商品库 - 平台运营维护商品库 - 供应商维护 2.2、数据同步 2.2.1、同步双写 写入 MySQL,直接也同步往…...
剑指 Offer 14-剪绳子
摘要 剑指 Offer 14- I. 剪绳子 剑指 Offer 14- II. 剪绳子 II 343. 整数拆分 一、动态规划解析 这道题给定一个大于1的正整数n,要求将n 拆分成至少两个正整数的和,并使这些正整数的乘积最大化,返回最大乘积。令x是拆分出的第…...
泰克示波器|MSO64示波器的应用
泰克新一代示波器MSO64为实例来讲解时频域信号分析技术。MSO64采用全新TEK049平台,不仅实现了4通道同时打开时25GS/s的高采样率,而且实现了12-bit高垂直分辨率。同时,由于采用了新型低噪声前端放大ASIC—TEK061,大大降低了噪声水平…...
1.4 黑群晖安装:SataPortMap和DiskIdxMap两种获取方式
tinycore及安装工具下载:工具:链接:https://pan.baidu.com/s/1CMLl6waOuW-Ys2gKZx7Jgg?pwdchct提取码:chcttinycore:链接:https://pan.baidu.com/s/19lchzLj-WDXPQu2cEcskBg?pwddcw2 提取码:d…...
JVM虚拟机概述(2)
3.JVM 运行时数据区 3.1.1 程序计数器(Program Counter Register) 是一块很小的内存空间,用来记录每个线程运行的指令位置,是线程私有的,每个线程都拥有一个程序计数器,生命周期与线程一致,是运行时数据区中唯一一个不…...
Intel CSME 简述
SME 算是 Intel X86 PC 上最神秘的部分了,本文根据 us-19-Hasarfaty-Behind-The-Scenes-Of-Intel-Security-And-Manageability-Engine 一文写成。讲述内容无法证伪,各位随便听听即可,了解这些能够帮助BIOS 工程师更好的理解一些操作的实现。文章基于 Intel 第八代第九代CPU(…...
复位理论基础
先收集资料,了解当前常用的基础理论和实现方式 复位 初始化微控制器内部电路 将所有寄存器恢复成默认值确认MCU的工作模式禁止全局中断关闭外设将IO设置为高阻输入状态等待时钟趋于稳定从固定地址取得复位向量并开始执行 造成复位的原因 有多种引起复位的因素&…...
Python基础知识——列表
列表 列表是可以存放任何数据,包括整型,浮点型,字符串,布尔型等等,是常用的数据类型之一。 1.列表的创建 列表也是一个可迭代对象 1. 普通形式l [1,2,3,4,5] ---整型列表l ["a","b","c&…...
如何使用工时表管理项目和非项目的资源?
对新机会做出反应的能力是企业竞争优势的关键。项目不断涌现,企业需要了解具体的可用性以及是否有资源来接受新事物。更进一步来说,企业需要知道员工将时间花在哪里。 使用 8Manage工时表解决方案,你将始终拥有做出正确业务决策所需的全面知…...
项目经理如何做好质量保证与标准维持?非技术项目经理如何做好质量管控?
项目经理如何做好质量保证与标准维持?非技术项目经理如何做好质量管控?01.质量保障需要重视哪些执行层面的细节02.非技术出身项目经理如何做好质量保障工作03.质量管理除了PDCA,还有哪些推荐的方法04.质量保证与标准维持,作为常态…...
[文件操作] File 类的用法和 InputStream, OutputStream 的用法
能吃是不是件幸福的事呢 文章目录前言1. 文件的相关定义2. 文件类型3. Java对文件系统的操作3.1 对文件的基础操作3.2 读文件3.3 写文件前言 从这章开始,我们就开始学文件操作相关的知识了~ 1. 文件的相关定义 1.文件的定义可以从狭义和广义两个方面解释. 狭义: 指硬盘上的文…...
大连的网站制作公司/百度一下就知道百度首页
Apache的mod_rewrite是提供了强大URL操作的杀手级的模块,可以实现几乎所有你梦想的URL操作类型,其代价是你必须接受其复杂性,因为mod_rewrite的主要障碍就是初学者不容易理解和运用,即使是Apache专家有时也会发掘出mod_rewrite的新…...
软件网站开发评估/seo关键词推广价格
饭卡 Time Limit: 5000/1000 MS (Java/Others) Memory Limit: 32768/32768 K (Java/Others) Total Submission(s): 514 Accepted Submission(s): 226Problem Description电子科大本部食堂的饭卡有一种很诡异的设计,即在购买之前判断余额。如果购买一个商品之前&…...
微信公众号网站建设/2345纯净版推广包
原文出处:https://wiki.ubuntu.com/IptablesHowTo 原文作者:UbuntuWiki 授权许可:创作共享协议 GNU FDL 翻译人员:denven 校对人员:5451vs5451 贡献者: 适用版本: 文章状态:翻译完成…...
用什么做网站开发/win7优化软件
今天主要是寻找板卡问题然后维修,这次生产了600PCS板卡,第一次小批量生产。记得第一次打样是4PCS,很多问题都无法暴露出来,这次600PCS就暴露不少问题了。首先功耗问题就有两个,然后其他的小问题有几个,所以…...
二手车网站制作贵吗/seo关键词查询工具
本位参考自:http://www.xifenfei.com/1527.html 目的:将已经offline掉的datafile 5 的scn信息改为与其他datafile一致。 db版本为11.2.0.4 背景知识: 1、datafile 的file header 存储在第一个block里 2、Oracle considers four attributes …...
网站推广的方式包括/最近韩国电影片
环境:STM32CubeMxClionMatlabPCAN利用STM32CubeMx配置STM32的CAN通信:默认这两项终端是Enabled的,这里去掉,否则CAN无法同时收发,会卡在接收中断服务里:通过STM32CubeMx配置的CAN后,还需要在生成…...