Update aggregators.py
This commit is contained in:
parent
8e0d053da2
commit
314f98be08
@ -6,7 +6,6 @@ import random
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
Set of modules for aggregating embeddings of neighbors.
|
Set of modules for aggregating embeddings of neighbors.
|
||||||
These modules take as input embeddings of neighbors.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
class MeanAggregator(nn.Module):
|
class MeanAggregator(nn.Module):
|
||||||
@ -17,8 +16,9 @@ class MeanAggregator(nn.Module):
|
|||||||
"""
|
"""
|
||||||
Initializes the aggregator for a specific graph.
|
Initializes the aggregator for a specific graph.
|
||||||
|
|
||||||
features -- function mapping (node_list, features, offset) to feature values
|
features -- function mapping LongTensor of node ids to FloatTensor of feature values.
|
||||||
see torch.nn.EmbeddingBag and forward function below docs for offset meaning.
|
cuda -- whether to use GPU
|
||||||
|
gcn --- whether to perform concatenation GraphSAGE-style, or add self-loops GCN-style
|
||||||
"""
|
"""
|
||||||
|
|
||||||
super(MeanAggregator, self).__init__()
|
super(MeanAggregator, self).__init__()
|
||||||
@ -28,6 +28,11 @@ class MeanAggregator(nn.Module):
|
|||||||
self.gcn = gcn
|
self.gcn = gcn
|
||||||
|
|
||||||
def forward(self, nodes, to_neighs, num_sample=10):
|
def forward(self, nodes, to_neighs, num_sample=10):
|
||||||
|
"""
|
||||||
|
nodes --- list of nodes in a batch
|
||||||
|
to_neighs --- list of sets, each set is the set of neighbors for node in batch
|
||||||
|
num_sample --- number of neighbors to sample. No sampling if None.
|
||||||
|
"""
|
||||||
# Local pointers to functions (speed hack)
|
# Local pointers to functions (speed hack)
|
||||||
_set = set
|
_set = set
|
||||||
if not num_sample is None:
|
if not num_sample is None:
|
||||||
|
Loading…
Reference in New Issue
Block a user