Notes:
- Environments: python 3.12.5, matplotlib 3.9, numpy 2.2.0, torch 2.5.1, torch-geometric 2.6.1, networkx 3.4.2
- Source code is available on GitHub [ref 1]
- To enhance the readability of the algorithm implementations, we have omitted non-essential code elements like error checking, comments, exceptions, validation of class and method arguments, scoping qualifiers, and import statement.
Introduction
- data.x: Node feature matrix with shape num_nodes, num_node_features
- data.edge_index: Graph connectivity with shape 2, num_edges and type torch.Long
- data.edge_attr: Edge feature matrix shape: num_edges, num_edge_features
- data.y: Target to train against (may have arbitrary shape), e.g., node-level targets of shape [num_nodes, *] or graph-level targets of shape [1, *]
- data.pos: Node position matrix with shape num_nodes, num_dimensions
NetworkX library
Overview
Sampling large graphs
Implementation
- graph: Reference to NetworkX directed or undirected graph
- data: Data representation of the graph dataset
- samples_node_index_range: Tuple (index of first sampled node, index of last sampled node)
import networkx as nx
from networkx import Graph
from torch_geometric.data import Data
from typing import Tuple, AnyStr, Callable, Dict, Any, Self, List
import matplotlib.pyplot as plt
class GNNPlotter(object):
# Default constructor
def __init__(self, graph: Graph, data: Data, sampled_node_index_range: Tuple[int, int] = None) -> None:
self.graph = graph
self.data = data
self.sampled_node_index_range = sampled_node_index_range
# Constructor for undirected graph @classmethod def build(cls, data: Data, sampled_node_index_range: Tuple[int, int] = None) -> Self: return cls(nx.Graph(), data, sampled_node_index_range)
# Constructor for directed graph@classmethod def build_directed(cls, data: Data, sampled_node_index_range: Tuple[int, int] = None) -> Self: return cls(nx.DiGraph(), data, sampled_node_index_range) # Sample/extract a subgraph from a large graph by selecting
# its nodes through a range of indices
def sample(self) -> None:
# # Display/visualize a subgraph extracted/sampled from a large graph
# by selecting its nodes through a range of indices def draw(self, layout_func: Callable[[Graph], Dict[Any, Any]], node_color: AnyStr, node_size: int, title: AnyStr) -> None:
- Transpose the list of edge indices.
- If a range for sampling indices is defined, use it to extract the indices of the first node, sampled_node_index_range[0], and the last node, last_node_index, in the subgraph.
- Add the edge indices associated with the selected nodes.
def sample(self) -> int:
# 1. Create edge indices vector
edge_index = self.data.edge_index.numpy()
transposed = edge_index.T
# 2. Sample the edges of the graph
if self.sampled_node_index_range is not None:
last_node_index = len(self.data.y) \
if self.sampled_node_index_range[1] >= len(self.data.y) \
else self.sampled_node_index_range[1]
condition = ((transposed[:, 0] >= self.sampled_node_index_range[0]) & \
(transposed[:, 0] <= last_node_index))
sampled_nodes = transposed[np.where(condition)]
else:
sampled_nodes = transposed
# 3. Assign the samples to the edge of the graph
self.graph.add_edges_from(sampled_nodes)
return len(sampled_nodes)
def draw(self,
layout_func: Callable[[Graph], Dict[Any, Any]],
node_color: AnyStr,
node_size: int,
title: AnyStr) -> int:
num_sampled_nodes = self.sample()
# Plot the graph using matplotlib
plt.figure(figsize=(8, 8))
# Draw nodes and edges
pos = layout_func(self.graph)
nx.draw(self.graph, pos, node_size=node_size, node_color=node_color)
nx.draw_networkx_edges(self.graph, pos, arrowsize=40, alpha=0.5, edge_color="black")
# Configure plot
plt.title(title)
plt.axis("off")
plt.show()
return num_sampled_nodes
Layouts
- Spring Layout: Positions nodes using a force-directed algorithm. Nodes repel each other, while edges act as springs pulling connected nodes closer.
- Circular Layout: Arranges nodes uniformly on a circle.
- Shell Layout: Arranges nodes in concentric circles (shells). Useful for graphs with hierarchical structures.
- Planar Layout: Positions nodes to ensure no edges overlap, provided the graph is planar.
- Kamada-Kawai Layout: Positions nodes to minimize the "energy" of the graph. Produces aesthetically pleasing layouts similar to `spring_layout`.
- Spectral Layout: Positions nodes using the eigenvectors of the graph Laplacian. Captures graph structure in the arrangement.
- Random Layout: Places nodes randomly within a unit square.
- Spiral Layout: Positions nodes in a spiral pattern.
Graph representation
# Load the Flickr data
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data', 'Flickr')
# Invoke the built-in PyG data loader for Flickr data
_dataset: Dataset = Flickr(path)
_data = _dataset[0]
# Setup the sampling of the undirected graph
gnn_plotter = GNNPlotter.build(_data, sampled_node_index_range =(12, 21))
# Draw the sampled graph
num_sampled_nodes =gnn_plotter.draw(
layout_func=lambda graph: nx.spring_layout(graph, k=1),
node_color='blue',
node_size=40,
title='Flickr spring layout')
print(f'Sample size: {num_sampled_nodes}')
Comparing layouts
References
Patrick Nicolas has over 25 years of experience in software and data engineering, architecture design and end-to-end deployment and support with extensive knowledge in machine learning.
He has been director of data engineering at Aideo Technologies since 2017 and he is the author of "Scala for Machine Learning", Packt Publishing ISBN 978-1-78712-238-3 and Geometric Learning in Python Newsletter on LinkedIn.