Showing posts with label Sampling. Show all posts
Showing posts with label Sampling. Show all posts

Sunday, January 19, 2025

Sampling & Visualizing PyG Graph Data

  Target audience: Beginner
Estimated reading time: 4'


Have you ever found it challenging to represent a graph from a very large dataset while building a graph neural network model?
This article presents a method to sample and visualize a subgraph from such datasets.

      Overview
      Implementation
      Layouts
   
What you will learn: How to sample and visualize a graph from a very large dataset for modeling graph neural networks.

Notes

  • Environments: python 3.12.5,  matplotlib 3.9, numpy 2.2.0, torch 2.5.1, torch-geometric 2.6.1, networkx 3.4.2
  • Source code is available on GitHub [ref 1]
  • To enhance the readability of the algorithm implementations, we have omitted non-essential code elements like error checking, comments, exceptions, validation of class and method arguments, scoping qualifiers, and import statement.


Introduction

This article focuses on visualizing subgraphs within the context of graph neural networks. It does not cover the introduction or explanation of graph neural network architectures and models, as those topics are beyond its scope [ref 2, 3].


Note: There are several Python libraries available for analyzing and visualizing graphs, including Plotly, PyVis, and NetworkKit. In some of our future articles on graph neural networks and geometric deep learning, we will use NetworkX for visualization.

As a reminder, a graph data is fully defined by an instance of Data of  torch_geometric.data package with the following property
  • data.x: Node feature matrix with shape num_nodes,  num_node_features            
  • data.edge_index: Graph connectivity with shape 2, num_edges and type torch.Long
  • data.edge_attr: Edge feature matrix shape: num_edges, num_edge_features        
  • data.y: Target to train against (may have arbitrary shape), e.g., node-level targets of shape [num_nodes*] or graph-level targets of shape [1*]
  • data.pos: Node position matrix with shape num_nodes, num_dimensions 

NetworkX library

Overview

NetworkX is a BSD-license powerful and flexible Python library for the creation, manipulation, and analysis of complex networks and graphs. It supports various types of graphs, including undirected, directed, and multi-graphs, allowing users to model relationships and structures efficiently [ref 4]
NetworkX provides a wide range of algorithms for graph theory and network analysis, such as shortest paths, clustering, centrality measures, and more. It is designed to handle graphs with millions of nodes and edges, making it suitable for applications in social networks, biology, transportation systems, and other domains. With its intuitive API and rich visualization capabilities, NetworkX is an essential tool for researchers and developers working with network data.

The library supports many standard graph algorithms such as clustering, link analysis, minimum spanning tree, shortest path, cliques, coloring, cuts, Erdos-Renyi or graph polynomial.

Sampling large graphs

Most datasets included in the PyG library contain an extremely large number of nodes and edges, making them impractical to visualize directly. To address this, we can extract (or sample) one or more subgraphs that are easier to display.

In our design, a subgraph is derived from the original large graph by sampling its nodes and edges based on a specified range of indices, as shown below:

Fig. 1 Illustration of sampling a large graph data

In the illustration above, the sampled nodes are [12, ...., 19]

Implementation

For the sake of simplicity, let wraps the visualization of a graph neural network data into a class GNNPlotter which constructor takes 3 parameters [ref 4]:
  • graph: Reference to NetworkX directed or undirected graph
  • data: Data representation of the graph dataset
  • samples_node_index_range: Tuple (index of first sampled node, index of last sampled node) 
import networkx as nx
from networkx import Graph
from torch_geometric.data import Data
from typing import Tuple, AnyStr, Callable, Dict, Any, Self, List
import matplotlib.pyplot as plt


class GNNPlotter(object):
   
    # Default constructor
    def __init__(self, graph: Graph, data: Data, sampled_node_index_range: Tuple[int, int] = None) -> None:
        self.graph = graph
        self.data = data
        self.sampled_node_index_range = sampled_node_index_range

     # Constructor for undirected graph
    @classmethod
    def build(cls, data: Data, sampled_node_index_range: Tuple[int, int] = None) -> Self:
        return cls(nx.Graph(), data, sampled_node_index_range)

    # Constructor for directed graph
@classmethod def build_directed(cls, data: Data, sampled_node_index_range: Tuple[int, int] = None) -> Self: return cls(nx.DiGraph(), data, sampled_node_index_range)  # Sample/extract a subgraph from a large graph by selecting 
   # its nodes through a range of indices
   def sample(self) -> None:

   
# # Display/visualize a subgraph extracted/sampled from a large graph 
   # by selecting its nodes through a range of indices
   def draw(self,
             layout_func: Callable[[Graph], Dict[Any, Any]],
             node_color: AnyStr,
             node_size: int,
             title: AnyStr) -> None:

Simplified constructor for directed (build) and undirected graph (build_directed) are also provided.


Let’s first review our implementation for sampling graph vertices and edges, which involves the following steps:
  1. Transpose the list of edge indices.
  2. If a range for sampling indices is defined, use it to extract the indices of the first node, sampled_node_index_range[0], and the last node, last_node_index,  in the subgraph.
  3. Add the edge indices associated with the selected nodes.
def sample(self) -> int:

    # 1. Create edge indices vector
    edge_index = self.data.edge_index.numpy()
    transposed = edge_index.T
        
    # 2. Sample the edges of the graph
    if self.sampled_node_index_range is not None:
        last_node_index = len(self.data.y) \
                    if self.sampled_node_index_range[1] >= len(self.data.y) \
                    else self.sampled_node_index_range[1]
        condition = ((transposed[:, 0] >= self.sampled_node_index_range[0]) & \ 
                           (transposed[:, 0] <= last_node_index))
        sampled_nodes = transposed[np.where(condition)]
    else:
        sampled_nodes = transposed
         
    # 3. Assign the samples to the edge of the graph
    self.graph.add_edges_from(sampled_nodes)

     return len(sampled_nodes)

Finally, the visualization of the graph is achieved by drawing it with the draw method and overlaying the edges using draw_networkx_edges.

In our basic application, we define the layout, customize the color and size of the nodes, and set the title for the display.

def draw(self,
               layout_func: Callable[[Graph], Dict[Any, Any]],
               node_color: AnyStr,
               node_size: int,
               title: AnyStr) -> int:
        num_sampled_nodes = self.sample()

        # Plot the graph using matplotlib
        plt.figure(figsize=(8, 8))

        # Draw nodes and edges
        pos = layout_func(self.graph)
        nx.draw(self.graph, pos, node_size=node_size, node_color=node_color)
        nx.draw_networkx_edges(self.graph, pos, arrowsize=40, alpha=0.5, edge_color="black")

        # Configure plot
        plt.title(title)
        plt.axis("off")
        plt.show()
        
        return num_sampled_nodes


Layouts

NetworkX provides several graph layouts to visually display undirected graphs. Each layout arranges nodes in a specific pattern, suited to different types of graphs and visualization purposes. Here's a brief overview of the common layouts available in networkx [ref 4]
  1. Spring Layout: Positions nodes using a force-directed algorithm. Nodes repel each other, while edges act as springs pulling connected nodes closer.
  2. Circular Layout: Arranges nodes uniformly on a circle.
  3. Shell Layout: Arranges nodes in concentric circles (shells). Useful for graphs with hierarchical structures.
  4. Planar Layout: Positions nodes to ensure no edges overlap, provided the graph is planar.
  5. Kamada-Kawai Layout: Positions nodes to minimize the "energy" of the graph. Produces aesthetically pleasing layouts similar to `spring_layout`.
  6. Spectral Layout: Positions nodes using the eigenvectors of the graph Laplacian. Captures graph structure in the arrangement.
  7. Random Layout: Places nodes randomly within a unit square.
  8. Spiral Layout: Positions nodes in a spiral pattern.

Graph representation

Let's consider the Flickr data set included in Torch Geometric (PyG) described in [ref 5]. As a reminder, The Flickr dataset is a graph where nodes represent images and edges signify similarities between them [ref 6]. It includes 89,250 images and 899,756 relationships. Node features consist of image descriptions and shared properties.

Let's apply the class GNNPlotter methods to the Flickr data set, selecting the edges for nodes of index starting 12 to 21 included, using the spring layout.

# Load the Flickr data
path = os.path.join(os.path.dirname(os.path.realpath(__file__)), '..', 'data', 'Flickr')

# Invoke the built-in PyG data loader for Flickr data
_dataset: Dataset = Flickr(path)
_data = _dataset[0]

# Setup the sampling of the undirected graph
gnn_plotter = GNNPlotter.build(_data, sampled_node_index_range =(12, 21))

# Draw the sampled graph
num_sampled_nodes =gnn_plotter.draw(
                            layout_func=lambda graph: nx.spring_layout(graph, k=1),
                            node_color='blue',
                            node_size=40,
                            title='Flickr spring layout')

print(f'Sample size: {num_sampled_nodes}')

Output: 319

The 319 vertices of the Flickr graph data set and its undirected edges are visualized using the spring layout. The visualization covers 319/89,250 = 0.35% of the entire dataset of images.
Fig. 2 Display of sampled sub-graph from Flickr data set using spring layout


Comparing layouts

Undirected graph representation
Let’s demonstrate six common layouts for displaying subgraphs sampled from the Flickr dataset using 67 nodes.


Fig. 3 Display of sampled sub-graph from Flickr data set using 6 common layouts


Directed graph representation
Finally, let’s apply the same layout to a directed graph derived from the Flickr dataset.

Fig. 4 Display of directed sub-graph from Flickr data set using 6 common layouts


References




-------------
Patrick Nicolas has over 25 years of experience in software and data engineering, architecture design and end-to-end deployment and support with extensive knowledge in machine learning. 
He has been director of data engineering at Aideo Technologies since 2017 and he is the author of "Scala for Machine Learning", Packt Publishing ISBN 978-1-78712-238-3 and 
Geometric Learning in Python Newsletter on LinkedIn.

Thursday, May 5, 2016

Bootstrapping With Replacement in Scala

Target audience: Intermediate
Estimated reading time: 6'

Bootstrapping is a method in statistics where random sampling of a dataset is done with replacement. It allows data scientists to determine the sampling distribution for a broad range of probability distributions through this technique.

Background

A primary goal of bootstrapping is to evaluate the precision of various statistical measures like mean, standard deviation, median, mode, or error. These measures, sf, often termed as estimators, serve to approximate a distribution. The most frequently used approximation is called the empirical distribution function. When the data points x are independent and identically distributed (iid), this empirical or approximate distribution can be assessed by employing resampling techniques.

The following diagram captures the essence of bootstrapping by resampling.

Generation of bootstrap replicates by resampling

Each of the B bootstrap samples has the same number of observations or data points as the original data set from which the samples are created. Once the samples are created, a statistical function sf such as mean, mode, median or standard deviation is computed for each sample.
The standard deviation for the B statistics should be similar to the standard deviation of the original data set.

Implementation in Scala

The purpose of this post is to illustrate some basic properties of bootstrapped sampling
  • Profile of the distribution of statistics sf for a given probability distribution
  • Comparison of the standard deviation of the statistics sf with the standard deviation of the original dataset
Let's implement a bootstrap by resampling in Scala, starting with a class Bootstrap.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
class Bootstrap(
    numSamples: Int = 1000,
    sf: Vector[Double] => Double,
    inputDistribution: Vector[Double],
    randomizer: Random
) {

  lazy val bootstrappedReplicates: Array[Double] =
     (0 until numSamples)./:( new mutable.ArrayBuffer[Double] )(
        ( buf, _ ) => buf += createBootstrapSample
      ).toArray

  def createBootstrapSample: Double {}

  lazy val mean = bootstrappedReplicates.reduce( _ + _ )/numSamples

  def error: Double = {}
}

The class Bootstrap is instantiated with a pre-defined number of samples, numSamples (line 2), a statistic function sf (line 3), a data set generated by a given distribution inputDistribution (line 4) and a randomizer (line 5).
The computation of the bootstrap replicates, bootstrappedReplicates is central to resampling (lines 8 - 11). As described in the introduction, a replicate, s is computed from a sample of the original data set with the method createBootstrapSample (line 10).

Let's implement the method createBootstrapSample.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
def createBootstrapSample: Double =
    sf(
       (0 until inputDistribution.size)./:( new mutable.ArrayBuffer[Double] )(
         ( buf, _ ) => {
            randomizer.setSeed( randomizer.nextLong )
            val randomValueIndex = randomizer.nextInt( inputDistribution.size )
            buf += inputDistribution( randomValueIndex )
         }
       ).toVector
    )

The method createBootstrapSample
- Samples the original dataset using a uniform random function (line 6)
- Applies the statistic function sf to this sample dataset (line 1 & 11)

The last step consists of computing the error (deviation) on the bootstrap replicates

1
2
3
4
5
6
7
8
  def error: Double = {
      val sumOfSquaredDiff = bootstrappedReplicates.reduce(
        (s1: Double, s2: Double) => (s1 - mean) (s1 - mean) +  (s2 - mean)*(s2 - mean)
      )

      Math.sqrt(sumOfSquaredDiff / (numSamples - 1))
  }


Evaluation

The first evaluation consists of comparing the distribution of replicates with the original distribution. To this purpose, we generate an input dataset using
  • Normal distribution
  • LogNormal distribution

Setup

Let's create a method, bootstrapEvaluation to compare the distribution of the bootstrap replicates with the dataset from which the bootstrap samples are generated.

 1
 2
 3
 4
 5
 6
 7
 8
 9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
def bootstrapEvaluation( 
     dist: RealDistribution, 
     random: Random, 
     gen: (Double, Double) 
): (Double, Double) = {

    val inputDistribution = (0 until 5000)./:(new ArrayBuffer[(Double, Double)])
     (
       ( buf, _ ) => {
          val x = gen._1 * random.nextDouble - gen._2
          buf += ( ( x, dist.density( x ) ) )
      }
    ).toVector

    val mean = (x: Vector[Double]) => x.sum/x.length
    val bootstrap = new Bootstrap(
        numReplicates,
        mean, 
        inputDistribution.map( _._2 ), 
        new Random( System.currentTimeMillis)
    )

    val meanS = bootstrap.bootstrappedReplicates.sum / 
          bootstrap.bootstrappedReplicates.size
    val sProb = bootstrap.bootstrappedReplicates.map(_ - meanS)
         // .. plotting histogram of distribution sProb
    (bootstrap.mean, bootstrap.error)
  }

We are using the normal and log normal probability density function defined in the Apache Commons Math Java library. These probability density functions are defined in the org.apache.commons.math3.distribution package.

The comparative method bootstrapEvaluation has the following argument:
  • dist: A probability density function used to generate the data set upon which sampling is performed (line 2).
  • random: A random number generator (line 3)
  • gen: A pair of parameters for the linear transform for the generation of random values (a.r + b) (line 4).
The input distribution inputDistribution { (x, pdf(x)} is generated for 5000 data points (lines 7 - 13).
Next the bootstrap is created with the appropriate number of replicates, numReplicates, the mean of the input data set as the statistical function s, the input distribution and the generic random number generator of Scala library, as arguments (lines 16 -20).
Let's plot the distribution the input data set generated from a normal density function.

val (meanNormal, errorNormal) = bootstrap(
    new NormalDistribution, 
    new scala.util.Random, 
    (5.0, 2.5)
)

Normally distributed dataset

The first graph plots the distribution of the input dataset using the Normal distribution.


The second graph illustrates the distribution (histogram) of the replicates s - mean.


The bootstrap replicates sf(x) are also normally distributed. The mean value for the bootstrap replicates is 0.1978 and the error is 0.001691

Dataset with a log normal distribution

We repeat the same process for the lognormal distribution. This time around the dataset to sample from follows a log-normal distribution.

val (meanLogNormal, errorLogNormal) = bootstrap(
    new LogNormalDistribution, 
    new scala.util.Random, 
    (2.0, 0.0)
)



Although the original dataset used for generated the bootstrap samples is normally distributed, the bootstrap replicates sf(x) are normally distributed. The mean for the bootstrap replicates is 0.3801 and the error is 0.002937

The error for the bootstrap resampling from a log normal distribution is twice as much as the error related to the normal distribution
The result is not surprising: The bootstrap replicates follow a normal distribution which matches closely the original dataset created using the same probability density function. 

References

  • Programming in Scala - 3rd edition M Odersky, L. Spoon, B. Venners - Artima - 2016
  • Elements of Statistics Learning: Data mining, Inference and Prediction - 7.11 Bootstrap method Springer - 2001
  • github.com/patnicolas