ggml_ot.train

Contents

ggml_ot.train#

ggml_ot.train(dataset, alpha=1.0, reg=0.001, reg_type='fro', n_comps=5, lr=0.1, max_iter=30, plot_iter=-1, verbose=False, batch_size=512, train_size=None, squared_ground_cost=False, return_dataset=True, measure_time=False, mi_reg=0.0, diag_bures_approx=False, mi_reg_weighting=None, entropic_reg=0.0, sinkhorn_max_iter=100, sinkhorn_stop_thr=1e-06, stop_thr=1e-06, eps=0.0001, **kwargs)[source]#

Perform supervised optimal transport by ground metric learning.

GGML learns a linear ground metric so distributions from the same class are closer than distributions from different classes.

Parameters:
dataset TripletDataset | AnnData_TripletDataset

Dataset containing distributions and labels.

alpha float (default: 1.0)

Triplet margin.

reg float (default: 0.001)

Regularization strength.

reg_type Literal[1, '1', 2, '2', 'fro', 'nuc'] (default: 'fro')

Type of regularization (1/”1” for L1, 2/”2”/”fro” for L2/Frobenius, “nuc” for nuclear norm).

n_comps int (default: 5)

Number of learned components (rank of the linear map).

lr float (default: 0.1)

Adam learning rate.

max_iter int (default: 30)

Number of training epochs.

plot_iter int | bool (default: -1)

Training progress plot frequency: 0/False disables plotting, -1 plots only after the final epoch, 1/True plots every epoch, k >= 1 plots every k epochs. Whether plots are displayed or saved follows the show/save conventions described in the plotting API (see ggml_ot.pl.embedding()).

verbose bool (default: False)

Print optimization progress.

batch_size int (default: 512)

DataLoader batch size.

train_size Optional[float] (default: None)

Optional train split ratio used to create a training subset.

return_dataset bool (default: True)

If True, assign learned metric to dataset.map_A and return dataset.

measure_time bool (default: False)

If True and return_dataset=False, also return mean epoch time.

mi_reg float (default: 0.0)

Mutual-information regularization strength for GMM training. In current behavior, bool is allowed and treated multiplicatively.

diag_bures_approx bool (default: False)

If True, use diagonal approximation of the Bures term for GMMs.

squared_ground_cost bool (default: False)

If True on Gaussian datasets, use squared Euclidean distance for the mean term instead of Euclidean distance.

mi_reg_weighting Optional[str] (default: None)

Weighting scheme for MI regularization across Gaussian components.

entropic_reg float (default: 0.0)

Entropic OT regularization. If > 0, dispatch to Sinkhorn training.

sinkhorn_max_iter int (default: 100)

Maximum iterations of the Sinkhorn inner solver.

sinkhorn_stop_thr Optional[float] (default: 1e-06)

Stopping threshold for the Sinkhorn inner solver.

stop_thr Optional[float] (default: 1e-06)

Stopping threshold on objective gradient norm in outer optimization.

eps float (default: 0.0001)

Numerical floor for covariance regularization and OT stability.

Return type:

TripletDataset | AnnData_TripletDataset | ndarray | tuple[ndarray, float]

Returns:

TripletDataset | AnnData_TripletDataset

Returned when return_dataset=True.

np.ndarray

Learned ground metric when return_dataset=False.

tuple[np.ndarray, float]

Learned ground metric and mean epoch time when return_dataset=False and measure_time=True.

Raises:

ValueError – If GMM-specific knobs are used without covariances.