Target audience: Beginner
Estimated reading time: 4'
Have you ever found it challenging to represent a graph from a very large dataset while building a graph neural network model?
This article presents a method to sample and visualize a subgraph from such datasets.
What you will learn: How to sample and visualize a graph from a very large dataset for modeling graph neural networks.
- Environments: python 3.12.5, matplotlib 3.9, numpy 2.2.0, torch 2.5.1, torch-geometric 2.6.1, networkx 3.4.2
- Source code is available on GitHub [ref 1]
- To enhance the readability of the algorithm implementations, we have omitted non-essential code elements like error checking, comments, exceptions, validation of class and method arguments, scoping qualifiers, and import statement.
This article focuses on visualizing subgraphs within the context of graph neural networks. It does not cover the introduction or explanation of graph neural network architectures and models, as those topics are beyond its scope [ref 2, 3].
Note: There are several Python libraries available for analyzing and visualizing graphs, including Plotly, PyVis, and NetworkKit. In some of our future articles on graph neural networks and geometric deep learning, we will use NetworkX for visualization.
As a reminder, a graph data is fully defined by an instance of Data of package with the following property
- data.x: Node feature matrix with shape num_nodes, num_node_features
- data.edge_index: Graph connectivity with shape 2, num_edges and type torch.Long
- data.edge_attr: Edge feature matrix shape: num_edges, num_edge_features
- data.y: Target to train against (may have arbitrary shape), e.g., node-level targets of shape [num_nodes, *] or graph-level targets of shape [1, *]
- data.pos: Node position matrix with shape num_nodes, num_dimensions
NetworkX library
NetworkX is a BSD-license powerful and flexible Python library for the creation, manipulation, and analysis of complex networks and graphs. It supports various types of graphs, including undirected, directed, and multi-graphs, allowing users to model relationships and structures efficiently [ref 4]
NetworkX provides a wide range of algorithms for graph theory and network analysis, such as shortest paths, clustering, centrality measures, and more. It is designed to handle graphs with millions of nodes and edges, making it suitable for applications in social networks, biology, transportation systems, and other domains. With its intuitive API and rich visualization capabilities, NetworkX is an essential tool for researchers and developers working with network data.
The library supports many standard graph algorithms such as clustering, link analysis, minimum spanning tree, shortest path, cliques, coloring, cuts, Erdos-Renyi or graph polynomial.
Sampling large graphs
Most datasets included in the PyG library contain an extremely large number of nodes and edges, making them impractical to visualize directly. To address this, we can extract (or sample) one or more subgraphs that are easier to display.
In our design, a subgraph is derived from the original large graph by sampling its nodes and edges based on a specified range of indices, as shown below:
In the illustration above, the sampled nodes are [12, ...., 19]
For the sake of simplicity, let wraps the visualization of a graph neural network data into a class GNNPlotter which constructor takes 3 parameters [ref 4]:
- graph: Reference to NetworkX directed or undirected graph
- data: Data representation of the graph dataset
- samples_node_index_range: Tuple (index of first sampled node, index of last sampled node)
import networkx as nx
from networkx import Graph
from import Data
from typing import Tuple, AnyStr, Callable, Dict, Any, Self, List
import matplotlib.pyplot as plt
class GNNPlotter(object):
# Default constructor
def __init__(self, graph: Graph, data: Data, sampled_node_index_range: Tuple[int, int] = None) -> None:
self.graph = graph = data
self.sampled_node_index_range = sampled_node_index_range
# Constructor for undirected graph @classmethod def build(cls, data: Data, sampled_node_index_range: Tuple[int, int] = None) -> Self: return cls(nx.Graph(), data, sampled_node_index_range)
# Constructor for directed graph@classmethod def build_directed(cls, data: Data, sampled_node_index_range: Tuple[int, int] = None) -> Self: return cls(nx.DiGraph(), data, sampled_node_index_range)
# Sample/extract a subgraph from a large graph by selecting
# its nodes through a range of indices
def sample(self) -> None:
# # Display/visualize a subgraph extracted/sampled from a large graph
# by selecting its nodes through a range of indices def draw(self, layout_func: Callable[[Graph], Dict[Any, Any]], node_color: AnyStr, node_size: int, title: AnyStr) -> None:
Simplified constructor for directed (build) and undirected graph (build_directed) are also provided.
Let’s first review our implementation for sampling graph vertices and edges, which involves the following steps:
- Transpose the list of edge indices.
- If a range for sampling indices is defined, use it to extract the indices of the first node, sampled_node_index_range[0], and the last node, last_node_index, in the subgraph.
- Add the edge indices associated with the selected nodes.
def sample(self) -> int:
# 1. Create edge indices vector
edge_index =
transposed = edge_index.T
# 2. Sample the edges of the graph
if self.sampled_node_index_range is not None:
last_node_index = len( \
if self.sampled_node_index_range[1] >= len( \
else self.sampled_node_index_range[1]
condition = ((transposed[:, 0] >= self.sampled_node_index_range[0]) & \
(transposed[:, 0] <= last_node_index))
sampled_nodes = transposed[np.where(condition)]
sampled_nodes = transposed
# 3. Assign the samples to the edge of the graph
return len(sampled_nodes)
Finally, the visualization of the graph is achieved by drawing it with the draw method and overlaying the edges using draw_networkx_edges.
In our basic application, we define the layout, customize the color and size of the nodes, and set the title for the display.
def draw(self,
layout_func: Callable[[Graph], Dict[Any, Any]],
node_color: AnyStr,
node_size: int,
title: AnyStr) -> int:
num_sampled_nodes = self.sample()
# Plot the graph using matplotlib
plt.figure(figsize=(8, 8))
# Draw nodes and edges
pos = layout_func(self.graph)
nx.draw(self.graph, pos, node_size=node_size, node_color=node_color)
nx.draw_networkx_edges(self.graph, pos, arrowsize=40, alpha=0.5, edge_color="black")
# Configure plot
return num_sampled_nodes
NetworkX provides several graph layouts to visually display undirected graphs. Each layout arranges nodes in a specific pattern, suited to different types of graphs and visualization purposes. Here's a brief overview of the common layouts available in networkx [ref 4]
- Spring Layout: Positions nodes using a force-directed algorithm. Nodes repel each other, while edges act as springs pulling connected nodes closer.
- Circular Layout: Arranges nodes uniformly on a circle.
- Shell Layout: Arranges nodes in concentric circles (shells). Useful for graphs with hierarchical structures.
- Planar Layout: Positions nodes to ensure no edges overlap, provided the graph is planar.
- Kamada-Kawai Layout: Positions nodes to minimize the "energy" of the graph. Produces aesthetically pleasing layouts similar to `spring_layout`.
- Spectral Layout: Positions nodes using the eigenvectors of the graph Laplacian. Captures graph structure in the arrangement.
- Random Layout: Places nodes randomly within a unit square.
- Spiral Layout: Positions nodes in a spiral pattern.
Graph representation
Let's consider the Flickr data set included in Torch Geometric (PyG) described in [ref 5]. As a reminder, The Flickr dataset is a graph where nodes represent images and edges signify similarities between them [ref 6]. It includes 89,250 images and 899,756 relationships. Node features consist of image descriptions and shared properties.
Let's apply the class GNNPlotter methods to the Flickr data set, selecting the edges for nodes of index starting 12 to 21 included, using the spring layout.
# Load the Flickr data
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data', 'Flickr')
# Invoke the built-in PyG data loader for Flickr data
_dataset: Dataset = Flickr(path)
_data = _dataset[0]
# Setup the sampling of the undirected graph
gnn_plotter =, sampled_node_index_range =(12, 21))
# Draw the sampled graph
num_sampled_nodes =gnn_plotter.draw(
layout_func=lambda graph: nx.spring_layout(graph, k=1),
title='Flickr spring layout')
print(f'Sample size: {num_sampled_nodes}')
Output: 319
The 319 vertices of the Flickr graph data set and its undirected edges are visualized using the spring layout. The visualization covers 319/89,250 = 0.35% of the entire dataset of images.
Examples of graph display configurations are listed in the Appendix.
Comparing layouts
Undirected graph representation
Let’s demonstrate six common layouts for displaying subgraphs sampled from the Flickr dataset using 67 nodes.
Fig. 3 Display of sampled sub-graph from Flickr data set using 6 common layouts
Finally, let’s apply the same layout to a directed graph derived from the Flickr dataset.