tf.contrib.losses.metric_learning.npairs_loss_multilabel( sparse_labels, embeddings_anchor, embeddings_positive, reg_lambda=0.002, print_losses=False )
Defined in tensorflow/contrib/losses/python/metric_learning/metric_loss_ops.py
.
Computes the npairs loss with multilabel data.
Npairs loss expects paired data where a pair is composed of samples from the same labels and each pairs in the minibatch have different labels. The loss has two components. The first component is the L2 regularizer on the embedding vectors. The second component is the sum of cross entropy loss which takes each row of the pair-wise similarity matrix as logits and the remapped one-hot labels as labels. Here, the similarity is defined by the dot product between two embedding vectors. S_{i,j} = f(x_i)^T f(x_j)
To deal with multilabel inputs, we use the count of label intersection i.e. L_{i,j} = | set_of_labels_for(i) \cap set_of_labels_for(j) | Then we normalize each rows of the count based label matrix so that each row sums to one.
sparse_labels
: List of 1-D Boolean SparseTensor
of dense_shape [batch_size/2, num_classes] labels for the anchor-pos pairs.embeddings_anchor
: 2-D Tensor
of shape [batch_size/2, embedding_dim] for the embedding vectors for the anchor images. Embeddings should not be l2 normalized.embeddings_positive
: 2-D Tensor
of shape [batch_size/2, embedding_dim] for the embedding vectors for the positive images. Embeddings should not be l2 normalized.reg_lambda
: Float. L2 regularization term on the embedding vectors.print_losses
: Boolean. Option to print the xent and l2loss.npairs_loss
: tf.float32 scalar.TypeError
: When the specified sparse_labels is not a SparseTensor
.
© 2018 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/api_docs/python/tf/contrib/losses/metric_learning/npairs_loss_multilabel