Compare commits

...

10 Commits

Author SHA1 Message Date
cf85cf7842 first 2024-01-04 16:15:33 +08:00
William Hamilton
d3105e5223 Small fix for type conversion. 2018-06-24 09:30:13 -04:00
William L Hamilton
146406c2c0
Update README.md 2018-03-10 10:33:12 -08:00
William L Hamilton
fcf5eb5d63
Update README.md 2017-12-27 16:31:49 -06:00
William L Hamilton
4583602eec
Update model.py 2017-12-27 16:27:30 -06:00
William L Hamilton
314f98be08
Update aggregators.py 2017-12-27 16:26:43 -06:00
William L Hamilton
8e0d053da2
Update encoders.py 2017-12-27 16:23:07 -06:00
William L Hamilton
b0bbc30be5
Delete README 2017-12-19 16:32:24 -06:00
William L Hamilton
664a8d633e
Create README.md 2017-12-19 16:32:14 -06:00
williamleif
1966d21b68 Added README. 2017-12-19 14:23:56 -08:00
4 changed files with 34 additions and 5 deletions

19
README.md Normal file
View File

@ -0,0 +1,19 @@
# Reference PyTorch GraphSAGE Implementation
### Author: William L. Hamilton
Basic reference PyTorch implementation of [GraphSAGE](https://github.com/williamleif/GraphSAGE).
This reference implementation is not as fast as the TensorFlow version for large graphs, but the code is easier to read and it performs better (in terms of speed) on small-graph benchmarks.
The code is also intended to be simpler, more extensible, and easier to work with than the TensorFlow version.
Currently, only supervised versions of GraphSAGE-mean and GraphSAGE-GCN are implemented.
#### Requirements
pytorch >0.2 is required.
#### Running examples
Execute `python -m graphsage.model` to run the Cora example.
It assumes that CUDA is not being used, but modifying the run functions in `model.py` in the obvious way can change this.
There is also a pubmed example (called via the `run_pubmed` function in model.py).

View File

@ -2,11 +2,11 @@ import torch
import torch.nn as nn import torch.nn as nn
from torch.autograd import Variable from torch.autograd import Variable
import random import random
""" """
Set of modules for aggregating embeddings of neighbors. Set of modules for aggregating embeddings of neighbors.
These modules take as input embeddings of neighbors.
""" """
class MeanAggregator(nn.Module): class MeanAggregator(nn.Module):
@ -17,8 +17,9 @@ class MeanAggregator(nn.Module):
""" """
Initializes the aggregator for a specific graph. Initializes the aggregator for a specific graph.
features -- function mapping (node_list, features, offset) to feature values features -- function mapping LongTensor of node ids to FloatTensor of feature values.
see torch.nn.EmbeddingBag and forward function below docs for offset meaning. cuda -- whether to use GPU
gcn --- whether to perform concatenation GraphSAGE-style, or add self-loops GCN-style
""" """
super(MeanAggregator, self).__init__() super(MeanAggregator, self).__init__()
@ -28,6 +29,11 @@ class MeanAggregator(nn.Module):
self.gcn = gcn self.gcn = gcn
def forward(self, nodes, to_neighs, num_sample=10): def forward(self, nodes, to_neighs, num_sample=10):
"""
nodes --- list of nodes in a batch
to_neighs --- list of sets, each set is the set of neighbors for node in batch
num_sample --- number of neighbors to sample. No sampling if None.
"""
# Local pointers to functions (speed hack) # Local pointers to functions (speed hack)
_set = set _set = set
if not num_sample is None: if not num_sample is None:

View File

@ -35,9 +35,8 @@ class Encoder(nn.Module):
Generates embeddings for a batch of nodes. Generates embeddings for a batch of nodes.
nodes -- list of nodes nodes -- list of nodes
mode -- string desiginating the mode of the nodes
""" """
neigh_feats = self.aggregator.forward(nodes, [self.adj_lists[node] for node in nodes], neigh_feats = self.aggregator.forward(nodes, [self.adj_lists[int(node)] for node in nodes],
self.num_sample) self.num_sample)
if not self.gcn: if not self.gcn:
if self.cuda: if self.cuda:

View File

@ -12,6 +12,11 @@ from collections import defaultdict
from graphsage.encoders import Encoder from graphsage.encoders import Encoder
from graphsage.aggregators import MeanAggregator from graphsage.aggregators import MeanAggregator
"""
Simple supervised GraphSAGE model as well as examples running the model
on the Cora and Pubmed datasets.
"""
class SupervisedGraphSage(nn.Module): class SupervisedGraphSage(nn.Module):
def __init__(self, num_classes, enc): def __init__(self, num_classes, enc):