ggml_ot.train_gmm#
- ggml_ot.train_gmm(dataset, alpha=10.0, reg=0.001, reg_type='fro', n_comps=5, lr=0.1, max_iter=30, verbose=False, plot_iter=-1, mi_reg=0.0, mi_sqrt=False, diag_bures_approx=False, mi_reg_weighting='projection', stop_thr=1e-06, entropic_reg=0.0, sinkhorn_max_iter=100, sinkhorn_stop_thr=None, batch_size=512, train_size=None, squared_ground_cost=False, eps=0.0001, **kwargs)[source]#
Train GGML with exact OT on GMM datasets.
This convenience wrapper enables GMM-specific regularization knobs while staying on the exact OT (EMD2) path.
- Parameters:
- dataset
TripletDataset|AnnData_TripletDataset Training dataset with Gaussian covariances.
- alpha
float(default:10.0) Triplet margin and regularization controls.
- reg
float(default:0.001) Triplet margin and regularization controls.
- reg_type
Literal[1,'1',2,'2','fro','nuc'] (default:'fro') Triplet margin and regularization controls.
- n_comps
int(default:5) Triplet margin and regularization controls.
- lr
float(default:0.1) Outer optimization controls.
- max_iter
int(default:30) Outer optimization controls.
- stop_thr
Optional[float] (default:1e-06) Outer optimization controls.
- verbose
bool(default:False) Logging and plotting cadence.
- plot_iter
int|bool(default:-1) Logging and plotting cadence.
- mi_reg
float|bool(default:0.0) Mutual-information regularization strength. If provided as bool, it is interpreted as a multiplicative on/off flag in the current implementation.
- mi_sqrt
bool(default:False) If True, minimize
sum w_k sqrt(MI_k)instead ofsum w_k MI_k. This matches the functional form of the diagonal approximation error bound (Theorem diag_mi_bounds).- diag_bures_approx
bool(default:False) If True, use diagonal approximation for the Bures component.
- squared_ground_cost
bool(default:False) If True, use squared Euclidean distance for the Gaussian mean term.
- mi_reg_weighting
str(default:'projection') Weighting mode for MI regularization over GMM components.
- entropic_reg
float(default:0.0) If
> 0, use the Sinkhorn backend instead of exact EMD2.- sinkhorn_max_iter
int(default:100) Sinkhorn inner-solver controls used when
entropic_reg > 0.- sinkhorn_stop_thr
Optional[float] (default:None) Sinkhorn inner-solver controls used when
entropic_reg > 0.- batch_size
int(default:512) DataLoader controls.
- train_size
Optional[float] (default:None) DataLoader controls.
- eps
float(default:0.0001) Numerical floor for covariance regularization and OT stability.
- dataset
- Return type:
tuple[ndarray,float]- Returns:
tuple[np.ndarray, float] Learned ground metric (
map_A) and mean epoch time.- Raises:
ValueError – If the dataset has no covariance tensors.