ggml_ot.data.AnnData_TripletDataset

ggml_ot.data.AnnData_TripletDataset#

class ggml_ot.data.AnnData_TripletDataset(adata, patient_col='sample', label_col='patient_group', n_cells=250, n_triplets=3, use_rep=None, group_by=None, gmm_key=None, sample_gmm=False, gmm_weights_source='auto')[source]#

Dataset to train GGML based on AnnData.

This subclass of TripletDataset formats triplets of patient-level cell distributions from an AnnData object. The triplets capture the relative relationship between patient groups (e.g. disease state) that GGML aims to learn.

By default, it captures the cells of a patient as a empirical distribution in the gene space of the AnnData (.X). Using the use_rep and/or group_by parameter, you can reduce the distribution to only cell_subtypes and/or low dimensional gene representations.

This class exposes the dataset to the standardized interfaces used by ggml_ot.train(), ggml_ot.tune(), ggml_ot.test() and ggml_ot.train_test().

Parameters:
adata str | anndata.AnnData

The AnnData object.

patient_col str, optional

Column in adata.obs that identifies the patient / sample (default: “sample”).

label_col str, optional

Column in adata.obs that contains the patient group, e.g., disease state (default: “patient_group”).

n_cells int, optional

Number of cells to sample per patient (default: 250).

n_triplets int, optional

Number of generated triplets for each patient to capture the relative relationship of the patient group. (default: 3). This will lead to n_patients * n_triplets * n_labels triplets being generated.

group_by None | str, optional

Optional column in adata.obs to group cells and learn a ground metric between cell groups instead (default: None).

use_rep None | str, optional

If provided, uses adata.obsm[use_rep] as the cell embedding representation; otherwise the raw .X matrix is used (default: None).

gmm_key None | str, optional

If provided, loads a previously fitted GMM representation from adata.uns[gmm_key] (default: None).

sample_gmm bool, optional

If True, samples empirical supports from fitted GMM mixtures instead of using parametric supports directly (default: False).

gmm_weights_source {"auto", "stored", "components"}, optional

Controls how per-distribution GMM weights are reconstructed when gmm_key is provided. "auto" tries stored weights first, then hard assignments predicted from the stored GMM parameters (default: “auto”).

Notes

Following scverse conventions, this class modifies the AnnData object in-place during dataset construction and training. In particular:

  • adata.uns["ggml_params"] — stores dataset construction parameters.

  • adata.uns["W_ggml"] — the learned linear map after training.

  • adata.varm["W_ggml"] — gene-space loadings of the learned ground metric.

  • adata.obsm["X_ggml"] — cells projected into the learned gene subspace.

If you need an unmodified copy, call adata.copy() before constructing the dataset.

See also

ggml_ot.data.generic.TripletDataset

base class providing triplet creation and dataset API.

__init__(adata, patient_col='sample', label_col='patient_group', n_cells=250, n_triplets=3, use_rep=None, group_by=None, gmm_key=None, sample_gmm=False, gmm_weights_source='auto')[source]#

Methods

__init__(adata[, patient_col, label_col, ...])

compute_OT([precomputed_distances, ...])

Compute the Optimal Transport distances between all distributions.

enrich_discriminant(**kwargs)

Enrich a supervised discriminant direction in latent space.

enrich_gmm_components(*[, gmm_key, ...])

Run pathway enrichment on GMM component gene signatures.

enrich_latent_axes(*[, axes, gene_symbols, ...])

Run pathway enrichment on per-axis gene rankings.

fit_gmm(*args, **kwargs)

Fit or refit a GMM representation for this AnnData-backed dataset.

normalize()

rank_latent_axes(*[, axes, gene_symbols])

Rank genes by their contribution to each GGML latent axis.

rotate_latent_axes([method])

Apply post-hoc rotation to latent axes (varimax, promax).

summarize_gmm_components(*[, gmm_key, ...])

Summarize GMM components by cell-type composition and patient weights.

test([ground_metric, knn_k])

Tests ground metric on a given dataset.

to_anndata()

Returns the AnnData of the dataset object.

train([alpha, reg, reg_type, n_comps, lr, ...])

Perform supervised optimal transport by ground metric learning.

train_emd2([alpha, reg, reg_type, n_comps, ...])

Train GGML with exact OT (EMD2).

train_sinkhorn([alpha, reg, reg_type, ...])

Train GGML with Sinkhorn-regularized OT.

train_test([n_splits, train_size, ...])

Trains and cross-validates ground metrics on train-test splits.

tune([alpha, reg, reg_type, n_comps, ...])

Tune hyperparameters by performing a Grid Search and Cross-Validation.

validate_gmm(*[, gmm_key, patient_col, ...])

Run a one-call hold-out validation pass for a fitted GMM.

Attributes

covariances

dim

Dimensions of space underlying the distributions.

distribution_labels

Patient group labels per patient-level distribution as int, references unique classes from adata.obs[label_col].

distribution_labels_str

Patient group labels per patient-level distribution as string, taken from adata.obs[label_col].

identical_supports

If True, indicates supports were forced identical across distributions by group_by.

map_A

Learned ground metric as a linear map (raises a warning if dataset is not trained yet).

patient_labels

points

points_labels

Returns list of the distribution_labels of all points concatenated over all distributions

points_labels_str

Patient group labels per cell as string, taken from adata.obs[label_col].

supports

Per-patient distribution supports, by default the cells of the patients.

w_theta

use map_A instead.

weights

Probability for each support, by default uniform over the cells of the patients.