diff --git a/graphsage/aggregators.py b/graphsage/aggregators.py index 23b7280..a23c891 100644 --- a/graphsage/aggregators.py +++ b/graphsage/aggregators.py @@ -6,7 +6,6 @@ import random """ Set of modules for aggregating embeddings of neighbors. -These modules take as input embeddings of neighbors. """ class MeanAggregator(nn.Module): @@ -17,8 +16,9 @@ class MeanAggregator(nn.Module): """ Initializes the aggregator for a specific graph. - features -- function mapping (node_list, features, offset) to feature values - see torch.nn.EmbeddingBag and forward function below docs for offset meaning. + features -- function mapping LongTensor of node ids to FloatTensor of feature values. + cuda -- whether to use GPU + gcn --- whether to perform concatenation GraphSAGE-style, or add self-loops GCN-style """ super(MeanAggregator, self).__init__() @@ -28,6 +28,11 @@ class MeanAggregator(nn.Module): self.gcn = gcn def forward(self, nodes, to_neighs, num_sample=10): + """ + nodes --- list of nodes in a batch + to_neighs --- list of sets, each set is the set of neighbors for node in batch + num_sample --- number of neighbors to sample. No sampling if None. + """ # Local pointers to functions (speed hack) _set = set if not num_sample is None: