pith. sign in

arxiv: 1506.05439 · v3 · pith:UT4ND4DDnew · submitted 2015-06-17 · 💻 cs.LG · cs.CV· stat.ML

Learning with a Wasserstein Loss

classification 💻 cs.LG cs.CVstat.ML
keywords wassersteinlearningdistancelossmeasuresmetricdescribemulti-label
0
0 comments X
read the original abstract

Learning to predict multi-label outputs is challenging, but in many problems there is a natural metric on the outputs that can be used to improve predictions. In this paper we develop a loss function for multi-label learning, based on the Wasserstein distance. The Wasserstein distance provides a natural notion of dissimilarity for probability measures. Although optimizing with respect to the exact Wasserstein distance is costly, recent work has described a regularized approximation that is efficiently computed. We describe an efficient learning algorithm based on this regularization, as well as a novel extension of the Wasserstein distance from probability measures to unnormalized measures. We also describe a statistical learning bound for the loss. The Wasserstein loss can encourage smoothness of the predictions with respect to a chosen metric on the output space. We demonstrate this property on a real-data tag prediction problem, using the Yahoo Flickr Creative Commons dataset, outperforming a baseline that doesn't use the metric.

This paper has not been read by Pith yet.

discussion (0)

Sign in with ORCID, Apple, or X to comment. Anyone can read and Pith papers without signing in.