import tensorflow as tf import graphsage.models as models import graphsage.layers as layers from graphsage.aggregators import MeanAggregator, PoolingAggregator, SeqAggregator, GCNAggregator, TwoLayerPoolingAggregator flags = tf.app.flags FLAGS = flags.FLAGS class SupervisedGraphsage(models.SampleAndAggregate): def __init__(self, num_classes, placeholders, features, adj, degrees, layer_infos, concat=True, aggregator_type="mean", model_size="small", sigmoid_loss=False, **kwargs): models.GeneralizedModel.__init__(self, **kwargs) if aggregator_type == "mean": self.aggregator_cls = MeanAggregator elif aggregator_type == "seq": self.aggregator_cls = SeqAggregator elif aggregator_type == "pool": self.aggregator_cls = PoolingAggregator elif aggregator_type == "pool_2": self.aggregator_cls = TwoLayerPoolingAggregator elif aggregator_type == "gcn": self.aggregator_cls = GCNAggregator else: raise Exception("Unknown aggregator: ", self.aggregator_cls) # get info from placeholders... self.inputs1 = placeholders["batch"] self.model_size = model_size self.adj_info = adj self.features = tf.Variable(tf.constant(features, dtype=tf.float32), trainable=False) self.degrees = degrees self.concat = concat self.num_classes = num_classes self.sigmoid_loss = sigmoid_loss self.dims = [features.shape[1]] self.dims.extend([layer_infos[i].output_dim for i in range(len(layer_infos))]) self.batch_size = placeholders["batch_size"] self.placeholders = placeholders self.layer_infos = layer_infos self.optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate) self.build() def build(self): samples1, support_sizes1 = self.sample(self.inputs1, self.layer_infos) num_samples = [layer_info.num_samples for layer_info in self.layer_infos] self.outputs1, self.aggregators = self.aggregate(samples1, [self.features], self.dims, num_samples, support_sizes1, concat=self.concat, model_size=self.model_size) dim_mult = 2 if self.concat else 1 self.outputs1 = tf.nn.l2_normalize(self.outputs1, 1) dim_mult = 2 if self.concat else 1 self.node_pred = layers.Dense(dim_mult*self.dims[-1], self.num_classes, dropout=self.placeholders['dropout'], act=lambda x : x) # TF graph management self.node_preds = self.node_pred(self.outputs1) self._loss() grads_and_vars = self.optimizer.compute_gradients(self.loss) clipped_grads_and_vars = [(tf.clip_by_value(grad, -5.0, 5.0) if grad is not None else None, var) for grad, var in grads_and_vars] self.grad, _ = clipped_grads_and_vars[0] self.opt_op = self.optimizer.apply_gradients(clipped_grads_and_vars) self.preds = self.predict() def _loss(self): # Weight decay loss for aggregator in self.aggregators: for var in aggregator.vars.values(): self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var) for var in self.node_pred.vars.values(): self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var) # classification loss if self.sigmoid_loss: self.loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits( logits=self.node_preds, labels=self.placeholders['labels'])) else: self.loss += tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits( logits=self.node_preds, labels=self.placeholders['labels'])) tf.summary.scalar('loss', self.loss) def predict(self): if self.sigmoid_loss: return tf.nn.sigmoid(self.node_preds) else: return tf.nn.softmax(self.node_preds)