GGML Method Overview#

[1]:
import ggml_ot

This tutorial introduces the Global Ground Metric Learning (GGML) method on a small toy dataset. We also provide intuition on how different concepts from the method relate to patient-level single-cell data. If you want to use the method on your own single-cell data, have a look at the basic tutorial on AnnData.

Motivating Example#

Let’s consider a small toy dataset containing three classes (colors) with 10 distributions each (symbols).

[2]:
dataset = ggml_ot.data.from_synth()
../_images/tutorials_0_ggml_method_introduction_4_0.png

Intuitively, you can think of the classes as disease states (e.g. control, acute, chronic) and distributions as the cells measured from different patients (•,⨯,▪ …) with the corresponding disease state. The x- and y-axis correspond to genes or gene loadings. The modes (or “cluster”) of the distributions correspond to different cell-types.

In this example, the only difference between disease states can be identified for the cell-type in the center along the x-axis.

Optimal Transport#

The Optimal Transport distance between two empirical distributions \(\mathcal{X}\) and \(\mathcal{Y}\) is defined as:

\[\quad OT(\mathcal{X},\mathcal{Y}) = \min_{\pi}\sum_{x,y} d(x,y) \pi_{x,y}\]

where \(d(x,y)\) is the ground metric, a function that determines the cost of moving mass between \(x \sim \mathcal{X}\) and \(y \sim \mathcal{Y}\). The transport plan \(\pi\) is a mapping between the marginals of \(\mathcal{X}\) and \(\mathcal{Y}\). The OT distance is then the minimal cost of the optimal transport plan, a measure of similarity between the distributions \(\mathcal{X}\) and \(\mathcal{Y}\).

Euclidean Ground Metric#

A commonly used ground metric is the euclidean distance \(d(x,y) = ||x-y||_2\), where the OT distance corresponds to the popular (squared) Wasserstein-2 distance.

\[\quad W_2(\mathcal{X},\mathcal{Y})^2 = \min_{\pi}\sum_{x,y} ||x-y||_2 \pi_{x,y}\]

Let’s compute the Optimal Transport distance on the example dataset using the Euclidean distance as a ground metric.

[3]:
D_euclidean = dataset.compute_OT(ground_metric="euclidean", plot=True)
/tmp/ipykernel_2526874/1858770048.py:1: DeprecationWarning: `plot=` is deprecated, use `plot_type=` instead.
  D_euclidean = dataset.compute_OT(ground_metric="euclidean", plot=True)
../_images/tutorials_0_ggml_method_introduction_7_1.png

As you can see, the class differences along the x-axis are not captured as the euclidean metric can not differentiate between class-related differences on the x-axis and unrelated differences on the y-axis.

Learn Ground Metric#

Now, we perform Supervised Optimal Transport with GGML. It trains a Ground Metric such that the Optimal Transport distance between distributions captures their known class relationships. We plot the learned metric after each epoch.

Ground Metric:#

\[\quad W_2(\mathcal{X},\mathcal{Y};\Theta)^2 = \inf_{\pi}\int_{\mathcal{X} \times \mathcal{Y}} d(x,y;\Theta)^2 d\pi(x,y)\]

Here, we introduced the notation of some parameters \(\Theta\) to indicate that \(d\) is a parameterized metric.

Triplet Learning:#

\[\mathcal{L}_\alpha(X,\mathcal{T};\Theta) = \sum\limits_{t\in\mathcal{T} } \mathcal{L}_\alpha(X,t;\Theta)\]
\[\begin{split}\text{where} \: \mathcal{L}_\alpha(X,(i,j,k);\Theta) = \\ \max \left( W(X_i,X_j;\Theta) - W(X_j,X_k;\Theta) + \alpha, 0 \right)\end{split}\]

where the alpha \(\alpha\) paramter determines the optimal margin between the distributions of the same class \(X_i,X_j\) and different class \(X_j,X_k\).

Regularization:#

In practise, we optimize:

\[\mathcal{L}_\alpha(X,\mathcal{T};\Theta) + \lambda R(\Theta)\]

where \(R(\Theta)\) is a regularizer on \(\Theta\) weighted by paramter reg \(\lambda\)

Generalized Mahalanobis Distance:#

\[\begin{split}\begin{align} d(\mathbf{x}_i,\mathbf{x}_j;\mathbf{M} ) &= \sqrt{(\mathbf{x}_i-\mathbf{x}_j)^T \mathbf{M} (\mathbf{x}_i - \mathbf{x}_j)} \nonumber \\ &= \lVert \mathbf{A}\mathbf{x}_i-\mathbf{A}\mathbf{x}_j \rVert =: d(\mathbf{x}_i,\mathbf{x}_j;\mathbf{A}) \nonumber \\ s.t. \;\mathbf{A}^T\mathbf{A} &= (\mathbf{Q} \Lambda^{\frac{1}{2}}) (\mathbf{Q} \Lambda^{\frac{1}{2}})^T = \mathbf{Q} \Lambda \mathbf{Q}^T = \mathbf{M} \nonumber \end{align}\end{split}\]

Normalize first#

Hyperparameters like alpha and reg depend on the absolute scale of the data — the triplet margin is measured in the same units as the ground cost. Calling dataset.normalize() centers and rescales each feature to unit variance, so hyperparameters values are more transferable across datasets.

[4]:
dataset.normalize()

map_A = dataset.train(alpha=1, reg_type=1, reg=0.001, n_comps=2, max_iter=6, plot_iter=1, verbose=True)
100%|██████████| 1/1 [00:17<00:00, 17.60s/it, obj_loss=0.9325, reg_loss=0.0028, obj_grad=9.985e-02]
Compute all OT distances after 1 iterations
../_images/tutorials_0_ggml_method_introduction_15_2.png
100%|██████████| 1/1 [00:21<00:00, 21.47s/it, obj_loss=0.9140, reg_loss=0.0028, obj_grad=1.114e-01]
Compute all OT distances after 2 iterations
../_images/tutorials_0_ggml_method_introduction_15_5.png
100%|██████████| 1/1 [00:22<00:00, 22.37s/it, obj_loss=0.8929, reg_loss=0.0028, obj_grad=1.236e-01]
Compute all OT distances after 3 iterations
../_images/tutorials_0_ggml_method_introduction_15_8.png
100%|██████████| 1/1 [00:22<00:00, 22.09s/it, obj_loss=0.8691, reg_loss=0.0028, obj_grad=1.369e-01]
Compute all OT distances after 4 iterations
../_images/tutorials_0_ggml_method_introduction_15_11.png
100%|██████████| 1/1 [00:28<00:00, 28.37s/it, obj_loss=0.8427, reg_loss=0.0028, obj_grad=1.518e-01]
Compute all OT distances after 5 iterations
../_images/tutorials_0_ggml_method_introduction_15_14.png
100%|██████████| 1/1 [00:20<00:00, 20.73s/it, obj_loss=0.8142, reg_loss=0.0028, obj_grad=1.688e-01]
Compute all OT distances after 6 iterations
../_images/tutorials_0_ggml_method_introduction_15_17.png

Cross-validation#

GGML also generalizes from learning on subsets of the data. We can verify this by cross-validating across multiple train/test splits with train_test() — a single call that wraps train() and evaluates the learned metric on the held-out split, reporting classification (k-NN accuracy) and clustering metrics (ARI, NMI, VI) per split.

See the basic tutorial on AnnData for the same workflow on a real single-cell dataset.

[5]:
map_As, scores = dataset.train_test(alpha=1, reg_type=1, reg=0.001, n_comps=2, n_splits=3)
../_images/tutorials_0_ggml_method_introduction_17_0.png
../_images/tutorials_0_ggml_method_introduction_17_1.png
../_images/tutorials_0_ggml_method_introduction_17_2.png
Metric knn_acc mi ari vi epoch_time(s)
Mean±SD Mean SD Mean SD Mean SD Mean SD Mean SD
0 1.00 0.00 1.00 0.00 1.00 0.00 0.00 0.00 6.35 0.72