Tuesday, December 10, 2024

Fréchet Centroid on Manifolds in Python

  Target audience: Intermediate
Estimated reading time: 5'


The Fréchet centroid (or intrinsic centroid) is a generalization of the concept of a mean to data points that lie on a manifold or in a non-Euclidean space. It minimizes a similar quantity defined using the intrinsic geometry of the manifold.


Follow me on LinkedIn

What you will learn: How to compute the Frechet centroid (or mean) of multiple point on a data manifold.

Notes

  • Environments: Python 3.12.5,  GeomStats 2.8.0, Matplotlib 3.9, Numpy 2.2.0
  • Source code is available on GitHub [ref 1]
  • To enhance the readability of the algorithm implementations, we have omitted non-essential code elements like error checking, comments, exceptions, validation of class and method arguments, scoping qualifiers, and import statement.


Introduction

For readers unfamiliar with manifolds and basic differential geometry, I highly recommend starting with my two introductory posts:

  1. Foundation of Geometric Learning This post introduces differential geometry in the context of machine learning, outlining its foundational components.
  2. Differentiable Manifolds for Geometric Learning: This post explores manifold concepts such as tangent vectors and geodesics, with Python implementations for the hypersphere using the Geomstats library.
Alternatively, I strongly encourage readers to consult tutorials [ref 2], video series [ref 3], or publications [ref 4, 5] for more in-depth information.

The following image illustrates the difference between the Euclidean and  Fréchet  means given a manifold.


Fig. 1. Illustration of Euclidean and Frechet means on an arbitrary manifold


Let's consider n tensors x[i] with d values, the Euclidean mean is computed as
\[\mathbf{\mu} = \frac{1}{n}\sum_{i=1}^{n} \mathbf{x}_{i} \] with for each  coordinate/value of index j: \[\mu^{(j)}=\frac{1}{n}\sum_{i=1}^{n}x_{i}^{(j)} \] 

For a manifold M and a metric-distance d the  Fréchet centroid is defined as \[ m=arg\displaystyle \min_{p \in M}\sum_{i=1}^{n}d^{2}\left(p, \mathbf{x_{i}} \right) \] and the weighted Frechet mean as: \[m=arg\displaystyle \min_{p \in M}\sum_{i=1}^{n}w_{i}d^{2}\left(p, \mathbf{x_{i}} \right) \]


Implementation

Let's explore the computation of means from the perspective of extrinsic geometry, focusing on surfaces embedded within a three-dimensional Euclidean space.
To begin, we'll define a class `FrechetEstimator` to encapsulate the essential components of the estimator:
  • space: A smooth manifold (e.g., Sphere, Hyperbolic) or a Lie group (e.g., SO3, SE3).  
  • optimizer: An optimization algorithm used to iterate through the space and minimize the sum of squared distances.  
  • weights: Optional weights that can be applied during the mean estimation process.

Note: The constructor invoke the Geomstats class FrechetMean associated to the given manifold


class FrechetEstimator(object):
    def __init__(self, space: Manifold, optimizer: BaseGradientDescent, weights: Tensor = None) -> None:
        self.frechet_mean = FrechetMean(space)
        self.frechet_mean.optimizer = optimizer
        self.weights = weights
        self.space = space

    def estimate(self, X: List[np.array]) -> np.array:
    
    def rand(self, num_samples: int) -> List[np.array]:


The `rand` method generates random points on the manifold for evaluation purposes, utilizing the `random_uniform` function from the Geomstats library.

The implementation ensures that:
- The manifold type is supported.  
- All randomly generated points (as tensors) are verified to belong to the manifold.

def rand(self, num_samples: int) -> np.array:
    from geomstats.geometry.hypersphere import Hypersphere
    from geomstats.geometry.special_orthogonal import _SpecialOrthogonalMatrices

    # Test if this manifold is supported
    if not (isinstance(self.space, Hypersphere) or 
              isinstance(self.space, _SpecialOrthogonalMatrices)):
        raise GeometricException('Cannot generate random values on unsupported manifold')

    X = self.space.random_uniform(num_samples)
    # Validate the randomly generated belongs to the manifold 'self.space'
    are_points_valid = all([self.space.belongs(x) for x in X])
    if not are_points_valid:
        raise GeometricException('Some generated points do not belong to the manifold')
    return X

Our evaluation uses the Euclidean mean as a baseline, with a straightforward implementation using NumPy.


def euclidean_mean(manifold_points: List[Tensor]) -> np.array:
      return np.mean(manifold_points, axis=0)


The computation of the Fréchet centroid on a sequence of data points defined as stacked numpy arrays, is implemented in method, estimate. It relies on the Geomstats method FrechetMean.fit

def estimate(self, X: np.array) -> np.array:
    self.frechet_mean.fit(X=X, y=None, weights=self.weights)
    return self.frechet_mean.estimate_

Evaluation

Centroid on Sphere (S2)

Our initial test involves calculating the non-weighted Fréchet centroid for 7 points located on a hypersphere [ref 6]. The randomly generated points, `rand_points`, are visualized both on a 3D sphere and in Euclidean space using the `HyperspherPlot` and `EuclideanPlot` classes. While the code is not included here for clarity, it is available on GitHub [ref 1].


frechet_estimator = FrechetEstimator(space=Hypersphere(dim=2),
                                                            optimizer=GradientDescent(), 
                                                            weights=None)
# 1- Generate the random point on the hypersphere
rand_points = frechet_estimator.rand(8)

# 2- Estimate the Frechet centroid then test if belongs to the manifold
frechet_mean = frechet_estimator.estimate(rand_points)
assert frechet_estimator.space.belongs(frechet_mean)
# 3- Display the points on the Hypersphere on plot hypersphere_plot = HyperspherePlot(rand_points,  
frechet_mean)
hypersphere_plot.show()

# 4- Compute the euclidean or arithmetic centroid
euclidean_mean = FrechetEstimator.euclidean_mean(np_points)

# 5- Display the points in 3D Euclidean space
euclidean_plot = EuclideanPlot(rand_points, euclidean_mean
euclidean_plot.show()

print(f'\nFrechet mean:   {frechet_mean}\nEuclidean mean: {euclidean_mean}')


Output:

Frechet mean:     [ 0.0174    -0.9958   -0.0897]

Euclidean mean: [ 0.0220    -0.0792     0.0226]


The 8 random points are visualized using Matplotlib 3D scatter plot.


'

Fig 2. Visualization of points from a 2-dimension hypersphere on 3D Euclidean space


The next graph visualizes the 8 random points on a 3D Sphere.

Fig 3. Visualization of random points on a 2-dimension hypersphere



Centroid on Special Orthogonal Group (SO3)

A special Orthogonal Group is a Lie group. In differential geometry, a Lie group is a mathematical structure that combines the properties of both a group and a smooth manifold. It allows for the application of both algebraic and geometric techniques. As a group, it has an operation (like multiplication) that satisfies certain axioms (closure, associativity, identity, and invertibility).
The Special Orthogonal Group in 3 dimensions, SO(3) is the group of all rotation matrices in 3 spatial dimensions.
It can be defined by 3 rotation elements for each of the axis of rotation x, y, and z.

Important note: The SO(3) manifold is described and evaluated with a Python source code in a previous post [ref 7].


We utilize the Geomstats `SpecialOrthogonal` class to generate 4 random matrices, then compute their Euclidean and Fréchet centroids.

manifold = SpecialOrthogonal(n=3, point_type="matrix")
frechet_estimator = FrechetEstimator(manifold, GradientDescent(), weights=None)

# Generate the random point on SO3
manifold_points = frechet_estimator.rand(4)

# Visualize the SO3 matrices on 3D plot
so3_plot = SO3Plot(manifold_points) so3_plot.show()

# Compute the Frechet and Euclidean centroids
frechet_mean = frechet_estimator.estimate(manifold_points) euclidean_mean = FrechetEstimator.euclidean_mean(manifold_points) print(f'\nFrechet mean:\n{frechet_mean}\nEuclidean mean:\n{euclidean_mean}')

Output:
Frechet mean:   
[[-0.74841 -0.40675 -0.52385]
 [-0.01445 -0.77966  0.62603]
 [-0.66306  0.47610  0.57763]]
Euclidean mean: 
[[-0.52282 -0.24432 -0.38322]
 [ 0.03642 -0.55841  0.37230]
 [-0.44746  0.26312  0.11775]]

In this scenario we represent 4 asymmetric 3x3 matrices representing the SO(3) rotations on a 3D plots.

Fig 4. Visualization of 4 random SO3 matrices on a 3D plot

References

[4Vector and Tensor Analysis with Applications -  A. I. Borisenko, I. E. Tarapov - Dover Books on MathenaticsPublications 1979 
[5A Student's Guide to Vectors and Tensors - D. Fleisch - Cambridge University Press - 2008





-------------
Patrick Nicolas has over 25 years of experience in software and data engineering, architecture design and end-to-end deployment and support with extensive knowledge in machine learning. 
He has been director of data engineering at Aideo Technologies since 2017 and he is the author of "Scala for Machine Learning", Packt Publishing ISBN 978-1-78712-238-3 and 
Geometric Learning in Python Newsletter on LinkedIn.



No comments:

Post a Comment

Note: Only a member of this blog may post a comment.