From 41f491b74772f17f0f8bb1690865915667f3bbc3 Mon Sep 17 00:00:00 2001 From: williamleif Date: Wed, 31 May 2017 06:39:04 -0700 Subject: [PATCH] Cleaning up comments etc. --- graphsage/inits.py | 3 +-- graphsage/metrics.py | 8 +++----- graphsage/minibatch.py | 22 +++++++++++++++++++--- graphsage/models.py | 26 +++++++++++++++----------- graphsage/neigh_samplers.py | 3 +-- graphsage/prediction.py | 6 +++--- graphsage/supervised_models.py | 16 ++++++++++++++++ graphsage/unsupervised_train.py | 1 - 8 files changed, 58 insertions(+), 27 deletions(-) diff --git a/graphsage/inits.py b/graphsage/inits.py index 73f00d6..c335149 100644 --- a/graphsage/inits.py +++ b/graphsage/inits.py @@ -4,8 +4,7 @@ import numpy as np # DISCLAIMER: # Parts of this code file are derived from # https://github.com/tkipf/gcn -# (A full license with proper attributions will be provided in the -# public repo of this code base) +# which is under an identical MIT license as GraphSAGE def uniform(shape, scale=0.05, name=None): """Uniform init.""" diff --git a/graphsage/metrics.py b/graphsage/metrics.py index afafa47..c696306 100644 --- a/graphsage/metrics.py +++ b/graphsage/metrics.py @@ -4,10 +4,8 @@ import tensorflow as tf # Parts of this code file were originally forked from # https://github.com/tkipf/gcn # which itself was very inspired by the keras package -# (A full license with de-anonymized attributions will be provided in the -# public repo of this code base) - def masked_logit_cross_entropy(preds, labels, mask): + """Logit cross-entropy loss with masking.""" loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=preds, labels=labels) loss = tf.reduce_sum(loss, axis=1) mask = tf.cast(mask, dtype=tf.float32) @@ -16,8 +14,8 @@ def masked_logit_cross_entropy(preds, labels, mask): return tf.reduce_mean(loss) def masked_softmax_cross_entropy(preds, labels, mask): + """Softmax cross-entropy loss with masking.""" loss = tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=labels) -# loss = tf.reduce_sum(loss, axis=1) mask = tf.cast(mask, dtype=tf.float32) mask /= tf.maximum(tf.reduce_sum(mask), tf.constant([1.])) loss *= mask @@ -25,7 +23,7 @@ def masked_softmax_cross_entropy(preds, labels, mask): def masked_l2(preds, actuals, mask): - """Softmax cross-entropy loss with masking.""" + """L2 loss with masking.""" loss = tf.nn.l2(preds, actuals) mask = tf.cast(mask, dtype=tf.float32) mask /= tf.reduce_mean(mask) diff --git a/graphsage/minibatch.py b/graphsage/minibatch.py index 6de9956..180648d 100644 --- a/graphsage/minibatch.py +++ b/graphsage/minibatch.py @@ -9,9 +9,18 @@ class EdgeMinibatchIterator(object): """ This minibatch iterator iterates over batches of sampled edges or random pairs of co-occuring edges. + + G -- networkx graph + id2idx -- dict mapping node ids to index in feature tensor + placeholders -- tensorflow placeholders object + context_pairs -- if not none, then a list of co-occuring node pairs (from random walks) + batch_size -- size of the minibatches + max_degree -- maximum size of the downsampled adjacency lists + n2v_retrain -- signals that the iterator is being used to add new embeddings to a n2v model + fixed_n2v -- signals that the iterator is being used to retrain n2v with only existing nodes as context """ def __init__(self, G, id2idx, - placeholders, context_pairs=None,batch_size=100, max_degree=25, num_neg_samples=20, + placeholders, context_pairs=None, batch_size=100, max_degree=25, n2v_retrain=False, fixed_n2v=False, **kwargs): @@ -21,7 +30,6 @@ class EdgeMinibatchIterator(object): self.placeholders = placeholders self.batch_size = batch_size self.max_degree = max_degree - self.num_neg_samples = num_neg_samples self.batch_num = 0 self.nodes = np.random.permutation(G.nodes()) @@ -162,9 +170,17 @@ class NodeMinibatchIterator(object): """ This minibatch iterator iterates over nodes for supervised learning. + + G -- networkx graph + id2idx -- dict mapping node ids to integer values indexing feature tensor + placeholders -- standard tensorflow placeholders object for feeding + label_map -- map from node ids to class values (integer or list) + num_classes -- number of output classes + batch_size -- size of the minibatches + max_degree -- maximum size of the downsampled adjacency lists """ def __init__(self, G, id2idx, - placeholders, label_map, num_classes, context_pairs=None, + placeholders, label_map, num_classes, batch_size=100, max_degree=25, **kwargs): diff --git a/graphsage/models.py b/graphsage/models.py index 52c3181..b40b17f 100644 --- a/graphsage/models.py +++ b/graphsage/models.py @@ -16,8 +16,6 @@ FLAGS = flags.FLAGS # Boilerplate parts of this code file were originally forked from # https://github.com/tkipf/gcn # which itself was very inspired by the keras package -# (A full license with proper attributions will be provided in the -# public repo of this code base) class Model(object): def __init__(self, **kwargs): @@ -97,6 +95,7 @@ class Model(object): class MLP(Model): + """ A standard multi-layer perceptron """ def __init__(self, placeholders, dims, categorical=True, **kwargs): super(MLP, self).__init__(**kwargs) @@ -177,7 +176,7 @@ class GeneralizedModel(Model): self.opt_op = self.optimizer.minimize(self.loss) # SAGEInfo is a namedtuple that specifies the parameters -# of the recursive sampled GCN layers +# of the recursive GraphSAGE layers SAGEInfo = namedtuple("SAGEInfo", ['layer_name', # name of the layer (to get feature embedding etc.) 'neigh_sampler', # callable neigh_sampler constructor @@ -187,8 +186,7 @@ SAGEInfo = namedtuple("SAGEInfo", class SampleAndAggregate(GeneralizedModel): """ - Implementation of a standard 2-step graph convolutional network - Uses random sampling on neighborhoods + Base implementation of unsupervised GraphSAGE """ def __init__(self, placeholders, features, adj, degrees, @@ -197,9 +195,15 @@ class SampleAndAggregate(GeneralizedModel): **kwargs): ''' Args: - - layer_infos: List of SGCInfo namedtuples that describe the parameters of all - the recursive layers. See SGCInfo definition above. - + - placeholders: Stanford TensorFlow placeholder object. + - features: Numpy array with node features. + - adj: Numpy array with adjacency lists (padded with random re-samples) + - degrees: Numpy array with node degrees. + - layer_infos: List of SAGEInfo namedtuples that describe the parameters of all + the recursive layers. See SAGEInfo definition above. + - concat: whether to concatenate during recursive iterations + - aggregator_type: how to aggregate neighbor information + - model_size: one of "small" and "big" ''' super(SampleAndAggregate, self).__init__(**kwargs) if aggregator_type == "mean": @@ -392,11 +396,11 @@ class SampleAndAggregate(GeneralizedModel): class Node2VecModel(GeneralizedModel): def __init__(self, placeholders, dict_size, degrees, name=None, nodevec_dim=50, lr=0.001, **kwargs): - """ Simple version of Node2Vec algorithm. + """ Simple version of Node2Vec/DeepWalk algorithm. Args: - dict_size1: the total number of nodes in set1. - dict_size2: the total number of nodes in set2. + dict_size: the total number of nodes. + degrees: numpy array of node degrees, ordered as in the data's id_map nodevec_dim: dimension of the vector representation of node. lr: learning rate of optimizer. """ diff --git a/graphsage/neigh_samplers.py b/graphsage/neigh_samplers.py index f9fc607..9c83d55 100644 --- a/graphsage/neigh_samplers.py +++ b/graphsage/neigh_samplers.py @@ -9,8 +9,7 @@ FLAGS = flags.FLAGS """ -Classes that are used to sample node neighborhoods during -convolutions. +Classes that are used to sample node neighborhoods """ class UniformNeighborSampler(Layer): diff --git a/graphsage/prediction.py b/graphsage/prediction.py index 41e4ab6..9bf0885 100644 --- a/graphsage/prediction.py +++ b/graphsage/prediction.py @@ -1,7 +1,7 @@ from __future__ import division from __future__ import print_function -from graphsage.inits import glorot, zeros +from graphsage.inits import zeros from graphsage.layers import Layer import tensorflow as tf @@ -13,6 +13,8 @@ class BipartiteEdgePredLayer(Layer): def __init__(self, input_dim1, input_dim2, placeholders, dropout=False, act=tf.nn.sigmoid, bias=False, bilinear_weights=False, **kwargs): """ + Basic class that applies skip-gram-like loss + (i.e., dot product of node+target and node and negative samples) Args: bilinear_weights: use a bilinear weight for affinity calculation: u^T A v. If set to false, it is assumed that input dimensions are the same and the affinity will be @@ -95,5 +97,3 @@ class BipartiteEdgePredLayer(Layer): def weights_norm(self): return tf.nn.l2_norm(self.vars['weights']) - - diff --git a/graphsage/supervised_models.py b/graphsage/supervised_models.py index fa00882..a8658b6 100644 --- a/graphsage/supervised_models.py +++ b/graphsage/supervised_models.py @@ -8,11 +8,27 @@ flags = tf.app.flags FLAGS = flags.FLAGS class SupervisedGraphsage(models.SampleAndAggregate): + """Implementation of supervised GraphSAGE.""" + def __init__(self, num_classes, placeholders, features, adj, degrees, layer_infos, concat=True, aggregator_type="mean", model_size="small", sigmoid_loss=False, **kwargs): + ''' + Args: + - placeholders: Stanford TensorFlow placeholder object. + - features: Numpy array with node features. + - adj: Numpy array with adjacency lists (padded with random re-samples) + - degrees: Numpy array with node degrees. + - layer_infos: List of SAGEInfo namedtuples that describe the parameters of all + the recursive layers. See SAGEInfo definition above. + - concat: whether to concatenate during recursive iterations + - aggregator_type: how to aggregate neighbor information + - model_size: one of "small" and "big" + - sigmoid_loss: Set to true if nodes can belong to multiple classes + ''' + models.GeneralizedModel.__init__(self, **kwargs) if aggregator_type == "mean": diff --git a/graphsage/unsupervised_train.py b/graphsage/unsupervised_train.py index a9919c6..945aa20 100644 --- a/graphsage/unsupervised_train.py +++ b/graphsage/unsupervised_train.py @@ -244,7 +244,6 @@ def train(train_data, test_data=None): epoch_val_costs.append(0) while not minibatch.end(): # Construct feed dictionary - #feed_dict = construct_minibatch_feed_dict(features, G, y_train, train_mask, placeholders) feed_dict = minibatch.next_minibatch_feed_dict() feed_dict.update({placeholders['dropout']: FLAGS.dropout})