graphsage-tf/graphsage/metrics.py

43 lines
1.5 KiB
Python
Raw Normal View History

2017-05-29 23:35:30 +08:00
import tensorflow as tf
# DISCLAIMER:
# Parts of this code file were originally forked from
# https://github.com/tkipf/gcn
# which itself was very inspired by the keras package
# (A full license with de-anonymized attributions will be provided in the
# public repo of this code base)
def masked_logit_cross_entropy(preds, labels, mask):
loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=preds, labels=labels)
loss = tf.reduce_sum(loss, axis=1)
mask = tf.cast(mask, dtype=tf.float32)
mask /= tf.maximum(tf.reduce_sum(mask), tf.constant([1.]))
loss *= mask
return tf.reduce_mean(loss)
def masked_softmax_cross_entropy(preds, labels, mask):
loss = tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=labels)
# loss = tf.reduce_sum(loss, axis=1)
mask = tf.cast(mask, dtype=tf.float32)
mask /= tf.maximum(tf.reduce_sum(mask), tf.constant([1.]))
loss *= mask
return tf.reduce_mean(loss)
def masked_l2(preds, actuals, mask):
"""Softmax cross-entropy loss with masking."""
loss = tf.nn.l2(preds, actuals)
mask = tf.cast(mask, dtype=tf.float32)
mask /= tf.reduce_mean(mask)
loss *= mask
return tf.reduce_mean(loss)
def masked_accuracy(preds, labels, mask):
"""Accuracy with masking."""
correct_prediction = tf.equal(tf.argmax(preds, 1), tf.argmax(labels, 1))
accuracy_all = tf.cast(correct_prediction, tf.float32)
mask = tf.cast(mask, dtype=tf.float32)
mask /= tf.reduce_mean(mask)
accuracy_all *= mask
return tf.reduce_mean(accuracy_all)