ggml_ot.train_sinkhorn#
- ggml_ot.train_sinkhorn(dataset, alpha=10.0, reg=0.001, reg_type='fro', n_comps=5, entropic_reg=0.5, sinkhorn_max_iter=100, sinkhorn_stop_thr=None, lr=0.05, max_iter=30, stop_thr=None, verbose=False, plot_iter=-1, batch_size=512, train_size=None, squared_ground_cost=False, eps=0.0001, **kwargs)[source]#
Train GGML with Sinkhorn-regularized OT.
- Parameters:
- dataset
TripletDataset|AnnData_TripletDataset Training dataset with empirical supports or Gaussian component means/covariances.
- alpha
float(default:10.0) Triplet margin and regularization controls.
- reg
float(default:0.001) Triplet margin and regularization controls.
- reg_type
Literal[1,'1',2,'2','fro','nuc'] (default:'fro') Triplet margin and regularization controls.
- n_comps
int(default:5) Triplet margin and regularization controls.
- entropic_reg
float(default:0.5) Entropic OT regularization strength for Sinkhorn (> 0 required), expressed as a fraction of the mean ground-cost value (e.g.
0.5regularizes at half the mean pairwise cost). The cost matrix is normalized internally, so this value is scale-invariant across datasets and training stages.- sinkhorn_max_iter
int(default:100) Sinkhorn inner-solver controls.
- sinkhorn_stop_thr
Optional[float] (default:None) Sinkhorn inner-solver controls.
- lr
float(default:0.05) Outer optimization controls.
- max_iter
int(default:30) Outer optimization controls.
- stop_thr
Optional[float] (default:None) Outer optimization controls.
- verbose
bool(default:False) Logging and plotting cadence.
- plot_iter
int|bool(default:-1) Logging and plotting cadence.
- batch_size
int(default:512) DataLoader controls.
- train_size
Optional[float] (default:None) DataLoader controls.
- squared_ground_cost
bool(default:False) If True on Gaussian datasets, use squared Euclidean distance for the mean term instead of Euclidean distance.
- eps
float(default:0.0001) Numerical floor for covariance regularization and OT stability.
- dataset
- Return type:
tuple[ndarray,float]- Returns:
tuple[np.ndarray, float] Learned ground metric (
map_A) and mean epoch time.- Raises:
ValueError – If entropic_reg <= 0.