ggml_ot.train_gmm

Contents

ggml_ot.train_gmm#

ggml_ot.train_gmm(dataset, alpha=10.0, reg=0.001, reg_type='fro', n_comps=5, lr=0.1, max_iter=30, verbose=False, plot_iter=-1, mi_reg=0.0, mi_sqrt=False, diag_bures_approx=False, mi_reg_weighting='projection', stop_thr=1e-06, entropic_reg=0.0, sinkhorn_max_iter=100, sinkhorn_stop_thr=None, batch_size=512, train_size=None, squared_ground_cost=False, eps=0.0001, **kwargs)[source]#

Train GGML with exact OT on GMM datasets.

This convenience wrapper enables GMM-specific regularization knobs while staying on the exact OT (EMD2) path.

Parameters:
dataset TripletDataset | AnnData_TripletDataset

Training dataset with Gaussian covariances.

alpha float (default: 10.0)

Triplet margin and regularization controls.

reg float (default: 0.001)

Triplet margin and regularization controls.

reg_type Literal[1, '1', 2, '2', 'fro', 'nuc'] (default: 'fro')

Triplet margin and regularization controls.

n_comps int (default: 5)

Triplet margin and regularization controls.

lr float (default: 0.1)

Outer optimization controls.

max_iter int (default: 30)

Outer optimization controls.

stop_thr Optional[float] (default: 1e-06)

Outer optimization controls.

verbose bool (default: False)

Logging and plotting cadence.

plot_iter int | bool (default: -1)

Logging and plotting cadence.

mi_reg float | bool (default: 0.0)

Mutual-information regularization strength. If provided as bool, it is interpreted as a multiplicative on/off flag in the current implementation.

mi_sqrt bool (default: False)

If True, minimize sum w_k sqrt(MI_k) instead of sum w_k MI_k. This matches the functional form of the diagonal approximation error bound (Theorem diag_mi_bounds).

diag_bures_approx bool (default: False)

If True, use diagonal approximation for the Bures component.

squared_ground_cost bool (default: False)

If True, use squared Euclidean distance for the Gaussian mean term.

mi_reg_weighting str (default: 'projection')

Weighting mode for MI regularization over GMM components.

entropic_reg float (default: 0.0)

If > 0, use the Sinkhorn backend instead of exact EMD2.

sinkhorn_max_iter int (default: 100)

Sinkhorn inner-solver controls used when entropic_reg > 0.

sinkhorn_stop_thr Optional[float] (default: None)

Sinkhorn inner-solver controls used when entropic_reg > 0.

batch_size int (default: 512)

DataLoader controls.

train_size Optional[float] (default: None)

DataLoader controls.

eps float (default: 0.0001)

Numerical floor for covariance regularization and OT stability.

Return type:

tuple[ndarray, float]

Returns:

tuple[np.ndarray, float] Learned ground metric (map_A) and mean epoch time.

Raises:

ValueError – If the dataset has no covariance tensors.