ggml_ot.train_test#
- ggml_ot.train_test(dataset, n_splits=5, train_size=0.6, test_size=None, scoring=('knn', 'ari'), knn_k=5, cluster_linkage='complete', plot_split=True, plot_type=True, print_table=True, print_latex=False, return_dataset=False, ground_metric=None, plot_split_dir=None, plot_title=None, **kwargs)[source]#
Trains and cross-validates ground metrics on train-test splits.
This function performs n_splits stratified train-test splits on the provided dataset. For each split, it trains a ground metric on the training set and evaluates it on the test set using a k-NN classification and hierarchical clustering.
Classification accuracy and clustering metrics are summarized in a table, and results can be plotted as clustermap and embeddings.
- Parameters:
- dataset TripletDataset | AnnData_TripletDataset
Dataset to perform train-test splits on.
See also
The documentation for the provided interfaces to
AnnDataandnumpy arrays.- n_splits int (default:
5) Number of train-test splits.
- train_size float (default:
0.6) Proportion of dataset to include in train split.
- test_size float | None (default:
None) Proportion of dataset to include in test split, if None 1 - train_size is used.
- scoring tuple[str, …] | list[str] (default:
('knn', 'ari')) Tuple or list of evaluation scores to compute on each test split.
"knn"computes k-NN classification accuracy;"ari"(Adjusted Rand Index),"mi"(Mutual Information), and"vi"(Variation of Information) are clustering scores obtained via hierarchical clustering.- knn_k int (default:
5) Number of neighbors used for benchmark k-NN classification.
- cluster_linkage str (default:
'complete') Linkage method used by hierarchical clustering for clustering scores and clustermap plots. Defaults to
"complete", the canonical choice for disease subtyping; produces more balanced cuts than"average"and avoids singleton outlier clusters that can destabilize ARI.- plot_split bool (default:
True) Whether to plot OT distances for each split
- plot_type Literal[‘clustermap_embedding’, ‘clustermap’, ‘embedding’] | list[str] | tuple[str, …] | bool (default:
True) Defines which plots to generate. One of
"clustermap_embedding","clustermap","embedding", a list of those values, orFalse.- print_table bool (default:
True) Whether to print the results table
- print_latex bool (default:
False) Whether to print the results table in LaTeX format
- return_dataset bool (default:
False) If False, returns a dict containing the trained ground metrics and a dataframe of the test scores. If True, returns the dataset with projected data using the best learned ground_metric.
Attention
return_dataset=True only works if ground metric is learned (default: ground_metric=None)
- ground_metric np.ndarray | str | callable | None (default:
None) If provided, this ground_metric is used for testing. You are encouraged to use
ggml_ot.test()instead.- plot_split_dir str | Path | None (default:
None) Optional output directory for per-split plots. When provided, split plots are saved into
split_XX/subdirectories below this path.- plot_title str | None (default:
None) Base title used for split plots. When
n_splits > 1, the split name is appended automatically.- **kwargs
Additional arguments passed to
ggml_ot.train(), see the corresponding docs for details.
- Return type:
TripletDataset | AnnData_TripletDataset | tuple[dict, pd.DataFrame]
- Returns:
- TripletDataset | AnnData_TripletDataset
If return_dataset is set to True, the dataset is returned with the best performing ground metric (dataset.map_A).
If the dataset is of type AnnData_TripletDataset, the cells are projected into the learned gene subspace (dataset.adata.obsm[“X_ggml”]) and the loadings of the gene subspace are stored in dataset.adata.varm[“W_ggml”].
- tuple[dict, pd.DataFrame]
- If return_dataset is False, a tuple is returned containing:
- A dict with keys:
”Ws”: List of learned ground metrics for each split
”best”: Best performing ground metric based on k-NN accuracy
”mean”: Mean ground metric across splits
”sd”: Standard deviation of ground metrics across splits
A DataFrame summarizing the mean and standard deviation of evaluation metrics across splits.