diff --git a/graphsage/models.py b/graphsage/models.py index 1c8490b..b3b9db4 100644 --- a/graphsage/models.py +++ b/graphsage/models.py @@ -477,7 +477,7 @@ class Node2VecModel(GeneralizedModel): def _loss(self): aff = tf.reduce_sum(tf.multiply(self.outputs1, self.outputs2), 1) + self.outputs2_bias - neg_aff = tf.matmul(self.outputs2, tf.transpose(self.neg_outputs)) + self.neg_outputs_bias + neg_aff = tf.matmul(self.outputs1, tf.transpose(self.neg_outputs)) + self.neg_outputs_bias true_xent = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(aff), logits=aff) negative_xent = tf.nn.sigmoid_cross_entropy_with_logits(