ggml_ot.train_emd2#
- ggml_ot.train_emd2(dataset, alpha=10.0, reg=0.001, reg_type='fro', n_comps=5, lr=0.05, max_iter=30, stop_thr=None, verbose=False, plot_iter=-1, batch_size=512, train_size=None, squared_ground_cost=False, eps=0.0001, **kwargs)[source]#
Train GGML with exact OT (EMD2).
- Parameters:
- dataset
TripletDataset|AnnData_TripletDataset Training dataset with empirical supports or Gaussian component means.
- alpha
float(default:10.0) Triplet margin and regularization controls.
- reg
float(default:0.001) Triplet margin and regularization controls.
- reg_type
Literal[1,'1',2,'2','fro','nuc'] (default:'fro') Triplet margin and regularization controls.
- n_comps
int(default:5) Triplet margin and regularization controls.
- lr
float(default:0.05) Outer optimization controls (stop_thr is the gradient-norm threshold).
- max_iter
int(default:30) Outer optimization controls (stop_thr is the gradient-norm threshold).
- stop_thr
Optional[float] (default:None) Outer optimization controls (stop_thr is the gradient-norm threshold).
- verbose
bool(default:False) Logging and plotting cadence. Plot behavior: 0/False disables plotting, -1 plots only after the final epoch, 1/True plots every epoch, k>=1 plots every k epochs.
- plot_iter
int|bool(default:-1) Logging and plotting cadence. Plot behavior: 0/False disables plotting, -1 plots only after the final epoch, 1/True plots every epoch, k>=1 plots every k epochs.
- batch_size
int(default:512) DataLoader controls.
- train_size
Optional[float] (default:None) DataLoader controls.
- squared_ground_cost
bool(default:False) If True on Gaussian datasets, use squared Euclidean distance for the mean term instead of Euclidean distance.
- eps
float(default:0.0001) Numerical floor for covariance regularization and OT stability.
- dataset
- Return type:
tuple[ndarray,float]- Returns:
tuple[np.ndarray, float] Learned ground metric (
map_A) and mean epoch time.