Added support for identity features.
This commit is contained in:
parent
a373623c16
commit
87e978e415
35
README.md
35
README.md
@ -1,4 +1,4 @@
|
||||
## GraphSage: Inductive Representation Learning on Large Graphs
|
||||
## GraphSage: Representation Learning on Large Graphs
|
||||
|
||||
#### Authors: [William L. Hamilton](http://stanford.edu/~wleif) (wleif@stanford.edu), [Rex Ying](http://joy-of-thinking.weebly.com/) (rexying@stanford.edu)
|
||||
#### [Project Website](http://snap.stanford.edu/graphsage/)
|
||||
@ -10,16 +10,23 @@ This directory contains code necessary to run the GraphSage algorithm.
|
||||
GraphSage can be viewed as a stochastic generalization of graph convolutions, and it is especially useful for massive, dynamic graphs that contain rich feature information.
|
||||
See our [paper](https://arxiv.org/pdf/1706.02216.pdf) for details on the algorithm.
|
||||
|
||||
*Note:* GraphSage now also has better support for training on smaller, static graphs and graphs that don't have node features.
|
||||
The original algorithm and paper are focused on the task of inductive generalization (i.e., generating embeddings for nodes that were not present during training),
|
||||
but many benchmarks/tasks use simple static graphs that do not necessarily have features.
|
||||
To support this use case, GraphSage now includes optional "identity features" that can be used with or without other node attributes.
|
||||
Including identity features will increase the runtime, but also potentially increase performance (at the usual risk of overfitting).
|
||||
See the section on "Running the code" below.
|
||||
|
||||
The example_data subdirectory contains a small example of the protein-protein interaction data,
|
||||
which includes 3 training graphs + one validation graph and one test graph.
|
||||
The full Reddit and PPI datasets (described in the paper) are available on the [project website](http://snap.stanford.edu/graphsage/).
|
||||
|
||||
If you make use of this code or the GraphSage algorithm in your work, please cite the following paper:
|
||||
|
||||
@article{hamilton2017inductive,
|
||||
@inproceedings{hamilton2017inductive,
|
||||
author = {Hamilton, William L. and Ying, Rex and Leskovec, Jure},
|
||||
title = {Inductive Representation Learning on Large Graphs},
|
||||
journal = {arXiv preprint, arXiv:1603.04467},
|
||||
booktitle = {NIPS},
|
||||
year = {2017}
|
||||
}
|
||||
|
||||
@ -29,7 +36,13 @@ Recent versions of TensorFlow, numpy, scipy, and networkx are required.
|
||||
|
||||
### Running the code
|
||||
|
||||
The example_unsupervised.sh and example_supervised.sh files contain example usages of the code, which use the unsupervised and supervised variants of GraphSAGE, respectively.
|
||||
The example_unsupervised.sh and example_supervised.sh files contain example usages of the code, which use the unsupervised and supervised variants of GraphSage, respectively.
|
||||
|
||||
If your benchmark/task does not require generalizing to unseen data, we recommend you try setting the "--identity_dim" flag to a value in the range [64,256].
|
||||
This flag will make the model use embed unique node ids as attributes, which will increase the runtime but also potentially increase the performance.
|
||||
Note that you should set this flag and *not* try to pass dense one-hot vectors as features (due to sparsity).
|
||||
The "dimension" of identity features specifies how many parameters there are per node in the sparse identity-feature lookup table.
|
||||
|
||||
Note that example_unsupervised.sh sets a very small max iteration number, which can be increased to improve performance.
|
||||
We generally found that performance continued to improve even after the loss was very near convergence (i.e., even when the loss was decreasing at a very slow rate).
|
||||
|
||||
@ -41,21 +54,19 @@ As input, at minimum the code requires that a --train_prefix option is specified
|
||||
* <train_prefix>-G.json -- A networkx-specified json file describing the input graph. Nodes have 'val' and 'test' attributes specifying if they are a part of the validation and test sets, respectively.
|
||||
* <train_prefix>-id_map.json -- A json-stored dictionary mapping the graph node ids to consecutive integers.
|
||||
* <train_prefix>-id_map.json -- A json-stored dictionary mapping the graph node ids to classes.
|
||||
* <train_prefix>-feats.npy --- A numpy-stored array of node features; ordering given by id_map.json
|
||||
* <train_prefix>-walks.txt --- A text file specifying random walk co-occurrences (one pair per line) (*only for unsupervised version of graphsage)
|
||||
* <train_prefix>-feats.npy [optional] --- A numpy-stored array of node features; ordering given by id_map.json. Can be omitted and only identity features will be used.
|
||||
* <train_prefix>-walks.txt [optional] --- A text file specifying random walk co-occurrences (one pair per line) (*only for unsupervised version of graphsage)
|
||||
|
||||
To run the model on a new dataset, you need to make data files in the format described above.
|
||||
To run random walks for the unsupervised model and to generate the <prefix>-walks.txt file)
|
||||
you can use the `run_walks` function in `graphsage.utils`.
|
||||
|
||||
|
||||
|
||||
#### Model variants
|
||||
The user must also specify a --model, the variants of which are described in detail in the paper:
|
||||
* graphsage_mean -- GraphSAGE with mean-based aggregator
|
||||
* graphsage_seq -- GraphSAGE with LSTM-based aggregator
|
||||
* graphsage_pool -- GraphSAGE with max-pooling aggregator
|
||||
* gcn -- GraphSAGE with GCN-based aggregator
|
||||
* graphsage_mean -- GraphSage with mean-based aggregator
|
||||
* graphsage_seq -- GraphSage with LSTM-based aggregator
|
||||
* graphsage_pool -- GraphSage with max-pooling aggregator
|
||||
* gcn -- GraphSage with GCN-based aggregator
|
||||
* n2v -- an implementation of [DeepWalk](https://arxiv.org/abs/1403.6652) (called n2v for short in the code.)
|
||||
|
||||
#### Logging directory
|
||||
|
@ -191,12 +191,13 @@ class SampleAndAggregate(GeneralizedModel):
|
||||
|
||||
def __init__(self, placeholders, features, adj, degrees,
|
||||
layer_infos, concat=True, aggregator_type="mean",
|
||||
model_size="small",
|
||||
model_size="small", identity_dim=0,
|
||||
**kwargs):
|
||||
'''
|
||||
Args:
|
||||
- placeholders: Stanford TensorFlow placeholder object.
|
||||
- features: Numpy array with node features.
|
||||
NOTE: Pass a None object to train in featureless mode (identity features for nodes)!
|
||||
- adj: Numpy array with adjacency lists (padded with random re-samples)
|
||||
- degrees: Numpy array with node degrees.
|
||||
- layer_infos: List of SAGEInfo namedtuples that describe the parameters of all
|
||||
@ -204,6 +205,7 @@ class SampleAndAggregate(GeneralizedModel):
|
||||
- concat: whether to concatenate during recursive iterations
|
||||
- aggregator_type: how to aggregate neighbor information
|
||||
- model_size: one of "small" and "big"
|
||||
- identity_dim: Set to positive int to use identity features (slow and cannot generalize, but better accuracy)
|
||||
'''
|
||||
super(SampleAndAggregate, self).__init__(**kwargs)
|
||||
if aggregator_type == "mean":
|
||||
@ -224,11 +226,22 @@ class SampleAndAggregate(GeneralizedModel):
|
||||
self.inputs2 = placeholders["batch2"]
|
||||
self.model_size = model_size
|
||||
self.adj_info = adj
|
||||
self.features = tf.Variable(tf.constant(features, dtype=tf.float32), trainable=False)
|
||||
if identity_dim > 0:
|
||||
self.embeds = tf.get_variable("node_embeddings", [adj.get_shape().as_list()[0], identity_dim])
|
||||
else:
|
||||
self.embeds = None
|
||||
if features is None:
|
||||
if identity_dim is None:
|
||||
raise Exception("Must have a positive value for identity feature dimension if no input features given.")
|
||||
self.features = self.embeds
|
||||
else:
|
||||
self.features = tf.Variable(tf.constant(features, dtype=tf.float32), trainable=False)
|
||||
if not self.embeds is None:
|
||||
self.features = tf.concat([self.embeds, self.features], axis=1)
|
||||
self.degrees = degrees
|
||||
self.concat = concat
|
||||
|
||||
self.dims = [features.shape[1]]
|
||||
self.dims = [(0 if features is None else features.shape[1]) + identity_dim]
|
||||
self.dims.extend([layer_infos[i].output_dim for i in range(len(layer_infos))])
|
||||
self.batch_size = placeholders["batch_size"]
|
||||
self.placeholders = placeholders
|
||||
|
@ -13,7 +13,7 @@ class SupervisedGraphsage(models.SampleAndAggregate):
|
||||
def __init__(self, num_classes,
|
||||
placeholders, features, adj, degrees,
|
||||
layer_infos, concat=True, aggregator_type="mean",
|
||||
model_size="small", sigmoid_loss=False,
|
||||
model_size="small", sigmoid_loss=False, identity_dim=0,
|
||||
**kwargs):
|
||||
'''
|
||||
Args:
|
||||
@ -48,13 +48,23 @@ class SupervisedGraphsage(models.SampleAndAggregate):
|
||||
self.inputs1 = placeholders["batch"]
|
||||
self.model_size = model_size
|
||||
self.adj_info = adj
|
||||
self.features = tf.Variable(tf.constant(features, dtype=tf.float32), trainable=False)
|
||||
if identity_dim > 0:
|
||||
self.embeds = tf.get_variable("node_embeddings", [adj.get_shape().as_list()[0], identity_dim])
|
||||
else:
|
||||
self.embeds = None
|
||||
if features is None:
|
||||
if identity_dim is None:
|
||||
raise Exception("Must have a positive value for identity feature dimension if no input features given.")
|
||||
self.features = self.embeds
|
||||
else:
|
||||
self.features = tf.Variable(tf.constant(features, dtype=tf.float32), trainable=False)
|
||||
if not self.embeds is None:
|
||||
self.features = tf.concat([self.embeds, self.features], axis=1)
|
||||
self.degrees = degrees
|
||||
self.concat = concat
|
||||
self.num_classes = num_classes
|
||||
self.sigmoid_loss = sigmoid_loss
|
||||
|
||||
self.dims = [features.shape[1]]
|
||||
self.dims = [(0 if features is None else features.shape[1]) + identity_dim]
|
||||
self.dims.extend([layer_infos[i].output_dim for i in range(len(layer_infos))])
|
||||
self.batch_size = placeholders["batch_size"]
|
||||
self.placeholders = placeholders
|
||||
|
@ -46,6 +46,7 @@ flags.DEFINE_integer('dim_2', 128, 'Size of output dim (final is 2x this, if usi
|
||||
flags.DEFINE_boolean('random_context', True, 'Whether to use random context or direct edges')
|
||||
flags.DEFINE_integer('batch_size', 512, 'minibatch size.')
|
||||
flags.DEFINE_boolean('sigmoid', False, 'whether to use sigmoid loss')
|
||||
flags.DEFINE_integer('identity_dim', 0, 'Set to positive value to use identity embedding features of that dimension. Default 0.')
|
||||
|
||||
#logging, saving, validation settings etc.
|
||||
flags.DEFINE_string('base_log_dir', '.', 'base directory for logging and saving embeddings')
|
||||
@ -129,8 +130,9 @@ def train(train_data, test_data=None):
|
||||
else:
|
||||
num_classes = len(set(class_map.values()))
|
||||
|
||||
# pad with dummy zero vector
|
||||
features = np.vstack([features, np.zeros((features.shape[1],))])
|
||||
if not features is None:
|
||||
# pad with dummy zero vector
|
||||
features = np.vstack([features, np.zeros((features.shape[1],))])
|
||||
|
||||
context_pairs = train_data[3] if FLAGS.random_context else None
|
||||
placeholders = construct_placeholders(num_classes)
|
||||
@ -164,6 +166,7 @@ def train(train_data, test_data=None):
|
||||
layer_infos,
|
||||
model_size=FLAGS.model_size,
|
||||
sigmoid_loss = FLAGS.sigmoid,
|
||||
identity_dim = FLAGS.identity_dim,
|
||||
logging=True)
|
||||
elif FLAGS.model == 'gcn':
|
||||
# Create model
|
||||
@ -180,6 +183,7 @@ def train(train_data, test_data=None):
|
||||
model_size=FLAGS.model_size,
|
||||
concat=False,
|
||||
sigmoid_loss = FLAGS.sigmoid,
|
||||
identity_dim = FLAGS.identity_dim,
|
||||
logging=True)
|
||||
|
||||
elif FLAGS.model == 'graphsage_seq':
|
||||
@ -195,6 +199,7 @@ def train(train_data, test_data=None):
|
||||
aggregator_type="seq",
|
||||
model_size=FLAGS.model_size,
|
||||
sigmoid_loss = FLAGS.sigmoid,
|
||||
identity_dim = FLAGS.identity_dim,
|
||||
logging=True)
|
||||
|
||||
elif FLAGS.model == 'graphsage_pool':
|
||||
@ -210,6 +215,7 @@ def train(train_data, test_data=None):
|
||||
aggregator_type="pool",
|
||||
model_size=FLAGS.model_size,
|
||||
sigmoid_loss = FLAGS.sigmoid,
|
||||
identity_dim = FLAGS.identity_dim,
|
||||
logging=True)
|
||||
else:
|
||||
raise Exception('Error: model name unrecognized.')
|
||||
|
@ -43,6 +43,7 @@ flags.DEFINE_boolean('random_context', True, 'Whether to use random context or d
|
||||
flags.DEFINE_integer('neg_sample_size', 20, 'number of negative samples')
|
||||
flags.DEFINE_integer('batch_size', 512, 'minibatch size.')
|
||||
flags.DEFINE_integer('n2v_test_epochs', 1, 'Number of new SGD epochs for n2v.')
|
||||
flags.DEFINE_integer('identity_dim', 0, 'Set to positive value to use identity embedding features of that dimension. Default 0.')
|
||||
|
||||
#logging, saving, validation settings etc.
|
||||
flags.DEFINE_boolean('save_embeddings', True, 'whether to save embeddings for all nodes after training')
|
||||
@ -115,7 +116,7 @@ def save_val_embeddings(sess, model, minibatch_iter, size, out_dir, mod=""):
|
||||
with open(out_dir + name + mod + ".txt", "w") as fp:
|
||||
fp.write("\n".join(map(str,nodes)))
|
||||
|
||||
def construct_placeholders(feature_size):
|
||||
def construct_placeholders():
|
||||
# Define placeholders
|
||||
placeholders = {
|
||||
'batch1' : tf.placeholder(tf.int32, shape=(None), name='batch1'),
|
||||
@ -133,12 +134,12 @@ def train(train_data, test_data=None):
|
||||
features = train_data[1]
|
||||
id_map = train_data[2]
|
||||
|
||||
# pad with dummy zero vector
|
||||
features = np.vstack([features, np.zeros((features.shape[1],))])
|
||||
feature_size = features.shape[1]
|
||||
if not features is None:
|
||||
# pad with dummy zero vector
|
||||
features = np.vstack([features, np.zeros((features.shape[1],))])
|
||||
|
||||
context_pairs = train_data[3] if FLAGS.random_context else None
|
||||
placeholders = construct_placeholders(feature_size)
|
||||
placeholders = construct_placeholders()
|
||||
minibatch = EdgeMinibatchIterator(G,
|
||||
id_map,
|
||||
placeholders, batch_size=FLAGS.batch_size,
|
||||
@ -159,6 +160,7 @@ def train(train_data, test_data=None):
|
||||
minibatch.deg,
|
||||
layer_infos=layer_infos,
|
||||
model_size=FLAGS.model_size,
|
||||
identity_dim = FLAGS.identity_dim,
|
||||
logging=True)
|
||||
elif FLAGS.model == 'gcn':
|
||||
# Create model
|
||||
@ -173,6 +175,7 @@ def train(train_data, test_data=None):
|
||||
layer_infos=layer_infos,
|
||||
aggregator_type="gcn",
|
||||
model_size=FLAGS.model_size,
|
||||
identity_dim = FLAGS.identity_dim,
|
||||
concat=False,
|
||||
logging=True)
|
||||
|
||||
@ -186,6 +189,7 @@ def train(train_data, test_data=None):
|
||||
adj_info,
|
||||
minibatch.deg,
|
||||
layer_infos=layer_infos,
|
||||
identity_dim = FLAGS.identity_dim,
|
||||
aggregator_type="seq",
|
||||
model_size=FLAGS.model_size,
|
||||
logging=True)
|
||||
@ -202,6 +206,7 @@ def train(train_data, test_data=None):
|
||||
layer_infos=layer_infos,
|
||||
aggregator_type="pool",
|
||||
model_size=FLAGS.model_size,
|
||||
identity_dim = FLAGS.identity_dim,
|
||||
logging=True)
|
||||
elif FLAGS.model == 'n2v':
|
||||
model = Node2VecModel(placeholders, features.shape[0],
|
||||
@ -354,7 +359,7 @@ def train(train_data, test_data=None):
|
||||
|
||||
def main(argv=None):
|
||||
print("Loading training data..")
|
||||
train_data = load_data(FLAGS.train_prefix)
|
||||
train_data = load_data(FLAGS.train_prefix, load_walks=True)
|
||||
print("Done loading training data..")
|
||||
train(train_data)
|
||||
|
||||
|
@ -4,13 +4,14 @@ import numpy as np
|
||||
import random
|
||||
import json
|
||||
import sys
|
||||
import os
|
||||
|
||||
from networkx.readwrite import json_graph
|
||||
|
||||
WALK_LEN=5
|
||||
N_WALKS=50
|
||||
|
||||
def load_data(prefix, normalize=True):
|
||||
def load_data(prefix, normalize=True, load_walks=False):
|
||||
G_data = json.load(open(prefix + "-G.json"))
|
||||
G = json_graph.node_link_graph(G_data)
|
||||
if isinstance(G.nodes()[0], int):
|
||||
@ -18,7 +19,11 @@ def load_data(prefix, normalize=True):
|
||||
else:
|
||||
conversion = lambda n : n
|
||||
|
||||
feats = np.load(prefix + "-feats.npy")
|
||||
if os.path.exists(prefix + "-feats.npy"):
|
||||
feats = np.load(prefix + "-feats.npy")
|
||||
else:
|
||||
print("No features present.. Only identity features will be used.")
|
||||
feats = None
|
||||
id_map = json.load(open(prefix + "-id_map.json"))
|
||||
id_map = {conversion(k):int(v) for k,v in id_map.iteritems()}
|
||||
walks = []
|
||||
@ -40,7 +45,7 @@ def load_data(prefix, normalize=True):
|
||||
else:
|
||||
G[edge[0]][edge[1]]['train_removed'] = False
|
||||
|
||||
if normalize:
|
||||
if normalize and not feats is None:
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
train_ids = np.array([id_map[n] for n in G.nodes() if not G.node[n]['val'] and not G.node[n]['test']])
|
||||
train_feats = feats[train_ids]
|
||||
@ -48,9 +53,10 @@ def load_data(prefix, normalize=True):
|
||||
scaler.fit(train_feats)
|
||||
feats = scaler.transform(feats)
|
||||
|
||||
with open(prefix + "-walks.txt") as fp:
|
||||
for line in fp:
|
||||
walks.append(map(conversion, line.split()))
|
||||
if load_walks:
|
||||
with open(prefix + "-walks.txt") as fp:
|
||||
for line in fp:
|
||||
walks.append(map(conversion, line.split()))
|
||||
|
||||
return G, feats, id_map, walks, class_map
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user