xent, skipgram and max margin loss options in predictions.py

This commit is contained in:
RexYing 2017-10-13 14:50:36 -07:00
parent d77df9ef65
commit bc4f2570a3

View File

@ -11,6 +11,7 @@ FLAGS = flags.FLAGS
class BipartiteEdgePredLayer(Layer): class BipartiteEdgePredLayer(Layer):
def __init__(self, input_dim1, input_dim2, placeholders, dropout=False, act=tf.nn.sigmoid, def __init__(self, input_dim1, input_dim2, placeholders, dropout=False, act=tf.nn.sigmoid,
loss_fn='xent',
bias=False, bilinear_weights=False, **kwargs): bias=False, bilinear_weights=False, **kwargs):
""" """
Basic class that applies skip-gram-like loss Basic class that applies skip-gram-like loss
@ -26,6 +27,10 @@ class BipartiteEdgePredLayer(Layer):
self.act = act self.act = act
self.bias = bias self.bias = bias
self.eps = 1e-7 self.eps = 1e-7
# Margin for hinge loss
self.margin = 0.1
self.bilinear_weights = bilinear_weights self.bilinear_weights = bilinear_weights
if dropout: if dropout:
@ -49,6 +54,13 @@ class BipartiteEdgePredLayer(Layer):
if self.bias: if self.bias:
self.vars['bias'] = zeros([self.output_dim], name='bias') self.vars['bias'] = zeros([self.output_dim], name='bias')
if loss_fn == 'xent':
self.loss_fn = self._xent_loss
elif loss_fn == 'skipgram':
self.loss_fn = self._skipgram_loss
elif loss_fn == 'hinge':
self.loss_fn = self._hinge_loss
if self.logging: if self.logging:
self._log_vars() self._log_vars()
@ -66,7 +78,7 @@ class BipartiteEdgePredLayer(Layer):
result = tf.reduce_sum(inputs1 * inputs2, axis=1) result = tf.reduce_sum(inputs1 * inputs2, axis=1)
return result return result
def neg_cost(self, inputs1, neg_samples): def neg_cost(self, inputs1, neg_samples, hard_neg_samples=None):
""" For each input in batch, compute the sum of its affinity to negative samples. """ For each input in batch, compute the sum of its affinity to negative samples.
Returns: Returns:
@ -84,16 +96,32 @@ class BipartiteEdgePredLayer(Layer):
neg_samples: tensor of shape [num_neg_samples x input_dim2]. Negative samples for all neg_samples: tensor of shape [num_neg_samples x input_dim2]. Negative samples for all
inputs in batch inputs1. inputs in batch inputs1.
""" """
return self.loss_fn(inputs1, inputs2, neg_samples)
def _xent_loss(self, inputs1, inputs2, neg_samples, hard_neg_samples=None):
aff = self.affinity(inputs1, inputs2) aff = self.affinity(inputs1, inputs2)
neg_aff = self.neg_cost(inputs1, neg_samples) neg_aff = self.neg_cost(inputs1, neg_samples, hard_neg_samples)
true_xent = tf.nn.sigmoid_cross_entropy_with_logits( true_xent = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.ones_like(aff), logits=aff) labels=tf.ones_like(aff), logits=aff)
negative_xent = tf.nn.sigmoid_cross_entropy_with_logits( negative_xent = tf.nn.sigmoid_cross_entropy_with_logits(
labels=tf.zeros_like(neg_aff), logits=neg_aff) labels=tf.zeros_like(neg_aff), logits=neg_aff)
loss = tf.reduce_sum(true_xent) + tf.reduce_sum(negative_xent) loss = tf.reduce_sum(true_xent) + 0.01*tf.reduce_sum(negative_xent)
return loss
return loss def _skipgram_loss(self, inputs1, inputs2, neg_samples, hard_neg_samples=None):
aff = self.affinity(inputs1, inputs2)
neg_aff = self.neg_cost(inputs1, neg_samples, hard_neg_samples)
neg_cost = tf.log(tf.reduce_sum(tf.exp(neg_aff), axis=1))
loss = tf.reduce_sum(aff - neg_cost)
return loss
def _hinge_loss(self, inputs1, inputs2, neg_samples, hard_neg_samples=None):
aff = self.affinity(inputs1, inputs2)
neg_aff = self.neg_cost(inputs1, neg_samples, hard_neg_samples)
diff = tf.nn.relu(tf.subtract(neg_aff, tf.expand_dims(aff, 1) - self.margin), name='diff')
loss = tf.reduce_sum(diff)
self.neg_shape = tf.shape(neg_aff)
return loss
def weights_norm(self): def weights_norm(self):
return tf.nn.l2_norm(self.vars['weights']) return tf.nn.l2_norm(self.vars['weights'])