xent, skipgram and max margin loss options in predictions.py
This commit is contained in:
parent
d77df9ef65
commit
bc4f2570a3
@ -11,6 +11,7 @@ FLAGS = flags.FLAGS
|
||||
|
||||
class BipartiteEdgePredLayer(Layer):
|
||||
def __init__(self, input_dim1, input_dim2, placeholders, dropout=False, act=tf.nn.sigmoid,
|
||||
loss_fn='xent',
|
||||
bias=False, bilinear_weights=False, **kwargs):
|
||||
"""
|
||||
Basic class that applies skip-gram-like loss
|
||||
@ -26,6 +27,10 @@ class BipartiteEdgePredLayer(Layer):
|
||||
self.act = act
|
||||
self.bias = bias
|
||||
self.eps = 1e-7
|
||||
|
||||
# Margin for hinge loss
|
||||
self.margin = 0.1
|
||||
|
||||
self.bilinear_weights = bilinear_weights
|
||||
|
||||
if dropout:
|
||||
@ -49,6 +54,13 @@ class BipartiteEdgePredLayer(Layer):
|
||||
if self.bias:
|
||||
self.vars['bias'] = zeros([self.output_dim], name='bias')
|
||||
|
||||
if loss_fn == 'xent':
|
||||
self.loss_fn = self._xent_loss
|
||||
elif loss_fn == 'skipgram':
|
||||
self.loss_fn = self._skipgram_loss
|
||||
elif loss_fn == 'hinge':
|
||||
self.loss_fn = self._hinge_loss
|
||||
|
||||
if self.logging:
|
||||
self._log_vars()
|
||||
|
||||
@ -66,7 +78,7 @@ class BipartiteEdgePredLayer(Layer):
|
||||
result = tf.reduce_sum(inputs1 * inputs2, axis=1)
|
||||
return result
|
||||
|
||||
def neg_cost(self, inputs1, neg_samples):
|
||||
def neg_cost(self, inputs1, neg_samples, hard_neg_samples=None):
|
||||
""" For each input in batch, compute the sum of its affinity to negative samples.
|
||||
|
||||
Returns:
|
||||
@ -84,16 +96,32 @@ class BipartiteEdgePredLayer(Layer):
|
||||
neg_samples: tensor of shape [num_neg_samples x input_dim2]. Negative samples for all
|
||||
inputs in batch inputs1.
|
||||
"""
|
||||
return self.loss_fn(inputs1, inputs2, neg_samples)
|
||||
|
||||
def _xent_loss(self, inputs1, inputs2, neg_samples, hard_neg_samples=None):
|
||||
aff = self.affinity(inputs1, inputs2)
|
||||
neg_aff = self.neg_cost(inputs1, neg_samples)
|
||||
neg_aff = self.neg_cost(inputs1, neg_samples, hard_neg_samples)
|
||||
true_xent = tf.nn.sigmoid_cross_entropy_with_logits(
|
||||
labels=tf.ones_like(aff), logits=aff)
|
||||
negative_xent = tf.nn.sigmoid_cross_entropy_with_logits(
|
||||
labels=tf.zeros_like(neg_aff), logits=neg_aff)
|
||||
loss = tf.reduce_sum(true_xent) + tf.reduce_sum(negative_xent)
|
||||
loss = tf.reduce_sum(true_xent) + 0.01*tf.reduce_sum(negative_xent)
|
||||
return loss
|
||||
|
||||
return loss
|
||||
def _skipgram_loss(self, inputs1, inputs2, neg_samples, hard_neg_samples=None):
|
||||
aff = self.affinity(inputs1, inputs2)
|
||||
neg_aff = self.neg_cost(inputs1, neg_samples, hard_neg_samples)
|
||||
neg_cost = tf.log(tf.reduce_sum(tf.exp(neg_aff), axis=1))
|
||||
loss = tf.reduce_sum(aff - neg_cost)
|
||||
return loss
|
||||
|
||||
def _hinge_loss(self, inputs1, inputs2, neg_samples, hard_neg_samples=None):
|
||||
aff = self.affinity(inputs1, inputs2)
|
||||
neg_aff = self.neg_cost(inputs1, neg_samples, hard_neg_samples)
|
||||
diff = tf.nn.relu(tf.subtract(neg_aff, tf.expand_dims(aff, 1) - self.margin), name='diff')
|
||||
loss = tf.reduce_sum(diff)
|
||||
self.neg_shape = tf.shape(neg_aff)
|
||||
return loss
|
||||
|
||||
def weights_norm(self):
|
||||
return tf.nn.l2_norm(self.vars['weights'])
|
||||
|
Loading…
Reference in New Issue
Block a user