Update aggregators.py

This commit is contained in:
William L Hamilton 2017-12-27 16:26:43 -06:00 committed by GitHub
parent 8e0d053da2
commit 314f98be08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -6,7 +6,6 @@ import random
"""
Set of modules for aggregating embeddings of neighbors.
These modules take as input embeddings of neighbors.
"""
class MeanAggregator(nn.Module):
@ -17,8 +16,9 @@ class MeanAggregator(nn.Module):
"""
Initializes the aggregator for a specific graph.
features -- function mapping (node_list, features, offset) to feature values
see torch.nn.EmbeddingBag and forward function below docs for offset meaning.
features -- function mapping LongTensor of node ids to FloatTensor of feature values.
cuda -- whether to use GPU
gcn --- whether to perform concatenation GraphSAGE-style, or add self-loops GCN-style
"""
super(MeanAggregator, self).__init__()
@ -28,6 +28,11 @@ class MeanAggregator(nn.Module):
self.gcn = gcn
def forward(self, nodes, to_neighs, num_sample=10):
"""
nodes --- list of nodes in a batch
to_neighs --- list of sets, each set is the set of neighbors for node in batch
num_sample --- number of neighbors to sample. No sampling if None.
"""
# Local pointers to functions (speed hack)
_set = set
if not num_sample is None: