The Fréchet centroid (or intrinsic centroid) is a generalization of the concept of a mean to data points that lie on a manifold or in a non-Euclidean space. It minimizes a similar quantity defined using the intrinsic geometry of the manifold.
What you will learn: How to compute the Frechet centroid (or mean) of multiple point on a data manifold.
Notes:
- Environments: Python 3.12.5, GeomStats 2.8.0, Matplotlib 3.9, Numpy 2.2.0
- Source code is available on GitHub [ref 1]
- To enhance the readability of the algorithm implementations, we have omitted non-essential code elements like error checking, comments, exceptions, validation of class and method arguments, scoping qualifiers, and import statement.
Introduction
For readers unfamiliar with manifolds and basic differential geometry, I highly recommend starting with my two introductory posts:
- Foundation of Geometric Learning This post introduces differential geometry in the context of machine learning, outlining its foundational components.
- Differentiable Manifolds for Geometric Learning: This post explores manifold concepts such as tangent vectors and geodesics, with Python implementations for the hypersphere using the Geomstats library.
Alternatively, I strongly encourage readers to consult tutorials [ref 2], video series [ref 3], or publications [ref 4, 5] for more in-depth information.
The following image illustrates the difference between the Euclidean and Fréchet means given a manifold.
Fig. 1. Illustration of Euclidean and Frechet means on an arbitrary manifold
Let's consider n tensors x[i] with d values, the Euclidean mean is computed as
\[\mathbf{\mu} = \frac{1}{n}\sum_{i=1}^{n} \mathbf{x}_{i} \] with for each coordinate/value of index j: \[\mu^{(j)}=\frac{1}{n}\sum_{i=1}^{n}x_{i}^{(j)} \]
For a manifold M and a metric-distance d the Fréchet centroid is defined as \[ m=arg\displaystyle \min_{p \in M}\sum_{i=1}^{n}d^{2}\left(p, \mathbf{x_{i}} \right) \] and the weighted Frechet mean as: \[m=arg\displaystyle \min_{p \in M}\sum_{i=1}^{n}w_{i}d^{2}\left(p, \mathbf{x_{i}} \right) \]
Implementation
Let's explore the computation of means from the perspective of extrinsic geometry, focusing on surfaces embedded within a three-dimensional Euclidean space.
To begin, we'll define a class `FrechetEstimator` to encapsulate the essential components of the estimator:
To begin, we'll define a class `FrechetEstimator` to encapsulate the essential components of the estimator:
- space: A smooth manifold (e.g., Sphere, Hyperbolic) or a Lie group (e.g., SO3, SE3).
- optimizer: An optimization algorithm used to iterate through the space and minimize the sum of squared distances.
- weights: Optional weights that can be applied during the mean estimation process.
Note: The constructor invoke the Geomstats class FrechetMean associated to the given manifold
class FrechetEstimator(object):
def __init__(self, space: Manifold, optimizer: BaseGradientDescent, weights: Tensor = None) -> None:
self.frechet_mean = FrechetMean(space)
self.frechet_mean.optimizer = optimizer
self.weights = weights
self.space = space
def estimate(self, X: List[np.array]) -> np.array:
def rand(self, num_samples: int) -> List[np.array]:
The `rand` method generates random points on the manifold for evaluation purposes, utilizing the `random_uniform` function from the Geomstats library.
The implementation ensures that:
- The manifold type is supported.
- All randomly generated points (as tensors) are verified to belong to the manifold.
def rand(self, num_samples: int) -> np.array:
from geomstats.geometry.hypersphere import Hypersphere
from geomstats.geometry.special_orthogonal import _SpecialOrthogonalMatrices
# Test if this manifold is supported
if not (isinstance(self.space, Hypersphere) or
isinstance(self.space, _SpecialOrthogonalMatrices)):
raise GeometricException('Cannot generate random values on unsupported manifold')
X = self.space.random_uniform(num_samples)
# Validate the randomly generated belongs to the manifold 'self.space'
are_points_valid = all([self.space.belongs(x) for x in X])
if not are_points_valid:
raise GeometricException('Some generated points do not belong to the manifold')
return X
Our evaluation uses the Euclidean mean as a baseline, with a straightforward implementation using NumPy.
def euclidean_mean(manifold_points: List[Tensor]) -> np.array:
return np.mean(manifold_points, axis=0)
The computation of the Fréchet centroid on a sequence of data points defined as stacked numpy arrays, is implemented in method, estimate. It relies on the Geomstats method FrechetMean.fit.
def estimate(self, X: np.array) -> np.array:
self.frechet_mean.fit(X=X, y=None, weights=self.weights)
return self.frechet_mean.estimate_
Evaluation
Centroid on Sphere (S2)
Our initial test involves calculating the non-weighted Fréchet centroid for 7 points located on a hypersphere [ref 6]. The randomly generated points, `rand_points`, are visualized both on a 3D sphere and in Euclidean space using the `HyperspherePlot` and `EuclideanPlot` classes. While the code is not included here for clarity, it is available on GitHub [ref 1].
frechet_estimator = FrechetEstimator(space=Hypersphere(dim=2),
optimizer=GradientDescent(),
weights=None)
# 1- Generate the random point on the hypersphere
rand_points = frechet_estimator.rand(8)
# 2- Estimate the Frechet centroid then test if belongs to the manifold frechet_mean = frechet_estimator.estimate(rand_points)
# 3- Display the points on the Hypersphere on plot hypersphere_plot = HyperspherePlot(rand_points, frechet_mean)assert frechet_estimator.space.belongs(frechet_mean)
hypersphere_plot.show()
# 4- Compute the euclidean or arithmetic centroid
euclidean_mean = FrechetEstimator.euclidean_mean(np_points)
# 5- Display the points in 3D Euclidean space euclidean_plot = EuclideanPlot(rand_points, euclidean_mean
euclidean_plot.show()
print(f'\nFrechet mean: {frechet_mean}\nEuclidean mean: {euclidean_mean}')
Output:
Frechet mean: [ 0.0174 -0.9958 -0.0897]
Euclidean mean: [ 0.0220 -0.0792 0.0226]
The 8 random points are visualized using Matplotlib 3D scatter plot.
Fig 2. Visualization of points from a 2-dimension hypersphere on 3D Euclidean space
The next graph visualizes the 8 random points on a 3D Sphere.
Fig 3. Visualization of random points on a 2-dimension hypersphere
Centroid on Special Orthogonal Group (SO3)
A special Orthogonal Group is a Lie group. In differential geometry, a Lie group is a mathematical structure that combines the properties of both a group and a smooth manifold. It allows for the application of both algebraic and geometric techniques. As a group, it has an operation (like multiplication) that satisfies certain axioms (closure, associativity, identity, and invertibility).
The Special Orthogonal Group in 3 dimensions, SO(3) is the group of all rotation matrices in 3 spatial dimensions.
It can be defined by 3 rotation elements for each of the axis of rotation x, y, and z.
Important note: The SO(3) manifold is described and evaluated with a Python source code in a previous post [ref 7].
We utilize the Geomstats `SpecialOrthogonal` class to generate 4 random matrices, then compute their Euclidean and Fréchet centroids.
manifold = SpecialOrthogonal(n=3, point_type="matrix")
frechet_estimator = FrechetEstimator(manifold, GradientDescent(), weights=None)
# Generate the random point on SO3
manifold_points = frechet_estimator.rand(4)
# Visualize the SO3 matrices on 3D plot
so3_plot = SO3Plot(manifold_points)
so3_plot.show()
# Compute the Frechet and Euclidean centroids
frechet_mean = frechet_estimator.estimate(manifold_points)
euclidean_mean = FrechetEstimator.euclidean_mean(manifold_points)
print(f'\nFrechet mean:\n{frechet_mean}\nEuclidean mean:\n{euclidean_mean}')
Output:
Frechet mean:
[[-0.74841 -0.40675 -0.52385]
[-0.01445 -0.77966 0.62603]
[-0.66306 0.47610 0.57763]]
Euclidean mean:
[[-0.52282 -0.24432 -0.38322]
[ 0.03642 -0.55841 0.37230]
[-0.44746 0.26312 0.11775]]
In this scenario we represent 4 asymmetric 3x3 matrices representing the SO(3) rotations on a 3D plots.