【Graph Net学习】GNN/GCN代码实战
一、简介
GNN(Graph Neural Network)和GCN(Graph Convolutional Network)都是基于图结构的神经网络模型。本文目标就是打代码基础,未用PyG,来扒一扒Graph Net两个基础算法的原理。直接上代码。
二、代码
import time
import random
import os
import numpy as np
import math
from torch.nn.parameter import Parameter
from torch.nn.modules.module import Moduleimport torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optimimport scipy.sparse as sp#配置项
class configs():def __init__(self):# Dataself.data_path = r'E:\code\Graph\data'self.save_model_dir = r'\code\Graph'self.model_name = r'GCN' #GNN/GCNself.seed = 2023self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")self.batch_size = 64self.epoch = 200self.in_features = 1433 #core ~ feature:1433self.hidden_features = 16 # 隐层数量self.output_features = 8 # core~paper-point~ 8类self.learning_rate = 0.01self.dropout = 0.5self.istrain = Trueself.istest = Truecfg = configs()def seed_everything(seed=2023):random.seed(seed)os.environ['PYTHONHASHSEED']=str(seed)np.random.seed(seed)torch.manual_seed(seed)seed_everything(seed = cfg.seed)#数据
class Graph_Data_Loader():def __init__(self):self.adj, self.features, self.labels, self.idx_train, self.idx_val, self.idx_test = self.load_data()self.adj = self.adj.to(cfg.device)self.features = self.features.to(cfg.device)self.labels = self.labels.to(cfg.device)self.idx_train = self.idx_train.to(cfg.device)self.idx_val = self.idx_val.to(cfg.device)self.idx_test = self.idx_test.to(cfg.device)def load_data(self,path=cfg.data_path, dataset="cora"):"""Load citation network dataset (cora only for now)"""print('Loading {} dataset...'.format(dataset))idx_features_labels = np.genfromtxt(os.path.join(path,dataset,dataset+'.content'),dtype=np.dtype(str))features = sp.csr_matrix(idx_features_labels[:, 1:-1], dtype=np.float32)labels = self.encode_onehot(idx_features_labels[:, -1])# build graphidx = np.array(idx_features_labels[:, 0], dtype=np.int32)idx_map = {j: i for i, j in enumerate(idx)}edges_unordered = np.genfromtxt(os.path.join(path,dataset,dataset+'.cites'),dtype=np.int32)edges = np.array(list(map(idx_map.get, edges_unordered.flatten())),dtype=np.int32).reshape(edges_unordered.shape)adj = sp.coo_matrix((np.ones(edges.shape[0]), (edges[:, 0], edges[:, 1])),shape=(labels.shape[0], labels.shape[0]),dtype=np.float32)# build symmetric adjacency matrixadj = adj + adj.T.multiply(adj.T > adj) - adj.multiply(adj.T > adj)features = self.normalize(features)adj = self.normalize(adj + sp.eye(adj.shape[0]))idx_train = range(140)idx_val = range(200, 500)idx_test = range(500, 1500)features = torch.FloatTensor(np.array(features.todense()))labels = torch.LongTensor(np.where(labels)[1])adj = self.sparse_mx_to_torch_sparse_tensor(adj)idx_train = torch.LongTensor(idx_train)idx_val = torch.LongTensor(idx_val)idx_test = torch.LongTensor(idx_test)return adj, features, labels, idx_train, idx_val, idx_testdef encode_onehot(self,labels):classes = set(labels)classes_dict = {c: np.identity(len(classes))[i, :] for i, c inenumerate(classes)}labels_onehot = np.array(list(map(classes_dict.get, labels)),dtype=np.int32)return labels_onehotdef normalize(self,mx):"""Row-normalize sparse matrix"""rowsum = np.array(mx.sum(1))r_inv = np.power(rowsum, -1).flatten()r_inv[np.isinf(r_inv)] = 0.r_mat_inv = sp.diags(r_inv)mx = r_mat_inv.dot(mx)return mxdef sparse_mx_to_torch_sparse_tensor(self,sparse_mx):"""Convert a scipy sparse matrix to a torch sparse tensor."""sparse_mx = sparse_mx.tocoo().astype(np.float32)indices = torch.from_numpy(np.vstack((sparse_mx.row, sparse_mx.col)).astype(np.int64))values = torch.from_numpy(sparse_mx.data)shape = torch.Size(sparse_mx.shape)return torch.sparse.FloatTensor(indices, values, shape)#精度评价指标
def accuracy(output, labels):preds = output.max(1)[1].type_as(labels)correct = preds.eq(labels).double()correct = correct.sum()return correct / len(labels)#模型
#01-GNN
class GNNLayer(nn.Module):def __init__(self, in_features, output_features):super(GNNLayer, self).__init__()self.linear = nn.Linear(in_features, output_features)def forward(self, adj_matrix, features):hidden_features = torch.matmul(adj_matrix, features) # GNN公式:H' = A * Hhidden_features = self.linear(hidden_features) # 使用线性变换hidden_features = F.relu(hidden_features) # 使用ReLU作为激活函数return hidden_features
class GNN(nn.Module):def __init__(self, in_features, hidden_features, output_features, num_layers=2):super(GNN, self).__init__()#输入维度in_features、隐藏层维度hidden_features、输出维度output_features、GNN的层数num_layersself.layers = nn.ModuleList([GNNLayer(in_features, hidden_features) if i == 0 else GNNLayer(hidden_features, hidden_features) for i inrange(num_layers)])self.output_layer = nn.Linear(hidden_features, output_features)def forward(self, adj_matrix, features):hidden_features = featuresfor layer in self.layers:hidden_features = layer(adj_matrix, hidden_features)output = self.output_layer(hidden_features)return F.log_softmax(output,dim=1)#02-GCN
class GraphConvolution(Module):"""Simple GCN layer, similar to https://arxiv.org/abs/1609.02907"""def __init__(self, in_features, out_features, bias=True):super(GraphConvolution, self).__init__()self.in_features = in_featuresself.out_features = out_featuresself.weight = Parameter(torch.FloatTensor(in_features, out_features))if bias:self.bias = Parameter(torch.FloatTensor(out_features))else:self.register_parameter('bias', None)self.reset_parameters()def reset_parameters(self):stdv = 1. / math.sqrt(self.weight.size(1))self.weight.data.uniform_(-stdv, stdv)if self.bias is not None:self.bias.data.uniform_(-stdv, stdv)def forward(self, input, adj):support = torch.mm(input, self.weight)output = torch.spmm(adj, support)if self.bias is not None:return output + self.biaselse:return outputdef __repr__(self):return self.__class__.__name__ + ' (' \+ str(self.in_features) + ' -> ' \+ str(self.out_features) + ')'class GCN(nn.Module):def __init__(self, in_features, hidden_features, output_features, dropout=cfg.dropout):super(GCN, self).__init__()self.gc1 = GraphConvolution(in_features, hidden_features)self.gc2 = GraphConvolution(hidden_features, output_features)self.dropout = dropoutdef forward(self, adj_matrix, features):x = F.relu(self.gc1(features, adj_matrix))x = F.dropout(x, self.dropout, training=self.training)x = self.gc2(x, adj_matrix)return F.log_softmax(x, dim=1)class graph_run():def train(self):t = time.time()#Create Train Processingall_data = Graph_Data_Loader()#创建一个模型model = eval(cfg.model_name)(in_features=cfg.in_features,hidden_features=cfg.hidden_features,output_features=cfg.output_features).to(cfg.device)optimizer = optim.Adam(model.parameters(),lr=cfg.learning_rate, weight_decay=5e-4)#Trainmodel.train()for epoch in range(cfg.epoch):optimizer.zero_grad()output = model(all_data.adj, all_data.features)loss_train = F.nll_loss(output[all_data.idx_train], all_data.labels[all_data.idx_train])acc_train = accuracy(output[all_data.idx_train], all_data.labels[all_data.idx_train])loss_train.backward()optimizer.step()loss_val = F.nll_loss(output[all_data.idx_val], all_data.labels[all_data.idx_val])acc_val = accuracy(output[all_data.idx_val], all_data.labels[all_data.idx_val])print('Epoch: {:04d}'.format(epoch + 1),'loss_train: {:.4f}'.format(loss_train.item()),'acc_train: {:.4f}'.format(acc_train.item()),'loss_val: {:.4f}'.format(loss_val.item()),'acc_val: {:.4f}'.format(acc_val.item()),'time: {:.4f}s'.format(time.time() - t))torch.save(model, os.path.join(cfg.save_model_dir, 'latest.pth')) # 模型保存def infer(self):#Create Test Processingall_data = Graph_Data_Loader()model_path = os.path.join(cfg.save_model_dir, 'latest.pth')model = torch.load(model_path, map_location=torch.device(cfg.device))model.eval()output = model(all_data.adj,all_data.features)loss_test = F.nll_loss(output[all_data.idx_test], all_data.labels[all_data.idx_test])acc_test = accuracy(output[all_data.idx_test], all_data.labels[all_data.idx_test])print("Test set results:","loss= {:.4f}".format(loss_test.item()),"accuracy= {:.4f}".format(acc_test.item()))if __name__ == '__main__':mygraph = graph_run()if cfg.istrain == True:mygraph.train()if cfg.istest == True:mygraph.infer()
三、结果与讨论
需要从网上下载cora数据集,数据组织形式如下图。
测了下Params和GFLOPs,还是比较大的,发现若作为一个Net的Block还是需要优化的哈哈~
Model | Params | GFLOPs |
GNN | 23.352K | 126.258M |
Model | Cora(/train/val/test) |
GNN | 1.0000/0.7800/0.7620 |
GCN | 0.9714/0.7767/0.8290 |
四、展望
未来可以考虑用PyG(PyTorch Geometric),毕竟PyG实现GAT等图网络、图的数据组织、加载会更加方便。Graph Net通常用可以用于属性数据的embedding模式,将属性数据可以作为一种补充特征加入Net去训练,看能不能发挥效能。
相关文章:
【Graph Net学习】GNN/GCN代码实战
一、简介 GNN(Graph Neural Network)和GCN(Graph Convolutional Network)都是基于图结构的神经网络模型。本文目标就是打代码基础,未用PyG,来扒一扒Graph Net两个基础算法的原理。直接上代码。 二、代码 …...
RocketMQ 发送顺序消息
文章目录 顺序消息应用场景消息组(MessageGroup)顺序性生产的顺序性MQ 存储的顺序性消费的顺序性 rocketmq-client-java 示例(gRPC 协议)1. 创建 FIFO 主题生产者代码消费者代码解决办法解决后执行结果 rocketmq-client 示例&…...
【面试经典150 | 双指针】判断子序列
文章目录 写在前面Tag题目来源题目解题解题思路方法一:双指针方法二:动态规划 写在最后 写在前面 本专栏专注于分析与讲解【面试经典150】算法,两到三天更新一篇文章,欢迎催更…… 专栏内容以分析题目为主,并附带一些对…...
自动驾驶信息安全方案
目录 1. 修订历史... 3 2. 概述... 4 2.1 目的... 4 2.2 适用范围... 4 2.3 参考文档... 4 2.4 术语和缩写... 4 3. 安全分析... 5 4. 总体设计... 6 4.1 ACU的安全防护... 7 4.1.1 系统安全引导... 7 4.1.2 密钥安全存储... 8 4.1.3 应…...
【云原生】kubernetes中pod(最小的资源管理组件)
目录 前言 一、pod 1.1pause容器使得Pod中的所有容器可以共享两种资源: 1.2 通常把Pod分为两类 1.2.1自主式Pod 1.2.2控制器管理的Pod 1.3 Pod 容器的分类 1.3.1基础容器(infrastructure container) 1.3.2初始化容器(initc…...
[DB]数据库--lowdb
[DB]数据库--lowdb lowdb基本应用获取数据数据变更写入文件 lodash的使用获取数据lodash方法使用数据变更写入文件 lowdb lowdb ,是一个基于文件存储的非关系型数据库 基于loadsh的轻量级数据库 可用于在json中存储数据,大小一般为0~200M的json文件 方便简单的数…...
Kotlin | 在for、forEach循环中正确的使用break、continue
文章目录 for循环中使用break、continueLabel标签forEach中如何退出循环资料 Kotlin 有三种结构化跳转表达式: return:默认从最直接包围它的函数或者匿名函数返回。break:终止最直接包围它的循环。continue:继续下一次最直接包围…...
【C++】详解std::mutex
2023年9月11日,周一中午开始 2023年9月11日,周一晚上23:25写完 目录 概述头文件std::mutex类的成员类型方法没有std::mutex会产生什么问题问题一:数据竞争问题二:不一致lock和unlock死锁 概述 std::mutex是C标准库中…...
Matlab图像处理-Lab模型
Lab模型 Lab模型是由CIE(国际照明委员会)制定的一种彩色模型。该模型与设备无关,弥补了RGB模型和CMYK模型必须依赖于设备颜色特性的不足; 另外,自然界中的任何颜色都可以在Lab空间中表现出来,也就是说RGB和…...
分布式ETL工具Sqoop实践
Mysql数据准备 1、在node02节点登录Mysql。 mysql -uroot -proot2、新建数据库testdb。 create database testdb;3、新建数据表ts。 use testdb; create table ts(id int, name varchar(10), age int, sex char(1));4、向表中插入数据。 insert into ts values(10001,张三…...
展会预告 | 图扑邀您共聚 IOTE 国际物联网展·深圳站
参展时间:9 月 20 日- 22 日 图扑展位:9 号馆 9B 35-1 参展地址:深圳国际会展中心(宝安新馆) IOTE 2023 第二十届国际物联网展深圳站,将于 9 月 20 日- 22 日在深圳国际会展中心(宝安…...
如何下载安装 WampServer 并结合 cpolar 内网穿透,轻松实现对本地服务的公网访问
文章目录 前言1.WampServer下载安装2.WampServer启动3.安装cpolar内网穿透3.1 注册账号3.2 下载cpolar客户端3.3 登录cpolar web ui管理界面3.4 创建公网地址 4.固定公网地址访问 前言 Wamp 是一个 Windows系统下的 Apache PHP Mysql 集成安装环境,是一组常用来…...
iOS添加Mapbox地图库
配置凭据 注册并导航到Account页面。你将需要: 公共访问令牌: 从帐户的tokens页面,你可以复制默认的公共令牌或单击"create a token"按钮来创建新的公共令牌。 带有Downloads:Read范围的秘密访问令牌: 从你帐户的t…...
destoon根据目录下的html文件生成地图索引
因为项目需要,destoon根据目录下的html文件生成地图索引,操作方法,代码如下: <?php $new_array array(); function loopDir($dir,&$new_array,$modurl) {$handle opendir($dir);header("Content-Type:text/xml&qu…...
gRPC之gRPC流
1、gRPC流 从其名称可以理解,流就是持续不断的传输。有一些业务场景请求或者响应的数据量比较大,不适合使用普通的 RPC 调用通过一次请求-响应处理,一方面是考虑数据量大对请求响应时间的影响,另一方面业务场景的设计不一 定需…...
Kafka Shell命令交互
Kafka提供了一个命令行工具,用于管理和与Kafka集群交互。这个命令行工具通常称为Kafka Shell,它允许您执行各种操作,如创建主题、发送和消费消息、查看主题列表等。 以下是一些常用的Kafka Shell命令: 创建主题(Topic): kafka-topics.sh --create --topic my-topic --pa…...
什么是回归测试?
什么是回归测试? 回归测试被定义为一种软件测试类型,以确认最近的程序或代码更改未对现有功能产生不利影响。 回归测试只不过是全部或部分选择已执行的测试用例,然后重新执行以确保现有功能正常运行。 进行此测试是为了确保新代码更改不会…...
ZTMap是如何在相关政策引导下让建筑更加智慧化的?
近几年随着智慧楼宇概念的深入,尤其是在“十四五规划”“新基建”“数字经济”等相关战略和政策的引导下,智慧楼宇也迎来了快速发展期,对推动智慧城市系统的建设越来越重要。那么究竟什么是智慧楼宇呢?智慧楼宇其实就是整合楼宇内…...
Python:函数和代码复用
嗨喽,大家好呀~这里是爱看美女的茜茜呐 👇 👇 👇 更多精彩机密、教程,尽在下方,赶紧点击了解吧~ python源码、视频教程、插件安装教程、资料我都准备好了,直接在文末名片自取就可 1、关于递归函…...
three.js——模型对象的使用材质和方法
模型对象的使用材质和方法 前言效果图1、旋转、缩放、平移,居中的使用1.1 旋转rotation(.rotateX()、.rotateY()、.rotateZ())1.2缩放.scale()1.3平移.translate()1.4居中.center() 2、材质属性.wireframe 前言 BufferGeometry通过.scale()、…...
sql explain
目录 1. sql explain每个字段对应的含义1.1. id1.2. select_type1.3. table1.4. partitions1.5. type1.6. possible_keys1.7. key1.8. key_len1.9. ref1.10. rows1.11. Extra 索引实践联合索引最左列原则全值匹配不建议在索引列上做任何操作, 否则索引会失效转而全表扫描尽量使…...
【LeetCode-简单题】剑指 Offer 05. 替换空格
文章目录 题目方法一:常规做法:方法二:双指针做法 题目 方法一:常规做法: class Solution {public String replaceSpace(String s) {int len s.length() ;StringBuffer str new StringBuffer();for(int i 0 ; i &l…...
数字虚拟人制作简明指南
如何在线创建虚拟人? 虚拟人,也称为数字化身、虚拟助理或虚拟代理,是一种可以通过各种在线平台与用户进行逼真交互的人工智能人。 在线创建虚拟人变得越来越流行,因为它为个人和企业带来了许多好处。 推荐:用 NSDT编辑…...
Nginx 文件解析漏洞复现
一、漏洞说明 Nginx文件解析漏洞算是一个比较经典的漏洞,接下来我们就通过如下步骤进行漏洞复现,以及进行漏洞的修复。 版本条件:IIS 7.0/IIS 7.5/ Nginx <8.03 二、搭建环境 cd /vulhub/nginx/nginx_parsing_vulnerability docker-compos…...
Lombok依赖
一.介绍 Project Lombok 是一个 Java 库,它会自动插入编辑器和构建工具,为您的 Java 增添趣味。永远不要再写另一个 getter 或 equals 方法,使用一个注释,您的类有一个功能齐全的构建器,自动化您的日志记录变量等等。…...
XML 和 JSON 学习笔记(基础)
XML Why XML 的出现背景:在实际开发中,不同语言(如Java、JavaScript等)的应用程序之间数据传递的格式不同,导致它们进行数据交换时很困难,XML就应运而生了!(XML 是一种通用的数据交…...
L1-005 考试座位号分数 15
每个 PAT 考生在参加考试时都会被分配两个座位号,一个是试机座位,一个是考试座位。正常情况下,考生在入场时先得到试机座位号码,入座进入试机状态后,系统会显示该考生的考试座位号码,考试时考生需要换到考试…...
无涯教程-JavaScript - CEILING.MATH函数
描述 CEILING.MATH函数将数字四舍五入到最接近的整数或最接近的有效倍数。 Excel CEILING.MATH函数是Excel中的十五个舍入函数之一。 语法 CEILING.MATH (number, [significance], [mode])争论 Argument描述Required/OptionalNumberNumber must be less than 9.99E307 and …...
ChatGPT提示词(prompt)资源汇总
文章目录 awesome-chatgpt-promptsLearn PromptingSnack PromptFlow GPTPrompt VineChatGPT 指令大全AI Toolbox HubAI Short ChatGPT是一种强大的生成式AI模型,而提示词(prompt)则是与ChatGPT一起使用的指导性文本,用于引导模型生…...
MySQL 几种导数据的方法与遇到的问题
零、说在前面 MySQL导数据通常使用第三方工具和MySQL自身的工具,本文分别就这两类方法分别介绍。 一、第三方工具之 Navicat 1.1、Navicat的“数据传输”工具 打开Navicat,点击“工具”标签,找到“数据传输”,即可看到操作界面。…...
大连住房和城乡建设网站/企业关键词优化专业公司
1.flex-direction:设置伸缩项目的排列方式,即主轴的方向row 设置从左到右排列row-reverse 设置从右到左排列column 设置从上到下排列column-reverse 设置从下到上排列2.justify-content:调整主轴方向上的对齐方式,对于…...
php做网站首页/百度手机seo
在市场发展的趋势中,为了产生更好的视觉或者跳转的过程的必要,设计师都很清楚认知到在目前的APP设计和网站加载是无法避免的事了,所以在面对这种困难,必须把这种缺点一步步的变成为优点,打造优秀的用户体验。 目前许多…...
隐藏网站开发语言/写文章免费的软件
全国各地交通管理部门与中国气象局公共气象服务中心、高德地图联合发布了《2018年春运安全出行指南》,依据历年春运期间高德地图交通大数据,融合高德地图位置大数据预测城市交通拥堵趋势以及出行趋势,为广大群众节假日出行提供科学的参考建议…...
做博客用什么系统做网站好/搜索引擎营销简称为
海度数据 MySQL同步到Oracle 一、数据表结构同步 将mysql表结构导出 将sql改为oracle表结构 将双引号去掉 不同数据字段类型改变 将表名改为要同步到oracle的名字 二、编写etl同步脚本 选择要同步表所在的工程 复制一个原有的表 右击选择【属性】一个一个的检查&am…...
商丘企业网站建设费用多少钱/国内seo公司
一、MyBatis入门简要介绍(百科) MyBatis 是支持普通 SQL查询,存储过程和高级映射的优秀持久层框架。MyBatis 消除了几乎所有的JDBC代码和参数的手工设置以及结果集的检索。MyBatis 使用简单的 XML或注解用于配置和原始映射,将接口…...
做跨境电商如何自建站/网络推广公司运作
在编写程序的时候我常常陷入纠结,一个抽象对象,到底应该定义成 抽象类(Abstract Class) 还是 接口(Interface) 呢?二者具有很大的相似性,甚至可以相互替换,难以选择。在 Stackoverflow 上这个问题被问了很多次…...