ggml_ot.data.TripletDataset#
- class ggml_ot.data.TripletDataset(supports, distribution_labels, n_triplets=3, weights=None, covariances=None, identical_supports=False, **kwargs)[source]#
Dataset to train GGML based on array data.
This class stores a collection of distributions (“supports”) and produces triplets (i, j, k) of relative relationships where i and j are from the same class and k is from a different class. These triplets are used to train GGML such that distributions i and j are closer to each other than j and k by some margin alpha.
This class exposes the dataset to the standardized interfaces used by
ggml_ot.train(),ggml_ot.tune(),ggml_ot.test()andggml_ot.train_test().- Parameters:
- supports Sequence[np.ndarray]
Sequence of per-distribution supports. Each element is an array of points (for empirical distributions) or component means (for GMM-style representations).
- distribution_labels Sequence[int] | np.ndarray
Integer labels identifying the class/group of each distribution.
- n_triplets int, optional
Number of triplets to generate per “anchor” distribution (default: 3).
- weights Sequence[np.ndarray] | None, optional
Per-distribution probability weights (e.g., cluster proportions) or None for uniform weights (default: None).
- covariances Sequence[np.ndarray] | None, optional
Optional per-distribution covariance arrays when supports represent Gaussian mixture components (default: None).
- identical_supports bool, optional
If True, indicates that all distributions share the same supports (e.g., identical component locations across distributions). This changes the __getitem__ return format and allows faster OT evaluation (default: False).
Notes
The class generates triplets by sampling t “positive” neighbors from the same class and t “negative” neighbors from each different class for every distribution.
- __init__(supports, distribution_labels, n_triplets=3, weights=None, covariances=None, identical_supports=False, **kwargs)[source]#
Methods
__init__(supports, distribution_labels[, ...])compute_OT([precomputed_distances, ...])Compute the Optimal Transport distances between all distributions.
fit_gmm(*[, component_sharing, k_comps, ...])Fit per-patient GMM parameters from a GGML dataset.
normalize()test([ground_metric, knn_k])Tests ground metric on a given dataset.
train([alpha, reg, reg_type, n_comps, lr, ...])Perform supervised optimal transport by ground metric learning.
train_emd2([alpha, reg, reg_type, n_comps, ...])Train GGML with exact OT (EMD2).
train_sinkhorn([alpha, reg, reg_type, ...])Train GGML with Sinkhorn-regularized OT.
train_test([n_splits, train_size, ...])Trains and cross-validates ground metrics on train-test splits.
tune([alpha, reg, reg_type, n_comps, ...])Tune hyperparameters by performing a Grid Search and Cross-Validation.
validate_gmm(*[, gmm_key, patient_col, ...])Run a one-call hold-out validation pass for a fitted GMM.
Attributes
covariancesdimDimensions of space underlying the distributions.
distribution_labelsInteger class labels for each distribution.
distribution_labels_stridentical_supportsFlag as passed to the constructor.
map_ALearned ground metric as a linear map (raises a warning if dataset is not trained yet).
pointspoints_labelsReturns list of the distribution_labels of all points concatenated over all distributions
supportsStored supports.
w_thetause
map_Ainstead.weightsStored per-distribution weights (if provided).