From 05c72baeb72c44f7923e548c4e1dbc2fe13fed22 Mon Sep 17 00:00:00 2001 From: williamleif Date: Tue, 10 Oct 2017 15:46:12 -0700 Subject: [PATCH] Added mean pooling. --- graphsage/aggregators.py | 87 +++++++++++++++++++++++++++++++-- graphsage/models.py | 10 ++-- graphsage/supervised_models.py | 10 ++-- graphsage/supervised_train.py | 23 +++++++-- graphsage/unsupervised_train.py | 19 ++++++- 5 files changed, 130 insertions(+), 19 deletions(-) diff --git a/graphsage/aggregators.py b/graphsage/aggregators.py index 705ec69..7dbd252 100644 --- a/graphsage/aggregators.py +++ b/graphsage/aggregators.py @@ -116,12 +116,12 @@ class GCNAggregator(Layer): return self.act(output) -class PoolingAggregator(Layer): +class MaxPoolingAggregator(Layer): """ Aggregates via max-pooling over MLP functions. """ def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None, dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs): - super(PoolingAggregator, self).__init__(**kwargs) + super(MaxPoolingAggregator, self).__init__(**kwargs) self.dropout = dropout self.bias = bias @@ -194,12 +194,91 @@ class PoolingAggregator(Layer): return self.act(output) -class TwoLayerPoolingAggregator(Layer): +class MeanPoolingAggregator(Layer): + """ Aggregates via mean-pooling over MLP functions. + """ + def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None, + dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs): + super(MeanPoolingAggregator, self).__init__(**kwargs) + + self.dropout = dropout + self.bias = bias + self.act = act + self.concat = concat + + if neigh_input_dim is None: + neigh_input_dim = input_dim + + if name is not None: + name = '/' + name + else: + name = '' + + if model_size == "small": + hidden_dim = self.hidden_dim = 512 + elif model_size == "big": + hidden_dim = self.hidden_dim = 1024 + + self.mlp_layers = [] + self.mlp_layers.append(Dense(input_dim=neigh_input_dim, + output_dim=hidden_dim, + act=tf.nn.relu, + dropout=dropout, + sparse_inputs=False, + logging=self.logging)) + + with tf.variable_scope(self.name + name + '_vars'): + self.vars['neigh_weights'] = glorot([hidden_dim, output_dim], + name='neigh_weights') + + self.vars['self_weights'] = glorot([input_dim, output_dim], + name='self_weights') + if self.bias: + self.vars['bias'] = zeros([self.output_dim], name='bias') + + if self.logging: + self._log_vars() + + self.input_dim = input_dim + self.output_dim = output_dim + self.neigh_input_dim = neigh_input_dim + + def _call(self, inputs): + self_vecs, neigh_vecs = inputs + neigh_h = neigh_vecs + + dims = tf.shape(neigh_h) + batch_size = dims[0] + num_neighbors = dims[1] + # [nodes * sampled neighbors] x [hidden_dim] + h_reshaped = tf.reshape(neigh_h, (batch_size * num_neighbors, self.neigh_input_dim)) + + for l in self.mlp_layers: + h_reshaped = l(h_reshaped) + neigh_h = tf.reshape(h_reshaped, (batch_size, num_neighbors, self.hidden_dim)) + neigh_h = tf.reduce_mean(neigh_h, axis=1) + + from_neighs = tf.matmul(neigh_h, self.vars['neigh_weights']) + from_self = tf.matmul(self_vecs, self.vars["self_weights"]) + + if not self.concat: + output = tf.add_n([from_self, from_neighs]) + else: + output = tf.concat([from_self, from_neighs], axis=1) + + # bias + if self.bias: + output += self.vars['bias'] + + return self.act(output) + + +class TwoMaxLayerPoolingAggregator(Layer): """ Aggregates via pooling over two MLP functions. """ def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None, dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs): - super(TwoLayerPoolingAggregator, self).__init__(**kwargs) + super(TwoMaxLayerPoolingAggregator, self).__init__(**kwargs) self.dropout = dropout self.bias = bias diff --git a/graphsage/models.py b/graphsage/models.py index 695cbc5..e9fe791 100644 --- a/graphsage/models.py +++ b/graphsage/models.py @@ -7,7 +7,7 @@ import graphsage.layers as layers import graphsage.metrics as metrics from .prediction import BipartiteEdgePredLayer -from .aggregators import MeanAggregator, PoolingAggregator, SeqAggregator, GCNAggregator, TwoLayerPoolingAggregator +from .aggregators import MeanAggregator, MaxPoolingAggregator, MeanPoolingAggregator, SeqAggregator, GCNAggregator flags = tf.app.flags FLAGS = flags.FLAGS @@ -212,10 +212,10 @@ class SampleAndAggregate(GeneralizedModel): self.aggregator_cls = MeanAggregator elif aggregator_type == "seq": self.aggregator_cls = SeqAggregator - elif aggregator_type == "pool": - self.aggregator_cls = PoolingAggregator - elif aggregator_type == "pool_2": - self.aggregator_cls = TwoLayerPoolingAggregator + elif aggregator_type == "maxpool": + self.aggregator_cls = MaxPoolingAggregator + elif aggregator_type == "meanpool": + self.aggregator_cls = MeanPoolingAggregator elif aggregator_type == "gcn": self.aggregator_cls = GCNAggregator else: diff --git a/graphsage/supervised_models.py b/graphsage/supervised_models.py index 08fc01e..9ea123c 100644 --- a/graphsage/supervised_models.py +++ b/graphsage/supervised_models.py @@ -2,7 +2,7 @@ import tensorflow as tf import graphsage.models as models import graphsage.layers as layers -from graphsage.aggregators import MeanAggregator, PoolingAggregator, SeqAggregator, GCNAggregator, TwoLayerPoolingAggregator +from graphsage.aggregators import MeanAggregator, MaxPoolingAggregator, MeanPoolingAggregator, SeqAggregator, GCNAggregator flags = tf.app.flags FLAGS = flags.FLAGS @@ -35,10 +35,10 @@ class SupervisedGraphsage(models.SampleAndAggregate): self.aggregator_cls = MeanAggregator elif aggregator_type == "seq": self.aggregator_cls = SeqAggregator - elif aggregator_type == "pool": - self.aggregator_cls = PoolingAggregator - elif aggregator_type == "pool_2": - self.aggregator_cls = TwoLayerPoolingAggregator + elif aggregator_type == "meanpool": + self.aggregator_cls = MeanPoolingAggregator + elif aggregator_type == "maxpool": + self.aggregator_cls = MaxPoolingAggregator elif aggregator_type == "gcn": self.aggregator_cls = GCNAggregator else: diff --git a/graphsage/supervised_train.py b/graphsage/supervised_train.py index 75d3451..9580149 100644 --- a/graphsage/supervised_train.py +++ b/graphsage/supervised_train.py @@ -39,8 +39,8 @@ flags.DEFINE_float('dropout', 0.0, 'dropout rate (1 - keep probability).') flags.DEFINE_float('weight_decay', 0.0, 'weight for l2 loss on embedding matrix.') flags.DEFINE_integer('max_degree', 128, 'maximum node degree.') flags.DEFINE_integer('samples_1', 25, 'number of samples in layer 1') -flags.DEFINE_integer('samples_2', 10, 'number of users samples in layer 2') -flags.DEFINE_integer('samples_3', 0, 'number of users samples in layer 3. (Only or mean model)') +flags.DEFINE_integer('samples_2', 10, 'number of samples in layer 2') +flags.DEFINE_integer('samples_3', 0, 'number of users samples in layer 3. (Only for mean model)') flags.DEFINE_integer('dim_1', 128, 'Size of output dim (final is 2x this, if using concat)') flags.DEFINE_integer('dim_2', 128, 'Size of output dim (final is 2x this, if using concat)') flags.DEFINE_boolean('random_context', True, 'Whether to use random context or direct edges') @@ -202,7 +202,7 @@ def train(train_data, test_data=None): identity_dim = FLAGS.identity_dim, logging=True) - elif FLAGS.model == 'graphsage_pool': + elif FLAGS.model == 'graphsage_maxpool': sampler = UniformNeighborSampler(adj_info) layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)] @@ -217,6 +217,23 @@ def train(train_data, test_data=None): sigmoid_loss = FLAGS.sigmoid, identity_dim = FLAGS.identity_dim, logging=True) + + elif FLAGS.model == 'graphsage_meanpool': + sampler = UniformNeighborSampler(adj_info) + layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), + SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)] + + model = SupervisedGraphsage(num_classes, placeholders, + features, + adj_info, + minibatch.deg, + layer_infos=layer_infos, + aggregator_type="meanpool", + model_size=FLAGS.model_size, + sigmoid_loss = FLAGS.sigmoid, + identity_dim = FLAGS.identity_dim, + logging=True) + else: raise Exception('Error: model name unrecognized.') diff --git a/graphsage/unsupervised_train.py b/graphsage/unsupervised_train.py index f3f6737..b3162fd 100644 --- a/graphsage/unsupervised_train.py +++ b/graphsage/unsupervised_train.py @@ -194,7 +194,7 @@ def train(train_data, test_data=None): model_size=FLAGS.model_size, logging=True) - elif FLAGS.model == 'graphsage_pool': + elif FLAGS.model == 'graphsage_maxpool': sampler = UniformNeighborSampler(adj_info) layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)] @@ -204,10 +204,25 @@ def train(train_data, test_data=None): adj_info, minibatch.deg, layer_infos=layer_infos, - aggregator_type="pool", + aggregator_type="maxpool", model_size=FLAGS.model_size, identity_dim = FLAGS.identity_dim, logging=True) + elif FLAGS.model == 'graphsage_meanpool': + sampler = UniformNeighborSampler(adj_info) + layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), + SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)] + + model = SampleAndAggregate(placeholders, + features, + adj_info, + minibatch.deg, + layer_infos=layer_infos, + aggregator_type="meanpool", + model_size=FLAGS.model_size, + identity_dim = FLAGS.identity_dim, + logging=True) + elif FLAGS.model == 'n2v': model = Node2VecModel(placeholders, features.shape[0], minibatch.deg,