【Graph Net学习】GNN/GCN代码实战
一、简介
GNN(Graph Neural Network)和GCN(Graph Convolutional Network)都是基于图结构的神经网络模型。本文目标就是打代码基础,未用PyG,来扒一扒Graph Net两个基础算法的原理。直接上代码。
二、代码
import time
import random
import os
import numpy as np
import math
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Moduleimport torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimimport scipy.sparse as sp#配置项
class configs():def __init__(self):# Dataself.data_path = r'E:\code\Graph\data'self.save_model_dir = r'\code\Graph'self.model_name = r'GCN' #GNN/GCNself.seed = 2023self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.batch_size = 64self.epoch = 200self.in_features = 1433 #core ~ feature:1433self.hidden_features = 16 # 隐层数量self.output_features = 8 # core~paper-point~ 8类self.learning_rate = 0.01self.dropout = 0.5self.istrain = Trueself.istest = Truecfg = configs()def seed_everything(seed=2023):random.seed(seed)os.environ['PYTHONHASHSEED']=str(seed)np.random.seed(seed)torch.manual_seed(seed)seed_everything(seed = cfg.seed)#数据
class Graph_Data_Loader():def __init__(self):self.adj, self.features, self.labels, self.idx_train, self.idx_val, self.idx_test = self.load_data()self.adj = self.adj.to(cfg.device)self.features = self.features.to(cfg.device)self.labels = self.labels.to(cfg.device)self.idx_train = self.idx_train.to(cfg.device)self.idx_val = self.idx_val.to(cfg.device)self.idx_test = self.idx_test.to(cfg.device)def load_data(self,path=cfg.data_path, dataset="cora"):"""Load citation network dataset (cora only for now)"""print('Loading {} dataset...'.format(dataset))idx_features_labels = np.genfromtxt(os.path.join(path,dataset,dataset+'.content'),dtype=np.dtype(str))features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)labels = self.encode_onehot(idx_features_labels[:, -1])# build graphidx = np.array(idx_features_labels[:, 0], dtype=np.int32)idx_map = {j: i for i, j in enumerate(idx)}edges_unordered = np.genfromtxt(os.path.join(path,dataset,dataset+'.cites'),dtype=np.int32)edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),dtype=np.int32).reshape(edges_unordered.shape)adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),shape=(labels.shape[0], labels.shape[0]),dtype=np.float32)# build symmetric adjacency matrixadj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)features = self.normalize(features)adj = self.normalize(adj + sp.eye(adj.shape[0]))idx_train = range(140)idx_val = range(200, 500)idx_test = range(500, 1500)features = torch.FloatTensor(np.array(features.todense()))labels = torch.LongTensor(np.where(labels)[1])adj = self.sparse_mx_to_torch_sparse_tensor(adj)idx_train = torch.LongTensor(idx_train)idx_val = torch.LongTensor(idx_val)idx_test = torch.LongTensor(idx_test)return adj, features, labels, idx_train, idx_val, idx_testdef encode_onehot(self,labels):classes = set(labels)classes_dict = {c: np.identity(len(classes))[i, :] for i, c inenumerate(classes)}labels_onehot = np.array(list(map(classes_dict.get, labels)),dtype=np.int32)return labels_onehotdef normalize(self,mx):"""Row-normalize sparse matrix"""rowsum = np.array(mx.sum(1))r_inv = np.power(rowsum, -1).flatten()r_inv[np.isinf(r_inv)] = 0.r_mat_inv = sp.diags(r_inv)mx = r_mat_inv.dot(mx)return mxdef sparse_mx_to_torch_sparse_tensor(self,sparse_mx):"""Convert a scipy sparse matrix to a torch sparse tensor."""sparse_mx = sparse_mx.tocoo().astype(np.float32)indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))values = torch.from_numpy(sparse_mx.data)shape = torch.Size(sparse_mx.shape)return torch.sparse.FloatTensor(indices, values, shape)#精度评价指标
def accuracy(output, labels):preds = output.max(1)[1].type_as(labels)correct = preds.eq(labels).double()correct = correct.sum()return correct / len(labels)#模型
#01-GNN
class GNNLayer(nn.Module):def __init__(self, in_features, output_features):super(GNNLayer, self).__init__()self.linear = nn.Linear(in_features, output_features)def forward(self, adj_matrix, features):hidden_features = torch.matmul(adj_matrix, features) # GNN公式:H' = A * Hhidden_features = self.linear(hidden_features) # 使用线性变换hidden_features = F.relu(hidden_features) # 使用ReLU作为激活函数return hidden_features
class GNN(nn.Module):def __init__(self, in_features, hidden_features, output_features, num_layers=2):super(GNN, self).__init__()#输入维度in_features、隐藏层维度hidden_features、输出维度output_features、GNN的层数num_layersself.layers = nn.ModuleList([GNNLayer(in_features, hidden_features) if i == 0 else GNNLayer(hidden_features, hidden_features) for i inrange(num_layers)])self.output_layer = nn.Linear(hidden_features, output_features)def forward(self, adj_matrix, features):hidden_features = featuresfor layer in self.layers:hidden_features = layer(adj_matrix, hidden_features)output = self.output_layer(hidden_features)return F.log_softmax(output,dim=1)#02-GCN
class GraphConvolution(Module):"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907"""def __init__(self, in_features, out_features, bias=True):super(GraphConvolution, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.weight = Parameter(torch.FloatTensor(in_features, out_features))if bias:self.bias = Parameter(torch.FloatTensor(out_features))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self):stdv = 1. / math.sqrt(self.weight.size(1))self.weight.data.uniform_(-stdv, stdv)if self.bias is not None:self.bias.data.uniform_(-stdv, stdv)def forward(self, input, adj):support = torch.mm(input, self.weight)output = torch.spmm(adj, support)if self.bias is not None:return output + self.biaselse:return outputdef __repr__(self):return self.__class__.__name__ + ' (' \+ str(self.in_features) + ' -> ' \+ str(self.out_features) + ')'class GCN(nn.Module):def __init__(self, in_features, hidden_features, output_features, dropout=cfg.dropout):super(GCN, self).__init__()self.gc1 = GraphConvolution(in_features, hidden_features)self.gc2 = GraphConvolution(hidden_features, output_features)self.dropout = dropoutdef forward(self, adj_matrix, features):x = F.relu(self.gc1(features, adj_matrix))x = F.dropout(x, self.dropout, training=self.training)x = self.gc2(x, adj_matrix)return F.log_softmax(x, dim=1)class graph_run():def train(self):t = time.time()#Create Train Processingall_data = Graph_Data_Loader()#创建一个模型model = eval(cfg.model_name)(in_features=cfg.in_features,hidden_features=cfg.hidden_features,output_features=cfg.output_features).to(cfg.device)optimizer = optim.Adam(model.parameters(),lr=cfg.learning_rate, weight_decay=5e-4)#Trainmodel.train()for epoch in range(cfg.epoch):optimizer.zero_grad()output = model(all_data.adj, all_data.features)loss_train = F.nll_loss(output[all_data.idx_train], all_data.labels[all_data.idx_train])acc_train = accuracy(output[all_data.idx_train], all_data.labels[all_data.idx_train])loss_train.backward()optimizer.step()loss_val = F.nll_loss(output[all_data.idx_val], all_data.labels[all_data.idx_val])acc_val = accuracy(output[all_data.idx_val], all_data.labels[all_data.idx_val])print('Epoch: {:04d}'.format(epoch + 1),'loss_train: {:.4f}'.format(loss_train.item()),'acc_train: {:.4f}'.format(acc_train.item()),'loss_val: {:.4f}'.format(loss_val.item()),'acc_val: {:.4f}'.format(acc_val.item()),'time: {:.4f}s'.format(time.time() - t))torch.save(model, os.path.join(cfg.save_model_dir, 'latest.pth')) # 模型保存def infer(self):#Create Test Processingall_data = Graph_Data_Loader()model_path = os.path.join(cfg.save_model_dir, 'latest.pth')model = torch.load(model_path, map_location=torch.device(cfg.device))model.eval()output = model(all_data.adj,all_data.features)loss_test = F.nll_loss(output[all_data.idx_test], all_data.labels[all_data.idx_test])acc_test = accuracy(output[all_data.idx_test], all_data.labels[all_data.idx_test])print("Test set results:","loss= {:.4f}".format(loss_test.item()),"accuracy= {:.4f}".format(acc_test.item()))if __name__ == '__main__':mygraph = graph_run()if cfg.istrain == True:mygraph.train()if cfg.istest == True:mygraph.infer()
三、结果与讨论
需要从网上下载cora数据集,数据组织形式如下图。

测了下Params和GFLOPs,还是比较大的,发现若作为一个Net的Block还是需要优化的哈哈~
| Model | Params | GFLOPs |
| GNN | 23.352K | 126.258M |
| Model | Cora(/train/val/test) |
| GNN | 1.0000/0.7800/0.7620 |
| GCN | 0.9714/0.7767/0.8290 |
四、展望
未来可以考虑用PyG(PyTorch Geometric),毕竟PyG实现GAT等图网络、图的数据组织、加载会更加方便。Graph Net通常用可以用于属性数据的embedding模式,将属性数据可以作为一种补充特征加入Net去训练,看能不能发挥效能。
相关文章:
【Graph Net学习】GNN/GCN代码实战
一、简介 GNN(Graph Neural Network)和GCN(Graph Convolutional Network)都是基于图结构的神经网络模型。本文目标就是打代码基础,未用PyG,来扒一扒Graph Net两个基础算法的原理。直接上代码。 二、代码 …...
RocketMQ 发送顺序消息
文章目录 顺序消息应用场景消息组(MessageGroup)顺序性生产的顺序性MQ 存储的顺序性消费的顺序性 rocketmq-client-java 示例(gRPC 协议)1. 创建 FIFO 主题生产者代码消费者代码解决办法解决后执行结果 rocketmq-client 示例&…...
【面试经典150 | 双指针】判断子序列
文章目录 写在前面Tag题目来源题目解题解题思路方法一:双指针方法二:动态规划 写在最后 写在前面 本专栏专注于分析与讲解【面试经典150】算法,两到三天更新一篇文章,欢迎催更…… 专栏内容以分析题目为主,并附带一些对…...
自动驾驶信息安全方案
目录 1. 修订历史... 3 2. 概述... 4 2.1 目的... 4 2.2 适用范围... 4 2.3 参考文档... 4 2.4 术语和缩写... 4 3. 安全分析... 5 4. 总体设计... 6 4.1 ACU的安全防护... 7 4.1.1 系统安全引导... 7 4.1.2 密钥安全存储... 8 4.1.3 应…...
【云原生】kubernetes中pod(最小的资源管理组件)
目录 前言 一、pod 1.1pause容器使得Pod中的所有容器可以共享两种资源: 1.2 通常把Pod分为两类 1.2.1自主式Pod 1.2.2控制器管理的Pod 1.3 Pod 容器的分类 1.3.1基础容器(infrastructure container) 1.3.2初始化容器(initc…...
[DB]数据库--lowdb
[DB]数据库--lowdb lowdb基本应用获取数据数据变更写入文件 lodash的使用获取数据lodash方法使用数据变更写入文件 lowdb lowdb ,是一个基于文件存储的非关系型数据库 基于loadsh的轻量级数据库 可用于在json中存储数据,大小一般为0~200M的json文件 方便简单的数…...
Kotlin | 在for、forEach循环中正确的使用break、continue
文章目录 for循环中使用break、continueLabel标签forEach中如何退出循环资料 Kotlin 有三种结构化跳转表达式: return:默认从最直接包围它的函数或者匿名函数返回。break:终止最直接包围它的循环。continue:继续下一次最直接包围…...
【C++】详解std::mutex
2023年9月11日,周一中午开始 2023年9月11日,周一晚上23:25写完 目录 概述头文件std::mutex类的成员类型方法没有std::mutex会产生什么问题问题一:数据竞争问题二:不一致lock和unlock死锁 概述 std::mutex是C标准库中…...
Matlab图像处理-Lab模型
Lab模型 Lab模型是由CIE(国际照明委员会)制定的一种彩色模型。该模型与设备无关,弥补了RGB模型和CMYK模型必须依赖于设备颜色特性的不足; 另外,自然界中的任何颜色都可以在Lab空间中表现出来,也就是说RGB和…...
分布式ETL工具Sqoop实践
Mysql数据准备 1、在node02节点登录Mysql。 mysql -uroot -proot2、新建数据库testdb。 create database testdb;3、新建数据表ts。 use testdb; create table ts(id int, name varchar(10), age int, sex char(1));4、向表中插入数据。 insert into ts values(10001,张三…...
展会预告 | 图扑邀您共聚 IOTE 国际物联网展·深圳站
参展时间:9 月 20 日- 22 日 图扑展位:9 号馆 9B 35-1 参展地址:深圳国际会展中心(宝安新馆) IOTE 2023 第二十届国际物联网展深圳站,将于 9 月 20 日- 22 日在深圳国际会展中心(宝安…...
如何下载安装 WampServer 并结合 cpolar 内网穿透,轻松实现对本地服务的公网访问
文章目录 前言1.WampServer下载安装2.WampServer启动3.安装cpolar内网穿透3.1 注册账号3.2 下载cpolar客户端3.3 登录cpolar web ui管理界面3.4 创建公网地址 4.固定公网地址访问 前言 Wamp 是一个 Windows系统下的 Apache PHP Mysql 集成安装环境,是一组常用来…...
iOS添加Mapbox地图库
配置凭据 注册并导航到Account页面。你将需要: 公共访问令牌: 从帐户的tokens页面,你可以复制默认的公共令牌或单击"create a token"按钮来创建新的公共令牌。 带有Downloads:Read范围的秘密访问令牌: 从你帐户的t…...
destoon根据目录下的html文件生成地图索引
因为项目需要,destoon根据目录下的html文件生成地图索引,操作方法,代码如下: <?php $new_array array(); function loopDir($dir,&$new_array,$modurl) {$handle opendir($dir);header("Content-Type:text/xml&qu…...
gRPC之gRPC流
1、gRPC流 从其名称可以理解,流就是持续不断的传输。有一些业务场景请求或者响应的数据量比较大,不适合使用普通的 RPC 调用通过一次请求-响应处理,一方面是考虑数据量大对请求响应时间的影响,另一方面业务场景的设计不一 定需…...
Kafka Shell命令交互
Kafka提供了一个命令行工具,用于管理和与Kafka集群交互。这个命令行工具通常称为Kafka Shell,它允许您执行各种操作,如创建主题、发送和消费消息、查看主题列表等。 以下是一些常用的Kafka Shell命令: 创建主题(Topic): kafka-topics.sh --create --topic my-topic --pa…...
什么是回归测试?
什么是回归测试? 回归测试被定义为一种软件测试类型,以确认最近的程序或代码更改未对现有功能产生不利影响。 回归测试只不过是全部或部分选择已执行的测试用例,然后重新执行以确保现有功能正常运行。 进行此测试是为了确保新代码更改不会…...
ZTMap是如何在相关政策引导下让建筑更加智慧化的?
近几年随着智慧楼宇概念的深入,尤其是在“十四五规划”“新基建”“数字经济”等相关战略和政策的引导下,智慧楼宇也迎来了快速发展期,对推动智慧城市系统的建设越来越重要。那么究竟什么是智慧楼宇呢?智慧楼宇其实就是整合楼宇内…...
Python:函数和代码复用
嗨喽,大家好呀~这里是爱看美女的茜茜呐 👇 👇 👇 更多精彩机密、教程,尽在下方,赶紧点击了解吧~ python源码、视频教程、插件安装教程、资料我都准备好了,直接在文末名片自取就可 1、关于递归函…...
three.js——模型对象的使用材质和方法
模型对象的使用材质和方法 前言效果图1、旋转、缩放、平移,居中的使用1.1 旋转rotation(.rotateX()、.rotateY()、.rotateZ())1.2缩放.scale()1.3平移.translate()1.4居中.center() 2、材质属性.wireframe 前言 BufferGeometry通过.scale()、…...
装饰模式(Decorator Pattern)重构java邮件发奖系统实战
前言 现在我们有个如下的需求,设计一个邮件发奖的小系统, 需求 1.数据验证 → 2. 敏感信息加密 → 3. 日志记录 → 4. 实际发送邮件 装饰器模式(Decorator Pattern)允许向一个现有的对象添加新的功能,同时又不改变其…...
Xshell远程连接Kali(默认 | 私钥)Note版
前言:xshell远程连接,私钥连接和常规默认连接 任务一 开启ssh服务 service ssh status //查看ssh服务状态 service ssh start //开启ssh服务 update-rc.d ssh enable //开启自启动ssh服务 任务二 修改配置文件 vi /etc/ssh/ssh_config //第一…...
React Native在HarmonyOS 5.0阅读类应用开发中的实践
一、技术选型背景 随着HarmonyOS 5.0对Web兼容层的增强,React Native作为跨平台框架可通过重新编译ArkTS组件实现85%以上的代码复用率。阅读类应用具有UI复杂度低、数据流清晰的特点。 二、核心实现方案 1. 环境配置 (1)使用React Native…...
基于Uniapp开发HarmonyOS 5.0旅游应用技术实践
一、技术选型背景 1.跨平台优势 Uniapp采用Vue.js框架,支持"一次开发,多端部署",可同步生成HarmonyOS、iOS、Android等多平台应用。 2.鸿蒙特性融合 HarmonyOS 5.0的分布式能力与原子化服务,为旅游应用带来…...
江苏艾立泰跨国资源接力:废料变黄金的绿色供应链革命
在华东塑料包装行业面临限塑令深度调整的背景下,江苏艾立泰以一场跨国资源接力的创新实践,重新定义了绿色供应链的边界。 跨国回收网络:废料变黄金的全球棋局 艾立泰在欧洲、东南亚建立再生塑料回收点,将海外废弃包装箱通过标准…...
Python如何给视频添加音频和字幕
在Python中,给视频添加音频和字幕可以使用电影文件处理库MoviePy和字幕处理库Subtitles。下面将详细介绍如何使用这些库来实现视频的音频和字幕添加,包括必要的代码示例和详细解释。 环境准备 在开始之前,需要安装以下Python库:…...
《基于Apache Flink的流处理》笔记
思维导图 1-3 章 4-7章 8-11 章 参考资料 源码: https://github.com/streaming-with-flink 博客 https://flink.apache.org/bloghttps://www.ververica.com/blog 聚会及会议 https://flink-forward.orghttps://www.meetup.com/topics/apache-flink https://n…...
成都鼎讯硬核科技!雷达目标与干扰模拟器,以卓越性能制胜电磁频谱战
在现代战争中,电磁频谱已成为继陆、海、空、天之后的 “第五维战场”,雷达作为电磁频谱领域的关键装备,其干扰与抗干扰能力的较量,直接影响着战争的胜负走向。由成都鼎讯科技匠心打造的雷达目标与干扰模拟器,凭借数字射…...
图表类系列各种样式PPT模版分享
图标图表系列PPT模版,柱状图PPT模版,线状图PPT模版,折线图PPT模版,饼状图PPT模版,雷达图PPT模版,树状图PPT模版 图表类系列各种样式PPT模版分享:图表系列PPT模板https://pan.quark.cn/s/20d40aa…...
android13 app的触摸问题定位分析流程
一、知识点 一般来说,触摸问题都是app层面出问题,我们可以在ViewRootImpl.java添加log的方式定位;如果是touchableRegion的计算问题,就会相对比较麻烦了,需要通过adb shell dumpsys input > input.log指令,且通过打印堆栈的方式,逐步定位问题,并找到修改方案。 问题…...
