2017-05-29 23:35:30 +08:00
|
|
|
import tensorflow as tf
|
|
|
|
|
|
|
|
import graphsage.models as models
|
|
|
|
import graphsage.layers as layers
|
|
|
|
from graphsage.aggregators import MeanAggregator, PoolingAggregator, SeqAggregator, GCNAggregator, TwoLayerPoolingAggregator
|
|
|
|
|
|
|
|
flags = tf.app.flags
|
|
|
|
FLAGS = flags.FLAGS
|
|
|
|
|
|
|
|
class SupervisedGraphsage(models.SampleAndAggregate):
|
2017-05-31 21:39:04 +08:00
|
|
|
"""Implementation of supervised GraphSAGE."""
|
|
|
|
|
2017-05-29 23:35:30 +08:00
|
|
|
def __init__(self, num_classes,
|
|
|
|
placeholders, features, adj, degrees,
|
|
|
|
layer_infos, concat=True, aggregator_type="mean",
|
2017-09-17 05:17:14 +08:00
|
|
|
model_size="small", sigmoid_loss=False, identity_dim=0,
|
2017-05-29 23:35:30 +08:00
|
|
|
**kwargs):
|
2017-05-31 21:39:04 +08:00
|
|
|
'''
|
|
|
|
Args:
|
|
|
|
- placeholders: Stanford TensorFlow placeholder object.
|
|
|
|
- features: Numpy array with node features.
|
|
|
|
- adj: Numpy array with adjacency lists (padded with random re-samples)
|
|
|
|
- degrees: Numpy array with node degrees.
|
|
|
|
- layer_infos: List of SAGEInfo namedtuples that describe the parameters of all
|
|
|
|
the recursive layers. See SAGEInfo definition above.
|
|
|
|
- concat: whether to concatenate during recursive iterations
|
|
|
|
- aggregator_type: how to aggregate neighbor information
|
|
|
|
- model_size: one of "small" and "big"
|
|
|
|
- sigmoid_loss: Set to true if nodes can belong to multiple classes
|
|
|
|
'''
|
|
|
|
|
2017-05-29 23:35:30 +08:00
|
|
|
models.GeneralizedModel.__init__(self, **kwargs)
|
|
|
|
|
|
|
|
if aggregator_type == "mean":
|
|
|
|
self.aggregator_cls = MeanAggregator
|
|
|
|
elif aggregator_type == "seq":
|
|
|
|
self.aggregator_cls = SeqAggregator
|
|
|
|
elif aggregator_type == "pool":
|
|
|
|
self.aggregator_cls = PoolingAggregator
|
|
|
|
elif aggregator_type == "pool_2":
|
|
|
|
self.aggregator_cls = TwoLayerPoolingAggregator
|
|
|
|
elif aggregator_type == "gcn":
|
|
|
|
self.aggregator_cls = GCNAggregator
|
|
|
|
else:
|
|
|
|
raise Exception("Unknown aggregator: ", self.aggregator_cls)
|
|
|
|
|
|
|
|
# get info from placeholders...
|
|
|
|
self.inputs1 = placeholders["batch"]
|
|
|
|
self.model_size = model_size
|
|
|
|
self.adj_info = adj
|
2017-09-17 05:17:14 +08:00
|
|
|
if identity_dim > 0:
|
|
|
|
self.embeds = tf.get_variable("node_embeddings", [adj.get_shape().as_list()[0], identity_dim])
|
|
|
|
else:
|
|
|
|
self.embeds = None
|
|
|
|
if features is None:
|
2017-09-17 05:26:47 +08:00
|
|
|
if identity_dim == 0:
|
2017-09-17 05:17:14 +08:00
|
|
|
raise Exception("Must have a positive value for identity feature dimension if no input features given.")
|
|
|
|
self.features = self.embeds
|
|
|
|
else:
|
|
|
|
self.features = tf.Variable(tf.constant(features, dtype=tf.float32), trainable=False)
|
|
|
|
if not self.embeds is None:
|
|
|
|
self.features = tf.concat([self.embeds, self.features], axis=1)
|
2017-05-29 23:35:30 +08:00
|
|
|
self.degrees = degrees
|
|
|
|
self.concat = concat
|
|
|
|
self.num_classes = num_classes
|
|
|
|
self.sigmoid_loss = sigmoid_loss
|
2017-09-17 05:17:14 +08:00
|
|
|
self.dims = [(0 if features is None else features.shape[1]) + identity_dim]
|
2017-05-29 23:35:30 +08:00
|
|
|
self.dims.extend([layer_infos[i].output_dim for i in range(len(layer_infos))])
|
|
|
|
self.batch_size = placeholders["batch_size"]
|
|
|
|
self.placeholders = placeholders
|
|
|
|
self.layer_infos = layer_infos
|
|
|
|
|
|
|
|
self.optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
|
|
|
|
|
|
|
|
self.build()
|
|
|
|
|
|
|
|
|
|
|
|
def build(self):
|
|
|
|
samples1, support_sizes1 = self.sample(self.inputs1, self.layer_infos)
|
|
|
|
num_samples = [layer_info.num_samples for layer_info in self.layer_infos]
|
|
|
|
self.outputs1, self.aggregators = self.aggregate(samples1, [self.features], self.dims, num_samples,
|
|
|
|
support_sizes1, concat=self.concat, model_size=self.model_size)
|
|
|
|
dim_mult = 2 if self.concat else 1
|
|
|
|
|
|
|
|
self.outputs1 = tf.nn.l2_normalize(self.outputs1, 1)
|
|
|
|
|
|
|
|
dim_mult = 2 if self.concat else 1
|
|
|
|
self.node_pred = layers.Dense(dim_mult*self.dims[-1], self.num_classes,
|
|
|
|
dropout=self.placeholders['dropout'],
|
|
|
|
act=lambda x : x)
|
|
|
|
# TF graph management
|
|
|
|
self.node_preds = self.node_pred(self.outputs1)
|
|
|
|
|
|
|
|
self._loss()
|
|
|
|
grads_and_vars = self.optimizer.compute_gradients(self.loss)
|
|
|
|
clipped_grads_and_vars = [(tf.clip_by_value(grad, -5.0, 5.0) if grad is not None else None, var)
|
|
|
|
for grad, var in grads_and_vars]
|
|
|
|
self.grad, _ = clipped_grads_and_vars[0]
|
|
|
|
self.opt_op = self.optimizer.apply_gradients(clipped_grads_and_vars)
|
|
|
|
self.preds = self.predict()
|
|
|
|
|
|
|
|
def _loss(self):
|
|
|
|
# Weight decay loss
|
|
|
|
for aggregator in self.aggregators:
|
|
|
|
for var in aggregator.vars.values():
|
|
|
|
self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var)
|
|
|
|
for var in self.node_pred.vars.values():
|
|
|
|
self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var)
|
|
|
|
|
|
|
|
# classification loss
|
|
|
|
if self.sigmoid_loss:
|
|
|
|
self.loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
|
|
|
|
logits=self.node_preds,
|
|
|
|
labels=self.placeholders['labels']))
|
|
|
|
else:
|
|
|
|
self.loss += tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
|
|
|
|
logits=self.node_preds,
|
|
|
|
labels=self.placeholders['labels']))
|
|
|
|
|
|
|
|
tf.summary.scalar('loss', self.loss)
|
|
|
|
|
|
|
|
def predict(self):
|
|
|
|
if self.sigmoid_loss:
|
|
|
|
return tf.nn.sigmoid(self.node_preds)
|
|
|
|
else:
|
|
|
|
return tf.nn.softmax(self.node_preds)
|