ggml_ot.train_sinkhorn

Contents

ggml_ot.train_sinkhorn#

ggml_ot.train_sinkhorn(dataset, alpha=10.0, reg=0.001, reg_type='fro', n_comps=5, entropic_reg=0.5, sinkhorn_max_iter=100, sinkhorn_stop_thr=None, lr=0.05, max_iter=30, stop_thr=None, verbose=False, plot_iter=-1, batch_size=512, train_size=None, squared_ground_cost=False, eps=0.0001, **kwargs)[source]#

Train GGML with Sinkhorn-regularized OT.

Parameters:
dataset TripletDataset | AnnData_TripletDataset

Training dataset with empirical supports or Gaussian component means/covariances.

alpha float (default: 10.0)

Triplet margin and regularization controls.

reg float (default: 0.001)

Triplet margin and regularization controls.

reg_type Literal[1, '1', 2, '2', 'fro', 'nuc'] (default: 'fro')

Triplet margin and regularization controls.

n_comps int (default: 5)

Triplet margin and regularization controls.

entropic_reg float (default: 0.5)

Entropic OT regularization strength for Sinkhorn (> 0 required), expressed as a fraction of the mean ground-cost value (e.g. 0.5 regularizes at half the mean pairwise cost). The cost matrix is normalized internally, so this value is scale-invariant across datasets and training stages.

sinkhorn_max_iter int (default: 100)

Sinkhorn inner-solver controls.

sinkhorn_stop_thr Optional[float] (default: None)

Sinkhorn inner-solver controls.

lr float (default: 0.05)

Outer optimization controls.

max_iter int (default: 30)

Outer optimization controls.

stop_thr Optional[float] (default: None)

Outer optimization controls.

verbose bool (default: False)

Logging and plotting cadence.

plot_iter int | bool (default: -1)

Logging and plotting cadence.

batch_size int (default: 512)

DataLoader controls.

train_size Optional[float] (default: None)

DataLoader controls.

squared_ground_cost bool (default: False)

If True on Gaussian datasets, use squared Euclidean distance for the mean term instead of Euclidean distance.

eps float (default: 0.0001)

Numerical floor for covariance regularization and OT stability.

Return type:

tuple[ndarray, float]

Returns:

tuple[np.ndarray, float] Learned ground metric (map_A) and mean epoch time.

Raises:

ValueError – If entropic_reg <= 0.