当前位置: 首页 > news >正文

[图神经网络]PyTorch简单实现一个GCN

Pytorch自带一个PyG的图神经网络库,和构建卷积神经网络类似。不同于卷积神经网络仅需重构__init__( )forward( )两个函数,PyTorch必须额外重构propagate( )message( )函数。

一、环境构建

        ①安装torch_geometric包。

pip install torch_geometric

        ②导入相关库

import torch
import torch.nn.functional as F
import torch.nn as nn
import torch_geometric.nn as pyg_nn
from torch_geometric.datasets import Planetoid

二、PyG图学习架构

        构建方法:首先继承MessagePassing类,接下来重写构造函数和以下三个方法:

message()      #构建消息传递
aggregate()    #将消息聚合到目标节点
update()       #更新消息节点

        1.构造函数

def __init__(self, aggr: Optional[str] = "add",flow: str = "source_to_target", node_dim: int = -2,decomposed_layers: int = 1):
参数内容
aggr消息聚合的方式,常见的方法:addmeanminmax
flow

消息传播的方向,source_to_target--从源节点到目标节点

                             target_to_source--从目标节点到源节点

node_dim传播的维度

        2.propagate函数

                该函数为消息传播的启动函数,调用此函数后会依次执行:messageaggregateupdate来完成消息的传递、聚合、更新

                该函数声明如下:

propagate(self, edge_index: Adj, size: Size = None, **kwargs)
参数说明
edge_index边索引
size邻接矩阵的尺寸,若为None则表示方阵
**kwargs额外参数

        3.message函数

                用于构建节点消息,传递给propagatetensor可以映射到中心节点和邻居节点,仅需在相应的变量名后加上_i(中心节点)或_j(邻居节点)即可。

self.propagate(edge_index, x=x):passdef message(self, x_i, x_j, edge_index_i):pass
x_i中心节点构成的特征向量组成的矩阵
x_j邻居节点构成的特征向量组成的矩阵
edge_index_i中心节点的索引

        4.aggregate函数

                消息聚合函数,用以聚合来自邻居的消息,常见的方法有add、sum、mean、max,可以通过super().__init__()中的参数aggr来设定

        5.update函数

                用于更新节点的消息

三、GCN图卷积网络

        GCN网络的原理可见:图卷积神经网络--GCN

        需要注意 torch_scatter无法使用pip install加载可以参见 torch_scatter安装

        1.加载数据集

from torch_geometric.datasets import Planetoiddataset = Planetoid(root='Cora', name='Cora')

                Cora数据集是一个根据科学论文之间相互引用关系构建的图(Graph)数据集合,论文合计7类,共2708篇论文(2708个节点),10556条边。

        2.定义GCN层

class GCNConv(MessagePassing):def __init__(self, in_channels, out_channels, add_self_loops=True, bias=True):super(GCNConv, self).__init__()self.add_self_loops = add_self_loopsself.edge_index = Noneself.linear = pyg_nn.dense.linear.Linear(in_channels, out_channels, weight_initializer='glorot')if bias:self.bias = nn.Parameter(torch.Tensor(out_channels, 1))self.bias = pyg_nn.inits.glorot(self.bias)else:self.register_parameter('bias', None)# 1.消息传递def message(self, x, edge_index):# 1.对所有节点进行新的空间映射x = self.linear(x) # [num_nodes, feature_size]# 2.添加偏置if self.bias != None:x += self.bias.flatten()# 3.返回source、target信息,对应边的起点和终点row, col = edge_index # [E]# 4.获得度矩阵deg = degree(col, x.shape[0], x.dtype) # [num_nodes]# 5.度矩阵归一化deg_inv_sqrt = deg.pow(-0.5) # [num_nodes]# 6.计算sqrt(deg(i)) * sqrt(deg(j))norm = deg_inv_sqrt[row] * deg_inv_sqrt[col] # [num_nodes]# 7.返回所有边的映射x_j = x[row] # [E, feature_size]# 8.计算归一化后的节点特征x_j = norm.view(-1, 1) * x_j # [E, feature_size]return x_j# 2.消息聚合def aggregate(self, x_j, edge_index):# 1.返回source、target信息,对应边的起点和终点row, col = edge_index # [E]# 2.聚合邻居特征aggr_out = scatter(x_j, row, dim=0, reduce='sum') # [num_nodes, feature_size]return aggr_out# 3.节点更新def update(self, aggr_out):# 对于GCN没有这个阶段,所以直接返回return aggr_outdef forward(self, x, edge_index):# 2.添加自环信息,考虑自身信息if self.add_self_loops:edge_index, _ = add_self_loops(edge_index, num_nodes=x.shape[0]) # [2, E]return self.propagate(edge_index, x=x)

        3.定义GCN网络

class GCN(nn.Module):def __init__(self, num_node_features, num_classes):super(GCN, self).__init__()self.conv1 = GCNConv(num_node_features, 16)self.conv2 = GCNConv(16, num_classes)def forward(self, data):x, edge_index = data.x, data.edge_indexx = self.conv1(x, edge_index)x = F.relu(x)x = F.dropout(x, training=self.training)x = self.conv2(x, edge_index)return F.log_softmax(x, dim=1)

        4.模型调用

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # 设备
epochs = 200 # 学习轮数
lr = 0.0003 # 学习率
num_node_features = dataset.num_node_features # 每个节点的特征数
num_classes = dataset.num_classes # 每个节点的类别数
data = dataset[0].to(device) # Cora的一张图# 4.定义模型
model = GCN(num_node_features, num_classes).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr) # 优化器
loss_function = nn.NLLLoss() # 损失函数# 训练模式
model.train()for epoch in range(epochs):optimizer.zero_grad()pred = model(data)loss = loss_function(pred[data.train_mask], data.y[data.train_mask]) # 损失correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item() # epoch正确分类数目acc_train = correct_count_train / data.train_mask.sum().item() # epoch训练精度loss.backward()optimizer.step()if epoch % 20 == 0:print("【EPOCH: 】%s" % str(epoch + 1))print('训练损失为:{:.4f}'.format(loss.item()), '训练精度为:{:.4f}'.format(acc_train))print('【Finished Training!】')# 模型验证
model.eval()
pred = model(data)# 训练集(使用了掩码)
correct_count_train = pred.argmax(axis=1)[data.train_mask].eq(data.y[data.train_mask]).sum().item()
acc_train = correct_count_train / data.train_mask.sum().item()
loss_train = loss_function(pred[data.train_mask], data.y[data.train_mask]).item()# 测试集
correct_count_test = pred.argmax(axis=1)[data.test_mask].eq(data.y[data.test_mask]).sum().item()
acc_test = correct_count_test / data.test_mask.sum().item()
loss_test = loss_function(pred[data.test_mask], data.y[data.test_mask]).item()print('Train Accuracy: {:.4f}'.format(acc_train), 'Train Loss: {:.4f}'.format(loss_train))
print('Test  Accuracy: {:.4f}'.format(acc_test), 'Test  Loss: {:.4f}'.format(loss_test))

相关文章:

[图神经网络]PyTorch简单实现一个GCN

Pytorch自带一个PyG的图神经网络库,和构建卷积神经网络类似。不同于卷积神经网络仅需重构__init__( )和forward( )两个函数,PyTorch必须额外重构propagate( )和message( )函数。 一、环境构建 ①安装torch_geometric包。 pip install torch_geometric …...

Elasticsearch(黑马)

初识elasticsearch ​​. 安装elasticsearch 1.部署单点es 1.1.创建网络 因为我们还需要部署kibana容器,因此需要让es和kibana容器互联。这里先创建一个网络: docker network create es-net 1.2.加载镜像 这里我们采用elasticsearch的7.12.1版本的…...

oracle数据库调整字段类型

oracle数据库更改字段类型比较墨迹,因为如果该字段有值,是不允许直接更改字段类型的。另外oralce不支持在指定的某个字段后面新增一个字段,但是mysql数据可以向指定的字段后面新增一个字段。 mysql向指定字段后面新增一个字段: al…...

面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码)

面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码) 目录 面部表情识别2:Pytorch实现表情识别(含表情识别数据集和训练代码) 1.面部表情识别方法 2.面部表情识别数据集 (1)表情识别数据集说明 (2&…...

赛效:如何在线给图片加水印

学会给图片加水印是一个非常实用的技能,可以让你的图片更具保护性和个性化。说到加水印,很多人不知道怎么操作。其实,给图片加水印非常简单,不用下载任何程序,在线就能完成。今天,我将介绍如何使用改图宝在…...

动力节点杜老师Vue笔记——Vue程序初体验

一、Vue程序初体验 我们可以先不去了解Vue框架的发展历史、Vue框架有什么特点、Vue是谁开发的,这些对我们编写Vue程序起不到太大的作用,更何况现在说了一些特点之后,我们也没有办法彻底理解它,因此我们可以先学会用,使…...

ajax上传图片存入到指定的文件夹并回显

html代码&#xff1a; <!DOCTYPE html> <html lang"en"> <head><meta charset"UTF-8"><title>Title</title><script src"js/jquery-2.1.0.js"></script> </head> <body> <form…...

cesium加载cesiumlab切的影像切片和标准TMS瓦片的区别

1.加载cesiumlab切的影像 var labImg viewer.scene.imageryLayers.addImageryProvider( new Cesium.UrlTemplateImageryProvider({url:http://192.168.1.25:8080/DOMtms/{z}/{x}/{y}.png,fileExtension : "png"})); 2.标准TMS瓦片 var labImg viewer.scene.im…...

第二周P9-P22

文章目录第三章 系统总线3.1、总线的基本概念一、为什么要用总线二、什么是总线三、总线上信息的传送四、总线结构的计算机举例1、单总线结构框图2、面向CPU的双总线结构框图3、以存储器为中心的双总线结构图3.2、总线的分类1、片内总线2、系统总线3、通信走线3.3、总线特性及性…...

java反射

文章目录何为反射&#xff1f;反射的应用场景了解么&#xff1f;谈谈反射机制的优缺点优点缺点反射实战获取 Class 对象的四种方式1. 知道具体类的情况下可以使用TargetObject.class&#xff1a;2. 通过 Class.forName()传入类的全路径获取&#xff1a;3. 通过对象实例instance…...

307 Temporary Redirect 解决办法(临时重定向)

背景&#xff1a;java后台服务请求python服务端 java服务报错&#xff1a;Unexpected response status&#xff1a;307 python服务端报错&#xff1a;307 Temporary Redirect 解决&#xff1a;查了好久找不到什么原因&#xff0c;请求路径问题 请求url&#xff1a;http//:w…...

SpringBoot案例

SpringBoot案例5&#xff0c;案例5.1 创建工程5.2 代码拷贝5.3 配置文件5.4 静态资源目标 基于SpringBoot的完成SSM整合项目开发 5&#xff0c;案例 SpringBoot 到这就已经学习完毕&#xff0c;接下来我们将学习 SSM 时做的三大框架整合的案例用 SpringBoot 来实现一下。我们完…...

Android 10.0 系统framework发送悬浮通知的流程分析

1.前言 在android10.0rom定制化开发中,在原生系统的systemui中,状态栏通知,和闹钟,wifi等悬浮通知也是很重要的, 悬浮通知也是系统通知的一种,也是在frameworks中发送出来的通知,接下来就分析下10.0中的悬浮通知的发送 流程,然后就可以实现自己自定义悬浮通知的相关功…...

傅里叶谱方法-傅里叶谱方法求解二维浅水方程组和二维粘性 Burgers 方程及其Matlab程序实现

3.3.2 二维浅水方程组 二维浅水方程组是描述水波运动的基本方程之一。它主要用于描述近岸浅水区域内的波浪、潮汐等水动力学现象。这个方程组由两个偏微分方程组成&#xff0c;一个是质量守恒方程&#xff0c;另一个是动量守恒方程。浅水方程描述了具有自由表面、密度均匀、深…...

算法训练营 - 广度优先BFS

目录 从层序遍历开始 N 叉树的层序遍历 经典BFS最短路模板 经典C queue 数组模拟队列 打印路径 示例1.bfs查找所有连接方块 Cqueue版 数组模拟队列 示例2.从多个位置同时开始BFS 示例3.抽象最短路类&#xff08;作图关键&#xff09; 示例4.跨过障碍的最短路 从层序遍历…...

​​​​​​​判断两个字符串是否匹配(1个通配符代表一个字符)

目录 判断两个字符串是否匹配(1个通配符代表一个字符) 程序设计 程序分析...

用css画一个csdn程序猿

效果如下&#xff1a; 头部 我们先来拆解一下&#xff0c;程序猿的结构 两只耳朵和头是圆形组成的&#xff0c;耳朵内的红色部分也是圆形 先画头部&#xff0c;利用圆角实现头部形状 借助工具来快速实现圆角效果 https://9elements.github.io/fancy-border-radius/ <div…...

Java多线程编程—wait/notify机制

文章目录1. 不使用wait/notify机制通信的缺点2. 什么是wait/notify机制3. wait/notify机制原理4. wait/notify方法的基本用法5. 线程状态的切换6. interrupt()遇到方法wait()7. notify/notifyAll方法8. wait(long)介绍9. 生产者/消费者模式10. 管道机制11. 利用wait/notify实现…...

Three.js教程:旋转动画、requestAnimationFrame周期性渲染

推荐&#xff1a;将NSDT场景编辑器加入你3D工具链其他工具系列&#xff1a;NSDT简石数字孪生基于WebGL技术开发在线游戏、商品展示、室内漫游往往都会涉及到动画&#xff0c;初步了解three.js可以做什么&#xff0c;深入讲解three.js动画之前&#xff0c;本节课先制作一个简单的…...

租车自驾app开发有什么作用?租车便利出行APP开发

在线租车APP有哪些优势&#xff0c;租车APP开发的基本功能&#xff0c;租车自驾app开发有什么作用?租车便利出行APP开发&#xff0c;租车服务平台小程序有哪些功能&#xff0c;租车软件开发需要多少钱&#xff0c;租车app都有哪些&#xff0c;租车平台定制开发&#xff0c;租车…...

未来机器人的大脑:如何用神经网络模拟器实现更智能的决策?

编辑&#xff1a;陈萍萍的公主一点人工一点智能 未来机器人的大脑&#xff1a;如何用神经网络模拟器实现更智能的决策&#xff1f;RWM通过双自回归机制有效解决了复合误差、部分可观测性和随机动力学等关键挑战&#xff0c;在不依赖领域特定归纳偏见的条件下实现了卓越的预测准…...

Chapter03-Authentication vulnerabilities

文章目录 1. 身份验证简介1.1 What is authentication1.2 difference between authentication and authorization1.3 身份验证机制失效的原因1.4 身份验证机制失效的影响 2. 基于登录功能的漏洞2.1 密码爆破2.2 用户名枚举2.3 有缺陷的暴力破解防护2.3.1 如果用户登录尝试失败次…...

python打卡day49

知识点回顾&#xff1a; 通道注意力模块复习空间注意力模块CBAM的定义 作业&#xff1a;尝试对今天的模型检查参数数目&#xff0c;并用tensorboard查看训练过程 import torch import torch.nn as nn# 定义通道注意力 class ChannelAttention(nn.Module):def __init__(self,…...

Xshell远程连接Kali(默认 | 私钥)Note版

前言:xshell远程连接&#xff0c;私钥连接和常规默认连接 任务一 开启ssh服务 service ssh status //查看ssh服务状态 service ssh start //开启ssh服务 update-rc.d ssh enable //开启自启动ssh服务 任务二 修改配置文件 vi /etc/ssh/ssh_config //第一…...

从WWDC看苹果产品发展的规律

WWDC 是苹果公司一年一度面向全球开发者的盛会&#xff0c;其主题演讲展现了苹果在产品设计、技术路线、用户体验和生态系统构建上的核心理念与演进脉络。我们借助 ChatGPT Deep Research 工具&#xff0c;对过去十年 WWDC 主题演讲内容进行了系统化分析&#xff0c;形成了这份…...

线程与协程

1. 线程与协程 1.1. “函数调用级别”的切换、上下文切换 1. 函数调用级别的切换 “函数调用级别的切换”是指&#xff1a;像函数调用/返回一样轻量地完成任务切换。 举例说明&#xff1a; 当你在程序中写一个函数调用&#xff1a; funcA() 然后 funcA 执行完后返回&…...

家政维修平台实战20:权限设计

目录 1 获取工人信息2 搭建工人入口3 权限判断总结 目前我们已经搭建好了基础的用户体系&#xff0c;主要是分成几个表&#xff0c;用户表我们是记录用户的基础信息&#xff0c;包括手机、昵称、头像。而工人和员工各有各的表。那么就有一个问题&#xff0c;不同的角色&#xf…...

macOS多出来了:Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用

文章目录 问题现象问题原因解决办法 问题现象 macOS启动台&#xff08;Launchpad&#xff09;多出来了&#xff1a;Google云端硬盘、YouTube、表格、幻灯片、Gmail、Google文档等应用。 问题原因 很明显&#xff0c;都是Google家的办公全家桶。这些应用并不是通过独立安装的…...

2025 后端自学UNIAPP【项目实战:旅游项目】6、我的收藏页面

代码框架视图 1、先添加一个获取收藏景点的列表请求 【在文件my_api.js文件中添加】 // 引入公共的请求封装 import http from ./my_http.js// 登录接口&#xff08;适配服务端返回 Token&#xff09; export const login async (code, avatar) > {const res await http…...

GC1808高性能24位立体声音频ADC芯片解析

1. 芯片概述 GC1808是一款24位立体声音频模数转换器&#xff08;ADC&#xff09;&#xff0c;支持8kHz~96kHz采样率&#xff0c;集成Δ-Σ调制器、数字抗混叠滤波器和高通滤波器&#xff0c;适用于高保真音频采集场景。 2. 核心特性 高精度&#xff1a;24位分辨率&#xff0c…...