Cleaning up comments etc.
This commit is contained in:
parent
095f63304a
commit
41f491b747
@ -4,8 +4,7 @@ import numpy as np
|
|||||||
# DISCLAIMER:
|
# DISCLAIMER:
|
||||||
# Parts of this code file are derived from
|
# Parts of this code file are derived from
|
||||||
# https://github.com/tkipf/gcn
|
# https://github.com/tkipf/gcn
|
||||||
# (A full license with proper attributions will be provided in the
|
# which is under an identical MIT license as GraphSAGE
|
||||||
# public repo of this code base)
|
|
||||||
|
|
||||||
def uniform(shape, scale=0.05, name=None):
|
def uniform(shape, scale=0.05, name=None):
|
||||||
"""Uniform init."""
|
"""Uniform init."""
|
||||||
|
@ -4,10 +4,8 @@ import tensorflow as tf
|
|||||||
# Parts of this code file were originally forked from
|
# Parts of this code file were originally forked from
|
||||||
# https://github.com/tkipf/gcn
|
# https://github.com/tkipf/gcn
|
||||||
# which itself was very inspired by the keras package
|
# which itself was very inspired by the keras package
|
||||||
# (A full license with de-anonymized attributions will be provided in the
|
|
||||||
# public repo of this code base)
|
|
||||||
|
|
||||||
def masked_logit_cross_entropy(preds, labels, mask):
|
def masked_logit_cross_entropy(preds, labels, mask):
|
||||||
|
"""Logit cross-entropy loss with masking."""
|
||||||
loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=preds, labels=labels)
|
loss = tf.nn.sigmoid_cross_entropy_with_logits(logits=preds, labels=labels)
|
||||||
loss = tf.reduce_sum(loss, axis=1)
|
loss = tf.reduce_sum(loss, axis=1)
|
||||||
mask = tf.cast(mask, dtype=tf.float32)
|
mask = tf.cast(mask, dtype=tf.float32)
|
||||||
@ -16,8 +14,8 @@ def masked_logit_cross_entropy(preds, labels, mask):
|
|||||||
return tf.reduce_mean(loss)
|
return tf.reduce_mean(loss)
|
||||||
|
|
||||||
def masked_softmax_cross_entropy(preds, labels, mask):
|
def masked_softmax_cross_entropy(preds, labels, mask):
|
||||||
|
"""Softmax cross-entropy loss with masking."""
|
||||||
loss = tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=labels)
|
loss = tf.nn.softmax_cross_entropy_with_logits(logits=preds, labels=labels)
|
||||||
# loss = tf.reduce_sum(loss, axis=1)
|
|
||||||
mask = tf.cast(mask, dtype=tf.float32)
|
mask = tf.cast(mask, dtype=tf.float32)
|
||||||
mask /= tf.maximum(tf.reduce_sum(mask), tf.constant([1.]))
|
mask /= tf.maximum(tf.reduce_sum(mask), tf.constant([1.]))
|
||||||
loss *= mask
|
loss *= mask
|
||||||
@ -25,7 +23,7 @@ def masked_softmax_cross_entropy(preds, labels, mask):
|
|||||||
|
|
||||||
|
|
||||||
def masked_l2(preds, actuals, mask):
|
def masked_l2(preds, actuals, mask):
|
||||||
"""Softmax cross-entropy loss with masking."""
|
"""L2 loss with masking."""
|
||||||
loss = tf.nn.l2(preds, actuals)
|
loss = tf.nn.l2(preds, actuals)
|
||||||
mask = tf.cast(mask, dtype=tf.float32)
|
mask = tf.cast(mask, dtype=tf.float32)
|
||||||
mask /= tf.reduce_mean(mask)
|
mask /= tf.reduce_mean(mask)
|
||||||
|
@ -9,9 +9,18 @@ class EdgeMinibatchIterator(object):
|
|||||||
|
|
||||||
""" This minibatch iterator iterates over batches of sampled edges or
|
""" This minibatch iterator iterates over batches of sampled edges or
|
||||||
random pairs of co-occuring edges.
|
random pairs of co-occuring edges.
|
||||||
|
|
||||||
|
G -- networkx graph
|
||||||
|
id2idx -- dict mapping node ids to index in feature tensor
|
||||||
|
placeholders -- tensorflow placeholders object
|
||||||
|
context_pairs -- if not none, then a list of co-occuring node pairs (from random walks)
|
||||||
|
batch_size -- size of the minibatches
|
||||||
|
max_degree -- maximum size of the downsampled adjacency lists
|
||||||
|
n2v_retrain -- signals that the iterator is being used to add new embeddings to a n2v model
|
||||||
|
fixed_n2v -- signals that the iterator is being used to retrain n2v with only existing nodes as context
|
||||||
"""
|
"""
|
||||||
def __init__(self, G, id2idx,
|
def __init__(self, G, id2idx,
|
||||||
placeholders, context_pairs=None,batch_size=100, max_degree=25, num_neg_samples=20,
|
placeholders, context_pairs=None, batch_size=100, max_degree=25,
|
||||||
n2v_retrain=False, fixed_n2v=False,
|
n2v_retrain=False, fixed_n2v=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
@ -21,7 +30,6 @@ class EdgeMinibatchIterator(object):
|
|||||||
self.placeholders = placeholders
|
self.placeholders = placeholders
|
||||||
self.batch_size = batch_size
|
self.batch_size = batch_size
|
||||||
self.max_degree = max_degree
|
self.max_degree = max_degree
|
||||||
self.num_neg_samples = num_neg_samples
|
|
||||||
self.batch_num = 0
|
self.batch_num = 0
|
||||||
|
|
||||||
self.nodes = np.random.permutation(G.nodes())
|
self.nodes = np.random.permutation(G.nodes())
|
||||||
@ -162,9 +170,17 @@ class NodeMinibatchIterator(object):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
This minibatch iterator iterates over nodes for supervised learning.
|
This minibatch iterator iterates over nodes for supervised learning.
|
||||||
|
|
||||||
|
G -- networkx graph
|
||||||
|
id2idx -- dict mapping node ids to integer values indexing feature tensor
|
||||||
|
placeholders -- standard tensorflow placeholders object for feeding
|
||||||
|
label_map -- map from node ids to class values (integer or list)
|
||||||
|
num_classes -- number of output classes
|
||||||
|
batch_size -- size of the minibatches
|
||||||
|
max_degree -- maximum size of the downsampled adjacency lists
|
||||||
"""
|
"""
|
||||||
def __init__(self, G, id2idx,
|
def __init__(self, G, id2idx,
|
||||||
placeholders, label_map, num_classes, context_pairs=None,
|
placeholders, label_map, num_classes,
|
||||||
batch_size=100, max_degree=25,
|
batch_size=100, max_degree=25,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
|
||||||
|
@ -16,8 +16,6 @@ FLAGS = flags.FLAGS
|
|||||||
# Boilerplate parts of this code file were originally forked from
|
# Boilerplate parts of this code file were originally forked from
|
||||||
# https://github.com/tkipf/gcn
|
# https://github.com/tkipf/gcn
|
||||||
# which itself was very inspired by the keras package
|
# which itself was very inspired by the keras package
|
||||||
# (A full license with proper attributions will be provided in the
|
|
||||||
# public repo of this code base)
|
|
||||||
|
|
||||||
class Model(object):
|
class Model(object):
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@ -97,6 +95,7 @@ class Model(object):
|
|||||||
|
|
||||||
|
|
||||||
class MLP(Model):
|
class MLP(Model):
|
||||||
|
""" A standard multi-layer perceptron """
|
||||||
def __init__(self, placeholders, dims, categorical=True, **kwargs):
|
def __init__(self, placeholders, dims, categorical=True, **kwargs):
|
||||||
super(MLP, self).__init__(**kwargs)
|
super(MLP, self).__init__(**kwargs)
|
||||||
|
|
||||||
@ -177,7 +176,7 @@ class GeneralizedModel(Model):
|
|||||||
self.opt_op = self.optimizer.minimize(self.loss)
|
self.opt_op = self.optimizer.minimize(self.loss)
|
||||||
|
|
||||||
# SAGEInfo is a namedtuple that specifies the parameters
|
# SAGEInfo is a namedtuple that specifies the parameters
|
||||||
# of the recursive sampled GCN layers
|
# of the recursive GraphSAGE layers
|
||||||
SAGEInfo = namedtuple("SAGEInfo",
|
SAGEInfo = namedtuple("SAGEInfo",
|
||||||
['layer_name', # name of the layer (to get feature embedding etc.)
|
['layer_name', # name of the layer (to get feature embedding etc.)
|
||||||
'neigh_sampler', # callable neigh_sampler constructor
|
'neigh_sampler', # callable neigh_sampler constructor
|
||||||
@ -187,8 +186,7 @@ SAGEInfo = namedtuple("SAGEInfo",
|
|||||||
|
|
||||||
class SampleAndAggregate(GeneralizedModel):
|
class SampleAndAggregate(GeneralizedModel):
|
||||||
"""
|
"""
|
||||||
Implementation of a standard 2-step graph convolutional network
|
Base implementation of unsupervised GraphSAGE
|
||||||
Uses random sampling on neighborhoods
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, placeholders, features, adj, degrees,
|
def __init__(self, placeholders, features, adj, degrees,
|
||||||
@ -197,9 +195,15 @@ class SampleAndAggregate(GeneralizedModel):
|
|||||||
**kwargs):
|
**kwargs):
|
||||||
'''
|
'''
|
||||||
Args:
|
Args:
|
||||||
- layer_infos: List of SGCInfo namedtuples that describe the parameters of all
|
- placeholders: Stanford TensorFlow placeholder object.
|
||||||
the recursive layers. See SGCInfo definition above.
|
- features: Numpy array with node features.
|
||||||
|
- adj: Numpy array with adjacency lists (padded with random re-samples)
|
||||||
|
- degrees: Numpy array with node degrees.
|
||||||
|
- layer_infos: List of SAGEInfo namedtuples that describe the parameters of all
|
||||||
|
the recursive layers. See SAGEInfo definition above.
|
||||||
|
- concat: whether to concatenate during recursive iterations
|
||||||
|
- aggregator_type: how to aggregate neighbor information
|
||||||
|
- model_size: one of "small" and "big"
|
||||||
'''
|
'''
|
||||||
super(SampleAndAggregate, self).__init__(**kwargs)
|
super(SampleAndAggregate, self).__init__(**kwargs)
|
||||||
if aggregator_type == "mean":
|
if aggregator_type == "mean":
|
||||||
@ -392,11 +396,11 @@ class SampleAndAggregate(GeneralizedModel):
|
|||||||
class Node2VecModel(GeneralizedModel):
|
class Node2VecModel(GeneralizedModel):
|
||||||
def __init__(self, placeholders, dict_size, degrees, name=None,
|
def __init__(self, placeholders, dict_size, degrees, name=None,
|
||||||
nodevec_dim=50, lr=0.001, **kwargs):
|
nodevec_dim=50, lr=0.001, **kwargs):
|
||||||
""" Simple version of Node2Vec algorithm.
|
""" Simple version of Node2Vec/DeepWalk algorithm.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dict_size1: the total number of nodes in set1.
|
dict_size: the total number of nodes.
|
||||||
dict_size2: the total number of nodes in set2.
|
degrees: numpy array of node degrees, ordered as in the data's id_map
|
||||||
nodevec_dim: dimension of the vector representation of node.
|
nodevec_dim: dimension of the vector representation of node.
|
||||||
lr: learning rate of optimizer.
|
lr: learning rate of optimizer.
|
||||||
"""
|
"""
|
||||||
|
@ -9,8 +9,7 @@ FLAGS = flags.FLAGS
|
|||||||
|
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Classes that are used to sample node neighborhoods during
|
Classes that are used to sample node neighborhoods
|
||||||
convolutions.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class UniformNeighborSampler(Layer):
|
class UniformNeighborSampler(Layer):
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from graphsage.inits import glorot, zeros
|
from graphsage.inits import zeros
|
||||||
from graphsage.layers import Layer
|
from graphsage.layers import Layer
|
||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
|
|
||||||
@ -13,6 +13,8 @@ class BipartiteEdgePredLayer(Layer):
|
|||||||
def __init__(self, input_dim1, input_dim2, placeholders, dropout=False, act=tf.nn.sigmoid,
|
def __init__(self, input_dim1, input_dim2, placeholders, dropout=False, act=tf.nn.sigmoid,
|
||||||
bias=False, bilinear_weights=False, **kwargs):
|
bias=False, bilinear_weights=False, **kwargs):
|
||||||
"""
|
"""
|
||||||
|
Basic class that applies skip-gram-like loss
|
||||||
|
(i.e., dot product of node+target and node and negative samples)
|
||||||
Args:
|
Args:
|
||||||
bilinear_weights: use a bilinear weight for affinity calculation: u^T A v. If set to
|
bilinear_weights: use a bilinear weight for affinity calculation: u^T A v. If set to
|
||||||
false, it is assumed that input dimensions are the same and the affinity will be
|
false, it is assumed that input dimensions are the same and the affinity will be
|
||||||
@ -95,5 +97,3 @@ class BipartiteEdgePredLayer(Layer):
|
|||||||
|
|
||||||
def weights_norm(self):
|
def weights_norm(self):
|
||||||
return tf.nn.l2_norm(self.vars['weights'])
|
return tf.nn.l2_norm(self.vars['weights'])
|
||||||
|
|
||||||
|
|
||||||
|
@ -8,11 +8,27 @@ flags = tf.app.flags
|
|||||||
FLAGS = flags.FLAGS
|
FLAGS = flags.FLAGS
|
||||||
|
|
||||||
class SupervisedGraphsage(models.SampleAndAggregate):
|
class SupervisedGraphsage(models.SampleAndAggregate):
|
||||||
|
"""Implementation of supervised GraphSAGE."""
|
||||||
|
|
||||||
def __init__(self, num_classes,
|
def __init__(self, num_classes,
|
||||||
placeholders, features, adj, degrees,
|
placeholders, features, adj, degrees,
|
||||||
layer_infos, concat=True, aggregator_type="mean",
|
layer_infos, concat=True, aggregator_type="mean",
|
||||||
model_size="small", sigmoid_loss=False,
|
model_size="small", sigmoid_loss=False,
|
||||||
**kwargs):
|
**kwargs):
|
||||||
|
'''
|
||||||
|
Args:
|
||||||
|
- placeholders: Stanford TensorFlow placeholder object.
|
||||||
|
- features: Numpy array with node features.
|
||||||
|
- adj: Numpy array with adjacency lists (padded with random re-samples)
|
||||||
|
- degrees: Numpy array with node degrees.
|
||||||
|
- layer_infos: List of SAGEInfo namedtuples that describe the parameters of all
|
||||||
|
the recursive layers. See SAGEInfo definition above.
|
||||||
|
- concat: whether to concatenate during recursive iterations
|
||||||
|
- aggregator_type: how to aggregate neighbor information
|
||||||
|
- model_size: one of "small" and "big"
|
||||||
|
- sigmoid_loss: Set to true if nodes can belong to multiple classes
|
||||||
|
'''
|
||||||
|
|
||||||
models.GeneralizedModel.__init__(self, **kwargs)
|
models.GeneralizedModel.__init__(self, **kwargs)
|
||||||
|
|
||||||
if aggregator_type == "mean":
|
if aggregator_type == "mean":
|
||||||
|
@ -244,7 +244,6 @@ def train(train_data, test_data=None):
|
|||||||
epoch_val_costs.append(0)
|
epoch_val_costs.append(0)
|
||||||
while not minibatch.end():
|
while not minibatch.end():
|
||||||
# Construct feed dictionary
|
# Construct feed dictionary
|
||||||
#feed_dict = construct_minibatch_feed_dict(features, G, y_train, train_mask, placeholders)
|
|
||||||
feed_dict = minibatch.next_minibatch_feed_dict()
|
feed_dict = minibatch.next_minibatch_feed_dict()
|
||||||
feed_dict.update({placeholders['dropout']: FLAGS.dropout})
|
feed_dict.update({placeholders['dropout']: FLAGS.dropout})
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user