wsingular.distance
- wsingular.distance.sinkhorn_map(A: torch.Tensor, C: torch.Tensor, eps: float, dtype: torch.dtype, device: str, R: Optional[torch.Tensor] = None, tau: float = 0, progress_bar: bool = False, stop_threshold: float = 1e-05, num_iter_max: int = 500) torch.Tensor
This function maps a ground cost to the pairwise Sinkhorn divergence matrix on a certain dataset using that ground cost. R is an added regularization.
- Parameters:
A (torch.Tensor) – The input dataset, rows as samples.
C (torch.Tensor) – The ground cost.
eps (float) – The entropic regularization parameter.
dtype (torch.dtype) – The dtype.
device (str) – The device.
R (torch.Tensor) – The added regularization. Defaults to None.
tau (float) – The regularization parameter. Defaults to 0.
progress_bar (bool) – Whether to show a progress bar during the computation. Defaults to False.
stop_threshold (float, optional) – Stopping threshold for Sinkhorn (please refer to POT). Defaults to 1e-5.
num_iter_max (int, optional) – Maximum number of Sinkhorn iterations (please refer to POT). Defaults to 500.
- Returns:
The pairwise Sinkhorn divergence matrix.
- Return type:
torch.Tensor
- wsingular.distance.stochastic_sinkhorn_map(A: torch.Tensor, D: torch.Tensor, C: torch.Tensor, sample_prop: float, gamma: float, eps: float, R: Optional[torch.Tensor] = None, tau: float = 0, progress_bar: bool = False, return_indices: bool = False, batch_size: int = 50, stop_threshold: float = 1e-05, num_iter_max: int = 100) torch.Tensor
Returns the stochastic Sinkhorn divergence map, updating only a random subset of indices and leaving the other ones as they are.
- Parameters:
A (torch.Tensor) – The input dataset.
D (torch.Tensor) – The intialization of the distance matrix
C (torch.Tensor) – The ground cost
sample_prop (float) – The proportion of indices to update
gamma (float) – Rescaling parameter. In practice, one should rescale by an approximation of the singular value.
eps (float) – The entropic regularization parameter
R (torch.Tensor) – The regularization matrix. Defaults to None.
tau (float) – The regularization parameter. Defaults to 0.
progress_bar (bool) – Whether to show a progress bar during the computation. Defaults to False.
return_indices (bool) – Whether to return the updated indices. Defaults to False.
batch_size (int) – Batch size, i.e. how many distances to compute at the same time. Depends on your available GPU memory. Defaults to 50.
- Returns:
The stochastically updated distance matrix.
- Return type:
torch.Tensor
- wsingular.distance.stochastic_wasserstein_map(A: torch.Tensor, D: torch.Tensor, C: torch.Tensor, sample_prop: float, gamma: float, dtype: torch.dtype, device: str, R: Optional[torch.Tensor] = None, tau: float = 0, progress_bar: bool = False, return_indices: bool = False) torch.Tensor
Returns the stochastic Wasserstein map, updating only a random subset of indices and leaving the other ones as they are.
- Parameters:
A (torch.Tensor) – The input dataset.
D (torch.Tensor) – The intialization of the distance matrix
C (torch.Tensor) – The ground cost
sample_prop (float) – The proportion of indices to update
gamma (float) – A scaling factor
dtype (torch.dtype) – The dtype
device (str) – The device
R (torch.Tensor) – The regularization matrix. Defaults to None.
tau (float) – The regularization parameter. Defaults to 0.
progress_bar (bool) – Whether to show a progress bar during the computation. Defaults to False.
return_indices (bool) – Whether to return the updated indices. Defaults to False.
stop_threshold (float, optional) – Stopping threshold for Sinkhorn (please refer to POT). Defaults to 1e-5.
num_iter_max (int, optional) – Maximum number of Sinkhorn iterations (please refer to POT). Defaults to 500.
- Returns:
The stochastically updated distance matrix.
- Return type:
torch.Tensor
- wsingular.distance.wasserstein_map(A: torch.Tensor, C: torch.Tensor, dtype: torch.dtype, device: str, R: Optional[torch.Tensor] = None, tau: float = 0, progress_bar: bool = False) torch.Tensor
This function maps a ground cost to the Wasserstein distance matrix on a certain dataset using that ground cost. R is an added regularization.
- Parameters:
A (torch.Tensor) – The input dataset, rows as samples.
C (torch.Tensor) – the ground cost.
dtype (torch.dtype) – The dtype.
device (str) – The device.
R (torch.Tensor) – The regularization matrix. Defaults to None.
tau (float) – The regularization parameter. Defaults to 0.
progress_bar (bool) – Whether to show a progress bar during the computation. Defaults to False
- Returns:
The Wasserstein distance matrix with regularization.
- Return type:
torch.Tensor