【Graph Net学习】GNN/GCN代码实战
一、简介
GNN(Graph Neural Network)和GCN(Graph Convolutional Network)都是基于图结构的神经网络模型。本文目标就是打代码基础,未用PyG,来扒一扒Graph Net两个基础算法的原理。直接上代码。
二、代码
import time
import random
import os
import numpy as np
import math
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Moduleimport torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimimport scipy.sparse as sp#配置项
class configs():def __init__(self):# Dataself.data_path = r'E:\code\Graph\data'self.save_model_dir = r'\code\Graph'self.model_name = r'GCN' #GNN/GCNself.seed = 2023self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.batch_size = 64self.epoch = 200self.in_features = 1433 #core ~ feature:1433self.hidden_features = 16 # 隐层数量self.output_features = 8 # core~paper-point~ 8类self.learning_rate = 0.01self.dropout = 0.5self.istrain = Trueself.istest = Truecfg = configs()def seed_everything(seed=2023):random.seed(seed)os.environ['PYTHONHASHSEED']=str(seed)np.random.seed(seed)torch.manual_seed(seed)seed_everything(seed = cfg.seed)#数据
class Graph_Data_Loader():def __init__(self):self.adj, self.features, self.labels, self.idx_train, self.idx_val, self.idx_test = self.load_data()self.adj = self.adj.to(cfg.device)self.features = self.features.to(cfg.device)self.labels = self.labels.to(cfg.device)self.idx_train = self.idx_train.to(cfg.device)self.idx_val = self.idx_val.to(cfg.device)self.idx_test = self.idx_test.to(cfg.device)def load_data(self,path=cfg.data_path, dataset="cora"):"""Load citation network dataset (cora only for now)"""print('Loading {} dataset...'.format(dataset))idx_features_labels = np.genfromtxt(os.path.join(path,dataset,dataset+'.content'),dtype=np.dtype(str))features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)labels = self.encode_onehot(idx_features_labels[:, -1])# build graphidx = np.array(idx_features_labels[:, 0], dtype=np.int32)idx_map = {j: i for i, j in enumerate(idx)}edges_unordered = np.genfromtxt(os.path.join(path,dataset,dataset+'.cites'),dtype=np.int32)edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),dtype=np.int32).reshape(edges_unordered.shape)adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),shape=(labels.shape[0], labels.shape[0]),dtype=np.float32)# build symmetric adjacency matrixadj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)features = self.normalize(features)adj = self.normalize(adj + sp.eye(adj.shape[0]))idx_train = range(140)idx_val = range(200, 500)idx_test = range(500, 1500)features = torch.FloatTensor(np.array(features.todense()))labels = torch.LongTensor(np.where(labels)[1])adj = self.sparse_mx_to_torch_sparse_tensor(adj)idx_train = torch.LongTensor(idx_train)idx_val = torch.LongTensor(idx_val)idx_test = torch.LongTensor(idx_test)return adj, features, labels, idx_train, idx_val, idx_testdef encode_onehot(self,labels):classes = set(labels)classes_dict = {c: np.identity(len(classes))[i, :] for i, c inenumerate(classes)}labels_onehot = np.array(list(map(classes_dict.get, labels)),dtype=np.int32)return labels_onehotdef normalize(self,mx):"""Row-normalize sparse matrix"""rowsum = np.array(mx.sum(1))r_inv = np.power(rowsum, -1).flatten()r_inv[np.isinf(r_inv)] = 0.r_mat_inv = sp.diags(r_inv)mx = r_mat_inv.dot(mx)return mxdef sparse_mx_to_torch_sparse_tensor(self,sparse_mx):"""Convert a scipy sparse matrix to a torch sparse tensor."""sparse_mx = sparse_mx.tocoo().astype(np.float32)indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))values = torch.from_numpy(sparse_mx.data)shape = torch.Size(sparse_mx.shape)return torch.sparse.FloatTensor(indices, values, shape)#精度评价指标
def accuracy(output, labels):preds = output.max(1)[1].type_as(labels)correct = preds.eq(labels).double()correct = correct.sum()return correct / len(labels)#模型
#01-GNN
class GNNLayer(nn.Module):def __init__(self, in_features, output_features):super(GNNLayer, self).__init__()self.linear = nn.Linear(in_features, output_features)def forward(self, adj_matrix, features):hidden_features = torch.matmul(adj_matrix, features) # GNN公式:H' = A * Hhidden_features = self.linear(hidden_features) # 使用线性变换hidden_features = F.relu(hidden_features) # 使用ReLU作为激活函数return hidden_features
class GNN(nn.Module):def __init__(self, in_features, hidden_features, output_features, num_layers=2):super(GNN, self).__init__()#输入维度in_features、隐藏层维度hidden_features、输出维度output_features、GNN的层数num_layersself.layers = nn.ModuleList([GNNLayer(in_features, hidden_features) if i == 0 else GNNLayer(hidden_features, hidden_features) for i inrange(num_layers)])self.output_layer = nn.Linear(hidden_features, output_features)def forward(self, adj_matrix, features):hidden_features = featuresfor layer in self.layers:hidden_features = layer(adj_matrix, hidden_features)output = self.output_layer(hidden_features)return F.log_softmax(output,dim=1)#02-GCN
class GraphConvolution(Module):"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907"""def __init__(self, in_features, out_features, bias=True):super(GraphConvolution, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.weight = Parameter(torch.FloatTensor(in_features, out_features))if bias:self.bias = Parameter(torch.FloatTensor(out_features))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self):stdv = 1. / math.sqrt(self.weight.size(1))self.weight.data.uniform_(-stdv, stdv)if self.bias is not None:self.bias.data.uniform_(-stdv, stdv)def forward(self, input, adj):support = torch.mm(input, self.weight)output = torch.spmm(adj, support)if self.bias is not None:return output + self.biaselse:return outputdef __repr__(self):return self.__class__.__name__ + ' (' \+ str(self.in_features) + ' -> ' \+ str(self.out_features) + ')'class GCN(nn.Module):def __init__(self, in_features, hidden_features, output_features, dropout=cfg.dropout):super(GCN, self).__init__()self.gc1 = GraphConvolution(in_features, hidden_features)self.gc2 = GraphConvolution(hidden_features, output_features)self.dropout = dropoutdef forward(self, adj_matrix, features):x = F.relu(self.gc1(features, adj_matrix))x = F.dropout(x, self.dropout, training=self.training)x = self.gc2(x, adj_matrix)return F.log_softmax(x, dim=1)class graph_run():def train(self):t = time.time()#Create Train Processingall_data = Graph_Data_Loader()#创建一个模型model = eval(cfg.model_name)(in_features=cfg.in_features,hidden_features=cfg.hidden_features,output_features=cfg.output_features).to(cfg.device)optimizer = optim.Adam(model.parameters(),lr=cfg.learning_rate, weight_decay=5e-4)#Trainmodel.train()for epoch in range(cfg.epoch):optimizer.zero_grad()output = model(all_data.adj, all_data.features)loss_train = F.nll_loss(output[all_data.idx_train], all_data.labels[all_data.idx_train])acc_train = accuracy(output[all_data.idx_train], all_data.labels[all_data.idx_train])loss_train.backward()optimizer.step()loss_val = F.nll_loss(output[all_data.idx_val], all_data.labels[all_data.idx_val])acc_val = accuracy(output[all_data.idx_val], all_data.labels[all_data.idx_val])print('Epoch: {:04d}'.format(epoch + 1),'loss_train: {:.4f}'.format(loss_train.item()),'acc_train: {:.4f}'.format(acc_train.item()),'loss_val: {:.4f}'.format(loss_val.item()),'acc_val: {:.4f}'.format(acc_val.item()),'time: {:.4f}s'.format(time.time() - t))torch.save(model, os.path.join(cfg.save_model_dir, 'latest.pth')) # 模型保存def infer(self):#Create Test Processingall_data = Graph_Data_Loader()model_path = os.path.join(cfg.save_model_dir, 'latest.pth')model = torch.load(model_path, map_location=torch.device(cfg.device))model.eval()output = model(all_data.adj,all_data.features)loss_test = F.nll_loss(output[all_data.idx_test], all_data.labels[all_data.idx_test])acc_test = accuracy(output[all_data.idx_test], all_data.labels[all_data.idx_test])print("Test set results:","loss= {:.4f}".format(loss_test.item()),"accuracy= {:.4f}".format(acc_test.item()))if __name__ == '__main__':mygraph = graph_run()if cfg.istrain == True:mygraph.train()if cfg.istest == True:mygraph.infer()
三、结果与讨论
需要从网上下载cora数据集,数据组织形式如下图。

测了下Params和GFLOPs,还是比较大的,发现若作为一个Net的Block还是需要优化的哈哈~
| Model | Params | GFLOPs |
| GNN | 23.352K | 126.258M |
| Model | Cora(/train/val/test) |
| GNN | 1.0000/0.7800/0.7620 |
| GCN | 0.9714/0.7767/0.8290 |
四、展望
未来可以考虑用PyG(PyTorch Geometric),毕竟PyG实现GAT等图网络、图的数据组织、加载会更加方便。Graph Net通常用可以用于属性数据的embedding模式,将属性数据可以作为一种补充特征加入Net去训练,看能不能发挥效能。
相关文章:
【Graph Net学习】GNN/GCN代码实战
一、简介 GNN(Graph Neural Network)和GCN(Graph Convolutional Network)都是基于图结构的神经网络模型。本文目标就是打代码基础,未用PyG,来扒一扒Graph Net两个基础算法的原理。直接上代码。 二、代码 …...
RocketMQ 发送顺序消息
文章目录 顺序消息应用场景消息组(MessageGroup)顺序性生产的顺序性MQ 存储的顺序性消费的顺序性 rocketmq-client-java 示例(gRPC 协议)1. 创建 FIFO 主题生产者代码消费者代码解决办法解决后执行结果 rocketmq-client 示例&…...
【面试经典150 | 双指针】判断子序列
文章目录 写在前面Tag题目来源题目解题解题思路方法一:双指针方法二:动态规划 写在最后 写在前面 本专栏专注于分析与讲解【面试经典150】算法,两到三天更新一篇文章,欢迎催更…… 专栏内容以分析题目为主,并附带一些对…...
自动驾驶信息安全方案
目录 1. 修订历史... 3 2. 概述... 4 2.1 目的... 4 2.2 适用范围... 4 2.3 参考文档... 4 2.4 术语和缩写... 4 3. 安全分析... 5 4. 总体设计... 6 4.1 ACU的安全防护... 7 4.1.1 系统安全引导... 7 4.1.2 密钥安全存储... 8 4.1.3 应…...
【云原生】kubernetes中pod(最小的资源管理组件)
目录 前言 一、pod 1.1pause容器使得Pod中的所有容器可以共享两种资源: 1.2 通常把Pod分为两类 1.2.1自主式Pod 1.2.2控制器管理的Pod 1.3 Pod 容器的分类 1.3.1基础容器(infrastructure container) 1.3.2初始化容器(initc…...
[DB]数据库--lowdb
[DB]数据库--lowdb lowdb基本应用获取数据数据变更写入文件 lodash的使用获取数据lodash方法使用数据变更写入文件 lowdb lowdb ,是一个基于文件存储的非关系型数据库 基于loadsh的轻量级数据库 可用于在json中存储数据,大小一般为0~200M的json文件 方便简单的数…...
Kotlin | 在for、forEach循环中正确的使用break、continue
文章目录 for循环中使用break、continueLabel标签forEach中如何退出循环资料 Kotlin 有三种结构化跳转表达式: return:默认从最直接包围它的函数或者匿名函数返回。break:终止最直接包围它的循环。continue:继续下一次最直接包围…...
【C++】详解std::mutex
2023年9月11日,周一中午开始 2023年9月11日,周一晚上23:25写完 目录 概述头文件std::mutex类的成员类型方法没有std::mutex会产生什么问题问题一:数据竞争问题二:不一致lock和unlock死锁 概述 std::mutex是C标准库中…...
Matlab图像处理-Lab模型
Lab模型 Lab模型是由CIE(国际照明委员会)制定的一种彩色模型。该模型与设备无关,弥补了RGB模型和CMYK模型必须依赖于设备颜色特性的不足; 另外,自然界中的任何颜色都可以在Lab空间中表现出来,也就是说RGB和…...
分布式ETL工具Sqoop实践
Mysql数据准备 1、在node02节点登录Mysql。 mysql -uroot -proot2、新建数据库testdb。 create database testdb;3、新建数据表ts。 use testdb; create table ts(id int, name varchar(10), age int, sex char(1));4、向表中插入数据。 insert into ts values(10001,张三…...
展会预告 | 图扑邀您共聚 IOTE 国际物联网展·深圳站
参展时间:9 月 20 日- 22 日 图扑展位:9 号馆 9B 35-1 参展地址:深圳国际会展中心(宝安新馆) IOTE 2023 第二十届国际物联网展深圳站,将于 9 月 20 日- 22 日在深圳国际会展中心(宝安…...
如何下载安装 WampServer 并结合 cpolar 内网穿透,轻松实现对本地服务的公网访问
文章目录 前言1.WampServer下载安装2.WampServer启动3.安装cpolar内网穿透3.1 注册账号3.2 下载cpolar客户端3.3 登录cpolar web ui管理界面3.4 创建公网地址 4.固定公网地址访问 前言 Wamp 是一个 Windows系统下的 Apache PHP Mysql 集成安装环境,是一组常用来…...
iOS添加Mapbox地图库
配置凭据 注册并导航到Account页面。你将需要: 公共访问令牌: 从帐户的tokens页面,你可以复制默认的公共令牌或单击"create a token"按钮来创建新的公共令牌。 带有Downloads:Read范围的秘密访问令牌: 从你帐户的t…...
destoon根据目录下的html文件生成地图索引
因为项目需要,destoon根据目录下的html文件生成地图索引,操作方法,代码如下: <?php $new_array array(); function loopDir($dir,&$new_array,$modurl) {$handle opendir($dir);header("Content-Type:text/xml&qu…...
gRPC之gRPC流
1、gRPC流 从其名称可以理解,流就是持续不断的传输。有一些业务场景请求或者响应的数据量比较大,不适合使用普通的 RPC 调用通过一次请求-响应处理,一方面是考虑数据量大对请求响应时间的影响,另一方面业务场景的设计不一 定需…...
Kafka Shell命令交互
Kafka提供了一个命令行工具,用于管理和与Kafka集群交互。这个命令行工具通常称为Kafka Shell,它允许您执行各种操作,如创建主题、发送和消费消息、查看主题列表等。 以下是一些常用的Kafka Shell命令: 创建主题(Topic): kafka-topics.sh --create --topic my-topic --pa…...
什么是回归测试?
什么是回归测试? 回归测试被定义为一种软件测试类型,以确认最近的程序或代码更改未对现有功能产生不利影响。 回归测试只不过是全部或部分选择已执行的测试用例,然后重新执行以确保现有功能正常运行。 进行此测试是为了确保新代码更改不会…...
ZTMap是如何在相关政策引导下让建筑更加智慧化的?
近几年随着智慧楼宇概念的深入,尤其是在“十四五规划”“新基建”“数字经济”等相关战略和政策的引导下,智慧楼宇也迎来了快速发展期,对推动智慧城市系统的建设越来越重要。那么究竟什么是智慧楼宇呢?智慧楼宇其实就是整合楼宇内…...
Python:函数和代码复用
嗨喽,大家好呀~这里是爱看美女的茜茜呐 👇 👇 👇 更多精彩机密、教程,尽在下方,赶紧点击了解吧~ python源码、视频教程、插件安装教程、资料我都准备好了,直接在文末名片自取就可 1、关于递归函…...
three.js——模型对象的使用材质和方法
模型对象的使用材质和方法 前言效果图1、旋转、缩放、平移,居中的使用1.1 旋转rotation(.rotateX()、.rotateY()、.rotateZ())1.2缩放.scale()1.3平移.translate()1.4居中.center() 2、材质属性.wireframe 前言 BufferGeometry通过.scale()、…...
web vue 项目 Docker化部署
Web 项目 Docker 化部署详细教程 目录 Web 项目 Docker 化部署概述Dockerfile 详解 构建阶段生产阶段 构建和运行 Docker 镜像 1. Web 项目 Docker 化部署概述 Docker 化部署的主要步骤分为以下几个阶段: 构建阶段(Build Stage):…...
React hook之useRef
React useRef 详解 useRef 是 React 提供的一个 Hook,用于在函数组件中创建可变的引用对象。它在 React 开发中有多种重要用途,下面我将全面详细地介绍它的特性和用法。 基本概念 1. 创建 ref const refContainer useRef(initialValue);initialValu…...
Reasoning over Uncertain Text by Generative Large Language Models
https://ojs.aaai.org/index.php/AAAI/article/view/34674/36829https://ojs.aaai.org/index.php/AAAI/article/view/34674/36829 1. 概述 文本中的不确定性在许多语境中传达,从日常对话到特定领域的文档(例如医学文档)(Heritage 2013;Landmark、Gulbrandsen 和 Svenevei…...
关于easyexcel动态下拉选问题处理
前些日子突然碰到一个问题,说是客户的导入文件模版想支持部分导入内容的下拉选,于是我就找了easyexcel官网寻找解决方案,并没有找到合适的方案,没办法只能自己动手并分享出来,针对Java生成Excel下拉菜单时因选项过多导…...
【LeetCode】算法详解#6 ---除自身以外数组的乘积
1.题目介绍 给定一个整数数组 nums,返回 数组 answer ,其中 answer[i] 等于 nums 中除 nums[i] 之外其余各元素的乘积 。 题目数据 保证 数组 nums之中任意元素的全部前缀元素和后缀的乘积都在 32 位 整数范围内。 请 不要使用除法,且在 O…...
系统掌握PyTorch:图解张量、Autograd、DataLoader、nn.Module与实战模型
本文较长,建议点赞收藏,以免遗失。更多AI大模型应用开发学习视频及资料,尽在聚客AI学院。 本文通过代码驱动的方式,系统讲解PyTorch核心概念和实战技巧,涵盖张量操作、自动微分、数据加载、模型构建和训练全流程&#…...
Xela矩阵三轴触觉传感器的工作原理解析与应用场景
Xela矩阵三轴触觉传感器通过先进技术模拟人类触觉感知,帮助设备实现精确的力测量与位移监测。其核心功能基于磁性三维力测量与空间位移测量,能够捕捉多维触觉信息。该传感器的设计不仅提升了触觉感知的精度,还为机器人、医疗设备和制造业的智…...
云原生周刊:k0s 成为 CNCF 沙箱项目
开源项目推荐 HAMi HAMi(原名 k8s‑vGPU‑scheduler)是一款 CNCF Sandbox 级别的开源 K8s 中间件,通过虚拟化 GPU/NPU 等异构设备并支持内存、计算核心时间片隔离及共享调度,为容器提供统一接口,实现细粒度资源配额…...
DAY 26 函数专题1
函数定义与参数知识点回顾:1. 函数的定义2. 变量作用域:局部变量和全局变量3. 函数的参数类型:位置参数、默认参数、不定参数4. 传递参数的手段:关键词参数5 题目1:计算圆的面积 任务: 编写一…...
【Ftrace 专栏】Ftrace 参考博文
ftrace、perf、bcc、bpftrace、ply、simple_perf的使用Ftrace 基本用法Linux 利用 ftrace 分析内核调用如何利用ftrace精确跟踪特定进程调度信息使用 ftrace 进行追踪延迟Linux-培训笔记-ftracehttps://www.kernel.org/doc/html/v4.18/trace/events.htmlhttps://blog.csdn.net/…...
