ggml_ot.train_emd2

Contents

ggml_ot.train_emd2#

ggml_ot.train_emd2(dataset, alpha=10.0, reg=0.001, reg_type='fro', n_comps=5, lr=0.05, max_iter=30, stop_thr=None, verbose=False, plot_iter=-1, batch_size=512, train_size=None, squared_ground_cost=False, eps=0.0001, **kwargs)[source]#

Train GGML with exact OT (EMD2).

Parameters:
dataset TripletDataset | AnnData_TripletDataset

Training dataset with empirical supports or Gaussian component means.

alpha float (default: 10.0)

Triplet margin and regularization controls.

reg float (default: 0.001)

Triplet margin and regularization controls.

reg_type Literal[1, '1', 2, '2', 'fro', 'nuc'] (default: 'fro')

Triplet margin and regularization controls.

n_comps int (default: 5)

Triplet margin and regularization controls.

lr float (default: 0.05)

Outer optimization controls (stop_thr is the gradient-norm threshold).

max_iter int (default: 30)

Outer optimization controls (stop_thr is the gradient-norm threshold).

stop_thr Optional[float] (default: None)

Outer optimization controls (stop_thr is the gradient-norm threshold).

verbose bool (default: False)

Logging and plotting cadence. Plot behavior: 0/False disables plotting, -1 plots only after the final epoch, 1/True plots every epoch, k>=1 plots every k epochs.

plot_iter int | bool (default: -1)

Logging and plotting cadence. Plot behavior: 0/False disables plotting, -1 plots only after the final epoch, 1/True plots every epoch, k>=1 plots every k epochs.

batch_size int (default: 512)

DataLoader controls.

train_size Optional[float] (default: None)

DataLoader controls.

squared_ground_cost bool (default: False)

If True on Gaussian datasets, use squared Euclidean distance for the mean term instead of Euclidean distance.

eps float (default: 0.0001)

Numerical floor for covariance regularization and OT stability.

Return type:

tuple[ndarray, float]

Returns:

tuple[np.ndarray, float] Learned ground metric (map_A) and mean epoch time.