Added mean pooling.
This commit is contained in:
parent
b142770576
commit
05c72baeb7
@ -116,12 +116,12 @@ class GCNAggregator(Layer):
|
||||
return self.act(output)
|
||||
|
||||
|
||||
class PoolingAggregator(Layer):
|
||||
class MaxPoolingAggregator(Layer):
|
||||
""" Aggregates via max-pooling over MLP functions.
|
||||
"""
|
||||
def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None,
|
||||
dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs):
|
||||
super(PoolingAggregator, self).__init__(**kwargs)
|
||||
super(MaxPoolingAggregator, self).__init__(**kwargs)
|
||||
|
||||
self.dropout = dropout
|
||||
self.bias = bias
|
||||
@ -194,12 +194,91 @@ class PoolingAggregator(Layer):
|
||||
|
||||
return self.act(output)
|
||||
|
||||
class TwoLayerPoolingAggregator(Layer):
|
||||
class MeanPoolingAggregator(Layer):
|
||||
""" Aggregates via mean-pooling over MLP functions.
|
||||
"""
|
||||
def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None,
|
||||
dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs):
|
||||
super(MeanPoolingAggregator, self).__init__(**kwargs)
|
||||
|
||||
self.dropout = dropout
|
||||
self.bias = bias
|
||||
self.act = act
|
||||
self.concat = concat
|
||||
|
||||
if neigh_input_dim is None:
|
||||
neigh_input_dim = input_dim
|
||||
|
||||
if name is not None:
|
||||
name = '/' + name
|
||||
else:
|
||||
name = ''
|
||||
|
||||
if model_size == "small":
|
||||
hidden_dim = self.hidden_dim = 512
|
||||
elif model_size == "big":
|
||||
hidden_dim = self.hidden_dim = 1024
|
||||
|
||||
self.mlp_layers = []
|
||||
self.mlp_layers.append(Dense(input_dim=neigh_input_dim,
|
||||
output_dim=hidden_dim,
|
||||
act=tf.nn.relu,
|
||||
dropout=dropout,
|
||||
sparse_inputs=False,
|
||||
logging=self.logging))
|
||||
|
||||
with tf.variable_scope(self.name + name + '_vars'):
|
||||
self.vars['neigh_weights'] = glorot([hidden_dim, output_dim],
|
||||
name='neigh_weights')
|
||||
|
||||
self.vars['self_weights'] = glorot([input_dim, output_dim],
|
||||
name='self_weights')
|
||||
if self.bias:
|
||||
self.vars['bias'] = zeros([self.output_dim], name='bias')
|
||||
|
||||
if self.logging:
|
||||
self._log_vars()
|
||||
|
||||
self.input_dim = input_dim
|
||||
self.output_dim = output_dim
|
||||
self.neigh_input_dim = neigh_input_dim
|
||||
|
||||
def _call(self, inputs):
|
||||
self_vecs, neigh_vecs = inputs
|
||||
neigh_h = neigh_vecs
|
||||
|
||||
dims = tf.shape(neigh_h)
|
||||
batch_size = dims[0]
|
||||
num_neighbors = dims[1]
|
||||
# [nodes * sampled neighbors] x [hidden_dim]
|
||||
h_reshaped = tf.reshape(neigh_h, (batch_size * num_neighbors, self.neigh_input_dim))
|
||||
|
||||
for l in self.mlp_layers:
|
||||
h_reshaped = l(h_reshaped)
|
||||
neigh_h = tf.reshape(h_reshaped, (batch_size, num_neighbors, self.hidden_dim))
|
||||
neigh_h = tf.reduce_mean(neigh_h, axis=1)
|
||||
|
||||
from_neighs = tf.matmul(neigh_h, self.vars['neigh_weights'])
|
||||
from_self = tf.matmul(self_vecs, self.vars["self_weights"])
|
||||
|
||||
if not self.concat:
|
||||
output = tf.add_n([from_self, from_neighs])
|
||||
else:
|
||||
output = tf.concat([from_self, from_neighs], axis=1)
|
||||
|
||||
# bias
|
||||
if self.bias:
|
||||
output += self.vars['bias']
|
||||
|
||||
return self.act(output)
|
||||
|
||||
|
||||
class TwoMaxLayerPoolingAggregator(Layer):
|
||||
""" Aggregates via pooling over two MLP functions.
|
||||
"""
|
||||
def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None,
|
||||
dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs):
|
||||
super(TwoLayerPoolingAggregator, self).__init__(**kwargs)
|
||||
super(TwoMaxLayerPoolingAggregator, self).__init__(**kwargs)
|
||||
|
||||
self.dropout = dropout
|
||||
self.bias = bias
|
||||
|
@ -7,7 +7,7 @@ import graphsage.layers as layers
|
||||
import graphsage.metrics as metrics
|
||||
|
||||
from .prediction import BipartiteEdgePredLayer
|
||||
from .aggregators import MeanAggregator, PoolingAggregator, SeqAggregator, GCNAggregator, TwoLayerPoolingAggregator
|
||||
from .aggregators import MeanAggregator, MaxPoolingAggregator, MeanPoolingAggregator, SeqAggregator, GCNAggregator
|
||||
|
||||
flags = tf.app.flags
|
||||
FLAGS = flags.FLAGS
|
||||
@ -212,10 +212,10 @@ class SampleAndAggregate(GeneralizedModel):
|
||||
self.aggregator_cls = MeanAggregator
|
||||
elif aggregator_type == "seq":
|
||||
self.aggregator_cls = SeqAggregator
|
||||
elif aggregator_type == "pool":
|
||||
self.aggregator_cls = PoolingAggregator
|
||||
elif aggregator_type == "pool_2":
|
||||
self.aggregator_cls = TwoLayerPoolingAggregator
|
||||
elif aggregator_type == "maxpool":
|
||||
self.aggregator_cls = MaxPoolingAggregator
|
||||
elif aggregator_type == "meanpool":
|
||||
self.aggregator_cls = MeanPoolingAggregator
|
||||
elif aggregator_type == "gcn":
|
||||
self.aggregator_cls = GCNAggregator
|
||||
else:
|
||||
|
@ -2,7 +2,7 @@ import tensorflow as tf
|
||||
|
||||
import graphsage.models as models
|
||||
import graphsage.layers as layers
|
||||
from graphsage.aggregators import MeanAggregator, PoolingAggregator, SeqAggregator, GCNAggregator, TwoLayerPoolingAggregator
|
||||
from graphsage.aggregators import MeanAggregator, MaxPoolingAggregator, MeanPoolingAggregator, SeqAggregator, GCNAggregator
|
||||
|
||||
flags = tf.app.flags
|
||||
FLAGS = flags.FLAGS
|
||||
@ -35,10 +35,10 @@ class SupervisedGraphsage(models.SampleAndAggregate):
|
||||
self.aggregator_cls = MeanAggregator
|
||||
elif aggregator_type == "seq":
|
||||
self.aggregator_cls = SeqAggregator
|
||||
elif aggregator_type == "pool":
|
||||
self.aggregator_cls = PoolingAggregator
|
||||
elif aggregator_type == "pool_2":
|
||||
self.aggregator_cls = TwoLayerPoolingAggregator
|
||||
elif aggregator_type == "meanpool":
|
||||
self.aggregator_cls = MeanPoolingAggregator
|
||||
elif aggregator_type == "maxpool":
|
||||
self.aggregator_cls = MaxPoolingAggregator
|
||||
elif aggregator_type == "gcn":
|
||||
self.aggregator_cls = GCNAggregator
|
||||
else:
|
||||
|
@ -39,8 +39,8 @@ flags.DEFINE_float('dropout', 0.0, 'dropout rate (1 - keep probability).')
|
||||
flags.DEFINE_float('weight_decay', 0.0, 'weight for l2 loss on embedding matrix.')
|
||||
flags.DEFINE_integer('max_degree', 128, 'maximum node degree.')
|
||||
flags.DEFINE_integer('samples_1', 25, 'number of samples in layer 1')
|
||||
flags.DEFINE_integer('samples_2', 10, 'number of users samples in layer 2')
|
||||
flags.DEFINE_integer('samples_3', 0, 'number of users samples in layer 3. (Only or mean model)')
|
||||
flags.DEFINE_integer('samples_2', 10, 'number of samples in layer 2')
|
||||
flags.DEFINE_integer('samples_3', 0, 'number of users samples in layer 3. (Only for mean model)')
|
||||
flags.DEFINE_integer('dim_1', 128, 'Size of output dim (final is 2x this, if using concat)')
|
||||
flags.DEFINE_integer('dim_2', 128, 'Size of output dim (final is 2x this, if using concat)')
|
||||
flags.DEFINE_boolean('random_context', True, 'Whether to use random context or direct edges')
|
||||
@ -202,7 +202,7 @@ def train(train_data, test_data=None):
|
||||
identity_dim = FLAGS.identity_dim,
|
||||
logging=True)
|
||||
|
||||
elif FLAGS.model == 'graphsage_pool':
|
||||
elif FLAGS.model == 'graphsage_maxpool':
|
||||
sampler = UniformNeighborSampler(adj_info)
|
||||
layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
|
||||
SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
|
||||
@ -217,6 +217,23 @@ def train(train_data, test_data=None):
|
||||
sigmoid_loss = FLAGS.sigmoid,
|
||||
identity_dim = FLAGS.identity_dim,
|
||||
logging=True)
|
||||
|
||||
elif FLAGS.model == 'graphsage_meanpool':
|
||||
sampler = UniformNeighborSampler(adj_info)
|
||||
layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
|
||||
SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
|
||||
|
||||
model = SupervisedGraphsage(num_classes, placeholders,
|
||||
features,
|
||||
adj_info,
|
||||
minibatch.deg,
|
||||
layer_infos=layer_infos,
|
||||
aggregator_type="meanpool",
|
||||
model_size=FLAGS.model_size,
|
||||
sigmoid_loss = FLAGS.sigmoid,
|
||||
identity_dim = FLAGS.identity_dim,
|
||||
logging=True)
|
||||
|
||||
else:
|
||||
raise Exception('Error: model name unrecognized.')
|
||||
|
||||
|
@ -194,7 +194,7 @@ def train(train_data, test_data=None):
|
||||
model_size=FLAGS.model_size,
|
||||
logging=True)
|
||||
|
||||
elif FLAGS.model == 'graphsage_pool':
|
||||
elif FLAGS.model == 'graphsage_maxpool':
|
||||
sampler = UniformNeighborSampler(adj_info)
|
||||
layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
|
||||
SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
|
||||
@ -204,10 +204,25 @@ def train(train_data, test_data=None):
|
||||
adj_info,
|
||||
minibatch.deg,
|
||||
layer_infos=layer_infos,
|
||||
aggregator_type="pool",
|
||||
aggregator_type="maxpool",
|
||||
model_size=FLAGS.model_size,
|
||||
identity_dim = FLAGS.identity_dim,
|
||||
logging=True)
|
||||
elif FLAGS.model == 'graphsage_meanpool':
|
||||
sampler = UniformNeighborSampler(adj_info)
|
||||
layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
|
||||
SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
|
||||
|
||||
model = SampleAndAggregate(placeholders,
|
||||
features,
|
||||
adj_info,
|
||||
minibatch.deg,
|
||||
layer_infos=layer_infos,
|
||||
aggregator_type="meanpool",
|
||||
model_size=FLAGS.model_size,
|
||||
identity_dim = FLAGS.identity_dim,
|
||||
logging=True)
|
||||
|
||||
elif FLAGS.model == 'n2v':
|
||||
model = Node2VecModel(placeholders, features.shape[0],
|
||||
minibatch.deg,
|
||||
|
Loading…
Reference in New Issue
Block a user