Hyperparameter Tuning#
[1]:
import ggml_ot
import scanpy as sc
import numpy as np
In this tutorial, we explore how to automatically tune important hyperparameters of ggml-ot:
alpha: Target margin in the OT distance between distributions from the same and different classesreg: regularization strength on \(\mathbf{A}\)reg_type: norm used for the regularization, L11, L22, or nuclear normnuc.n_comps: rank of the latent subspace \(\mathbf{A}\)
First, we setup the Dataset from AnnData as before. To speedup the tutorial, we subsample each patient to n_cells=100 which can be increased (or omitted) on your own data.
[2]:
id = "c1f6034b-7973-45e1-85e7-16933d0550bc.h5ad"
adata = ggml_ot.data.load_cellxgene(id)
sc.pp.highly_variable_genes(adata, n_top_genes=1000, subset=True)
# Replace patient_col and label_col names to match the .obs in your AnnData
dataset = ggml_ot.from_anndata(adata, patient_col="sample", label_col="patient_group", n_cells=100)
dataset.normalize()
/home/kuehn/ot_metric_learning/gaussian-ground-metric-learning/code/ggml-ot_privat/.venv/lib/python3.12/site-packages/numba/np/ufunc/parallel.py:373: NumbaWarning: The TBB threading layer requires TBB version 2021 update 6 or later i.e., TBB_INTERFACE_VERSION >= 12060. Found TBB_INTERFACE_VERSION = 12050. The TBB threading layer is disabled.
warnings.warn(problem)
[2]:
<ggml_ot.data.anndata.AnnData_TripletDataset at 0x7fe6da127440>
Grid Search over alpha, reg, reg_type and n_comps#
Here, we run the hyperparameter tuning to find the best hyperparameter combination for alpha and reg. This will lead to n_splits train-validation splits being trained and tested for hyperparameter combination.
You can also tune reg_type and n_comps with this function. In general, n_comps=2 is sufficient and preferable for visualization purposes and fast computation, but complex datasets may require more components. The choice of regularization type mostly depends on the downstream task. For example, to identify distinct marker genes L1 is useful, whereas gene enrichment analysis should be trained with L2 to include correlated genes.
Depending on the size of your dataset, the number of splits and hyperparameter combinations, the grid search may take a while!
[3]:
# Grids for alpha (triplet margin) and reg (regularization strength)
alpha_grid = np.logspace(-1, 2, 4) # [10^-1, 10^0, 10^1]
reg_grid = np.logspace(-4, -1, 4) # [10^-4, 10^-3, 10^-2]
As, scores = dataset.tune(
alpha=alpha_grid,
reg=reg_grid,
reg_type="fro",
n_comps=[2, 5, 10],
train_size=0.7,
)
| Metric | knn_acc | mi | ari | vi | epoch_time(s) | ||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| Mean±SD | Mean | SD | Mean | SD | Mean | SD | Mean | SD | Mean | SD | |||
| n_comps | reg_type | alpha | reg | ||||||||||
| 2 | fro | 0.10 | 0.00 | 0.73 | 0.11 | 0.72 | 0.09 | 0.68 | 0.12 | 0.52 | 0.16 | 1.81 | 0.06 |
| 0.00 | 0.87 | 0.08 | 0.76 | 0.17 | 0.71 | 0.22 | 0.44 | 0.31 | 1.62 | 0.17 | |||
| 0.01 | 0.82 | 0.11 | 0.72 | 0.12 | 0.64 | 0.15 | 0.54 | 0.24 | 1.72 | 0.20 | |||
| 0.10 | 0.78 | 0.07 | 0.75 | 0.08 | 0.68 | 0.12 | 0.48 | 0.15 | 1.64 | 0.11 | |||
| 1.00 | 0.00 | 0.76 | 0.13 | 0.78 | 0.12 | 0.78 | 0.15 | 0.43 | 0.23 | 1.74 | 0.09 | ||
| 0.00 | 0.71 | 0.09 | 0.74 | 0.15 | 0.73 | 0.17 | 0.51 | 0.28 | 1.78 | 0.11 | |||
| 0.01 | 0.80 | 0.13 | 0.80 | 0.15 | 0.79 | 0.20 | 0.38 | 0.26 | 1.53 | 0.07 | |||
| 0.10 | 0.71 | 0.11 | 0.76 | 0.08 | 0.73 | 0.11 | 0.47 | 0.15 | 1.29 | 0.16 | |||
| 10.00 | 0.00 | 0.82 | 0.09 | 0.84 | 0.13 | 0.84 | 0.14 | 0.32 | 0.24 | 1.58 | 0.13 | ||
| 0.00 | 0.80 | 0.13 | 0.78 | 0.10 | 0.79 | 0.10 | 0.44 | 0.19 | 1.67 | 0.16 | |||
| 0.01 | 0.80 | 0.08 | 0.83 | 0.13 | 0.82 | 0.15 | 0.34 | 0.26 | 1.59 | 0.11 | |||
| 0.10 | 0.87 | 0.04 | 0.79 | 0.10 | 0.79 | 0.13 | 0.42 | 0.20 | 1.62 | 0.18 | |||
| 100.00 | 0.00 | 0.78 | 0.12 | 0.76 | 0.04 | 0.76 | 0.06 | 0.48 | 0.07 | 1.78 | 0.07 | ||
| 0.00 | 0.76 | 0.11 | 0.71 | 0.05 | 0.71 | 0.06 | 0.57 | 0.11 | 1.73 | 0.09 | |||
| 0.01 | 0.78 | 0.07 | 0.73 | 0.11 | 0.71 | 0.14 | 0.51 | 0.20 | 1.63 | 0.10 | |||
| 0.10 | 0.76 | 0.11 | 0.76 | 0.11 | 0.74 | 0.13 | 0.46 | 0.19 | 1.70 | 0.05 | |||
| 5 | fro | 0.10 | 0.00 | 0.76 | 0.11 | 0.70 | 0.13 | 0.63 | 0.17 | 0.54 | 0.22 | 1.93 | 0.16 |
| 0.00 | 0.78 | 0.12 | 0.68 | 0.09 | 0.60 | 0.13 | 0.60 | 0.15 | 1.91 | 0.21 | |||
| 0.01 | 0.82 | 0.05 | 0.69 | 0.03 | 0.62 | 0.02 | 0.57 | 0.06 | 1.87 | 0.08 | |||
| 0.10 | 0.80 | 0.08 | 0.73 | 0.04 | 0.67 | 0.07 | 0.51 | 0.09 | 1.80 | 0.12 | |||
| 1.00 | 0.00 | 0.80 | 0.15 | 0.74 | 0.06 | 0.70 | 0.10 | 0.50 | 0.11 | 1.82 | 0.14 | ||
| 0.00 | 0.76 | 0.08 | 0.75 | 0.08 | 0.68 | 0.11 | 0.46 | 0.14 | 1.89 | 0.14 | |||
| 0.01 | 0.78 | 0.14 | 0.72 | 0.04 | 0.65 | 0.07 | 0.52 | 0.07 | 1.86 | 0.23 | |||
| 0.10 | 0.78 | 0.07 | 0.69 | 0.11 | 0.67 | 0.12 | 0.59 | 0.19 | 1.86 | 0.16 | |||
| 10.00 | 0.00 | 0.82 | 0.13 | 0.78 | 0.13 | 0.77 | 0.15 | 0.44 | 0.26 | 1.81 | 0.11 | ||
| 0.00 | 0.89 | 0.00 | 0.85 | 0.12 | 0.85 | 0.14 | 0.29 | 0.23 | 1.91 | 0.12 | |||
| 0.01 | 0.82 | 0.05 | 0.80 | 0.10 | 0.77 | 0.14 | 0.39 | 0.19 | 1.66 | 0.22 | |||
| 0.10 | 0.87 | 0.04 | 0.81 | 0.12 | 0.80 | 0.13 | 0.37 | 0.24 | 2.34 | 0.17 | |||
| 100.00 | 0.00 | 0.80 | 0.13 | 0.76 | 0.10 | 0.74 | 0.12 | 0.46 | 0.19 | 1.88 | 0.18 | ||
| 0.00 | 0.78 | 0.12 | 0.78 | 0.10 | 0.76 | 0.13 | 0.43 | 0.19 | 2.00 | 0.11 | |||
| 0.01 | 0.80 | 0.13 | 0.80 | 0.06 | 0.80 | 0.06 | 0.40 | 0.11 | 1.79 | 0.11 | |||
| 0.10 | 0.80 | 0.13 | 0.78 | 0.09 | 0.77 | 0.11 | 0.43 | 0.17 | 1.98 | 0.19 | |||
| 10 | fro | 0.10 | 0.00 | 0.82 | 0.09 | 0.74 | 0.04 | 0.67 | 0.07 | 0.50 | 0.08 | 2.01 | 0.10 |
| 0.00 | 0.80 | 0.08 | 0.76 | 0.08 | 0.70 | 0.11 | 0.45 | 0.14 | 1.94 | 0.15 | |||
| 0.01 | 0.78 | 0.07 | 0.71 | 0.12 | 0.67 | 0.14 | 0.53 | 0.22 | 1.89 | 0.07 | |||
| 0.10 | 0.80 | 0.08 | 0.65 | 0.10 | 0.55 | 0.10 | 0.62 | 0.17 | 2.01 | 0.16 | |||
| 1.00 | 0.00 | 0.76 | 0.08 | 0.74 | 0.08 | 0.69 | 0.11 | 0.48 | 0.14 | 2.10 | 0.08 | ||
| 0.00 | 0.82 | 0.09 | 0.74 | 0.14 | 0.72 | 0.16 | 0.48 | 0.26 | 1.97 | 0.08 | |||
| 0.01 | 0.80 | 0.11 | 0.70 | 0.12 | 0.65 | 0.13 | 0.55 | 0.21 | 1.89 | 0.14 | |||
| 0.10 | 0.87 | 0.08 | 0.74 | 0.04 | 0.70 | 0.09 | 0.50 | 0.09 | 1.91 | 0.11 | |||
| 10.00 | 0.00 | 0.87 | 0.04 | 0.79 | 0.14 | 0.76 | 0.17 | 0.39 | 0.26 | 1.99 | 0.11 | ||
| 0.00 | 0.87 | 0.04 | 0.78 | 0.09 | 0.74 | 0.12 | 0.43 | 0.17 | 2.16 | 0.15 | |||
| 0.01 | 0.82 | 0.13 | 0.75 | 0.13 | 0.73 | 0.16 | 0.47 | 0.26 | 1.95 | 0.06 | |||
| 0.10 | 0.84 | 0.05 | 0.85 | 0.12 | 0.85 | 0.14 | 0.29 | 0.23 | 2.05 | 0.04 | |||
| 100.00 | 0.00 | 0.82 | 0.13 | 0.80 | 0.12 | 0.79 | 0.14 | 0.38 | 0.21 | 1.87 | 0.07 | ||
| 0.00 | 0.80 | 0.13 | 0.82 | 0.12 | 0.82 | 0.15 | 0.35 | 0.22 | 1.88 | 0.06 | |||
| 0.01 | 0.84 | 0.09 | 0.81 | 0.10 | 0.80 | 0.12 | 0.37 | 0.18 | 1.83 | 0.07 | |||
| 0.10 | 0.82 | 0.13 | 0.82 | 0.15 | 0.81 | 0.17 | 0.35 | 0.28 | 1.94 | 0.02 | |||
Once tuning is complete, the function displays the performance of each hyperparameter combination on the validation set and returns:
As: a dict containing the learned transformations for each parameter combination (with the best performance across splits)scores: a dataframe of the performance of the considered hyperparameter combination, as shown above.
Using the scores dataframe and contour_hyperparams plotting fucntion, you can plot other parameter combinations and evaluation metrics as shown below:
[ ]:
# ARI over alpha and reg
ggml_ot.pl.contour_hyperparams(
scores, x="alpha", y="reg", fixed_params={"reg_type": "fro", "n_comps": 2}, value_col=("ari", "mean"), levels=20
)
# Epoch time over n_comps and reg
ggml_ot.pl.contour_hyperparams(
scores,
x="n_comps",
y="reg",
fixed_params={"reg_type": "fro"},
value_col=("epoch_time(s)", "mean"),
log_axis="y",
levels=20,
)
(<Figure size 640x480 with 2 Axes>, <Axes: xlabel='n_comps', ylabel='reg'>)
Access best performing hyperparameters#
To get the best performing hyperparameters you can call .idxmax() on the respective column (e.g. mean classification accuracy ("knn","mean")) of the scores dataframe that the hyperparameter tuning returned.
[ ]:
best_params = scores[("knn", "mean")].idxmax()
print("as tupple to access As and scores: " + str(best_params))
# Cast numpy scalars (np.int64, np.float64) to Python scalars for cleaner printing and kwargs
best_params_dict = {p: (v.item() if hasattr(v, "item") else v) for p, v in zip(scores.index.names, best_params)}
print("as dict to pass as parameters: " + str(best_params_dict))
as tupple to access As and scores: (np.int64(5), 'fro', np.float64(10.0), np.float64(0.001))
as dict to pass as parameters: {'n_comps': 5, 'reg_type': 'fro', 'alpha': 10.0, 'reg': 0.001}
We can also rerun the training (and testing) with the learned hyperparameters to show that the identified parameters reliable learn a performant ground metric.
[6]:
_ = dataset.train_test(**best_params_dict, plot_split=False)
| Metric | knn_acc | mi | ari | vi | epoch_time(s) | |||||
|---|---|---|---|---|---|---|---|---|---|---|
| Mean±SD | Mean | SD | Mean | SD | Mean | SD | Mean | SD | Mean | SD |
| 0 | 0.85 | 0.06 | 0.80 | 0.08 | 0.79 | 0.11 | 0.38 | 0.15 | 1.34 | 0.11 |