diff --git a/graphsage/models.py b/graphsage/models.py index e3a904f..1c8490b 100644 --- a/graphsage/models.py +++ b/graphsage/models.py @@ -387,7 +387,7 @@ class SampleAndAggregate(GeneralizedModel): for var in aggregator.vars.values(): self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var) - self.loss = self.link_pred_layer.loss(self.outputs1, self.outputs2, self.neg_outputs) + self.loss += self.link_pred_layer.loss(self.outputs1, self.outputs2, self.neg_outputs) tf.summary.scalar('loss', self.loss) def _accuracy(self):