Added support for identity features.

This commit is contained in:
williamleif 2017-09-16 14:17:14 -07:00
parent a373623c16
commit 87e978e415
6 changed files with 86 additions and 35 deletions

View File

@ -1,4 +1,4 @@
## GraphSage: Inductive Representation Learning on Large Graphs
## GraphSage: Representation Learning on Large Graphs
#### Authors: [William L. Hamilton](http://stanford.edu/~wleif) (wleif@stanford.edu), [Rex Ying](http://joy-of-thinking.weebly.com/) (rexying@stanford.edu)
#### [Project Website](http://snap.stanford.edu/graphsage/)
@ -10,16 +10,23 @@ This directory contains code necessary to run the GraphSage algorithm.
GraphSage can be viewed as a stochastic generalization of graph convolutions, and it is especially useful for massive, dynamic graphs that contain rich feature information.
See our [paper](https://arxiv.org/pdf/1706.02216.pdf) for details on the algorithm.
*Note:* GraphSage now also has better support for training on smaller, static graphs and graphs that don't have node features.
The original algorithm and paper are focused on the task of inductive generalization (i.e., generating embeddings for nodes that were not present during training),
but many benchmarks/tasks use simple static graphs that do not necessarily have features.
To support this use case, GraphSage now includes optional "identity features" that can be used with or without other node attributes.
Including identity features will increase the runtime, but also potentially increase performance (at the usual risk of overfitting).
See the section on "Running the code" below.
The example_data subdirectory contains a small example of the protein-protein interaction data,
which includes 3 training graphs + one validation graph and one test graph.
The full Reddit and PPI datasets (described in the paper) are available on the [project website](http://snap.stanford.edu/graphsage/).
If you make use of this code or the GraphSage algorithm in your work, please cite the following paper:
@article{hamilton2017inductive,
@inproceedings{hamilton2017inductive,
author = {Hamilton, William L. and Ying, Rex and Leskovec, Jure},
title = {Inductive Representation Learning on Large Graphs},
journal = {arXiv preprint, arXiv:1603.04467},
booktitle = {NIPS},
year = {2017}
}
@ -29,7 +36,13 @@ Recent versions of TensorFlow, numpy, scipy, and networkx are required.
### Running the code
The example_unsupervised.sh and example_supervised.sh files contain example usages of the code, which use the unsupervised and supervised variants of GraphSAGE, respectively.
The example_unsupervised.sh and example_supervised.sh files contain example usages of the code, which use the unsupervised and supervised variants of GraphSage, respectively.
If your benchmark/task does not require generalizing to unseen data, we recommend you try setting the "--identity_dim" flag to a value in the range [64,256].
This flag will make the model use embed unique node ids as attributes, which will increase the runtime but also potentially increase the performance.
Note that you should set this flag and *not* try to pass dense one-hot vectors as features (due to sparsity).
The "dimension" of identity features specifies how many parameters there are per node in the sparse identity-feature lookup table.
Note that example_unsupervised.sh sets a very small max iteration number, which can be increased to improve performance.
We generally found that performance continued to improve even after the loss was very near convergence (i.e., even when the loss was decreasing at a very slow rate).
@ -41,21 +54,19 @@ As input, at minimum the code requires that a --train_prefix option is specified
* <train_prefix>-G.json -- A networkx-specified json file describing the input graph. Nodes have 'val' and 'test' attributes specifying if they are a part of the validation and test sets, respectively.
* <train_prefix>-id_map.json -- A json-stored dictionary mapping the graph node ids to consecutive integers.
* <train_prefix>-id_map.json -- A json-stored dictionary mapping the graph node ids to classes.
* <train_prefix>-feats.npy --- A numpy-stored array of node features; ordering given by id_map.json
* <train_prefix>-walks.txt --- A text file specifying random walk co-occurrences (one pair per line) (*only for unsupervised version of graphsage)
* <train_prefix>-feats.npy [optional] --- A numpy-stored array of node features; ordering given by id_map.json. Can be omitted and only identity features will be used.
* <train_prefix>-walks.txt [optional] --- A text file specifying random walk co-occurrences (one pair per line) (*only for unsupervised version of graphsage)
To run the model on a new dataset, you need to make data files in the format described above.
To run random walks for the unsupervised model and to generate the <prefix>-walks.txt file)
you can use the `run_walks` function in `graphsage.utils`.
#### Model variants
The user must also specify a --model, the variants of which are described in detail in the paper:
* graphsage_mean -- GraphSAGE with mean-based aggregator
* graphsage_seq -- GraphSAGE with LSTM-based aggregator
* graphsage_pool -- GraphSAGE with max-pooling aggregator
* gcn -- GraphSAGE with GCN-based aggregator
* graphsage_mean -- GraphSage with mean-based aggregator
* graphsage_seq -- GraphSage with LSTM-based aggregator
* graphsage_pool -- GraphSage with max-pooling aggregator
* gcn -- GraphSage with GCN-based aggregator
* n2v -- an implementation of [DeepWalk](https://arxiv.org/abs/1403.6652) (called n2v for short in the code.)
#### Logging directory

View File

@ -191,12 +191,13 @@ class SampleAndAggregate(GeneralizedModel):
def __init__(self, placeholders, features, adj, degrees,
layer_infos, concat=True, aggregator_type="mean",
model_size="small",
model_size="small", identity_dim=0,
**kwargs):
'''
Args:
- placeholders: Stanford TensorFlow placeholder object.
- features: Numpy array with node features.
NOTE: Pass a None object to train in featureless mode (identity features for nodes)!
- adj: Numpy array with adjacency lists (padded with random re-samples)
- degrees: Numpy array with node degrees.
- layer_infos: List of SAGEInfo namedtuples that describe the parameters of all
@ -204,6 +205,7 @@ class SampleAndAggregate(GeneralizedModel):
- concat: whether to concatenate during recursive iterations
- aggregator_type: how to aggregate neighbor information
- model_size: one of "small" and "big"
- identity_dim: Set to positive int to use identity features (slow and cannot generalize, but better accuracy)
'''
super(SampleAndAggregate, self).__init__(**kwargs)
if aggregator_type == "mean":
@ -224,11 +226,22 @@ class SampleAndAggregate(GeneralizedModel):
self.inputs2 = placeholders["batch2"]
self.model_size = model_size
self.adj_info = adj
self.features = tf.Variable(tf.constant(features, dtype=tf.float32), trainable=False)
if identity_dim > 0:
self.embeds = tf.get_variable("node_embeddings", [adj.get_shape().as_list()[0], identity_dim])
else:
self.embeds = None
if features is None:
if identity_dim is None:
raise Exception("Must have a positive value for identity feature dimension if no input features given.")
self.features = self.embeds
else:
self.features = tf.Variable(tf.constant(features, dtype=tf.float32), trainable=False)
if not self.embeds is None:
self.features = tf.concat([self.embeds, self.features], axis=1)
self.degrees = degrees
self.concat = concat
self.dims = [features.shape[1]]
self.dims = [(0 if features is None else features.shape[1]) + identity_dim]
self.dims.extend([layer_infos[i].output_dim for i in range(len(layer_infos))])
self.batch_size = placeholders["batch_size"]
self.placeholders = placeholders

View File

@ -13,7 +13,7 @@ class SupervisedGraphsage(models.SampleAndAggregate):
def __init__(self, num_classes,
placeholders, features, adj, degrees,
layer_infos, concat=True, aggregator_type="mean",
model_size="small", sigmoid_loss=False,
model_size="small", sigmoid_loss=False, identity_dim=0,
**kwargs):
'''
Args:
@ -48,13 +48,23 @@ class SupervisedGraphsage(models.SampleAndAggregate):
self.inputs1 = placeholders["batch"]
self.model_size = model_size
self.adj_info = adj
self.features = tf.Variable(tf.constant(features, dtype=tf.float32), trainable=False)
if identity_dim > 0:
self.embeds = tf.get_variable("node_embeddings", [adj.get_shape().as_list()[0], identity_dim])
else:
self.embeds = None
if features is None:
if identity_dim is None:
raise Exception("Must have a positive value for identity feature dimension if no input features given.")
self.features = self.embeds
else:
self.features = tf.Variable(tf.constant(features, dtype=tf.float32), trainable=False)
if not self.embeds is None:
self.features = tf.concat([self.embeds, self.features], axis=1)
self.degrees = degrees
self.concat = concat
self.num_classes = num_classes
self.sigmoid_loss = sigmoid_loss
self.dims = [features.shape[1]]
self.dims = [(0 if features is None else features.shape[1]) + identity_dim]
self.dims.extend([layer_infos[i].output_dim for i in range(len(layer_infos))])
self.batch_size = placeholders["batch_size"]
self.placeholders = placeholders

View File

@ -46,6 +46,7 @@ flags.DEFINE_integer('dim_2', 128, 'Size of output dim (final is 2x this, if usi
flags.DEFINE_boolean('random_context', True, 'Whether to use random context or direct edges')
flags.DEFINE_integer('batch_size', 512, 'minibatch size.')
flags.DEFINE_boolean('sigmoid', False, 'whether to use sigmoid loss')
flags.DEFINE_integer('identity_dim', 0, 'Set to positive value to use identity embedding features of that dimension. Default 0.')
#logging, saving, validation settings etc.
flags.DEFINE_string('base_log_dir', '.', 'base directory for logging and saving embeddings')
@ -129,8 +130,9 @@ def train(train_data, test_data=None):
else:
num_classes = len(set(class_map.values()))
# pad with dummy zero vector
features = np.vstack([features, np.zeros((features.shape[1],))])
if not features is None:
# pad with dummy zero vector
features = np.vstack([features, np.zeros((features.shape[1],))])
context_pairs = train_data[3] if FLAGS.random_context else None
placeholders = construct_placeholders(num_classes)
@ -164,6 +166,7 @@ def train(train_data, test_data=None):
layer_infos,
model_size=FLAGS.model_size,
sigmoid_loss = FLAGS.sigmoid,
identity_dim = FLAGS.identity_dim,
logging=True)
elif FLAGS.model == 'gcn':
# Create model
@ -180,6 +183,7 @@ def train(train_data, test_data=None):
model_size=FLAGS.model_size,
concat=False,
sigmoid_loss = FLAGS.sigmoid,
identity_dim = FLAGS.identity_dim,
logging=True)
elif FLAGS.model == 'graphsage_seq':
@ -195,6 +199,7 @@ def train(train_data, test_data=None):
aggregator_type="seq",
model_size=FLAGS.model_size,
sigmoid_loss = FLAGS.sigmoid,
identity_dim = FLAGS.identity_dim,
logging=True)
elif FLAGS.model == 'graphsage_pool':
@ -210,6 +215,7 @@ def train(train_data, test_data=None):
aggregator_type="pool",
model_size=FLAGS.model_size,
sigmoid_loss = FLAGS.sigmoid,
identity_dim = FLAGS.identity_dim,
logging=True)
else:
raise Exception('Error: model name unrecognized.')

View File

@ -43,6 +43,7 @@ flags.DEFINE_boolean('random_context', True, 'Whether to use random context or d
flags.DEFINE_integer('neg_sample_size', 20, 'number of negative samples')
flags.DEFINE_integer('batch_size', 512, 'minibatch size.')
flags.DEFINE_integer('n2v_test_epochs', 1, 'Number of new SGD epochs for n2v.')
flags.DEFINE_integer('identity_dim', 0, 'Set to positive value to use identity embedding features of that dimension. Default 0.')
#logging, saving, validation settings etc.
flags.DEFINE_boolean('save_embeddings', True, 'whether to save embeddings for all nodes after training')
@ -115,7 +116,7 @@ def save_val_embeddings(sess, model, minibatch_iter, size, out_dir, mod=""):
with open(out_dir + name + mod + ".txt", "w") as fp:
fp.write("\n".join(map(str,nodes)))
def construct_placeholders(feature_size):
def construct_placeholders():
# Define placeholders
placeholders = {
'batch1' : tf.placeholder(tf.int32, shape=(None), name='batch1'),
@ -133,12 +134,12 @@ def train(train_data, test_data=None):
features = train_data[1]
id_map = train_data[2]
# pad with dummy zero vector
features = np.vstack([features, np.zeros((features.shape[1],))])
feature_size = features.shape[1]
if not features is None:
# pad with dummy zero vector
features = np.vstack([features, np.zeros((features.shape[1],))])
context_pairs = train_data[3] if FLAGS.random_context else None
placeholders = construct_placeholders(feature_size)
placeholders = construct_placeholders()
minibatch = EdgeMinibatchIterator(G,
id_map,
placeholders, batch_size=FLAGS.batch_size,
@ -159,6 +160,7 @@ def train(train_data, test_data=None):
minibatch.deg,
layer_infos=layer_infos,
model_size=FLAGS.model_size,
identity_dim = FLAGS.identity_dim,
logging=True)
elif FLAGS.model == 'gcn':
# Create model
@ -173,6 +175,7 @@ def train(train_data, test_data=None):
layer_infos=layer_infos,
aggregator_type="gcn",
model_size=FLAGS.model_size,
identity_dim = FLAGS.identity_dim,
concat=False,
logging=True)
@ -186,6 +189,7 @@ def train(train_data, test_data=None):
adj_info,
minibatch.deg,
layer_infos=layer_infos,
identity_dim = FLAGS.identity_dim,
aggregator_type="seq",
model_size=FLAGS.model_size,
logging=True)
@ -202,6 +206,7 @@ def train(train_data, test_data=None):
layer_infos=layer_infos,
aggregator_type="pool",
model_size=FLAGS.model_size,
identity_dim = FLAGS.identity_dim,
logging=True)
elif FLAGS.model == 'n2v':
model = Node2VecModel(placeholders, features.shape[0],
@ -354,7 +359,7 @@ def train(train_data, test_data=None):
def main(argv=None):
print("Loading training data..")
train_data = load_data(FLAGS.train_prefix)
train_data = load_data(FLAGS.train_prefix, load_walks=True)
print("Done loading training data..")
train(train_data)

View File

@ -4,13 +4,14 @@ import numpy as np
import random
import json
import sys
import os
from networkx.readwrite import json_graph
WALK_LEN=5
N_WALKS=50
def load_data(prefix, normalize=True):
def load_data(prefix, normalize=True, load_walks=False):
G_data = json.load(open(prefix + "-G.json"))
G = json_graph.node_link_graph(G_data)
if isinstance(G.nodes()[0], int):
@ -18,7 +19,11 @@ def load_data(prefix, normalize=True):
else:
conversion = lambda n : n
feats = np.load(prefix + "-feats.npy")
if os.path.exists(prefix + "-feats.npy"):
feats = np.load(prefix + "-feats.npy")
else:
print("No features present.. Only identity features will be used.")
feats = None
id_map = json.load(open(prefix + "-id_map.json"))
id_map = {conversion(k):int(v) for k,v in id_map.iteritems()}
walks = []
@ -40,7 +45,7 @@ def load_data(prefix, normalize=True):
else:
G[edge[0]][edge[1]]['train_removed'] = False
if normalize:
if normalize and not feats is None:
from sklearn.preprocessing import StandardScaler
train_ids = np.array([id_map[n] for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']])
train_feats = feats[train_ids]
@ -48,9 +53,10 @@ def load_data(prefix, normalize=True):
scaler.fit(train_feats)
feats = scaler.transform(feats)
with open(prefix + "-walks.txt") as fp:
for line in fp:
walks.append(map(conversion, line.split()))
if load_walks:
with open(prefix + "-walks.txt") as fp:
for line in fp:
walks.append(map(conversion, line.split()))
return G, feats, id_map, walks, class_map