fix loss computation

This commit is contained in:
RexYing 2017-11-03 11:43:57 -07:00
parent bb355bb696
commit 826e715cb1

View File

@ -11,7 +11,7 @@ FLAGS = flags.FLAGS
class BipartiteEdgePredLayer(Layer): class BipartiteEdgePredLayer(Layer):
def __init__(self, input_dim1, input_dim2, placeholders, dropout=False, act=tf.nn.sigmoid, def __init__(self, input_dim1, input_dim2, placeholders, dropout=False, act=tf.nn.sigmoid,
loss_fn='xent', loss_fn='xent', neg_sample_weights=1.0,
bias=False, bilinear_weights=False, **kwargs): bias=False, bilinear_weights=False, **kwargs):
""" """
Basic class that applies skip-gram-like loss Basic class that applies skip-gram-like loss
@ -30,6 +30,7 @@ class BipartiteEdgePredLayer(Layer):
# Margin for hinge loss # Margin for hinge loss
self.margin = 0.1 self.margin = 0.1
self.neg_sample_weights = neg_sample_weights
self.bilinear_weights = bilinear_weights self.bilinear_weights = bilinear_weights
@ -105,7 +106,7 @@ class BipartiteEdgePredLayer(Layer):
labels=tf.ones_like(aff), logits=aff) labels=tf.ones_like(aff), logits=aff)
negative_xent = tf.nn.sigmoid_cross_entropy_with_logits( negative_xent = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.zeros_like(neg_aff), logits=neg_aff) labels=tf.zeros_like(neg_aff), logits=neg_aff)
loss = tf.reduce_sum(true_xent) + 0.01*tf.reduce_sum(negative_xent) loss = tf.reduce_sum(true_xent) + self.neg_sample_weights * tf.reduce_sum(negative_xent)
return loss return loss
def _skipgram_loss(self, inputs1, inputs2, neg_samples, hard_neg_samples=None): def _skipgram_loss(self, inputs1, inputs2, neg_samples, hard_neg_samples=None):