ggml_ot.tune#
- ggml_ot.tune(dataset, alpha=[0.1, 1, 10], reg=[0.01, 0.1, 1, 10], reg_type=['fro'], n_comps=[2, 5], mi_reg=None, knn_k=5, print_latex=False, plot_contour=True, verbose=False, return_dataset=False, **kwargs)[source]#
Tune hyperparameters by performing a Grid Search and Cross-Validation.
- Parameters:
- dataset
TripletDataset|AnnData_TripletDataset A dataset containing triplets of distributions.
See also
The documentation for the provided interfaces to
AnnDataandnumpy arrays.- alpha
float|list(default:[0.1, 1, 10]) A list or float of margin(s) between distributions from different classes (e.g. disease states). Large values lead to strong separations on the train set, but potential overfitting.
- reg
float|list(default:[0.01, 0.1, 1, 10]) A list or float of regularization strength(s).
- reg_type
str|list(default:['fro']) A list or str of regularization type(s): 1 for L1, 2 or “fro” for L2/Frobenius, and “nuc” for nuclear norm.
- n_comps
int|list(default:[2, 5]) A list or int of number of components in the learned subspaces, i.e., rank of the subspace.
- mi_reg
UnionType[float,list,None] (default:None) Optional mutual-information regularization strength(s). Can only be used when the dataset has covariances (i.e. a GMM dataset). If not provided, the behavior is unchanged and the default from
ggml_ot.train_test()/training is used.- knn_k
int(default:5) Number of neighbors used for benchmark k-NN classification during each train/test evaluation.
- print_latex
bool(default:False) Whether to print the hyperparameter tuning results as a LaTeX table.
- plot_contour
bool(default:True) Plot hyperparameter tuning results over alpha and reg for best n_comps and reg_type. You can also manually create contour plots from the returned dataframe using
ggml_ot.pl.contour_hyperparams()- verbose
bool(default:False) Whether to print progress information during training.
- return_dataset
bool(default:False) If False, returns a tuple containing the results of the hyperparameter tuning.
If True, returns the dataset with the best performing ground metric assigned to dataset.map_A. If the dataset is of type AnnData_TripletDataset, the cells are projected into the learned gene subspace (dataset.adata.obsm[“X_ggml”]) and the loadings of the gene subspace are stored in dataset.adata.varm[“W_ggml”].
- **kwargs
Additional arguments passed to
ggml_ot.train_test().
- dataset
- Returns:
- tuple[dict, pd.DataFrame]
If return_dataset is set to False, a tuple is returned containing: - A dictionary mapping hyperparameter combinations to the best performing ground metric for that combination. - A DataFrame summarizing the mean and standard deviation of the evaluation metrics across test splits for each hyperparameter combination.
- TripletDataset | AnnData_TripletDataset
If return_dataset is set to True, the dataset is returned with the best performing ground metric (dataset.map_A).
If the dataset is of type AnnData_TripletDataset, the cells are projected into the learned gene subspace (dataset.adata.obsm[“X_ggml”]) and the loadings of the gene subspace are stored in dataset.adata.varm[“W_ggml”].