Fixed weight decay support
This commit is contained in:
parent
75794f8a09
commit
2bb809dba5
@ -387,7 +387,7 @@ class SampleAndAggregate(GeneralizedModel):
|
|||||||
for var in aggregator.vars.values():
|
for var in aggregator.vars.values():
|
||||||
self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var)
|
self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var)
|
||||||
|
|
||||||
self.loss = self.link_pred_layer.loss(self.outputs1, self.outputs2, self.neg_outputs)
|
self.loss += self.link_pred_layer.loss(self.outputs1, self.outputs2, self.neg_outputs)
|
||||||
tf.summary.scalar('loss', self.loss)
|
tf.summary.scalar('loss', self.loss)
|
||||||
|
|
||||||
def _accuracy(self):
|
def _accuracy(self):
|
||||||
|
Loading…
Reference in New Issue
Block a user