ggml_ot.train_test

Contents

ggml_ot.train_test#

ggml_ot.train_test(dataset, n_splits=5, train_size=0.6, test_size=None, scoring=('knn', 'ari'), knn_k=5, cluster_linkage='complete', plot_split=True, plot_type=True, print_table=True, print_latex=False, return_dataset=False, ground_metric=None, plot_split_dir=None, plot_title=None, **kwargs)[source]#

Trains and cross-validates ground metrics on train-test splits.

This function performs n_splits stratified train-test splits on the provided dataset. For each split, it trains a ground metric on the training set and evaluates it on the test set using a k-NN classification and hierarchical clustering.

Classification accuracy and clustering metrics are summarized in a table, and results can be plotted as clustermap and embeddings.

Parameters:
dataset TripletDataset | AnnData_TripletDataset

Dataset to perform train-test splits on.

See also

The documentation for the provided interfaces to AnnData and numpy arrays.

n_splits int (default: 5)

Number of train-test splits.

train_size float (default: 0.6)

Proportion of dataset to include in train split.

test_size float | None (default: None)

Proportion of dataset to include in test split, if None 1 - train_size is used.

scoring tuple[str, …] | list[str] (default: ('knn', 'ari'))

Tuple or list of evaluation scores to compute on each test split. "knn" computes k-NN classification accuracy; "ari" (Adjusted Rand Index), "mi" (Mutual Information), and "vi" (Variation of Information) are clustering scores obtained via hierarchical clustering.

knn_k int (default: 5)

Number of neighbors used for benchmark k-NN classification.

cluster_linkage str (default: 'complete')

Linkage method used by hierarchical clustering for clustering scores and clustermap plots. Defaults to "complete", the canonical choice for disease subtyping; produces more balanced cuts than "average" and avoids singleton outlier clusters that can destabilize ARI.

plot_split bool (default: True)

Whether to plot OT distances for each split

plot_type Literal[‘clustermap_embedding’, ‘clustermap’, ‘embedding’] | list[str] | tuple[str, …] | bool (default: True)

Defines which plots to generate. One of "clustermap_embedding", "clustermap", "embedding", a list of those values, or False.

print_table bool (default: True)

Whether to print the results table

print_latex bool (default: False)

Whether to print the results table in LaTeX format

return_dataset bool (default: False)

If False, returns a dict containing the trained ground metrics and a dataframe of the test scores. If True, returns the dataset with projected data using the best learned ground_metric.

Attention

return_dataset=True only works if ground metric is learned (default: ground_metric=None)

ground_metric np.ndarray | str | callable | None (default: None)

If provided, this ground_metric is used for testing. You are encouraged to use ggml_ot.test() instead.

plot_split_dir str | Path | None (default: None)

Optional output directory for per-split plots. When provided, split plots are saved into split_XX/ subdirectories below this path.

plot_title str | None (default: None)

Base title used for split plots. When n_splits > 1, the split name is appended automatically.

**kwargs

Additional arguments passed to ggml_ot.train(), see the corresponding docs for details.

Return type:

TripletDataset | AnnData_TripletDataset | tuple[dict, pd.DataFrame]

Returns:

TripletDataset | AnnData_TripletDataset

If return_dataset is set to True, the dataset is returned with the best performing ground metric (dataset.map_A).

If the dataset is of type AnnData_TripletDataset, the cells are projected into the learned gene subspace (dataset.adata.obsm[“X_ggml”]) and the loadings of the gene subspace are stored in dataset.adata.varm[“W_ggml”].

tuple[dict, pd.DataFrame]
If return_dataset is False, a tuple is returned containing:
  • A dict with keys:
    • ”Ws”: List of learned ground metrics for each split

    • ”best”: Best performing ground metric based on k-NN accuracy

    • ”mean”: Mean ground metric across splits

    • ”sd”: Standard deviation of ground metrics across splits

  • A DataFrame summarizing the mean and standard deviation of evaluation metrics across splits.