ggml_ot.data.AnnData_TripletDataset#
- class ggml_ot.data.AnnData_TripletDataset(adata, patient_col='sample', label_col='patient_group', n_cells=250, n_triplets=3, use_rep=None, group_by=None, gmm_key=None, sample_gmm=False, gmm_weights_source='auto')[source]#
Dataset to train GGML based on AnnData.
This subclass of TripletDataset formats triplets of patient-level cell distributions from an AnnData object. The triplets capture the relative relationship between patient groups (e.g. disease state) that GGML aims to learn.
By default, it captures the cells of a patient as a empirical distribution in the gene space of the AnnData (.X). Using the
use_repand/orgroup_byparameter, you can reduce the distribution to only cell_subtypes and/or low dimensional gene representations.This class exposes the dataset to the standardized interfaces used by
ggml_ot.train(),ggml_ot.tune(),ggml_ot.test()andggml_ot.train_test().- Parameters:
- adata str | anndata.AnnData
The AnnData object.
- patient_col str, optional
Column in
adata.obsthat identifies the patient / sample (default: “sample”).- label_col str, optional
Column in
adata.obsthat contains the patient group, e.g., disease state (default: “patient_group”).- n_cells int, optional
Number of cells to sample per patient (default: 250).
- n_triplets int, optional
Number of generated triplets for each patient to capture the relative relationship of the patient group. (default: 3). This will lead to
n_patients * n_triplets * n_labelstriplets being generated.- group_by None | str, optional
Optional column in
adata.obsto group cells and learn a ground metric between cell groups instead (default: None).- use_rep None | str, optional
If provided, uses
adata.obsm[use_rep]as the cell embedding representation; otherwise the raw .X matrix is used (default: None).- gmm_key None | str, optional
If provided, loads a previously fitted GMM representation from
adata.uns[gmm_key](default: None).- sample_gmm bool, optional
If
True, samples empirical supports from fitted GMM mixtures instead of using parametric supports directly (default: False).- gmm_weights_source {"auto", "stored", "components"}, optional
Controls how per-distribution GMM weights are reconstructed when
gmm_keyis provided."auto"tries stored weights first, then hard assignments predicted from the stored GMM parameters (default: “auto”).
Notes
Following scverse conventions, this class modifies the AnnData object in-place during dataset construction and training. In particular:
adata.uns["ggml_params"]— stores dataset construction parameters.adata.uns["W_ggml"]— the learned linear map after training.adata.varm["W_ggml"]— gene-space loadings of the learned ground metric.adata.obsm["X_ggml"]— cells projected into the learned gene subspace.
If you need an unmodified copy, call
adata.copy()before constructing the dataset.See also
ggml_ot.data.generic.TripletDatasetbase class providing triplet creation and dataset API.
- __init__(adata, patient_col='sample', label_col='patient_group', n_cells=250, n_triplets=3, use_rep=None, group_by=None, gmm_key=None, sample_gmm=False, gmm_weights_source='auto')[source]#
Methods
__init__(adata[, patient_col, label_col, ...])compute_OT([precomputed_distances, ...])Compute the Optimal Transport distances between all distributions.
enrich_discriminant(**kwargs)Enrich a supervised discriminant direction in latent space.
enrich_gmm_components(*[, gmm_key, ...])Run pathway enrichment on GMM component gene signatures.
enrich_latent_axes(*[, axes, gene_symbols, ...])Run pathway enrichment on per-axis gene rankings.
fit_gmm(*args, **kwargs)Fit or refit a GMM representation for this AnnData-backed dataset.
normalize()rank_latent_axes(*[, axes, gene_symbols])Rank genes by their contribution to each GGML latent axis.
rotate_latent_axes([method])Apply post-hoc rotation to latent axes (varimax, promax).
summarize_gmm_components(*[, gmm_key, ...])Summarize GMM components by cell-type composition and patient weights.
test([ground_metric, knn_k])Tests ground metric on a given dataset.
to_anndata()Returns the AnnData of the dataset object.
train([alpha, reg, reg_type, n_comps, lr, ...])Perform supervised optimal transport by ground metric learning.
train_emd2([alpha, reg, reg_type, n_comps, ...])Train GGML with exact OT (EMD2).
train_sinkhorn([alpha, reg, reg_type, ...])Train GGML with Sinkhorn-regularized OT.
train_test([n_splits, train_size, ...])Trains and cross-validates ground metrics on train-test splits.
tune([alpha, reg, reg_type, n_comps, ...])Tune hyperparameters by performing a Grid Search and Cross-Validation.
validate_gmm(*[, gmm_key, patient_col, ...])Run a one-call hold-out validation pass for a fitted GMM.
Attributes
covariancesdimDimensions of space underlying the distributions.
distribution_labelsPatient group labels per patient-level distribution as int, references unique classes from
adata.obs[label_col].distribution_labels_strPatient group labels per patient-level distribution as string, taken from
adata.obs[label_col].identical_supportsIf True, indicates supports were forced identical across distributions by group_by.
map_ALearned ground metric as a linear map (raises a warning if dataset is not trained yet).
patient_labelspointspoints_labelsReturns list of the distribution_labels of all points concatenated over all distributions
points_labels_strPatient group labels per cell as string, taken from
adata.obs[label_col].supportsPer-patient distribution supports, by default the cells of the patients.
w_thetause
map_Ainstead.weightsProbability for each support, by default uniform over the cells of the patients.