diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..8dad90a --- /dev/null +++ b/.dockerignore @@ -0,0 +1,3 @@ +.git +Dockerfile* +.gitignore diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..71a833c --- /dev/null +++ b/Dockerfile @@ -0,0 +1,6 @@ +FROM gcr.io/tensorflow/tensorflow:1.3.0 + +RUN pip install networkx==1.11 +RUN rm /notebooks/* + +COPY . /notebooks diff --git a/Dockerfile.gpu b/Dockerfile.gpu new file mode 100644 index 0000000..681f22c --- /dev/null +++ b/Dockerfile.gpu @@ -0,0 +1,6 @@ +FROM gcr.io/tensorflow/tensorflow:1.3.0-gpu + +RUN pip install networkx==1.11 +RUN rm /notebooks/* + +COPY . /notebooks diff --git a/README.md b/README.md index 976d87d..5f09e26 100644 --- a/README.md +++ b/README.md @@ -7,21 +7,21 @@ ### Overview This directory contains code necessary to run the GraphSage algorithm. -GraphSage can be viewed as a stochastic generalization of graph convolutions, and it is especially useful for massive, dynamic graphs that contain rich feature information. +GraphSage can be viewed as a stochastic generalization of graph convolutions, and it is especially useful for massive, dynamic graphs that contain rich feature information. See our [paper](https://arxiv.org/pdf/1706.02216.pdf) for details on the algorithm. *Note:* GraphSage now also has better support for training on smaller, static graphs and graphs that don't have node features. The original algorithm and paper are focused on the task of inductive generalization (i.e., generating embeddings for nodes that were not present during training), but many benchmarks/tasks use simple static graphs that do not necessarily have features. To support this use case, GraphSage now includes optional "identity features" that can be used with or without other node attributes. -Including identity features will increase the runtime, but also potentially increase performance (at the usual risk of overfitting). +Including identity features will increase the runtime, but also potentially increase performance (at the usual risk of overfitting). See the section on "Running the code" below. The example_data subdirectory contains a small example of the protein-protein interaction data, which includes 3 training graphs + one validation graph and one test graph. The full Reddit and PPI datasets (described in the paper) are available on the [project website](http://snap.stanford.edu/graphsage/). -If you make use of this code or the GraphSage algorithm in your work, please cite the following paper: +If you make use of this code or the GraphSage algorithm in your work, please cite the following paper: @inproceedings{hamilton2017inductive, author = {Hamilton, William L. and Ying, Rex and Leskovec, Jure}, @@ -32,36 +32,55 @@ If you make use of this code or the GraphSage algorithm in your work, please cit ### Requirements -Recent versions of TensorFlow, numpy, scipy, and networkx are required. +Recent versions of TensorFlow, numpy, scipy, and networkx are required (but networkx must be <=1.11). To guarantee that you have the right package versions, you can use [docker](https://docs.docker.com/) to easily set up a virtual environment. See the Docker subsection below for more info. + +#### Docker + +If you do not have [docker](https://docs.docker.com/) installed, you will need to do so. (Just click on the preceding link, the installation is pretty painless). + +You can run GraphSage inside a [docker](https://docs.docker.com/) image. After cloning the project, build and run the image as following: + + $ docker build -t graphsage . + $ docker run -it graphsage bash + +or start a Jupyter Notebook instead of bash: + + $ docker run -it -p 8888:8888 graphsage + +You can also run the GPU image using [nvidia-docker](https://github.com/NVIDIA/nvidia-docker): + + $ docker build -t graphsage:gpu -f Dockerfile.gpu . + $ nvidia-docker run -it graphsage:gpu bash ### Running the code The example_unsupervised.sh and example_supervised.sh files contain example usages of the code, which use the unsupervised and supervised variants of GraphSage, respectively. -If your benchmark/task does not require generalizing to unseen data, we recommend you try setting the "--identity_dim" flag to a value in the range [64,256]. -This flag will make the model embed unique node ids as attributes, which will increase the runtime and number of parameters but also potentially increase the performance. +If your benchmark/task does not require generalizing to unseen data, we recommend you try setting the "--identity_dim" flag to a value in the range [64,256]. +This flag will make the model embed unique node ids as attributes, which will increase the runtime and number of parameters but also potentially increase the performance. Note that you should set this flag and *not* try to pass dense one-hot vectors as features (due to sparsity). The "dimension" of identity features specifies how many parameters there are per node in the sparse identity-feature lookup table. Note that example_unsupervised.sh sets a very small max iteration number, which can be increased to improve performance. -We generally found that performance continued to improve even after the loss was very near convergence (i.e., even when the loss was decreasing at a very slow rate). +We generally found that performance continued to improve even after the loss was very near convergence (i.e., even when the loss was decreasing at a very slow rate). + +*Note:* For the PPI data, and any other multi-ouput dataset that allows individual nodes to belong to multiple classes, it is necessary to set the `--sigmoid` flag during supervised training. By default the model assumes that the dataset is in the "one-hot" categorical setting. -*Note:* For the PPI data, and any other multi-ouput dataset that allows individual nodes to belong to multiple classes, it is necessary to set the `--sigmoid` flag during supervised training. By default the model assumes that the dataset is in the "one-hot" categorical setting. #### Input format As input, at minimum the code requires that a --train_prefix option is specified which specifies the following data files: -* -G.json -- A networkx-specified json file describing the input graph. Nodes have 'val' and 'test' attributes specifying if they are a part of the validation and test sets, respectively. +* -G.json -- A networkx-specified json file describing the input graph. Nodes have 'val' and 'test' attributes specifying if they are a part of the validation and test sets, respectively. * -id_map.json -- A json-stored dictionary mapping the graph node ids to consecutive integers. * -id_map.json -- A json-stored dictionary mapping the graph node ids to classes. * -feats.npy [optional] --- A numpy-stored array of node features; ordering given by id_map.json. Can be omitted and only identity features will be used. * -walks.txt [optional] --- A text file specifying random walk co-occurrences (one pair per line) (*only for unsupervised version of graphsage) -To run the model on a new dataset, you need to make data files in the format described above. +To run the model on a new dataset, you need to make data files in the format described above. To run random walks for the unsupervised model and to generate the -walks.txt file) you can use the `run_walks` function in `graphsage.utils`. -#### Model variants +#### Model variants The user must also specify a --model, the variants of which are described in detail in the paper: * graphsage_mean -- GraphSage with mean-based aggregator * graphsage_seq -- GraphSage with LSTM-based aggregator @@ -71,7 +90,7 @@ The user must also specify a --model, the variants of which are described in det * n2v -- an implementation of [DeepWalk](https://arxiv.org/abs/1403.6652) (called n2v for short in the code.) #### Logging directory -Finally, a --base_log_dir should be specified (it defaults to the current directory). +Finally, a --base_log_dir should be specified (it defaults to the current directory). The output of the model and log files will be stored in a subdirectory of the base_log_dir. The path to the logged data will be of the form `-/graphsage-/`. The supervised model will output F1 scores, while the unsupervised model will train embeddings and store them. @@ -87,5 +106,5 @@ The `eval_scripts` directory contains examples of feeding the embeddings into si #### Acknowledgements The original version of this code base was originally forked from https://github.com/tkipf/gcn/, and we owe many thanks to Thomas Kipf for making his code available. -We also thank Yuanfang Li and Xin Li who contributed to a course project that was based on this work. -Please see the [paper](https://arxiv.org/pdf/1706.02216.pdf) for funding details and additional (non-code related) acknowledgements. +We also thank Yuanfang Li and Xin Li who contributed to a course project that was based on this work. +Please see the [paper](https://arxiv.org/pdf/1706.02216.pdf) for funding details and additional (non-code related) acknowledgements. diff --git a/eval_scripts/ppi_eval.py b/eval_scripts/ppi_eval.py index 9348926..c63c577 100644 --- a/eval_scripts/ppi_eval.py +++ b/eval_scripts/ppi_eval.py @@ -5,6 +5,13 @@ import numpy as np from networkx.readwrite import json_graph from argparse import ArgumentParser +''' To evaluate the embeddings, we run a logistic regression. +Run this script after running unsupervised training. +Baseline of using features-only can be run by setting data_dir as 'feat' +Example: + python eval_scripts/ppi_eval.py ../data/ppi unsup-ppi/n2v_big_0.000010 test +''' + def run_regression(train_embeds, train_labels, test_embeds, test_labels): np.random.seed(1) from sklearn.linear_model import SGDClassifier @@ -15,8 +22,12 @@ def run_regression(train_embeds, train_labels, test_embeds, test_labels): dummy.fit(train_embeds, train_labels) log = MultiOutputClassifier(SGDClassifier(loss="log"), n_jobs=10) log.fit(train_embeds, train_labels) - print("F1 score", f1_score(test_labels, log.predict(test_embeds), average="micro")) - print("Random baseline F1 score", f1_score(test_labels, dummy.predict(test_embeds), average="micro")) + + f1 = 0 + for i in range(test_labels.shape[1]): + print("F1 score", f1_score(test_labels[:,i], log.predict(test_embeds)[:,i], average="micro")) + for i in range(test_labels.shape[1]): + print("Random baseline F1 score", f1_score(test_labels[:,i], dummy.predict(test_embeds)[:,i], average="micro")) if __name__ == '__main__': parser = ArgumentParser("Run evaluation on PPI data.") @@ -30,12 +41,14 @@ if __name__ == '__main__': print("Loading data...") G = json_graph.node_link_graph(json.load(open(dataset_dir + "/ppi-G.json"))) - labels = json.load(open("/dfs/scratch0/graphnet/ppi/ppi-class_map.json")) + labels = json.load(open(dataset_dir + "/ppi-class_map.json")) labels = {int(i):l for i, l in labels.iteritems()} train_ids = [n for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']] test_ids = [n for n in G.nodes() if G.node[n][setting]] train_labels = np.array([labels[i] for i in train_ids]) + if train_labels.ndim == 1: + train_labels = np.expand_dims(train_labels, 1) test_labels = np.array([labels[i] for i in test_ids]) print("running", data_dir) @@ -45,7 +58,7 @@ if __name__ == '__main__': ## Logistic gets thrown off by big counts, so log transform num comments and score feats[:,0] = np.log(feats[:,0]+1.0) feats[:,1] = np.log(feats[:,1]-min(np.min(feats[:,1]), -1)) - feat_id_map = json.load(open("/dfs/scratch0/graphnet/ppi/ppi-id_map.json")) + feat_id_map = json.load(open(dataset_dir + "/ppi-id_map.json")) feat_id_map = {int(id):val for id,val in feat_id_map.iteritems()} train_feats = feats[[feat_id_map[id] for id in train_ids]] test_feats = feats[[feat_id_map[id] for id in test_ids]] diff --git a/graphsage/minibatch.py b/graphsage/minibatch.py index 1cfd6d9..1f53784 100644 --- a/graphsage/minibatch.py +++ b/graphsage/minibatch.py @@ -130,6 +130,9 @@ class EdgeMinibatchIterator(object): batch_edges = self.train_edges[start : start + self.batch_size] return self.batch_feed_dict(batch_edges) + def num_training_batches(self): + return len(self.train_edges) // self.batch_size + 1 + def val_feed_dict(self, size=None): edge_list = self.val_edges if size is None: @@ -292,6 +295,9 @@ class NodeMinibatchIterator(object): ret_val = self.batch_feed_dict(val_node_subset) return ret_val[0], ret_val[1], (iter_num+1)*size >= len(val_nodes), val_node_subset + def num_training_batches(self): + return len(self.train_nodes) // self.batch_size + 1 + def next_minibatch_feed_dict(self): start = self.batch_num * self.batch_size self.batch_num += 1 diff --git a/graphsage/prediction.py b/graphsage/prediction.py index 9bf0885..2e73d4c 100644 --- a/graphsage/prediction.py +++ b/graphsage/prediction.py @@ -11,6 +11,7 @@ FLAGS = flags.FLAGS class BipartiteEdgePredLayer(Layer): def __init__(self, input_dim1, input_dim2, placeholders, dropout=False, act=tf.nn.sigmoid, + loss_fn='xent', bias=False, bilinear_weights=False, **kwargs): """ Basic class that applies skip-gram-like loss @@ -26,6 +27,10 @@ class BipartiteEdgePredLayer(Layer): self.act = act self.bias = bias self.eps = 1e-7 + + # Margin for hinge loss + self.margin = 0.1 + self.bilinear_weights = bilinear_weights if dropout: @@ -49,6 +54,13 @@ class BipartiteEdgePredLayer(Layer): if self.bias: self.vars['bias'] = zeros([self.output_dim], name='bias') + if loss_fn == 'xent': + self.loss_fn = self._xent_loss + elif loss_fn == 'skipgram': + self.loss_fn = self._skipgram_loss + elif loss_fn == 'hinge': + self.loss_fn = self._hinge_loss + if self.logging: self._log_vars() @@ -66,7 +78,7 @@ class BipartiteEdgePredLayer(Layer): result = tf.reduce_sum(inputs1 * inputs2, axis=1) return result - def neg_cost(self, inputs1, neg_samples): + def neg_cost(self, inputs1, neg_samples, hard_neg_samples=None): """ For each input in batch, compute the sum of its affinity to negative samples. Returns: @@ -84,16 +96,32 @@ class BipartiteEdgePredLayer(Layer): neg_samples: tensor of shape [num_neg_samples x input_dim2]. Negative samples for all inputs in batch inputs1. """ + return self.loss_fn(inputs1, inputs2, neg_samples) + def _xent_loss(self, inputs1, inputs2, neg_samples, hard_neg_samples=None): aff = self.affinity(inputs1, inputs2) - neg_aff = self.neg_cost(inputs1, neg_samples) + neg_aff = self.neg_cost(inputs1, neg_samples, hard_neg_samples) true_xent = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.ones_like(aff), logits=aff) negative_xent = tf.nn.sigmoid_cross_entropy_with_logits( labels=tf.zeros_like(neg_aff), logits=neg_aff) - loss = tf.reduce_sum(true_xent) + tf.reduce_sum(negative_xent) + loss = tf.reduce_sum(true_xent) + 0.01*tf.reduce_sum(negative_xent) + return loss - return loss + def _skipgram_loss(self, inputs1, inputs2, neg_samples, hard_neg_samples=None): + aff = self.affinity(inputs1, inputs2) + neg_aff = self.neg_cost(inputs1, neg_samples, hard_neg_samples) + neg_cost = tf.log(tf.reduce_sum(tf.exp(neg_aff), axis=1)) + loss = tf.reduce_sum(aff - neg_cost) + return loss + + def _hinge_loss(self, inputs1, inputs2, neg_samples, hard_neg_samples=None): + aff = self.affinity(inputs1, inputs2) + neg_aff = self.neg_cost(inputs1, neg_samples, hard_neg_samples) + diff = tf.nn.relu(tf.subtract(neg_aff, tf.expand_dims(aff, 1) - self.margin), name='diff') + loss = tf.reduce_sum(diff) + self.neg_shape = tf.shape(neg_aff) + return loss def weights_norm(self): return tf.nn.l2_norm(self.vars['weights']) diff --git a/graphsage/utils.py b/graphsage/utils.py index 23c6b52..15e3f97 100644 --- a/graphsage/utils.py +++ b/graphsage/utils.py @@ -87,4 +87,4 @@ if __name__ == "__main__": G = G.subgraph(nodes) pairs = run_random_walks(G, nodes) with open(out_file, "w") as fp: - fp.write("\n".join([p[0] + "\t" + p[1] for p in pairs])) + fp.write("\n".join([str(p[0]) + "\t" + str(p[1]) for p in pairs]))