fix loss computation
This commit is contained in:
parent
bb355bb696
commit
826e715cb1
@ -11,7 +11,7 @@ FLAGS = flags.FLAGS
|
||||
|
||||
class BipartiteEdgePredLayer(Layer):
|
||||
def __init__(self, input_dim1, input_dim2, placeholders, dropout=False, act=tf.nn.sigmoid,
|
||||
loss_fn='xent',
|
||||
loss_fn='xent', neg_sample_weights=1.0,
|
||||
bias=False, bilinear_weights=False, **kwargs):
|
||||
"""
|
||||
Basic class that applies skip-gram-like loss
|
||||
@ -30,6 +30,7 @@ class BipartiteEdgePredLayer(Layer):
|
||||
|
||||
# Margin for hinge loss
|
||||
self.margin = 0.1
|
||||
self.neg_sample_weights = neg_sample_weights
|
||||
|
||||
self.bilinear_weights = bilinear_weights
|
||||
|
||||
@ -105,7 +106,7 @@ class BipartiteEdgePredLayer(Layer):
|
||||
labels=tf.ones_like(aff), logits=aff)
|
||||
negative_xent = tf.nn.sigmoid_cross_entropy_with_logits(
|
||||
labels=tf.zeros_like(neg_aff), logits=neg_aff)
|
||||
loss = tf.reduce_sum(true_xent) + 0.01*tf.reduce_sum(negative_xent)
|
||||
loss = tf.reduce_sum(true_xent) + self.neg_sample_weights * tf.reduce_sum(negative_xent)
|
||||
return loss
|
||||
|
||||
def _skipgram_loss(self, inputs1, inputs2, neg_samples, hard_neg_samples=None):
|
||||
|
Loading…
Reference in New Issue
Block a user