101 lines
4.0 KiB
Python
101 lines
4.0 KiB
Python
|
import tensorflow as tf
|
||
|
|
||
|
import graphsage.models as models
|
||
|
import graphsage.layers as layers
|
||
|
from graphsage.aggregators import MeanAggregator, PoolingAggregator, SeqAggregator, GCNAggregator, TwoLayerPoolingAggregator
|
||
|
|
||
|
flags = tf.app.flags
|
||
|
FLAGS = flags.FLAGS
|
||
|
|
||
|
class SupervisedGraphsage(models.SampleAndAggregate):
|
||
|
def __init__(self, num_classes,
|
||
|
placeholders, features, adj, degrees,
|
||
|
layer_infos, concat=True, aggregator_type="mean",
|
||
|
model_size="small", sigmoid_loss=False,
|
||
|
**kwargs):
|
||
|
models.GeneralizedModel.__init__(self, **kwargs)
|
||
|
|
||
|
if aggregator_type == "mean":
|
||
|
self.aggregator_cls = MeanAggregator
|
||
|
elif aggregator_type == "seq":
|
||
|
self.aggregator_cls = SeqAggregator
|
||
|
elif aggregator_type == "pool":
|
||
|
self.aggregator_cls = PoolingAggregator
|
||
|
elif aggregator_type == "pool_2":
|
||
|
self.aggregator_cls = TwoLayerPoolingAggregator
|
||
|
elif aggregator_type == "gcn":
|
||
|
self.aggregator_cls = GCNAggregator
|
||
|
else:
|
||
|
raise Exception("Unknown aggregator: ", self.aggregator_cls)
|
||
|
|
||
|
# get info from placeholders...
|
||
|
self.inputs1 = placeholders["batch"]
|
||
|
self.model_size = model_size
|
||
|
self.adj_info = adj
|
||
|
self.features = tf.Variable(tf.constant(features, dtype=tf.float32), trainable=False)
|
||
|
self.degrees = degrees
|
||
|
self.concat = concat
|
||
|
self.num_classes = num_classes
|
||
|
self.sigmoid_loss = sigmoid_loss
|
||
|
|
||
|
self.dims = [features.shape[1]]
|
||
|
self.dims.extend([layer_infos[i].output_dim for i in range(len(layer_infos))])
|
||
|
self.batch_size = placeholders["batch_size"]
|
||
|
self.placeholders = placeholders
|
||
|
self.layer_infos = layer_infos
|
||
|
|
||
|
self.optimizer = tf.train.AdamOptimizer(learning_rate=FLAGS.learning_rate)
|
||
|
|
||
|
self.build()
|
||
|
|
||
|
|
||
|
def build(self):
|
||
|
samples1, support_sizes1 = self.sample(self.inputs1, self.layer_infos)
|
||
|
num_samples = [layer_info.num_samples for layer_info in self.layer_infos]
|
||
|
self.outputs1, self.aggregators = self.aggregate(samples1, [self.features], self.dims, num_samples,
|
||
|
support_sizes1, concat=self.concat, model_size=self.model_size)
|
||
|
dim_mult = 2 if self.concat else 1
|
||
|
|
||
|
self.outputs1 = tf.nn.l2_normalize(self.outputs1, 1)
|
||
|
|
||
|
dim_mult = 2 if self.concat else 1
|
||
|
self.node_pred = layers.Dense(dim_mult*self.dims[-1], self.num_classes,
|
||
|
dropout=self.placeholders['dropout'],
|
||
|
act=lambda x : x)
|
||
|
# TF graph management
|
||
|
self.node_preds = self.node_pred(self.outputs1)
|
||
|
|
||
|
self._loss()
|
||
|
grads_and_vars = self.optimizer.compute_gradients(self.loss)
|
||
|
clipped_grads_and_vars = [(tf.clip_by_value(grad, -5.0, 5.0) if grad is not None else None, var)
|
||
|
for grad, var in grads_and_vars]
|
||
|
self.grad, _ = clipped_grads_and_vars[0]
|
||
|
self.opt_op = self.optimizer.apply_gradients(clipped_grads_and_vars)
|
||
|
self.preds = self.predict()
|
||
|
|
||
|
def _loss(self):
|
||
|
# Weight decay loss
|
||
|
for aggregator in self.aggregators:
|
||
|
for var in aggregator.vars.values():
|
||
|
self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var)
|
||
|
for var in self.node_pred.vars.values():
|
||
|
self.loss += FLAGS.weight_decay * tf.nn.l2_loss(var)
|
||
|
|
||
|
# classification loss
|
||
|
if self.sigmoid_loss:
|
||
|
self.loss += tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(
|
||
|
logits=self.node_preds,
|
||
|
labels=self.placeholders['labels']))
|
||
|
else:
|
||
|
self.loss += tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(
|
||
|
logits=self.node_preds,
|
||
|
labels=self.placeholders['labels']))
|
||
|
|
||
|
tf.summary.scalar('loss', self.loss)
|
||
|
|
||
|
def predict(self):
|
||
|
if self.sigmoid_loss:
|
||
|
return tf.nn.sigmoid(self.node_preds)
|
||
|
else:
|
||
|
return tf.nn.softmax(self.node_preds)
|