GPU Training with Sinkhorn#

How to speed up GGML training by using the Sinkhorn OT solver on a GPU.

[1]:
import ggml_ot
import scanpy as sc
import pandas as pd

By default, GGML uses the exact Earth Mover’s Distance (EMD). EMD calls a network-simplex LP solver per distribution pair which is only supported on CPU. The Sinkhorn solver instead expresses OT as a sequence of batched matrix operations that vectorize across all pairs in a triplet batch, which maps cleanly onto a GPU.

We reuse the Breast cancer dataset from Kumar et al., 2023 (714,331 cells from 126 donors) from the fast approximation tutorial.

[ ]:
id = "b8b5be07-061b-4390-af0a-f9ced877a068.h5ad"
adata = ggml_ot.data.load_cellxgene(id)
sc.pp.highly_variable_genes(adata, n_top_genes=5000, subset=True)

Setup dataset with normalize()#

Sinkhorn uses entropic regularization which is controlled by the entropic_reg parameter which is sensitive to the scale of the data. To better compare and tune entropic_reg across datasets, we normalize datasets by calling dataset.normalize().

[ ]:
dataset = ggml_ot.from_anndata(adata, patient_col="donor_id", label_col="reported_diseases")
dataset.normalize()
[ ]:
train_params = {"n_splits": 2}
scores = {}

Sinkhorn on GPU#

Passing entropic_reg > 0 to train() dispatches the inner OT solve to Sinkhorn. With settings.device="cuda" the GGML optimization, including the inner OT, is performed on the GPU as batched vectorized operations which leads to a significant speed-up.

[4]:
ggml_ot.settings.device = "cuda"

_, scores["sinkhorn (GPU)"] = dataset.train_test(**train_params, entropic_reg=5.0)
../_images/tutorials_4_gpu_sinkhorn_8_0.png
../_images/tutorials_4_gpu_sinkhorn_8_1.png
Metric knn_acc mi ari vi epoch_time(s)
Mean±SD Mean SD Mean SD Mean SD Mean SD Mean SD
0 0.85 0.01 0.67 0.01 0.78 0.00 0.40 0.01 0.32 0.01

Sinkhorn on CPU#

Even on CPU (settings.device="cpu") the computation time benefits from batched tensor operations and better utilization of large number of CPU threads (settings.n_threads) compared to the slow training using EMD2 (entropic_reg=0).

[5]:
ggml_ot.settings.device = "cpu"
ggml_ot.settings.n_threads = 64

_, scores["sinkhorn (CPU)"] = dataset.train_test(**train_params, entropic_reg=5.0)
../_images/tutorials_4_gpu_sinkhorn_10_0.png
../_images/tutorials_4_gpu_sinkhorn_10_1.png
Metric knn_acc mi ari vi epoch_time(s)
Mean±SD Mean SD Mean SD Mean SD Mean SD Mean SD
0 0.84 0.00 0.36 0.34 0.42 0.40 0.51 0.14 7.06 0.07

Exact EMD2 on GPU and CPU#

For contrast, we run the exact EMD2 solver (entropic_reg=0.0, the default) on both devices. EMD2 is computed by the network-simplex LP solver shipped with POT, which is not batched and runs on the CPU. Setting settings.device="cuda" therefore only moves tensors to the GPU for the triplet loss, the OT solves themselves stay on the CPU and train() emits a warning to make this visible.

[6]:
ggml_ot.settings.device = "cuda"

_, scores["emd2 (GPU)"] = dataset.train_test(**train_params)
/home/kuehn/ot_metric_learning/gaussian-ground-metric-learning/code/ggml-ot_privat/ggml_ot/optimization/api.py:44: UserWarning: Exact EMD2 solver is CPU-bound; no GPU acceleration is used for OT solves.
  warnings.warn(
../_images/tutorials_4_gpu_sinkhorn_12_1.png
/home/kuehn/ot_metric_learning/gaussian-ground-metric-learning/code/ggml-ot_privat/ggml_ot/optimization/api.py:44: UserWarning: Exact EMD2 solver is CPU-bound; no GPU acceleration is used for OT solves.
  warnings.warn(
../_images/tutorials_4_gpu_sinkhorn_12_3.png
Metric knn_acc mi ari vi epoch_time(s)
Mean±SD Mean SD Mean SD Mean SD Mean SD Mean SD
0 0.88 0.00 0.62 0.05 0.73 0.06 0.46 0.05 391.31 4.02
[7]:
ggml_ot.settings.device = "cpu"

_, scores["emd2 (CPU)"] = dataset.train_test(**train_params)
../_images/tutorials_4_gpu_sinkhorn_13_0.png
../_images/tutorials_4_gpu_sinkhorn_13_1.png
Metric knn_acc mi ari vi epoch_time(s)
Mean±SD Mean SD Mean SD Mean SD Mean SD Mean SD
0 0.85 0.01 0.60 0.03 0.71 0.04 0.48 0.03 354.64 27.33

As expected, the non-batched EMD2 cannot leverage GPU speed-up and massive parallelization.

Performance and classification comparison#

train_test() reports the mean epoch time alongside the usual classification and clustering metrics. Sorting by epoch time lines up the four setups from fastest to slowest.

[8]:
scores_df = (
    pd.concat([scores[setup] for setup in scores], ignore_index=True)
    .set_index(pd.Index(scores.keys(), name="setup"))
    .sort_values(("epoch_time(s)", "Mean"))
)

ggml_ot.pl.table(
    scores_df,
    style_performance=True,
    title="Sinkhorn vs. EMD2 on GPU and CPU",
)
Sinkhorn vs. EMD2 on GPU and CPU
Metric knn_acc mi ari vi epoch_time(s)
Mean±SD Mean SD Mean SD Mean SD Mean SD Mean SD
setup                    
sinkhorn (GPU) 0.85 0.01 0.67 0.01 0.78 0.00 0.40 0.01 0.32 0.01
sinkhorn (CPU) 0.84 0.00 0.36 0.34 0.42 0.40 0.51 0.14 7.06 0.07
emd2 (CPU) 0.85 0.01 0.60 0.03 0.71 0.04 0.48 0.03 354.64 27.33
emd2 (GPU) 0.88 0.00 0.62 0.05 0.73 0.06 0.46 0.05 391.31 4.02

Takeaways:

  • To leverage GPU speed-up the Sinkhorn solver must be used with entropic_reg > 0

    • normalize() before tuning entropic_reg as it makes regularization strengths transferable across datasets

  • EMD2 solver is not batched and computations are CPU-bound regardless of settings.device

See the fast approximations tutorial for other computational speed-ups for both, EMD2 and Sinkhorn.

Note

To improve training time also consider other speed-up strategies like PCA embeddings, cell-type grouping, and cell subsampling, see Tutorial on Computational Speed-ups.