scipy.stats.wasserstein_distance¶
-
scipy.stats.
wasserstein_distance
(u_values, v_values, u_weights=None, v_weights=None)[source]¶ Compute the first Wasserstein distance between two 1D distributions.
This distance is also known as the earth mover’s distance, since it can be seen as the minimum amount of “work” required to transform into , where “work” is measured as the amount of distribution weight that must be moved, multiplied by the distance it has to be moved.
New in version 1.0.0.
- Parameters
- u_values, v_valuesarray_like
Values observed in the (empirical) distribution.
- u_weights, v_weightsarray_like, optional
Weight for each value. If unspecified, each value is assigned the same weight. u_weights (resp. v_weights) must have the same length as u_values (resp. v_values). If the weight sum differs from 1, it must still be positive and finite so that the weights can be normalized to sum to 1.
- Returns
- distancefloat
The computed distance between the distributions.
Notes
The first Wasserstein distance between the distributions and is:
where is the set of (probability) distributions on whose marginals are and on the first and second factors respectively.
If and are the respective CDFs of and , this distance also equals to:
See [2] for a proof of the equivalence of both definitions.
The input distributions can be empirical, therefore coming from samples whose values are effectively inputs of the function, or they can be seen as generalized functions, in which case they are weighted sums of Dirac delta functions located at the specified values.
References
- 1
“Wasserstein metric”, https://en.wikipedia.org/wiki/Wasserstein_metric
- 2(1,2)
Ramdas, Garcia, Cuturi “On Wasserstein Two Sample Testing and Related Families of Nonparametric Tests” (2015). arXiv:1509.02237.
Examples
>>> from scipy.stats import wasserstein_distance >>> wasserstein_distance([0, 1, 3], [5, 6, 8]) 5.0 >>> wasserstein_distance([0, 1], [0, 1], [3, 1], [2, 2]) 0.25 >>> wasserstein_distance([3.4, 3.9, 7.5, 7.8], [4.5, 1.4], ... [1.4, 0.9, 3.1, 7.2], [3.2, 3.5]) 4.0781331438047861