From a373623c16ac6c026388e0a33c70b85e648961f5 Mon Sep 17 00:00:00 2001 From: William L Hamilton Date: Sat, 16 Sep 2017 11:43:32 -0700 Subject: [PATCH 01/10] Update README.md --- README.md | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/README.md b/README.md index 4ccf602..9cc74dd 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -## GraphSAGE: Inductive Representation Learning on Large Graphs +## GraphSage: Inductive Representation Learning on Large Graphs #### Authors: [William L. Hamilton](http://stanford.edu/~wleif) (wleif@stanford.edu), [Rex Ying](http://joy-of-thinking.weebly.com/) (rexying@stanford.edu) #### [Project Website](http://snap.stanford.edu/graphsage/) @@ -6,14 +6,15 @@ ### Overview -This directory contains code necessary to run the GraphSAGE algorithm. +This directory contains code necessary to run the GraphSage algorithm. +GraphSage can be viewed as a stochastic generalization of graph convolutions, and it is especially useful for massive, dynamic graphs that contain rich feature information. See our [paper](https://arxiv.org/pdf/1706.02216.pdf) for details on the algorithm. The example_data subdirectory contains a small example of the protein-protein interaction data, which includes 3 training graphs + one validation graph and one test graph. The full Reddit and PPI datasets (described in the paper) are available on the [project website](http://snap.stanford.edu/graphsage/). -If you make use of this code or the GraphSAGE algorithm in your work, please cite the following paper: +If you make use of this code or the GraphSage algorithm in your work, please cite the following paper: @article{hamilton2017inductive, author = {Hamilton, William L. and Ying, Rex and Leskovec, Jure}, @@ -67,7 +68,7 @@ Note that the full log outputs and stored embeddings can be 5-10Gb in size (on t #### Using the output of the unsupervised models -The unsupervised variants of GraphSAGE will output embeddings to the logging directory as described above. +The unsupervised variants of GraphSage will output embeddings to the logging directory as described above. These embeddings can then be used in downstream machine learning applications. The `eval_scripts` directory contains examples of feeding the embeddings into simple logistic classifiers. From 87e978e415d8c5917e41bfc8bf0f3b7466c54fe8 Mon Sep 17 00:00:00 2001 From: williamleif Date: Sat, 16 Sep 2017 14:17:14 -0700 Subject: [PATCH 02/10] Added support for identity features. --- README.md | 35 ++++++++++++++++++++++----------- graphsage/models.py | 21 ++++++++++++++++---- graphsage/supervised_models.py | 18 +++++++++++++---- graphsage/supervised_train.py | 10 ++++++++-- graphsage/unsupervised_train.py | 17 ++++++++++------ graphsage/utils.py | 20 ++++++++++++------- 6 files changed, 86 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 9cc74dd..d2ac546 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -## GraphSage: Inductive Representation Learning on Large Graphs +## GraphSage: Representation Learning on Large Graphs #### Authors: [William L. Hamilton](http://stanford.edu/~wleif) (wleif@stanford.edu), [Rex Ying](http://joy-of-thinking.weebly.com/) (rexying@stanford.edu) #### [Project Website](http://snap.stanford.edu/graphsage/) @@ -10,16 +10,23 @@ This directory contains code necessary to run the GraphSage algorithm. GraphSage can be viewed as a stochastic generalization of graph convolutions, and it is especially useful for massive, dynamic graphs that contain rich feature information. See our [paper](https://arxiv.org/pdf/1706.02216.pdf) for details on the algorithm. +*Note:* GraphSage now also has better support for training on smaller, static graphs and graphs that don't have node features. +The original algorithm and paper are focused on the task of inductive generalization (i.e., generating embeddings for nodes that were not present during training), +but many benchmarks/tasks use simple static graphs that do not necessarily have features. +To support this use case, GraphSage now includes optional "identity features" that can be used with or without other node attributes. +Including identity features will increase the runtime, but also potentially increase performance (at the usual risk of overfitting). +See the section on "Running the code" below. + The example_data subdirectory contains a small example of the protein-protein interaction data, which includes 3 training graphs + one validation graph and one test graph. The full Reddit and PPI datasets (described in the paper) are available on the [project website](http://snap.stanford.edu/graphsage/). If you make use of this code or the GraphSage algorithm in your work, please cite the following paper: - @article{hamilton2017inductive, + @inproceedings{hamilton2017inductive, author = {Hamilton, William L. and Ying, Rex and Leskovec, Jure}, title = {Inductive Representation Learning on Large Graphs}, - journal = {arXiv preprint, arXiv:1603.04467}, + booktitle = {NIPS}, year = {2017} } @@ -29,7 +36,13 @@ Recent versions of TensorFlow, numpy, scipy, and networkx are required. ### Running the code -The example_unsupervised.sh and example_supervised.sh files contain example usages of the code, which use the unsupervised and supervised variants of GraphSAGE, respectively. +The example_unsupervised.sh and example_supervised.sh files contain example usages of the code, which use the unsupervised and supervised variants of GraphSage, respectively. + +If your benchmark/task does not require generalizing to unseen data, we recommend you try setting the "--identity_dim" flag to a value in the range [64,256]. +This flag will make the model use embed unique node ids as attributes, which will increase the runtime but also potentially increase the performance. +Note that you should set this flag and *not* try to pass dense one-hot vectors as features (due to sparsity). +The "dimension" of identity features specifies how many parameters there are per node in the sparse identity-feature lookup table. + Note that example_unsupervised.sh sets a very small max iteration number, which can be increased to improve performance. We generally found that performance continued to improve even after the loss was very near convergence (i.e., even when the loss was decreasing at a very slow rate). @@ -41,21 +54,19 @@ As input, at minimum the code requires that a --train_prefix option is specified * -G.json -- A networkx-specified json file describing the input graph. Nodes have 'val' and 'test' attributes specifying if they are a part of the validation and test sets, respectively. * -id_map.json -- A json-stored dictionary mapping the graph node ids to consecutive integers. * -id_map.json -- A json-stored dictionary mapping the graph node ids to classes. -* -feats.npy --- A numpy-stored array of node features; ordering given by id_map.json -* -walks.txt --- A text file specifying random walk co-occurrences (one pair per line) (*only for unsupervised version of graphsage) +* -feats.npy [optional] --- A numpy-stored array of node features; ordering given by id_map.json. Can be omitted and only identity features will be used. +* -walks.txt [optional] --- A text file specifying random walk co-occurrences (one pair per line) (*only for unsupervised version of graphsage) To run the model on a new dataset, you need to make data files in the format described above. To run random walks for the unsupervised model and to generate the -walks.txt file) you can use the `run_walks` function in `graphsage.utils`. - - #### Model variants The user must also specify a --model, the variants of which are described in detail in the paper: -* graphsage_mean -- GraphSAGE with mean-based aggregator -* graphsage_seq -- GraphSAGE with LSTM-based aggregator -* graphsage_pool -- GraphSAGE with max-pooling aggregator -* gcn -- GraphSAGE with GCN-based aggregator +* graphsage_mean -- GraphSage with mean-based aggregator +* graphsage_seq -- GraphSage with LSTM-based aggregator +* graphsage_pool -- GraphSage with max-pooling aggregator +* gcn -- GraphSage with GCN-based aggregator * n2v -- an implementation of [DeepWalk](https://arxiv.org/abs/1403.6652) (called n2v for short in the code.) #### Logging directory diff --git a/graphsage/models.py b/graphsage/models.py index b40b17f..861e382 100644 --- a/graphsage/models.py +++ b/graphsage/models.py @@ -191,12 +191,13 @@ class SampleAndAggregate(GeneralizedModel): def __init__(self, placeholders, features, adj, degrees, layer_infos, concat=True, aggregator_type="mean", - model_size="small", + model_size="small", identity_dim=0, **kwargs): ''' Args: - placeholders: Stanford TensorFlow placeholder object. - - features: Numpy array with node features. + - features: Numpy array with node features. + NOTE: Pass a None object to train in featureless mode (identity features for nodes)! - adj: Numpy array with adjacency lists (padded with random re-samples) - degrees: Numpy array with node degrees. - layer_infos: List of SAGEInfo namedtuples that describe the parameters of all @@ -204,6 +205,7 @@ class SampleAndAggregate(GeneralizedModel): - concat: whether to concatenate during recursive iterations - aggregator_type: how to aggregate neighbor information - model_size: one of "small" and "big" + - identity_dim: Set to positive int to use identity features (slow and cannot generalize, but better accuracy) ''' super(SampleAndAggregate, self).__init__(**kwargs) if aggregator_type == "mean": @@ -224,11 +226,22 @@ class SampleAndAggregate(GeneralizedModel): self.inputs2 = placeholders["batch2"] self.model_size = model_size self.adj_info = adj - self.features = tf.Variable(tf.constant(features, dtype=tf.float32), trainable=False) + if identity_dim > 0: + self.embeds = tf.get_variable("node_embeddings", [adj.get_shape().as_list()[0], identity_dim]) + else: + self.embeds = None + if features is None: + if identity_dim is None: + raise Exception("Must have a positive value for identity feature dimension if no input features given.") + self.features = self.embeds + else: + self.features = tf.Variable(tf.constant(features, dtype=tf.float32), trainable=False) + if not self.embeds is None: + self.features = tf.concat([self.embeds, self.features], axis=1) self.degrees = degrees self.concat = concat - self.dims = [features.shape[1]] + self.dims = [(0 if features is None else features.shape[1]) + identity_dim] self.dims.extend([layer_infos[i].output_dim for i in range(len(layer_infos))]) self.batch_size = placeholders["batch_size"] self.placeholders = placeholders diff --git a/graphsage/supervised_models.py b/graphsage/supervised_models.py index a8658b6..4bff401 100644 --- a/graphsage/supervised_models.py +++ b/graphsage/supervised_models.py @@ -13,7 +13,7 @@ class SupervisedGraphsage(models.SampleAndAggregate): def __init__(self, num_classes, placeholders, features, adj, degrees, layer_infos, concat=True, aggregator_type="mean", - model_size="small", sigmoid_loss=False, + model_size="small", sigmoid_loss=False, identity_dim=0, **kwargs): ''' Args: @@ -48,13 +48,23 @@ class SupervisedGraphsage(models.SampleAndAggregate): self.inputs1 = placeholders["batch"] self.model_size = model_size self.adj_info = adj - self.features = tf.Variable(tf.constant(features, dtype=tf.float32), trainable=False) + if identity_dim > 0: + self.embeds = tf.get_variable("node_embeddings", [adj.get_shape().as_list()[0], identity_dim]) + else: + self.embeds = None + if features is None: + if identity_dim is None: + raise Exception("Must have a positive value for identity feature dimension if no input features given.") + self.features = self.embeds + else: + self.features = tf.Variable(tf.constant(features, dtype=tf.float32), trainable=False) + if not self.embeds is None: + self.features = tf.concat([self.embeds, self.features], axis=1) self.degrees = degrees self.concat = concat self.num_classes = num_classes self.sigmoid_loss = sigmoid_loss - - self.dims = [features.shape[1]] + self.dims = [(0 if features is None else features.shape[1]) + identity_dim] self.dims.extend([layer_infos[i].output_dim for i in range(len(layer_infos))]) self.batch_size = placeholders["batch_size"] self.placeholders = placeholders diff --git a/graphsage/supervised_train.py b/graphsage/supervised_train.py index fa52581..75d3451 100644 --- a/graphsage/supervised_train.py +++ b/graphsage/supervised_train.py @@ -46,6 +46,7 @@ flags.DEFINE_integer('dim_2', 128, 'Size of output dim (final is 2x this, if usi flags.DEFINE_boolean('random_context', True, 'Whether to use random context or direct edges') flags.DEFINE_integer('batch_size', 512, 'minibatch size.') flags.DEFINE_boolean('sigmoid', False, 'whether to use sigmoid loss') +flags.DEFINE_integer('identity_dim', 0, 'Set to positive value to use identity embedding features of that dimension. Default 0.') #logging, saving, validation settings etc. flags.DEFINE_string('base_log_dir', '.', 'base directory for logging and saving embeddings') @@ -129,8 +130,9 @@ def train(train_data, test_data=None): else: num_classes = len(set(class_map.values())) - # pad with dummy zero vector - features = np.vstack([features, np.zeros((features.shape[1],))]) + if not features is None: + # pad with dummy zero vector + features = np.vstack([features, np.zeros((features.shape[1],))]) context_pairs = train_data[3] if FLAGS.random_context else None placeholders = construct_placeholders(num_classes) @@ -164,6 +166,7 @@ def train(train_data, test_data=None): layer_infos, model_size=FLAGS.model_size, sigmoid_loss = FLAGS.sigmoid, + identity_dim = FLAGS.identity_dim, logging=True) elif FLAGS.model == 'gcn': # Create model @@ -180,6 +183,7 @@ def train(train_data, test_data=None): model_size=FLAGS.model_size, concat=False, sigmoid_loss = FLAGS.sigmoid, + identity_dim = FLAGS.identity_dim, logging=True) elif FLAGS.model == 'graphsage_seq': @@ -195,6 +199,7 @@ def train(train_data, test_data=None): aggregator_type="seq", model_size=FLAGS.model_size, sigmoid_loss = FLAGS.sigmoid, + identity_dim = FLAGS.identity_dim, logging=True) elif FLAGS.model == 'graphsage_pool': @@ -210,6 +215,7 @@ def train(train_data, test_data=None): aggregator_type="pool", model_size=FLAGS.model_size, sigmoid_loss = FLAGS.sigmoid, + identity_dim = FLAGS.identity_dim, logging=True) else: raise Exception('Error: model name unrecognized.') diff --git a/graphsage/unsupervised_train.py b/graphsage/unsupervised_train.py index 945aa20..f3f6737 100644 --- a/graphsage/unsupervised_train.py +++ b/graphsage/unsupervised_train.py @@ -43,6 +43,7 @@ flags.DEFINE_boolean('random_context', True, 'Whether to use random context or d flags.DEFINE_integer('neg_sample_size', 20, 'number of negative samples') flags.DEFINE_integer('batch_size', 512, 'minibatch size.') flags.DEFINE_integer('n2v_test_epochs', 1, 'Number of new SGD epochs for n2v.') +flags.DEFINE_integer('identity_dim', 0, 'Set to positive value to use identity embedding features of that dimension. Default 0.') #logging, saving, validation settings etc. flags.DEFINE_boolean('save_embeddings', True, 'whether to save embeddings for all nodes after training') @@ -115,7 +116,7 @@ def save_val_embeddings(sess, model, minibatch_iter, size, out_dir, mod=""): with open(out_dir + name + mod + ".txt", "w") as fp: fp.write("\n".join(map(str,nodes))) -def construct_placeholders(feature_size): +def construct_placeholders(): # Define placeholders placeholders = { 'batch1' : tf.placeholder(tf.int32, shape=(None), name='batch1'), @@ -133,12 +134,12 @@ def train(train_data, test_data=None): features = train_data[1] id_map = train_data[2] - # pad with dummy zero vector - features = np.vstack([features, np.zeros((features.shape[1],))]) - feature_size = features.shape[1] + if not features is None: + # pad with dummy zero vector + features = np.vstack([features, np.zeros((features.shape[1],))]) context_pairs = train_data[3] if FLAGS.random_context else None - placeholders = construct_placeholders(feature_size) + placeholders = construct_placeholders() minibatch = EdgeMinibatchIterator(G, id_map, placeholders, batch_size=FLAGS.batch_size, @@ -159,6 +160,7 @@ def train(train_data, test_data=None): minibatch.deg, layer_infos=layer_infos, model_size=FLAGS.model_size, + identity_dim = FLAGS.identity_dim, logging=True) elif FLAGS.model == 'gcn': # Create model @@ -173,6 +175,7 @@ def train(train_data, test_data=None): layer_infos=layer_infos, aggregator_type="gcn", model_size=FLAGS.model_size, + identity_dim = FLAGS.identity_dim, concat=False, logging=True) @@ -186,6 +189,7 @@ def train(train_data, test_data=None): adj_info, minibatch.deg, layer_infos=layer_infos, + identity_dim = FLAGS.identity_dim, aggregator_type="seq", model_size=FLAGS.model_size, logging=True) @@ -202,6 +206,7 @@ def train(train_data, test_data=None): layer_infos=layer_infos, aggregator_type="pool", model_size=FLAGS.model_size, + identity_dim = FLAGS.identity_dim, logging=True) elif FLAGS.model == 'n2v': model = Node2VecModel(placeholders, features.shape[0], @@ -354,7 +359,7 @@ def train(train_data, test_data=None): def main(argv=None): print("Loading training data..") - train_data = load_data(FLAGS.train_prefix) + train_data = load_data(FLAGS.train_prefix, load_walks=True) print("Done loading training data..") train(train_data) diff --git a/graphsage/utils.py b/graphsage/utils.py index c15f568..a8254d2 100644 --- a/graphsage/utils.py +++ b/graphsage/utils.py @@ -4,13 +4,14 @@ import numpy as np import random import json import sys +import os from networkx.readwrite import json_graph WALK_LEN=5 N_WALKS=50 -def load_data(prefix, normalize=True): +def load_data(prefix, normalize=True, load_walks=False): G_data = json.load(open(prefix + "-G.json")) G = json_graph.node_link_graph(G_data) if isinstance(G.nodes()[0], int): @@ -18,7 +19,11 @@ def load_data(prefix, normalize=True): else: conversion = lambda n : n - feats = np.load(prefix + "-feats.npy") + if os.path.exists(prefix + "-feats.npy"): + feats = np.load(prefix + "-feats.npy") + else: + print("No features present.. Only identity features will be used.") + feats = None id_map = json.load(open(prefix + "-id_map.json")) id_map = {conversion(k):int(v) for k,v in id_map.iteritems()} walks = [] @@ -40,17 +45,18 @@ def load_data(prefix, normalize=True): else: G[edge[0]][edge[1]]['train_removed'] = False - if normalize: + if normalize and not feats is None: from sklearn.preprocessing import StandardScaler train_ids = np.array([id_map[n] for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']]) train_feats = feats[train_ids] scaler = StandardScaler() scaler.fit(train_feats) feats = scaler.transform(feats) - - with open(prefix + "-walks.txt") as fp: - for line in fp: - walks.append(map(conversion, line.split())) + + if load_walks: + with open(prefix + "-walks.txt") as fp: + for line in fp: + walks.append(map(conversion, line.split())) return G, feats, id_map, walks, class_map From b1427705767c4001dacf905b6a45cfb1eb29944b Mon Sep 17 00:00:00 2001 From: williamleif Date: Sat, 16 Sep 2017 14:26:47 -0700 Subject: [PATCH 03/10] Fixed error message when no features are present. --- graphsage/models.py | 2 +- graphsage/supervised_models.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/graphsage/models.py b/graphsage/models.py index 861e382..695cbc5 100644 --- a/graphsage/models.py +++ b/graphsage/models.py @@ -231,7 +231,7 @@ class SampleAndAggregate(GeneralizedModel): else: self.embeds = None if features is None: - if identity_dim is None: + if identity_dim == 0: raise Exception("Must have a positive value for identity feature dimension if no input features given.") self.features = self.embeds else: diff --git a/graphsage/supervised_models.py b/graphsage/supervised_models.py index 4bff401..08fc01e 100644 --- a/graphsage/supervised_models.py +++ b/graphsage/supervised_models.py @@ -53,7 +53,7 @@ class SupervisedGraphsage(models.SampleAndAggregate): else: self.embeds = None if features is None: - if identity_dim is None: + if identity_dim == 0: raise Exception("Must have a positive value for identity feature dimension if no input features given.") self.features = self.embeds else: From 0d9c4a739261c9fd86736af95406d6004de4833d Mon Sep 17 00:00:00 2001 From: William L Hamilton Date: Sat, 16 Sep 2017 14:34:13 -0700 Subject: [PATCH 04/10] Update README.md --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index d2ac546..a4562fe 100644 --- a/README.md +++ b/README.md @@ -39,7 +39,7 @@ Recent versions of TensorFlow, numpy, scipy, and networkx are required. The example_unsupervised.sh and example_supervised.sh files contain example usages of the code, which use the unsupervised and supervised variants of GraphSage, respectively. If your benchmark/task does not require generalizing to unseen data, we recommend you try setting the "--identity_dim" flag to a value in the range [64,256]. -This flag will make the model use embed unique node ids as attributes, which will increase the runtime but also potentially increase the performance. +This flag will make the model embed unique node ids as attributes, which will increase the runtime and number of parameters but also potentially increase the performance. Note that you should set this flag and *not* try to pass dense one-hot vectors as features (due to sparsity). The "dimension" of identity features specifies how many parameters there are per node in the sparse identity-feature lookup table. From 05c72baeb72c44f7923e548c4e1dbc2fe13fed22 Mon Sep 17 00:00:00 2001 From: williamleif Date: Tue, 10 Oct 2017 15:46:12 -0700 Subject: [PATCH 05/10] Added mean pooling. --- graphsage/aggregators.py | 87 +++++++++++++++++++++++++++++++-- graphsage/models.py | 10 ++-- graphsage/supervised_models.py | 10 ++-- graphsage/supervised_train.py | 23 +++++++-- graphsage/unsupervised_train.py | 19 ++++++- 5 files changed, 130 insertions(+), 19 deletions(-) diff --git a/graphsage/aggregators.py b/graphsage/aggregators.py index 705ec69..7dbd252 100644 --- a/graphsage/aggregators.py +++ b/graphsage/aggregators.py @@ -116,12 +116,12 @@ class GCNAggregator(Layer): return self.act(output) -class PoolingAggregator(Layer): +class MaxPoolingAggregator(Layer): """ Aggregates via max-pooling over MLP functions. """ def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None, dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs): - super(PoolingAggregator, self).__init__(**kwargs) + super(MaxPoolingAggregator, self).__init__(**kwargs) self.dropout = dropout self.bias = bias @@ -194,12 +194,91 @@ class PoolingAggregator(Layer): return self.act(output) -class TwoLayerPoolingAggregator(Layer): +class MeanPoolingAggregator(Layer): + """ Aggregates via mean-pooling over MLP functions. + """ + def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None, + dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs): + super(MeanPoolingAggregator, self).__init__(**kwargs) + + self.dropout = dropout + self.bias = bias + self.act = act + self.concat = concat + + if neigh_input_dim is None: + neigh_input_dim = input_dim + + if name is not None: + name = '/' + name + else: + name = '' + + if model_size == "small": + hidden_dim = self.hidden_dim = 512 + elif model_size == "big": + hidden_dim = self.hidden_dim = 1024 + + self.mlp_layers = [] + self.mlp_layers.append(Dense(input_dim=neigh_input_dim, + output_dim=hidden_dim, + act=tf.nn.relu, + dropout=dropout, + sparse_inputs=False, + logging=self.logging)) + + with tf.variable_scope(self.name + name + '_vars'): + self.vars['neigh_weights'] = glorot([hidden_dim, output_dim], + name='neigh_weights') + + self.vars['self_weights'] = glorot([input_dim, output_dim], + name='self_weights') + if self.bias: + self.vars['bias'] = zeros([self.output_dim], name='bias') + + if self.logging: + self._log_vars() + + self.input_dim = input_dim + self.output_dim = output_dim + self.neigh_input_dim = neigh_input_dim + + def _call(self, inputs): + self_vecs, neigh_vecs = inputs + neigh_h = neigh_vecs + + dims = tf.shape(neigh_h) + batch_size = dims[0] + num_neighbors = dims[1] + # [nodes * sampled neighbors] x [hidden_dim] + h_reshaped = tf.reshape(neigh_h, (batch_size * num_neighbors, self.neigh_input_dim)) + + for l in self.mlp_layers: + h_reshaped = l(h_reshaped) + neigh_h = tf.reshape(h_reshaped, (batch_size, num_neighbors, self.hidden_dim)) + neigh_h = tf.reduce_mean(neigh_h, axis=1) + + from_neighs = tf.matmul(neigh_h, self.vars['neigh_weights']) + from_self = tf.matmul(self_vecs, self.vars["self_weights"]) + + if not self.concat: + output = tf.add_n([from_self, from_neighs]) + else: + output = tf.concat([from_self, from_neighs], axis=1) + + # bias + if self.bias: + output += self.vars['bias'] + + return self.act(output) + + +class TwoMaxLayerPoolingAggregator(Layer): """ Aggregates via pooling over two MLP functions. """ def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None, dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs): - super(TwoLayerPoolingAggregator, self).__init__(**kwargs) + super(TwoMaxLayerPoolingAggregator, self).__init__(**kwargs) self.dropout = dropout self.bias = bias diff --git a/graphsage/models.py b/graphsage/models.py index 695cbc5..e9fe791 100644 --- a/graphsage/models.py +++ b/graphsage/models.py @@ -7,7 +7,7 @@ import graphsage.layers as layers import graphsage.metrics as metrics from .prediction import BipartiteEdgePredLayer -from .aggregators import MeanAggregator, PoolingAggregator, SeqAggregator, GCNAggregator, TwoLayerPoolingAggregator +from .aggregators import MeanAggregator, MaxPoolingAggregator, MeanPoolingAggregator, SeqAggregator, GCNAggregator flags = tf.app.flags FLAGS = flags.FLAGS @@ -212,10 +212,10 @@ class SampleAndAggregate(GeneralizedModel): self.aggregator_cls = MeanAggregator elif aggregator_type == "seq": self.aggregator_cls = SeqAggregator - elif aggregator_type == "pool": - self.aggregator_cls = PoolingAggregator - elif aggregator_type == "pool_2": - self.aggregator_cls = TwoLayerPoolingAggregator + elif aggregator_type == "maxpool": + self.aggregator_cls = MaxPoolingAggregator + elif aggregator_type == "meanpool": + self.aggregator_cls = MeanPoolingAggregator elif aggregator_type == "gcn": self.aggregator_cls = GCNAggregator else: diff --git a/graphsage/supervised_models.py b/graphsage/supervised_models.py index 08fc01e..9ea123c 100644 --- a/graphsage/supervised_models.py +++ b/graphsage/supervised_models.py @@ -2,7 +2,7 @@ import tensorflow as tf import graphsage.models as models import graphsage.layers as layers -from graphsage.aggregators import MeanAggregator, PoolingAggregator, SeqAggregator, GCNAggregator, TwoLayerPoolingAggregator +from graphsage.aggregators import MeanAggregator, MaxPoolingAggregator, MeanPoolingAggregator, SeqAggregator, GCNAggregator flags = tf.app.flags FLAGS = flags.FLAGS @@ -35,10 +35,10 @@ class SupervisedGraphsage(models.SampleAndAggregate): self.aggregator_cls = MeanAggregator elif aggregator_type == "seq": self.aggregator_cls = SeqAggregator - elif aggregator_type == "pool": - self.aggregator_cls = PoolingAggregator - elif aggregator_type == "pool_2": - self.aggregator_cls = TwoLayerPoolingAggregator + elif aggregator_type == "meanpool": + self.aggregator_cls = MeanPoolingAggregator + elif aggregator_type == "maxpool": + self.aggregator_cls = MaxPoolingAggregator elif aggregator_type == "gcn": self.aggregator_cls = GCNAggregator else: diff --git a/graphsage/supervised_train.py b/graphsage/supervised_train.py index 75d3451..9580149 100644 --- a/graphsage/supervised_train.py +++ b/graphsage/supervised_train.py @@ -39,8 +39,8 @@ flags.DEFINE_float('dropout', 0.0, 'dropout rate (1 - keep probability).') flags.DEFINE_float('weight_decay', 0.0, 'weight for l2 loss on embedding matrix.') flags.DEFINE_integer('max_degree', 128, 'maximum node degree.') flags.DEFINE_integer('samples_1', 25, 'number of samples in layer 1') -flags.DEFINE_integer('samples_2', 10, 'number of users samples in layer 2') -flags.DEFINE_integer('samples_3', 0, 'number of users samples in layer 3. (Only or mean model)') +flags.DEFINE_integer('samples_2', 10, 'number of samples in layer 2') +flags.DEFINE_integer('samples_3', 0, 'number of users samples in layer 3. (Only for mean model)') flags.DEFINE_integer('dim_1', 128, 'Size of output dim (final is 2x this, if using concat)') flags.DEFINE_integer('dim_2', 128, 'Size of output dim (final is 2x this, if using concat)') flags.DEFINE_boolean('random_context', True, 'Whether to use random context or direct edges') @@ -202,7 +202,7 @@ def train(train_data, test_data=None): identity_dim = FLAGS.identity_dim, logging=True) - elif FLAGS.model == 'graphsage_pool': + elif FLAGS.model == 'graphsage_maxpool': sampler = UniformNeighborSampler(adj_info) layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)] @@ -217,6 +217,23 @@ def train(train_data, test_data=None): sigmoid_loss = FLAGS.sigmoid, identity_dim = FLAGS.identity_dim, logging=True) + + elif FLAGS.model == 'graphsage_meanpool': + sampler = UniformNeighborSampler(adj_info) + layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), + SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)] + + model = SupervisedGraphsage(num_classes, placeholders, + features, + adj_info, + minibatch.deg, + layer_infos=layer_infos, + aggregator_type="meanpool", + model_size=FLAGS.model_size, + sigmoid_loss = FLAGS.sigmoid, + identity_dim = FLAGS.identity_dim, + logging=True) + else: raise Exception('Error: model name unrecognized.') diff --git a/graphsage/unsupervised_train.py b/graphsage/unsupervised_train.py index f3f6737..b3162fd 100644 --- a/graphsage/unsupervised_train.py +++ b/graphsage/unsupervised_train.py @@ -194,7 +194,7 @@ def train(train_data, test_data=None): model_size=FLAGS.model_size, logging=True) - elif FLAGS.model == 'graphsage_pool': + elif FLAGS.model == 'graphsage_maxpool': sampler = UniformNeighborSampler(adj_info) layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)] @@ -204,10 +204,25 @@ def train(train_data, test_data=None): adj_info, minibatch.deg, layer_infos=layer_infos, - aggregator_type="pool", + aggregator_type="maxpool", model_size=FLAGS.model_size, identity_dim = FLAGS.identity_dim, logging=True) + elif FLAGS.model == 'graphsage_meanpool': + sampler = UniformNeighborSampler(adj_info) + layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), + SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)] + + model = SampleAndAggregate(placeholders, + features, + adj_info, + minibatch.deg, + layer_infos=layer_infos, + aggregator_type="meanpool", + model_size=FLAGS.model_size, + identity_dim = FLAGS.identity_dim, + logging=True) + elif FLAGS.model == 'n2v': model = Node2VecModel(placeholders, features.shape[0], minibatch.deg, From c29361aa7d95c49181b208acb58e33350ab232c2 Mon Sep 17 00:00:00 2001 From: williamleif Date: Wed, 11 Oct 2017 14:05:36 -0700 Subject: [PATCH 06/10] Python 3 support. --- graphsage/minibatch.py | 12 ++++++------ graphsage/supervised_train.py | 4 ++-- graphsage/utils.py | 8 ++++---- 3 files changed, 12 insertions(+), 12 deletions(-) diff --git a/graphsage/minibatch.py b/graphsage/minibatch.py index 180648d..60c7b97 100644 --- a/graphsage/minibatch.py +++ b/graphsage/minibatch.py @@ -42,15 +42,15 @@ class EdgeMinibatchIterator(object): self.train_edges = self.edges = np.random.permutation(edges) if not n2v_retrain: self.train_edges = self._remove_isolated(self.train_edges) - self.val_edges = [e for e in G.edges_iter() if G[e[0]][e[1]]['train_removed']] + self.val_edges = [e for e in G.edges() if G[e[0]][e[1]]['train_removed']] else: if fixed_n2v: self.train_edges = self.val_edges = self._n2v_prune(self.edges) else: self.train_edges = self.val_edges = self.edges - print(len([n for n in G.nodes_iter() if not G.node[n]['test'] and not G.node[n]['val']]), 'train nodes') - print(len([n for n in G.nodes_iter() if G.node[n]['test'] or G.node[n]['val']]), 'test nodes') + print(len([n for n in G.nodes() if not G.node[n]['test'] and not G.node[n]['val']]), 'train nodes') + print(len([n for n in G.nodes() if G.node[n]['test'] or G.node[n]['val']]), 'test nodes') self.val_set_size = len(self.val_edges) def _n2v_prune(self, edges): @@ -150,7 +150,7 @@ class EdgeMinibatchIterator(object): def label_val(self): train_edges = [] val_edges = [] - for n1, n2 in self.G.edges_iter(): + for n1, n2 in self.G.edges(): if (self.G.node[n1]['val'] or self.G.node[n1]['test'] or self.G.node[n2]['val'] or self.G.node[n2]['test']): val_edges.append((n1,n2)) @@ -197,8 +197,8 @@ class NodeMinibatchIterator(object): self.adj, self.deg = self.construct_adj() self.test_adj = self.construct_test_adj() - self.val_nodes = [n for n in self.G.nodes_iter() if self.G.node[n]['val']] - self.test_nodes = [n for n in self.G.nodes_iter() if self.G.node[n]['test']] + self.val_nodes = [n for n in self.G.nodes() if self.G.node[n]['val']] + self.test_nodes = [n for n in self.G.nodes() if self.G.node[n]['test']] self.no_train_nodes_set = set(self.val_nodes + self.test_nodes) self.train_nodes = set(G.nodes()).difference(self.no_train_nodes_set) diff --git a/graphsage/supervised_train.py b/graphsage/supervised_train.py index 9580149..240d9aa 100644 --- a/graphsage/supervised_train.py +++ b/graphsage/supervised_train.py @@ -125,8 +125,8 @@ def train(train_data, test_data=None): features = train_data[1] id_map = train_data[2] class_map = train_data[4] - if isinstance(class_map.values()[0], list): - num_classes = len(class_map.values()[0]) + if isinstance(list(class_map.values())[0], list): + num_classes = len(list(class_map.values())[0]) else: num_classes = len(set(class_map.values())) diff --git a/graphsage/utils.py b/graphsage/utils.py index a8254d2..23c6b52 100644 --- a/graphsage/utils.py +++ b/graphsage/utils.py @@ -25,20 +25,20 @@ def load_data(prefix, normalize=True, load_walks=False): print("No features present.. Only identity features will be used.") feats = None id_map = json.load(open(prefix + "-id_map.json")) - id_map = {conversion(k):int(v) for k,v in id_map.iteritems()} + id_map = {conversion(k):int(v) for k,v in id_map.items()} walks = [] class_map = json.load(open(prefix + "-class_map.json")) - if isinstance(class_map.values()[0], list): + if isinstance(list(class_map.values())[0], list): lab_conversion = lambda n : n else: lab_conversion = lambda n : int(n) - class_map = {conversion(k):lab_conversion(v) for k,v in class_map.iteritems()} + class_map = {conversion(k):lab_conversion(v) for k,v in class_map.items()} ## Make sure the graph has edge train_removed annotations ## (some datasets might already have this..) print("Loaded data.. now preprocessing..") - for edge in G.edges_iter(): + for edge in G.edges(): if (G.node[edge[0]]['val'] or G.node[edge[1]]['val'] or G.node[edge[0]]['test'] or G.node[edge[1]]['test']): G[edge[0]][edge[1]]['train_removed'] = True From 676c30f5f4383f2d9cbcdbda4c150a6f831b45b8 Mon Sep 17 00:00:00 2001 From: Can Guney Aksakalli Date: Thu, 12 Oct 2017 16:17:42 +0200 Subject: [PATCH 07/10] adding Dockerfiles --- .dockerignore | 3 +++ Dockerfile | 6 ++++++ Dockerfile.gpu | 6 ++++++ README.md | 42 +++++++++++++++++++++++++++++------------- 4 files changed, 44 insertions(+), 13 deletions(-) create mode 100644 .dockerignore create mode 100644 Dockerfile create mode 100644 Dockerfile.gpu diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..8dad90a --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +.git +Dockerfile* +.gitignore diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..71a833c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,6 @@ +FROM gcr.io/tensorflow/tensorflow:1.3.0 + +RUN pip install networkx==1.11 +RUN rm /notebooks/* + +COPY . /notebooks diff --git a/Dockerfile.gpu b/Dockerfile.gpu new file mode 100644 index 0000000..681f22c --- /dev/null +++ b/Dockerfile.gpu @@ -0,0 +1,6 @@ +FROM gcr.io/tensorflow/tensorflow:1.3.0-gpu + +RUN pip install networkx==1.11 +RUN rm /notebooks/* + +COPY . /notebooks diff --git a/README.md b/README.md index a4562fe..3f77624 100644 --- a/README.md +++ b/README.md @@ -7,21 +7,21 @@ ### Overview This directory contains code necessary to run the GraphSage algorithm. -GraphSage can be viewed as a stochastic generalization of graph convolutions, and it is especially useful for massive, dynamic graphs that contain rich feature information. +GraphSage can be viewed as a stochastic generalization of graph convolutions, and it is especially useful for massive, dynamic graphs that contain rich feature information. See our [paper](https://arxiv.org/pdf/1706.02216.pdf) for details on the algorithm. *Note:* GraphSage now also has better support for training on smaller, static graphs and graphs that don't have node features. The original algorithm and paper are focused on the task of inductive generalization (i.e., generating embeddings for nodes that were not present during training), but many benchmarks/tasks use simple static graphs that do not necessarily have features. To support this use case, GraphSage now includes optional "identity features" that can be used with or without other node attributes. -Including identity features will increase the runtime, but also potentially increase performance (at the usual risk of overfitting). +Including identity features will increase the runtime, but also potentially increase performance (at the usual risk of overfitting). See the section on "Running the code" below. The example_data subdirectory contains a small example of the protein-protein interaction data, which includes 3 training graphs + one validation graph and one test graph. The full Reddit and PPI datasets (described in the paper) are available on the [project website](http://snap.stanford.edu/graphsage/). -If you make use of this code or the GraphSage algorithm in your work, please cite the following paper: +If you make use of this code or the GraphSage algorithm in your work, please cite the following paper: @inproceedings{hamilton2017inductive, author = {Hamilton, William L. and Ying, Rex and Leskovec, Jure}, @@ -38,30 +38,46 @@ Recent versions of TensorFlow, numpy, scipy, and networkx are required. The example_unsupervised.sh and example_supervised.sh files contain example usages of the code, which use the unsupervised and supervised variants of GraphSage, respectively. -If your benchmark/task does not require generalizing to unseen data, we recommend you try setting the "--identity_dim" flag to a value in the range [64,256]. -This flag will make the model embed unique node ids as attributes, which will increase the runtime and number of parameters but also potentially increase the performance. +If your benchmark/task does not require generalizing to unseen data, we recommend you try setting the "--identity_dim" flag to a value in the range [64,256]. +This flag will make the model embed unique node ids as attributes, which will increase the runtime and number of parameters but also potentially increase the performance. Note that you should set this flag and *not* try to pass dense one-hot vectors as features (due to sparsity). The "dimension" of identity features specifies how many parameters there are per node in the sparse identity-feature lookup table. Note that example_unsupervised.sh sets a very small max iteration number, which can be increased to improve performance. -We generally found that performance continued to improve even after the loss was very near convergence (i.e., even when the loss was decreasing at a very slow rate). +We generally found that performance continued to improve even after the loss was very near convergence (i.e., even when the loss was decreasing at a very slow rate). -*Note:* For the PPI data, and any other multi-ouput dataset that allows individual nodes to belong to multiple classes, it is necessary to set the `--sigmoid` flag during supervised training. By default the model assumes that the dataset is in the "one-hot" categorical setting. +*Note:* For the PPI data, and any other multi-ouput dataset that allows individual nodes to belong to multiple classes, it is necessary to set the `--sigmoid` flag during supervised training. By default the model assumes that the dataset is in the "one-hot" categorical setting. + +#### Docker + +You can run GraphSage inside a [docker](https://docs.docker.com/) image. After cloning the project, build and run the image as following: + + $ docker build -t graphsage . + $ docker run -it graphsage bash + +or start a Jupyter Notebook instead of bash: + + $ docker run -it -p 8888:8888 graphsage + +You can also run the GPU image using [nvidia-docker](https://github.com/NVIDIA/nvidia-docker): + + $ docker build -t graphsage:gpu -f Dockerfile.gpu . + $ nvidia-docker run -it graphsage:gpu bash #### Input format As input, at minimum the code requires that a --train_prefix option is specified which specifies the following data files: -* -G.json -- A networkx-specified json file describing the input graph. Nodes have 'val' and 'test' attributes specifying if they are a part of the validation and test sets, respectively. +* -G.json -- A networkx-specified json file describing the input graph. Nodes have 'val' and 'test' attributes specifying if they are a part of the validation and test sets, respectively. * -id_map.json -- A json-stored dictionary mapping the graph node ids to consecutive integers. * -id_map.json -- A json-stored dictionary mapping the graph node ids to classes. * -feats.npy [optional] --- A numpy-stored array of node features; ordering given by id_map.json. Can be omitted and only identity features will be used. * -walks.txt [optional] --- A text file specifying random walk co-occurrences (one pair per line) (*only for unsupervised version of graphsage) -To run the model on a new dataset, you need to make data files in the format described above. +To run the model on a new dataset, you need to make data files in the format described above. To run random walks for the unsupervised model and to generate the -walks.txt file) you can use the `run_walks` function in `graphsage.utils`. -#### Model variants +#### Model variants The user must also specify a --model, the variants of which are described in detail in the paper: * graphsage_mean -- GraphSage with mean-based aggregator * graphsage_seq -- GraphSage with LSTM-based aggregator @@ -70,7 +86,7 @@ The user must also specify a --model, the variants of which are described in det * n2v -- an implementation of [DeepWalk](https://arxiv.org/abs/1403.6652) (called n2v for short in the code.) #### Logging directory -Finally, a --base_log_dir should be specified (it defaults to the current directory). +Finally, a --base_log_dir should be specified (it defaults to the current directory). The output of the model and log files will be stored in a subdirectory of the base_log_dir. The path to the logged data will be of the form `-/graphsage-/`. The supervised model will output F1 scores, while the unsupervised model will train embeddings and store them. @@ -86,5 +102,5 @@ The `eval_scripts` directory contains examples of feeding the embeddings into si #### Acknowledgements The original version of this code base was originally forked from https://github.com/tkipf/gcn/, and we owe many thanks to Thomas Kipf for making his code available. -We also thank Yuanfang Li and Xin Li who contributed to a course project that was based on this work. -Please see the [paper](https://arxiv.org/pdf/1706.02216.pdf) for funding details and additional (non-code related) acknowledgements. +We also thank Yuanfang Li and Xin Li who contributed to a course project that was based on this work. +Please see the [paper](https://arxiv.org/pdf/1706.02216.pdf) for funding details and additional (non-code related) acknowledgements. From 93d2aaf259bb0349f595093e36d4992b0c9e11f0 Mon Sep 17 00:00:00 2001 From: williamleif Date: Thu, 12 Oct 2017 14:15:21 -0700 Subject: [PATCH 08/10] Cleaned up experimental run files. --- eval_scripts/citation_eval.py | 4 ++-- eval_scripts/ppi_eval.py | 4 ++-- eval_scripts/reddit_eval.py | 4 ++-- graphsage/minibatch.py | 5 +++++ 4 files changed, 11 insertions(+), 6 deletions(-) diff --git a/eval_scripts/citation_eval.py b/eval_scripts/citation_eval.py index 3707a53..feb69cc 100644 --- a/eval_scripts/citation_eval.py +++ b/eval_scripts/citation_eval.py @@ -31,11 +31,11 @@ def run_regression(train_embeds, train_labels, test_embeds, test_labels): if __name__ == '__main__': parser = ArgumentParser("Run evaluation on citation data.") parser.add_argument("dataset_dir", help="Path to directory containing the dataset.") - parser.add_argument("data_dir", help="Path to directory containing the learned node embeddings.") + parser.add_argument("embed_dir", help="Path to directory containing the learned node embeddings.") parser.add_argument("setting", help="Either val or test.") args = parser.parse_args() dataset_dir = args.dataset_dir - data_dir = args.data_dir + data_dir = args.embed_dir setting = args.setting print("Loading data...") diff --git a/eval_scripts/ppi_eval.py b/eval_scripts/ppi_eval.py index 88b7a9e..9348926 100644 --- a/eval_scripts/ppi_eval.py +++ b/eval_scripts/ppi_eval.py @@ -21,11 +21,11 @@ def run_regression(train_embeds, train_labels, test_embeds, test_labels): if __name__ == '__main__': parser = ArgumentParser("Run evaluation on PPI data.") parser.add_argument("dataset_dir", help="Path to directory containing the dataset.") - parser.add_argument("data_dir", help="Path to directory containing the learned node embeddings. Set to 'feat' for raw features.") + parser.add_argument("embed_dir", help="Path to directory containing the learned node embeddings. Set to 'feat' for raw features.") parser.add_argument("setting", help="Either val or test.") args = parser.parse_args() dataset_dir = args.dataset_dir - data_dir = args.data_dir + data_dir = args.embed_dir setting = args.setting print("Loading data...") diff --git a/eval_scripts/reddit_eval.py b/eval_scripts/reddit_eval.py index a0f68c6..7161084 100644 --- a/eval_scripts/reddit_eval.py +++ b/eval_scripts/reddit_eval.py @@ -24,11 +24,11 @@ def run_regression(train_embeds, train_labels, test_embeds, test_labels): if __name__ == '__main__': parser = ArgumentParser("Run evaluation on Reddit data.") parser.add_argument("dataset_dir", help="Path to directory containing the dataset.") - parser.add_argument("data_dir", help="Path to directory containing the learned node embeddings. Set to 'feat' for raw features.") + parser.add_argument("embed_dir", help="Path to directory containing the learned node embeddings. Set to 'feat' for raw features.") parser.add_argument("setting", help="Either val or test.") args = parser.parse_args() dataset_dir = args.dataset_dir - data_dir = args.data_dir + data_dir = args.embed_dir setting = args.setting print("Loading data...") diff --git a/graphsage/minibatch.py b/graphsage/minibatch.py index 60c7b97..1cfd6d9 100644 --- a/graphsage/minibatch.py +++ b/graphsage/minibatch.py @@ -59,13 +59,18 @@ class EdgeMinibatchIterator(object): def _remove_isolated(self, edge_list): new_edge_list = [] + missing = 0 for n1, n2 in edge_list: + if not n1 in self.G.node or not n2 in self.G.node: + missing += 1 + continue if (self.deg[self.id2idx[n1]] == 0 or self.deg[self.id2idx[n2]] == 0) \ and (not self.G.node[n1]['test'] or self.G.node[n1]['val']) \ and (not self.G.node[n2]['test'] or self.G.node[n2]['val']): continue else: new_edge_list.append((n1,n2)) + print("Unexpected missing:", missing) return new_edge_list def construct_adj(self): From 21f5c9cd34d4cf4809e127fe4afb6846ba4e5dbc Mon Sep 17 00:00:00 2001 From: williamleif Date: Thu, 12 Oct 2017 14:17:40 -0700 Subject: [PATCH 09/10] Describing mean pooling. --- README.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index d2ac546..f6474e6 100644 --- a/README.md +++ b/README.md @@ -65,7 +65,8 @@ you can use the `run_walks` function in `graphsage.utils`. The user must also specify a --model, the variants of which are described in detail in the paper: * graphsage_mean -- GraphSage with mean-based aggregator * graphsage_seq -- GraphSage with LSTM-based aggregator -* graphsage_pool -- GraphSage with max-pooling aggregator +* graphsage_maxpool -- GraphSage with max-pooling aggregator (as described in the NIPS 2017 paper) +* graphsage_meanpool -- GraphSage with mean-pooling aggregator (a variant of the pooling aggregator, where the element-wie mean replaces the element-wise max). * gcn -- GraphSage with GCN-based aggregator * n2v -- an implementation of [DeepWalk](https://arxiv.org/abs/1403.6652) (called n2v for short in the code.) From 8062f032fd16acfb04f72c495d4829b9d07353b3 Mon Sep 17 00:00:00 2001 From: William L Hamilton Date: Thu, 12 Oct 2017 15:13:49 -0700 Subject: [PATCH 10/10] Updated Docker description --- README.md | 35 +++++++++++++++++++---------------- 1 file changed, 19 insertions(+), 16 deletions(-) diff --git a/README.md b/README.md index c1a522a..5f09e26 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,25 @@ If you make use of this code or the GraphSage algorithm in your work, please cit ### Requirements -Recent versions of TensorFlow, numpy, scipy, and networkx are required. +Recent versions of TensorFlow, numpy, scipy, and networkx are required (but networkx must be <=1.11). To guarantee that you have the right package versions, you can use [docker](https://docs.docker.com/) to easily set up a virtual environment. See the Docker subsection below for more info. + +#### Docker + +If you do not have [docker](https://docs.docker.com/) installed, you will need to do so. (Just click on the preceding link, the installation is pretty painless). + +You can run GraphSage inside a [docker](https://docs.docker.com/) image. After cloning the project, build and run the image as following: + + $ docker build -t graphsage . + $ docker run -it graphsage bash + +or start a Jupyter Notebook instead of bash: + + $ docker run -it -p 8888:8888 graphsage + +You can also run the GPU image using [nvidia-docker](https://github.com/NVIDIA/nvidia-docker): + + $ docker build -t graphsage:gpu -f Dockerfile.gpu . + $ nvidia-docker run -it graphsage:gpu bash ### Running the code @@ -48,21 +66,6 @@ We generally found that performance continued to improve even after the loss was *Note:* For the PPI data, and any other multi-ouput dataset that allows individual nodes to belong to multiple classes, it is necessary to set the `--sigmoid` flag during supervised training. By default the model assumes that the dataset is in the "one-hot" categorical setting. -#### Docker - -You can run GraphSage inside a [docker](https://docs.docker.com/) image. After cloning the project, build and run the image as following: - - $ docker build -t graphsage . - $ docker run -it graphsage bash - -or start a Jupyter Notebook instead of bash: - - $ docker run -it -p 8888:8888 graphsage - -You can also run the GPU image using [nvidia-docker](https://github.com/NVIDIA/nvidia-docker): - - $ docker build -t graphsage:gpu -f Dockerfile.gpu . - $ nvidia-docker run -it graphsage:gpu bash #### Input format As input, at minimum the code requires that a --train_prefix option is specified which specifies the following data files: