Update aggregators.py
This commit is contained in:
parent
8e0d053da2
commit
314f98be08
@ -6,7 +6,6 @@ import random
|
||||
|
||||
"""
|
||||
Set of modules for aggregating embeddings of neighbors.
|
||||
These modules take as input embeddings of neighbors.
|
||||
"""
|
||||
|
||||
class MeanAggregator(nn.Module):
|
||||
@ -17,8 +16,9 @@ class MeanAggregator(nn.Module):
|
||||
"""
|
||||
Initializes the aggregator for a specific graph.
|
||||
|
||||
features -- function mapping (node_list, features, offset) to feature values
|
||||
see torch.nn.EmbeddingBag and forward function below docs for offset meaning.
|
||||
features -- function mapping LongTensor of node ids to FloatTensor of feature values.
|
||||
cuda -- whether to use GPU
|
||||
gcn --- whether to perform concatenation GraphSAGE-style, or add self-loops GCN-style
|
||||
"""
|
||||
|
||||
super(MeanAggregator, self).__init__()
|
||||
@ -28,6 +28,11 @@ class MeanAggregator(nn.Module):
|
||||
self.gcn = gcn
|
||||
|
||||
def forward(self, nodes, to_neighs, num_sample=10):
|
||||
"""
|
||||
nodes --- list of nodes in a batch
|
||||
to_neighs --- list of sets, each set is the set of neighbors for node in batch
|
||||
num_sample --- number of neighbors to sample. No sampling if None.
|
||||
"""
|
||||
# Local pointers to functions (speed hack)
|
||||
_set = set
|
||||
if not num_sample is None:
|
||||
|
Loading…
Reference in New Issue
Block a user