分享

Github火爆图神经网络框架pytorch_geometric原理解析—基于边的高效GNN实现

 LibraryPKU 2019-07-20

近几年来,常用深度网络的实现,如多层感知机(MLP)、卷积神经网络(CNN)、循环神经网络(RNN)等的实现几乎已经形成了规范(如数据预处理、输入输出数据格式、代码的设计模式等)。然而,较晚出现的图神经网络却还没有形成一套规范体系。例如,Github上不同的GNN实现有着多种不同的数据结构来存放输入的图。

虽然在图卷积网络(GCN)、图注意力网络(GAT)等许多图神经网络的理论中,每一层图神经网络就是节点与邻节点特征的融合。直观上说,用循环遍历每个节点的邻节点,按照一定的规律加权平均就可以实现这些网络(如下图所示)。然而实际上,这样的实现方式与TensorFlow和PyTorch等深度学习框架并不兼容。由于要利用GPU的并行计算能力,这些深度学习框架需要我们将数据规整为整齐的矩阵,用矩阵运算而不是循环来实现深度网络。

为了将图神经网络的实现用矩阵运算形式实现,不同的算法可能需要采用不同的设计模式。例如GCN通常使用稀疏矩阵来实现,而GAT的一些版本由于需要使用Attention矩阵,稀疏矩阵在一些情况下就失效了。

为了解决这个问题,pytorch_geometric(https://github.com/rusty1s/pytorch_geometric)使用了一种基于边的实现方法。该方法使用scatter操作实现了上述的“用循环遍历每个节点的邻节点,按照一定的规律加权平均”的操作。该实现依赖于pytorch_scatter(https://github.com/rusty1s/pytorch_scatter)。

用(i, j)表示一个边,假设一个图中有8条边,我们用index表示i(起始点)的集合,用to_index表示j(目标点)的集合,用input表示to_index特征的集合,那么,一个简化版GCN(没有权重计算,以所有邻节点的平均值为输出;也没有全连接层)的示意图如下:

第一行index表示边的起始点,第二行是目标点的特征(邻节点的特征向量,这里简化为标量)。在GCN过程中,我们其实是根据边的起始点来聚合目标点的特征的(以起始点为核心,聚合与其相邻的邻节点的特征值),因此,我们对具有相同起始点(index)的特征(input)进行聚合(相加)即可完成上述操作。在pytorch_scatter中,上述操作可以用下面一行代码实现:

torch_scatter.scatter_add(srcindex)

其中,src对应input(邻节点特征向量集合)。

除了加法,pytorch_scatter还集成了许多其它的聚合操作。因此,pytorch_geometric基于pytorch_scatter构建了一个名为MessagePassing的类:

  • https://github.com/rusty1s/pytorch_geometric/blob/master/torch_geometric/nn/conv/message_passing.py

该类可以根据输入的边、特征和指定的聚合方式来对邻节点进行聚合。因此,在pytorch_geometric中,GCN和GAT的实现都是一个继承了MessagePassing的子类,分别实现了GCN和GAT的权重计算。这样的实现大幅度简化了GNN实现的门槛,使用者只要关注于权重的计算,而不需要干涉具体的与邻节点融合的过程。

另外,由于框架的输入的图的边,而不是邻接矩阵,避免了大量的不存在的边对网络性能的干扰(内存占用、计算效率)。例如经典的GAT实现会让非邻节点参与计算,为其赋予一个非常小的权重来降低其对效果的干扰,这样GAT的计算效率就会大大降低。

除了MessagePassing,pytorch_geometric还实现了使用其他机制的许多GNN。我们会在以后的文章中介绍。

参考链接:

  • https://github.com/rusty1s/pytorch_geometric

  • https://github.com/rusty1s/pytorch_scatter

    本站是提供个人知识管理的网络存储空间,所有内容均由用户发布,不代表本站观点。请注意甄别内容中的联系方式、诱导购买等信息,谨防诈骗。如发现有害或侵权内容,请点击一键举报。
    转藏 分享 献花(0

    0条评论

    发表

    请遵守用户 评论公约

    类似文章 更多