Added mean pooling.

This commit is contained in:
williamleif 2017-10-10 15:46:12 -07:00
parent b142770576
commit 05c72baeb7
5 changed files with 130 additions and 19 deletions

View File

@ -116,12 +116,12 @@ class GCNAggregator(Layer):
return self.act(output) return self.act(output)
class PoolingAggregator(Layer): class MaxPoolingAggregator(Layer):
""" Aggregates via max-pooling over MLP functions. """ Aggregates via max-pooling over MLP functions.
""" """
def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None, def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None,
dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs): dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs):
super(PoolingAggregator, self).__init__(**kwargs) super(MaxPoolingAggregator, self).__init__(**kwargs)
self.dropout = dropout self.dropout = dropout
self.bias = bias self.bias = bias
@ -194,12 +194,91 @@ class PoolingAggregator(Layer):
return self.act(output) return self.act(output)
class TwoLayerPoolingAggregator(Layer): class MeanPoolingAggregator(Layer):
""" Aggregates via mean-pooling over MLP functions.
"""
def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None,
dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs):
super(MeanPoolingAggregator, self).__init__(**kwargs)
self.dropout = dropout
self.bias = bias
self.act = act
self.concat = concat
if neigh_input_dim is None:
neigh_input_dim = input_dim
if name is not None:
name = '/' + name
else:
name = ''
if model_size == "small":
hidden_dim = self.hidden_dim = 512
elif model_size == "big":
hidden_dim = self.hidden_dim = 1024
self.mlp_layers = []
self.mlp_layers.append(Dense(input_dim=neigh_input_dim,
output_dim=hidden_dim,
act=tf.nn.relu,
dropout=dropout,
sparse_inputs=False,
logging=self.logging))
with tf.variable_scope(self.name + name + '_vars'):
self.vars['neigh_weights'] = glorot([hidden_dim, output_dim],
name='neigh_weights')
self.vars['self_weights'] = glorot([input_dim, output_dim],
name='self_weights')
if self.bias:
self.vars['bias'] = zeros([self.output_dim], name='bias')
if self.logging:
self._log_vars()
self.input_dim = input_dim
self.output_dim = output_dim
self.neigh_input_dim = neigh_input_dim
def _call(self, inputs):
self_vecs, neigh_vecs = inputs
neigh_h = neigh_vecs
dims = tf.shape(neigh_h)
batch_size = dims[0]
num_neighbors = dims[1]
# [nodes * sampled neighbors] x [hidden_dim]
h_reshaped = tf.reshape(neigh_h, (batch_size * num_neighbors, self.neigh_input_dim))
for l in self.mlp_layers:
h_reshaped = l(h_reshaped)
neigh_h = tf.reshape(h_reshaped, (batch_size, num_neighbors, self.hidden_dim))
neigh_h = tf.reduce_mean(neigh_h, axis=1)
from_neighs = tf.matmul(neigh_h, self.vars['neigh_weights'])
from_self = tf.matmul(self_vecs, self.vars["self_weights"])
if not self.concat:
output = tf.add_n([from_self, from_neighs])
else:
output = tf.concat([from_self, from_neighs], axis=1)
# bias
if self.bias:
output += self.vars['bias']
return self.act(output)
class TwoMaxLayerPoolingAggregator(Layer):
""" Aggregates via pooling over two MLP functions. """ Aggregates via pooling over two MLP functions.
""" """
def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None, def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None,
dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs): dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs):
super(TwoLayerPoolingAggregator, self).__init__(**kwargs) super(TwoMaxLayerPoolingAggregator, self).__init__(**kwargs)
self.dropout = dropout self.dropout = dropout
self.bias = bias self.bias = bias

View File

@ -7,7 +7,7 @@ import graphsage.layers as layers
import graphsage.metrics as metrics import graphsage.metrics as metrics
from .prediction import BipartiteEdgePredLayer from .prediction import BipartiteEdgePredLayer
from .aggregators import MeanAggregator, PoolingAggregator, SeqAggregator, GCNAggregator, TwoLayerPoolingAggregator from .aggregators import MeanAggregator, MaxPoolingAggregator, MeanPoolingAggregator, SeqAggregator, GCNAggregator
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
@ -212,10 +212,10 @@ class SampleAndAggregate(GeneralizedModel):
self.aggregator_cls = MeanAggregator self.aggregator_cls = MeanAggregator
elif aggregator_type == "seq": elif aggregator_type == "seq":
self.aggregator_cls = SeqAggregator self.aggregator_cls = SeqAggregator
elif aggregator_type == "pool": elif aggregator_type == "maxpool":
self.aggregator_cls = PoolingAggregator self.aggregator_cls = MaxPoolingAggregator
elif aggregator_type == "pool_2": elif aggregator_type == "meanpool":
self.aggregator_cls = TwoLayerPoolingAggregator self.aggregator_cls = MeanPoolingAggregator
elif aggregator_type == "gcn": elif aggregator_type == "gcn":
self.aggregator_cls = GCNAggregator self.aggregator_cls = GCNAggregator
else: else:

View File

@ -2,7 +2,7 @@ import tensorflow as tf
import graphsage.models as models import graphsage.models as models
import graphsage.layers as layers import graphsage.layers as layers
from graphsage.aggregators import MeanAggregator, PoolingAggregator, SeqAggregator, GCNAggregator, TwoLayerPoolingAggregator from graphsage.aggregators import MeanAggregator, MaxPoolingAggregator, MeanPoolingAggregator, SeqAggregator, GCNAggregator
flags = tf.app.flags flags = tf.app.flags
FLAGS = flags.FLAGS FLAGS = flags.FLAGS
@ -35,10 +35,10 @@ class SupervisedGraphsage(models.SampleAndAggregate):
self.aggregator_cls = MeanAggregator self.aggregator_cls = MeanAggregator
elif aggregator_type == "seq": elif aggregator_type == "seq":
self.aggregator_cls = SeqAggregator self.aggregator_cls = SeqAggregator
elif aggregator_type == "pool": elif aggregator_type == "meanpool":
self.aggregator_cls = PoolingAggregator self.aggregator_cls = MeanPoolingAggregator
elif aggregator_type == "pool_2": elif aggregator_type == "maxpool":
self.aggregator_cls = TwoLayerPoolingAggregator self.aggregator_cls = MaxPoolingAggregator
elif aggregator_type == "gcn": elif aggregator_type == "gcn":
self.aggregator_cls = GCNAggregator self.aggregator_cls = GCNAggregator
else: else:

View File

@ -39,8 +39,8 @@ flags.DEFINE_float('dropout', 0.0, 'dropout rate (1 - keep probability).')
flags.DEFINE_float('weight_decay', 0.0, 'weight for l2 loss on embedding matrix.') flags.DEFINE_float('weight_decay', 0.0, 'weight for l2 loss on embedding matrix.')
flags.DEFINE_integer('max_degree', 128, 'maximum node degree.') flags.DEFINE_integer('max_degree', 128, 'maximum node degree.')
flags.DEFINE_integer('samples_1', 25, 'number of samples in layer 1') flags.DEFINE_integer('samples_1', 25, 'number of samples in layer 1')
flags.DEFINE_integer('samples_2', 10, 'number of users samples in layer 2') flags.DEFINE_integer('samples_2', 10, 'number of samples in layer 2')
flags.DEFINE_integer('samples_3', 0, 'number of users samples in layer 3. (Only or mean model)') flags.DEFINE_integer('samples_3', 0, 'number of users samples in layer 3. (Only for mean model)')
flags.DEFINE_integer('dim_1', 128, 'Size of output dim (final is 2x this, if using concat)') flags.DEFINE_integer('dim_1', 128, 'Size of output dim (final is 2x this, if using concat)')
flags.DEFINE_integer('dim_2', 128, 'Size of output dim (final is 2x this, if using concat)') flags.DEFINE_integer('dim_2', 128, 'Size of output dim (final is 2x this, if using concat)')
flags.DEFINE_boolean('random_context', True, 'Whether to use random context or direct edges') flags.DEFINE_boolean('random_context', True, 'Whether to use random context or direct edges')
@ -202,7 +202,7 @@ def train(train_data, test_data=None):
identity_dim = FLAGS.identity_dim, identity_dim = FLAGS.identity_dim,
logging=True) logging=True)
elif FLAGS.model == 'graphsage_pool': elif FLAGS.model == 'graphsage_maxpool':
sampler = UniformNeighborSampler(adj_info) sampler = UniformNeighborSampler(adj_info)
layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)] SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
@ -217,6 +217,23 @@ def train(train_data, test_data=None):
sigmoid_loss = FLAGS.sigmoid, sigmoid_loss = FLAGS.sigmoid,
identity_dim = FLAGS.identity_dim, identity_dim = FLAGS.identity_dim,
logging=True) logging=True)
elif FLAGS.model == 'graphsage_meanpool':
sampler = UniformNeighborSampler(adj_info)
layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
model = SupervisedGraphsage(num_classes, placeholders,
features,
adj_info,
minibatch.deg,
layer_infos=layer_infos,
aggregator_type="meanpool",
model_size=FLAGS.model_size,
sigmoid_loss = FLAGS.sigmoid,
identity_dim = FLAGS.identity_dim,
logging=True)
else: else:
raise Exception('Error: model name unrecognized.') raise Exception('Error: model name unrecognized.')

View File

@ -194,7 +194,7 @@ def train(train_data, test_data=None):
model_size=FLAGS.model_size, model_size=FLAGS.model_size,
logging=True) logging=True)
elif FLAGS.model == 'graphsage_pool': elif FLAGS.model == 'graphsage_maxpool':
sampler = UniformNeighborSampler(adj_info) sampler = UniformNeighborSampler(adj_info)
layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)] SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
@ -204,10 +204,25 @@ def train(train_data, test_data=None):
adj_info, adj_info,
minibatch.deg, minibatch.deg,
layer_infos=layer_infos, layer_infos=layer_infos,
aggregator_type="pool", aggregator_type="maxpool",
model_size=FLAGS.model_size, model_size=FLAGS.model_size,
identity_dim = FLAGS.identity_dim, identity_dim = FLAGS.identity_dim,
logging=True) logging=True)
elif FLAGS.model == 'graphsage_meanpool':
sampler = UniformNeighborSampler(adj_info)
layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
model = SampleAndAggregate(placeholders,
features,
adj_info,
minibatch.deg,
layer_infos=layer_infos,
aggregator_type="meanpool",
model_size=FLAGS.model_size,
identity_dim = FLAGS.identity_dim,
logging=True)
elif FLAGS.model == 'n2v': elif FLAGS.model == 'n2v':
model = Node2VecModel(placeholders, features.shape[0], model = Node2VecModel(placeholders, features.shape[0],
minibatch.deg, minibatch.deg,