diff --git a/graphsage/prediction.py b/graphsage/prediction.py index 2e73d4c..0c00c68 100644 --- a/graphsage/prediction.py +++ b/graphsage/prediction.py @@ -11,7 +11,7 @@ FLAGS = flags.FLAGS class BipartiteEdgePredLayer(Layer): def __init__(self, input_dim1, input_dim2, placeholders, dropout=False, act=tf.nn.sigmoid, - loss_fn='xent', + loss_fn='xent', neg_sample_weights=1.0, bias=False, bilinear_weights=False, **kwargs): """ Basic class that applies skip-gram-like loss @@ -30,6 +30,7 @@ class BipartiteEdgePredLayer(Layer): # Margin for hinge loss self.margin = 0.1 + self.neg_sample_weights = neg_sample_weights self.bilinear_weights = bilinear_weights @@ -105,7 +106,7 @@ class BipartiteEdgePredLayer(Layer): labels=tf.ones_like(aff), logits=aff) negative_xent = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.zeros_like(neg_aff), logits=neg_aff) - loss = tf.reduce_sum(true_xent) + 0.01*tf.reduce_sum(negative_xent) + loss = tf.reduce_sum(true_xent) + self.neg_sample_weights * tf.reduce_sum(negative_xent) return loss def _skipgram_loss(self, inputs1, inputs2, neg_samples, hard_neg_samples=None):