diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..8dad90a --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +.git +Dockerfile* +.gitignore diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..71a833c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,6 @@ +FROM gcr.io/tensorflow/tensorflow:1.3.0 + +RUN pip install networkx==1.11 +RUN rm /notebooks/* + +COPY . /notebooks diff --git a/Dockerfile.gpu b/Dockerfile.gpu new file mode 100644 index 0000000..681f22c --- /dev/null +++ b/Dockerfile.gpu @@ -0,0 +1,6 @@ +FROM gcr.io/tensorflow/tensorflow:1.3.0-gpu + +RUN pip install networkx==1.11 +RUN rm /notebooks/* + +COPY . /notebooks diff --git a/README.md b/README.md index 4ccf602..5f09e26 100644 --- a/README.md +++ b/README.md @@ -1,4 +1,4 @@ -## GraphSAGE: Inductive Representation Learning on Large Graphs +## GraphSage: Representation Learning on Large Graphs #### Authors: [William L. Hamilton](http://stanford.edu/~wleif) (wleif@stanford.edu), [Rex Ying](http://joy-of-thinking.weebly.com/) (rexying@stanford.edu) #### [Project Website](http://snap.stanford.edu/graphsage/) @@ -6,59 +6,91 @@ ### Overview -This directory contains code necessary to run the GraphSAGE algorithm. +This directory contains code necessary to run the GraphSage algorithm. +GraphSage can be viewed as a stochastic generalization of graph convolutions, and it is especially useful for massive, dynamic graphs that contain rich feature information. See our [paper](https://arxiv.org/pdf/1706.02216.pdf) for details on the algorithm. +*Note:* GraphSage now also has better support for training on smaller, static graphs and graphs that don't have node features. +The original algorithm and paper are focused on the task of inductive generalization (i.e., generating embeddings for nodes that were not present during training), +but many benchmarks/tasks use simple static graphs that do not necessarily have features. +To support this use case, GraphSage now includes optional "identity features" that can be used with or without other node attributes. +Including identity features will increase the runtime, but also potentially increase performance (at the usual risk of overfitting). +See the section on "Running the code" below. + The example_data subdirectory contains a small example of the protein-protein interaction data, which includes 3 training graphs + one validation graph and one test graph. The full Reddit and PPI datasets (described in the paper) are available on the [project website](http://snap.stanford.edu/graphsage/). -If you make use of this code or the GraphSAGE algorithm in your work, please cite the following paper: +If you make use of this code or the GraphSage algorithm in your work, please cite the following paper: - @article{hamilton2017inductive, + @inproceedings{hamilton2017inductive, author = {Hamilton, William L. and Ying, Rex and Leskovec, Jure}, title = {Inductive Representation Learning on Large Graphs}, - journal = {arXiv preprint, arXiv:1603.04467}, + booktitle = {NIPS}, year = {2017} } ### Requirements -Recent versions of TensorFlow, numpy, scipy, and networkx are required. +Recent versions of TensorFlow, numpy, scipy, and networkx are required (but networkx must be <=1.11). To guarantee that you have the right package versions, you can use [docker](https://docs.docker.com/) to easily set up a virtual environment. See the Docker subsection below for more info. + +#### Docker + +If you do not have [docker](https://docs.docker.com/) installed, you will need to do so. (Just click on the preceding link, the installation is pretty painless). + +You can run GraphSage inside a [docker](https://docs.docker.com/) image. After cloning the project, build and run the image as following: + + $ docker build -t graphsage . + $ docker run -it graphsage bash + +or start a Jupyter Notebook instead of bash: + + $ docker run -it -p 8888:8888 graphsage + +You can also run the GPU image using [nvidia-docker](https://github.com/NVIDIA/nvidia-docker): + + $ docker build -t graphsage:gpu -f Dockerfile.gpu . + $ nvidia-docker run -it graphsage:gpu bash ### Running the code -The example_unsupervised.sh and example_supervised.sh files contain example usages of the code, which use the unsupervised and supervised variants of GraphSAGE, respectively. -Note that example_unsupervised.sh sets a very small max iteration number, which can be increased to improve performance. -We generally found that performance continued to improve even after the loss was very near convergence (i.e., even when the loss was decreasing at a very slow rate). +The example_unsupervised.sh and example_supervised.sh files contain example usages of the code, which use the unsupervised and supervised variants of GraphSage, respectively. + +If your benchmark/task does not require generalizing to unseen data, we recommend you try setting the "--identity_dim" flag to a value in the range [64,256]. +This flag will make the model embed unique node ids as attributes, which will increase the runtime and number of parameters but also potentially increase the performance. +Note that you should set this flag and *not* try to pass dense one-hot vectors as features (due to sparsity). +The "dimension" of identity features specifies how many parameters there are per node in the sparse identity-feature lookup table. + +Note that example_unsupervised.sh sets a very small max iteration number, which can be increased to improve performance. +We generally found that performance continued to improve even after the loss was very near convergence (i.e., even when the loss was decreasing at a very slow rate). + +*Note:* For the PPI data, and any other multi-ouput dataset that allows individual nodes to belong to multiple classes, it is necessary to set the `--sigmoid` flag during supervised training. By default the model assumes that the dataset is in the "one-hot" categorical setting. -*Note:* For the PPI data, and any other multi-ouput dataset that allows individual nodes to belong to multiple classes, it is necessary to set the `--sigmoid` flag during supervised training. By default the model assumes that the dataset is in the "one-hot" categorical setting. #### Input format As input, at minimum the code requires that a --train_prefix option is specified which specifies the following data files: -* -G.json -- A networkx-specified json file describing the input graph. Nodes have 'val' and 'test' attributes specifying if they are a part of the validation and test sets, respectively. +* -G.json -- A networkx-specified json file describing the input graph. Nodes have 'val' and 'test' attributes specifying if they are a part of the validation and test sets, respectively. * -id_map.json -- A json-stored dictionary mapping the graph node ids to consecutive integers. * -id_map.json -- A json-stored dictionary mapping the graph node ids to classes. -* -feats.npy --- A numpy-stored array of node features; ordering given by id_map.json -* -walks.txt --- A text file specifying random walk co-occurrences (one pair per line) (*only for unsupervised version of graphsage) +* -feats.npy [optional] --- A numpy-stored array of node features; ordering given by id_map.json. Can be omitted and only identity features will be used. +* -walks.txt [optional] --- A text file specifying random walk co-occurrences (one pair per line) (*only for unsupervised version of graphsage) -To run the model on a new dataset, you need to make data files in the format described above. +To run the model on a new dataset, you need to make data files in the format described above. To run random walks for the unsupervised model and to generate the -walks.txt file) you can use the `run_walks` function in `graphsage.utils`. - - -#### Model variants +#### Model variants The user must also specify a --model, the variants of which are described in detail in the paper: -* graphsage_mean -- GraphSAGE with mean-based aggregator -* graphsage_seq -- GraphSAGE with LSTM-based aggregator -* graphsage_pool -- GraphSAGE with max-pooling aggregator -* gcn -- GraphSAGE with GCN-based aggregator +* graphsage_mean -- GraphSage with mean-based aggregator +* graphsage_seq -- GraphSage with LSTM-based aggregator +* graphsage_maxpool -- GraphSage with max-pooling aggregator (as described in the NIPS 2017 paper) +* graphsage_meanpool -- GraphSage with mean-pooling aggregator (a variant of the pooling aggregator, where the element-wie mean replaces the element-wise max). +* gcn -- GraphSage with GCN-based aggregator * n2v -- an implementation of [DeepWalk](https://arxiv.org/abs/1403.6652) (called n2v for short in the code.) #### Logging directory -Finally, a --base_log_dir should be specified (it defaults to the current directory). +Finally, a --base_log_dir should be specified (it defaults to the current directory). The output of the model and log files will be stored in a subdirectory of the base_log_dir. The path to the logged data will be of the form `-/graphsage-/`. The supervised model will output F1 scores, while the unsupervised model will train embeddings and store them. @@ -67,12 +99,12 @@ Note that the full log outputs and stored embeddings can be 5-10Gb in size (on t #### Using the output of the unsupervised models -The unsupervised variants of GraphSAGE will output embeddings to the logging directory as described above. +The unsupervised variants of GraphSage will output embeddings to the logging directory as described above. These embeddings can then be used in downstream machine learning applications. The `eval_scripts` directory contains examples of feeding the embeddings into simple logistic classifiers. #### Acknowledgements The original version of this code base was originally forked from https://github.com/tkipf/gcn/, and we owe many thanks to Thomas Kipf for making his code available. -We also thank Yuanfang Li and Xin Li who contributed to a course project that was based on this work. -Please see the [paper](https://arxiv.org/pdf/1706.02216.pdf) for funding details and additional (non-code related) acknowledgements. +We also thank Yuanfang Li and Xin Li who contributed to a course project that was based on this work. +Please see the [paper](https://arxiv.org/pdf/1706.02216.pdf) for funding details and additional (non-code related) acknowledgements. diff --git a/eval_scripts/citation_eval.py b/eval_scripts/citation_eval.py index 3707a53..feb69cc 100644 --- a/eval_scripts/citation_eval.py +++ b/eval_scripts/citation_eval.py @@ -31,11 +31,11 @@ def run_regression(train_embeds, train_labels, test_embeds, test_labels): if __name__ == '__main__': parser = ArgumentParser("Run evaluation on citation data.") parser.add_argument("dataset_dir", help="Path to directory containing the dataset.") - parser.add_argument("data_dir", help="Path to directory containing the learned node embeddings.") + parser.add_argument("embed_dir", help="Path to directory containing the learned node embeddings.") parser.add_argument("setting", help="Either val or test.") args = parser.parse_args() dataset_dir = args.dataset_dir - data_dir = args.data_dir + data_dir = args.embed_dir setting = args.setting print("Loading data...") diff --git a/eval_scripts/ppi_eval.py b/eval_scripts/ppi_eval.py index a72c05d..c63c577 100644 --- a/eval_scripts/ppi_eval.py +++ b/eval_scripts/ppi_eval.py @@ -32,11 +32,11 @@ def run_regression(train_embeds, train_labels, test_embeds, test_labels): if __name__ == '__main__': parser = ArgumentParser("Run evaluation on PPI data.") parser.add_argument("dataset_dir", help="Path to directory containing the dataset.") - parser.add_argument("data_dir", help="Path to directory containing the learned node embeddings. Set to 'feat' for raw features.") + parser.add_argument("embed_dir", help="Path to directory containing the learned node embeddings. Set to 'feat' for raw features.") parser.add_argument("setting", help="Either val or test.") args = parser.parse_args() dataset_dir = args.dataset_dir - data_dir = args.data_dir + data_dir = args.embed_dir setting = args.setting print("Loading data...") diff --git a/eval_scripts/reddit_eval.py b/eval_scripts/reddit_eval.py index a0f68c6..7161084 100644 --- a/eval_scripts/reddit_eval.py +++ b/eval_scripts/reddit_eval.py @@ -24,11 +24,11 @@ def run_regression(train_embeds, train_labels, test_embeds, test_labels): if __name__ == '__main__': parser = ArgumentParser("Run evaluation on Reddit data.") parser.add_argument("dataset_dir", help="Path to directory containing the dataset.") - parser.add_argument("data_dir", help="Path to directory containing the learned node embeddings. Set to 'feat' for raw features.") + parser.add_argument("embed_dir", help="Path to directory containing the learned node embeddings. Set to 'feat' for raw features.") parser.add_argument("setting", help="Either val or test.") args = parser.parse_args() dataset_dir = args.dataset_dir - data_dir = args.data_dir + data_dir = args.embed_dir setting = args.setting print("Loading data...") diff --git a/graphsage/aggregators.py b/graphsage/aggregators.py index 705ec69..7dbd252 100644 --- a/graphsage/aggregators.py +++ b/graphsage/aggregators.py @@ -116,12 +116,12 @@ class GCNAggregator(Layer): return self.act(output) -class PoolingAggregator(Layer): +class MaxPoolingAggregator(Layer): """ Aggregates via max-pooling over MLP functions. """ def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None, dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs): - super(PoolingAggregator, self).__init__(**kwargs) + super(MaxPoolingAggregator, self).__init__(**kwargs) self.dropout = dropout self.bias = bias @@ -194,12 +194,91 @@ class PoolingAggregator(Layer): return self.act(output) -class TwoLayerPoolingAggregator(Layer): +class MeanPoolingAggregator(Layer): + """ Aggregates via mean-pooling over MLP functions. + """ + def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None, + dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs): + super(MeanPoolingAggregator, self).__init__(**kwargs) + + self.dropout = dropout + self.bias = bias + self.act = act + self.concat = concat + + if neigh_input_dim is None: + neigh_input_dim = input_dim + + if name is not None: + name = '/' + name + else: + name = '' + + if model_size == "small": + hidden_dim = self.hidden_dim = 512 + elif model_size == "big": + hidden_dim = self.hidden_dim = 1024 + + self.mlp_layers = [] + self.mlp_layers.append(Dense(input_dim=neigh_input_dim, + output_dim=hidden_dim, + act=tf.nn.relu, + dropout=dropout, + sparse_inputs=False, + logging=self.logging)) + + with tf.variable_scope(self.name + name + '_vars'): + self.vars['neigh_weights'] = glorot([hidden_dim, output_dim], + name='neigh_weights') + + self.vars['self_weights'] = glorot([input_dim, output_dim], + name='self_weights') + if self.bias: + self.vars['bias'] = zeros([self.output_dim], name='bias') + + if self.logging: + self._log_vars() + + self.input_dim = input_dim + self.output_dim = output_dim + self.neigh_input_dim = neigh_input_dim + + def _call(self, inputs): + self_vecs, neigh_vecs = inputs + neigh_h = neigh_vecs + + dims = tf.shape(neigh_h) + batch_size = dims[0] + num_neighbors = dims[1] + # [nodes * sampled neighbors] x [hidden_dim] + h_reshaped = tf.reshape(neigh_h, (batch_size * num_neighbors, self.neigh_input_dim)) + + for l in self.mlp_layers: + h_reshaped = l(h_reshaped) + neigh_h = tf.reshape(h_reshaped, (batch_size, num_neighbors, self.hidden_dim)) + neigh_h = tf.reduce_mean(neigh_h, axis=1) + + from_neighs = tf.matmul(neigh_h, self.vars['neigh_weights']) + from_self = tf.matmul(self_vecs, self.vars["self_weights"]) + + if not self.concat: + output = tf.add_n([from_self, from_neighs]) + else: + output = tf.concat([from_self, from_neighs], axis=1) + + # bias + if self.bias: + output += self.vars['bias'] + + return self.act(output) + + +class TwoMaxLayerPoolingAggregator(Layer): """ Aggregates via pooling over two MLP functions. """ def __init__(self, input_dim, output_dim, model_size="small", neigh_input_dim=None, dropout=0., bias=False, act=tf.nn.relu, name=None, concat=False, **kwargs): - super(TwoLayerPoolingAggregator, self).__init__(**kwargs) + super(TwoMaxLayerPoolingAggregator, self).__init__(**kwargs) self.dropout = dropout self.bias = bias diff --git a/graphsage/minibatch.py b/graphsage/minibatch.py index a480e15..1f53784 100644 --- a/graphsage/minibatch.py +++ b/graphsage/minibatch.py @@ -42,15 +42,15 @@ class EdgeMinibatchIterator(object): self.train_edges = self.edges = np.random.permutation(edges) if not n2v_retrain: self.train_edges = self._remove_isolated(self.train_edges) - self.val_edges = [e for e in G.edges_iter() if G[e[0]][e[1]]['train_removed']] + self.val_edges = [e for e in G.edges() if G[e[0]][e[1]]['train_removed']] else: if fixed_n2v: self.train_edges = self.val_edges = self._n2v_prune(self.edges) else: self.train_edges = self.val_edges = self.edges - print(len([n for n in G.nodes_iter() if not G.node[n]['test'] and not G.node[n]['val']]), 'train nodes') - print(len([n for n in G.nodes_iter() if G.node[n]['test'] or G.node[n]['val']]), 'test nodes') + print(len([n for n in G.nodes() if not G.node[n]['test'] and not G.node[n]['val']]), 'train nodes') + print(len([n for n in G.nodes() if G.node[n]['test'] or G.node[n]['val']]), 'test nodes') self.val_set_size = len(self.val_edges) def _n2v_prune(self, edges): @@ -59,13 +59,18 @@ class EdgeMinibatchIterator(object): def _remove_isolated(self, edge_list): new_edge_list = [] + missing = 0 for n1, n2 in edge_list: + if not n1 in self.G.node or not n2 in self.G.node: + missing += 1 + continue if (self.deg[self.id2idx[n1]] == 0 or self.deg[self.id2idx[n2]] == 0) \ and (not self.G.node[n1]['test'] or self.G.node[n1]['val']) \ and (not self.G.node[n2]['test'] or self.G.node[n2]['val']): continue else: new_edge_list.append((n1,n2)) + print("Unexpected missing:", missing) return new_edge_list def construct_adj(self): @@ -153,7 +158,7 @@ class EdgeMinibatchIterator(object): def label_val(self): train_edges = [] val_edges = [] - for n1, n2 in self.G.edges_iter(): + for n1, n2 in self.G.edges(): if (self.G.node[n1]['val'] or self.G.node[n1]['test'] or self.G.node[n2]['val'] or self.G.node[n2]['test']): val_edges.append((n1,n2)) @@ -200,8 +205,8 @@ class NodeMinibatchIterator(object): self.adj, self.deg = self.construct_adj() self.test_adj = self.construct_test_adj() - self.val_nodes = [n for n in self.G.nodes_iter() if self.G.node[n]['val']] - self.test_nodes = [n for n in self.G.nodes_iter() if self.G.node[n]['test']] + self.val_nodes = [n for n in self.G.nodes() if self.G.node[n]['val']] + self.test_nodes = [n for n in self.G.nodes() if self.G.node[n]['test']] self.no_train_nodes_set = set(self.val_nodes + self.test_nodes) self.train_nodes = set(G.nodes()).difference(self.no_train_nodes_set) diff --git a/graphsage/models.py b/graphsage/models.py index b40b17f..e9fe791 100644 --- a/graphsage/models.py +++ b/graphsage/models.py @@ -7,7 +7,7 @@ import graphsage.layers as layers import graphsage.metrics as metrics from .prediction import BipartiteEdgePredLayer -from .aggregators import MeanAggregator, PoolingAggregator, SeqAggregator, GCNAggregator, TwoLayerPoolingAggregator +from .aggregators import MeanAggregator, MaxPoolingAggregator, MeanPoolingAggregator, SeqAggregator, GCNAggregator flags = tf.app.flags FLAGS = flags.FLAGS @@ -191,12 +191,13 @@ class SampleAndAggregate(GeneralizedModel): def __init__(self, placeholders, features, adj, degrees, layer_infos, concat=True, aggregator_type="mean", - model_size="small", + model_size="small", identity_dim=0, **kwargs): ''' Args: - placeholders: Stanford TensorFlow placeholder object. - - features: Numpy array with node features. + - features: Numpy array with node features. + NOTE: Pass a None object to train in featureless mode (identity features for nodes)! - adj: Numpy array with adjacency lists (padded with random re-samples) - degrees: Numpy array with node degrees. - layer_infos: List of SAGEInfo namedtuples that describe the parameters of all @@ -204,16 +205,17 @@ class SampleAndAggregate(GeneralizedModel): - concat: whether to concatenate during recursive iterations - aggregator_type: how to aggregate neighbor information - model_size: one of "small" and "big" + - identity_dim: Set to positive int to use identity features (slow and cannot generalize, but better accuracy) ''' super(SampleAndAggregate, self).__init__(**kwargs) if aggregator_type == "mean": self.aggregator_cls = MeanAggregator elif aggregator_type == "seq": self.aggregator_cls = SeqAggregator - elif aggregator_type == "pool": - self.aggregator_cls = PoolingAggregator - elif aggregator_type == "pool_2": - self.aggregator_cls = TwoLayerPoolingAggregator + elif aggregator_type == "maxpool": + self.aggregator_cls = MaxPoolingAggregator + elif aggregator_type == "meanpool": + self.aggregator_cls = MeanPoolingAggregator elif aggregator_type == "gcn": self.aggregator_cls = GCNAggregator else: @@ -224,11 +226,22 @@ class SampleAndAggregate(GeneralizedModel): self.inputs2 = placeholders["batch2"] self.model_size = model_size self.adj_info = adj - self.features = tf.Variable(tf.constant(features, dtype=tf.float32), trainable=False) + if identity_dim > 0: + self.embeds = tf.get_variable("node_embeddings", [adj.get_shape().as_list()[0], identity_dim]) + else: + self.embeds = None + if features is None: + if identity_dim == 0: + raise Exception("Must have a positive value for identity feature dimension if no input features given.") + self.features = self.embeds + else: + self.features = tf.Variable(tf.constant(features, dtype=tf.float32), trainable=False) + if not self.embeds is None: + self.features = tf.concat([self.embeds, self.features], axis=1) self.degrees = degrees self.concat = concat - self.dims = [features.shape[1]] + self.dims = [(0 if features is None else features.shape[1]) + identity_dim] self.dims.extend([layer_infos[i].output_dim for i in range(len(layer_infos))]) self.batch_size = placeholders["batch_size"] self.placeholders = placeholders diff --git a/graphsage/supervised_models.py b/graphsage/supervised_models.py index a8658b6..9ea123c 100644 --- a/graphsage/supervised_models.py +++ b/graphsage/supervised_models.py @@ -2,7 +2,7 @@ import tensorflow as tf import graphsage.models as models import graphsage.layers as layers -from graphsage.aggregators import MeanAggregator, PoolingAggregator, SeqAggregator, GCNAggregator, TwoLayerPoolingAggregator +from graphsage.aggregators import MeanAggregator, MaxPoolingAggregator, MeanPoolingAggregator, SeqAggregator, GCNAggregator flags = tf.app.flags FLAGS = flags.FLAGS @@ -13,7 +13,7 @@ class SupervisedGraphsage(models.SampleAndAggregate): def __init__(self, num_classes, placeholders, features, adj, degrees, layer_infos, concat=True, aggregator_type="mean", - model_size="small", sigmoid_loss=False, + model_size="small", sigmoid_loss=False, identity_dim=0, **kwargs): ''' Args: @@ -35,10 +35,10 @@ class SupervisedGraphsage(models.SampleAndAggregate): self.aggregator_cls = MeanAggregator elif aggregator_type == "seq": self.aggregator_cls = SeqAggregator - elif aggregator_type == "pool": - self.aggregator_cls = PoolingAggregator - elif aggregator_type == "pool_2": - self.aggregator_cls = TwoLayerPoolingAggregator + elif aggregator_type == "meanpool": + self.aggregator_cls = MeanPoolingAggregator + elif aggregator_type == "maxpool": + self.aggregator_cls = MaxPoolingAggregator elif aggregator_type == "gcn": self.aggregator_cls = GCNAggregator else: @@ -48,13 +48,23 @@ class SupervisedGraphsage(models.SampleAndAggregate): self.inputs1 = placeholders["batch"] self.model_size = model_size self.adj_info = adj - self.features = tf.Variable(tf.constant(features, dtype=tf.float32), trainable=False) + if identity_dim > 0: + self.embeds = tf.get_variable("node_embeddings", [adj.get_shape().as_list()[0], identity_dim]) + else: + self.embeds = None + if features is None: + if identity_dim == 0: + raise Exception("Must have a positive value for identity feature dimension if no input features given.") + self.features = self.embeds + else: + self.features = tf.Variable(tf.constant(features, dtype=tf.float32), trainable=False) + if not self.embeds is None: + self.features = tf.concat([self.embeds, self.features], axis=1) self.degrees = degrees self.concat = concat self.num_classes = num_classes self.sigmoid_loss = sigmoid_loss - - self.dims = [features.shape[1]] + self.dims = [(0 if features is None else features.shape[1]) + identity_dim] self.dims.extend([layer_infos[i].output_dim for i in range(len(layer_infos))]) self.batch_size = placeholders["batch_size"] self.placeholders = placeholders diff --git a/graphsage/supervised_train.py b/graphsage/supervised_train.py index fa52581..240d9aa 100644 --- a/graphsage/supervised_train.py +++ b/graphsage/supervised_train.py @@ -39,13 +39,14 @@ flags.DEFINE_float('dropout', 0.0, 'dropout rate (1 - keep probability).') flags.DEFINE_float('weight_decay', 0.0, 'weight for l2 loss on embedding matrix.') flags.DEFINE_integer('max_degree', 128, 'maximum node degree.') flags.DEFINE_integer('samples_1', 25, 'number of samples in layer 1') -flags.DEFINE_integer('samples_2', 10, 'number of users samples in layer 2') -flags.DEFINE_integer('samples_3', 0, 'number of users samples in layer 3. (Only or mean model)') +flags.DEFINE_integer('samples_2', 10, 'number of samples in layer 2') +flags.DEFINE_integer('samples_3', 0, 'number of users samples in layer 3. (Only for mean model)') flags.DEFINE_integer('dim_1', 128, 'Size of output dim (final is 2x this, if using concat)') flags.DEFINE_integer('dim_2', 128, 'Size of output dim (final is 2x this, if using concat)') flags.DEFINE_boolean('random_context', True, 'Whether to use random context or direct edges') flags.DEFINE_integer('batch_size', 512, 'minibatch size.') flags.DEFINE_boolean('sigmoid', False, 'whether to use sigmoid loss') +flags.DEFINE_integer('identity_dim', 0, 'Set to positive value to use identity embedding features of that dimension. Default 0.') #logging, saving, validation settings etc. flags.DEFINE_string('base_log_dir', '.', 'base directory for logging and saving embeddings') @@ -124,13 +125,14 @@ def train(train_data, test_data=None): features = train_data[1] id_map = train_data[2] class_map = train_data[4] - if isinstance(class_map.values()[0], list): - num_classes = len(class_map.values()[0]) + if isinstance(list(class_map.values())[0], list): + num_classes = len(list(class_map.values())[0]) else: num_classes = len(set(class_map.values())) - # pad with dummy zero vector - features = np.vstack([features, np.zeros((features.shape[1],))]) + if not features is None: + # pad with dummy zero vector + features = np.vstack([features, np.zeros((features.shape[1],))]) context_pairs = train_data[3] if FLAGS.random_context else None placeholders = construct_placeholders(num_classes) @@ -164,6 +166,7 @@ def train(train_data, test_data=None): layer_infos, model_size=FLAGS.model_size, sigmoid_loss = FLAGS.sigmoid, + identity_dim = FLAGS.identity_dim, logging=True) elif FLAGS.model == 'gcn': # Create model @@ -180,6 +183,7 @@ def train(train_data, test_data=None): model_size=FLAGS.model_size, concat=False, sigmoid_loss = FLAGS.sigmoid, + identity_dim = FLAGS.identity_dim, logging=True) elif FLAGS.model == 'graphsage_seq': @@ -195,9 +199,10 @@ def train(train_data, test_data=None): aggregator_type="seq", model_size=FLAGS.model_size, sigmoid_loss = FLAGS.sigmoid, + identity_dim = FLAGS.identity_dim, logging=True) - elif FLAGS.model == 'graphsage_pool': + elif FLAGS.model == 'graphsage_maxpool': sampler = UniformNeighborSampler(adj_info) layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)] @@ -210,7 +215,25 @@ def train(train_data, test_data=None): aggregator_type="pool", model_size=FLAGS.model_size, sigmoid_loss = FLAGS.sigmoid, + identity_dim = FLAGS.identity_dim, logging=True) + + elif FLAGS.model == 'graphsage_meanpool': + sampler = UniformNeighborSampler(adj_info) + layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), + SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)] + + model = SupervisedGraphsage(num_classes, placeholders, + features, + adj_info, + minibatch.deg, + layer_infos=layer_infos, + aggregator_type="meanpool", + model_size=FLAGS.model_size, + sigmoid_loss = FLAGS.sigmoid, + identity_dim = FLAGS.identity_dim, + logging=True) + else: raise Exception('Error: model name unrecognized.') diff --git a/graphsage/unsupervised_train.py b/graphsage/unsupervised_train.py index 945aa20..b3162fd 100644 --- a/graphsage/unsupervised_train.py +++ b/graphsage/unsupervised_train.py @@ -43,6 +43,7 @@ flags.DEFINE_boolean('random_context', True, 'Whether to use random context or d flags.DEFINE_integer('neg_sample_size', 20, 'number of negative samples') flags.DEFINE_integer('batch_size', 512, 'minibatch size.') flags.DEFINE_integer('n2v_test_epochs', 1, 'Number of new SGD epochs for n2v.') +flags.DEFINE_integer('identity_dim', 0, 'Set to positive value to use identity embedding features of that dimension. Default 0.') #logging, saving, validation settings etc. flags.DEFINE_boolean('save_embeddings', True, 'whether to save embeddings for all nodes after training') @@ -115,7 +116,7 @@ def save_val_embeddings(sess, model, minibatch_iter, size, out_dir, mod=""): with open(out_dir + name + mod + ".txt", "w") as fp: fp.write("\n".join(map(str,nodes))) -def construct_placeholders(feature_size): +def construct_placeholders(): # Define placeholders placeholders = { 'batch1' : tf.placeholder(tf.int32, shape=(None), name='batch1'), @@ -133,12 +134,12 @@ def train(train_data, test_data=None): features = train_data[1] id_map = train_data[2] - # pad with dummy zero vector - features = np.vstack([features, np.zeros((features.shape[1],))]) - feature_size = features.shape[1] + if not features is None: + # pad with dummy zero vector + features = np.vstack([features, np.zeros((features.shape[1],))]) context_pairs = train_data[3] if FLAGS.random_context else None - placeholders = construct_placeholders(feature_size) + placeholders = construct_placeholders() minibatch = EdgeMinibatchIterator(G, id_map, placeholders, batch_size=FLAGS.batch_size, @@ -159,6 +160,7 @@ def train(train_data, test_data=None): minibatch.deg, layer_infos=layer_infos, model_size=FLAGS.model_size, + identity_dim = FLAGS.identity_dim, logging=True) elif FLAGS.model == 'gcn': # Create model @@ -173,6 +175,7 @@ def train(train_data, test_data=None): layer_infos=layer_infos, aggregator_type="gcn", model_size=FLAGS.model_size, + identity_dim = FLAGS.identity_dim, concat=False, logging=True) @@ -186,11 +189,12 @@ def train(train_data, test_data=None): adj_info, minibatch.deg, layer_infos=layer_infos, + identity_dim = FLAGS.identity_dim, aggregator_type="seq", model_size=FLAGS.model_size, logging=True) - elif FLAGS.model == 'graphsage_pool': + elif FLAGS.model == 'graphsage_maxpool': sampler = UniformNeighborSampler(adj_info) layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)] @@ -200,9 +204,25 @@ def train(train_data, test_data=None): adj_info, minibatch.deg, layer_infos=layer_infos, - aggregator_type="pool", + aggregator_type="maxpool", model_size=FLAGS.model_size, + identity_dim = FLAGS.identity_dim, logging=True) + elif FLAGS.model == 'graphsage_meanpool': + sampler = UniformNeighborSampler(adj_info) + layer_infos = [SAGEInfo("node", sampler, FLAGS.samples_1, FLAGS.dim_1), + SAGEInfo("node", sampler, FLAGS.samples_2, FLAGS.dim_2)] + + model = SampleAndAggregate(placeholders, + features, + adj_info, + minibatch.deg, + layer_infos=layer_infos, + aggregator_type="meanpool", + model_size=FLAGS.model_size, + identity_dim = FLAGS.identity_dim, + logging=True) + elif FLAGS.model == 'n2v': model = Node2VecModel(placeholders, features.shape[0], minibatch.deg, @@ -354,7 +374,7 @@ def train(train_data, test_data=None): def main(argv=None): print("Loading training data..") - train_data = load_data(FLAGS.train_prefix) + train_data = load_data(FLAGS.train_prefix, load_walks=True) print("Done loading training data..") train(train_data) diff --git a/graphsage/utils.py b/graphsage/utils.py index 400b95e..15e3f97 100644 --- a/graphsage/utils.py +++ b/graphsage/utils.py @@ -4,13 +4,14 @@ import numpy as np import random import json import sys +import os from networkx.readwrite import json_graph WALK_LEN=5 N_WALKS=50 -def load_data(prefix, normalize=True): +def load_data(prefix, normalize=True, load_walks=False): G_data = json.load(open(prefix + "-G.json")) G = json_graph.node_link_graph(G_data) if isinstance(G.nodes()[0], int): @@ -18,39 +19,44 @@ def load_data(prefix, normalize=True): else: conversion = lambda n : n - feats = np.load(prefix + "-feats.npy") + if os.path.exists(prefix + "-feats.npy"): + feats = np.load(prefix + "-feats.npy") + else: + print("No features present.. Only identity features will be used.") + feats = None id_map = json.load(open(prefix + "-id_map.json")) - id_map = {conversion(k):int(v) for k,v in id_map.iteritems()} + id_map = {conversion(k):int(v) for k,v in id_map.items()} walks = [] class_map = json.load(open(prefix + "-class_map.json")) - if isinstance(class_map.values()[0], list): + if isinstance(list(class_map.values())[0], list): lab_conversion = lambda n : n else: lab_conversion = lambda n : int(n) - class_map = {conversion(k):lab_conversion(v) for k,v in class_map.iteritems()} + class_map = {conversion(k):lab_conversion(v) for k,v in class_map.items()} ## Make sure the graph has edge train_removed annotations ## (some datasets might already have this..) print("Loaded data.. now preprocessing..") - for edge in G.edges_iter(): + for edge in G.edges(): if (G.node[edge[0]]['val'] or G.node[edge[1]]['val'] or G.node[edge[0]]['test'] or G.node[edge[1]]['test']): G[edge[0]][edge[1]]['train_removed'] = True else: G[edge[0]][edge[1]]['train_removed'] = False - if normalize: + if normalize and not feats is None: from sklearn.preprocessing import StandardScaler train_ids = np.array([id_map[n] for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']]) train_feats = feats[train_ids] scaler = StandardScaler() scaler.fit(train_feats) feats = scaler.transform(feats) - - with open(prefix + "-walks.txt") as fp: - for line in fp: - walks.append(map(conversion, line.split())) + + if load_walks: + with open(prefix + "-walks.txt") as fp: + for line in fp: + walks.append(map(conversion, line.split())) return G, feats, id_map, walks, class_map