ggml_ot.gmm.fit_gmm

Contents

ggml_ot.gmm.fit_gmm#

ggml_ot.gmm.fit_gmm(dataset, *, component_sharing='sample_specific', k_comps=None, k_selection_metric='aic', refit='none', covariance_type='full', train_size=0.5, max_iter=100, tol=0.001, n_init=1, eps=0.0001, singularity_handling='guarded', gmm_key=None, verbose=True)[source]#

Fit per-patient GMM parameters from a GGML dataset.

Fits Gaussian mixture models to the per-patient cell distributions in dataset and attaches the result as Gaussian component (means/covariances/weights) tensors. The updated dataset can then be used directly with ggml_ot.train_gmm().

Parameters:
dataset

A TripletDataset or AnnData_TripletDataset whose distributions will be modelled as GMMs.

component_sharing Literal['global', 'sample_specific', 'auto'] (default: 'sample_specific')

How components are shared across patients. "sample_specific" (default) fits an independent model per patient. "global" fits one shared model across all patients. "auto" picks "sample_specific" for AnnData datasets and "global" for generic datasets.

k_comps UnionType[int, list[int], tuple[int, ...], None] (default: None)

Number of GMM components. An int fixes the count; a list/tuple triggers automatic k-selection over the provided candidates using k_selection_metric. None uses a default candidate range.

k_selection_metric Literal['bic', 'aic', 'heldout_nll'] (default: 'aic')

Criterion used to select the best k when k_comps is a sequence: "aic" (default), "bic", or "heldout_nll". The held-out NLL path uses an internal train/validation split of the selected cells and stores the resulting diagnostics in the persisted GMM schema.

refit Literal['full', 'none'] (default: 'none')

Whether to refit the best selected model on all data after k-selection. Only used when k_comps is a sequence. "none" (default) reuses the model from the selection pass; "full" refits the best k on the full dataset.

covariance_type Literal['diag', 'full'] (default: 'full')

Covariance structure of each Gaussian component: "full" (default) or "diag".

train_size float (default: 0.5)

Training fraction used during k-selection. The remaining 1 - train_size cells form the validation split for k_selection_metric="heldout_nll" and are ignored by in-sample criteria such as "aic" and "bic". When refit="none", the selected model is retained from that training split; when refit="full", the best selected k is refit on all cells. The default 0.5 yields a balanced train/validation split during selection.

max_iter int (default: 100)

Maximum EM iterations per fit.

tol float (default: 0.001)

EM convergence tolerance.

n_init int (default: 1)

Number of random EM restarts. Higher values reduce sensitivity to initialization at the cost of runtime.

eps float (default: 0.0001)

Numerical floor added to diagonal of covariance matrices to prevent singularities.

singularity_handling Literal['guarded', 'robust', 'strict'] (default: 'guarded')

How near-singular projected covariances are treated: "guarded" (default) applies a small jitter and continues, "robust" uses a larger stabilization, "strict" raises an error on any detected singularity.

gmm_key Optional[str] (default: None)

Key under which the fitted GMM is stored in dataset.adata.uns (AnnData datasets only). Defaults to "gmm_<use_rep>" when None.

verbose bool (default: True)

Print per-patient fit progress.

Returns:

TripletDataset | AnnData_TripletDataset The input dataset augmented with fitted GMM supports, covariances, and weights.