From bc4f2570a350d51a885da4d2a8beb1a8875c21a7 Mon Sep 17 00:00:00 2001 From: RexYing Date: Fri, 13 Oct 2017 14:50:36 -0700 Subject: [PATCH] xent, skipgram and max margin loss options in predictions.py --- graphsage/prediction.py | 36 ++++++++++++++++++++++++++++++++---- 1 file changed, 32 insertions(+), 4 deletions(-) diff --git a/graphsage/prediction.py b/graphsage/prediction.py index 9bf0885..2e73d4c 100644 --- a/graphsage/prediction.py +++ b/graphsage/prediction.py @@ -11,6 +11,7 @@ FLAGS = flags.FLAGS class BipartiteEdgePredLayer(Layer): def __init__(self, input_dim1, input_dim2, placeholders, dropout=False, act=tf.nn.sigmoid, + loss_fn='xent', bias=False, bilinear_weights=False, **kwargs): """ Basic class that applies skip-gram-like loss @@ -26,6 +27,10 @@ class BipartiteEdgePredLayer(Layer): self.act = act self.bias = bias self.eps = 1e-7 + + # Margin for hinge loss + self.margin = 0.1 + self.bilinear_weights = bilinear_weights if dropout: @@ -49,6 +54,13 @@ class BipartiteEdgePredLayer(Layer): if self.bias: self.vars['bias'] = zeros([self.output_dim], name='bias') + if loss_fn == 'xent': + self.loss_fn = self._xent_loss + elif loss_fn == 'skipgram': + self.loss_fn = self._skipgram_loss + elif loss_fn == 'hinge': + self.loss_fn = self._hinge_loss + if self.logging: self._log_vars() @@ -66,7 +78,7 @@ class BipartiteEdgePredLayer(Layer): result = tf.reduce_sum(inputs1 * inputs2, axis=1) return result - def neg_cost(self, inputs1, neg_samples): + def neg_cost(self, inputs1, neg_samples, hard_neg_samples=None): """ For each input in batch, compute the sum of its affinity to negative samples. Returns: @@ -84,16 +96,32 @@ class BipartiteEdgePredLayer(Layer): neg_samples: tensor of shape [num_neg_samples x input_dim2]. Negative samples for all inputs in batch inputs1. """ + return self.loss_fn(inputs1, inputs2, neg_samples) + def _xent_loss(self, inputs1, inputs2, neg_samples, hard_neg_samples=None): aff = self.affinity(inputs1, inputs2) - neg_aff = self.neg_cost(inputs1, neg_samples) + neg_aff = self.neg_cost(inputs1, neg_samples, hard_neg_samples) true_xent = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(aff), logits=aff) negative_xent = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.zeros_like(neg_aff), logits=neg_aff) - loss = tf.reduce_sum(true_xent) + tf.reduce_sum(negative_xent) + loss = tf.reduce_sum(true_xent) + 0.01*tf.reduce_sum(negative_xent) + return loss - return loss + def _skipgram_loss(self, inputs1, inputs2, neg_samples, hard_neg_samples=None): + aff = self.affinity(inputs1, inputs2) + neg_aff = self.neg_cost(inputs1, neg_samples, hard_neg_samples) + neg_cost = tf.log(tf.reduce_sum(tf.exp(neg_aff), axis=1)) + loss = tf.reduce_sum(aff - neg_cost) + return loss + + def _hinge_loss(self, inputs1, inputs2, neg_samples, hard_neg_samples=None): + aff = self.affinity(inputs1, inputs2) + neg_aff = self.neg_cost(inputs1, neg_samples, hard_neg_samples) + diff = tf.nn.relu(tf.subtract(neg_aff, tf.expand_dims(aff, 1) - self.margin), name='diff') + loss = tf.reduce_sum(diff) + self.neg_shape = tf.shape(neg_aff) + return loss def weights_norm(self): return tf.nn.l2_norm(self.vars['weights'])