ggml_ot.train#
- ggml_ot.train(dataset, alpha=1.0, reg=0.001, reg_type='fro', n_comps=5, lr=0.1, max_iter=30, plot_iter=-1, verbose=False, batch_size=512, train_size=None, squared_ground_cost=False, return_dataset=True, measure_time=False, mi_reg=0.0, diag_bures_approx=False, mi_reg_weighting=None, entropic_reg=0.0, sinkhorn_max_iter=100, sinkhorn_stop_thr=1e-06, stop_thr=1e-06, eps=0.0001, **kwargs)[source]#
Perform supervised optimal transport by ground metric learning.
GGML learns a linear ground metric so distributions from the same class are closer than distributions from different classes.
- Parameters:
- dataset
TripletDataset|AnnData_TripletDataset Dataset containing distributions and labels.
- alpha
float(default:1.0) Triplet margin.
- reg
float(default:0.001) Regularization strength.
- reg_type
Literal[1,'1',2,'2','fro','nuc'] (default:'fro') Type of regularization (1/”1” for L1, 2/”2”/”fro” for L2/Frobenius, “nuc” for nuclear norm).
- n_comps
int(default:5) Number of learned components (rank of the linear map).
- lr
float(default:0.1) Adam learning rate.
- max_iter
int(default:30) Number of training epochs.
- plot_iter
int|bool(default:-1) Training progress plot frequency:
0/Falsedisables plotting,-1plots only after the final epoch,1/Trueplots every epoch,k >= 1plots every k epochs. Whether plots are displayed or saved follows theshow/saveconventions described in the plotting API (seeggml_ot.pl.embedding()).- verbose
bool(default:False) Print optimization progress.
- batch_size
int(default:512) DataLoader batch size.
- train_size
Optional[float] (default:None) Optional train split ratio used to create a training subset.
- return_dataset
bool(default:True) If True, assign learned metric to
dataset.map_Aand return dataset.- measure_time
bool(default:False) If True and
return_dataset=False, also return mean epoch time.- mi_reg
float(default:0.0) Mutual-information regularization strength for GMM training. In current behavior, bool is allowed and treated multiplicatively.
- diag_bures_approx
bool(default:False) If True, use diagonal approximation of the Bures term for GMMs.
- squared_ground_cost
bool(default:False) If True on Gaussian datasets, use squared Euclidean distance for the mean term instead of Euclidean distance.
- mi_reg_weighting
Optional[str] (default:None) Weighting scheme for MI regularization across Gaussian components.
- entropic_reg
float(default:0.0) Entropic OT regularization. If
> 0, dispatch to Sinkhorn training.- sinkhorn_max_iter
int(default:100) Maximum iterations of the Sinkhorn inner solver.
- sinkhorn_stop_thr
Optional[float] (default:1e-06) Stopping threshold for the Sinkhorn inner solver.
- stop_thr
Optional[float] (default:1e-06) Stopping threshold on objective gradient norm in outer optimization.
- eps
float(default:0.0001) Numerical floor for covariance regularization and OT stability.
- dataset
- Return type:
TripletDataset|AnnData_TripletDataset|ndarray|tuple[ndarray,float]- Returns:
- TripletDataset | AnnData_TripletDataset
Returned when
return_dataset=True.- np.ndarray
Learned ground metric when
return_dataset=False.- tuple[np.ndarray, float]
Learned ground metric and mean epoch time when
return_dataset=Falseandmeasure_time=True.
- Raises:
ValueError – If GMM-specific knobs are used without covariances.