Loading [MathJax]/extensions/TeX/AMSsymbols.js
Showing posts with label Frechet centroid. Show all posts
Showing posts with label Frechet centroid. Show all posts

Tuesday, December 10, 2024

Fréchet Centroid on Manifolds in Python

  Target audience: Intermediate
Estimated reading time: 5'


The Fréchet centroid (or intrinsic centroid) is a generalization of the concept of a mean to data points that lie on a manifold or in a non-Euclidean space. It minimizes a similar quantity defined using the intrinsic geometry of the manifold.



What you will learn: How to compute the Frechet centroid (or mean) of multiple point on a data manifold.

Notes

  • Environments: Python 3.12.5,  GeomStats 2.8.0, Matplotlib 3.9, Numpy 2.2.0
  • Source code is available on GitHub [ref 1]
  • To enhance the readability of the algorithm implementations, we have omitted non-essential code elements like error checking, comments, exceptions, validation of class and method arguments, scoping qualifiers, and import statement.


Introduction

For readers unfamiliar with manifolds and basic differential geometry, I highly recommend starting with my two introductory posts:

  1. Foundation of Geometric Learning This post introduces differential geometry in the context of machine learning, outlining its foundational components.
  2. Differentiable Manifolds for Geometric Learning: This post explores manifold concepts such as tangent vectors and geodesics, with Python implementations for the hypersphere using the Geomstats library.
Alternatively, I strongly encourage readers to consult tutorials [ref 2], video series [ref 3], or publications [ref 4, 5] for more in-depth information.

The following image illustrates the difference between the Euclidean and  Fréchet  means given a manifold.


Fig. 1. Illustration of Euclidean and Frechet means on an arbitrary manifold


Let's consider n tensors x[i] with d values, the Euclidean mean is computed as
\[\mathbf{\mu} = \frac{1}{n}\sum_{i=1}^{n} \mathbf{x}_{i} \] with for each  coordinate/value of index j: \[\mu^{(j)}=\frac{1}{n}\sum_{i=1}^{n}x_{i}^{(j)} \] 

For a manifold M and a metric-distance d the  Fréchet centroid is defined as \[ m=arg\displaystyle \min_{p \in M}\sum_{i=1}^{n}d^{2}\left(p, \mathbf{x_{i}} \right) \] and the weighted Frechet mean as: \[m=arg\displaystyle \min_{p \in M}\sum_{i=1}^{n}w_{i}d^{2}\left(p, \mathbf{x_{i}} \right) \]


Implementation

Let's explore the computation of means from the perspective of extrinsic geometry, focusing on surfaces embedded within a three-dimensional Euclidean space.
To begin, we'll define a class `FrechetEstimator` to encapsulate the essential components of the estimator:
  • space: A smooth manifold (e.g., Sphere, Hyperbolic) or a Lie group (e.g., SO3, SE3).  
  • optimizer: An optimization algorithm used to iterate through the space and minimize the sum of squared distances.  
  • weights: Optional weights that can be applied during the mean estimation process.

Note: The constructor invoke the Geomstats class FrechetMean associated to the given manifold


class FrechetEstimator(object):
    def __init__(self, space: Manifold, optimizer: BaseGradientDescent, weights: Tensor = None) -> None:
        self.frechet_mean = FrechetMean(space)
        self.frechet_mean.optimizer = optimizer
        self.weights = weights
        self.space = space

    def estimate(self, X: List[np.array]) -> np.array:
    
    def rand(self, num_samples: int) -> List[np.array]:


The `rand` method generates random points on the manifold for evaluation purposes, utilizing the `random_uniform` function from the Geomstats library.

The implementation ensures that:
- The manifold type is supported.  
- All randomly generated points (as tensors) are verified to belong to the manifold.

def rand(self, num_samples: int) -> np.array:
    from geomstats.geometry.hypersphere import Hypersphere
    from geomstats.geometry.special_orthogonal import _SpecialOrthogonalMatrices

    # Test if this manifold is supported
    if not (isinstance(self.space, Hypersphere) or 
              isinstance(self.space, _SpecialOrthogonalMatrices)):
        raise GeometricException('Cannot generate random values on unsupported manifold')

    X = self.space.random_uniform(num_samples)
    # Validate the randomly generated belongs to the manifold 'self.space'
    are_points_valid = all([self.space.belongs(x) for x in X])
    if not are_points_valid:
        raise GeometricException('Some generated points do not belong to the manifold')
    return X

Our evaluation uses the Euclidean mean as a baseline, with a straightforward implementation using NumPy.


def euclidean_mean(manifold_points: List[Tensor]) -> np.array:
      return np.mean(manifold_points, axis=0)


The computation of the Fréchet centroid on a sequence of data points defined as stacked numpy arrays, is implemented in method, estimate. It relies on the Geomstats method FrechetMean.fit

def estimate(self, X: np.array) -> np.array:
    self.frechet_mean.fit(X=X, y=None, weights=self.weights)
    return self.frechet_mean.estimate_

Evaluation

Centroid on Sphere (S2)

Our initial test involves calculating the non-weighted Fréchet centroid for 7 points located on a hypersphere [ref 6]. The randomly generated points, `rand_points`, are visualized both on a 3D sphere and in Euclidean space using the `HyperspherePlot` and `EuclideanPlot` classes. While the code is not included here for clarity, it is available on GitHub [ref 1].


frechet_estimator = FrechetEstimator(space=Hypersphere(dim=2),
                                                            optimizer=GradientDescent(), 
                                                            weights=None)
# 1- Generate the random point on the hypersphere
rand_points = frechet_estimator.rand(8)

# 2- Estimate the Frechet centroid then test if belongs to the manifold
frechet_mean = frechet_estimator.estimate(rand_points)
assert frechet_estimator.space.belongs(frechet_mean)
# 3- Display the points on the Hypersphere on plot hypersphere_plot = HyperspherePlot(rand_points,  
frechet_mean)
hypersphere_plot.show()

# 4- Compute the euclidean or arithmetic centroid
euclidean_mean = FrechetEstimator.euclidean_mean(np_points)

# 5- Display the points in 3D Euclidean space
euclidean_plot = EuclideanPlot(rand_points, euclidean_mean
euclidean_plot.show()

print(f'\nFrechet mean:   {frechet_mean}\nEuclidean mean: {euclidean_mean}')


Output:

Frechet mean:     [ 0.0174    -0.9958   -0.0897]

Euclidean mean: [ 0.0220    -0.0792     0.0226]


The 8 random points are visualized using Matplotlib 3D scatter plot.


'

Fig 2. Visualization of points from a 2-dimension hypersphere on 3D Euclidean space


The next graph visualizes the 8 random points on a 3D Sphere.

Fig 3. Visualization of random points on a 2-dimension hypersphere



Centroid on Special Orthogonal Group (SO3)

A special Orthogonal Group is a Lie group. In differential geometry, a Lie group is a mathematical structure that combines the properties of both a group and a smooth manifold. It allows for the application of both algebraic and geometric techniques. As a group, it has an operation (like multiplication) that satisfies certain axioms (closure, associativity, identity, and invertibility).
The Special Orthogonal Group in 3 dimensions, SO(3) is the group of all rotation matrices in 3 spatial dimensions.
It can be defined by 3 rotation elements for each of the axis of rotation x, y, and z.

Important note: The SO(3) manifold is described and evaluated with a Python source code in a previous post [ref 7].


We utilize the Geomstats `SpecialOrthogonal` class to generate 4 random matrices, then compute their Euclidean and Fréchet centroids.

manifold = SpecialOrthogonal(n=3, point_type="matrix")
frechet_estimator = FrechetEstimator(manifold, GradientDescent(), weights=None)

# Generate the random point on SO3
manifold_points = frechet_estimator.rand(4)

# Visualize the SO3 matrices on 3D plot
so3_plot = SO3Plot(manifold_points) so3_plot.show()

# Compute the Frechet and Euclidean centroids
frechet_mean = frechet_estimator.estimate(manifold_points) euclidean_mean = FrechetEstimator.euclidean_mean(manifold_points) print(f'\nFrechet mean:\n{frechet_mean}\nEuclidean mean:\n{euclidean_mean}')

Output:
Frechet mean:   
[[-0.74841 -0.40675 -0.52385]
 [-0.01445 -0.77966  0.62603]
 [-0.66306  0.47610  0.57763]]
Euclidean mean: 
[[-0.52282 -0.24432 -0.38322]
 [ 0.03642 -0.55841  0.37230]
 [-0.44746  0.26312  0.11775]]

In this scenario we represent 4 asymmetric 3x3 matrices representing the SO(3) rotations on a 3D plots.

Fig 4. Visualization of 4 random SO3 matrices on a 3D plot