ggml_ot.data.TripletDataset

ggml_ot.data.TripletDataset#

class ggml_ot.data.TripletDataset(supports, distribution_labels, n_triplets=3, weights=None, covariances=None, identical_supports=False, **kwargs)[source]#

Dataset to train GGML based on array data.

This class stores a collection of distributions (“supports”) and produces triplets (i, j, k) of relative relationships where i and j are from the same class and k is from a different class. These triplets are used to train GGML such that distributions i and j are closer to each other than j and k by some margin alpha.

This class exposes the dataset to the standardized interfaces used by ggml_ot.train(), ggml_ot.tune(), ggml_ot.test() and ggml_ot.train_test().

Parameters:
supports Sequence[np.ndarray]

Sequence of per-distribution supports. Each element is an array of points (for empirical distributions) or component means (for GMM-style representations).

distribution_labels Sequence[int] | np.ndarray

Integer labels identifying the class/group of each distribution.

n_triplets int, optional

Number of triplets to generate per “anchor” distribution (default: 3).

weights Sequence[np.ndarray] | None, optional

Per-distribution probability weights (e.g., cluster proportions) or None for uniform weights (default: None).

covariances Sequence[np.ndarray] | None, optional

Optional per-distribution covariance arrays when supports represent Gaussian mixture components (default: None).

identical_supports bool, optional

If True, indicates that all distributions share the same supports (e.g., identical component locations across distributions). This changes the __getitem__ return format and allows faster OT evaluation (default: False).

Notes

  • The class generates triplets by sampling t “positive” neighbors from the same class and t “negative” neighbors from each different class for every distribution.

__init__(supports, distribution_labels, n_triplets=3, weights=None, covariances=None, identical_supports=False, **kwargs)[source]#

Methods

__init__(supports, distribution_labels[, ...])

compute_OT([precomputed_distances, ...])

Compute the Optimal Transport distances between all distributions.

fit_gmm(*[, component_sharing, k_comps, ...])

Fit per-patient GMM parameters from a GGML dataset.

normalize()

test([ground_metric, knn_k])

Tests ground metric on a given dataset.

train([alpha, reg, reg_type, n_comps, lr, ...])

Perform supervised optimal transport by ground metric learning.

train_emd2([alpha, reg, reg_type, n_comps, ...])

Train GGML with exact OT (EMD2).

train_sinkhorn([alpha, reg, reg_type, ...])

Train GGML with Sinkhorn-regularized OT.

train_test([n_splits, train_size, ...])

Trains and cross-validates ground metrics on train-test splits.

tune([alpha, reg, reg_type, n_comps, ...])

Tune hyperparameters by performing a Grid Search and Cross-Validation.

validate_gmm(*[, gmm_key, patient_col, ...])

Run a one-call hold-out validation pass for a fitted GMM.

Attributes

covariances

dim

Dimensions of space underlying the distributions.

distribution_labels

Integer class labels for each distribution.

distribution_labels_str

identical_supports

Flag as passed to the constructor.

map_A

Learned ground metric as a linear map (raises a warning if dataset is not trained yet).

points

points_labels

Returns list of the distribution_labels of all points concatenated over all distributions

supports

Stored supports.

w_theta

use map_A instead.

weights

Stored per-distribution weights (if provided).