Added mean pooling.
This commit is contained in:
parent
b142770576
commit
05c72baeb7
@ -116,12 +116,12 @@ class GCNAggregator(Layer):
|
|||||||
return self.act(output)
|
return self.act(output)
|
||||||
|
|
||||||
|
|
||||||
class PoolingAggregator(Layer):
|
class MaxPoolingAggregator(Layer):
|
||||||
""" Aggregates via max-pooling over MLP functions.
|
""" Aggregates via max-pooling over MLP functions.
|
||||||
"""
|
"""
|
||||||
def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None,
|
def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None,
|
||||||
dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs):
|
dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs):
|
||||||
super(PoolingAggregator, self).__init__(**kwargs)
|
super(MaxPoolingAggregator, self).__init__(**kwargs)
|
||||||
|
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
@ -194,12 +194,91 @@ class PoolingAggregator(Layer):
|
|||||||
|
|
||||||
return self.act(output)
|
return self.act(output)
|
||||||
|
|
||||||
class TwoLayerPoolingAggregator(Layer):
|
class MeanPoolingAggregator(Layer):
|
||||||
|
""" Aggregates via mean-pooling over MLP functions.
|
||||||
|
"""
|
||||||
|
def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None,
|
||||||
|
dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs):
|
||||||
|
super(MeanPoolingAggregator, self).__init__(**kwargs)
|
||||||
|
|
||||||
|
self.dropout = dropout
|
||||||
|
self.bias = bias
|
||||||
|
self.act = act
|
||||||
|
self.concat = concat
|
||||||
|
|
||||||
|
if neigh_input_dim is None:
|
||||||
|
neigh_input_dim = input_dim
|
||||||
|
|
||||||
|
if name is not None:
|
||||||
|
name = '/' + name
|
||||||
|
else:
|
||||||
|
name = ''
|
||||||
|
|
||||||
|
if model_size == "small":
|
||||||
|
hidden_dim = self.hidden_dim = 512
|
||||||
|
elif model_size == "big":
|
||||||
|
hidden_dim = self.hidden_dim = 1024
|
||||||
|
|
||||||
|
self.mlp_layers = []
|
||||||
|
self.mlp_layers.append(Dense(input_dim=neigh_input_dim,
|
||||||
|
output_dim=hidden_dim,
|
||||||
|
act=tf.nn.relu,
|
||||||
|
dropout=dropout,
|
||||||
|
sparse_inputs=False,
|
||||||
|
logging=self.logging))
|
||||||
|
|
||||||
|
with tf.variable_scope(self.name + name + '_vars'):
|
||||||
|
self.vars['neigh_weights'] = glorot([hidden_dim, output_dim],
|
||||||
|
name='neigh_weights')
|
||||||
|
|
||||||
|
self.vars['self_weights'] = glorot([input_dim, output_dim],
|
||||||
|
name='self_weights')
|
||||||
|
if self.bias:
|
||||||
|
self.vars['bias'] = zeros([self.output_dim], name='bias')
|
||||||
|
|
||||||
|
if self.logging:
|
||||||
|
self._log_vars()
|
||||||
|
|
||||||
|
self.input_dim = input_dim
|
||||||
|
self.output_dim = output_dim
|
||||||
|
self.neigh_input_dim = neigh_input_dim
|
||||||
|
|
||||||
|
def _call(self, inputs):
|
||||||
|
self_vecs, neigh_vecs = inputs
|
||||||
|
neigh_h = neigh_vecs
|
||||||
|
|
||||||
|
dims = tf.shape(neigh_h)
|
||||||
|
batch_size = dims[0]
|
||||||
|
num_neighbors = dims[1]
|
||||||
|
# [nodes * sampled neighbors] x [hidden_dim]
|
||||||
|
h_reshaped = tf.reshape(neigh_h, (batch_size * num_neighbors, self.neigh_input_dim))
|
||||||
|
|
||||||
|
for l in self.mlp_layers:
|
||||||
|
h_reshaped = l(h_reshaped)
|
||||||
|
neigh_h = tf.reshape(h_reshaped, (batch_size, num_neighbors, self.hidden_dim))
|
||||||
|
neigh_h = tf.reduce_mean(neigh_h, axis=1)
|
||||||
|
|
||||||
|
from_neighs = tf.matmul(neigh_h, self.vars['neigh_weights'])
|
||||||
|
from_self = tf.matmul(self_vecs, self.vars["self_weights"])
|
||||||
|
|
||||||
|
if not self.concat:
|
||||||
|
output = tf.add_n([from_self, from_neighs])
|
||||||
|
else:
|
||||||
|
output = tf.concat([from_self, from_neighs], axis=1)
|
||||||
|
|
||||||
|
# bias
|
||||||
|
if self.bias:
|
||||||
|
output += self.vars['bias']
|
||||||
|
|
||||||
|
return self.act(output)
|
||||||
|
|
||||||
|
|
||||||
|
class TwoMaxLayerPoolingAggregator(Layer):
|
||||||
""" Aggregates via pooling over two MLP functions.
|
""" Aggregates via pooling over two MLP functions.
|
||||||
"""
|
"""
|
||||||
def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None,
|
def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None,
|
||||||
dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs):
|
dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs):
|
||||||
super(TwoLayerPoolingAggregator, self).__init__(**kwargs)
|
super(TwoMaxLayerPoolingAggregator, self).__init__(**kwargs)
|
||||||
|
|
||||||
self.dropout = dropout
|
self.dropout = dropout
|
||||||
self.bias = bias
|
self.bias = bias
|
||||||
|
@ -7,7 +7,7 @@ import graphsage.layers as layers
|
|||||||
import graphsage.metrics as metrics
|
import graphsage.metrics as metrics
|
||||||
|
|
||||||
from .prediction import BipartiteEdgePredLayer
|
from .prediction import BipartiteEdgePredLayer
|
||||||
from .aggregators import MeanAggregator, PoolingAggregator, SeqAggregator, GCNAggregator, TwoLayerPoolingAggregator
|
from .aggregators import MeanAggregator, MaxPoolingAggregator, MeanPoolingAggregator, SeqAggregator, GCNAggregator
|
||||||
|
|
||||||
flags = tf.app.flags
|
flags = tf.app.flags
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
@ -212,10 +212,10 @@ class SampleAndAggregate(GeneralizedModel):
|
|||||||
self.aggregator_cls = MeanAggregator
|
self.aggregator_cls = MeanAggregator
|
||||||
elif aggregator_type == "seq":
|
elif aggregator_type == "seq":
|
||||||
self.aggregator_cls = SeqAggregator
|
self.aggregator_cls = SeqAggregator
|
||||||
elif aggregator_type == "pool":
|
elif aggregator_type == "maxpool":
|
||||||
self.aggregator_cls = PoolingAggregator
|
self.aggregator_cls = MaxPoolingAggregator
|
||||||
elif aggregator_type == "pool_2":
|
elif aggregator_type == "meanpool":
|
||||||
self.aggregator_cls = TwoLayerPoolingAggregator
|
self.aggregator_cls = MeanPoolingAggregator
|
||||||
elif aggregator_type == "gcn":
|
elif aggregator_type == "gcn":
|
||||||
self.aggregator_cls = GCNAggregator
|
self.aggregator_cls = GCNAggregator
|
||||||
else:
|
else:
|
||||||
|
@ -2,7 +2,7 @@ import tensorflow as tf
|
|||||||
|
|
||||||
import graphsage.models as models
|
import graphsage.models as models
|
||||||
import graphsage.layers as layers
|
import graphsage.layers as layers
|
||||||
from graphsage.aggregators import MeanAggregator, PoolingAggregator, SeqAggregator, GCNAggregator, TwoLayerPoolingAggregator
|
from graphsage.aggregators import MeanAggregator, MaxPoolingAggregator, MeanPoolingAggregator, SeqAggregator, GCNAggregator
|
||||||
|
|
||||||
flags = tf.app.flags
|
flags = tf.app.flags
|
||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
@ -35,10 +35,10 @@ class SupervisedGraphsage(models.SampleAndAggregate):
|
|||||||
self.aggregator_cls = MeanAggregator
|
self.aggregator_cls = MeanAggregator
|
||||||
elif aggregator_type == "seq":
|
elif aggregator_type == "seq":
|
||||||
self.aggregator_cls = SeqAggregator
|
self.aggregator_cls = SeqAggregator
|
||||||
elif aggregator_type == "pool":
|
elif aggregator_type == "meanpool":
|
||||||
self.aggregator_cls = PoolingAggregator
|
self.aggregator_cls = MeanPoolingAggregator
|
||||||
elif aggregator_type == "pool_2":
|
elif aggregator_type == "maxpool":
|
||||||
self.aggregator_cls = TwoLayerPoolingAggregator
|
self.aggregator_cls = MaxPoolingAggregator
|
||||||
elif aggregator_type == "gcn":
|
elif aggregator_type == "gcn":
|
||||||
self.aggregator_cls = GCNAggregator
|
self.aggregator_cls = GCNAggregator
|
||||||
else:
|
else:
|
||||||
|
@ -39,8 +39,8 @@ flags.DEFINE_float('dropout', 0.0, 'dropout rate (1 - keep probability).')
|
|||||||
flags.DEFINE_float('weight_decay', 0.0, 'weight for l2 loss on embedding matrix.')
|
flags.DEFINE_float('weight_decay', 0.0, 'weight for l2 loss on embedding matrix.')
|
||||||
flags.DEFINE_integer('max_degree', 128, 'maximum node degree.')
|
flags.DEFINE_integer('max_degree', 128, 'maximum node degree.')
|
||||||
flags.DEFINE_integer('samples_1', 25, 'number of samples in layer 1')
|
flags.DEFINE_integer('samples_1', 25, 'number of samples in layer 1')
|
||||||
flags.DEFINE_integer('samples_2', 10, 'number of users samples in layer 2')
|
flags.DEFINE_integer('samples_2', 10, 'number of samples in layer 2')
|
||||||
flags.DEFINE_integer('samples_3', 0, 'number of users samples in layer 3. (Only or mean model)')
|
flags.DEFINE_integer('samples_3', 0, 'number of users samples in layer 3. (Only for mean model)')
|
||||||
flags.DEFINE_integer('dim_1', 128, 'Size of output dim (final is 2x this, if using concat)')
|
flags.DEFINE_integer('dim_1', 128, 'Size of output dim (final is 2x this, if using concat)')
|
||||||
flags.DEFINE_integer('dim_2', 128, 'Size of output dim (final is 2x this, if using concat)')
|
flags.DEFINE_integer('dim_2', 128, 'Size of output dim (final is 2x this, if using concat)')
|
||||||
flags.DEFINE_boolean('random_context', True, 'Whether to use random context or direct edges')
|
flags.DEFINE_boolean('random_context', True, 'Whether to use random context or direct edges')
|
||||||
@ -202,7 +202,7 @@ def train(train_data, test_data=None):
|
|||||||
identity_dim = FLAGS.identity_dim,
|
identity_dim = FLAGS.identity_dim,
|
||||||
logging=True)
|
logging=True)
|
||||||
|
|
||||||
elif FLAGS.model == 'graphsage_pool':
|
elif FLAGS.model == 'graphsage_maxpool':
|
||||||
sampler = UniformNeighborSampler(adj_info)
|
sampler = UniformNeighborSampler(adj_info)
|
||||||
layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
|
layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
|
||||||
SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
|
SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
|
||||||
@ -217,6 +217,23 @@ def train(train_data, test_data=None):
|
|||||||
sigmoid_loss = FLAGS.sigmoid,
|
sigmoid_loss = FLAGS.sigmoid,
|
||||||
identity_dim = FLAGS.identity_dim,
|
identity_dim = FLAGS.identity_dim,
|
||||||
logging=True)
|
logging=True)
|
||||||
|
|
||||||
|
elif FLAGS.model == 'graphsage_meanpool':
|
||||||
|
sampler = UniformNeighborSampler(adj_info)
|
||||||
|
layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
|
||||||
|
SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
|
||||||
|
|
||||||
|
model = SupervisedGraphsage(num_classes, placeholders,
|
||||||
|
features,
|
||||||
|
adj_info,
|
||||||
|
minibatch.deg,
|
||||||
|
layer_infos=layer_infos,
|
||||||
|
aggregator_type="meanpool",
|
||||||
|
model_size=FLAGS.model_size,
|
||||||
|
sigmoid_loss = FLAGS.sigmoid,
|
||||||
|
identity_dim = FLAGS.identity_dim,
|
||||||
|
logging=True)
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise Exception('Error: model name unrecognized.')
|
raise Exception('Error: model name unrecognized.')
|
||||||
|
|
||||||
|
@ -194,7 +194,7 @@ def train(train_data, test_data=None):
|
|||||||
model_size=FLAGS.model_size,
|
model_size=FLAGS.model_size,
|
||||||
logging=True)
|
logging=True)
|
||||||
|
|
||||||
elif FLAGS.model == 'graphsage_pool':
|
elif FLAGS.model == 'graphsage_maxpool':
|
||||||
sampler = UniformNeighborSampler(adj_info)
|
sampler = UniformNeighborSampler(adj_info)
|
||||||
layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
|
layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
|
||||||
SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
|
SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
|
||||||
@ -204,10 +204,25 @@ def train(train_data, test_data=None):
|
|||||||
adj_info,
|
adj_info,
|
||||||
minibatch.deg,
|
minibatch.deg,
|
||||||
layer_infos=layer_infos,
|
layer_infos=layer_infos,
|
||||||
aggregator_type="pool",
|
aggregator_type="maxpool",
|
||||||
model_size=FLAGS.model_size,
|
model_size=FLAGS.model_size,
|
||||||
identity_dim = FLAGS.identity_dim,
|
identity_dim = FLAGS.identity_dim,
|
||||||
logging=True)
|
logging=True)
|
||||||
|
elif FLAGS.model == 'graphsage_meanpool':
|
||||||
|
sampler = UniformNeighborSampler(adj_info)
|
||||||
|
layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1),
|
||||||
|
SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)]
|
||||||
|
|
||||||
|
model = SampleAndAggregate(placeholders,
|
||||||
|
features,
|
||||||
|
adj_info,
|
||||||
|
minibatch.deg,
|
||||||
|
layer_infos=layer_infos,
|
||||||
|
aggregator_type="meanpool",
|
||||||
|
model_size=FLAGS.model_size,
|
||||||
|
identity_dim = FLAGS.identity_dim,
|
||||||
|
logging=True)
|
||||||
|
|
||||||
elif FLAGS.model == 'n2v':
|
elif FLAGS.model == 'n2v':
|
||||||
model = Node2VecModel(placeholders, features.shape[0],
|
model = Node2VecModel(placeholders, features.shape[0],
|
||||||
minibatch.deg,
|
minibatch.deg,
|
||||||
|
Loading…
Reference in New Issue
Block a user