Update aggregators.py

This commit is contained in:
William L Hamilton 2017-12-27 16:26:43 -06:00 committed by GitHub
parent 8e0d053da2
commit 314f98be08
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -6,7 +6,6 @@ import random
""" """
Set of modules for aggregating embeddings of neighbors. Set of modules for aggregating embeddings of neighbors.
These modules take as input embeddings of neighbors.
""" """
class MeanAggregator(nn.Module): class MeanAggregator(nn.Module):
@ -17,8 +16,9 @@ class MeanAggregator(nn.Module):
""" """
Initializes the aggregator for a specific graph. Initializes the aggregator for a specific graph.
features -- function mapping (node_list, features, offset) to feature values features -- function mapping LongTensor of node ids to FloatTensor of feature values.
see torch.nn.EmbeddingBag and forward function below docs for offset meaning. cuda -- whether to use GPU
gcn --- whether to perform concatenation GraphSAGE-style, or add self-loops GCN-style
""" """
super(MeanAggregator, self).__init__() super(MeanAggregator, self).__init__()
@ -28,6 +28,11 @@ class MeanAggregator(nn.Module):
self.gcn = gcn self.gcn = gcn
def forward(self, nodes, to_neighs, num_sample=10): def forward(self, nodes, to_neighs, num_sample=10):
"""
nodes --- list of nodes in a batch
to_neighs --- list of sets, each set is the set of neighbors for node in batch
num_sample --- number of neighbors to sample. No sampling if None.
"""
# Local pointers to functions (speed hack) # Local pointers to functions (speed hack)
_set = set _set = set
if not num_sample is None: if not num_sample is None: