ggml_ot.gmm.fit_gmm#
- ggml_ot.gmm.fit_gmm(dataset, *, component_sharing='sample_specific', k_comps=None, k_selection_metric='aic', refit='none', covariance_type='full', train_size=0.5, max_iter=100, tol=0.001, n_init=1, eps=0.0001, singularity_handling='guarded', gmm_key=None, verbose=True)[source]#
Fit per-patient GMM parameters from a GGML dataset.
Fits Gaussian mixture models to the per-patient cell distributions in
datasetand attaches the result as Gaussian component (means/covariances/weights) tensors. The updated dataset can then be used directly withggml_ot.train_gmm().- Parameters:
- dataset
A
TripletDatasetorAnnData_TripletDatasetwhose distributions will be modelled as GMMs.- component_sharing
Literal['global','sample_specific','auto'] (default:'sample_specific') How components are shared across patients.
"sample_specific"(default) fits an independent model per patient."global"fits one shared model across all patients."auto"picks"sample_specific"for AnnData datasets and"global"for generic datasets.- k_comps
UnionType[int,list[int],tuple[int,...],None] (default:None) Number of GMM components. An int fixes the count; a list/tuple triggers automatic k-selection over the provided candidates using
k_selection_metric.Noneuses a default candidate range.- k_selection_metric
Literal['bic','aic','heldout_nll'] (default:'aic') Criterion used to select the best
kwhenk_compsis a sequence:"aic"(default),"bic", or"heldout_nll". The held-out NLL path uses an internal train/validation split of the selected cells and stores the resulting diagnostics in the persisted GMM schema.- refit
Literal['full','none'] (default:'none') Whether to refit the best selected model on all data after k-selection. Only used when
k_compsis a sequence."none"(default) reuses the model from the selection pass;"full"refits the bestkon the full dataset.- covariance_type
Literal['diag','full'] (default:'full') Covariance structure of each Gaussian component:
"full"(default) or"diag".- train_size
float(default:0.5) Training fraction used during k-selection. The remaining
1 - train_sizecells form the validation split fork_selection_metric="heldout_nll"and are ignored by in-sample criteria such as"aic"and"bic". Whenrefit="none", the selected model is retained from that training split; whenrefit="full", the best selectedkis refit on all cells. The default0.5yields a balanced train/validation split during selection.- max_iter
int(default:100) Maximum EM iterations per fit.
- tol
float(default:0.001) EM convergence tolerance.
- n_init
int(default:1) Number of random EM restarts. Higher values reduce sensitivity to initialization at the cost of runtime.
- eps
float(default:0.0001) Numerical floor added to diagonal of covariance matrices to prevent singularities.
- singularity_handling
Literal['guarded','robust','strict'] (default:'guarded') How near-singular projected covariances are treated:
"guarded"(default) applies a small jitter and continues,"robust"uses a larger stabilization,"strict"raises an error on any detected singularity.- gmm_key
Optional[str] (default:None) Key under which the fitted GMM is stored in
dataset.adata.uns(AnnData datasets only). Defaults to"gmm_<use_rep>"whenNone.- verbose
bool(default:True) Print per-patient fit progress.
- Returns:
TripletDataset | AnnData_TripletDataset The input dataset augmented with fitted GMM supports, covariances, and weights.