[图神经网络]PyTorch简单实现一个GCN
Pytorch自带一个PyG的图神经网络库,和构建卷积神经网络类似。不同于卷积神经网络仅需重构__init__( )和forward( )两个函数,PyTorch必须额外重构propagate( )和message( )函数。
一、环境构建
①安装torch_geometric包。
pip install torch_geometric
②导入相关库
import torch
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric.nn as pyg_nn
from torch_geometric.datasets import Planetoid
二、PyG图学习架构
构建方法:首先继承MessagePassing类,接下来重写构造函数和以下三个方法:
message() #构建消息传递
aggregate() #将消息聚合到目标节点
update() #更新消息节点
1.构造函数
def __init__(self, aggr: Optional[str] = "add",flow: str = "source_to_target", node_dim: int = -2,decomposed_layers: int = 1):
参数 | 内容 |
aggr | 消息聚合的方式,常见的方法:add、mean、min、max |
flow | 消息传播的方向,source_to_target--从源节点到目标节点 target_to_source--从目标节点到源节点 |
node_dim | 传播的维度 |
2.propagate函数
该函数为消息传播的启动函数,调用此函数后会依次执行:message、aggregate、update来完成消息的传递、聚合、更新。
该函数声明如下:
propagate(self, edge_index: Adj, size: Size = None, **kwargs)
参数 | 说明 |
edge_index | 边索引 |
size | 邻接矩阵的尺寸,若为None则表示方阵 |
**kwargs | 额外参数 |
3.message函数
用于构建节点消息,传递给propagate的tensor可以映射到中心节点和邻居节点,仅需在相应的变量名后加上_i(中心节点)或_j(邻居节点)即可。
self.propagate(edge_index, x=x):passdef message(self, x_i, x_j, edge_index_i):pass
x_i | 中心节点构成的特征向量组成的矩阵 |
x_j | 邻居节点构成的特征向量组成的矩阵 |
edge_index_i | 中心节点的索引 |
4.aggregate函数
消息聚合函数,用以聚合来自邻居的消息,常见的方法有add、sum、mean、max,可以通过super().__init__()中的参数aggr来设定
5.update函数
用于更新节点的消息
三、GCN图卷积网络
GCN网络的原理可见:图卷积神经网络--GCN
需要注意 torch_scatter无法使用pip install加载可以参见 torch_scatter安装
1.加载数据集
from torch_geometric.datasets import Planetoiddataset = Planetoid(root='Cora', name='Cora')
Cora数据集是一个根据科学论文之间相互引用关系构建的图(Graph)数据集合,论文合计7类,共2708篇论文(2708个节点),10556条边。
2.定义GCN层
class GCNConv(MessagePassing):def __init__(self, in_channels, out_channels, add_self_loops=True, bias=True):super(GCNConv, self).__init__()self.add_self_loops = add_self_loopsself.edge_index = Noneself.linear = pyg_nn.dense.linear.Linear(in_channels, out_channels, weight_initializer='glorot')if bias:self.bias = nn.Parameter(torch.Tensor(out_channels, 1))self.bias = pyg_nn.inits.glorot(self.bias)else:self.register_parameter('bias', None)# 1.消息传递def message(self, x, edge_index):# 1.对所有节点进行新的空间映射x = self.linear(x) # [num_nodes, feature_size]# 2.添加偏置if self.bias != None:x += self.bias.flatten()# 3.返回source、target信息,对应边的起点和终点row, col = edge_index # [E]# 4.获得度矩阵deg = degree(col, x.shape[0], x.dtype) # [num_nodes]# 5.度矩阵归一化deg_inv_sqrt = deg.pow(-0.5) # [num_nodes]# 6.计算sqrt(deg(i)) * sqrt(deg(j))norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # [num_nodes]# 7.返回所有边的映射x_j = x[row] # [E, feature_size]# 8.计算归一化后的节点特征x_j = norm.view(-1, 1) * x_j # [E, feature_size]return x_j# 2.消息聚合def aggregate(self, x_j, edge_index):# 1.返回source、target信息,对应边的起点和终点row, col = edge_index # [E]# 2.聚合邻居特征aggr_out = scatter(x_j, row, dim=0, reduce='sum') # [num_nodes, feature_size]return aggr_out# 3.节点更新def update(self, aggr_out):# 对于GCN没有这个阶段,所以直接返回return aggr_outdef forward(self, x, edge_index):# 2.添加自环信息,考虑自身信息if self.add_self_loops:edge_index, _ = add_self_loops(edge_index, num_nodes=x.shape[0]) # [2, E]return self.propagate(edge_index, x=x)
3.定义GCN网络
class GCN(nn.Module):def __init__(self, num_node_features, num_classes):super(GCN, self).__init__()self.conv1 = GCNConv(num_node_features, 16)self.conv2 = GCNConv(16, num_classes)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = F.dropout(x, training=self.training)x = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)
4.模型调用
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备
epochs = 200 # 学习轮数
lr = 0.0003 # 学习率
num_node_features = dataset.num_node_features # 每个节点的特征数
num_classes = dataset.num_classes # 每个节点的类别数
data = dataset[0].to(device) # Cora的一张图# 4.定义模型
model = GCN(num_node_features, num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器
loss_function = nn.NLLLoss() # 损失函数# 训练模式
model.train()for epoch in range(epochs):optimizer.zero_grad()pred = model(data)loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度loss.backward()optimizer.step()if epoch % 20 == 0:print("【EPOCH: 】%s" % str(epoch + 1))print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')# 模型验证
model.eval()
pred = model(data)# 训练集(使用了掩码)
correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()
acc_train = correct_count_train / data.train_mask.sum().item()
loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集
correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()
acc_test = correct_count_test / data.test_mask.sum().item()
loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))
print('Test Accuracy: {:.4f}'.format(acc_test), 'Test Loss: {:.4f}'.format(loss_test))
相关文章:
![](https://www.ngui.cc/images/no-images.jpg)
[图神经网络]PyTorch简单实现一个GCN
Pytorch自带一个PyG的图神经网络库,和构建卷积神经网络类似。不同于卷积神经网络仅需重构__init__( )和forward( )两个函数,PyTorch必须额外重构propagate( )和message( )函数。 一、环境构建 ①安装torch_geometric包。 pip install torch_geometric …...
![](https://img-blog.csdnimg.cn/d7c28023f39d435681d633efb9673696.png)
Elasticsearch(黑马)
初识elasticsearch . 安装elasticsearch 1.部署单点es 1.1.创建网络 因为我们还需要部署kibana容器,因此需要让es和kibana容器互联。这里先创建一个网络: docker network create es-net 1.2.加载镜像 这里我们采用elasticsearch的7.12.1版本的…...
![](https://www.ngui.cc/images/no-images.jpg)
oracle数据库调整字段类型
oracle数据库更改字段类型比较墨迹,因为如果该字段有值,是不允许直接更改字段类型的。另外oralce不支持在指定的某个字段后面新增一个字段,但是mysql数据可以向指定的字段后面新增一个字段。 mysql向指定字段后面新增一个字段: al…...
![](https://img-blog.csdnimg.cn/e2fe2f2649c54c608fe6c150837009c3.gif)
面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码)
面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码) 目录 面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码) 1.面部表情识别方法 2.面部表情识别数据集 (1)表情识别数据集说明 (2&…...
![](https://img-blog.csdnimg.cn/img_convert/22f0529f04212c69cc4773cb607a6a5f.png)
赛效:如何在线给图片加水印
学会给图片加水印是一个非常实用的技能,可以让你的图片更具保护性和个性化。说到加水印,很多人不知道怎么操作。其实,给图片加水印非常简单,不用下载任何程序,在线就能完成。今天,我将介绍如何使用改图宝在…...
![](https://img-blog.csdnimg.cn/img_convert/d7f5fc50455382017acebbee5ed28fae.png)
动力节点杜老师Vue笔记——Vue程序初体验
一、Vue程序初体验 我们可以先不去了解Vue框架的发展历史、Vue框架有什么特点、Vue是谁开发的,这些对我们编写Vue程序起不到太大的作用,更何况现在说了一些特点之后,我们也没有办法彻底理解它,因此我们可以先学会用,使…...
![](https://img-blog.csdnimg.cn/20974e3d29fa446dbaa3a8a4cc5213b8.png)
ajax上传图片存入到指定的文件夹并回显
html代码: <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title><script src"js/jquery-2.1.0.js"></script> </head> <body> <form…...
![](https://www.ngui.cc/images/no-images.jpg)
cesium加载cesiumlab切的影像切片和标准TMS瓦片的区别
1.加载cesiumlab切的影像 var labImg viewer.scene.imageryLayers.addImageryProvider( new Cesium.UrlTemplateImageryProvider({url:http://192.168.1.25:8080/DOMtms/{z}/{x}/{y}.png,fileExtension : "png"})); 2.标准TMS瓦片 var labImg viewer.scene.im…...
![](https://img-blog.csdnimg.cn/5de72c299d62427ab9d7aa9e780c7798.png)
第二周P9-P22
文章目录第三章 系统总线3.1、总线的基本概念一、为什么要用总线二、什么是总线三、总线上信息的传送四、总线结构的计算机举例1、单总线结构框图2、面向CPU的双总线结构框图3、以存储器为中心的双总线结构图3.2、总线的分类1、片内总线2、系统总线3、通信走线3.3、总线特性及性…...
![](https://www.ngui.cc/images/no-images.jpg)
java反射
文章目录何为反射?反射的应用场景了解么?谈谈反射机制的优缺点优点缺点反射实战获取 Class 对象的四种方式1. 知道具体类的情况下可以使用TargetObject.class:2. 通过 Class.forName()传入类的全路径获取:3. 通过对象实例instance…...
![](https://www.ngui.cc/images/no-images.jpg)
307 Temporary Redirect 解决办法(临时重定向)
背景:java后台服务请求python服务端 java服务报错:Unexpected response status:307 python服务端报错:307 Temporary Redirect 解决:查了好久找不到什么原因,请求路径问题 请求url:http//:w…...
![](https://img-blog.csdnimg.cn/c0df94fdf0ef412d8da47086d594866a.png)
SpringBoot案例
SpringBoot案例5,案例5.1 创建工程5.2 代码拷贝5.3 配置文件5.4 静态资源目标 基于SpringBoot的完成SSM整合项目开发 5,案例 SpringBoot 到这就已经学习完毕,接下来我们将学习 SSM 时做的三大框架整合的案例用 SpringBoot 来实现一下。我们完…...
![](https://www.ngui.cc/images/no-images.jpg)
Android 10.0 系统framework发送悬浮通知的流程分析
1.前言 在android10.0rom定制化开发中,在原生系统的systemui中,状态栏通知,和闹钟,wifi等悬浮通知也是很重要的, 悬浮通知也是系统通知的一种,也是在frameworks中发送出来的通知,接下来就分析下10.0中的悬浮通知的发送 流程,然后就可以实现自己自定义悬浮通知的相关功…...
![](https://mweb-1307664364.cos.ap-chengdu.myqcloud.com/2023/04/04/220pxburgersequationtravelingwaveplot14.gif)
傅里叶谱方法-傅里叶谱方法求解二维浅水方程组和二维粘性 Burgers 方程及其Matlab程序实现
3.3.2 二维浅水方程组 二维浅水方程组是描述水波运动的基本方程之一。它主要用于描述近岸浅水区域内的波浪、潮汐等水动力学现象。这个方程组由两个偏微分方程组成,一个是质量守恒方程,另一个是动量守恒方程。浅水方程描述了具有自由表面、密度均匀、深…...
![](https://img-blog.csdnimg.cn/5239cb14ac994056a00cffb73b02a810.png)
算法训练营 - 广度优先BFS
目录 从层序遍历开始 N 叉树的层序遍历 经典BFS最短路模板 经典C queue 数组模拟队列 打印路径 示例1.bfs查找所有连接方块 Cqueue版 数组模拟队列 示例2.从多个位置同时开始BFS 示例3.抽象最短路类(作图关键) 示例4.跨过障碍的最短路 从层序遍历…...
![](https://www.ngui.cc/images/no-images.jpg)
判断两个字符串是否匹配(1个通配符代表一个字符)
目录 判断两个字符串是否匹配(1个通配符代表一个字符) 程序设计 程序分析...
![](https://img-blog.csdnimg.cn/4042c817dd68446395b969370a267cfb.png)
用css画一个csdn程序猿
效果如下: 头部 我们先来拆解一下,程序猿的结构 两只耳朵和头是圆形组成的,耳朵内的红色部分也是圆形 先画头部,利用圆角实现头部形状 借助工具来快速实现圆角效果 https://9elements.github.io/fancy-border-radius/ <div…...
![](https://img-blog.csdnimg.cn/54f0868043ee4a3ab851e51e90b1dc1a.png)
Java多线程编程—wait/notify机制
文章目录1. 不使用wait/notify机制通信的缺点2. 什么是wait/notify机制3. wait/notify机制原理4. wait/notify方法的基本用法5. 线程状态的切换6. interrupt()遇到方法wait()7. notify/notifyAll方法8. wait(long)介绍9. 生产者/消费者模式10. 管道机制11. 利用wait/notify实现…...
![](https://www.ngui.cc/images/no-images.jpg)
Three.js教程:旋转动画、requestAnimationFrame周期性渲染
推荐:将NSDT场景编辑器加入你3D工具链其他工具系列:NSDT简石数字孪生基于WebGL技术开发在线游戏、商品展示、室内漫游往往都会涉及到动画,初步了解three.js可以做什么,深入讲解three.js动画之前,本节课先制作一个简单的…...
![](https://www.ngui.cc/images/no-images.jpg)
租车自驾app开发有什么作用?租车便利出行APP开发
在线租车APP有哪些优势,租车APP开发的基本功能,租车自驾app开发有什么作用?租车便利出行APP开发,租车服务平台小程序有哪些功能,租车软件开发需要多少钱,租车app都有哪些,租车平台定制开发,租车…...
![](https://www.ngui.cc/images/no-images.jpg)
linux shell 文件分割
split 按照 10m 大小进行分割 split -b 10m large_file.bin new_file_prefix...
![](https://img-blog.csdnimg.cn/fbe945af31dd4f5380d812119afb8cd2.jpeg)
智慧农业系统开发功能有哪些?
农业从古至今都是备受关注的话题,新时代背景下农业发展已经融合了互联网,数字化技术等新型发展方式,形成了农业物联网管控系统,让农业生产更加科技化、智能化、高效化,对农业可持续发展有巨大的推动作用。所以…...
![](https://img-blog.csdnimg.cn/98a4a522f9d043d683c2063f3b0cfe96.png)
【C语言】 指针的进阶 看这一篇就够了
目录 1.字符指针 2.数组指针 3.指针数组 4.数组传参和指针传参 5.函数指针 6.函数指针数组 7.指向函数指针数组的指针 8.回调函数 9.qsort排序和冒泡排序 1.字符指针 让我们一起来回顾一下指针的概念! 1.指针就是一个变量,用来存放地址,地址…...
![](https://img-blog.csdnimg.cn/552d3df95f2d4baa91827e14b6f33133.png)
redis set list
Listlist: 插入命令:lpush / rpush 查看list列表所有数据(-1 表示最后一个):lrange key 0 -1 查看列表长度(key 不存在则长度返回0 ): llen key list长度 获取下表 为 0 的元素 修改下标为0的元素,改为haha 移除列表的第一个元素 或最后一…...
![](https://img-blog.csdnimg.cn/1379ee9f75a14aaeaeff0422a1e74add.jpeg)
如何解决DNS劫持
随着互联网的不断发展,DNS(域名系统)成为了构建网络基础的重要组成部分。而DNS遭到劫持,成为一种常见的安全问题。那么DNS遭到劫持是什么意思呢?如何解决DNS劫持问题呢?下面就让小编来为您一一解答。 DNS遭到劫持是什么意思? DNS遭到劫持指的是黑客通…...
![](https://img-blog.csdnimg.cn/354c8932b6cb4db7862df49118edfa69.png)
【LeetCode】剑指 Offer(28)
目录 题目:剑指 Offer 54. 二叉搜索树的第k大节点 - 力扣(Leetcode) 题目的接口: 解题思路: 代码: 过啦!!! 题目:剑指 Offer 55 - I. 二叉树的深度 - 力…...
![](https://img-blog.csdnimg.cn/707277595deb4f53b40578a82b8f868d.png#pic_center)
「ML 实践篇」模型训练
在训练不同机器学习算法模型时,遇到的各类训练算法大多对用户都是一个黑匣子,而理解它们实际怎么工作,对用户是很有帮助的; 快速定位到合适的模型与正确的训练算法,找到一套适当的超参数等;更高效的执行错…...
![](https://img-blog.csdnimg.cn/img_convert/0e43d5a30c7e97034ce74b3337f080e1.png)
域名解析协议-DNS
DNS(Domain Name System)是互联网上非常重要的一项服务,我们每天上网都要依靠大量的DNS服务。在Internet上,用户更容易记住的是域名,但是网络中的计算机的互相访问是通过 IP 地址实现的。DNS 最常用的功能是给用户提供…...
![](https://www.ngui.cc/images/no-images.jpg)
分享:包括 AI 绘画在内的超齐全免费可用的API 大全
AI 绘画已经火出圈了,你还不知道哪里可以用嘛?我给大家整理了超级齐全的免费可用 API,包括 AI 绘画在内,有需要的小伙伴赶紧收藏了。 AI 绘画/AI 作画 类 AI 绘画:通过AI 生成图片,包括图生文、文生图等。…...
![](https://img-blog.csdnimg.cn/img_convert/bce9558dc42ca30f719742ef8fab21a2.png)
虹科新闻 | 虹科与Overland-Tandberg正式建立合作伙伴关系
虹科Overland-Tandberg 近日,虹科与美国Overland-Tandberg公司达成战略合作,虹科正式成为Overland-Tandberg公司在中国区域的认证授权代理商。未来,虹科将携手Overland-Tandberg,共同致力于提供企业数据管理和保护解决方案。 虹科…...
![](https://img-blog.csdnimg.cn/img_convert/8de8d7eed9e40b7f2c3c47018156035b.png)
整合营销网站/seo是什么意思中文
原标题:一款卡通风格的“类恶魔城”独立游戏HELLO~大家好,这里是小白的每日一游推荐时间。世上的游戏千千万,有许多好玩的游戏由于缺乏宣传,所以不被广大玩家所熟知。在这里小白每天会为大家推荐一款评价很高但是不太出…...
![](https://img-my.csdn.net/uploads/201209/05/1346812446_9021.jpg)
分销系统商城定制开发/百度快速排名优化服务
类加载器,顾名思义,类加载器(class loader)用来加载 Java 类到 Java 虚拟机中。一般来说,Java 虚拟机使用 Java 类的方式如下:Java 源程序(.java 文件)在经过 Java 编译器编译之后就…...
![](/images/no-images.jpg)
用Java做知乎网站/考研培训机构排名前五的机构
安装完成后,创建J2ME项目时显示信息如下: Not all requested modules can be enabled: [StandardModule:org.netbeans.modules.mobility.kit jarFile:C:\Program Files\NetBeans 7.1.1\mobility\modules\org-netbeans-modules-mobility-kit.jar]...
![](/images/no-images.jpg)
天堂网长尾关键词挖掘网站/全球最大的中文搜索引擎
1、安装 Yum install -y freeradius freeradius-mysql freeradius-utils 2、配置 1)修改 clients.conf # vi /usr/local/etc/raddb/clients.conf 在最后增加如下几行: client 172.18.5.88 { 增加认证体,填写OMA的ip地址 s…...
![](https://img-blog.csdnimg.cn/img_convert/3ec624a19b2d1f4f798a04ecb3961441.png)
做网站模板哪里买/电商代运营一般收多少服务费
本文使用「署名 4.0 国际 (CC BY 4.0)」许可协议,欢迎转载、或重新修改使用,但需要注明来源。 署名 4.0 国际 (CC BY 4.0)本文作者: 苏洋创建时间: 2019年08月05日 统计字数: 7024字 阅读时间: 15分钟阅读 本文链接: https://soulteary.com/2019/08/05/p…...
![](https://img-blog.csdnimg.cn/20200430150735507.png)
日照做网站公司/如何设置友情链接
目的 序 新的一年又到了,该跳槽的跳槽了,那跳槽是不是又得面试了,面试中总会问些用不到的。或者有些用到的,但是却当时没说出来。后悔万分,之前被这种问题居然问懵逼了。天天上班写代码,忘记了理论知识。今天补充一下…...