Showing posts with label torch geometric. Show all posts
Showing posts with label torch geometric. Show all posts

Tuesday, January 14, 2025

Modeling Graph Neural Networks With PyG

  Target audience: Beginner
Estimated reading time: 6'


Have you ever wondered how to get started with Graph Neural Networks (GNNs)?  
Torch Geometric (PyG) provides a comprehensive toolkit to explore the various elements of a GNN and build your own learning path through hands-on experience and highly reusable components.


Table of contents
      Overview
      Features
      Flickr dataset
      GNN block
      GNN base model
Follow me on LinkedIn

What you will learn: How Torch Geometric can help you kickstart your exploration of Graph Neural Networks and build reusable models.

Notes

  • Environments: python 3.12.5,  matplotlib 3.9, numpy 2.2.0, torch 2.5.1, torch-geometric 2.6.1
  • Source code is available on GitHub [ref 1]
  • To enhance the readability of the algorithm implementations, we have omitted non-essential code elements like error checking, comments, exceptions, validation of class and method arguments, scoping qualifiers, and import statement.
  • This article explores the PyG (PyTorch Geometric) Python library to evaluate various graph neural network (GNN) architectures. It is not intended as an introduction or overview of GNNs and assumes the reader has some prior knowledge of the subject.

Introduction

     Data on manifolds can often be represented as a graph, where the manifold's local structure is approximated by connections between nearby points. GNNs and their variants (like Graph Convolutional Networks (GCNs)) extend neural networks to process data on non-Euclidean domains by leveraging the graph structure, which may approximate the underlying manifold.
Application Social network analysis, molecular structure prediction, and 3D point cloud data can all be modeled using GNNs.
  
     Here is an overview of the major types of GNNs:
  • Graph Convolutional Networks (GCNs): GCNs generalize the concept of convolution from grids (e.g., images) to graphs. They aggregate information from a node's neighbors using normalized adjacency matrices and apply transformations to learn node embeddings.
  • Graph Attention Networks (GATs):  GATs use attention mechanisms to learn the importance of neighboring nodes dynamically. Each edge is assigned a learned weight during aggregation.
  • GraphSAGE (Graph Sample and Aggregate):It learns node embeddings by sampling and aggregating features from a fixed-size neighborhood of each node, enabling scalable learning on large graphs.
  • Graph Isomorphism Networks (GINs): GINs are designed to be as powerful as the Weisfeiler-Lehman (WL) graph isomorphism test, distinguishing graph structures more effectively.
  • Spectral Graph Neural Networks (SGNN): These networks operate in the spectral domain using the graph Laplacian. They use eigenvectors of the Laplacian for convolution-like operations.
  • Graph Pooling Networks: They summarize graph information into a smaller representation, similar to pooling in CNNs. They can be categorized into Global and hierarchical pooling
  • Hyperbolic Graph Neural Networks: These networks operate in hyperbolic space, which is well-suited for representing hierarchical or tree-like graph structures.
  • Dynamic Graph Neural Networks: These networks are designed to handle temporal graphs, where nodes and edges evolve over time.
  • Relational Graph Convolutional Networks (R-GCNs):R-GCNs extend GCNs to handle heterogeneous graphs with different types of nodes and edges.
  • Graph Transformers: They adapt the Transformer architecture to graph-structured data using attention mechanisms and global context.
  • Graph Autoencoders: These are used for unsupervised learning on graphs, aiming to reconstruct graph structure and node features.
  • Diffusion-Based GNNs: These networks use graph diffusion processes to propagate information.

     The description of the inner-workings and mathematical foundation of graph neural networks, message passing architecture and aggregation policies are beyond the scope of this article.
     Some of the most relevant tutorials and presentations on GNNs are listed in references [ref 2, 3 & 4].


PyG

    Overview

    PyTorch Geometric (PyG) is a powerful and flexible Python library for graph neural networks (GNNs) and geometric deep learning. Built on top of PyTorch, PyG provides tools to handle graph-structured data efficiently and enables researchers and practitioners to build, train, and evaluate a wide range of graph-based machine learning models [ref 5]

     PyTorch Geometric is designed to work with non-Euclidean data (e.g., graphs, point clouds, and manifolds), where relationships between entities are represented as edges between nodes. Examples of such data include:
  • Graphs: Social networks, citation networks, knowledge graphs.
  • Molecular Data: Protein structures, chemical compounds.
  • Point Clouds: 3D object representations.

PyG provides high-level abstractions for graph operations and optimizations, making it easier to implement and train graph neural networks.

There are several benefits of using PyG
  • Ease of use: Simplifies the implementation of graph-based machine learning tasks with high-level APIs and pre-built components.
  • Flexibility: Customizable architecture for experimenting with novel GNN models and algorithms.
  • Efficiency: Optimized for sparse matrix operations, making it memory- and computation-efficient, especially for large graphs.
  • Community support: Widely adopted in academia and industry, with active community support and frequent updates.

Features

The key tasks for graph neural networks are

  • Node-Level TasksPredict labels or embeddings for nodes in a graph. Example: Social network analysis, fraud detection.
  • Edge-Level TasksPredict the presence or properties of edges. Example: Link prediction in recommendation systems.
  • Graph-Level TasksPredict properties of entire graphs. Example: Molecular property prediction, drug discovery.
  • Geometric Data ProcessingHandle point clouds and 3D objects. Example: 3D shape classification.
Graph neural networks have numerous applications:
  • Node Classification: Assign categories to nodes within a graph (e.g., documents, videos, protein functions, web pages).  
  • Link Prediction: Predict relationships between pairs of nodes in a graph, such as determining whether a link exists (binary classification) or identifying the type/category of a link. Applications include recommendation systems and discovering new relationships in knowledge graphs.  
  • Community Detection: Group nodes into clusters based on domain-specific similarity measures, used in tasks like fraud detection, entity resolution, text clustering, and identifying cliques.  
  • Node and Edge Regression: Predict numerical values associated with nodes or edges, such as traffic flow on roads or internet data streams.  
  • Graph Classification and Regression: Predict properties of an entire graph based on its nodes and links, such as determining the type or behavior of the graph.  

Simple tutorial

The implementation of deep learning models often depends heavily on repetitive, boilerplate code. It makes sense to apply the same level of reusability and design patterns in this field that are common in conventional software development
 
Neural blocks are routinely used to wrap the components associated to neural network layer into a single class [ref 6].
For instance, a convolutional model can be represented as:

Fig. 1 Neural blocks for a convolutional neural network


Let's apply the same principle of reusability and encapsulation to build a neural block wrapping a GNN layer.
  • Message passing (i.e. GraphConv, GraphGat)
  • Activation function (i.e. ReLU)
  • Batch normalization
  • Drop out regularization

GNN data structure

A component of GNN is implemented by class torch_geometric.data.Data as follow:

  • data.x: Node feature matrix with shape [num_nodes, num_node_features]
  • data.edge_index: Graph connectivity with shape [2, num_edges] and type torch.long
  • data.edge_attr: Edge feature matrix shape [num_edges, num_edge_features]              
  • data.y: Target to train against (may have arbitrary shape), e.g., node-level targets of shape [num_nodes, *] or graph-level targets of shape [1, *]
  • data.pos: Node position matrix with shape [num_nodes, num_dimensions].

Let's start with the implementation of a very simple graph as described below in PyG.

Fig. 2 Schema for a simple 3 node graph

The implementation consists of defining the node attribute data.x and the indices of nodes in the list of edges, edge_index.

import torch
from torch_geometric.data import Data

edge_index = torch.tensor([[0, 1, 1, 2],[1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([-1, 0, 1], dtype=torch.float)

data = Data(x=x.T, edge_index=edge_index)
data.validate(raise_or_error=True)


Flickr dataset

The Flickr dataset is a graph where nodes represent images and edges signify similarities between them. It includes 89,250 images and 899,756 relationships. Node features consist of image descriptions and shared properties. This dataset is commonly used for tasks like node classification, link prediction, and graph representation learning.

You can load the dataset using the torch_geometric.datasets.flickr.Flickr() function, which downloads and stores it in a local directory. Individual data points can be accessed using indexing, such as _data = dataset[0]. The _data object contains the graph structure and its associated features.

import os
from torch.utils.data import Dataset
from torch_geometric.datasets.flickr import Flickr

path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data', 'Flickr')
_dataset: Dataset = Flickr(path)
_data = _dataset[0]

print(str(_data))

Output:
Data(
     x=[89250, 500],
     edge_index=[2, 899756], 
     y=[89250], 
     train_mask=[89250], 
     val_mask=[89250], 
     test_mask=[89250]
)


GNN block

Let's apply the principle of reusability and encapsulation described at the beginning of this chapter to build a neural block wrapping a GNN layer.
  • Message passing (i.e. GraphConv, GraphGat)
  • Activation function (i.e. ReLU)
  • Batch normalization
  • Drop out regularization
Therefore we define the GNNBaseBlock class as a PyTorch module to encapsulate these components implemented as torch modules.

import torch.nn as nn
from torch_geometric.nn.conv import MessagePassing
from torch_geometric.typing import Adj
from torch_geometric.nn import BatchNorm



class GNNBaseBlock(nn.Module):
    def __init__(self,
                        _id: AnyStr,
                        message_passing: MessagePassing,
                        activation: nn.Module = None,
                        batch_norm: BatchNorm = None,
                        drop_out: float = 0.0) -> None:
 
        self.id = _id
        modules: List[nn.Module] = [message_passing]
        if batch_norm is not None:
            modules.append(batch_norm)
        if drop_out > 0.0:
            modules.append(nn.Dropout(drop_out))
        if activation is not None:
            modules.append(activation)
        self.modules = modules


    def forward(self, x: torch.Tensor, edge_index: Adj) -> torch.Tensor:
        for idx, module in enumerate(self.modules):
            x = module(x, edge_index) if idx == 0 else module(x)
        return x

The forward method iteratively invokes the `__call__` method for each module in the block, which internally executes its own forward method. 

For graph layers, the forward method requires two arguments:
  • features: Typically passed as data.x.
  • edge indices: Provided as data.edge_indices.
For all other modules (e.g., activation functions), the forward method only requires the default input features as its argument.

GNN base model

Our GNN model is composed of a sequence of GNN blocks, gnn_bocks, followed by a fully connected feedforward layer, ffnn and an activation function, output_activation as defined in the constructor.
The method build implements an alternative, simplified constructor.

class GNNBaseModel(nn.Module):
    def __init__(self,
                        model_id: AnyStr,
                        gnn_blocks: List[GNNBaseBlock],
                        ffnn: nn.Module,
                        output_activation: nn.Module) -> None:

        self.model_id = model_id
        self.gnn_blocks = gnn_blocks
        self.ffnn = ffnn
        self.output_activation = output_activation
        
         
        # Extract the torch modules from GNN blocks
        modules: List[nn.Module] = [
              module for block in gnn_blocks for module in block.modules
        ]

        # Add a fully connected feed forward network
        modules.append(nn.Flatten())
        modules.append(ffnn)
        modules.append(output_activation)
        self.modules = nn.Sequential(*modules)
    
@classmethod def build(cls, model_id: AnyStr, gnn_blocks: List[GNNBaseBlock]) -> Self: return cls(model_id, gnn_blocks=gnn_blocks, ffnn=None, output_activation=None)


The forward method invokes the __call__ -> forward method for each of the GNN block, concatenate the output of all the GNN layers, process through the feed forward layer and its activation function.

def forward(self, data: Data) -> torch.Tensor:
    x = data.x
    edge_index = data.edge_index

    output = []
    for gnn_block in self.gnn_blocks:
         x = gnn_block(x, edge_index)      # Invoke gnn_block.forward
         output.append(x)
     
    x = torch.cat(output, dim=-1)             # Concatenate the output of all GNN blocks
    x = self.ffnn(x)

    return self.output_activation(x)


Here is an example of architecture for a Graph Convolutional Neural Network used for classifying Flickr images.

Fig. 3 A 3-layer graph convolutional neural network for Flickr data set


.. and its implementation using PyG

from torch_geometric.nn import GraphConv

     
hidden_channels = 256
num_node_features = _dataset.num_node_features
num_classes = _dataset.num_classes

# First graph convolutional layer and block conv_1 = GraphConv(in_channels=num_node_features, out_channels=hidden_channels) gnn_block_1 = GNNBaseBlock(_id='K1', message_passing=conv_1, activation=nn.ReLU(), drop_out=0.2) # Second graph convolutional layer and block conv_2 = GraphConv(in_channels=hidden_channels, out_channels=hidden_channels) gnn_block_2 = GNNBaseBlock(_id='K2', 
                                                  message_passing=conv_2, 
                                                  activation=nn.ReLU(), 
                                                  drop_out=0.2)
        
   # Third graph convolutional layer and block
conv_3 = GraphConv(in_channels=hidden_channels, out_channels=hidden_channels)
gnn_block_3 = GNNBaseBlock(_id='K3', 
                                                   message_passing=conv_3, 
                                                   activation=nn.ReLU(), 
                                                   drop_out=0.2)

    # Our GNN model
gnn_model = GNNBaseModel(
                         model_id='Flickr',
                         gnn_blocks=[gnn_block_1, gnn_block_2, gnn_block_3],
                         ffnn=nn.Linear(3*hidden_channels, num_classes),
                         output_activation=nn.LogSoftmax(dim=-1))



References

[5pyg.org


-------------
Patrick Nicolas has over 25 years of experience in software and data engineering, architecture design and end-to-end deployment and support with extensive knowledge in machine learning. 
He has been director of data engineering at Aideo Technologies since 2017 and he is the author of "Scala for Machine Learning", Packt Publishing ISBN 978-1-78712-238-3 and 
Geometric Learning in Python Newsletter on LinkedIn.